mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-08-26 20:17:35 +08:00
Add TLS record fragment support
This commit is contained in:
parent
c2b7d5bd12
commit
0be2233795
@ -74,6 +74,7 @@ type InboundContext struct {
|
||||
UDPTimeout time.Duration
|
||||
TLSFragment bool
|
||||
TLSFragmentFallbackDelay time.Duration
|
||||
TLSRecordFragment bool
|
||||
|
||||
NetworkStrategy *C.NetworkStrategy
|
||||
NetworkType []C.InterfaceType
|
||||
|
@ -1,7 +1,9 @@
|
||||
package tf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
@ -17,17 +19,19 @@ type Conn struct {
|
||||
tcpConn *net.TCPConn
|
||||
ctx context.Context
|
||||
firstPacketWritten bool
|
||||
splitRecord bool
|
||||
fallbackDelay time.Duration
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) {
|
||||
func NewConn(conn net.Conn, ctx context.Context, splitRecord bool, fallbackDelay time.Duration) *Conn {
|
||||
tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
tcpConn: tcpConn,
|
||||
ctx: ctx,
|
||||
splitRecord: splitRecord,
|
||||
fallbackDelay: fallbackDelay,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
@ -37,10 +41,12 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
}()
|
||||
serverName := indexTLSServerName(b)
|
||||
if serverName != nil {
|
||||
if c.tcpConn != nil {
|
||||
err = c.tcpConn.SetNoDelay(true)
|
||||
if err != nil {
|
||||
return
|
||||
if !c.splitRecord {
|
||||
if c.tcpConn != nil {
|
||||
err = c.tcpConn.SetNoDelay(true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
splits := strings.Split(serverName.ServerName, ".")
|
||||
@ -61,16 +67,25 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
currentIndex++
|
||||
}
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
for i := 0; i <= len(splitIndexes); i++ {
|
||||
var payload []byte
|
||||
if i == 0 {
|
||||
payload = b[:splitIndexes[i]]
|
||||
if c.splitRecord {
|
||||
payload = payload[recordLayerHeaderLen:]
|
||||
}
|
||||
} else if i == len(splitIndexes) {
|
||||
payload = b[splitIndexes[i-1]:]
|
||||
} else {
|
||||
payload = b[splitIndexes[i-1]:splitIndexes[i]]
|
||||
}
|
||||
if c.tcpConn != nil && i != len(splitIndexes) {
|
||||
if c.splitRecord {
|
||||
payloadLen := uint16(len(payload))
|
||||
buffer.Write(b[:3])
|
||||
binary.Write(&buffer, binary.BigEndian, payloadLen)
|
||||
buffer.Write(payload)
|
||||
} else if c.tcpConn != nil && i != len(splitIndexes) {
|
||||
err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
|
||||
if err != nil {
|
||||
return
|
||||
@ -82,11 +97,18 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.tcpConn != nil {
|
||||
err = c.tcpConn.SetNoDelay(false)
|
||||
if c.splitRecord {
|
||||
_, err = c.tcpConn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if c.tcpConn != nil {
|
||||
err = c.tcpConn.SetNoDelay(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
32
common/tlsfragment/conn_test.go
Normal file
32
common/tlsfragment/conn_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package tf_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
tf "github.com/sagernet/sing-box/common/tlsfragment"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTLSFragment(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
|
||||
require.NoError(t, err)
|
||||
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), false, 0), &tls.Config{
|
||||
ServerName: "www.cloudflare.com",
|
||||
})
|
||||
require.NoError(t, tlsConn.Handshake())
|
||||
}
|
||||
|
||||
func TestTLSRecordFragment(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
|
||||
require.NoError(t, err)
|
||||
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), true, 0), &tls.Config{
|
||||
ServerName: "www.cloudflare.com",
|
||||
})
|
||||
require.NoError(t, tlsConn.Handshake())
|
||||
}
|
@ -158,6 +158,7 @@ type RawRouteOptionsActionOptions struct {
|
||||
|
||||
TLSFragment bool `json:"tls_fragment,omitempty"`
|
||||
TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
|
||||
TLSRecordFragment bool `json:"tls_record_fragment,omitempty"`
|
||||
}
|
||||
|
||||
type RouteOptionsActionOptions RawRouteOptionsActionOptions
|
||||
@ -170,6 +171,9 @@ func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error {
|
||||
if *r == (RouteOptionsActionOptions{}) {
|
||||
return E.New("empty route option action")
|
||||
}
|
||||
if r.TLSFragment && r.TLSRecordFragment {
|
||||
return E.New("`tls_fragment` and `tls_record_fragment` are mutually exclusive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -95,15 +95,9 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
|
||||
if fallbackDelay == 0 {
|
||||
fallbackDelay = C.TLSFragmentFallbackDelay
|
||||
}
|
||||
var newConn *tf.Conn
|
||||
newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
remoteConn.Close()
|
||||
m.logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
remoteConn = newConn
|
||||
remoteConn = tf.NewConn(remoteConn, ctx, false, fallbackDelay)
|
||||
} else if metadata.TLSRecordFragment {
|
||||
remoteConn = tf.NewConn(remoteConn, ctx, true, 0)
|
||||
}
|
||||
m.access.Lock()
|
||||
element := m.connections.PushBack(conn)
|
||||
|
@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
|
||||
UDPConnect: action.RouteOptions.UDPConnect,
|
||||
TLSFragment: action.RouteOptions.TLSFragment,
|
||||
TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
|
||||
TLSRecordFragment: action.RouteOptions.TLSRecordFragment,
|
||||
},
|
||||
}, nil
|
||||
case C.RuleActionTypeRouteOptions:
|
||||
@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
|
||||
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
|
||||
TLSFragment: action.RouteOptionsOptions.TLSFragment,
|
||||
TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
|
||||
TLSRecordFragment: action.RouteOptionsOptions.TLSRecordFragment,
|
||||
}, nil
|
||||
case C.RuleActionTypeDirect:
|
||||
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
|
||||
@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string {
|
||||
func (r *RuleActionRoute) String() string {
|
||||
var descriptions []string
|
||||
descriptions = append(descriptions, r.Outbound)
|
||||
if r.UDPDisableDomainUnmapping {
|
||||
descriptions = append(descriptions, "udp-disable-domain-unmapping")
|
||||
}
|
||||
if r.UDPConnect {
|
||||
descriptions = append(descriptions, "udp-connect")
|
||||
}
|
||||
if r.TLSFragment {
|
||||
descriptions = append(descriptions, "tls-fragment")
|
||||
}
|
||||
descriptions = append(descriptions, r.Descriptions()...)
|
||||
return F.ToString("route(", strings.Join(descriptions, ","), ")")
|
||||
}
|
||||
|
||||
@ -176,6 +170,7 @@ type RuleActionRouteOptions struct {
|
||||
UDPTimeout time.Duration
|
||||
TLSFragment bool
|
||||
TLSFragmentFallbackDelay time.Duration
|
||||
TLSRecordFragment bool
|
||||
}
|
||||
|
||||
func (r *RuleActionRouteOptions) Type() string {
|
||||
@ -183,6 +178,10 @@ func (r *RuleActionRouteOptions) Type() string {
|
||||
}
|
||||
|
||||
func (r *RuleActionRouteOptions) String() string {
|
||||
return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")")
|
||||
}
|
||||
|
||||
func (r *RuleActionRouteOptions) Descriptions() []string {
|
||||
var descriptions []string
|
||||
if r.OverrideAddress.IsValid() {
|
||||
descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString()))
|
||||
@ -211,7 +210,16 @@ func (r *RuleActionRouteOptions) String() string {
|
||||
if r.UDPTimeout > 0 {
|
||||
descriptions = append(descriptions, "udp-timeout")
|
||||
}
|
||||
return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
|
||||
if r.TLSFragment {
|
||||
descriptions = append(descriptions, "tls-fragment")
|
||||
}
|
||||
if r.TLSFragmentFallbackDelay > 0 {
|
||||
descriptions = append(descriptions, F.ToString("tls-fragment-fallback-delay=", r.TLSFragmentFallbackDelay.String()))
|
||||
}
|
||||
if r.TLSRecordFragment {
|
||||
descriptions = append(descriptions, "tls-record-fragment")
|
||||
}
|
||||
return descriptions
|
||||
}
|
||||
|
||||
type RuleActionDNSRoute struct {
|
||||
|
Loading…
x
Reference in New Issue
Block a user