From 2edfed7d918bce9fd94777f9c50e48e0ded52a48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 2 Sep 2025 17:37:44 +0800 Subject: [PATCH] Improve DHCP DNS server --- dns/transport/dhcp/dhcp.go | 132 ++++++++++--------- dns/transport/dhcp/dhcp_shared.go | 202 ++++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 56 deletions(-) create mode 100644 dns/transport/dhcp/dhcp_shared.go diff --git a/dns/transport/dhcp/dhcp.go b/dns/transport/dhcp/dhcp.go index 92dd1f8b..b56a60e7 100644 --- a/dns/transport/dhcp/dhcp.go +++ b/dns/transport/dhcp/dhcp.go @@ -9,10 +9,8 @@ import ( "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" - "github.com/sagernet/sing-box/dns/transport" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" @@ -29,6 +27,7 @@ import ( "github.com/insomniacslk/dhcp/dhcpv4" mDNS "github.com/miekg/dns" + "golang.org/x/exp/slices" ) func RegisterTransport(registry *dns.TransportRegistry) { @@ -45,9 +44,12 @@ type Transport struct { networkManager adapter.NetworkManager interfaceName string interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback] - transports []adapter.DNSTransport - updateAccess sync.Mutex + transportLock sync.RWMutex updatedAt time.Time + servers []M.Socksaddr + search []string + ndots int + attempts int } func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) { @@ -62,27 +64,40 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt logger: logger, networkManager: service.FromContext[adapter.NetworkManager](ctx), interfaceName: options.Interface, + ndots: 1, + attempts: 2, }, nil } +func NewRawTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) *Transport { + return &Transport{ + TransportAdapter: transportAdapter, + ctx: ctx, + dialer: dialer, + logger: logger, + networkManager: service.FromContext[adapter.NetworkManager](ctx), + ndots: 1, + attempts: 2, + } +} + func (t *Transport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.fetchServers() - if err != nil { - return err - } if t.interfaceName == "" { t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated) } + go func() { + _, err := t.Fetch() + if err != nil { + t.logger.Error(E.Cause(err, "fetch DNS servers")) + } + }() return nil } func (t *Transport) Close() error { - for _, transport := range t.transports { - transport.Close() - } if t.interfaceCallback != nil { t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) } @@ -90,23 +105,44 @@ func (t *Transport) Close() error { } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - err := t.fetchServers() + servers, err := t.Fetch() if err != nil { return nil, err } - - if len(t.transports) == 0 { + if len(servers) == 0 { return nil, E.New("dhcp: empty DNS servers from response") } + return t.Exchange0(ctx, message, servers) +} - var response *mDNS.Msg - for _, transport := range t.transports { - response, err = transport.Exchange(ctx, message) - if err == nil { - return response, nil - } +func (t *Transport) Exchange0(ctx context.Context, message *mDNS.Msg, servers []M.Socksaddr) (*mDNS.Msg, error) { + question := message.Question[0] + domain := dns.FqdnToDomain(question.Name) + if len(servers) == 1 || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) { + return t.exchangeSingleRequest(ctx, servers, message, domain) + } else { + return t.exchangeParallel(ctx, servers, message, domain) } - return nil, err +} + +func (t *Transport) Fetch() ([]M.Socksaddr, error) { + t.transportLock.RLock() + updatedAt := t.updatedAt + servers := t.servers + t.transportLock.RUnlock() + if time.Since(updatedAt) < C.DHCPTTL { + return servers, nil + } + t.transportLock.Lock() + defer t.transportLock.Unlock() + if time.Since(t.updatedAt) < C.DHCPTTL { + return t.servers, nil + } + err := t.updateServers() + if err != nil { + return nil, err + } + return t.servers, nil } func (t *Transport) fetchInterface() (*control.Interface, error) { @@ -124,18 +160,6 @@ func (t *Transport) fetchInterface() (*control.Interface, error) { } } -func (t *Transport) fetchServers() error { - if time.Since(t.updatedAt) < C.DHCPTTL { - return nil - } - t.updateAccess.Lock() - defer t.updateAccess.Unlock() - if time.Since(t.updatedAt) < C.DHCPTTL { - return nil - } - return t.updateServers() -} - func (t *Transport) updateServers() error { iface, err := t.fetchInterface() if err != nil { @@ -148,7 +172,7 @@ func (t *Transport) updateServers() error { cancel() if err != nil { return err - } else if len(t.transports) == 0 { + } else if len(t.servers) == 0 { return E.New("dhcp: empty DNS servers response") } else { t.updatedAt = time.Now() @@ -177,7 +201,11 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface) } defer packetConn.Close() - discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(dhcpv4.OptionDomainNameServer)) + discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions( + dhcpv4.OptionDomainName, + dhcpv4.OptionDomainNameServer, + dhcpv4.OptionDNSDomainSearchList, + )) if err != nil { return err } @@ -223,31 +251,23 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne continue } - dns := dhcpPacket.DNS() - if len(dns) == 0 { - return nil - } - return t.recreateServers(iface, common.Map(dns, func(it net.IP) M.Socksaddr { - return M.SocksaddrFrom(M.AddrFromIP(it), 53) - })) + return t.recreateServers(iface, dhcpPacket) } } -func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.Socksaddr) error { - if len(serverAddrs) > 0 { - t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "]") +func (t *Transport) recreateServers(iface *control.Interface, dhcpPacket *dhcpv4.DHCPv4) error { + searchList := dhcpPacket.DomainSearch() + if searchList != nil && len(searchList.Labels) > 0 { + t.search = searchList.Labels + } else if dhcpPacket.DomainName() != "" { + t.search = []string{dhcpPacket.DomainName()} } - serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{ - BindInterface: iface.Name, - UDPFragmentDefault: true, - })) - var transports []adapter.DNSTransport - for _, serverAddr := range serverAddrs { - transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr)) + serverAddrs := common.Map(dhcpPacket.DNS(), func(it net.IP) M.Socksaddr { + return M.SocksaddrFrom(M.AddrFromIP(it), 53) + }) + if len(serverAddrs) > 0 && !slices.Equal(t.servers, serverAddrs) { + t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "], search: [", strings.Join(t.search, ","), "]") } - for _, transport := range t.transports { - transport.Close() - } - t.transports = transports + t.servers = serverAddrs return nil } diff --git a/dns/transport/dhcp/dhcp_shared.go b/dns/transport/dhcp/dhcp_shared.go new file mode 100644 index 00000000..31b92fb5 --- /dev/null +++ b/dns/transport/dhcp/dhcp_shared.go @@ -0,0 +1,202 @@ +package dhcp + +import ( + "context" + "math/rand" + "strings" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + mDNS "github.com/miekg/dns" +) + +const ( + // net.maxDNSPacketSize + maxDNSPacketSize = 1232 +) + +func (t *Transport) exchangeSingleRequest(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + var lastErr error + for _, fqdn := range t.nameList(domain) { + response, err := t.tryOneName(ctx, servers, fqdn, message) + if err != nil { + lastErr = err + continue + } + return response, nil + } + return nil, lastErr +} + +func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) { + returned := make(chan struct{}) + defer close(returned) + type queryResult struct { + response *mDNS.Msg + err error + } + results := make(chan queryResult) + startRacer := func(ctx context.Context, fqdn string) { + response, err := t.tryOneName(ctx, servers, fqdn, message) + if err == nil { + if response.Rcode != mDNS.RcodeSuccess { + err = dns.RcodeError(response.Rcode) + } else if len(dns.MessageToAddresses(response)) == 0 { + err = E.New(fqdn, ": empty result") + } + } + select { + case results <- queryResult{response, err}: + case <-returned: + } + } + queryCtx, queryCancel := context.WithCancel(ctx) + defer queryCancel() + var nameCount int + for _, fqdn := range t.nameList(domain) { + nameCount++ + go startRacer(queryCtx, fqdn) + } + var errors []error + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-results: + if result.err == nil { + return result.response, nil + } + errors = append(errors, result.err) + if len(errors) == nameCount { + return nil, E.Errors(errors...) + } + } + } +} + +func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) { + sLen := len(servers) + var lastErr error + for i := 0; i < t.attempts; i++ { + for j := 0; j < sLen; j++ { + server := servers[j] + question := message.Question[0] + question.Name = fqdn + response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true) + if err != nil { + lastErr = err + continue + } + return response, nil + } + } + return nil, E.Cause(lastErr, fqdn) +} + +func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) { + if server.Port == 0 { + server.Port = 53 + } + var networks []string + if useTCP { + networks = []string{N.NetworkTCP} + } else { + networks = []string{N.NetworkUDP, N.NetworkTCP} + } + request := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: uint16(rand.Uint32()), + RecursionDesired: true, + AuthenticatedData: ad, + }, + Question: []mDNS.Question{question}, + Compress: true, + } + request.SetEdns0(maxDNSPacketSize, false) + buffer := buf.Get(buf.UDPBufferSize) + defer buf.Put(buffer) + for _, network := range networks { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + conn, err := t.dialer.DialContext(ctx, network, server) + if err != nil { + return nil, err + } + defer conn.Close() + if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() { + conn.SetDeadline(deadline) + } + rawMessage, err := request.PackBuffer(buffer) + if err != nil { + return nil, E.Cause(err, "pack request") + } + _, err = conn.Write(rawMessage) + if err != nil { + return nil, E.Cause(err, "write request") + } + n, err := conn.Read(buffer) + if err != nil { + return nil, E.Cause(err, "read response") + } + var response mDNS.Msg + err = response.Unpack(buffer[:n]) + if err != nil { + return nil, E.Cause(err, "unpack response") + } + if response.Truncated && network == N.NetworkUDP { + continue + } + return &response, nil + } + panic("unexpected") +} + +func (t *Transport) nameList(name string) []string { + l := len(name) + rooted := l > 0 && name[l-1] == '.' + if l > 254 || l == 254 && !rooted { + return nil + } + + if rooted { + if avoidDNS(name) { + return nil + } + return []string{name} + } + + hasNdots := strings.Count(name, ".") >= t.ndots + name += "." + // l++ + + names := make([]string, 0, 1+len(t.search)) + if hasNdots && !avoidDNS(name) { + names = append(names, name) + } + for _, suffix := range t.search { + fqdn := name + suffix + if !avoidDNS(fqdn) && len(fqdn) <= 254 { + names = append(names, fqdn) + } + } + if !hasNdots && !avoidDNS(name) { + names = append(names, name) + } + return names +} + +func avoidDNS(name string) bool { + if name == "" { + return true + } + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return strings.HasSuffix(name, ".onion") +}