Implement ping support for WireGuard and Tailscale

This commit is contained in:
世界 2025-08-24 11:22:49 +08:00
parent 29862d8cce
commit 3145f8c54c
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
11 changed files with 187 additions and 298 deletions

4
go.mod
View File

@ -33,10 +33,10 @@ require (
github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks v0.2.8
github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowsocks2 v0.2.1
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11
github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250823082506-db38fc041d02 github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250824024715-dd18aa2b8633
github.com/sagernet/sing-vmess v0.2.7 github.com/sagernet/sing-vmess v0.2.7
github.com/sagernet/smux v1.5.34-mod.2 github.com/sagernet/smux v1.5.34-mod.2
github.com/sagernet/tailscale v1.80.3-mod.5 github.com/sagernet/tailscale v1.80.3-mod.6
github.com/sagernet/wireguard-go v0.0.1-beta.7 github.com/sagernet/wireguard-go v0.0.1-beta.7
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
github.com/spf13/cobra v1.9.1 github.com/spf13/cobra v1.9.1

8
go.sum
View File

@ -179,14 +179,14 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq
github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w=
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250823082506-db38fc041d02 h1:vnbpigObkH2tT7OdAuxR0jbk6lOnF5BlDRb4A9hF0xM= github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250824024715-dd18aa2b8633 h1:cqm3Gd253bpnQV5qQvvrFEcO0dzUrfsiOQRTtSFM8cs=
github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250823082506-db38fc041d02/go.mod h1:z1lkiAE5ex5gHBzh5+G9TFsyM9grOaSsRx33mVfWfVI= github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250824024715-dd18aa2b8633/go.mod h1:z1lkiAE5ex5gHBzh5+G9TFsyM9grOaSsRx33mVfWfVI=
github.com/sagernet/sing-vmess v0.2.7 h1:2ee+9kO0xW5P4mfe6TYVWf9VtY8k1JhNysBqsiYj0sk= github.com/sagernet/sing-vmess v0.2.7 h1:2ee+9kO0xW5P4mfe6TYVWf9VtY8k1JhNysBqsiYj0sk=
github.com/sagernet/sing-vmess v0.2.7/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs= github.com/sagernet/sing-vmess v0.2.7/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs=
github.com/sagernet/smux v1.5.34-mod.2 h1:gkmBjIjlJ2zQKpLigOkFur5kBKdV6bNRoFu2WkltRQ4= github.com/sagernet/smux v1.5.34-mod.2 h1:gkmBjIjlJ2zQKpLigOkFur5kBKdV6bNRoFu2WkltRQ4=
github.com/sagernet/smux v1.5.34-mod.2/go.mod h1:0KW0+R+ycvA2INW4gbsd7BNyg+HEfLIAxa5N02/28Zc= github.com/sagernet/smux v1.5.34-mod.2/go.mod h1:0KW0+R+ycvA2INW4gbsd7BNyg+HEfLIAxa5N02/28Zc=
github.com/sagernet/tailscale v1.80.3-mod.5 h1:7V7z+p2C//TGtff20pPnDCt3qP6uFyY62peJoKF9z/A= github.com/sagernet/tailscale v1.80.3-mod.6 h1:oJs0jpRNS/12+mPf3r9maxWl9dWy1RanugLNmsF74Gs=
github.com/sagernet/tailscale v1.80.3-mod.5/go.mod h1:EBxXsWu4OH2ELbQLq32WoBeIubG8KgDrg4/Oaxjs6lI= github.com/sagernet/tailscale v1.80.3-mod.6/go.mod h1:EBxXsWu4OH2ELbQLq32WoBeIubG8KgDrg4/Oaxjs6lI=
github.com/sagernet/wireguard-go v0.0.1-beta.7 h1:ltgBwYHfr+9Wz1eG59NiWnHrYEkDKHG7otNZvu85DXI= github.com/sagernet/wireguard-go v0.0.1-beta.7 h1:ltgBwYHfr+9Wz1eG59NiWnHrYEkDKHG7otNZvu85DXI=
github.com/sagernet/wireguard-go v0.0.1-beta.7/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo= github.com/sagernet/wireguard-go v0.0.1-beta.7/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=

View File

@ -31,6 +31,7 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
@ -57,7 +58,10 @@ import (
"go4.org/netipx" "go4.org/netipx"
) )
var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) var (
_ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
_ adapter.DirectRouteOutbound = (*Endpoint)(nil)
)
func init() { func init() {
version.SetVersion("sing-box " + C.Version) version.SetVersion("sing-box " + C.Version)
@ -77,6 +81,7 @@ type Endpoint struct {
platformInterface platform.Interface platformInterface platform.Interface
server *tsnet.Server server *tsnet.Server
stack *stack.Stack stack *stack.Stack
icmpForwarder *tun.ICMPForwarder
filter *atomic.Pointer[filter.Filter] filter *atomic.Pointer[filter.Filter]
onReconfigHook wgengine.ReconfigListener onReconfigHook wgengine.ReconfigListener
@ -176,7 +181,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
}, },
} }
return &Endpoint{ return &Endpoint{
Adapter: endpoint.NewAdapter(C.TypeTailscale, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil), Adapter: endpoint.NewAdapter(C.TypeTailscale, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMPv4, N.NetworkICMPv6}, nil),
ctx: ctx, ctx: ctx,
router: router, router: router,
logger: logger, logger: logger,
@ -242,10 +247,11 @@ func (t *Endpoint) Start(stage adapter.StartStage) error {
} }
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(t.ctx, ipStack, t).HandlePacket) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(t.ctx, ipStack, t).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout).HandlePacket)
icmpForwarder := tun.NewICMPForwarder(t.ctx, ipStack, netip.Addr{}, netip.Addr{}, t, t.udpTimeout) icmpForwarder := tun.NewICMPForwarder(t.ctx, ipStack, t, t.udpTimeout)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
t.stack = ipStack t.stack = ipStack
t.icmpForwarder = icmpForwarder
localBackend := t.server.ExportLocalBackend() localBackend := t.server.ExportLocalBackend()
perfs := &ipn.MaskedPrefs{ perfs := &ipn.MaskedPrefs{
@ -485,6 +491,26 @@ func (t *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
t.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) t.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (t *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) {
inet4Address, inet6Address := t.server.TailscaleIPs()
if metadata.Destination.Addr.Is4() && !inet4Address.IsValid() || metadata.Destination.Addr.Is6() && !inet6Address.IsValid() {
return nil, E.New("Tailscale is not ready yet")
}
ctx := log.ContextWithNewID(t.ctx)
destination, err := ping.ConnectGVisor(
ctx, t.logger,
metadata.Source.Addr, metadata.Destination.Addr,
routeContext,
t.stack,
inet4Address, inet6Address,
)
if err != nil {
return nil, err
}
t.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
return destination, nil
}
func (t *Endpoint) PreferredDomain(domain string) bool { func (t *Endpoint) PreferredDomain(domain string) bool {
routeDomains := t.routeDomains.Load() routeDomains := t.routeDomains.Load()
if routeDomains == nil { if routeDomains == nil {
@ -512,6 +538,15 @@ func (t *Endpoint) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCf
if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) { if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) {
return return
} }
var inet4Address, inet6Address netip.Addr
for _, address := range cfg.Addresses {
if address.Addr().Is4() {
inet4Address = address.Addr()
} else if address.Addr().Is6() {
inet6Address = address.Addr()
}
}
t.icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
t.cfg = cfg t.cfg = cfg
t.dnsCfg = dnsCfg t.dnsCfg = dnsCfg

View File

@ -38,7 +38,7 @@ type DeviceOptions struct {
func NewDevice(options DeviceOptions) (Device, error) { func NewDevice(options DeviceOptions) (Device, error) {
if !options.System { if !options.System {
return newStackDevice(options) return newStackDevice(options)
} else if options.Handler == nil { } else if !tun.WithGVisor {
return newSystemDevice(options) return newSystemDevice(options)
} else { } else {
return newSystemStackDevice(options) return newSystemStackDevice(options)

View File

@ -3,6 +3,7 @@ package wireguard
import ( import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
) )
@ -10,29 +11,34 @@ var _ Device = (*natDeviceWrapper)(nil)
type natDeviceWrapper struct { type natDeviceWrapper struct {
Device Device
gVisorOutbound
packetOutbound chan *buf.Buffer packetOutbound chan *buf.Buffer
mapping *tun.NatMapping rewriter *ping.Rewriter
writer *tun.NatWriter
buffer [][]byte buffer [][]byte
} }
func NewNATDevice(upstream Device, ipRewrite bool) NatDevice { func NewNATDevice(upstream Device) NatDevice {
wrapper := &natDeviceWrapper{ wrapper := &natDeviceWrapper{
Device: upstream, Device: upstream,
gVisorOutbound: newGVisorOutbound(),
packetOutbound: make(chan *buf.Buffer, 256), packetOutbound: make(chan *buf.Buffer, 256),
mapping: tun.NewNatMapping(ipRewrite), rewriter: ping.NewRewriter(upstream.Inet4Address(), upstream.Inet6Address()),
}
if ipRewrite {
wrapper.writer = tun.NewNatWriter(upstream.Inet4Address(), upstream.Inet6Address())
} }
return wrapper return wrapper
} }
func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
select {
case packet := <-d.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
return 1, nil
default:
}
return d.Device.Read(bufs, sizes, offset)
}
func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
for _, buffer := range bufs { for _, buffer := range bufs {
handled, err := d.mapping.WritePacket(buffer[offset:]) handled, err := d.rewriter.WriteBack(buffer[offset:])
if handled { if handled {
if err != nil { if err != nil {
return 0, err return 0, err
@ -56,30 +62,24 @@ func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, ro
Source: metadata.Source.Addr, Source: metadata.Source.Addr,
Destination: metadata.Destination.Addr, Destination: metadata.Destination.Addr,
} }
d.mapping.CreateSession(session, routeContext) d.rewriter.CreateSession(session, routeContext)
return &natDestinationWrapper{d, session}, nil return &natDestination{d, session}, nil
} }
var _ tun.DirectRouteDestination = (*natDestinationWrapper)(nil) var _ tun.DirectRouteDestination = (*natDestination)(nil)
type natDestinationWrapper struct { type natDestination struct {
device *natDeviceWrapper device *natDeviceWrapper
session tun.DirectRouteSession session tun.DirectRouteSession
} }
func (d *natDestinationWrapper) WritePacket(buffer *buf.Buffer) error { func (d *natDestination) WritePacket(buffer *buf.Buffer) error {
if d.device.writer != nil { d.device.rewriter.RewritePacket(buffer.Bytes())
d.device.writer.RewritePacket(buffer.Bytes())
}
d.device.packetOutbound <- buffer d.device.packetOutbound <- buffer
return nil return nil
} }
func (d *natDestinationWrapper) Close() error { func (d *natDestination) Close() error {
d.device.mapping.DeleteSession(d.session) d.device.rewriter.DeleteSession(d.session)
return nil return nil
} }
func (d *natDestinationWrapper) Timeout() bool {
return false
}

View File

@ -1,48 +0,0 @@
//go:build with_gvisor
package wireguard
import (
"github.com/sagernet/gvisor/pkg/tcpip/stack"
)
type gVisorOutbound struct {
outbound chan *stack.PacketBuffer
}
func newGVisorOutbound() gVisorOutbound {
return gVisorOutbound{
outbound: make(chan *stack.PacketBuffer, 256),
}
}
func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
select {
case packet := <-d.outbound:
defer packet.DecRef()
var copyN int
/*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
})*/
for _, view := range packet.AsSlices() {
copyN += copy(bufs[0][offset+copyN:], view)
}
sizes[0] = copyN
return 1, nil
case packet := <-d.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
return 1, nil
default:
}
return d.Device.Read(bufs, sizes, offset)
}
func (d *natDestinationWrapper) WritePacketBuffer(packetBuffer *stack.PacketBuffer) error {
println("read from wg")
if d.device.writer != nil {
d.device.writer.RewritePacketBuffer(packetBuffer)
}
d.device.outbound <- packetBuffer
return nil
}

View File

@ -1,20 +0,0 @@
//go:build !with_gvisor
package wireguard
type gVisorOutbound struct{}
func newGVisorOutbound() gVisorOutbound {
return gVisorOutbound{}
}
func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
select {
case packet := <-d.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
return 1, nil
default:
}
return d.Device.Read(bufs, sizes, offset)
}

View File

@ -19,7 +19,9 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -31,6 +33,8 @@ import (
var _ NatDevice = (*stackDevice)(nil) var _ NatDevice = (*stackDevice)(nil)
type stackDevice struct { type stackDevice struct {
ctx context.Context
logger log.ContextLogger
stack *stack.Stack stack *stack.Stack
mtu uint32 mtu uint32
events chan wgTun.Event events chan wgTun.Event
@ -38,25 +42,28 @@ type stackDevice struct {
packetOutbound chan *buf.Buffer packetOutbound chan *buf.Buffer
done chan struct{} done chan struct{}
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
addr4 tcpip.Address inet4Address netip.Addr
addr6 tcpip.Address inet6Address netip.Addr
mapping *tun.NatMapping
writer *tun.NatWriter
} }
func newStackDevice(options DeviceOptions) (*stackDevice, error) { func newStackDevice(options DeviceOptions) (*stackDevice, error) {
tunDevice := &stackDevice{ tunDevice := &stackDevice{
ctx: options.Context,
logger: options.Logger,
mtu: options.MTU, mtu: options.MTU,
events: make(chan wgTun.Event, 1), events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256), outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256), packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}), done: make(chan struct{}),
mapping: tun.NewNatMapping(true),
} }
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice)) ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var (
inet4Address netip.Addr
inet6Address netip.Addr
)
for _, prefix := range options.Address { for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr()) addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
@ -66,10 +73,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
}, },
} }
if prefix.Addr().Is4() { if prefix.Addr().Is4() {
tunDevice.addr4 = addr inet4Address = prefix.Addr()
tunDevice.inet4Address = inet4Address
protoAddr.Protocol = ipv4.ProtocolNumber protoAddr.Protocol = ipv4.ProtocolNumber
} else { } else {
tunDevice.addr6 = addr inet6Address = prefix.Addr()
tunDevice.inet6Address = inet6Address
protoAddr.Protocol = ipv6.ProtocolNumber protoAddr.Protocol = ipv6.ProtocolNumber
} }
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
@ -77,12 +86,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
} }
} }
tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address())
tunDevice.stack = ipStack tunDevice.stack = ipStack
if options.Handler != nil { if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
} }
@ -101,10 +110,10 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() { if destination.IsIPv4() {
networkProtocol = header.IPv4ProtocolNumber networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4 bind.Addr = tun.AddressFromAddr(w.inet4Address)
} else { } else {
networkProtocol = header.IPv6ProtocolNumber networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6 bind.Addr = tun.AddressFromAddr(w.inet4Address)
} }
switch N.NetworkName(network) { switch N.NetworkName(network) {
case N.NetworkTCP: case N.NetworkTCP:
@ -131,10 +140,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() { if destination.IsIPv4() {
networkProtocol = header.IPv4ProtocolNumber networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4 bind.Addr = tun.AddressFromAddr(w.inet4Address)
} else { } else {
networkProtocol = header.IPv6ProtocolNumber networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6 bind.Addr = tun.AddressFromAddr(w.inet4Address)
} }
udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
if err != nil { if err != nil {
@ -144,11 +153,11 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
} }
func (w *stackDevice) Inet4Address() netip.Addr { func (w *stackDevice) Inet4Address() netip.Addr {
return netip.AddrFrom4(w.addr4.As4()) return w.inet4Address
} }
func (w *stackDevice) Inet6Address() netip.Addr { func (w *stackDevice) Inet6Address() netip.Addr {
return netip.AddrFrom16(w.addr6.As16()) return w.inet6Address
} }
func (w *stackDevice) SetDevice(device *device.Device) { func (w *stackDevice) SetDevice(device *device.Device) {
@ -194,14 +203,6 @@ func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if len(b) == 0 { if len(b) == 0 {
continue continue
} }
handled, err := w.mapping.WritePacket(b)
if handled {
if err != nil {
return count, err
}
count++
continue
}
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(b) { switch header.IPVersion(b) {
case header.IPv4Version: case header.IPv4Version:
@ -250,6 +251,22 @@ func (w *stackDevice) BatchSize() int {
return 1 return 1
} }
func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) {
ctx := log.ContextWithNewID(w.ctx)
destination, err := ping.ConnectGVisor(
ctx, w.logger,
metadata.Source.Addr, metadata.Destination.Addr,
routeContext,
w.stack,
w.inet4Address, w.inet6Address,
)
if err != nil {
return nil, err
}
w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
return destination, nil
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil) var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint stackDevice type wireEndpoint stackDevice
@ -315,157 +332,3 @@ func (ep *wireEndpoint) Close() {
func (ep *wireEndpoint) SetOnCloseAction(f func()) { func (ep *wireEndpoint) SetOnCloseAction(f func()) {
} }
func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) {
/* var wq waiter.Queue
ep, err := raw.NewEndpoint(w.stack, ipv4.ProtocolNumber, icmp.ProtocolNumber4, &wq)
if err != nil {
return nil, E.Cause(gonet.TranslateNetstackError(err), "create endpoint")
}
err = ep.Connect(tcpip.FullAddress{
NIC: tun.DefaultNIC,
Port: metadata.Destination.Port,
Addr: tun.AddressFromAddr(metadata.Destination.Addr),
})
if err != nil {
ep.Close()
return nil, E.Cause(gonet.TranslateNetstackError(err), "ICMP connect ", metadata.Destination)
}
fmt.Println("linked ", metadata.Network, " connection to ", metadata.Destination.AddrString())
destination := &endpointNatDestination{
ep: ep,
wq: &wq,
context: routeContext,
}
go destination.loopRead()
return destination, nil*/
session := tun.DirectRouteSession{
Source: metadata.Source.Addr,
Destination: metadata.Destination.Addr,
}
w.mapping.CreateSession(session, routeContext)
return &stackNatDestination{
device: w,
session: session,
}, nil
}
type stackNatDestination struct {
device *stackDevice
session tun.DirectRouteSession
}
func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error {
if d.device.writer != nil {
d.device.writer.RewritePacket(buffer.Bytes())
}
d.device.packetOutbound <- buffer
return nil
}
func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error {
if d.device.writer != nil {
d.device.writer.RewritePacketBuffer(buffer)
}
d.device.outbound <- buffer
return nil
}
func (d *stackNatDestination) Close() error {
d.device.mapping.DeleteSession(d.session)
return nil
}
func (d *stackNatDestination) Timeout() bool {
return false
}
/*type endpointNatDestination struct {
ep tcpip.Endpoint
wq *waiter.Queue
networkProto tcpip.NetworkProtocolNumber
context tun.DirectRouteContext
done chan struct{}
}
func (d *endpointNatDestination) loopRead() {
for {
println("start read")
buffer, err := commonRead(d.ep, d.wq, d.done)
if err != nil {
log.Error(err)
return
}
println("done read")
ipHdr := header.IPv4(buffer.Bytes())
if ipHdr.TransportProtocol() != header.ICMPv4ProtocolNumber {
buffer.Release()
continue
}
icmpHdr := header.ICMPv4(ipHdr.Payload())
if icmpHdr.Type() != header.ICMPv4EchoReply {
buffer.Release()
continue
}
fmt.Println("read echo reply")
_ = d.context.WritePacket(ipHdr)
buffer.Release()
}
}
func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, done chan struct{}) (*buf.Buffer, error) {
buffer := buf.NewPacket()
result, err := ep.Read(buffer, tcpip.ReadOptions{})
if err != nil {
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents)
wq.EventRegister(&waitEntry)
defer wq.EventUnregister(&waitEntry)
for {
result, err = ep.Read(buffer, tcpip.ReadOptions{})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
select {
case <-notifyCh:
case <-done:
buffer.Release()
return nil, context.DeadlineExceeded
}
}
}
return nil, gonet.TranslateNetstackError(err)
}
buffer.Truncate(result.Count)
return buffer, nil
}
func (d *endpointNatDestination) WritePacket(buffer *buf.Buffer) error {
_, err := d.ep.Write(buffer, tcpip.WriteOptions{})
if err != nil {
return gonet.TranslateNetstackError(err)
}
return nil
}
func (d *endpointNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error {
data := buffer.ToView().AsSlice()
println("write echo request buffer :" + fmt.Sprint(data))
_, err := d.ep.Write(bytes.NewReader(data), tcpip.WriteOptions{})
if err != nil {
log.Error(err)
return gonet.TranslateNetstackError(err)
}
return nil
}
func (d *endpointNatDestination) Close() error {
d.ep.Abort()
close(d.done)
return nil
}
func (d *endpointNatDestination) Timeout() bool {
return false
}
*/

View File

@ -22,14 +22,14 @@ import (
var _ Device = (*systemDevice)(nil) var _ Device = (*systemDevice)(nil)
type systemDevice struct { type systemDevice struct {
options DeviceOptions options DeviceOptions
dialer N.Dialer dialer N.Dialer
device tun.Tun device tun.Tun
batchDevice tun.LinuxTUN batchDevice tun.LinuxTUN
events chan wgTun.Event events chan wgTun.Event
closeOnce sync.Once closeOnce sync.Once
addr4 netip.Addr inet4Address netip.Addr
addr6 netip.Addr inet6Address netip.Addr
} }
func newSystemDevice(options DeviceOptions) (*systemDevice, error) { func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
@ -53,11 +53,11 @@ func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
} }
} }
return &systemDevice{ return &systemDevice{
options: options, options: options,
dialer: options.CreateDialer(options.Name), dialer: options.CreateDialer(options.Name),
events: make(chan wgTun.Event, 1), events: make(chan wgTun.Event, 1),
addr4: inet4Address, inet4Address: inet4Address,
addr6: inet6Address, inet6Address: inet6Address,
}, nil }, nil
} }
@ -70,11 +70,11 @@ func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
} }
func (w *systemDevice) Inet4Address() netip.Addr { func (w *systemDevice) Inet4Address() netip.Addr {
return w.addr4 return w.inet4Address
} }
func (w *systemDevice) Inet6Address() netip.Addr { func (w *systemDevice) Inet6Address() netip.Addr {
return w.addr6 return w.inet6Address
} }
func (w *systemDevice) SetDevice(device *device.Device) { func (w *systemDevice) SetDevice(device *device.Device) {

View File

@ -3,16 +3,25 @@
package wireguard package wireguard
import ( import (
"context"
"net/netip" "net/netip"
"github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/wireguard-go/device" "github.com/sagernet/wireguard-go/device"
) )
@ -20,6 +29,8 @@ var _ Device = (*systemStackDevice)(nil)
type systemStackDevice struct { type systemStackDevice struct {
*systemDevice *systemDevice
ctx context.Context
logger logger.ContextLogger
stack *stack.Stack stack *stack.Stack
endpoint *deviceEndpoint endpoint *deviceEndpoint
writeBufs [][]byte writeBufs [][]byte
@ -34,13 +45,45 @@ func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
mtu: options.MTU, mtu: options.MTU,
done: make(chan struct{}), done: make(chan struct{}),
} }
ipStack, err := tun.NewGVisorStack(endpoint) ipStack, err := tun.NewGVisorStackWithOptions(endpoint, stack.NICOptions{}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) var (
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) inet4Address netip.Addr
inet6Address netip.Addr
)
for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: prefix.Bits(),
},
}
if prefix.Addr().Is4() {
inet4Address = prefix.Addr()
protoAddr.Protocol = ipv4.ProtocolNumber
} else {
inet6Address = prefix.Addr()
protoAddr.Protocol = ipv6.ProtocolNumber
}
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
}
}
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
}
return &systemStackDevice{ return &systemStackDevice{
ctx: options.Context,
logger: options.Logger,
systemDevice: system, systemDevice: system,
stack: ipStack, stack: ipStack,
endpoint: endpoint, endpoint: endpoint,
@ -116,6 +159,22 @@ func (w *systemStackDevice) writeStack(packet []byte) bool {
return true return true
} }
func (w *systemStackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) {
ctx := log.ContextWithNewID(w.ctx)
destination, err := ping.ConnectGVisor(
ctx, w.logger,
metadata.Source.Addr, metadata.Destination.Addr,
routeContext,
w.stack,
w.inet4Address, w.inet6Address,
)
if err != nil {
return nil, err
}
w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
return destination, nil
}
type deviceEndpoint struct { type deviceEndpoint struct {
mtu uint32 mtu uint32
done chan struct{} done chan struct{}

View File

@ -119,7 +119,7 @@ func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
} }
natDevice, isNatDevice := tunDevice.(NatDevice) natDevice, isNatDevice := tunDevice.(NatDevice)
if !isNatDevice { if !isNatDevice {
natDevice = NewNATDevice(tunDevice, true) natDevice = NewNATDevice(tunDevice)
} }
return &Endpoint{ return &Endpoint{
options: options, options: options,