From 3145f8c54cc0c5ea6a20069fc90b88e7ae13057c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 24 Aug 2025 11:22:49 +0800 Subject: [PATCH] Implement ping support for WireGuard and Tailscale --- go.mod | 4 +- go.sum | 8 +- protocol/tailscale/endpoint.go | 41 +++- transport/wireguard/device.go | 2 +- transport/wireguard/device_nat.go | 48 ++-- transport/wireguard/device_nat_gvisor.go | 48 ---- transport/wireguard/device_nat_non_gvisor.go | 20 -- transport/wireguard/device_stack.go | 217 ++++--------------- transport/wireguard/device_system.go | 30 +-- transport/wireguard/device_system_stack.go | 65 +++++- transport/wireguard/endpoint.go | 2 +- 11 files changed, 187 insertions(+), 298 deletions(-) delete mode 100644 transport/wireguard/device_nat_gvisor.go delete mode 100644 transport/wireguard/device_nat_non_gvisor.go diff --git a/go.mod b/go.mod index 7e0ddb7a..17beddc7 100644 --- a/go.mod +++ b/go.mod @@ -33,10 +33,10 @@ require ( github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 - github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250823082506-db38fc041d02 + github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250824024715-dd18aa2b8633 github.com/sagernet/sing-vmess v0.2.7 github.com/sagernet/smux v1.5.34-mod.2 - github.com/sagernet/tailscale v1.80.3-mod.5 + github.com/sagernet/tailscale v1.80.3-mod.6 github.com/sagernet/wireguard-go v0.0.1-beta.7 github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 github.com/spf13/cobra v1.9.1 diff --git a/go.sum b/go.sum index 719db075..77a502bf 100644 --- a/go.sum +++ b/go.sum @@ -179,14 +179,14 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= -github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250823082506-db38fc041d02 h1:vnbpigObkH2tT7OdAuxR0jbk6lOnF5BlDRb4A9hF0xM= -github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250823082506-db38fc041d02/go.mod h1:z1lkiAE5ex5gHBzh5+G9TFsyM9grOaSsRx33mVfWfVI= +github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250824024715-dd18aa2b8633 h1:cqm3Gd253bpnQV5qQvvrFEcO0dzUrfsiOQRTtSFM8cs= +github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250824024715-dd18aa2b8633/go.mod h1:z1lkiAE5ex5gHBzh5+G9TFsyM9grOaSsRx33mVfWfVI= github.com/sagernet/sing-vmess v0.2.7 h1:2ee+9kO0xW5P4mfe6TYVWf9VtY8k1JhNysBqsiYj0sk= github.com/sagernet/sing-vmess v0.2.7/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs= github.com/sagernet/smux v1.5.34-mod.2 h1:gkmBjIjlJ2zQKpLigOkFur5kBKdV6bNRoFu2WkltRQ4= github.com/sagernet/smux v1.5.34-mod.2/go.mod h1:0KW0+R+ycvA2INW4gbsd7BNyg+HEfLIAxa5N02/28Zc= -github.com/sagernet/tailscale v1.80.3-mod.5 h1:7V7z+p2C//TGtff20pPnDCt3qP6uFyY62peJoKF9z/A= -github.com/sagernet/tailscale v1.80.3-mod.5/go.mod h1:EBxXsWu4OH2ELbQLq32WoBeIubG8KgDrg4/Oaxjs6lI= +github.com/sagernet/tailscale v1.80.3-mod.6 h1:oJs0jpRNS/12+mPf3r9maxWl9dWy1RanugLNmsF74Gs= +github.com/sagernet/tailscale v1.80.3-mod.6/go.mod h1:EBxXsWu4OH2ELbQLq32WoBeIubG8KgDrg4/Oaxjs6lI= github.com/sagernet/wireguard-go v0.0.1-beta.7 h1:ltgBwYHfr+9Wz1eG59NiWnHrYEkDKHG7otNZvu85DXI= github.com/sagernet/wireguard-go v0.0.1-beta.7/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index 55f690c5..743a793c 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -31,6 +31,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" @@ -57,7 +58,10 @@ import ( "go4.org/netipx" ) -var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) +var ( + _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) + _ adapter.DirectRouteOutbound = (*Endpoint)(nil) +) func init() { version.SetVersion("sing-box " + C.Version) @@ -77,6 +81,7 @@ type Endpoint struct { platformInterface platform.Interface server *tsnet.Server stack *stack.Stack + icmpForwarder *tun.ICMPForwarder filter *atomic.Pointer[filter.Filter] onReconfigHook wgengine.ReconfigListener @@ -176,7 +181,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL }, } return &Endpoint{ - Adapter: endpoint.NewAdapter(C.TypeTailscale, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil), + Adapter: endpoint.NewAdapter(C.TypeTailscale, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMPv4, N.NetworkICMPv6}, nil), ctx: ctx, router: router, logger: logger, @@ -242,10 +247,11 @@ func (t *Endpoint) Start(stage adapter.StartStage) error { } ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(t.ctx, ipStack, t).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout).HandlePacket) - icmpForwarder := tun.NewICMPForwarder(t.ctx, ipStack, netip.Addr{}, netip.Addr{}, t, t.udpTimeout) + icmpForwarder := tun.NewICMPForwarder(t.ctx, ipStack, t, t.udpTimeout) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) t.stack = ipStack + t.icmpForwarder = icmpForwarder localBackend := t.server.ExportLocalBackend() perfs := &ipn.MaskedPrefs{ @@ -485,6 +491,26 @@ func (t *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, t.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } +func (t *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + inet4Address, inet6Address := t.server.TailscaleIPs() + if metadata.Destination.Addr.Is4() && !inet4Address.IsValid() || metadata.Destination.Addr.Is6() && !inet6Address.IsValid() { + return nil, E.New("Tailscale is not ready yet") + } + ctx := log.ContextWithNewID(t.ctx) + destination, err := ping.ConnectGVisor( + ctx, t.logger, + metadata.Source.Addr, metadata.Destination.Addr, + routeContext, + t.stack, + inet4Address, inet6Address, + ) + if err != nil { + return nil, err + } + t.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + func (t *Endpoint) PreferredDomain(domain string) bool { routeDomains := t.routeDomains.Load() if routeDomains == nil { @@ -512,6 +538,15 @@ func (t *Endpoint) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCf if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) { return } + var inet4Address, inet6Address netip.Addr + for _, address := range cfg.Addresses { + if address.Addr().Is4() { + inet4Address = address.Addr() + } else if address.Addr().Is6() { + inet6Address = address.Addr() + } + } + t.icmpForwarder.SetLocalAddresses(inet4Address, inet6Address) t.cfg = cfg t.dnsCfg = dnsCfg diff --git a/transport/wireguard/device.go b/transport/wireguard/device.go index 4e2d6a24..0ef0407b 100644 --- a/transport/wireguard/device.go +++ b/transport/wireguard/device.go @@ -38,7 +38,7 @@ type DeviceOptions struct { func NewDevice(options DeviceOptions) (Device, error) { if !options.System { return newStackDevice(options) - } else if options.Handler == nil { + } else if !tun.WithGVisor { return newSystemDevice(options) } else { return newSystemStackDevice(options) diff --git a/transport/wireguard/device_nat.go b/transport/wireguard/device_nat.go index 2c482d30..cff5c29d 100644 --- a/transport/wireguard/device_nat.go +++ b/transport/wireguard/device_nat.go @@ -3,6 +3,7 @@ package wireguard import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common/buf" ) @@ -10,29 +11,34 @@ var _ Device = (*natDeviceWrapper)(nil) type natDeviceWrapper struct { Device - gVisorOutbound packetOutbound chan *buf.Buffer - mapping *tun.NatMapping - writer *tun.NatWriter + rewriter *ping.Rewriter buffer [][]byte } -func NewNATDevice(upstream Device, ipRewrite bool) NatDevice { +func NewNATDevice(upstream Device) NatDevice { wrapper := &natDeviceWrapper{ Device: upstream, - gVisorOutbound: newGVisorOutbound(), packetOutbound: make(chan *buf.Buffer, 256), - mapping: tun.NewNatMapping(ipRewrite), - } - if ipRewrite { - wrapper.writer = tun.NewNatWriter(upstream.Inet4Address(), upstream.Inet6Address()) + rewriter: ping.NewRewriter(upstream.Inet4Address(), upstream.Inet6Address()), } return wrapper } +func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + select { + case packet := <-d.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil + default: + } + return d.Device.Read(bufs, sizes, offset) +} + func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { for _, buffer := range bufs { - handled, err := d.mapping.WritePacket(buffer[offset:]) + handled, err := d.rewriter.WriteBack(buffer[offset:]) if handled { if err != nil { return 0, err @@ -56,30 +62,24 @@ func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, ro Source: metadata.Source.Addr, Destination: metadata.Destination.Addr, } - d.mapping.CreateSession(session, routeContext) - return &natDestinationWrapper{d, session}, nil + d.rewriter.CreateSession(session, routeContext) + return &natDestination{d, session}, nil } -var _ tun.DirectRouteDestination = (*natDestinationWrapper)(nil) +var _ tun.DirectRouteDestination = (*natDestination)(nil) -type natDestinationWrapper struct { +type natDestination struct { device *natDeviceWrapper session tun.DirectRouteSession } -func (d *natDestinationWrapper) WritePacket(buffer *buf.Buffer) error { - if d.device.writer != nil { - d.device.writer.RewritePacket(buffer.Bytes()) - } +func (d *natDestination) WritePacket(buffer *buf.Buffer) error { + d.device.rewriter.RewritePacket(buffer.Bytes()) d.device.packetOutbound <- buffer return nil } -func (d *natDestinationWrapper) Close() error { - d.device.mapping.DeleteSession(d.session) +func (d *natDestination) Close() error { + d.device.rewriter.DeleteSession(d.session) return nil } - -func (d *natDestinationWrapper) Timeout() bool { - return false -} diff --git a/transport/wireguard/device_nat_gvisor.go b/transport/wireguard/device_nat_gvisor.go deleted file mode 100644 index edecba34..00000000 --- a/transport/wireguard/device_nat_gvisor.go +++ /dev/null @@ -1,48 +0,0 @@ -//go:build with_gvisor - -package wireguard - -import ( - "github.com/sagernet/gvisor/pkg/tcpip/stack" -) - -type gVisorOutbound struct { - outbound chan *stack.PacketBuffer -} - -func newGVisorOutbound() gVisorOutbound { - return gVisorOutbound{ - outbound: make(chan *stack.PacketBuffer, 256), - } -} - -func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - select { - case packet := <-d.outbound: - defer packet.DecRef() - var copyN int - /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) { - copyN += copy(bufs[0][offset+copyN:], view.AsSlice()) - })*/ - for _, view := range packet.AsSlices() { - copyN += copy(bufs[0][offset+copyN:], view) - } - sizes[0] = copyN - return 1, nil - case packet := <-d.packetOutbound: - defer packet.Release() - sizes[0] = copy(bufs[0][offset:], packet.Bytes()) - return 1, nil - default: - } - return d.Device.Read(bufs, sizes, offset) -} - -func (d *natDestinationWrapper) WritePacketBuffer(packetBuffer *stack.PacketBuffer) error { - println("read from wg") - if d.device.writer != nil { - d.device.writer.RewritePacketBuffer(packetBuffer) - } - d.device.outbound <- packetBuffer - return nil -} diff --git a/transport/wireguard/device_nat_non_gvisor.go b/transport/wireguard/device_nat_non_gvisor.go deleted file mode 100644 index e81e1e31..00000000 --- a/transport/wireguard/device_nat_non_gvisor.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !with_gvisor - -package wireguard - -type gVisorOutbound struct{} - -func newGVisorOutbound() gVisorOutbound { - return gVisorOutbound{} -} - -func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - select { - case packet := <-d.packetOutbound: - defer packet.Release() - sizes[0] = copy(bufs[0][offset:], packet.Bytes()) - return 1, nil - default: - } - return d.Device.Read(bufs, sizes, offset) -} diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index eb759a2b..94fa39f8 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -19,7 +19,9 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -31,6 +33,8 @@ import ( var _ NatDevice = (*stackDevice)(nil) type stackDevice struct { + ctx context.Context + logger log.ContextLogger stack *stack.Stack mtu uint32 events chan wgTun.Event @@ -38,25 +42,28 @@ type stackDevice struct { packetOutbound chan *buf.Buffer done chan struct{} dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address - mapping *tun.NatMapping - writer *tun.NatWriter + inet4Address netip.Addr + inet6Address netip.Addr } func newStackDevice(options DeviceOptions) (*stackDevice, error) { tunDevice := &stackDevice{ + ctx: options.Context, + logger: options.Logger, mtu: options.MTU, events: make(chan wgTun.Event, 1), outbound: make(chan *stack.PacketBuffer, 256), packetOutbound: make(chan *buf.Buffer, 256), done: make(chan struct{}), - mapping: tun.NewNatMapping(true), } - ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice)) + ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true) if err != nil { return nil, err } + var ( + inet4Address netip.Addr + inet6Address netip.Addr + ) for _, prefix := range options.Address { addr := tun.AddressFromAddr(prefix.Addr()) protoAddr := tcpip.ProtocolAddress{ @@ -66,10 +73,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) { }, } if prefix.Addr().Is4() { - tunDevice.addr4 = addr + inet4Address = prefix.Addr() + tunDevice.inet4Address = inet4Address protoAddr.Protocol = ipv4.ProtocolNumber } else { - tunDevice.addr6 = addr + inet6Address = prefix.Addr() + tunDevice.inet6Address = inet6Address protoAddr.Protocol = ipv6.ProtocolNumber } gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) @@ -77,12 +86,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) { return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) } } - tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address()) tunDevice.stack = ipStack if options.Handler != nil { ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) + icmpForwarder.SetLocalAddresses(inet4Address, inet6Address) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) } @@ -101,10 +110,10 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { networkProtocol = header.IPv4ProtocolNumber - bind.Addr = w.addr4 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } else { networkProtocol = header.IPv6ProtocolNumber - bind.Addr = w.addr6 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } switch N.NetworkName(network) { case N.NetworkTCP: @@ -131,10 +140,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { networkProtocol = header.IPv4ProtocolNumber - bind.Addr = w.addr4 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } else { networkProtocol = header.IPv6ProtocolNumber - bind.Addr = w.addr6 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) if err != nil { @@ -144,11 +153,11 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) } func (w *stackDevice) Inet4Address() netip.Addr { - return netip.AddrFrom4(w.addr4.As4()) + return w.inet4Address } func (w *stackDevice) Inet6Address() netip.Addr { - return netip.AddrFrom16(w.addr6.As16()) + return w.inet6Address } func (w *stackDevice) SetDevice(device *device.Device) { @@ -194,14 +203,6 @@ func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) { if len(b) == 0 { continue } - handled, err := w.mapping.WritePacket(b) - if handled { - if err != nil { - return count, err - } - count++ - continue - } var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(b) { case header.IPv4Version: @@ -250,6 +251,22 @@ func (w *stackDevice) BatchSize() int { return 1 } +func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + ctx := log.ContextWithNewID(w.ctx) + destination, err := ping.ConnectGVisor( + ctx, w.logger, + metadata.Source.Addr, metadata.Destination.Addr, + routeContext, + w.stack, + w.inet4Address, w.inet6Address, + ) + if err != nil { + return nil, err + } + w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + var _ stack.LinkEndpoint = (*wireEndpoint)(nil) type wireEndpoint stackDevice @@ -315,157 +332,3 @@ func (ep *wireEndpoint) Close() { func (ep *wireEndpoint) SetOnCloseAction(f func()) { } - -func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { - /* var wq waiter.Queue - ep, err := raw.NewEndpoint(w.stack, ipv4.ProtocolNumber, icmp.ProtocolNumber4, &wq) - if err != nil { - return nil, E.Cause(gonet.TranslateNetstackError(err), "create endpoint") - } - err = ep.Connect(tcpip.FullAddress{ - NIC: tun.DefaultNIC, - Port: metadata.Destination.Port, - Addr: tun.AddressFromAddr(metadata.Destination.Addr), - }) - if err != nil { - ep.Close() - return nil, E.Cause(gonet.TranslateNetstackError(err), "ICMP connect ", metadata.Destination) - } - fmt.Println("linked ", metadata.Network, " connection to ", metadata.Destination.AddrString()) - destination := &endpointNatDestination{ - ep: ep, - wq: &wq, - context: routeContext, - } - go destination.loopRead() - return destination, nil*/ - session := tun.DirectRouteSession{ - Source: metadata.Source.Addr, - Destination: metadata.Destination.Addr, - } - w.mapping.CreateSession(session, routeContext) - return &stackNatDestination{ - device: w, - session: session, - }, nil -} - -type stackNatDestination struct { - device *stackDevice - session tun.DirectRouteSession -} - -func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error { - if d.device.writer != nil { - d.device.writer.RewritePacket(buffer.Bytes()) - } - d.device.packetOutbound <- buffer - return nil -} - -func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error { - if d.device.writer != nil { - d.device.writer.RewritePacketBuffer(buffer) - } - d.device.outbound <- buffer - return nil -} - -func (d *stackNatDestination) Close() error { - d.device.mapping.DeleteSession(d.session) - return nil -} - -func (d *stackNatDestination) Timeout() bool { - return false -} - -/*type endpointNatDestination struct { - ep tcpip.Endpoint - wq *waiter.Queue - networkProto tcpip.NetworkProtocolNumber - context tun.DirectRouteContext - done chan struct{} -} - -func (d *endpointNatDestination) loopRead() { - for { - println("start read") - buffer, err := commonRead(d.ep, d.wq, d.done) - if err != nil { - log.Error(err) - return - } - println("done read") - ipHdr := header.IPv4(buffer.Bytes()) - if ipHdr.TransportProtocol() != header.ICMPv4ProtocolNumber { - buffer.Release() - continue - } - icmpHdr := header.ICMPv4(ipHdr.Payload()) - if icmpHdr.Type() != header.ICMPv4EchoReply { - buffer.Release() - continue - } - fmt.Println("read echo reply") - _ = d.context.WritePacket(ipHdr) - buffer.Release() - } -} - -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, done chan struct{}) (*buf.Buffer, error) { - buffer := buf.NewPacket() - result, err := ep.Read(buffer, tcpip.ReadOptions{}) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) - for { - result, err = ep.Read(buffer, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - break - } - select { - case <-notifyCh: - case <-done: - buffer.Release() - return nil, context.DeadlineExceeded - } - } - } - return nil, gonet.TranslateNetstackError(err) - } - buffer.Truncate(result.Count) - return buffer, nil -} - -func (d *endpointNatDestination) WritePacket(buffer *buf.Buffer) error { - _, err := d.ep.Write(buffer, tcpip.WriteOptions{}) - if err != nil { - return gonet.TranslateNetstackError(err) - } - return nil -} - -func (d *endpointNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error { - data := buffer.ToView().AsSlice() - println("write echo request buffer :" + fmt.Sprint(data)) - _, err := d.ep.Write(bytes.NewReader(data), tcpip.WriteOptions{}) - if err != nil { - log.Error(err) - return gonet.TranslateNetstackError(err) - } - return nil -} - -func (d *endpointNatDestination) Close() error { - d.ep.Abort() - close(d.done) - return nil -} - -func (d *endpointNatDestination) Timeout() bool { - return false -} -*/ diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index 90abee4b..162a5cbf 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -22,14 +22,14 @@ import ( var _ Device = (*systemDevice)(nil) type systemDevice struct { - options DeviceOptions - dialer N.Dialer - device tun.Tun - batchDevice tun.LinuxTUN - events chan wgTun.Event - closeOnce sync.Once - addr4 netip.Addr - addr6 netip.Addr + options DeviceOptions + dialer N.Dialer + device tun.Tun + batchDevice tun.LinuxTUN + events chan wgTun.Event + closeOnce sync.Once + inet4Address netip.Addr + inet6Address netip.Addr } func newSystemDevice(options DeviceOptions) (*systemDevice, error) { @@ -53,11 +53,11 @@ func newSystemDevice(options DeviceOptions) (*systemDevice, error) { } } return &systemDevice{ - options: options, - dialer: options.CreateDialer(options.Name), - events: make(chan wgTun.Event, 1), - addr4: inet4Address, - addr6: inet6Address, + options: options, + dialer: options.CreateDialer(options.Name), + events: make(chan wgTun.Event, 1), + inet4Address: inet4Address, + inet6Address: inet6Address, }, nil } @@ -70,11 +70,11 @@ func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr } func (w *systemDevice) Inet4Address() netip.Addr { - return w.addr4 + return w.inet4Address } func (w *systemDevice) Inet6Address() netip.Addr { - return w.addr6 + return w.inet6Address } func (w *systemDevice) SetDevice(device *device.Device) { diff --git a/transport/wireguard/device_system_stack.go b/transport/wireguard/device_system_stack.go index 4249e53e..16f56ebd 100644 --- a/transport/wireguard/device_system_stack.go +++ b/transport/wireguard/device_system_stack.go @@ -3,16 +3,25 @@ package wireguard import ( + "context" "net/netip" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" "github.com/sagernet/wireguard-go/device" ) @@ -20,6 +29,8 @@ var _ Device = (*systemStackDevice)(nil) type systemStackDevice struct { *systemDevice + ctx context.Context + logger logger.ContextLogger stack *stack.Stack endpoint *deviceEndpoint writeBufs [][]byte @@ -34,13 +45,45 @@ func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) { mtu: options.MTU, done: make(chan struct{}), } - ipStack, err := tun.NewGVisorStack(endpoint) + ipStack, err := tun.NewGVisorStackWithOptions(endpoint, stack.NICOptions{}, true) if err != nil { return nil, err } - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + var ( + inet4Address netip.Addr + inet6Address netip.Addr + ) + for _, prefix := range options.Address { + addr := tun.AddressFromAddr(prefix.Addr()) + protoAddr := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: prefix.Bits(), + }, + } + if prefix.Addr().Is4() { + inet4Address = prefix.Addr() + protoAddr.Protocol = ipv4.ProtocolNumber + } else { + inet6Address = prefix.Addr() + protoAddr.Protocol = ipv6.ProtocolNumber + } + gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) + if gErr != nil { + return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) + } + } + if options.Handler != nil { + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) + icmpForwarder.SetLocalAddresses(inet4Address, inet6Address) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) + } return &systemStackDevice{ + ctx: options.Context, + logger: options.Logger, systemDevice: system, stack: ipStack, endpoint: endpoint, @@ -116,6 +159,22 @@ func (w *systemStackDevice) writeStack(packet []byte) bool { return true } +func (w *systemStackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + ctx := log.ContextWithNewID(w.ctx) + destination, err := ping.ConnectGVisor( + ctx, w.logger, + metadata.Source.Addr, metadata.Destination.Addr, + routeContext, + w.stack, + w.inet4Address, w.inet6Address, + ) + if err != nil { + return nil, err + } + w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + type deviceEndpoint struct { mtu uint32 done chan struct{} diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index fe3a5a5c..b2b41956 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -119,7 +119,7 @@ func NewEndpoint(options EndpointOptions) (*Endpoint, error) { } natDevice, isNatDevice := tunDevice.(NatDevice) if !isNatDevice { - natDevice = NewNATDevice(tunDevice, true) + natDevice = NewNATDevice(tunDevice) } return &Endpoint{ options: options,