diff --git a/adapter/outbound.go b/adapter/outbound.go index 2c2b1091..517aa0fe 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -2,6 +2,7 @@ package adapter import ( "context" + "net/netip" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -18,6 +19,11 @@ type Outbound interface { N.Dialer } +type OutboundWithPreferredRoutes interface { + PreferredDomain(domain string) bool + PreferredAddress(address netip.Addr) bool +} + type OutboundRegistry interface { option.OutboundOptionsRegistry CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) diff --git a/option/rule.go b/option/rule.go index d12b679d..44927e3a 100644 --- a/option/rule.go +++ b/option/rule.go @@ -103,6 +103,7 @@ type RawDefaultRule struct { InterfaceAddress *badjson.TypedMap[string, badoption.Listable[badoption.Prefixable]] `json:"interface_address,omitempty"` NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[badoption.Prefixable]] `json:"network_interface_address,omitempty"` DefaultInterfaceAddress badoption.Listable[badoption.Prefixable] `json:"default_interface_address,omitempty"` + PreferredBy badoption.Listable[string] `json:"preferred_by,omitempty"` RuleSet badoption.Listable[string] `json:"rule_set,omitempty"` RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"` Invert bool `json:"invert,omitempty"` diff --git a/protocol/tailscale/dns_transport.go b/protocol/tailscale/dns_transport.go index 3447b6b2..51115717 100644 --- a/protocol/tailscale/dns_transport.go +++ b/protocol/tailscale/dns_transport.go @@ -7,7 +7,6 @@ import ( "net/netip" "net/url" "os" - "reflect" "strings" "sync" @@ -47,8 +46,6 @@ type DNSTransport struct { acceptDefaultResolvers bool dnsRouter adapter.DNSRouter endpointManager adapter.EndpointManager - cfg *wgcfg.Config - dnsCfg *nDNS.Config endpoint *Endpoint routePrefixes []netip.Prefix routes map[string][]adapter.DNSTransport @@ -83,10 +80,10 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error { if !isTailscale { return E.New("endpoint is not Tailscale: ", t.endpointTag) } - if ep.onReconfig != nil { + if ep.onReconfigHook != nil { return E.New("only one Tailscale DNS server is allowed for single endpoint") } - ep.onReconfig = t.onReconfig + ep.onReconfigHook = t.onReconfig t.endpoint = ep return nil } @@ -95,14 +92,6 @@ func (t *DNSTransport) Reset() { } func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) { - if cfg == nil || dnsCfg == nil { - return - } - if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) { - return - } - t.cfg = cfg - t.dnsCfg = dnsCfg err := t.updateDNSServers(routerCfg, dnsCfg) if err != nil { t.logger.Error(E.Cause(err, "update DNS servers")) diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index 695811f2..44c882de 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -10,9 +10,9 @@ import ( "net/url" "os" "path/filepath" + "reflect" "runtime" "strings" - "sync/atomic" "syscall" "time" @@ -31,6 +31,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" @@ -49,8 +50,14 @@ import ( "github.com/sagernet/tailscale/version" "github.com/sagernet/tailscale/wgengine" "github.com/sagernet/tailscale/wgengine/filter" + "github.com/sagernet/tailscale/wgengine/router" + "github.com/sagernet/tailscale/wgengine/wgcfg" + + "go4.org/netipx" ) +var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) + func init() { version.SetVersion("sing-box " + C.Version) } @@ -70,7 +77,12 @@ type Endpoint struct { server *tsnet.Server stack *stack.Stack filter *atomic.Pointer[filter.Filter] - onReconfig wgengine.ReconfigListener + onReconfigHook wgengine.ReconfigListener + + cfg *wgcfg.Config + dnsCfg *tsDNS.Config + routeDomains atomic.TypedValue[map[string]bool] + routePrefixes atomic.Pointer[netipx.IPSet] acceptRoutes bool exitNode string @@ -216,9 +228,7 @@ func (t *Endpoint) Start(stage adapter.StartStage) error { if err != nil { return err } - if t.onReconfig != nil { - t.server.ExportLocalBackend().ExportEngine().(wgengine.ExportedUserspaceEngine).SetOnReconfigListener(t.onReconfig) - } + t.server.ExportLocalBackend().ExportEngine().(wgengine.ExportedUserspaceEngine).SetOnReconfigListener(t.onReconfig) ipStack := t.server.ExportNetstack().ExportIPStack() gErr := ipStack.SetSpoofing(tun.DefaultNIC, true) @@ -253,8 +263,7 @@ func (t *Endpoint) Start(stage adapter.StartStage) error { if err != nil { return E.Cause(err, "update prefs") } - t.filter = localBackend.ExportFilter() - + t.filter = atomic.PointerForm(localBackend.ExportFilter()) go t.watchState() return nil } @@ -473,10 +482,58 @@ func (t *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, t.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } +func (t *Endpoint) PreferredDomain(domain string) bool { + routeDomains := t.routeDomains.Load() + if routeDomains == nil { + return false + } + return routeDomains[strings.ToLower(domain)] +} + +func (t *Endpoint) PreferredAddress(address netip.Addr) bool { + routePrefixes := t.routePrefixes.Load() + if routePrefixes == nil { + return false + } + return routePrefixes.Contains(address) +} + func (t *Endpoint) Server() *tsnet.Server { return t.server } +func (t *Endpoint) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *tsDNS.Config) { + if cfg == nil || dnsCfg == nil { + return + } + if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) { + return + } + t.cfg = cfg + t.dnsCfg = dnsCfg + + routeDomains := make(map[string]bool) + for fqdn := range dnsCfg.Routes { + routeDomains[fqdn.WithoutTrailingDot()] = true + } + for _, fqdn := range dnsCfg.SearchDomains { + routeDomains[fqdn.WithoutTrailingDot()] = true + } + t.routeDomains.Store(routeDomains) + + var builder netipx.IPSetBuilder + for _, peer := range cfg.Peers { + for _, allowedIP := range peer.AllowedIPs { + builder.AddPrefix(allowedIP) + } + } + t.routePrefixes.Store(common.Must1(builder.IPSet())) + + if t.onReconfigHook != nil { + t.onReconfigHook(cfg, routerCfg, dnsCfg) + } +} + func addressFromAddr(destination netip.Addr) tcpip.Address { if destination.Is6() { return tcpip.AddrFrom16(destination.As16()) diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 4165d126..207670c2 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -22,6 +22,8 @@ import ( "github.com/sagernet/sing/service" ) +var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) + func RegisterEndpoint(registry *endpoint.Registry) { endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint) } @@ -210,3 +212,11 @@ func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n } return w.endpoint.ListenPacket(ctx, destination) } + +func (w *Endpoint) PreferredDomain(domain string) bool { + return false +} + +func (w *Endpoint) PreferredAddress(address netip.Addr) bool { + return w.endpoint.Lookup(address) != nil +} diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index 129e69b8..fa58d959 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -21,6 +21,8 @@ import ( "github.com/sagernet/sing/service" ) +var _ adapter.OutboundWithPreferredRoutes = (*Outbound)(nil) + func RegisterOutbound(registry *outbound.Registry) { outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound) } @@ -158,3 +160,11 @@ func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n } return o.endpoint.ListenPacket(ctx, destination) } + +func (o *Outbound) PreferredDomain(domain string) bool { + return false +} + +func (o *Outbound) PreferredAddress(address netip.Addr) bool { + return o.endpoint.Lookup(address) != nil +} diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index e0677b97..66a6e5a7 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -117,7 +117,7 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio if len(options.DomainRegex) > 0 { item, err := NewDomainRegexItem(options.DomainRegex) if err != nil { - return nil, E.Cause(err, "domain_regex") + return nil, err } rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.allItems = append(rule.allItems, item) @@ -261,6 +261,11 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.PreferredBy) > 0 { + item := NewPreferredByItem(ctx, options.PreferredBy) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.RuleSet) > 0 { var matchSource bool if options.RuleSetIPCIDRMatchSource { diff --git a/route/rule/rule_item_preferred_by.go b/route/rule/rule_item_preferred_by.go new file mode 100644 index 00000000..42c8a627 --- /dev/null +++ b/route/rule/rule_item_preferred_by.go @@ -0,0 +1,86 @@ +package rule + +import ( + "context" + "strings" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/service" +) + +var _ RuleItem = (*PreferredByItem)(nil) + +type PreferredByItem struct { + ctx context.Context + outboundTags []string + outbounds []adapter.OutboundWithPreferredRoutes +} + +func NewPreferredByItem(ctx context.Context, outboundTags []string) *PreferredByItem { + return &PreferredByItem{ + ctx: ctx, + outboundTags: outboundTags, + } +} + +func (r *PreferredByItem) Start() error { + outboundManager := service.FromContext[adapter.OutboundManager](r.ctx) + for _, outboundTag := range r.outboundTags { + rawOutbound, loaded := outboundManager.Outbound(outboundTag) + if !loaded { + return E.New("outbound not found: ", outboundTag) + } + outboundWithPreferredRoutes, withRoutes := rawOutbound.(adapter.OutboundWithPreferredRoutes) + if !withRoutes { + return E.New("outbound type does not support preferred routes: ", rawOutbound.Type()) + } + r.outbounds = append(r.outbounds, outboundWithPreferredRoutes) + } + return nil +} + +func (r *PreferredByItem) Match(metadata *adapter.InboundContext) bool { + var domainHost string + if metadata.Domain != "" { + domainHost = metadata.Domain + } else { + domainHost = metadata.Destination.Fqdn + } + if domainHost != "" { + for _, outbound := range r.outbounds { + if outbound.PreferredDomain(domainHost) { + return true + } + } + } + if metadata.Destination.IsIP() { + for _, outbound := range r.outbounds { + if outbound.PreferredAddress(metadata.Destination.Addr) { + return true + } + } + } + if len(metadata.DestinationAddresses) > 0 { + for _, address := range metadata.DestinationAddresses { + for _, outbound := range r.outbounds { + if outbound.PreferredAddress(address) { + return true + } + } + } + } + return false +} + +func (r *PreferredByItem) String() string { + description := "preferred_by=" + pLen := len(r.outboundTags) + if pLen == 1 { + description += F.ToString(r.outboundTags[0]) + } else { + description += "[" + strings.Join(F.MapToString(r.outboundTags), " ") + "]" + } + return description +} diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index 3801640f..f4c37c0c 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -8,7 +8,9 @@ import ( "net" "net/netip" "os" + "reflect" "strings" + "unsafe" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -30,6 +32,7 @@ type Endpoint struct { allowedAddress []netip.Prefix tunDevice Device device *device.Device + allowedIPs *device.AllowedIPs pause pause.Manager pauseCallback *list.Element[pause.Callback] } @@ -191,6 +194,7 @@ func (e *Endpoint) Start(resolve bool) error { if e.pause != nil { e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated) } + e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr())) return nil } @@ -218,6 +222,10 @@ func (e *Endpoint) Close() error { return nil } +func (e *Endpoint) Lookup(address netip.Addr) *device.Peer { + return e.allowedIPs.Lookup(address.AsSlice()) +} + func (e *Endpoint) onPauseUpdated(event int) { switch event { case pause.EventDevicePaused, pause.EventNetworkPause: