diff --git a/common/badtls/read_wait.go b/common/badtls/read_wait.go index 334bcfa8..9508a7e3 100644 --- a/common/badtls/read_wait.go +++ b/common/badtls/read_wait.go @@ -128,6 +128,10 @@ func (c *ReadWaitConn) Upstream() any { return c.Conn } +func (c *ReadWaitConn) ReaderReplaceable() bool { + return true +} + var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) func init() { diff --git a/common/badtls/read_wait_utls.go b/common/badtls/read_wait_utls.go index bba016e4..1facd30b 100644 --- a/common/badtls/read_wait_utls.go +++ b/common/badtls/read_wait_utls.go @@ -6,22 +6,26 @@ import ( "net" _ "unsafe" - "github.com/sagernet/sing/common" - "github.com/metacubex/utls" ) func init() { tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { - tlsConn, loaded := common.Cast[*tls.UConn](conn) - if !loaded { - return + switch tlsConn := conn.(type) { + case *tls.UConn: + return true, func() error { + return utlsReadRecord(tlsConn.Conn) + }, func() error { + return utlsHandlePostHandshakeMessage(tlsConn.Conn) + } + case *tls.Conn: + return true, func() error { + return utlsReadRecord(tlsConn) + }, func() error { + return utlsHandlePostHandshakeMessage(tlsConn) + } } - return true, func() error { - return utlsReadRecord(tlsConn.Conn) - }, func() error { - return utlsHandlePostHandshakeMessage(tlsConn.Conn) - } + return }) }