diff --git a/adapter/inbound.go b/adapter/inbound.go index 1218c049..fd07ff5c 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -74,6 +74,7 @@ type InboundContext struct { UDPTimeout time.Duration TLSFragment bool TLSFragmentFallbackDelay time.Duration + TLSRecordFragment bool NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType diff --git a/common/tlsfragment/conn.go b/common/tlsfragment/conn.go index 6f2a3dad..c224b689 100644 --- a/common/tlsfragment/conn.go +++ b/common/tlsfragment/conn.go @@ -1,7 +1,9 @@ package tf import ( + "bytes" "context" + "encoding/binary" "math/rand" "net" "strings" @@ -17,17 +19,19 @@ type Conn struct { tcpConn *net.TCPConn ctx context.Context firstPacketWritten bool + splitRecord bool fallbackDelay time.Duration } -func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) { +func NewConn(conn net.Conn, ctx context.Context, splitRecord bool, fallbackDelay time.Duration) *Conn { tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn) return &Conn{ Conn: conn, tcpConn: tcpConn, ctx: ctx, + splitRecord: splitRecord, fallbackDelay: fallbackDelay, - }, nil + } } func (c *Conn) Write(b []byte) (n int, err error) { @@ -37,10 +41,12 @@ func (c *Conn) Write(b []byte) (n int, err error) { }() serverName := indexTLSServerName(b) if serverName != nil { - if c.tcpConn != nil { - err = c.tcpConn.SetNoDelay(true) - if err != nil { - return + if !c.splitRecord { + if c.tcpConn != nil { + err = c.tcpConn.SetNoDelay(true) + if err != nil { + return + } } } splits := strings.Split(serverName.ServerName, ".") @@ -61,16 +67,25 @@ func (c *Conn) Write(b []byte) (n int, err error) { currentIndex++ } } + var buffer bytes.Buffer for i := 0; i <= len(splitIndexes); i++ { var payload []byte if i == 0 { payload = b[:splitIndexes[i]] + if c.splitRecord { + payload = payload[recordLayerHeaderLen:] + } } else if i == len(splitIndexes) { payload = b[splitIndexes[i-1]:] } else { payload = b[splitIndexes[i-1]:splitIndexes[i]] } - if c.tcpConn != nil && i != len(splitIndexes) { + if c.splitRecord { + payloadLen := uint16(len(payload)) + buffer.Write(b[:3]) + binary.Write(&buffer, binary.BigEndian, payloadLen) + buffer.Write(payload) + } else if c.tcpConn != nil && i != len(splitIndexes) { err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay) if err != nil { return @@ -82,11 +97,18 @@ func (c *Conn) Write(b []byte) (n int, err error) { } } } - if c.tcpConn != nil { - err = c.tcpConn.SetNoDelay(false) + if c.splitRecord { + _, err = c.tcpConn.Write(buffer.Bytes()) if err != nil { return } + } else { + if c.tcpConn != nil { + err = c.tcpConn.SetNoDelay(false) + if err != nil { + return + } + } } return len(b), nil } diff --git a/common/tlsfragment/conn_test.go b/common/tlsfragment/conn_test.go new file mode 100644 index 00000000..21e2fcb2 --- /dev/null +++ b/common/tlsfragment/conn_test.go @@ -0,0 +1,32 @@ +package tf_test + +import ( + "context" + "crypto/tls" + "net" + "testing" + + tf "github.com/sagernet/sing-box/common/tlsfragment" + + "github.com/stretchr/testify/require" +) + +func TestTLSFragment(t *testing.T) { + t.Parallel() + tcpConn, err := net.Dial("tcp", "1.1.1.1:443") + require.NoError(t, err) + tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), false, 0), &tls.Config{ + ServerName: "www.cloudflare.com", + }) + require.NoError(t, tlsConn.Handshake()) +} + +func TestTLSRecordFragment(t *testing.T) { + t.Parallel() + tcpConn, err := net.Dial("tcp", "1.1.1.1:443") + require.NoError(t, err) + tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), true, 0), &tls.Config{ + ServerName: "www.cloudflare.com", + }) + require.NoError(t, tlsConn.Handshake()) +} diff --git a/option/rule_action.go b/option/rule_action.go index 7c05dce6..914edb84 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -158,6 +158,7 @@ type RawRouteOptionsActionOptions struct { TLSFragment bool `json:"tls_fragment,omitempty"` TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"` + TLSRecordFragment bool `json:"tls_record_fragment,omitempty"` } type RouteOptionsActionOptions RawRouteOptionsActionOptions @@ -170,6 +171,9 @@ func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error { if *r == (RouteOptionsActionOptions{}) { return E.New("empty route option action") } + if r.TLSFragment && r.TLSRecordFragment { + return E.New("`tls_fragment` and `tls_record_fragment` are mutually exclusive") + } return nil } diff --git a/route/conn.go b/route/conn.go index f46283ad..d5f914b8 100644 --- a/route/conn.go +++ b/route/conn.go @@ -95,15 +95,9 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co if fallbackDelay == 0 { fallbackDelay = C.TLSFragmentFallbackDelay } - var newConn *tf.Conn - newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay) - if err != nil { - conn.Close() - remoteConn.Close() - m.logger.ErrorContext(ctx, err) - return - } - remoteConn = newConn + remoteConn = tf.NewConn(remoteConn, ctx, false, fallbackDelay) + } else if metadata.TLSRecordFragment { + remoteConn = tf.NewConn(remoteConn, ctx, true, 0) } m.access.Lock() element := m.connections.PushBack(conn) diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 098b9d3a..a2fcf911 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPConnect: action.RouteOptions.UDPConnect, TLSFragment: action.RouteOptions.TLSFragment, TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay), + TLSRecordFragment: action.RouteOptions.TLSRecordFragment, }, }, nil case C.RuleActionTypeRouteOptions: @@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), TLSFragment: action.RouteOptionsOptions.TLSFragment, TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay), + TLSRecordFragment: action.RouteOptionsOptions.TLSRecordFragment, }, nil case C.RuleActionTypeDirect: directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false) @@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string { func (r *RuleActionRoute) String() string { var descriptions []string descriptions = append(descriptions, r.Outbound) - if r.UDPDisableDomainUnmapping { - descriptions = append(descriptions, "udp-disable-domain-unmapping") - } - if r.UDPConnect { - descriptions = append(descriptions, "udp-connect") - } - if r.TLSFragment { - descriptions = append(descriptions, "tls-fragment") - } + descriptions = append(descriptions, r.Descriptions()...) return F.ToString("route(", strings.Join(descriptions, ","), ")") } @@ -176,6 +170,7 @@ type RuleActionRouteOptions struct { UDPTimeout time.Duration TLSFragment bool TLSFragmentFallbackDelay time.Duration + TLSRecordFragment bool } func (r *RuleActionRouteOptions) Type() string { @@ -183,6 +178,10 @@ func (r *RuleActionRouteOptions) Type() string { } func (r *RuleActionRouteOptions) String() string { + return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")") +} + +func (r *RuleActionRouteOptions) Descriptions() []string { var descriptions []string if r.OverrideAddress.IsValid() { descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString())) @@ -211,7 +210,16 @@ func (r *RuleActionRouteOptions) String() string { if r.UDPTimeout > 0 { descriptions = append(descriptions, "udp-timeout") } - return F.ToString("route-options(", strings.Join(descriptions, ","), ")") + if r.TLSFragment { + descriptions = append(descriptions, "tls-fragment") + } + if r.TLSFragmentFallbackDelay > 0 { + descriptions = append(descriptions, F.ToString("tls-fragment-fallback-delay=", r.TLSFragmentFallbackDelay.String())) + } + if r.TLSRecordFragment { + descriptions = append(descriptions, "tls-record-fragment") + } + return descriptions } type RuleActionDNSRoute struct {