Add TLS record fragment support

This commit is contained in:
世界 2025-05-12 12:03:34 +08:00
parent c2b7d5bd12
commit 0be2233795
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
6 changed files with 89 additions and 28 deletions

View File

@ -74,6 +74,7 @@ type InboundContext struct {
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
TLSRecordFragment bool
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType

View File

@ -1,7 +1,9 @@
package tf
import (
"bytes"
"context"
"encoding/binary"
"math/rand"
"net"
"strings"
@ -17,17 +19,19 @@ type Conn struct {
tcpConn *net.TCPConn
ctx context.Context
firstPacketWritten bool
splitRecord bool
fallbackDelay time.Duration
}
func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) {
func NewConn(conn net.Conn, ctx context.Context, splitRecord bool, fallbackDelay time.Duration) *Conn {
tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
return &Conn{
Conn: conn,
tcpConn: tcpConn,
ctx: ctx,
splitRecord: splitRecord,
fallbackDelay: fallbackDelay,
}, nil
}
}
func (c *Conn) Write(b []byte) (n int, err error) {
@ -37,12 +41,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
}()
serverName := indexTLSServerName(b)
if serverName != nil {
if !c.splitRecord {
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(true)
if err != nil {
return
}
}
}
splits := strings.Split(serverName.ServerName, ".")
currentIndex := serverName.Index
if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
@ -61,16 +67,25 @@ func (c *Conn) Write(b []byte) (n int, err error) {
currentIndex++
}
}
var buffer bytes.Buffer
for i := 0; i <= len(splitIndexes); i++ {
var payload []byte
if i == 0 {
payload = b[:splitIndexes[i]]
if c.splitRecord {
payload = payload[recordLayerHeaderLen:]
}
} else if i == len(splitIndexes) {
payload = b[splitIndexes[i-1]:]
} else {
payload = b[splitIndexes[i-1]:splitIndexes[i]]
}
if c.tcpConn != nil && i != len(splitIndexes) {
if c.splitRecord {
payloadLen := uint16(len(payload))
buffer.Write(b[:3])
binary.Write(&buffer, binary.BigEndian, payloadLen)
buffer.Write(payload)
} else if c.tcpConn != nil && i != len(splitIndexes) {
err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
if err != nil {
return
@ -82,12 +97,19 @@ func (c *Conn) Write(b []byte) (n int, err error) {
}
}
}
if c.splitRecord {
_, err = c.tcpConn.Write(buffer.Bytes())
if err != nil {
return
}
} else {
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(false)
if err != nil {
return
}
}
}
return len(b), nil
}
}

View File

@ -0,0 +1,32 @@
package tf_test
import (
"context"
"crypto/tls"
"net"
"testing"
tf "github.com/sagernet/sing-box/common/tlsfragment"
"github.com/stretchr/testify/require"
)
func TestTLSFragment(t *testing.T) {
t.Parallel()
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
require.NoError(t, err)
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), false, 0), &tls.Config{
ServerName: "www.cloudflare.com",
})
require.NoError(t, tlsConn.Handshake())
}
func TestTLSRecordFragment(t *testing.T) {
t.Parallel()
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
require.NoError(t, err)
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), true, 0), &tls.Config{
ServerName: "www.cloudflare.com",
})
require.NoError(t, tlsConn.Handshake())
}

View File

@ -158,6 +158,7 @@ type RawRouteOptionsActionOptions struct {
TLSFragment bool `json:"tls_fragment,omitempty"`
TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
TLSRecordFragment bool `json:"tls_record_fragment,omitempty"`
}
type RouteOptionsActionOptions RawRouteOptionsActionOptions
@ -170,6 +171,9 @@ func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
if *r == (RouteOptionsActionOptions{}) {
return E.New("empty route option action")
}
if r.TLSFragment && r.TLSRecordFragment {
return E.New("`tls_fragment` and `tls_record_fragment` are mutually exclusive")
}
return nil
}

View File

@ -95,15 +95,9 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
if fallbackDelay == 0 {
fallbackDelay = C.TLSFragmentFallbackDelay
}
var newConn *tf.Conn
newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay)
if err != nil {
conn.Close()
remoteConn.Close()
m.logger.ErrorContext(ctx, err)
return
}
remoteConn = newConn
remoteConn = tf.NewConn(remoteConn, ctx, false, fallbackDelay)
} else if metadata.TLSRecordFragment {
remoteConn = tf.NewConn(remoteConn, ctx, true, 0)
}
m.access.Lock()
element := m.connections.PushBack(conn)

View File

@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPConnect: action.RouteOptions.UDPConnect,
TLSFragment: action.RouteOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
TLSRecordFragment: action.RouteOptions.TLSRecordFragment,
},
}, nil
case C.RuleActionTypeRouteOptions:
@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
TLSFragment: action.RouteOptionsOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
TLSRecordFragment: action.RouteOptionsOptions.TLSRecordFragment,
}, nil
case C.RuleActionTypeDirect:
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string {
func (r *RuleActionRoute) String() string {
var descriptions []string
descriptions = append(descriptions, r.Outbound)
if r.UDPDisableDomainUnmapping {
descriptions = append(descriptions, "udp-disable-domain-unmapping")
}
if r.UDPConnect {
descriptions = append(descriptions, "udp-connect")
}
if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
}
descriptions = append(descriptions, r.Descriptions()...)
return F.ToString("route(", strings.Join(descriptions, ","), ")")
}
@ -176,6 +170,7 @@ type RuleActionRouteOptions struct {
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
TLSRecordFragment bool
}
func (r *RuleActionRouteOptions) Type() string {
@ -183,6 +178,10 @@ func (r *RuleActionRouteOptions) Type() string {
}
func (r *RuleActionRouteOptions) String() string {
return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")")
}
func (r *RuleActionRouteOptions) Descriptions() []string {
var descriptions []string
if r.OverrideAddress.IsValid() {
descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString()))
@ -211,7 +210,16 @@ func (r *RuleActionRouteOptions) String() string {
if r.UDPTimeout > 0 {
descriptions = append(descriptions, "udp-timeout")
}
return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
}
if r.TLSFragmentFallbackDelay > 0 {
descriptions = append(descriptions, F.ToString("tls-fragment-fallback-delay=", r.TLSFragmentFallbackDelay.String()))
}
if r.TLSRecordFragment {
descriptions = append(descriptions, "tls-record-fragment")
}
return descriptions
}
type RuleActionDNSRoute struct {