From 610ed9e2ff7541a596f7fe214a6ac6babcaf6a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 17 Feb 2025 22:00:57 +0800 Subject: [PATCH] Add proxy support for ICMP echo request --- adapter/outbound.go | 8 ++ adapter/router.go | 4 +- common/dialer/default.go | 8 ++ dns/transport/local/local_resolved_linux.go | 2 +- docs/configuration/route/rule.md | 13 ++- docs/configuration/route/rule.zh.md | 13 ++- go.mod | 8 +- go.sum | 16 +-- protocol/direct/outbound.go | 17 ++- protocol/tailscale/endpoint.go | 83 ++++++++++++--- protocol/tun/inbound.go | 20 +++- protocol/wireguard/endpoint.go | 27 ++++- protocol/wireguard/outbound.go | 8 +- route/route.go | 40 +++++-- transport/wireguard/device.go | 10 +- transport/wireguard/device_nat.go | 103 ++++++++++++++++++ transport/wireguard/device_stack.go | 112 ++++++++++++++------ transport/wireguard/device_system.go | 46 ++++++-- transport/wireguard/device_system_stack.go | 67 +++++++++++- transport/wireguard/endpoint.go | 24 ++++- 20 files changed, 535 insertions(+), 94 deletions(-) create mode 100644 transport/wireguard/device_nat.go diff --git a/adapter/outbound.go b/adapter/outbound.go index 517aa0fe..91fb9c65 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -3,9 +3,11 @@ package adapter import ( "context" "net/netip" + "time" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" ) @@ -20,10 +22,16 @@ type Outbound interface { } type OutboundWithPreferredRoutes interface { + Outbound PreferredDomain(domain string) bool PreferredAddress(address netip.Addr) bool } +type DirectRouteOutbound interface { + Outbound + NewDirectRouteConnection(metadata InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) +} + type OutboundRegistry interface { option.OutboundOptionsRegistry CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) diff --git a/adapter/router.go b/adapter/router.go index 0b7c8f4f..522a0d9d 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -6,8 +6,10 @@ import ( "net" "net/http" "sync" + "time" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-tun" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" @@ -19,7 +21,7 @@ import ( type Router interface { Lifecycle ConnectionRouter - PreMatch(metadata InboundContext) error + PreMatch(metadata InboundContext, context tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) ConnectionRouterEx RuleSet(tag string) (RuleSet, bool) NeedWIFIState() bool diff --git a/common/dialer/default.go b/common/dialer/default.go index dd1e90e5..ae3908a6 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -317,6 +317,14 @@ func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd } } +func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer { + if !destination.Is6() { + return dialerFromTCPDialer(d.dialer6) + } else { + return dialerFromTCPDialer(d.dialer4) + } +} + func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { if strategy == nil { strategy = d.networkStrategy diff --git a/dns/transport/local/local_resolved_linux.go b/dns/transport/local/local_resolved_linux.go index 279f9c8e..6fbd0a0e 100644 --- a/dns/transport/local/local_resolved_linux.go +++ b/dns/transport/local/local_resolved_linux.go @@ -5,12 +5,12 @@ import ( "errors" "os" "sync" + "sync/atomic" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/service/resolved" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" diff --git a/docs/configuration/route/rule.md b/docs/configuration/route/rule.md index 89eecaaf..a6e89d7b 100644 --- a/docs/configuration/route/rule.md +++ b/docs/configuration/route/rule.md @@ -7,7 +7,8 @@ icon: material/new-box :material-plus: [interface_address](#interface_address) :material-plus: [network_interface_address](#network_interface_address) :material-plus: [default_interface_address](#default_interface_address) - :material-plus: [preferred_by](#preferred_by) + :material-plus: [preferred_by](#preferred_by) + :material-alert: [network](#network) !!! quote "Changes in sing-box 1.11.0" @@ -226,7 +227,15 @@ Sniffed client type, see [Protocol Sniff](/configuration/route/sniff/) for detai #### network -`tcp` or `udp`. +!!! quote "Changes in sing-box 1.13.0" + + Since sing-box 1.13.0, you can match ICMP echo (ping) requests via the new `icmp` network. + + Such traffic originates from `TUN`, `WireGuard`, and `Tailscale` inbounds and can be routed to `Direct`, `WireGuard`, and `Tailscale` outbounds. + +Match network type. + +`tcp`, `udp` or `icmp`. #### domain diff --git a/docs/configuration/route/rule.zh.md b/docs/configuration/route/rule.zh.md index 100344da..a90607d4 100644 --- a/docs/configuration/route/rule.zh.md +++ b/docs/configuration/route/rule.zh.md @@ -7,7 +7,8 @@ icon: material/new-box :material-plus: [interface_address](#interface_address) :material-plus: [network_interface_address](#network_interface_address) :material-plus: [default_interface_address](#default_interface_address) - :material-plus: [preferred_by](#preferred_by) + :material-plus: [preferred_by](#preferred_by) + :material-alert: [network](#network) !!! quote "sing-box 1.11.0 中的更改" @@ -223,7 +224,15 @@ icon: material/new-box #### network -`tcp` 或 `udp`。 +!!! quote "sing-box 1.13.0 中的更改" + + 自 sing-box 1.13.0 起,您可以通过新的 `icmp` 网络匹配 ICMP 回显(ping)请求。 + + 此类流量源自 `TUN`、`WireGuard` 和 `Tailscale` 入站,并可路由至 `Direct`、`WireGuard` 和 `Tailscale` 出站。 + +匹配网络类型。 + +`tcp`、`udp` 或 `icmp`。 #### domain diff --git a/go.mod b/go.mod index dccba18f..7d3d14a2 100644 --- a/go.mod +++ b/go.mod @@ -25,18 +25,18 @@ require ( github.com/sagernet/cors v1.2.1 github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gomobile v0.1.8 - github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb + github.com/sagernet/gvisor v0.0.0-20250822052253-5558536cf237 github.com/sagernet/quic-go v0.52.0-beta.1 - github.com/sagernet/sing v0.7.6-0.20250825114712-2aeec120ce28 + github.com/sagernet/sing v0.7.6-0.20250825141840-811aa328e57b github.com/sagernet/sing-mux v0.3.3 github.com/sagernet/sing-quic v0.5.0 github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 - github.com/sagernet/sing-tun v0.7.0-beta.1 + github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250826030950-79e2d3b56d01 github.com/sagernet/sing-vmess v0.2.7 github.com/sagernet/smux v1.5.34-mod.2 - github.com/sagernet/tailscale v1.80.3-mod.5 + github.com/sagernet/tailscale v1.80.3-mod.6 github.com/sagernet/wireguard-go v0.0.1-beta.7 github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 github.com/spf13/cobra v1.9.1 diff --git a/go.sum b/go.sum index 86e4b287..bc2ec8fc 100644 --- a/go.sum +++ b/go.sum @@ -158,8 +158,8 @@ github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQ github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o= github.com/sagernet/gomobile v0.1.8 h1:vXgoN0pjsMONAaYCTdsKBX2T1kxuS7sbT/mZ7PElGoo= github.com/sagernet/gomobile v0.1.8/go.mod h1:A8l3FlHi2D/+mfcd4HHvk5DGFPW/ShFb9jHP5VmSiDY= -github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb h1:pprQtDqNgqXkRsXn+0E8ikKOemzmum8bODjSfDene38= -github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb/go.mod h1:QkkPEJLw59/tfxgapHta14UL5qMUah5NXhO0Kw2Kan4= +github.com/sagernet/gvisor v0.0.0-20250822052253-5558536cf237 h1:SUPFNB+vSP4RBPrSEgNII+HkfqC8hKMpYLodom4o4EU= +github.com/sagernet/gvisor v0.0.0-20250822052253-5558536cf237/go.mod h1:QkkPEJLw59/tfxgapHta14UL5qMUah5NXhO0Kw2Kan4= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= @@ -167,8 +167,8 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs= github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4= github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= -github.com/sagernet/sing v0.7.6-0.20250825114712-2aeec120ce28 h1:C8Lnqd0Q+C15kwaMiDsfq5S45rhhaQMBG91TT+6oFVo= -github.com/sagernet/sing v0.7.6-0.20250825114712-2aeec120ce28/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.7.6-0.20250825141840-811aa328e57b h1:RCfo1Q6VDAXfumNupRyqTomKzDODhASswkxVCqM8l2M= +github.com/sagernet/sing v0.7.6-0.20250825141840-811aa328e57b/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw= github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA= github.com/sagernet/sing-quic v0.5.0 h1:jNLIyVk24lFPvu8A4x+ZNEnZdI+Tg1rp7eCJ6v0Csak= @@ -179,14 +179,14 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= -github.com/sagernet/sing-tun v0.7.0-beta.1 h1:mBIFXYAnGO5ey/HcCYanqnBx61E7yF8zTFGRZonGYmY= -github.com/sagernet/sing-tun v0.7.0-beta.1/go.mod h1:AHJuRrLbNRJuivuFZ2VhXwDj4ViYp14szG5EkkKAqRQ= +github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250826030950-79e2d3b56d01 h1:eUVH7DY/1P/EwNSV5fwgkT3IlXY9AyxFThgi0liGFmI= +github.com/sagernet/sing-tun v0.7.0-beta.1.0.20250826030950-79e2d3b56d01/go.mod h1:LokZYuEV3crByjQc/XRohLgfNvybtXdx5qe/I4W6S7k= github.com/sagernet/sing-vmess v0.2.7 h1:2ee+9kO0xW5P4mfe6TYVWf9VtY8k1JhNysBqsiYj0sk= github.com/sagernet/sing-vmess v0.2.7/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs= github.com/sagernet/smux v1.5.34-mod.2 h1:gkmBjIjlJ2zQKpLigOkFur5kBKdV6bNRoFu2WkltRQ4= github.com/sagernet/smux v1.5.34-mod.2/go.mod h1:0KW0+R+ycvA2INW4gbsd7BNyg+HEfLIAxa5N02/28Zc= -github.com/sagernet/tailscale v1.80.3-mod.5 h1:7V7z+p2C//TGtff20pPnDCt3qP6uFyY62peJoKF9z/A= -github.com/sagernet/tailscale v1.80.3-mod.5/go.mod h1:EBxXsWu4OH2ELbQLq32WoBeIubG8KgDrg4/Oaxjs6lI= +github.com/sagernet/tailscale v1.80.3-mod.6 h1:oJs0jpRNS/12+mPf3r9maxWl9dWy1RanugLNmsF74Gs= +github.com/sagernet/tailscale v1.80.3-mod.6/go.mod h1:EBxXsWu4OH2ELbQLq32WoBeIubG8KgDrg4/Oaxjs6lI= github.com/sagernet/wireguard-go v0.0.1-beta.7 h1:ltgBwYHfr+9Wz1eG59NiWnHrYEkDKHG7otNZvu85DXI= github.com/sagernet/wireguard-go v0.0.1-beta.7/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 84838bc0..cd937053 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -13,6 +13,8 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -29,10 +31,12 @@ var ( _ N.ParallelDialer = (*Outbound)(nil) _ dialer.ParallelNetworkDialer = (*Outbound)(nil) _ dialer.DirectDialer = (*Outbound)(nil) + _ adapter.DirectRouteOutbound = (*Outbound)(nil) ) type Outbound struct { outbound.Adapter + ctx context.Context logger logger.ContextLogger dialer dialer.ParallelInterfaceDialer domainStrategy C.DomainStrategy @@ -58,7 +62,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions), + ctx: ctx, logger: logger, //nolint:staticcheck domainStrategy: C.DomainStrategy(options.DomainStrategy), @@ -146,6 +151,16 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return conn, nil } +func (h *Outbound) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + ctx := log.ContextWithNewID(h.ctx) + destination, err := ping.ConnectDestination(ctx, h.logger, common.MustCast[*dialer.DefaultDialer](h.dialer).DialerForICMPDestination(metadata.Destination.Addr).Control, metadata.Destination.Addr, routeContext, timeout) + if err != nil { + return nil, err + } + h.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + func (h *Outbound) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) { ctx, metadata := adapter.ExtendContext(ctx) metadata.Outbound = h.Tag() diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index 44c882de..09ec38cd 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -13,6 +13,7 @@ import ( "reflect" "runtime" "strings" + "sync/atomic" "syscall" "time" @@ -20,6 +21,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/sing-box/adapter" @@ -29,9 +31,10 @@ import ( "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" @@ -56,7 +59,10 @@ import ( "go4.org/netipx" ) -var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) +var ( + _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) + _ adapter.DirectRouteOutbound = (*Endpoint)(nil) +) func init() { version.SetVersion("sing-box " + C.Version) @@ -76,12 +82,13 @@ type Endpoint struct { platformInterface platform.Interface server *tsnet.Server stack *stack.Stack + icmpForwarder *tun.ICMPForwarder filter *atomic.Pointer[filter.Filter] onReconfigHook wgengine.ReconfigListener cfg *wgcfg.Config dnsCfg *tsDNS.Config - routeDomains atomic.TypedValue[map[string]bool] + routeDomains common.TypedValue[map[string]bool] routePrefixes atomic.Pointer[netipx.IPSet] acceptRoutes bool @@ -175,7 +182,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL }, } return &Endpoint{ - Adapter: endpoint.NewAdapter(C.TypeTailscale, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil), + Adapter: endpoint.NewAdapter(C.TypeTailscale, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, nil), ctx: ctx, router: router, logger: logger, @@ -240,9 +247,12 @@ func (t *Endpoint) Start(stage adapter.StartStage) error { return gonet.TranslateNetstackError(gErr) } ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(t.ctx, ipStack, t).HandlePacket) - udpForwarder := tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(t.ctx, ipStack, t, t.udpTimeout) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) t.stack = ipStack + t.icmpForwarder = icmpForwarder localBackend := t.server.ExportLocalBackend() perfs := &ipn.MaskedPrefs{ @@ -263,7 +273,7 @@ func (t *Endpoint) Start(stage adapter.StartStage) error { if err != nil { return E.Cause(err, "update prefs") } - t.filter = atomic.PointerForm(localBackend.ExportFilter()) + t.filter = localBackend.ExportFilter() go t.watchState() return nil } @@ -415,7 +425,7 @@ func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return udpConn, nil } -func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { +func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { tsFilter := t.filter.Load() if tsFilter != nil { var ipProto ipproto.Proto @@ -424,22 +434,41 @@ func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destina ipProto = ipproto.TCP case N.NetworkUDP: ipProto = ipproto.UDP + case N.NetworkICMP: + if !destination.IsIPv6() { + ipProto = ipproto.ICMPv4 + } else { + ipProto = ipproto.ICMPv6 + } } response := tsFilter.Check(source.Addr, destination.Addr, destination.Port, ipProto) switch response { case filter.Drop: - return syscall.ECONNRESET + return nil, syscall.ECONNREFUSED case filter.DropSilently: - return tun.ErrDrop + return nil, tun.ErrDrop } } - return t.router.PreMatch(adapter.InboundContext{ + var ipVersion uint8 + if !destination.IsIPv6() { + ipVersion = 4 + } else { + ipVersion = 6 + } + routeDestination, err := t.router.PreMatch(adapter.InboundContext{ Inbound: t.Tag(), InboundType: t.Type(), + IPVersion: ipVersion, Network: network, Source: source, Destination: destination, - }) + }, routeContext, timeout) + if err != nil { + if !rule.IsRejected(err) { + t.logger.Warn(E.Cause(err, "link ", network, " connection from ", source.AddrString(), " to ", destination.AddrString())) + } + } + return routeDestination, err } func (t *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { @@ -482,6 +511,27 @@ func (t *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, t.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } +func (t *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + inet4Address, inet6Address := t.server.TailscaleIPs() + if metadata.Destination.Addr.Is4() && !inet4Address.IsValid() || metadata.Destination.Addr.Is6() && !inet6Address.IsValid() { + return nil, E.New("Tailscale is not ready yet") + } + ctx := log.ContextWithNewID(t.ctx) + destination, err := ping.ConnectGVisor( + ctx, t.logger, + metadata.Source.Addr, metadata.Destination.Addr, + routeContext, + t.stack, + inet4Address, inet6Address, + timeout, + ) + if err != nil { + return nil, err + } + t.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + func (t *Endpoint) PreferredDomain(domain string) bool { routeDomains := t.routeDomains.Load() if routeDomains == nil { @@ -509,6 +559,15 @@ func (t *Endpoint) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCf if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) { return } + var inet4Address, inet6Address netip.Addr + for _, address := range cfg.Addresses { + if address.Addr().Is4() { + inet4Address = address.Addr() + } else if address.Addr().Is6() { + inet6Address = address.Addr() + } + } + t.icmpForwarder.SetLocalAddresses(inet4Address, inet6Address) t.cfg = cfg t.dnsCfg = dnsCfg diff --git a/protocol/tun/inbound.go b/protocol/tun/inbound.go index fc69309c..3f013598 100644 --- a/protocol/tun/inbound.go +++ b/protocol/tun/inbound.go @@ -18,6 +18,7 @@ import ( "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -454,15 +455,28 @@ func (t *Inbound) Close() error { ) } -func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { - return t.router.PreMatch(adapter.InboundContext{ +func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + var ipVersion uint8 + if !destination.IsIPv6() { + ipVersion = 4 + } else { + ipVersion = 6 + } + routeDestination, err := t.router.PreMatch(adapter.InboundContext{ Inbound: t.tag, InboundType: C.TypeTun, + IPVersion: ipVersion, Network: network, Source: source, Destination: destination, InboundOptions: t.inboundOptions, - }) + }, routeContext, timeout) + if err != nil { + if !rule.IsRejected(err) { + t.logger.Warn(E.Cause(err, "link ", network, " connection from ", source.AddrString(), " to ", destination.AddrString())) + } + } + return routeDestination, err } func (t *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 207670c2..811c6bb4 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -12,7 +12,9 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-box/transport/wireguard" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -40,7 +42,7 @@ type Endpoint struct { func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) { ep := &Endpoint{ - Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions), ctx: ctx, router: router, dnsRouter: service.FromContext[adapter.DNSRouter](ctx), @@ -124,14 +126,27 @@ func (w *Endpoint) Close() error { return w.endpoint.Close() } -func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { - return w.router.PreMatch(adapter.InboundContext{ +func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + var ipVersion uint8 + if !destination.IsIPv6() { + ipVersion = 4 + } else { + ipVersion = 6 + } + routeDestination, err := w.router.PreMatch(adapter.InboundContext{ Inbound: w.Tag(), InboundType: w.Type(), + IPVersion: ipVersion, Network: network, Source: source, Destination: destination, - }) + }, routeContext, timeout) + if err != nil { + if !rule.IsRejected(err) { + w.logger.Warn(E.Cause(err, "link ", network, " connection from ", source.AddrString(), " to ", destination.AddrString())) + } + } + return routeDestination, err } func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { @@ -220,3 +235,7 @@ func (w *Endpoint) PreferredDomain(domain string) bool { func (w *Endpoint) PreferredAddress(address netip.Addr) bool { return w.endpoint.Lookup(address) != nil } + +func (w *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + return w.endpoint.NewDirectRouteConnection(metadata, routeContext, timeout) +} diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index fa58d959..5b08c6a7 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/outbound" @@ -13,6 +14,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" + tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" @@ -42,7 +44,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL deprecated.Report(ctx, deprecated.OptionWireGuardGSO) } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions), ctx: ctx, dnsRouter: service.FromContext[adapter.DNSRouter](ctx), logger: logger, @@ -168,3 +170,7 @@ func (o *Outbound) PreferredDomain(domain string) bool { func (o *Outbound) PreferredAddress(address netip.Addr) bool { return o.endpoint.Lookup(address) != nil } + +func (o *Outbound) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + return o.endpoint.NewDirectRouteConnection(metadata, routeContext, timeout) +} diff --git a/route/route.go b/route/route.go index 20fbf4ec..9bca7304 100644 --- a/route/route.go +++ b/route/route.go @@ -17,6 +17,7 @@ import ( "github.com/sagernet/sing-box/option" R "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-mux" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing-vmess" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -258,19 +259,37 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m return nil } -func (r *Router) PreMatch(metadata adapter.InboundContext) error { +func (r *Router) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil) if err != nil { - return err + return nil, err } - if selectedRule == nil { - return nil + if selectedRule != nil { + switch action := selectedRule.Action().(type) { + case *R.RuleActionReject: + return nil, action.Error(context.Background()) + case *R.RuleActionRoute: + if routeContext == nil { + return nil, nil + } + outbound, loaded := r.outbound.Outbound(action.Outbound) + if !loaded { + return nil, E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(outbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by outbound: ", action.Outbound) + } + return outbound.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout) + } } - rejectAction, isReject := selectedRule.Action().(*R.RuleActionReject) - if !isReject { - return nil + if selectedRule != nil || metadata.Network != N.NetworkICMP { + return nil, nil } - return rejectAction.Error(context.Background()) + defaultOutbound := r.outbound.Default() + if !common.Contains(defaultOutbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by default outbound: ", defaultOutbound.Tag()) + } + return defaultOutbound.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout) } func (r *Router) matchRule( @@ -463,7 +482,7 @@ match: } else if len(newPacketBuffers) > 0 { packetBuffers = append(packetBuffers, newPacketBuffers...) } - } else { + } else if metadata.Network != N.NetworkICMP { selectedRule = currentRule selectedRuleIndex = currentRuleIndex break match @@ -477,8 +496,7 @@ match: actionType := currentRule.Action().Type() if actionType == C.RuleActionTypeRoute || actionType == C.RuleActionTypeReject || - actionType == C.RuleActionTypeHijackDNS || - (actionType == C.RuleActionTypeSniff && preMatch) { + actionType == C.RuleActionTypeHijackDNS { selectedRule = currentRule selectedRuleIndex = currentRuleIndex break match diff --git a/transport/wireguard/device.go b/transport/wireguard/device.go index 7a17b8f3..4dd615c5 100644 --- a/transport/wireguard/device.go +++ b/transport/wireguard/device.go @@ -5,6 +5,7 @@ import ( "net/netip" "time" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/logger" N "github.com/sagernet/sing/common/network" @@ -17,6 +18,8 @@ type Device interface { N.Dialer Start() error SetDevice(device *device.Device) + Inet4Address() netip.Addr + Inet6Address() netip.Addr } type DeviceOptions struct { @@ -35,9 +38,14 @@ type DeviceOptions struct { func NewDevice(options DeviceOptions) (Device, error) { if !options.System { return newStackDevice(options) - } else if options.Handler == nil { + } else if !tun.WithGVisor { return newSystemDevice(options) } else { return newSystemStackDevice(options) } } + +type NatDevice interface { + Device + CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) +} diff --git a/transport/wireguard/device_nat.go b/transport/wireguard/device_nat.go new file mode 100644 index 00000000..e5a28c1b --- /dev/null +++ b/transport/wireguard/device_nat.go @@ -0,0 +1,103 @@ +package wireguard + +import ( + "context" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/logger" +) + +var _ Device = (*natDeviceWrapper)(nil) + +type natDeviceWrapper struct { + Device + ctx context.Context + logger logger.ContextLogger + packetOutbound chan *buf.Buffer + rewriter *ping.Rewriter + buffer [][]byte +} + +func NewNATDevice(ctx context.Context, logger logger.ContextLogger, upstream Device) NatDevice { + wrapper := &natDeviceWrapper{ + Device: upstream, + ctx: ctx, + logger: logger, + packetOutbound: make(chan *buf.Buffer, 256), + rewriter: ping.NewRewriter(ctx, logger, upstream.Inet4Address(), upstream.Inet6Address()), + } + return wrapper +} + +func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + select { + case packet := <-d.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil + default: + } + return d.Device.Read(bufs, sizes, offset) +} + +func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { + for _, buffer := range bufs { + handled, err := d.rewriter.WriteBack(buffer[offset:]) + if handled { + if err != nil { + return 0, err + } + } else { + d.buffer = append(d.buffer, buffer) + } + } + if len(d.buffer) > 0 { + _, err := d.Device.Write(d.buffer, offset) + if err != nil { + return 0, err + } + d.buffer = d.buffer[:0] + } + return 0, nil +} + +func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + ctx := log.ContextWithNewID(d.ctx) + session := tun.DirectRouteSession{ + Source: metadata.Source.Addr, + Destination: metadata.Destination.Addr, + } + d.rewriter.CreateSession(session, routeContext) + d.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return &natDestination{device: d, session: session}, nil +} + +var _ tun.DirectRouteDestination = (*natDestination)(nil) + +type natDestination struct { + device *natDeviceWrapper + session tun.DirectRouteSession + closed atomic.Bool +} + +func (d *natDestination) WritePacket(buffer *buf.Buffer) error { + d.device.rewriter.RewritePacket(buffer.Bytes()) + d.device.packetOutbound <- buffer + return nil +} + +func (d *natDestination) Close() error { + d.closed.Store(true) + d.device.rewriter.DeleteSession(d.session) + return nil +} + +func (d *natDestination) IsClosed() bool { + return d.closed.Load() +} diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index f9440f02..8b7c40cd 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -5,7 +5,9 @@ package wireguard import ( "context" "net" + "net/netip" "os" + "time" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" @@ -14,9 +16,14 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -24,30 +31,40 @@ import ( wgTun "github.com/sagernet/wireguard-go/tun" ) -var _ Device = (*stackDevice)(nil) +var _ NatDevice = (*stackDevice)(nil) type stackDevice struct { - stack *stack.Stack - mtu uint32 - events chan wgTun.Event - outbound chan *stack.PacketBuffer - done chan struct{} - dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address + ctx context.Context + logger log.ContextLogger + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + outbound chan *stack.PacketBuffer + packetOutbound chan *buf.Buffer + done chan struct{} + dispatcher stack.NetworkDispatcher + inet4Address netip.Addr + inet6Address netip.Addr } func newStackDevice(options DeviceOptions) (*stackDevice, error) { tunDevice := &stackDevice{ - mtu: options.MTU, - events: make(chan wgTun.Event, 1), - outbound: make(chan *stack.PacketBuffer, 256), - done: make(chan struct{}), + ctx: options.Context, + logger: options.Logger, + mtu: options.MTU, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + packetOutbound: make(chan *buf.Buffer, 256), + done: make(chan struct{}), } - ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice)) + ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true) if err != nil { return nil, err } + var ( + inet4Address netip.Addr + inet6Address netip.Addr + ) for _, prefix := range options.Address { addr := tun.AddressFromAddr(prefix.Addr()) protoAddr := tcpip.ProtocolAddress{ @@ -57,10 +74,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) { }, } if prefix.Addr().Is4() { - tunDevice.addr4 = addr + inet4Address = prefix.Addr() + tunDevice.inet4Address = inet4Address protoAddr.Protocol = ipv4.ProtocolNumber } else { - tunDevice.addr6 = addr + inet6Address = prefix.Addr() + tunDevice.inet6Address = inet6Address protoAddr.Protocol = ipv6.ProtocolNumber } gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) @@ -72,6 +91,10 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) { if options.Handler != nil { ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) + icmpForwarder.SetLocalAddresses(inet4Address, inet6Address) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) } return tunDevice, nil } @@ -88,10 +111,10 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { networkProtocol = header.IPv4ProtocolNumber - bind.Addr = w.addr4 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } else { networkProtocol = header.IPv6ProtocolNumber - bind.Addr = w.addr6 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } switch N.NetworkName(network) { case N.NetworkTCP: @@ -118,10 +141,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { networkProtocol = header.IPv4ProtocolNumber - bind.Addr = w.addr4 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } else { networkProtocol = header.IPv6ProtocolNumber - bind.Addr = w.addr6 + bind.Addr = tun.AddressFromAddr(w.inet4Address) } udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) if err != nil { @@ -130,6 +153,14 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) return udpConn, nil } +func (w *stackDevice) Inet4Address() netip.Addr { + return w.inet4Address +} + +func (w *stackDevice) Inet6Address() netip.Addr { + return w.inet6Address +} + func (w *stackDevice) SetDevice(device *device.Device) { } @@ -144,20 +175,24 @@ func (w *stackDevice) File() *os.File { func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { select { - case packetBuffer, ok := <-w.outbound: + case packet, ok := <-w.outbound: if !ok { return 0, os.ErrClosed } - defer packetBuffer.DecRef() - p := bufs[0] - p = p[offset:] - n := 0 - for _, slice := range packetBuffer.AsSlices() { - n += copy(p[n:], slice) + defer packet.DecRef() + var copyN int + /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) { + copyN += copy(bufs[0][offset+copyN:], view.AsSlice()) + })*/ + for _, view := range packet.AsSlices() { + copyN += copy(bufs[0][offset+copyN:], view) } - sizes[0] = n - count = 1 - return + sizes[0] = copyN + return 1, nil + case packet := <-w.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil case <-w.done: return 0, os.ErrClosed } @@ -217,6 +252,23 @@ func (w *stackDevice) BatchSize() int { return 1 } +func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + ctx := log.ContextWithNewID(w.ctx) + destination, err := ping.ConnectGVisor( + ctx, w.logger, + metadata.Source.Addr, metadata.Destination.Addr, + routeContext, + w.stack, + w.inet4Address, w.inet6Address, + timeout, + ) + if err != nil { + return nil, err + } + w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + var _ stack.LinkEndpoint = (*wireEndpoint)(nil) type wireEndpoint stackDevice diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index fa54f332..162a5cbf 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -22,22 +22,42 @@ import ( var _ Device = (*systemDevice)(nil) type systemDevice struct { - options DeviceOptions - dialer N.Dialer - device tun.Tun - batchDevice tun.LinuxTUN - events chan wgTun.Event - closeOnce sync.Once + options DeviceOptions + dialer N.Dialer + device tun.Tun + batchDevice tun.LinuxTUN + events chan wgTun.Event + closeOnce sync.Once + inet4Address netip.Addr + inet6Address netip.Addr } func newSystemDevice(options DeviceOptions) (*systemDevice, error) { if options.Name == "" { options.Name = tun.CalculateInterfaceName("wg") } + var inet4Address netip.Addr + var inet6Address netip.Addr + if len(options.Address) > 0 { + if prefix := common.Find(options.Address, func(it netip.Prefix) bool { + return it.Addr().Is4() + }); prefix.IsValid() { + inet4Address = prefix.Addr() + } + } + if len(options.Address) > 0 { + if prefix := common.Find(options.Address, func(it netip.Prefix) bool { + return it.Addr().Is6() + }); prefix.IsValid() { + inet6Address = prefix.Addr() + } + } return &systemDevice{ - options: options, - dialer: options.CreateDialer(options.Name), - events: make(chan wgTun.Event, 1), + options: options, + dialer: options.CreateDialer(options.Name), + events: make(chan wgTun.Event, 1), + inet4Address: inet4Address, + inet6Address: inet6Address, }, nil } @@ -49,6 +69,14 @@ func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr return w.dialer.ListenPacket(ctx, destination) } +func (w *systemDevice) Inet4Address() netip.Addr { + return w.inet4Address +} + +func (w *systemDevice) Inet6Address() netip.Addr { + return w.inet6Address +} + func (w *systemDevice) SetDevice(device *device.Device) { } diff --git a/transport/wireguard/device_system_stack.go b/transport/wireguard/device_system_stack.go index 4249e53e..94fd6f4f 100644 --- a/transport/wireguard/device_system_stack.go +++ b/transport/wireguard/device_system_stack.go @@ -3,16 +3,26 @@ package wireguard import ( + "context" "net/netip" + "time" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" "github.com/sagernet/wireguard-go/device" ) @@ -20,6 +30,8 @@ var _ Device = (*systemStackDevice)(nil) type systemStackDevice struct { *systemDevice + ctx context.Context + logger logger.ContextLogger stack *stack.Stack endpoint *deviceEndpoint writeBufs [][]byte @@ -34,13 +46,45 @@ func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) { mtu: options.MTU, done: make(chan struct{}), } - ipStack, err := tun.NewGVisorStack(endpoint) + ipStack, err := tun.NewGVisorStackWithOptions(endpoint, stack.NICOptions{}, true) if err != nil { return nil, err } - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + var ( + inet4Address netip.Addr + inet6Address netip.Addr + ) + for _, prefix := range options.Address { + addr := tun.AddressFromAddr(prefix.Addr()) + protoAddr := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: prefix.Bits(), + }, + } + if prefix.Addr().Is4() { + inet4Address = prefix.Addr() + protoAddr.Protocol = ipv4.ProtocolNumber + } else { + inet6Address = prefix.Addr() + protoAddr.Protocol = ipv6.ProtocolNumber + } + gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) + if gErr != nil { + return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) + } + } + if options.Handler != nil { + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) + icmpForwarder.SetLocalAddresses(inet4Address, inet6Address) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) + } return &systemStackDevice{ + ctx: options.Context, + logger: options.Logger, systemDevice: system, stack: ipStack, endpoint: endpoint, @@ -116,6 +160,23 @@ func (w *systemStackDevice) writeStack(packet []byte) bool { return true } +func (w *systemStackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + ctx := log.ContextWithNewID(w.ctx) + destination, err := ping.ConnectGVisor( + ctx, w.logger, + metadata.Source.Addr, metadata.Destination.Addr, + routeContext, + w.stack, + w.inet4Address, w.inet6Address, + timeout, + ) + if err != nil { + return nil, err + } + w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString()) + return destination, nil +} + type deviceEndpoint struct { mtu uint32 done chan struct{} diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index 2adf7832..12718b91 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -10,8 +10,11 @@ import ( "os" "reflect" "strings" + "time" "unsafe" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -31,6 +34,7 @@ type Endpoint struct { ipcConf string allowedAddress []netip.Prefix tunDevice Device + natDevice NatDevice device *device.Device allowedIPs *device.AllowedIPs pause pause.Manager @@ -114,12 +118,17 @@ func NewEndpoint(options EndpointOptions) (*Endpoint, error) { if err != nil { return nil, E.Cause(err, "create WireGuard device") } + natDevice, isNatDevice := tunDevice.(NatDevice) + if !isNatDevice { + natDevice = NewNATDevice(options.Context, options.Logger, tunDevice) + } return &Endpoint{ options: options, peers: peers, ipcConf: ipcConf, allowedAddress: allowedAddresses, tunDevice: tunDevice, + natDevice: natDevice, }, nil } @@ -179,7 +188,13 @@ func (e *Endpoint) Start(resolve bool) error { e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) }, } - wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers) + var deviceInput Device + if e.natDevice != nil { + deviceInput = e.natDevice + } else { + deviceInput = e.tunDevice + } + wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers) e.tunDevice.SetDevice(wgDevice) ipcConf := e.ipcConf for _, peer := range e.peers { @@ -229,6 +244,13 @@ func (e *Endpoint) Lookup(address netip.Addr) *device.Peer { return e.allowedIPs.Lookup(address.AsSlice()) } +func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + if e.natDevice == nil { + return nil, os.ErrInvalid + } + return e.natDevice.CreateDestination(metadata, routeContext, timeout) +} + func (e *Endpoint) onPauseUpdated(event int) { switch event { case pause.EventDevicePaused, pause.EventNetworkPause: