diff --git a/dns/transport/dhcp/dhcp.go b/dns/transport/dhcp/dhcp.go index 92dd1f8b..ffe2f097 100644 --- a/dns/transport/dhcp/dhcp.go +++ b/dns/transport/dhcp/dhcp.go @@ -9,10 +9,8 @@ import ( "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" - "github.com/sagernet/sing-box/dns/transport" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" @@ -29,6 +27,7 @@ import ( "github.com/insomniacslk/dhcp/dhcpv4" mDNS "github.com/miekg/dns" + "golang.org/x/exp/slices" ) func RegisterTransport(registry *dns.TransportRegistry) { @@ -45,9 +44,12 @@ type Transport struct { networkManager adapter.NetworkManager interfaceName string interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback] - transports []adapter.DNSTransport - updateAccess sync.Mutex + transportLock sync.RWMutex updatedAt time.Time + servers []M.Socksaddr + search []string + ndots int + attempts int } func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) { @@ -62,16 +64,28 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt logger: logger, networkManager: service.FromContext[adapter.NetworkManager](ctx), interfaceName: options.Interface, + ndots: 1, + attempts: 2, }, nil } +func NewRawTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) *Transport { + return &Transport{ + TransportAdapter: transportAdapter, + ctx: ctx, + dialer: dialer, + logger: logger, + networkManager: service.FromContext[adapter.NetworkManager](ctx), + } +} + func (t *Transport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.fetchServers() + _, err := t.Fetch() if err != nil { - return err + t.logger.Error(E.Cause(err, "fetch DNS servers")) } if t.interfaceName == "" { t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated) @@ -80,9 +94,6 @@ func (t *Transport) Start(stage adapter.StartStage) error { } func (t *Transport) Close() error { - for _, transport := range t.transports { - transport.Close() - } if t.interfaceCallback != nil { t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) } @@ -90,23 +101,44 @@ func (t *Transport) Close() error { } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - err := t.fetchServers() + servers, err := t.Fetch() if err != nil { return nil, err } - - if len(t.transports) == 0 { + if len(servers) == 0 { return nil, E.New("dhcp: empty DNS servers from response") } + return t.Exchange0(ctx, message, servers) +} - var response *mDNS.Msg - for _, transport := range t.transports { - response, err = transport.Exchange(ctx, message) - if err == nil { - return response, nil - } +func (t *Transport) Exchange0(ctx context.Context, message *mDNS.Msg, servers []M.Socksaddr) (*mDNS.Msg, error) { + question := message.Question[0] + domain := dns.FqdnToDomain(question.Name) + if len(servers) == 1 || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) { + return t.exchangeSingleRequest(ctx, servers, message, domain) + } else { + return t.exchangeParallel(ctx, servers, message, domain) } - return nil, err +} + +func (t *Transport) Fetch() ([]M.Socksaddr, error) { + t.transportLock.RLock() + updatedAt := t.updatedAt + servers := t.servers + t.transportLock.RUnlock() + if time.Since(updatedAt) < C.DHCPTTL { + return servers, nil + } + t.transportLock.Lock() + defer t.transportLock.Unlock() + if time.Since(t.updatedAt) < C.DHCPTTL { + return t.servers, nil + } + err := t.updateServers() + if err != nil { + return nil, err + } + return t.servers, nil } func (t *Transport) fetchInterface() (*control.Interface, error) { @@ -124,18 +156,6 @@ func (t *Transport) fetchInterface() (*control.Interface, error) { } } -func (t *Transport) fetchServers() error { - if time.Since(t.updatedAt) < C.DHCPTTL { - return nil - } - t.updateAccess.Lock() - defer t.updateAccess.Unlock() - if time.Since(t.updatedAt) < C.DHCPTTL { - return nil - } - return t.updateServers() -} - func (t *Transport) updateServers() error { iface, err := t.fetchInterface() if err != nil { @@ -148,7 +168,7 @@ func (t *Transport) updateServers() error { cancel() if err != nil { return err - } else if len(t.transports) == 0 { + } else if len(t.servers) == 0 { return E.New("dhcp: empty DNS servers response") } else { t.updatedAt = time.Now() @@ -177,7 +197,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface) } defer packetConn.Close() - discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(dhcpv4.OptionDomainNameServer)) + discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(dhcpv4.OptionDomainNameServer, dhcpv4.OptionDNSDomainSearchList)) if err != nil { return err } @@ -223,31 +243,21 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne continue } - dns := dhcpPacket.DNS() - if len(dns) == 0 { - return nil - } - return t.recreateServers(iface, common.Map(dns, func(it net.IP) M.Socksaddr { - return M.SocksaddrFrom(M.AddrFromIP(it), 53) - })) + return t.recreateServers(iface, dhcpPacket) } } -func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.Socksaddr) error { - if len(serverAddrs) > 0 { - t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "]") +func (t *Transport) recreateServers(iface *control.Interface, dhcpPacket *dhcpv4.DHCPv4) error { + searchList := dhcpPacket.DomainSearch() + if searchList != nil { + t.search = searchList.Labels } - serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{ - BindInterface: iface.Name, - UDPFragmentDefault: true, - })) - var transports []adapter.DNSTransport - for _, serverAddr := range serverAddrs { - transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr)) + serverAddrs := common.Map(dhcpPacket.DNS(), func(it net.IP) M.Socksaddr { + return M.SocksaddrFrom(M.AddrFromIP(it), 53) + }) + if len(serverAddrs) > 0 && !slices.Equal(t.servers, serverAddrs) { + t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "], search: [", strings.Join(t.search, ","), "]") } - for _, transport := range t.transports { - transport.Close() - } - t.transports = transports + t.servers = serverAddrs return nil } diff --git a/dns/transport/dhcp/dhcp_shared.go b/dns/transport/dhcp/dhcp_shared.go new file mode 100644 index 00000000..e470b1da --- /dev/null +++ b/dns/transport/dhcp/dhcp_shared.go @@ -0,0 +1,202 @@ +package dhcp + +import ( + "context" + "math/rand" + "strings" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +const ( + // net.maxDNSPacketSize + maxDNSPacketSize = 1232 +) + +func (t *Transport) exchangeSingleRequest(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + var lastErr error + for _, fqdn := range t.nameList(domain) { + response, err := t.tryOneName(ctx, servers, fqdn, message) + if err != nil { + lastErr = err + continue + } + return response, nil + } + return nil, lastErr +} + +func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + returned := make(chan struct{}) + defer close(returned) + type queryResult struct { + response *mDNS.Msg + err error + } + results := make(chan queryResult) + startRacer := func(ctx context.Context, fqdn string) { + response, err := t.tryOneName(ctx, servers, fqdn, message) + if err == nil { + if response.Rcode != mDNS.RcodeSuccess { + err = dns.RcodeError(response.Rcode) + } else if len(dns.MessageToAddresses(response)) == 0 { + err = E.New(fqdn, ": empty result") + } + } + select { + case results <- queryResult{response, err}: + case <-returned: + } + } + queryCtx, queryCancel := context.WithCancel(ctx) + defer queryCancel() + var nameCount int + for _, fqdn := range t.nameList(domain) { + nameCount++ + go startRacer(queryCtx, fqdn) + } + var errors []error + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-results: + if result.err == nil { + return result.response, nil + } + errors = append(errors, result.err) + if len(errors) == nameCount { + return nil, E.Errors(errors...) + } + } + } +} + +func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) { + sLen := len(servers) + var lastErr error + for i := 0; i < t.attempts; i++ { + for j := 0; j < sLen; j++ { + server := servers[j%sLen] + question := message.Question[0] + question.Name = fqdn + response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true) + if err != nil { + lastErr = err + continue + } + return response, nil + } + } + return nil, E.Cause(lastErr, fqdn) +} + +func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) { + if server.Port == 0 { + server.Port = 53 + } + var networks []string + if useTCP { + networks = []string{N.NetworkTCP} + } else { + networks = []string{N.NetworkUDP, N.NetworkTCP} + } + request := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: uint16(rand.Uint32()), + RecursionDesired: true, + AuthenticatedData: ad, + }, + Question: []mDNS.Question{question}, + Compress: true, + } + request.SetEdns0(maxDNSPacketSize, false) + buffer := buf.Get(buf.UDPBufferSize) + defer buf.Put(buffer) + for _, network := range networks { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + conn, err := t.dialer.DialContext(ctx, network, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + conn.SetDeadline(deadline) + } + rawMessage, err := request.PackBuffer(buffer) + if err != nil { + return nil, E.Cause(err, "pack request") + } + _, err = conn.Write(rawMessage) + if err != nil { + return nil, E.Cause(err, "write request") + } + n, err := conn.Read(buffer) + if err != nil { + return nil, E.Cause(err, "read response") + } + var response mDNS.Msg + err = response.Unpack(buffer[:n]) + if err != nil { + return nil, E.Cause(err, "unpack response") + } + if response.Truncated && network == N.NetworkUDP { + continue + } + return &response, nil + } + panic("unexpected") +} + +func (t *Transport) nameList(name string) []string { + l := len(name) + rooted := l > 0 && name[l-1] == '.' + if l > 254 || l == 254 && !rooted { + return nil + } + + if rooted { + if avoidDNS(name) { + return nil + } + return []string{name} + } + + hasNdots := strings.Count(name, ".") >= t.ndots + name += "." + // l++ + + names := make([]string, 0, 1+len(t.search)) + if hasNdots && !avoidDNS(name) { + names = append(names, name) + } + for _, suffix := range t.search { + fqdn := name + suffix + if !avoidDNS(fqdn) && len(fqdn) <= 254 { + names = append(names, fqdn) + } + } + if !hasNdots && !avoidDNS(name) { + names = append(names, name) + } + return names +} + +func avoidDNS(name string) bool { + if name == "" { + return true + } + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return strings.HasSuffix(name, ".onion") +} diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go index 74c9c702..28d49aed 100644 --- a/dns/transport/local/local.go +++ b/dns/transport/local/local.go @@ -1,9 +1,9 @@ +//go:build !darwin + package local import ( "context" - "math/rand" - "time" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -11,10 +11,8 @@ import ( "github.com/sagernet/sing-box/dns/transport/hosts" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" mDNS "github.com/miekg/dns" @@ -37,9 +35,6 @@ type Transport struct { } func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { - if C.IsDarwin && !options.PreferGo { - return NewResolvTransport(ctx, logger, tag) - } transportDialer, err := dns.NewLocalDialer(ctx, options) if err != nil { return nil, err @@ -94,147 +89,5 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil } } - systemConfig := getSystemDNSConfig(t.ctx) - if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) { - return t.exchangeSingleRequest(ctx, systemConfig, message, domain) - } else { - return t.exchangeParallel(ctx, systemConfig, message, domain) - } -} - -func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { - var lastErr error - for _, fqdn := range systemConfig.nameList(domain) { - response, err := t.tryOneName(ctx, systemConfig, fqdn, message) - if err != nil { - lastErr = err - continue - } - return response, nil - } - return nil, lastErr -} - -func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { - returned := make(chan struct{}) - defer close(returned) - type queryResult struct { - response *mDNS.Msg - err error - } - results := make(chan queryResult) - startRacer := func(ctx context.Context, fqdn string) { - response, err := t.tryOneName(ctx, systemConfig, fqdn, message) - if err == nil { - if response.Rcode != mDNS.RcodeSuccess { - err = dns.RcodeError(response.Rcode) - } else if len(dns.MessageToAddresses(response)) == 0 { - err = E.New(fqdn, ": empty result") - } - } - select { - case results <- queryResult{response, err}: - case <-returned: - } - } - queryCtx, queryCancel := context.WithCancel(ctx) - defer queryCancel() - var nameCount int - for _, fqdn := range systemConfig.nameList(domain) { - nameCount++ - go startRacer(queryCtx, fqdn) - } - var errors []error - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case result := <-results: - if result.err == nil { - return result.response, nil - } - errors = append(errors, result.err) - if len(errors) == nameCount { - return nil, E.Errors(errors...) - } - } - } -} - -func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) { - serverOffset := config.serverOffset() - sLen := uint32(len(config.servers)) - var lastErr error - for i := 0; i < config.attempts; i++ { - for j := uint32(0); j < sLen; j++ { - server := config.servers[(serverOffset+j)%sLen] - question := message.Question[0] - question.Name = fqdn - response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD) - if err != nil { - lastErr = err - continue - } - return response, nil - } - } - return nil, E.Cause(lastErr, fqdn) -} - -func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) { - if server.Port == 0 { - server.Port = 53 - } - var networks []string - if useTCP { - networks = []string{N.NetworkTCP} - } else { - networks = []string{N.NetworkUDP, N.NetworkTCP} - } - request := &mDNS.Msg{ - MsgHdr: mDNS.MsgHdr{ - Id: uint16(rand.Uint32()), - RecursionDesired: true, - AuthenticatedData: ad, - }, - Question: []mDNS.Question{question}, - Compress: true, - } - request.SetEdns0(maxDNSPacketSize, false) - buffer := buf.Get(buf.UDPBufferSize) - defer buf.Put(buffer) - for _, network := range networks { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - conn, err := t.dialer.DialContext(ctx, network, server) - if err != nil { - return nil, err - } - defer conn.Close() - if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { - conn.SetDeadline(deadline) - } - rawMessage, err := request.PackBuffer(buffer) - if err != nil { - return nil, E.Cause(err, "pack request") - } - _, err = conn.Write(rawMessage) - if err != nil { - return nil, E.Cause(err, "write request") - } - n, err := conn.Read(buffer) - if err != nil { - return nil, E.Cause(err, "read response") - } - var response mDNS.Msg - err = response.Unpack(buffer[:n]) - if err != nil { - return nil, E.Cause(err, "unpack response") - } - if response.Truncated && network == N.NetworkUDP { - continue - } - return &response, nil - } - panic("unexpected") + return t.exchange(ctx, message, domain) } diff --git a/dns/transport/local/local_darwin.go b/dns/transport/local/local_darwin.go new file mode 100644 index 00000000..6bf9e455 --- /dev/null +++ b/dns/transport/local/local_darwin.go @@ -0,0 +1,135 @@ +//go:build darwin + +package local + +import ( + "context" + "errors" + "net" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/hosts" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" + + mDNS "github.com/miekg/dns" +) + +func RegisterTransport(registry *dns.TransportRegistry) { + dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport) +} + +var _ adapter.DNSTransport = (*Transport)(nil) + +type Transport struct { + dns.TransportAdapter + ctx context.Context + logger logger.ContextLogger + hosts *hosts.File + dialer N.Dialer + preferGo bool + fallback bool + dhcpTransport dhcpTransport + resolver net.Resolver +} + +type dhcpTransport interface { + adapter.DNSTransport + Fetch() ([]M.Socksaddr, error) + Exchange0(ctx context.Context, message *mDNS.Msg, servers []M.Socksaddr) (*mDNS.Msg, error) +} + +func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { + transportDialer, err := dns.NewLocalDialer(ctx, options) + if err != nil { + return nil, err + } + transportAdapter := dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options) + return &Transport{ + TransportAdapter: transportAdapter, + ctx: ctx, + logger: logger, + hosts: hosts.NewFile(hosts.DefaultPath), + dialer: transportDialer, + preferGo: options.PreferGo, + }, nil +} + +func (t *Transport) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + inboundManager := service.FromContext[adapter.InboundManager](t.ctx) + for _, inbound := range inboundManager.Inbounds() { + if inbound.Type() == C.TypeTun { + t.fallback = true + break + } + } + if t.fallback { + t.dhcpTransport = newDHCPTransport(t.TransportAdapter, log.ContextWithOverrideLevel(t.ctx, log.LevelDebug), t.dialer, t.logger) + if t.dhcpTransport != nil { + err := t.dhcpTransport.Start(stage) + if err != nil { + return err + } + } + } + return nil +} + +func (t *Transport) Close() error { + return common.Close( + t.dhcpTransport, + ) +} + +func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + question := message.Question[0] + domain := dns.FqdnToDomain(question.Name) + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { + addresses := t.hosts.Lookup(domain) + if len(addresses) > 0 { + return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil + } + } + if !t.fallback { + return t.exchange(ctx, message, domain) + } + if t.dhcpTransport != nil { + dhcpTransports, _ := t.dhcpTransport.Fetch() + if len(dhcpTransports) > 0 { + return t.dhcpTransport.Exchange0(ctx, message, dhcpTransports) + } + } + if t.preferGo { + // Assuming the user knows what they are doing, we still execute the query which will fail. + return t.exchange(ctx, message, domain) + } + if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { + var network string + if question.Qtype == mDNS.TypeA { + network = "ip4" + } else { + network = "ip6" + } + addresses, err := t.resolver.LookupNetIP(ctx, network, domain) + if err != nil { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return nil, dns.RcodeRefused + } + return nil, err + } + return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil + } + return nil, E.New("only A and AAAA queries are supported on Apple platforms when using TUN and DHCP unavailable.") +} diff --git a/dns/transport/local/local_darwin_dhcp.go b/dns/transport/local/local_darwin_dhcp.go new file mode 100644 index 00000000..b228b76a --- /dev/null +++ b/dns/transport/local/local_darwin_dhcp.go @@ -0,0 +1,16 @@ +//go:build darwin && with_dhcp + +package local + +import ( + "context" + + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/dhcp" + "github.com/sagernet/sing-box/log" + N "github.com/sagernet/sing/common/network" +) + +func newDHCPTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) dhcpTransport { + return dhcp.NewRawTransport(transportAdapter, ctx, dialer, logger) +} diff --git a/dns/transport/local/local_darwin_nodhcp.go b/dns/transport/local/local_darwin_nodhcp.go new file mode 100644 index 00000000..5ce84690 --- /dev/null +++ b/dns/transport/local/local_darwin_nodhcp.go @@ -0,0 +1,15 @@ +//go:build darwin && !with_dhcp + +package local + +import ( + "context" + + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/log" + N "github.com/sagernet/sing/common/network" +) + +func newDHCPTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) dhcpTransport { + return nil +} diff --git a/dns/transport/local/local_resolv.go b/dns/transport/local/local_resolv.go deleted file mode 100644 index cf7bcfba..00000000 --- a/dns/transport/local/local_resolv.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build darwin - -package local - -import ( - "context" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/dns" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common/logger" - - mDNS "github.com/miekg/dns" -) - -var _ adapter.DNSTransport = (*ResolvTransport)(nil) - -type ResolvTransport struct { - dns.TransportAdapter - ctx context.Context - logger logger.ContextLogger -} - -func NewResolvTransport(ctx context.Context, logger log.ContextLogger, tag string) (adapter.DNSTransport, error) { - return &ResolvTransport{ - TransportAdapter: dns.NewTransportAdapter(C.DNSTypeLocal, tag, nil), - ctx: ctx, - logger: logger, - }, nil -} - -func (t *ResolvTransport) Start(stage adapter.StartStage) error { - return nil -} - -func (t *ResolvTransport) Close() error { - return nil -} - -func (t *ResolvTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - question := message.Question[0] - return doBlockingWithCtx(ctx, func() (*mDNS.Msg, error) { - return cgoResSearch(question.Name, int(question.Qtype), int(question.Qclass)) - }) -} diff --git a/dns/transport/local/local_resolv_linkname.go b/dns/transport/local/local_resolv_linkname.go deleted file mode 100644 index 1495ae1d..00000000 --- a/dns/transport/local/local_resolv_linkname.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build darwin - -package local - -import ( - "context" - "errors" - "runtime" - "syscall" - "unsafe" - _ "unsafe" - - E "github.com/sagernet/sing/common/exceptions" - - mDNS "github.com/miekg/dns" -) - -type ( - _C_char = byte - _C_int = int32 - _C_uchar = byte - _C_ushort = uint16 - _C_uint = uint32 - _C_ulong = uint64 - _C_struct___res_state = ResState - _C_struct_sockaddr = syscall.RawSockaddr -) - -func _C_free(p unsafe.Pointer) { runtime.KeepAlive(p) } - -func _C_malloc(n uintptr) unsafe.Pointer { - if n <= 0 { - n = 1 - } - return unsafe.Pointer(&make([]byte, n)[0]) -} - -const ( - MAXNS = 3 - MAXDNSRCH = 6 -) - -type ResState struct { - Retrans _C_int - Retry _C_int - Options _C_ulong - Nscount _C_int - Nsaddrlist [MAXNS]_C_struct_sockaddr - Id _C_ushort - Dnsrch [MAXDNSRCH + 1]*_C_char - Defname [256]_C_char - Pfcode _C_ulong - Ndots _C_uint - Nsort _C_uint - stub [128]byte -} - -//go:linkname ResNinit internal/syscall/unix.ResNinit -func ResNinit(state *_C_struct___res_state) error - -//go:linkname ResNsearch internal/syscall/unix.ResNsearch -func ResNsearch(state *_C_struct___res_state, dname *byte, class, typ int, ans *byte, anslen int) (int, error) - -//go:linkname ResNclose internal/syscall/unix.ResNclose -func ResNclose(state *_C_struct___res_state) - -//go:linkname GoString internal/syscall/unix.GoString -func GoString(p *byte) string - -// doBlockingWithCtx executes a blocking function in a separate goroutine when the provided -// context is cancellable. It is intended for use with calls that don't support context -// cancellation (cgo, syscalls). blocking func may still be running after this function finishes. -// For the duration of the execution of the blocking function, the thread is 'acquired' using [acquireThread], -// blocking might not be executed when the context gets canceled early. -func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) { - if err := acquireThread(ctx); err != nil { - var zero T - return zero, err - } - - if ctx.Done() == nil { - defer releaseThread() - return blocking() - } - - type result struct { - res T - err error - } - - res := make(chan result, 1) - go func() { - defer releaseThread() - var r result - r.res, r.err = blocking() - res <- r - }() - - select { - case r := <-res: - return r.res, r.err - case <-ctx.Done(): - var zero T - return zero, ctx.Err() - } -} - -//go:linkname acquireThread net.acquireThread -func acquireThread(ctx context.Context) error - -//go:linkname releaseThread net.releaseThread -func releaseThread() - -func cgoResSearch(hostname string, rtype, class int) (*mDNS.Msg, error) { - resStateSize := unsafe.Sizeof(_C_struct___res_state{}) - var state *_C_struct___res_state - if resStateSize > 0 { - mem := _C_malloc(resStateSize) - defer _C_free(mem) - memSlice := unsafe.Slice((*byte)(mem), resStateSize) - clear(memSlice) - state = (*_C_struct___res_state)(unsafe.Pointer(&memSlice[0])) - } - if err := ResNinit(state); err != nil { - return nil, errors.New("res_ninit failure: " + err.Error()) - } - defer ResNclose(state) - - bufSize := maxDNSPacketSize - buf := (*_C_uchar)(_C_malloc(uintptr(bufSize))) - defer _C_free(unsafe.Pointer(buf)) - - s, err := syscall.BytePtrFromString(hostname) - if err != nil { - return nil, err - } - - var size int - for { - size, _ = ResNsearch(state, s, class, rtype, buf, bufSize) - if size <= bufSize || size > 0xffff { - break - } - - // Allocate a bigger buffer to fit the entire msg. - _C_free(unsafe.Pointer(buf)) - bufSize = size - buf = (*_C_uchar)(_C_malloc(uintptr(bufSize))) - } - - var msg mDNS.Msg - if size == -1 { - // macOS's libresolv seems to directly return -1 for responses that are not success responses but are exchanged. - // However, we still need the response, so we fall back to parsing the entire buffer. - err = msg.Unpack(unsafe.Slice(buf, bufSize)) - if err != nil { - return nil, E.New("res_nsearch failure") - } - } else { - err = msg.Unpack(unsafe.Slice(buf, size)) - if err != nil { - return nil, err - } - } - return &msg, nil -} diff --git a/dns/transport/local/local_resolv_stub.go b/dns/transport/local/local_resolv_stub.go deleted file mode 100644 index 8486b87a..00000000 --- a/dns/transport/local/local_resolv_stub.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build !darwin - -package local - -import ( - "context" - "os" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/log" -) - -func NewResolvTransport(ctx context.Context, logger log.ContextLogger, tag string) (adapter.DNSTransport, error) { - return nil, os.ErrInvalid -} diff --git a/dns/transport/local/local_shared.go b/dns/transport/local/local_shared.go new file mode 100644 index 00000000..9059b91f --- /dev/null +++ b/dns/transport/local/local_shared.go @@ -0,0 +1,161 @@ +package local + +import ( + "context" + "math/rand" + "time" + + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + systemConfig := getSystemDNSConfig(t.ctx) + if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) { + return t.exchangeSingleRequest(ctx, systemConfig, message, domain) + } else { + return t.exchangeParallel(ctx, systemConfig, message, domain) + } +} + +func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + var lastErr error + for _, fqdn := range systemConfig.nameList(domain) { + response, err := t.tryOneName(ctx, systemConfig, fqdn, message) + if err != nil { + lastErr = err + continue + } + return response, nil + } + return nil, lastErr +} + +func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + returned := make(chan struct{}) + defer close(returned) + type queryResult struct { + response *mDNS.Msg + err error + } + results := make(chan queryResult) + startRacer := func(ctx context.Context, fqdn string) { + response, err := t.tryOneName(ctx, systemConfig, fqdn, message) + if err == nil { + if response.Rcode != mDNS.RcodeSuccess { + err = dns.RcodeError(response.Rcode) + } else if len(dns.MessageToAddresses(response)) == 0 { + err = E.New(fqdn, ": empty result") + } + } + select { + case results <- queryResult{response, err}: + case <-returned: + } + } + queryCtx, queryCancel := context.WithCancel(ctx) + defer queryCancel() + var nameCount int + for _, fqdn := range systemConfig.nameList(domain) { + nameCount++ + go startRacer(queryCtx, fqdn) + } + var errors []error + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-results: + if result.err == nil { + return result.response, nil + } + errors = append(errors, result.err) + if len(errors) == nameCount { + return nil, E.Errors(errors...) + } + } + } +} + +func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) { + serverOffset := config.serverOffset() + sLen := uint32(len(config.servers)) + var lastErr error + for i := 0; i < config.attempts; i++ { + for j := uint32(0); j < sLen; j++ { + server := config.servers[(serverOffset+j)%sLen] + question := message.Question[0] + question.Name = fqdn + response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD) + if err != nil { + lastErr = err + continue + } + return response, nil + } + } + return nil, E.Cause(lastErr, fqdn) +} + +func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) { + if server.Port == 0 { + server.Port = 53 + } + var networks []string + if useTCP { + networks = []string{N.NetworkTCP} + } else { + networks = []string{N.NetworkUDP, N.NetworkTCP} + } + request := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: uint16(rand.Uint32()), + RecursionDesired: true, + AuthenticatedData: ad, + }, + Question: []mDNS.Question{question}, + Compress: true, + } + request.SetEdns0(maxDNSPacketSize, false) + buffer := buf.Get(buf.UDPBufferSize) + defer buf.Put(buffer) + for _, network := range networks { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + conn, err := t.dialer.DialContext(ctx, network, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + conn.SetDeadline(deadline) + } + rawMessage, err := request.PackBuffer(buffer) + if err != nil { + return nil, E.Cause(err, "pack request") + } + _, err = conn.Write(rawMessage) + if err != nil { + return nil, E.Cause(err, "write request") + } + n, err := conn.Read(buffer) + if err != nil { + return nil, E.Cause(err, "read response") + } + var response mDNS.Msg + err = response.Unpack(buffer[:n]) + if err != nil { + return nil, E.Cause(err, "unpack response") + } + if response.Truncated && network == N.NetworkUDP { + continue + } + return &response, nil + } + panic("unexpected") +} diff --git a/dns/transport/local/resolv_darwin.go b/dns/transport/local/resolv_darwin.go deleted file mode 100644 index 396e40de..00000000 --- a/dns/transport/local/resolv_darwin.go +++ /dev/null @@ -1,72 +0,0 @@ -package local - -import ( - "context" - "net/netip" - "syscall" - "time" - "unsafe" - - E "github.com/sagernet/sing/common/exceptions" - - "github.com/miekg/dns" -) - -func dnsReadConfig(_ context.Context, _ string) *dnsConfig { - resStateSize := unsafe.Sizeof(_C_struct___res_state{}) - var state *_C_struct___res_state - if resStateSize > 0 { - mem := _C_malloc(resStateSize) - defer _C_free(mem) - memSlice := unsafe.Slice((*byte)(mem), resStateSize) - clear(memSlice) - state = (*_C_struct___res_state)(unsafe.Pointer(&memSlice[0])) - } - if err := ResNinit(state); err != nil { - return &dnsConfig{ - servers: defaultNS, - search: dnsDefaultSearch(), - ndots: 1, - timeout: 5 * time.Second, - attempts: 2, - err: E.Cause(err, "libresolv initialization failed"), - } - } - defer ResNclose(state) - conf := &dnsConfig{ - ndots: 1, - timeout: 5 * time.Second, - attempts: int(state.Retry), - } - for i := 0; i < int(state.Nscount); i++ { - addr := parseRawSockaddr(&state.Nsaddrlist[i]) - if addr.IsValid() { - conf.servers = append(conf.servers, addr.String()) - } - } - for i := 0; ; i++ { - search := state.Dnsrch[i] - if search == nil { - break - } - name := dns.Fqdn(GoString(search)) - if name == "" { - continue - } - conf.search = append(conf.search, name) - } - return conf -} - -func parseRawSockaddr(rawSockaddr *syscall.RawSockaddr) netip.Addr { - switch rawSockaddr.Family { - case syscall.AF_INET: - sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(rawSockaddr)) - return netip.AddrFrom4(sa.Addr) - case syscall.AF_INET6: - sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(rawSockaddr)) - return netip.AddrFrom16(sa.Addr) - default: - return netip.Addr{} - } -} diff --git a/dns/transport/local/resolv_unix.go b/dns/transport/local/resolv_unix.go index 99eb71e6..51512f65 100644 --- a/dns/transport/local/resolv_unix.go +++ b/dns/transport/local/resolv_unix.go @@ -1,4 +1,4 @@ -//go:build !windows && !darwin +//go:build !windows package local