From d78cdac640940e4a5cda7afcc133a57b6fe3feb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 5 Sep 2025 15:49:06 +0800 Subject: [PATCH] Fix DNS client --- .../clashapi => common}/compatible/map.go | 0 dns/client.go | 118 ++++++++++-------- .../clashapi/trafficontrol/manager.go | 2 +- 3 files changed, 69 insertions(+), 51 deletions(-) rename {experimental/clashapi => common}/compatible/map.go (100%) diff --git a/experimental/clashapi/compatible/map.go b/common/compatible/map.go similarity index 100% rename from experimental/clashapi/compatible/map.go rename to common/compatible/map.go diff --git a/dns/client.go b/dns/client.go index 9bdb5b69..89acc971 100644 --- a/dns/client.go +++ b/dns/client.go @@ -2,12 +2,14 @@ package dns import ( "context" + "errors" "net" "net/netip" "strings" "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/compatible" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -30,16 +32,18 @@ var ( var _ adapter.DNSClient = (*Client)(nil) type Client struct { - timeout time.Duration - disableCache bool - disableExpire bool - independentCache bool - clientSubnet netip.Prefix - rdrc adapter.RDRCStore - initRDRCFunc func() adapter.RDRCStore - logger logger.ContextLogger - cache freelru.Cache[dns.Question, *dns.Msg] - transportCache freelru.Cache[transportCacheKey, *dns.Msg] + timeout time.Duration + disableCache bool + disableExpire bool + independentCache bool + clientSubnet netip.Prefix + rdrc adapter.RDRCStore + initRDRCFunc func() adapter.RDRCStore + logger logger.ContextLogger + cache freelru.Cache[dns.Question, *dns.Msg] + cacheLock compatible.Map[dns.Question, chan struct{}] + transportCache freelru.Cache[transportCacheKey, *dns.Msg] + transportCacheLock compatible.Map[dns.Question, chan struct{}] } type ClientOptions struct { @@ -96,17 +100,15 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m if c.logger != nil { c.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) } - responseMessage := dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: message.Id, - Response: true, - Rcode: dns.RcodeFormatError, - }, - Question: message.Question, - } - return &responseMessage, nil + return FixedResponseStatus(message, dns.RcodeFormatError), nil } question := message.Question[0] + if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only { + if c.logger != nil { + c.logger.DebugContext(ctx, "strategy rejected") + } + return FixedResponseStatus(message, dns.RcodeSuccess), nil + } clientSubnet := options.ClientSubnet if !clientSubnet.IsValid() { clientSubnet = c.clientSubnet @@ -120,6 +122,27 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m !options.ClientSubnet.IsValid() disableCache := !isSimpleRequest || c.disableCache || options.DisableCache if !disableCache { + if c.cache != nil { + cond, loaded := c.cacheLock.LoadOrStore(question, make(chan struct{})) + if loaded { + <-cond + } else { + defer func() { + c.cacheLock.Delete(question) + close(cond) + }() + } + } else if c.transportCache != nil { + cond, loaded := c.transportCacheLock.LoadOrStore(question, make(chan struct{})) + if loaded { + <-cond + } else { + defer func() { + c.transportCacheLock.Delete(question) + close(cond) + }() + } + } response, ttl := c.loadResponse(question, transport) if response != nil { logCachedResponse(c.logger, ctx, response, ttl) @@ -127,27 +150,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m return response, nil } } - if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only { - responseMessage := dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: message.Id, - Response: true, - Rcode: dns.RcodeSuccess, - }, - Question: []dns.Question{question}, - } - if c.logger != nil { - c.logger.DebugContext(ctx, "strategy rejected") - } - return &responseMessage, nil - } + messageId := message.Id contextTransport, clientSubnetLoaded := transportTagFromContext(ctx) if clientSubnetLoaded && transport.Tag() == contextTransport { return nil, E.New("DNS query loopback in transport[", contextTransport, "]") } ctx = contextWithTransportTag(ctx, transport.Tag()) - if responseChecker != nil && c.rdrc != nil { + if !disableCache && responseChecker != nil && c.rdrc != nil { rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype) if rejected { return nil, ErrResponseRejectedCached @@ -157,7 +167,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m response, err := transport.Exchange(ctx, message) cancel() if err != nil { - return nil, err + var rcodeError RcodeError + if errors.As(err, &rcodeError) { + response = FixedResponseStatus(message, int(rcodeError)) + } else { + return nil, err + } } /*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { validResponse := response @@ -196,13 +211,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m }*/ if responseChecker != nil { var rejected bool - if !(response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError) { + // TODO: add accept_any rule and support to check response instead of addresses + if response.Rcode != dns.RcodeSuccess || len(response.Answer) == 0 { rejected = true } else { rejected = !responseChecker(MessageToAddresses(response)) } if rejected { - if c.rdrc != nil { + if !disableCache && c.rdrc != nil { c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger) } logRejectedResponse(c.logger, ctx, response) @@ -305,8 +321,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom func (c *Client) ClearCache() { if c.cache != nil { c.cache.Purge() - } - if c.transportCache != nil { + } else if c.transportCache != nil { c.transportCache.Purge() } } @@ -390,15 +405,15 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio transportTag: transport.Tag(), }, message) } - return - } - if !c.independentCache { - c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive)) } else { - c.transportCache.AddWithLifetime(transportCacheKey{ - Question: question, - transportTag: transport.Tag(), - }, message, time.Second*time.Duration(timeToLive)) + if !c.independentCache { + c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive)) + } else { + c.transportCache.AddWithLifetime(transportCacheKey{ + Question: question, + transportTag: transport.Tag(), + }, message, time.Second*time.Duration(timeToLive)) + } } } @@ -564,9 +579,12 @@ func transportTagFromContext(ctx context.Context) (string, bool) { func FixedResponseStatus(message *dns.Msg, rcode int) *dns.Msg { return &dns.Msg{ MsgHdr: dns.MsgHdr{ - Id: message.Id, - Rcode: rcode, - Response: true, + Id: message.Id, + Response: true, + Authoritative: true, + RecursionDesired: true, + RecursionAvailable: true, + Rcode: rcode, }, Question: message.Question, } diff --git a/experimental/clashapi/trafficontrol/manager.go b/experimental/clashapi/trafficontrol/manager.go index 7b69b93d..bb4822df 100644 --- a/experimental/clashapi/trafficontrol/manager.go +++ b/experimental/clashapi/trafficontrol/manager.go @@ -6,8 +6,8 @@ import ( "sync/atomic" "time" + "github.com/sagernet/sing-box/common/compatible" C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/experimental/clashapi/compatible" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/x/list"