diff --git a/adapter/inbound.go b/adapter/inbound.go index edba8447..de30149f 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -57,6 +57,7 @@ type InboundContext struct { Domain string Client string SniffContext any + SnifferNames []string SniffError error // cache diff --git a/route/route.go b/route/route.go index 250b8fee..ef30a0b1 100644 --- a/route/route.go +++ b/route/route.go @@ -27,6 +27,8 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/uot" + + "golang.org/x/exp/slices" ) // Deprecated: use RouteConnectionEx instead. @@ -345,16 +347,16 @@ func (r *Router) matchRule( newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &R.RuleActionSniff{ OverrideDestination: metadata.InboundOptions.SniffOverrideDestination, Timeout: time.Duration(metadata.InboundOptions.SniffTimeout), - }, inputConn, inputPacketConn, nil) - if newErr != nil { - fatalErr = newErr - return - } + }, inputConn, inputPacketConn, nil, nil) if newBuffer != nil { buffers = []*buf.Buffer{newBuffer} } else if len(newPackerBuffers) > 0 { packetBuffers = newPackerBuffers } + if newErr != nil { + fatalErr = newErr + return + } } if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS { fatalErr = r.actionResolve(ctx, metadata, &R.RuleActionResolve{ @@ -453,16 +455,16 @@ match: switch action := currentRule.Action().(type) { case *R.RuleActionSniff: if !preMatch { - newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers) - if newErr != nil { - fatalErr = newErr - return - } + newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers, packetBuffers) if newBuffer != nil { buffers = append(buffers, newBuffer) } else if len(newPacketBuffers) > 0 { packetBuffers = append(packetBuffers, newPacketBuffers...) } + if newErr != nil { + fatalErr = newErr + return + } } else { selectedRule = currentRule selectedRuleIndex = currentRuleIndex @@ -489,7 +491,7 @@ match: func (r *Router) actionSniff( ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionSniff, - inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, + inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, inputPacketBuffers []*N.PacketBuffer, ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) { if sniff.Skip(metadata) { r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first") @@ -501,7 +503,7 @@ func (r *Router) actionSniff( if inputConn != nil { if len(action.StreamSniffers) == 0 && len(action.PacketSniffers) > 0 { return - } else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { + } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError) return } @@ -528,6 +530,7 @@ func (r *Router) actionSniff( action.Timeout, streamSniffers..., ) + metadata.SnifferNames = action.SnifferNames metadata.SniffError = err if err == nil { //goland:noinspection GoDeprecation @@ -553,10 +556,13 @@ func (r *Router) actionSniff( } else if inputPacketConn != nil { if len(action.PacketSniffers) == 0 && len(action.StreamSniffers) > 0 { return - } else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { + } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError) return } + quicMoreData := func() bool { + return slices.Equal(metadata.SnifferNames, action.SnifferNames) && errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) + } var packetSniffers []sniff.PacketSniffer if len(action.PacketSniffers) > 0 { packetSniffers = action.PacketSniffers @@ -571,12 +577,37 @@ func (r *Router) actionSniff( sniff.NTP, } } + var err error + for _, packetBuffer := range inputPacketBuffers { + if quicMoreData() { + err = sniff.PeekPacket( + ctx, + metadata, + packetBuffer.Buffer.Bytes(), + sniff.QUICClientHello, + ) + } else { + err = sniff.PeekPacket( + ctx, metadata, + packetBuffer.Buffer.Bytes(), + packetSniffers..., + ) + } + metadata.SnifferNames = action.SnifferNames + metadata.SniffError = err + if errors.Is(err, sniff.ErrNeedMoreData) { + // TODO: replace with generic message when there are more multi-packet protocols + r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") + continue + } + goto finally + } + packetBuffers = inputPacketBuffers for { var ( sniffBuffer = buf.NewPacket() destination M.Socksaddr done = make(chan struct{}) - err error ) go func() { sniffTimeout := C.ReadPayloadTimeout @@ -602,7 +633,7 @@ func (r *Router) actionSniff( return } } else { - if len(packetBuffers) > 0 || metadata.SniffError != nil { + if quicMoreData() { err = sniff.PeekPacket( ctx, metadata, @@ -622,32 +653,34 @@ func (r *Router) actionSniff( Destination: destination, } packetBuffers = append(packetBuffers, packetBuffer) + metadata.SnifferNames = action.SnifferNames metadata.SniffError = err if errors.Is(err, sniff.ErrNeedMoreData) { // TODO: replace with generic message when there are more multi-packet protocols r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") continue } - if metadata.Protocol != "" { - //goland:noinspection GoDeprecation - if action.OverrideDestination && M.IsDomainName(metadata.Domain) { - metadata.Destination = M.Socksaddr{ - Fqdn: metadata.Domain, - Port: metadata.Destination.Port, - } - } - if metadata.Domain != "" && metadata.Client != "" { - r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client) - } else if metadata.Domain != "" { - r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) - } else if metadata.Client != "" { - r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client) - } else { - r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol) - } + } + goto finally + } + finally: + if err == nil { + //goland:noinspection GoDeprecation + if action.OverrideDestination && M.IsDomainName(metadata.Domain) { + metadata.Destination = M.Socksaddr{ + Fqdn: metadata.Domain, + Port: metadata.Destination.Port, } } - break + if metadata.Domain != "" && metadata.Client != "" { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client) + } else if metadata.Domain != "" { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) + } else if metadata.Client != "" { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client) + } else { + r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol) + } } } return diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index ce50dad9..62d5410a 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -87,7 +87,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti return &RuleActionHijackDNS{}, nil case C.RuleActionTypeSniff: sniffAction := &RuleActionSniff{ - snifferNames: action.SniffOptions.Sniffer, + SnifferNames: action.SniffOptions.Sniffer, Timeout: time.Duration(action.SniffOptions.Timeout), } return sniffAction, sniffAction.build() @@ -361,7 +361,7 @@ func (r *RuleActionHijackDNS) String() string { } type RuleActionSniff struct { - snifferNames []string + SnifferNames []string StreamSniffers []sniff.StreamSniffer PacketSniffers []sniff.PacketSniffer Timeout time.Duration @@ -374,7 +374,7 @@ func (r *RuleActionSniff) Type() string { } func (r *RuleActionSniff) build() error { - for _, name := range r.snifferNames { + for _, name := range r.SnifferNames { switch name { case C.ProtocolTLS: r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello) @@ -407,14 +407,14 @@ func (r *RuleActionSniff) build() error { } func (r *RuleActionSniff) String() string { - if len(r.snifferNames) == 0 && r.Timeout == 0 { + if len(r.SnifferNames) == 0 && r.Timeout == 0 { return "sniff" - } else if len(r.snifferNames) > 0 && r.Timeout == 0 { - return F.ToString("sniff(", strings.Join(r.snifferNames, ","), ")") - } else if len(r.snifferNames) == 0 && r.Timeout > 0 { + } else if len(r.SnifferNames) > 0 && r.Timeout == 0 { + return F.ToString("sniff(", strings.Join(r.SnifferNames, ","), ")") + } else if len(r.SnifferNames) == 0 && r.Timeout > 0 { return F.ToString("sniff(", r.Timeout.String(), ")") } else { - return F.ToString("sniff(", strings.Join(r.snifferNames, ","), ",", r.Timeout.String(), ")") + return F.ToString("sniff(", strings.Join(r.SnifferNames, ","), ",", r.Timeout.String(), ")") } }