diff --git a/common/tls/client.go b/common/tls/client.go index 5e05c990..d45d6173 100644 --- a/common/tls/client.go +++ b/common/tls/client.go @@ -2,10 +2,11 @@ package tls import ( "context" + "crypto/tls" + "errors" "net" "os" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/badtls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" @@ -14,7 +15,7 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -func NewDialerFromOptions(ctx context.Context, router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) { +func NewDialerFromOptions(ctx context.Context, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) { if !options.Enabled { return dialer, nil } @@ -53,26 +54,57 @@ func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, e return tlsConn, nil } -type Dialer struct { +type Dialer interface { + N.Dialer + DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) +} + +type defaultDialer struct { dialer N.Dialer config Config } -func NewDialer(dialer N.Dialer, config Config) N.Dialer { - return &Dialer{dialer, config} +func NewDialer(dialer N.Dialer, config Config) Dialer { + return &defaultDialer{dialer, config} } -func (d *Dialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if network != N.NetworkTCP { +func (d *defaultDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if N.NetworkName(network) != N.NetworkTCP { return nil, os.ErrInvalid } - conn, err := d.dialer.DialContext(ctx, network, destination) + return d.DialTLSContext(ctx, destination) +} + +func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return nil, os.ErrInvalid +} + +func (d *defaultDialer) DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) { + return d.dialContext(ctx, destination, true) +} + +func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr, echRetry bool) (Conn, error) { + conn, err := d.dialer.DialContext(ctx, N.NetworkTCP, destination) if err != nil { return nil, err } - return ClientHandshake(ctx, conn, d.config) + tlsConn, err := ClientHandshake(ctx, conn, d.config) + if err == nil { + return tlsConn, nil + } + conn.Close() + if echRetry { + var echErr *tls.ECHRejectionError + if errors.As(err, &echErr) && len(echErr.RetryConfigList) > 0 { + if echConfig, isECH := d.config.(ECHCapableConfig); isECH { + echConfig.SetECHConfigList(echErr.RetryConfigList) + } + } + return d.dialContext(ctx, destination, false) + } + return nil, err } -func (d *Dialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return nil, os.ErrInvalid +func (d *defaultDialer) Upstream() any { + return d.dialer } diff --git a/dns/transport/tls.go b/dns/transport/tls.go index afa988cc..9cb35fd9 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -30,7 +30,7 @@ func RegisterTLS(registry *dns.TransportRegistry) { type TLSTransport struct { dns.TransportAdapter logger logger.ContextLogger - dialer N.Dialer + dialer tls.Dialer serverAddr M.Socksaddr tlsConfig tls.Config access sync.Mutex @@ -67,7 +67,7 @@ func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer return &TLSTransport{ TransportAdapter: adapter, logger: logger, - dialer: dialer, + dialer: tls.NewDialer(dialer, tlsConfig), serverAddr: serverAddr, tlsConfig: tlsConfig, } @@ -100,15 +100,10 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return response, nil } } - tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) + tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) if err != nil { return nil, err } - tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig) - if err != nil { - tcpConn.Close() - return nil, err - } return t.exchange(message, &tlsDNSConn{Conn: tlsConn}) } diff --git a/protocol/anytls/outbound.go b/protocol/anytls/outbound.go index b026ed18..ce71ed61 100644 --- a/protocol/anytls/outbound.go +++ b/protocol/anytls/outbound.go @@ -26,7 +26,7 @@ func RegisterOutbound(registry *outbound.Registry) { type Outbound struct { outbound.Adapter - dialer N.Dialer + dialer tls.Dialer server M.Socksaddr tlsConfig tls.Config client *anytls.Client @@ -58,7 +58,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } - outbound.dialer = outboundDialer + + outbound.dialer = tls.NewDialer(outboundDialer, tlsConfig) client, err := anytls.NewClient(ctx, anytls.ClientConfig{ Password: options.Password, @@ -91,16 +92,7 @@ func (d anytlsDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) } func (h *Outbound) dialOut(ctx context.Context) (net.Conn, error) { - conn, err := h.dialer.DialContext(ctx, N.NetworkTCP, h.server) - if err != nil { - return nil, err - } - tlsConn, err := tls.ClientHandshake(ctx, conn, h.tlsConfig) - if err != nil { - common.Close(tlsConn, conn) - return nil, err - } - return tlsConn, nil + return h.dialer.DialTLSContext(ctx, h.server) } func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { diff --git a/protocol/http/outbound.go b/protocol/http/outbound.go index 0570dde5..3b631b39 100644 --- a/protocol/http/outbound.go +++ b/protocol/http/outbound.go @@ -34,7 +34,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } - detour, err := tls.NewDialerFromOptions(ctx, router, outboundDialer, options.Server, common.PtrValueOrDefault(options.TLS)) + detour, err := tls.NewDialerFromOptions(ctx, outboundDialer, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/protocol/trojan/outbound.go b/protocol/trojan/outbound.go index cd290386..dc2e0fe4 100644 --- a/protocol/trojan/outbound.go +++ b/protocol/trojan/outbound.go @@ -34,6 +34,7 @@ type Outbound struct { key [56]byte multiplexDialer *mux.Client tlsConfig tls.Config + tlsDialer tls.Dialer transport adapter.V2RayClientTransport } @@ -54,6 +55,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } + outbound.tlsDialer = tls.NewDialer(outboundDialer, outbound.tlsConfig) } if options.Transport != nil { outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig) @@ -121,11 +123,10 @@ func (h *trojanDialer) DialContext(ctx context.Context, network string, destinat var err error if h.transport != nil { conn, err = h.transport.DialContext(ctx) + } else if h.tlsDialer != nil { + conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr) } else { conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr) - if err == nil && h.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig) - } } if err != nil { common.Close(conn) diff --git a/protocol/vless/outbound.go b/protocol/vless/outbound.go index b95a36f7..a96d12a0 100644 --- a/protocol/vless/outbound.go +++ b/protocol/vless/outbound.go @@ -35,6 +35,7 @@ type Outbound struct { serverAddr M.Socksaddr multiplexDialer *mux.Client tlsConfig tls.Config + tlsDialer tls.Dialer transport adapter.V2RayClientTransport packetAddr bool xudp bool @@ -56,6 +57,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } + outbound.tlsDialer = tls.NewDialer(outboundDialer, outbound.tlsConfig) } if options.Transport != nil { outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig) @@ -140,11 +142,10 @@ func (h *vlessDialer) DialContext(ctx context.Context, network string, destinati var err error if h.transport != nil { conn, err = h.transport.DialContext(ctx) + } else if h.tlsDialer != nil { + conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr) } else { conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr) - if err == nil && h.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig) - } } if err != nil { return nil, err @@ -183,11 +184,10 @@ func (h *vlessDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) var err error if h.transport != nil { conn, err = h.transport.DialContext(ctx) + } else if h.tlsDialer != nil { + conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr) } else { conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr) - if err == nil && h.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig) - } } if err != nil { common.Close(conn) diff --git a/protocol/vmess/outbound.go b/protocol/vmess/outbound.go index bf76ab3d..716570f3 100644 --- a/protocol/vmess/outbound.go +++ b/protocol/vmess/outbound.go @@ -35,6 +35,7 @@ type Outbound struct { serverAddr M.Socksaddr multiplexDialer *mux.Client tlsConfig tls.Config + tlsDialer tls.Dialer transport adapter.V2RayClientTransport packetAddr bool xudp bool @@ -56,6 +57,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } + outbound.tlsDialer = tls.NewDialer(outboundDialer, outbound.tlsConfig) } if options.Transport != nil { outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig) @@ -154,11 +156,10 @@ func (h *vmessDialer) DialContext(ctx context.Context, network string, destinati var err error if h.transport != nil { conn, err = h.transport.DialContext(ctx) + } else if h.tlsDialer != nil { + conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr) } else { conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr) - if err == nil && h.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig) - } } if err != nil { common.Close(conn) @@ -182,11 +183,10 @@ func (h *vmessDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) var err error if h.transport != nil { conn, err = h.transport.DialContext(ctx) + } else if h.tlsDialer != nil { + conn, err = h.tlsDialer.DialTLSContext(ctx, h.serverAddr) } else { conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr) - if err == nil && h.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig) - } } if err != nil { return nil, err diff --git a/transport/v2raygrpclite/client.go b/transport/v2raygrpclite/client.go index de8915a1..b2aab911 100644 --- a/transport/v2raygrpclite/client.go +++ b/transport/v2raygrpclite/client.go @@ -29,7 +29,6 @@ var defaultClientHeader = http.Header{ type Client struct { ctx context.Context - dialer N.Dialer serverAddr M.Socksaddr transport *http2.Transport options option.V2RayGRPCOptions @@ -46,7 +45,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt } client := &Client{ ctx: ctx, - dialer: dialer, serverAddr: serverAddr, options: options, transport: &http2.Transport{ @@ -62,7 +60,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt }, host: host, } - if tlsConfig == nil { client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) { return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) @@ -71,12 +68,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt if len(tlsConfig.NextProtos()) == 0 { tlsConfig.SetNextProtos([]string{http2.NextProtoTLS}) } + tlsDialer := tls.NewDialer(dialer, tlsConfig) client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - if err != nil { - return nil, err - } - return tls.ClientHandshake(ctx, conn, tlsConfig) + return tlsDialer.DialTLSContext(ctx, M.ParseSocksaddr(addr)) } } diff --git a/transport/v2rayhttp/client.go b/transport/v2rayhttp/client.go index a105a4f3..6c327cd6 100644 --- a/transport/v2rayhttp/client.go +++ b/transport/v2rayhttp/client.go @@ -47,15 +47,12 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt if len(tlsConfig.NextProtos()) == 0 { tlsConfig.SetNextProtos([]string{http2.NextProtoTLS}) } + tlsDialer := tls.NewDialer(dialer, tlsConfig) transport = &http2.Transport{ ReadIdleTimeout: time.Duration(options.IdleTimeout), PingTimeout: time.Duration(options.PingTimeout), DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - if err != nil { - return nil, err - } - return tls.ClientHandshake(ctx, conn, tlsConfig) + return tlsDialer.DialTLSContext(ctx, M.ParseSocksaddr(addr)) }, } } diff --git a/transport/v2rayhttpupgrade/client.go b/transport/v2rayhttpupgrade/client.go index e2b86b1f..f282d3f6 100644 --- a/transport/v2rayhttpupgrade/client.go +++ b/transport/v2rayhttpupgrade/client.go @@ -23,7 +23,6 @@ var _ adapter.V2RayClientTransport = (*Client)(nil) type Client struct { dialer N.Dialer - tlsConfig tls.Config serverAddr M.Socksaddr requestURL url.URL headers http.Header @@ -35,6 +34,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt if len(tlsConfig.NextProtos()) == 0 { tlsConfig.SetNextProtos([]string{"http/1.1"}) } + dialer = tls.NewDialer(dialer, tlsConfig) } var host string if options.Host != "" { @@ -65,7 +65,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt } return &Client{ dialer: dialer, - tlsConfig: tlsConfig, serverAddr: serverAddr, requestURL: requestURL, headers: headers, @@ -78,12 +77,6 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { if err != nil { return nil, err } - if c.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig) - if err != nil { - return nil, err - } - } request := &http.Request{ Method: http.MethodGet, URL: &c.requestURL, diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index 748bae4c..e5630109 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -26,7 +26,6 @@ var _ adapter.V2RayClientTransport = (*Client)(nil) type Client struct { dialer N.Dialer - tlsConfig tls.Config serverAddr M.Socksaddr requestURL url.URL headers http.Header @@ -39,6 +38,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt if len(tlsConfig.NextProtos()) == 0 { tlsConfig.SetNextProtos([]string{"http/1.1"}) } + dialer = tls.NewDialer(dialer, tlsConfig) } var requestURL url.URL if tlsConfig == nil { @@ -65,7 +65,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt } return &Client{ dialer, - tlsConfig, serverAddr, requestURL, headers, @@ -79,12 +78,6 @@ func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers h if err != nil { return nil, err } - if c.tlsConfig != nil { - conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig) - if err != nil { - return nil, err - } - } var deadlineConn net.Conn if deadline.NeedAdditionalReadDeadline(conn) { deadlineConn = deadline.NewConn(conn)