Fix multiple sniff

This commit is contained in:
世界 2025-09-02 18:09:48 +08:00
parent 0ef7e8eca2
commit cbf48e9b8c
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
3 changed files with 75 additions and 41 deletions

View File

@ -57,6 +57,7 @@ type InboundContext struct {
Domain string Domain string
Client string Client string
SniffContext any SniffContext any
SnifferNames []string
SniffError error SniffError error
// cache // cache

View File

@ -27,6 +27,8 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot" "github.com/sagernet/sing/common/uot"
"golang.org/x/exp/slices"
) )
// Deprecated: use RouteConnectionEx instead. // Deprecated: use RouteConnectionEx instead.
@ -345,16 +347,16 @@ func (r *Router) matchRule(
newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &R.RuleActionSniff{ newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &R.RuleActionSniff{
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination, OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout), Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
}, inputConn, inputPacketConn, nil) }, inputConn, inputPacketConn, nil, nil)
if newErr != nil {
fatalErr = newErr
return
}
if newBuffer != nil { if newBuffer != nil {
buffers = []*buf.Buffer{newBuffer} buffers = []*buf.Buffer{newBuffer}
} else if len(newPackerBuffers) > 0 { } else if len(newPackerBuffers) > 0 {
packetBuffers = newPackerBuffers packetBuffers = newPackerBuffers
} }
if newErr != nil {
fatalErr = newErr
return
}
} }
if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS { if C.DomainStrategy(metadata.InboundOptions.DomainStrategy) != C.DomainStrategyAsIS {
fatalErr = r.actionResolve(ctx, metadata, &R.RuleActionResolve{ fatalErr = r.actionResolve(ctx, metadata, &R.RuleActionResolve{
@ -453,16 +455,16 @@ match:
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *R.RuleActionSniff: case *R.RuleActionSniff:
if !preMatch { if !preMatch {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers, packetBuffers)
if newErr != nil {
fatalErr = newErr
return
}
if newBuffer != nil { if newBuffer != nil {
buffers = append(buffers, newBuffer) buffers = append(buffers, newBuffer)
} else if len(newPacketBuffers) > 0 { } else if len(newPacketBuffers) > 0 {
packetBuffers = append(packetBuffers, newPacketBuffers...) packetBuffers = append(packetBuffers, newPacketBuffers...)
} }
if newErr != nil {
fatalErr = newErr
return
}
} else { } else {
selectedRule = currentRule selectedRule = currentRule
selectedRuleIndex = currentRuleIndex selectedRuleIndex = currentRuleIndex
@ -489,7 +491,7 @@ match:
func (r *Router) actionSniff( func (r *Router) actionSniff(
ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionSniff, ctx context.Context, metadata *adapter.InboundContext, action *R.RuleActionSniff,
inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, inputPacketBuffers []*N.PacketBuffer,
) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) { ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
if sniff.Skip(metadata) { if sniff.Skip(metadata) {
r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first") r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first")
@ -501,7 +503,7 @@ func (r *Router) actionSniff(
if inputConn != nil { if inputConn != nil {
if len(action.StreamSniffers) == 0 && len(action.PacketSniffers) > 0 { if len(action.StreamSniffers) == 0 && len(action.PacketSniffers) > 0 {
return return
} else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError) r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError)
return return
} }
@ -528,6 +530,7 @@ func (r *Router) actionSniff(
action.Timeout, action.Timeout,
streamSniffers..., streamSniffers...,
) )
metadata.SnifferNames = action.SnifferNames
metadata.SniffError = err metadata.SniffError = err
if err == nil { if err == nil {
//goland:noinspection GoDeprecation //goland:noinspection GoDeprecation
@ -553,10 +556,13 @@ func (r *Router) actionSniff(
} else if inputPacketConn != nil { } else if inputPacketConn != nil {
if len(action.PacketSniffers) == 0 && len(action.StreamSniffers) > 0 { if len(action.PacketSniffers) == 0 && len(action.StreamSniffers) > 0 {
return return
} else if metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) { } else if slices.Equal(metadata.SnifferNames, action.SnifferNames) && metadata.SniffError != nil && !errors.Is(metadata.SniffError, sniff.ErrNeedMoreData) {
r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError) r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.SniffError)
return return
} }
quicMoreData := func() bool {
return slices.Equal(metadata.SnifferNames, action.SnifferNames) && errors.Is(metadata.SniffError, sniff.ErrNeedMoreData)
}
var packetSniffers []sniff.PacketSniffer var packetSniffers []sniff.PacketSniffer
if len(action.PacketSniffers) > 0 { if len(action.PacketSniffers) > 0 {
packetSniffers = action.PacketSniffers packetSniffers = action.PacketSniffers
@ -571,12 +577,37 @@ func (r *Router) actionSniff(
sniff.NTP, sniff.NTP,
} }
} }
var err error
for _, packetBuffer := range inputPacketBuffers {
if quicMoreData() {
err = sniff.PeekPacket(
ctx,
metadata,
packetBuffer.Buffer.Bytes(),
sniff.QUICClientHello,
)
} else {
err = sniff.PeekPacket(
ctx, metadata,
packetBuffer.Buffer.Bytes(),
packetSniffers...,
)
}
metadata.SnifferNames = action.SnifferNames
metadata.SniffError = err
if errors.Is(err, sniff.ErrNeedMoreData) {
// TODO: replace with generic message when there are more multi-packet protocols
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
continue
}
goto finally
}
packetBuffers = inputPacketBuffers
for { for {
var ( var (
sniffBuffer = buf.NewPacket() sniffBuffer = buf.NewPacket()
destination M.Socksaddr destination M.Socksaddr
done = make(chan struct{}) done = make(chan struct{})
err error
) )
go func() { go func() {
sniffTimeout := C.ReadPayloadTimeout sniffTimeout := C.ReadPayloadTimeout
@ -602,7 +633,7 @@ func (r *Router) actionSniff(
return return
} }
} else { } else {
if len(packetBuffers) > 0 || metadata.SniffError != nil { if quicMoreData() {
err = sniff.PeekPacket( err = sniff.PeekPacket(
ctx, ctx,
metadata, metadata,
@ -622,32 +653,34 @@ func (r *Router) actionSniff(
Destination: destination, Destination: destination,
} }
packetBuffers = append(packetBuffers, packetBuffer) packetBuffers = append(packetBuffers, packetBuffer)
metadata.SnifferNames = action.SnifferNames
metadata.SniffError = err metadata.SniffError = err
if errors.Is(err, sniff.ErrNeedMoreData) { if errors.Is(err, sniff.ErrNeedMoreData) {
// TODO: replace with generic message when there are more multi-packet protocols // TODO: replace with generic message when there are more multi-packet protocols
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
continue continue
} }
if metadata.Protocol != "" { }
//goland:noinspection GoDeprecation goto finally
if action.OverrideDestination && M.IsDomainName(metadata.Domain) { }
metadata.Destination = M.Socksaddr{ finally:
Fqdn: metadata.Domain, if err == nil {
Port: metadata.Destination.Port, //goland:noinspection GoDeprecation
} if action.OverrideDestination && M.IsDomainName(metadata.Domain) {
} metadata.Destination = M.Socksaddr{
if metadata.Domain != "" && metadata.Client != "" { Fqdn: metadata.Domain,
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client) Port: metadata.Destination.Port,
} else if metadata.Domain != "" {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
} else if metadata.Client != "" {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client)
} else {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
}
} }
} }
break if metadata.Domain != "" && metadata.Client != "" {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain, ", client: ", metadata.Client)
} else if metadata.Domain != "" {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
} else if metadata.Client != "" {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", client: ", metadata.Client)
} else {
r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
}
} }
} }
return return

View File

@ -87,7 +87,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
return &RuleActionHijackDNS{}, nil return &RuleActionHijackDNS{}, nil
case C.RuleActionTypeSniff: case C.RuleActionTypeSniff:
sniffAction := &RuleActionSniff{ sniffAction := &RuleActionSniff{
snifferNames: action.SniffOptions.Sniffer, SnifferNames: action.SniffOptions.Sniffer,
Timeout: time.Duration(action.SniffOptions.Timeout), Timeout: time.Duration(action.SniffOptions.Timeout),
} }
return sniffAction, sniffAction.build() return sniffAction, sniffAction.build()
@ -361,7 +361,7 @@ func (r *RuleActionHijackDNS) String() string {
} }
type RuleActionSniff struct { type RuleActionSniff struct {
snifferNames []string SnifferNames []string
StreamSniffers []sniff.StreamSniffer StreamSniffers []sniff.StreamSniffer
PacketSniffers []sniff.PacketSniffer PacketSniffers []sniff.PacketSniffer
Timeout time.Duration Timeout time.Duration
@ -374,7 +374,7 @@ func (r *RuleActionSniff) Type() string {
} }
func (r *RuleActionSniff) build() error { func (r *RuleActionSniff) build() error {
for _, name := range r.snifferNames { for _, name := range r.SnifferNames {
switch name { switch name {
case C.ProtocolTLS: case C.ProtocolTLS:
r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello) r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello)
@ -407,14 +407,14 @@ func (r *RuleActionSniff) build() error {
} }
func (r *RuleActionSniff) String() string { func (r *RuleActionSniff) String() string {
if len(r.snifferNames) == 0 && r.Timeout == 0 { if len(r.SnifferNames) == 0 && r.Timeout == 0 {
return "sniff" return "sniff"
} else if len(r.snifferNames) > 0 && r.Timeout == 0 { } else if len(r.SnifferNames) > 0 && r.Timeout == 0 {
return F.ToString("sniff(", strings.Join(r.snifferNames, ","), ")") return F.ToString("sniff(", strings.Join(r.SnifferNames, ","), ")")
} else if len(r.snifferNames) == 0 && r.Timeout > 0 { } else if len(r.SnifferNames) == 0 && r.Timeout > 0 {
return F.ToString("sniff(", r.Timeout.String(), ")") return F.ToString("sniff(", r.Timeout.String(), ")")
} else { } else {
return F.ToString("sniff(", strings.Join(r.snifferNames, ","), ",", r.Timeout.String(), ")") return F.ToString("sniff(", strings.Join(r.SnifferNames, ","), ",", r.Timeout.String(), ")")
} }
} }