Add TLS record fragment support

This commit is contained in:
世界 2025-05-12 12:03:34 +08:00
parent c2b7d5bd12
commit 0be2233795
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
6 changed files with 89 additions and 28 deletions

View File

@ -74,6 +74,7 @@ type InboundContext struct {
UDPTimeout time.Duration UDPTimeout time.Duration
TLSFragment bool TLSFragment bool
TLSFragmentFallbackDelay time.Duration TLSFragmentFallbackDelay time.Duration
TLSRecordFragment bool
NetworkStrategy *C.NetworkStrategy NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType NetworkType []C.InterfaceType

View File

@ -1,7 +1,9 @@
package tf package tf
import ( import (
"bytes"
"context" "context"
"encoding/binary"
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
@ -17,17 +19,19 @@ type Conn struct {
tcpConn *net.TCPConn tcpConn *net.TCPConn
ctx context.Context ctx context.Context
firstPacketWritten bool firstPacketWritten bool
splitRecord bool
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) { func NewConn(conn net.Conn, ctx context.Context, splitRecord bool, fallbackDelay time.Duration) *Conn {
tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn) tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
return &Conn{ return &Conn{
Conn: conn, Conn: conn,
tcpConn: tcpConn, tcpConn: tcpConn,
ctx: ctx, ctx: ctx,
splitRecord: splitRecord,
fallbackDelay: fallbackDelay, fallbackDelay: fallbackDelay,
}, nil }
} }
func (c *Conn) Write(b []byte) (n int, err error) { func (c *Conn) Write(b []byte) (n int, err error) {
@ -37,10 +41,12 @@ func (c *Conn) Write(b []byte) (n int, err error) {
}() }()
serverName := indexTLSServerName(b) serverName := indexTLSServerName(b)
if serverName != nil { if serverName != nil {
if c.tcpConn != nil { if !c.splitRecord {
err = c.tcpConn.SetNoDelay(true) if c.tcpConn != nil {
if err != nil { err = c.tcpConn.SetNoDelay(true)
return if err != nil {
return
}
} }
} }
splits := strings.Split(serverName.ServerName, ".") splits := strings.Split(serverName.ServerName, ".")
@ -61,16 +67,25 @@ func (c *Conn) Write(b []byte) (n int, err error) {
currentIndex++ currentIndex++
} }
} }
var buffer bytes.Buffer
for i := 0; i <= len(splitIndexes); i++ { for i := 0; i <= len(splitIndexes); i++ {
var payload []byte var payload []byte
if i == 0 { if i == 0 {
payload = b[:splitIndexes[i]] payload = b[:splitIndexes[i]]
if c.splitRecord {
payload = payload[recordLayerHeaderLen:]
}
} else if i == len(splitIndexes) { } else if i == len(splitIndexes) {
payload = b[splitIndexes[i-1]:] payload = b[splitIndexes[i-1]:]
} else { } else {
payload = b[splitIndexes[i-1]:splitIndexes[i]] payload = b[splitIndexes[i-1]:splitIndexes[i]]
} }
if c.tcpConn != nil && i != len(splitIndexes) { if c.splitRecord {
payloadLen := uint16(len(payload))
buffer.Write(b[:3])
binary.Write(&buffer, binary.BigEndian, payloadLen)
buffer.Write(payload)
} else if c.tcpConn != nil && i != len(splitIndexes) {
err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay) err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
if err != nil { if err != nil {
return return
@ -82,11 +97,18 @@ func (c *Conn) Write(b []byte) (n int, err error) {
} }
} }
} }
if c.tcpConn != nil { if c.splitRecord {
err = c.tcpConn.SetNoDelay(false) _, err = c.tcpConn.Write(buffer.Bytes())
if err != nil { if err != nil {
return return
} }
} else {
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(false)
if err != nil {
return
}
}
} }
return len(b), nil return len(b), nil
} }

View File

@ -0,0 +1,32 @@
package tf_test
import (
"context"
"crypto/tls"
"net"
"testing"
tf "github.com/sagernet/sing-box/common/tlsfragment"
"github.com/stretchr/testify/require"
)
func TestTLSFragment(t *testing.T) {
t.Parallel()
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
require.NoError(t, err)
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), false, 0), &tls.Config{
ServerName: "www.cloudflare.com",
})
require.NoError(t, tlsConn.Handshake())
}
func TestTLSRecordFragment(t *testing.T) {
t.Parallel()
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
require.NoError(t, err)
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), true, 0), &tls.Config{
ServerName: "www.cloudflare.com",
})
require.NoError(t, tlsConn.Handshake())
}

View File

