sing-box/common/ktls/ktls.go
2025-09-08 09:35:55 +08:00

100 lines
2.0 KiB
Go

//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/tls"
"io"
"net"
"os"
"syscall"
"github.com/sagernet/sing-box/common/badtls"
// C "github.com/sagernet/sing-box/constant"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
)
type Conn struct {
aTLS.Conn
conn net.Conn
rawConn *badtls.RawConn
rawSyscallConn syscall.RawConn
readWaitOptions N.ReadWaitOptions
kernelTx bool
kernelRx bool
kernelDidRead bool
kernelDidWrite bool
}
func NewConn(conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
err := Load()
if err != nil {
return nil, err
}
syscallConn, isSyscallConn := N.CastReader[interface {
io.Reader
syscall.Conn
}](conn.NetConn())
if !isSyscallConn {
return nil, os.ErrInvalid
}
rawSyscallConn, err := syscallConn.SyscallConn()
if err != nil {
return nil, err
}
rawConn, err := badtls.NewRawConn(conn)
if err != nil {
return nil, err
}
if *rawConn.Vers != tls.VersionTLS13 {
return nil, os.ErrInvalid
}
for rawConn.RawInput.Len() > 0 {
err = rawConn.ReadRecord()
if err != nil {
return nil, err
}
for rawConn.Hand.Len() > 0 {
err = rawConn.HandlePostHandshakeMessage()
if err != nil {
return nil, E.Cause(err, "ktls: failed to handle post-handshake messages")
}
}
}
kConn := &Conn{
Conn: conn,
conn: conn.NetConn(),
rawConn: rawConn,
rawSyscallConn: rawSyscallConn,
}
err = kConn.setupKernel(txOffload, rxOffload)
if err != nil {
return nil, err
}
return kConn, nil
}
func (c *Conn) Upstream() any {
return c.conn
}
func (c *Conn) ReaderReplaceable() bool {
if !c.kernelRx {
return false
}
c.rawConn.In.Lock()
defer c.rawConn.In.Unlock()
return !c.kernelDidRead
}
func (c *Conn) WriterReplaceable() bool {
if !c.kernelTx {
return false
}
c.rawConn.Out.Lock()
defer c.rawConn.Out.Unlock()
return !c.kernelDidWrite
}