sing-box/protocol/ndis/endpoint.go
2025-01-03 18:38:04 +08:00

111 lines
2.6 KiB
Go

//go:build windows
package ndis
import (
"sync"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/wiresock/ndisapi-go"
"github.com/wiresock/ndisapi-go/driver"
)
var _ stack.LinkEndpoint = (*ndisEndpoint)(nil)
type ndisEndpoint struct {
filter *driver.QueuedPacketFilter
mtu uint32
address tcpip.LinkAddress
dispatcher stack.NetworkDispatcher
}
func (e *ndisEndpoint) MTU() uint32 {
return e.mtu
}
func (e *ndisEndpoint) SetMTU(mtu uint32) {
}
func (e *ndisEndpoint) MaxHeaderLength() uint16 {
return header.EthernetMinimumSize
}
func (e *ndisEndpoint) LinkAddress() tcpip.LinkAddress {
return e.address
}
func (e *ndisEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (e *ndisEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return 0
}
func (e *ndisEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *ndisEndpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *ndisEndpoint) Wait() {
}
func (e *ndisEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareEther
}
func (e *ndisEndpoint) AddHeader(pkt *stack.PacketBuffer) {
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
fields := header.EthernetFields{
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
Type: pkt.NetworkProtocolNumber,
}
eth.Encode(&fields)
}
func (e *ndisEndpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
return ok
}
func (e *ndisEndpoint) Close() {
}
func (e *ndisEndpoint) SetOnCloseAction(f func()) {
}
var bufferPool = sync.Pool{
New: func() any {
return new(ndisapi.IntermediateBuffer)
},
}
func (e *ndisEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
ndisBuf := bufferPool.Get().(*ndisapi.IntermediateBuffer)
viewList, offset := packetBuffer.AsViewList()
var view *buffer.View
for view = viewList.Front(); view != nil && offset >= view.Size(); view = view.Next() {
offset -= view.Size()
}
index := copy(ndisBuf.Buffer[:], view.AsSlice()[offset:])
for view = view.Next(); view != nil; view = view.Next() {
index += copy(ndisBuf.Buffer[index:], view.AsSlice())
}
ndisBuf.Length = uint32(index)
err := e.filter.InsertPacketToMstcp(ndisBuf)
bufferPool.Put(ndisBuf)
if err != nil {
return 0, &tcpip.ErrAborted{}
}
}
return list.Len(), nil
}