@ -158,6 +158,7 @@ type RawRouteOptionsActionOptions struct {
TLSFragment bool `json:"tls_fragment,omitempty"` TLSFragment bool `json:"tls_fragment,omitempty"`
TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"` TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
TLSRecordFragment bool `json:"tls_record_fragment,omitempty"`
} }
type RouteOptionsActionOptions RawRouteOptionsActionOptions type RouteOptionsActionOptions RawRouteOptionsActionOptions
@ -170,6 +171,9 @@ func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
if *r == (RouteOptionsActionOptions{}) { if *r == (RouteOptionsActionOptions{}) {
return E.New("empty route option action") return E.New("empty route option action")
} }
if r.TLSFragment && r.TLSRecordFragment {
return E.New("`tls_fragment` and `tls_record_fragment` are mutually exclusive")
}
return nil return nil
} }

View File

@ -95,15 +95,9 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
if fallbackDelay == 0 { if fallbackDelay == 0 {
fallbackDelay = C.TLSFragmentFallbackDelay fallbackDelay = C.TLSFragmentFallbackDelay
} }
var newConn *tf.Conn remoteConn = tf.NewConn(remoteConn, ctx, false, fallbackDelay)
newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay) } else if metadata.TLSRecordFragment {
if err != nil { remoteConn = tf.NewConn(remoteConn, ctx, true, 0)
conn.Close()
remoteConn.Close()
m.logger.ErrorContext(ctx, err)
return
}
remoteConn = newConn
} }
m.access.Lock() m.access.Lock()
element := m.connections.PushBack(conn) element := m.connections.PushBack(conn)

View File

@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPConnect: action.RouteOptions.UDPConnect, UDPConnect: action.RouteOptions.UDPConnect,
TLSFragment: action.RouteOptions.TLSFragment, TLSFragment: action.RouteOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay), TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
TLSRecordFragment: action.RouteOptions.TLSRecordFragment,
}, },
}, nil }, nil
case C.RuleActionTypeRouteOptions: case C.RuleActionTypeRouteOptions:
@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
TLSFragment: action.RouteOptionsOptions.TLSFragment, TLSFragment: action.RouteOptionsOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay), TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
TLSRecordFragment: action.RouteOptionsOptions.TLSRecordFragment,
}, nil }, nil
case C.RuleActionTypeDirect: case C.RuleActionTypeDirect:
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false) directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string {
func (r *RuleActionRoute) String() string { func (r *RuleActionRoute) String() string {
var descriptions []string var descriptions []string
descriptions = append(descriptions, r.Outbound) descriptions = append(descriptions, r.Outbound)
if r.UDPDisableDomainUnmapping { descriptions = append(descriptions, r.Descriptions()...)
descriptions = append(descriptions, "udp-disable-domain-unmapping")
}
if r.UDPConnect {
descriptions = append(descriptions, "udp-connect")
}
if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
}
return F.ToString("route(", strings.Join(descriptions, ","), ")") return F.ToString("route(", strings.Join(descriptions, ","), ")")
} }
@ -176,6 +170,7 @@ type RuleActionRouteOptions struct {
UDPTimeout time.Duration UDPTimeout time.Duration
TLSFragment bool TLSFragment bool
TLSFragmentFallbackDelay time.Duration TLSFragmentFallbackDelay time.Duration
TLSRecordFragment bool
} }
func (r *RuleActionRouteOptions) Type() string { func (r *RuleActionRouteOptions) Type() string {
@ -183,6 +178,10 @@ func (r *RuleActionRouteOptions) Type() string {
} }
func (r *RuleActionRouteOptions) String() string { func (r *RuleActionRouteOptions) String() string {
return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")")
}
func (r *RuleActionRouteOptions) Descriptions() []string {
var descriptions []string var descriptions []string
if r.OverrideAddress.IsValid() { if r.OverrideAddress.IsValid() {
descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString())) descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString()))
@ -211,7 +210,16 @@ func (r *RuleActionRouteOptions) String() string {
if r.UDPTimeout > 0 { if r.UDPTimeout > 0 {
descriptions = append(descriptions, "udp-timeout") descriptions = append(descriptions, "udp-timeout")
} }
return F.ToString("route-options(", strings.Join(descriptions, ","), ")") if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
}
if r.TLSFragmentFallbackDelay > 0 {
descriptions = append(descriptions, F.ToString("tls-fragment-fallback-delay=", r.TLSFragmentFallbackDelay.String()))
}
if r.TLSRecordFragment {
descriptions = append(descriptions, "tls-record-fragment")
}
return descriptions
} }
type RuleActionDNSRoute struct { type RuleActionDNSRoute struct {