From 6e6ab676019d007ef7f972a36960fe60e680f043 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 7 Sep 2025 21:03:32 +0800 Subject: [PATCH] Add support for kTLS Reference: https://gitlab.com/go-extension/tls --- .github/workflows/build.yml | 4 +- Dockerfile | 2 +- cmd/internal/build_libbox/main.go | 8 +- common/badtls/raw_conn.go | 169 ++++++++++++ common/badtls/raw_half_conn.go | 121 +++++++++ common/badtls/read_wait.go | 122 ++------- common/badtls/read_wait_stub.go | 2 +- common/badtls/read_wait_utls.go | 36 --- common/badtls/registry.go | 62 +++++ common/badtls/registry_utls.go | 56 ++++ common/ktls/ktls.go | 84 ++++++ common/ktls/ktls_alert.go | 80 ++++++ common/ktls/ktls_cipher_suites_linux.go | 326 ++++++++++++++++++++++++ common/ktls/ktls_close.go | 67 +++++ common/ktls/ktls_const.go | 24 ++ common/ktls/ktls_handshake_messages.go | 238 +++++++++++++++++ common/ktls/ktls_key_update.go | 173 +++++++++++++ common/ktls/ktls_linux.go | 311 ++++++++++++++++++++++ common/ktls/ktls_prf.go | 24 ++ common/ktls/ktls_read.go | 292 +++++++++++++++++++++ common/ktls/ktls_read_wait.go | 41 +++ common/ktls/ktls_stub.go | 13 + common/ktls/ktls_write.go | 154 +++++++++++ common/tls/client.go | 8 + common/tls/config.go | 6 + common/tls/server.go | 8 + common/tls/std_client.go | 29 ++- common/tls/std_server.go | 33 ++- common/tls/utls_client.go | 19 +- go.mod | 3 +- go.sum | 6 +- option/tls.go | 4 + release/local/debug.sh | 2 +- release/local/install.sh | 2 +- release/local/reinstall.sh | 2 +- route/conn.go | 36 +-- transport/trojan/protocol.go | 8 + 37 files changed, 2397 insertions(+), 178 deletions(-) create mode 100644 common/badtls/raw_conn.go create mode 100644 common/badtls/raw_half_conn.go delete mode 100644 common/badtls/read_wait_utls.go create mode 100644 common/badtls/registry.go create mode 100644 common/badtls/registry_utls.go create mode 100644 common/ktls/ktls.go create mode 100644 common/ktls/ktls_alert.go create mode 100644 common/ktls/ktls_cipher_suites_linux.go create mode 100644 common/ktls/ktls_close.go create mode 100644 common/ktls/ktls_const.go create mode 100644 common/ktls/ktls_handshake_messages.go create mode 100644 common/ktls/ktls_key_update.go create mode 100644 common/ktls/ktls_linux.go create mode 100644 common/ktls/ktls_prf.go create mode 100644 common/ktls/ktls_read.go create mode 100644 common/ktls/ktls_read_wait.go create mode 100644 common/ktls/ktls_stub.go create mode 100644 common/ktls/ktls_write.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 25114dff..7983a0a4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -154,7 +154,7 @@ jobs: set -xeuo pipefail mkdir -p dist go build -v -trimpath -o dist/sing-box -tags "${BUILD_TAGS}" \ - -ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }}' \ + -ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0' \ ./cmd/sing-box env: CGO_ENABLED: "0" @@ -174,7 +174,7 @@ jobs: export CXX="${CC}++" mkdir -p dist GOOS=$BUILD_GOOS GOARCH=$BUILD_GOARCH build go build -v -trimpath -o dist/sing-box -tags "${BUILD_TAGS}" \ - -ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }}' \ + -ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0' \ ./cmd/sing-box env: CGO_ENABLED: "1" diff --git a/Dockerfile b/Dockerfile index 359ae293..1c8dbec9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ RUN set -ex \ && go build -v -trimpath -tags \ "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale" \ -o /go/bin/sing-box \ - -ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid=" \ + -ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid= -checklinkname=0" \ ./cmd/sing-box FROM --platform=$TARGETPLATFORM alpine AS dist LABEL maintainer="nekohasekai " diff --git a/cmd/internal/build_libbox/main.go b/cmd/internal/build_libbox/main.go index 0f4db850..483bebed 100644 --- a/cmd/internal/build_libbox/main.go +++ b/cmd/internal/build_libbox/main.go @@ -59,8 +59,8 @@ func init() { if err != nil { currentTag = "unknown" } - sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -s -w -buildid=") - debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag) + sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -s -w -buildid= -checklinkname=0") + debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -checklinkname=0") sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_utls", "with_clash_api", "with_conntrack") macOSTags = append(macOSTags, "with_dhcp") @@ -107,10 +107,10 @@ func buildAndroid() { } if !debugEnabled { - sharedFlags[3] = sharedFlags[3] + " -checklinkname=0" + // sharedFlags[3] = sharedFlags[3] + " -checklinkname=0" args = append(args, sharedFlags...) } else { - debugFlags[1] = debugFlags[1] + " -checklinkname=0" + // debugFlags[1] = debugFlags[1] + " -checklinkname=0" args = append(args, debugFlags...) } diff --git a/common/badtls/raw_conn.go b/common/badtls/raw_conn.go new file mode 100644 index 00000000..3a60bdbe --- /dev/null +++ b/common/badtls/raw_conn.go @@ -0,0 +1,169 @@ +//go:build go1.25 && !without_badtls + +package badtls + +import ( + "bytes" + "os" + "reflect" + "sync/atomic" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/tls" +) + +type RawConn struct { + pointer unsafe.Pointer + methods *Methods + + IsClient *bool + IsHandshakeComplete *atomic.Bool + Vers *uint16 + CipherSuite *uint16 + + RawInput *bytes.Buffer + Input *bytes.Reader + Hand *bytes.Buffer + + CloseNotifySent *bool + CloseNotifyErr *error + + In *RawHalfConn + Out *RawHalfConn + + BytesSent *int64 + PacketsSent *int64 + + ActiveCall *atomic.Int32 +} + +func NewRawConn(rawTLSConn tls.Conn) (*RawConn, error) { + var ( + pointer unsafe.Pointer + methods *Methods + loaded bool + ) + for _, tlsCreator := range methodRegistry { + pointer, methods, loaded = tlsCreator(rawTLSConn) + if loaded { + break + } + } + if !loaded { + return nil, os.ErrInvalid + } + + conn := &RawConn{ + pointer: pointer, + methods: methods, + } + + rawConn := reflect.Indirect(reflect.ValueOf(rawTLSConn)) + + rawIsClient := rawConn.FieldByName("isClient") + if !rawIsClient.IsValid() || rawIsClient.Kind() != reflect.Bool { + return nil, E.New("invalid Conn.isClient") + } + conn.IsClient = (*bool)(unsafe.Pointer(rawIsClient.UnsafeAddr())) + + rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete") + if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.isHandshakeComplete") + } + conn.IsHandshakeComplete = (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr())) + + rawVers := rawConn.FieldByName("vers") + if !rawVers.IsValid() || rawVers.Kind() != reflect.Uint16 { + return nil, E.New("invalid Conn.vers") + } + conn.Vers = (*uint16)(unsafe.Pointer(rawVers.UnsafeAddr())) + + rawCipherSuite := rawConn.FieldByName("cipherSuite") + if !rawCipherSuite.IsValid() || rawCipherSuite.Kind() != reflect.Uint16 { + return nil, E.New("invalid Conn.cipherSuite") + } + conn.CipherSuite = (*uint16)(unsafe.Pointer(rawCipherSuite.UnsafeAddr())) + + rawRawInput := rawConn.FieldByName("rawInput") + if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.rawInput") + } + conn.RawInput = (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr())) + + rawInput := rawConn.FieldByName("input") + if !rawInput.IsValid() || rawInput.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.input") + } + conn.Input = (*bytes.Reader)(unsafe.Pointer(rawInput.UnsafeAddr())) + + rawHand := rawConn.FieldByName("hand") + if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.hand") + } + conn.Hand = (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr())) + + rawCloseNotifySent := rawConn.FieldByName("closeNotifySent") + if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool { + return nil, E.New("invalid Conn.closeNotifySent") + } + conn.CloseNotifySent = (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr())) + + rawCloseNotifyErr := rawConn.FieldByName("closeNotifyErr") + if !rawCloseNotifyErr.IsValid() || rawCloseNotifyErr.Kind() != reflect.Interface { + return nil, E.New("invalid Conn.closeNotifyErr") + } + conn.CloseNotifyErr = (*error)(unsafe.Pointer(rawCloseNotifyErr.UnsafeAddr())) + + rawIn := rawConn.FieldByName("in") + if !rawIn.IsValid() || rawIn.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.in") + } + halfIn, err := NewRawHalfConn(rawIn, methods) + if err != nil { + return nil, E.Cause(err, "invalid Conn.in") + } + conn.In = halfIn + + rawOut := rawConn.FieldByName("out") + if !rawOut.IsValid() || rawOut.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.out") + } + halfOut, err := NewRawHalfConn(rawOut, methods) + if err != nil { + return nil, E.Cause(err, "invalid Conn.out") + } + conn.Out = halfOut + + rawBytesSent := rawConn.FieldByName("bytesSent") + if !rawBytesSent.IsValid() || rawBytesSent.Kind() != reflect.Int64 { + return nil, E.New("invalid Conn.bytesSent") + } + conn.BytesSent = (*int64)(unsafe.Pointer(rawBytesSent.UnsafeAddr())) + + rawPacketsSent := rawConn.FieldByName("packetsSent") + if !rawPacketsSent.IsValid() || rawPacketsSent.Kind() != reflect.Int64 { + return nil, E.New("invalid Conn.packetsSent") + } + conn.PacketsSent = (*int64)(unsafe.Pointer(rawPacketsSent.UnsafeAddr())) + + rawActiveCall := rawConn.FieldByName("activeCall") + if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct { + return nil, E.New("invalid Conn.activeCall") + } + conn.ActiveCall = (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) + + return conn, nil +} + +func (c *RawConn) ReadRecord() error { + return c.methods.readRecord(c.pointer) +} + +func (c *RawConn) HandlePostHandshakeMessage() error { + return c.methods.handlePostHandshakeMessage(c.pointer) +} + +func (c *RawConn) WriteRecordLocked(typ uint16, data []byte) (int, error) { + return c.methods.writeRecordLocked(c.pointer, typ, data) +} diff --git a/common/badtls/raw_half_conn.go b/common/badtls/raw_half_conn.go new file mode 100644 index 00000000..dd5e249e --- /dev/null +++ b/common/badtls/raw_half_conn.go @@ -0,0 +1,121 @@ +//go:build go1.25 && !without_badtls + +package badtls + +import ( + "hash" + "reflect" + "sync" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" +) + +type RawHalfConn struct { + pointer unsafe.Pointer + methods *Methods + *sync.Mutex + Err *error + Version *uint16 + Cipher *any + Seq *[8]byte + ScratchBuf *[13]byte + TrafficSecret *[]byte + Mac *hash.Hash + RawKey *[]byte + RawIV *[]byte + RawMac *[]byte +} + +func NewRawHalfConn(rawHalfConn reflect.Value, methods *Methods) (*RawHalfConn, error) { + halfConn := &RawHalfConn{ + pointer: (unsafe.Pointer)(rawHalfConn.UnsafeAddr()), + methods: methods, + } + + rawMutex := rawHalfConn.FieldByName("Mutex") + if !rawMutex.IsValid() || rawMutex.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid halfConn.Mutex") + } + halfConn.Mutex = (*sync.Mutex)(unsafe.Pointer(rawMutex.UnsafeAddr())) + + rawErr := rawHalfConn.FieldByName("err") + if !rawErr.IsValid() || rawErr.Kind() != reflect.Interface { + return nil, E.New("badtls: invalid halfConn.err") + } + halfConn.Err = (*error)(unsafe.Pointer(rawErr.UnsafeAddr())) + + rawVersion := rawHalfConn.FieldByName("version") + if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 { + return nil, E.New("badtls: invalid halfConn.version") + } + halfConn.Version = (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr())) + + rawCipher := rawHalfConn.FieldByName("cipher") + if !rawCipher.IsValid() || rawCipher.Kind() != reflect.Interface { + return nil, E.New("badtls: invalid halfConn.cipher") + } + halfConn.Cipher = (*any)(unsafe.Pointer(rawCipher.UnsafeAddr())) + + rawSeq := rawHalfConn.FieldByName("seq") + if !rawSeq.IsValid() || rawSeq.Kind() != reflect.Array || rawSeq.Len() != 8 || rawSeq.Type().Elem().Kind() != reflect.Uint8 { + return nil, E.New("badtls: invalid halfConn.seq") + } + halfConn.Seq = (*[8]byte)(unsafe.Pointer(rawSeq.UnsafeAddr())) + + rawScratchBuf := rawHalfConn.FieldByName("scratchBuf") + if !rawScratchBuf.IsValid() || rawScratchBuf.Kind() != reflect.Array || rawScratchBuf.Len() != 13 || rawScratchBuf.Type().Elem().Kind() != reflect.Uint8 { + return nil, E.New("badtls: invalid halfConn.scratchBuf") + } + halfConn.ScratchBuf = (*[13]byte)(unsafe.Pointer(rawScratchBuf.UnsafeAddr())) + + rawTrafficSecret := rawHalfConn.FieldByName("trafficSecret") + if !rawTrafficSecret.IsValid() || rawTrafficSecret.Kind() != reflect.Slice || rawTrafficSecret.Type().Elem().Kind() != reflect.Uint8 { + return nil, E.New("badtls: invalid halfConn.trafficSecret") + } + halfConn.TrafficSecret = (*[]byte)(unsafe.Pointer(rawTrafficSecret.UnsafeAddr())) + + rawMac := rawHalfConn.FieldByName("mac") + if !rawMac.IsValid() || rawMac.Kind() != reflect.Interface { + return nil, E.New("badtls: invalid halfConn.mac") + } + halfConn.Mac = (*hash.Hash)(unsafe.Pointer(rawMac.UnsafeAddr())) + + rawKey := rawHalfConn.FieldByName("rawKey") + if rawKey.IsValid() { + if /*!rawKey.IsValid() || */ rawKey.Kind() != reflect.Slice || rawKey.Type().Elem().Kind() != reflect.Uint8 { + return nil, E.New("badtls: invalid halfConn.rawKey") + } + halfConn.RawKey = (*[]byte)(unsafe.Pointer(rawKey.UnsafeAddr())) + + rawIV := rawHalfConn.FieldByName("rawIV") + if !rawIV.IsValid() || rawIV.Kind() != reflect.Slice || rawIV.Type().Elem().Kind() != reflect.Uint8 { + return nil, E.New("badtls: invalid halfConn.rawIV") + } + halfConn.RawIV = (*[]byte)(unsafe.Pointer(rawIV.UnsafeAddr())) + + rawMAC := rawHalfConn.FieldByName("rawMac") + if !rawMAC.IsValid() || rawMAC.Kind() != reflect.Slice || rawMAC.Type().Elem().Kind() != reflect.Uint8 { + return nil, E.New("badtls: invalid halfConn.rawMac") + } + halfConn.RawMac = (*[]byte)(unsafe.Pointer(rawMAC.UnsafeAddr())) + } + + return halfConn, nil +} + +func (hc *RawHalfConn) Decrypt(record []byte) ([]byte, uint8, error) { + return hc.methods.decrypt(hc.pointer, record) +} + +func (hc *RawHalfConn) SetErrorLocked(err error) error { + return hc.methods.setErrorLocked(hc.pointer, err) +} + +func (hc *RawHalfConn) SetTrafficSecret(suite unsafe.Pointer, level int, secret []byte) { + hc.methods.setTrafficSecret(hc.pointer, suite, level, secret) +} + +func (hc *RawHalfConn) ExplicitNonceLen() int { + return hc.methods.explicitNonceLen(hc.pointer) +} diff --git a/common/badtls/read_wait.go b/common/badtls/read_wait.go index 9508a7e3..47d5b65e 100644 --- a/common/badtls/read_wait.go +++ b/common/badtls/read_wait.go @@ -1,18 +1,9 @@ -//go:build go1.21 && !without_badtls +//go:build go1.25 && !without_badtls package badtls import ( - "bytes" - "context" - "net" - "os" - "reflect" - "sync" - "unsafe" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/tls" ) @@ -21,63 +12,18 @@ var _ N.ReadWaiter = (*ReadWaitConn)(nil) type ReadWaitConn struct { tls.Conn - halfAccess *sync.Mutex - rawInput *bytes.Buffer - input *bytes.Reader - hand *bytes.Buffer - readWaitOptions N.ReadWaitOptions - tlsReadRecord func() error - tlsHandlePostHandshakeMessage func() error + rawConn *RawConn + readWaitOptions N.ReadWaitOptions } func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { - var ( - loaded bool - tlsReadRecord func() error - tlsHandlePostHandshakeMessage func() error - ) - for _, tlsCreator := range tlsRegistry { - loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn) - if loaded { - break - } + rawConn, err := NewRawConn(conn) + if err != nil { + return nil, err } - if !loaded { - return nil, os.ErrInvalid - } - rawConn := reflect.Indirect(reflect.ValueOf(conn)) - rawHalfConn := rawConn.FieldByName("in") - if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid half conn") - } - rawHalfMutex := rawHalfConn.FieldByName("Mutex") - if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid half mutex") - } - halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr())) - rawRawInput := rawConn.FieldByName("rawInput") - if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid raw input") - } - rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr())) - rawInput0 := rawConn.FieldByName("input") - if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid input") - } - input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr())) - rawHand := rawConn.FieldByName("hand") - if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid hand") - } - hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr())) return &ReadWaitConn{ - Conn: conn, - halfAccess: halfAccess, - rawInput: rawInput, - input: input, - hand: hand, - tlsReadRecord: tlsReadRecord, - tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage, + Conn: conn, + rawConn: rawConn, }, nil } @@ -87,36 +33,36 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy } func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { - err = c.HandshakeContext(context.Background()) - if err != nil { - return - } - c.halfAccess.Lock() - defer c.halfAccess.Unlock() - for c.input.Len() == 0 { - err = c.tlsReadRecord() + //err = c.HandshakeContext(context.Background()) + //if err != nil { + // return + //} + c.rawConn.In.Lock() + defer c.rawConn.In.Unlock() + for c.rawConn.Input.Len() == 0 { + err = c.rawConn.ReadRecord() if err != nil { return } - for c.hand.Len() > 0 { - err = c.tlsHandlePostHandshakeMessage() + for c.rawConn.Hand.Len() > 0 { + err = c.rawConn.HandlePostHandshakeMessage() if err != nil { return } } } buffer = c.readWaitOptions.NewBuffer() - n, err := c.input.Read(buffer.FreeBytes()) + n, err := c.rawConn.Input.Read(buffer.FreeBytes()) if err != nil { buffer.Release() return } buffer.Truncate(n) - if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && - // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { - c.rawInput.Bytes()[0] == 21 { - _ = c.tlsReadRecord() + if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 && + // recordType(c.RawInput.Bytes()[0]) == recordTypeAlert { + c.rawConn.RawInput.Bytes()[0] == 21 { + _ = c.rawConn.ReadRecord() // return n, err // will be io.EOF on closeNotify } @@ -131,25 +77,3 @@ func (c *ReadWaitConn) Upstream() any { func (c *ReadWaitConn) ReaderReplaceable() bool { return true } - -var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) - -func init() { - tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { - tlsConn, loaded := conn.(*tls.STDConn) - if !loaded { - return - } - return true, func() error { - return stdTLSReadRecord(tlsConn) - }, func() error { - return stdTLSHandlePostHandshakeMessage(tlsConn) - } - }) -} - -//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord -func stdTLSReadRecord(c *tls.STDConn) error - -//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage -func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error diff --git a/common/badtls/read_wait_stub.go b/common/badtls/read_wait_stub.go index c5c9946f..4bd4fc60 100644 --- a/common/badtls/read_wait_stub.go +++ b/common/badtls/read_wait_stub.go @@ -1,4 +1,4 @@ -//go:build !go1.21 || without_badtls +//go:build !go1.25 || without_badtls package badtls diff --git a/common/badtls/read_wait_utls.go b/common/badtls/read_wait_utls.go deleted file mode 100644 index 1facd30b..00000000 --- a/common/badtls/read_wait_utls.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build go1.21 && !without_badtls && with_utls - -package badtls - -import ( - "net" - _ "unsafe" - - "github.com/metacubex/utls" -) - -func init() { - tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { - switch tlsConn := conn.(type) { - case *tls.UConn: - return true, func() error { - return utlsReadRecord(tlsConn.Conn) - }, func() error { - return utlsHandlePostHandshakeMessage(tlsConn.Conn) - } - case *tls.Conn: - return true, func() error { - return utlsReadRecord(tlsConn) - }, func() error { - return utlsHandlePostHandshakeMessage(tlsConn) - } - } - return - }) -} - -//go:linkname utlsReadRecord github.com/metacubex/utls.(*Conn).readRecord -func utlsReadRecord(c *tls.Conn) error - -//go:linkname utlsHandlePostHandshakeMessage github.com/metacubex/utls.(*Conn).handlePostHandshakeMessage -func utlsHandlePostHandshakeMessage(c *tls.Conn) error diff --git a/common/badtls/registry.go b/common/badtls/registry.go new file mode 100644 index 00000000..cc11a16e --- /dev/null +++ b/common/badtls/registry.go @@ -0,0 +1,62 @@ +//go:build go1.25 && !without_badtls + +package badtls + +import ( + "crypto/tls" + "net" + "unsafe" +) + +type Methods struct { + readRecord func(c unsafe.Pointer) error + handlePostHandshakeMessage func(c unsafe.Pointer) error + writeRecordLocked func(c unsafe.Pointer, typ uint16, data []byte) (int, error) + + setErrorLocked func(hc unsafe.Pointer, err error) error + decrypt func(hc unsafe.Pointer, record []byte) ([]byte, uint8, error) + setTrafficSecret func(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte) + explicitNonceLen func(hc unsafe.Pointer) int +} + +var methodRegistry []func(conn net.Conn) (unsafe.Pointer, *Methods, bool) + +func init() { + methodRegistry = append(methodRegistry, func(conn net.Conn) (unsafe.Pointer, *Methods, bool) { + tlsConn, loaded := conn.(*tls.Conn) + if !loaded { + return nil, nil, false + } + return unsafe.Pointer(tlsConn), &Methods{ + readRecord: stdTLSReadRecord, + handlePostHandshakeMessage: stdTLSHandlePostHandshakeMessage, + writeRecordLocked: stdWriteRecordLocked, + + setErrorLocked: stdSetErrorLocked, + decrypt: stdDecrypt, + setTrafficSecret: stdSetTrafficSecret, + explicitNonceLen: stdExplicitNonceLen, + }, true + }) +} + +//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord +func stdTLSReadRecord(c unsafe.Pointer) error + +//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage +func stdTLSHandlePostHandshakeMessage(c unsafe.Pointer) error + +//go:linkname stdWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked +func stdWriteRecordLocked(c unsafe.Pointer, typ uint16, data []byte) (int, error) + +//go:linkname stdSetErrorLocked crypto/tls.(*halfConn).setErrorLocked +func stdSetErrorLocked(hc unsafe.Pointer, err error) error + +//go:linkname stdDecrypt crypto/tls.(*halfConn).decrypt +func stdDecrypt(hc unsafe.Pointer, record []byte) ([]byte, uint8, error) + +//go:linkname stdSetTrafficSecret crypto/tls.(*halfConn).setTrafficSecret +func stdSetTrafficSecret(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte) + +//go:linkname stdExplicitNonceLen crypto/tls.(*halfConn).explicitNonceLen +func stdExplicitNonceLen(hc unsafe.Pointer) int diff --git a/common/badtls/registry_utls.go b/common/badtls/registry_utls.go new file mode 100644 index 00000000..c0454355 --- /dev/null +++ b/common/badtls/registry_utls.go @@ -0,0 +1,56 @@ +//go:build go1.25 && !without_badtls + +package badtls + +import ( + "net" + "unsafe" + + N "github.com/sagernet/sing/common/network" + + "github.com/metacubex/utls" +) + +func init() { + methodRegistry = append(methodRegistry, func(conn net.Conn) (unsafe.Pointer, *Methods, bool) { + var pointer unsafe.Pointer + if uConn, loaded := N.CastReader[*tls.Conn](conn); loaded { + pointer = unsafe.Pointer(uConn) + } else if uConn, loaded := N.CastReader[*tls.UConn](conn); loaded { + pointer = unsafe.Pointer(uConn.Conn) + } else { + return nil, nil, false + } + return pointer, &Methods{ + readRecord: utlsReadRecord, + handlePostHandshakeMessage: utlsHandlePostHandshakeMessage, + writeRecordLocked: utlsWriteRecordLocked, + + setErrorLocked: utlsSetErrorLocked, + decrypt: utlsDecrypt, + setTrafficSecret: utlsSetTrafficSecret, + explicitNonceLen: utlsExplicitNonceLen, + }, true + }) +} + +//go:linkname utlsReadRecord github.com/metacubex/utls.(*Conn).readRecord +func utlsReadRecord(c unsafe.Pointer) error + +//go:linkname utlsHandlePostHandshakeMessage github.com/metacubex/utls.(*Conn).handlePostHandshakeMessage +func utlsHandlePostHandshakeMessage(c unsafe.Pointer) error + +//go:linkname utlsWriteRecordLocked github.com/metacubex/utls.(*Conn).writeRecordLocked +func utlsWriteRecordLocked(hc unsafe.Pointer, typ uint16, data []byte) (int, error) + +//go:linkname utlsSetErrorLocked github.com/metacubex/utls.(*halfConn).setErrorLocked +func utlsSetErrorLocked(hc unsafe.Pointer, err error) error + +//go:linkname utlsDecrypt github.com/metacubex/utls.(*halfConn).decrypt +func utlsDecrypt(hc unsafe.Pointer, record []byte) ([]byte, uint8, error) + +//go:linkname utlsSetTrafficSecret github.com/metacubex/utls.(*halfConn).setTrafficSecret +func utlsSetTrafficSecret(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte) + +//go:linkname utlsExplicitNonceLen github.com/metacubex/utls.(*halfConn).explicitNonceLen +func utlsExplicitNonceLen(hc unsafe.Pointer) int diff --git a/common/ktls/ktls.go b/common/ktls/ktls.go new file mode 100644 index 00000000..a3d629b3 --- /dev/null +++ b/common/ktls/ktls.go @@ -0,0 +1,84 @@ +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "crypto/tls" + "io" + "net" + "os" + "syscall" + + "github.com/sagernet/sing-box/common/badtls" + // C "github.com/sagernet/sing-box/constant" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" +) + +type Conn struct { + aTLS.Conn + conn net.Conn + rawConn *badtls.RawConn + rawSyscallConn syscall.RawConn + readWaitOptions N.ReadWaitOptions + kernelTx bool + kernelRx bool + tmp [16]byte +} + +func NewConn(conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) { + syscallConn, isSyscallConn := N.CastReader[interface { + io.Reader + syscall.Conn + }](conn.NetConn()) + if !isSyscallConn { + return nil, os.ErrInvalid + } + rawSyscallConn, err := syscallConn.SyscallConn() + if err != nil { + return nil, err + } + rawConn, err := badtls.NewRawConn(conn) + if err != nil { + return nil, err + } + if *rawConn.Vers != tls.VersionTLS13 { + return nil, os.ErrInvalid + } + for rawConn.RawInput.Len() > 0 { + err = rawConn.ReadRecord() + if err != nil { + return nil, err + } + for rawConn.Hand.Len() > 0 { + err = rawConn.HandlePostHandshakeMessage() + if err != nil { + return nil, E.Cause(err, "ktls: failed to handle post-handshake messages") + } + } + } + kConn := &Conn{ + Conn: conn, + conn: conn.NetConn(), + rawConn: rawConn, + rawSyscallConn: rawSyscallConn, + } + err = kConn.setupKernel(txOffload, rxOffload) + if err != nil { + return nil, err + } + return kConn, nil +} + +func (c *Conn) Upstream() any { + return c.conn +} + +func (c *Conn) ReaderReplaceable() bool { + return c.kernelRx +} + +func (c *Conn) WriterReplaceable() bool { + return c.kernelTx +} diff --git a/common/ktls/ktls_alert.go b/common/ktls/ktls_alert.go new file mode 100644 index 00000000..a60f8f2f --- /dev/null +++ b/common/ktls/ktls_alert.go @@ -0,0 +1,80 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "crypto/tls" + "net" +) + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify = 0 + alertUnexpectedMessage = 10 + alertBadRecordMAC = 20 + alertDecryptionFailed = 21 + alertRecordOverflow = 22 + alertDecompressionFailure = 30 + alertHandshakeFailure = 40 + alertBadCertificate = 42 + alertUnsupportedCertificate = 43 + alertCertificateRevoked = 44 + alertCertificateExpired = 45 + alertCertificateUnknown = 46 + alertIllegalParameter = 47 + alertUnknownCA = 48 + alertAccessDenied = 49 + alertDecodeError = 50 + alertDecryptError = 51 + alertExportRestriction = 60 + alertProtocolVersion = 70 + alertInsufficientSecurity = 71 + alertInternalError = 80 + alertInappropriateFallback = 86 + alertUserCanceled = 90 + alertNoRenegotiation = 100 + alertMissingExtension = 109 + alertUnsupportedExtension = 110 + alertCertificateUnobtainable = 111 + alertUnrecognizedName = 112 + alertBadCertificateStatusResponse = 113 + alertBadCertificateHashValue = 114 + alertUnknownPSKIdentity = 115 + alertCertificateRequired = 116 + alertNoApplicationProtocol = 120 + alertECHRequired = 121 +) + +func (c *Conn) sendAlertLocked(err uint8) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.rawConn.Out.SetErrorLocked(&net.OpError{Op: "local error", Err: tls.AlertError(err)}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err uint8) error { + c.rawConn.Out.Lock() + defer c.rawConn.Out.Unlock() + return c.sendAlertLocked(err) +} diff --git a/common/ktls/ktls_cipher_suites_linux.go b/common/ktls/ktls_cipher_suites_linux.go new file mode 100644 index 00000000..348b1c42 --- /dev/null +++ b/common/ktls/ktls_cipher_suites_linux.go @@ -0,0 +1,326 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "crypto/tls" + "unsafe" + + "github.com/sagernet/sing-box/common/badtls" +) + +type kernelCryptoCipherType uint16 + +const ( + TLS_CIPHER_AES_GCM_128 kernelCryptoCipherType = 51 + TLS_CIPHER_AES_GCM_128_IV_SIZE kernelCryptoCipherType = 8 + TLS_CIPHER_AES_GCM_128_KEY_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_AES_GCM_128_SALT_SIZE kernelCryptoCipherType = 4 + TLS_CIPHER_AES_GCM_128_TAG_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + TLS_CIPHER_AES_GCM_256 kernelCryptoCipherType = 52 + TLS_CIPHER_AES_GCM_256_IV_SIZE kernelCryptoCipherType = 8 + TLS_CIPHER_AES_GCM_256_KEY_SIZE kernelCryptoCipherType = 32 + TLS_CIPHER_AES_GCM_256_SALT_SIZE kernelCryptoCipherType = 4 + TLS_CIPHER_AES_GCM_256_TAG_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + TLS_CIPHER_AES_CCM_128 kernelCryptoCipherType = 53 + TLS_CIPHER_AES_CCM_128_IV_SIZE kernelCryptoCipherType = 8 + TLS_CIPHER_AES_CCM_128_KEY_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_AES_CCM_128_SALT_SIZE kernelCryptoCipherType = 4 + TLS_CIPHER_AES_CCM_128_TAG_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + TLS_CIPHER_CHACHA20_POLY1305 kernelCryptoCipherType = 54 + TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE kernelCryptoCipherType = 12 + TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE kernelCryptoCipherType = 32 + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE kernelCryptoCipherType = 0 + TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + // TLS_CIPHER_SM4_GCM kernelCryptoCipherType = 55 + // TLS_CIPHER_SM4_GCM_IV_SIZE kernelCryptoCipherType = 8 + // TLS_CIPHER_SM4_GCM_KEY_SIZE kernelCryptoCipherType = 16 + // TLS_CIPHER_SM4_GCM_SALT_SIZE kernelCryptoCipherType = 4 + // TLS_CIPHER_SM4_GCM_TAG_SIZE kernelCryptoCipherType = 16 + // TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + // TLS_CIPHER_SM4_CCM kernelCryptoCipherType = 56 + // TLS_CIPHER_SM4_CCM_IV_SIZE kernelCryptoCipherType = 8 + // TLS_CIPHER_SM4_CCM_KEY_SIZE kernelCryptoCipherType = 16 + // TLS_CIPHER_SM4_CCM_SALT_SIZE kernelCryptoCipherType = 4 + // TLS_CIPHER_SM4_CCM_TAG_SIZE kernelCryptoCipherType = 16 + // TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + TLS_CIPHER_ARIA_GCM_128 kernelCryptoCipherType = 57 + TLS_CIPHER_ARIA_GCM_128_IV_SIZE kernelCryptoCipherType = 8 + TLS_CIPHER_ARIA_GCM_128_KEY_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE kernelCryptoCipherType = 4 + TLS_CIPHER_ARIA_GCM_128_TAG_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8 + + TLS_CIPHER_ARIA_GCM_256 kernelCryptoCipherType = 58 + TLS_CIPHER_ARIA_GCM_256_IV_SIZE kernelCryptoCipherType = 8 + TLS_CIPHER_ARIA_GCM_256_KEY_SIZE kernelCryptoCipherType = 32 + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE kernelCryptoCipherType = 4 + TLS_CIPHER_ARIA_GCM_256_TAG_SIZE kernelCryptoCipherType = 16 + TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8 +) + +type kernelCrypto interface { + String() string +} + +type kernelCryptoInfo struct { + version uint16 + cipher_type kernelCryptoCipherType +} + +var _ kernelCrypto = &kernelCryptoAES128GCM{} + +type kernelCryptoAES128GCM struct { + kernelCryptoInfo + iv [TLS_CIPHER_AES_GCM_128_IV_SIZE]byte + key [TLS_CIPHER_AES_GCM_128_KEY_SIZE]byte + salt [TLS_CIPHER_AES_GCM_128_SALT_SIZE]byte + rec_seq [TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE]byte +} + +func (crypto *kernelCryptoAES128GCM) String() string { + crypto.cipher_type = TLS_CIPHER_AES_GCM_128 + return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +} + +var _ kernelCrypto = &kernelCryptoAES256GCM{} + +type kernelCryptoAES256GCM struct { + kernelCryptoInfo + iv [TLS_CIPHER_AES_GCM_256_IV_SIZE]byte + key [TLS_CIPHER_AES_GCM_256_KEY_SIZE]byte + salt [TLS_CIPHER_AES_GCM_256_SALT_SIZE]byte + rec_seq [TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE]byte +} + +func (crypto *kernelCryptoAES256GCM) String() string { + crypto.cipher_type = TLS_CIPHER_AES_GCM_256 + return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +} + +var _ kernelCrypto = &kernelCryptoAES128CCM{} + +type kernelCryptoAES128CCM struct { + kernelCryptoInfo + iv [TLS_CIPHER_AES_CCM_128_IV_SIZE]byte + key [TLS_CIPHER_AES_CCM_128_KEY_SIZE]byte + salt [TLS_CIPHER_AES_CCM_128_SALT_SIZE]byte + rec_seq [TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE]byte +} + +func (crypto *kernelCryptoAES128CCM) String() string { + crypto.cipher_type = TLS_CIPHER_AES_CCM_128 + return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +} + +var _ kernelCrypto = &kernelCryptoChacha20Poly1035{} + +type kernelCryptoChacha20Poly1035 struct { + kernelCryptoInfo + iv [TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE]byte + key [TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE]byte + salt [TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE]byte + rec_seq [TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE]byte +} + +func (crypto *kernelCryptoChacha20Poly1035) String() string { + crypto.cipher_type = TLS_CIPHER_CHACHA20_POLY1305 + return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +} + +// var _ kernelCrypto = &kernelCryptoSM4GCM{} + +// type kernelCryptoSM4GCM struct { +// kernelCryptoInfo +// iv [TLS_CIPHER_SM4_GCM_IV_SIZE]byte +// key [TLS_CIPHER_SM4_GCM_KEY_SIZE]byte +// salt [TLS_CIPHER_SM4_GCM_SALT_SIZE]byte +// rec_seq [TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE]byte +// } + +// func (crypto *kernelCryptoSM4GCM) String() string { +// crypto.cipher_type = TLS_CIPHER_SM4_GCM +// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +// } + +// var _ kernelCrypto = &kernelCryptoSM4CCM{} + +// type kernelCryptoSM4CCM struct { +// kernelCryptoInfo +// iv [TLS_CIPHER_SM4_CCM_IV_SIZE]byte +// key [TLS_CIPHER_SM4_CCM_KEY_SIZE]byte +// salt [TLS_CIPHER_SM4_CCM_SALT_SIZE]byte +// rec_seq [TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE]byte +// } + +// func (crypto *kernelCryptoSM4CCM) String() string { +// crypto.cipher_type = TLS_CIPHER_SM4_CCM +// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +// } + +var _ kernelCrypto = &kernelCryptoARIA128GCM{} + +type kernelCryptoARIA128GCM struct { + kernelCryptoInfo + iv [TLS_CIPHER_ARIA_GCM_128_IV_SIZE]byte + key [TLS_CIPHER_ARIA_GCM_128_KEY_SIZE]byte + salt [TLS_CIPHER_ARIA_GCM_128_SALT_SIZE]byte + rec_seq [TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE]byte +} + +func (crypto *kernelCryptoARIA128GCM) String() string { + crypto.cipher_type = TLS_CIPHER_ARIA_GCM_128 + return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +} + +var _ kernelCrypto = &kernelCryptoARIA256GCM{} + +type kernelCryptoARIA256GCM struct { + kernelCryptoInfo + iv [TLS_CIPHER_ARIA_GCM_256_IV_SIZE]byte + key [TLS_CIPHER_ARIA_GCM_256_KEY_SIZE]byte + salt [TLS_CIPHER_ARIA_GCM_256_SALT_SIZE]byte + rec_seq [TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE]byte +} + +func (crypto *kernelCryptoARIA256GCM) String() string { + crypto.cipher_type = TLS_CIPHER_ARIA_GCM_256 + return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:]) +} + +func kernelCipher(kernel *Support, hc *badtls.RawHalfConn, cipherSuite uint16, isRX bool) kernelCrypto { + if !kernel.TLS { + return nil + } + + switch *hc.Version { + case tls.VersionTLS12: + if isRX && !kernel.TLS_Version13_RX { + return nil + } + + case tls.VersionTLS13: + if !kernel.TLS_Version13 { + return nil + } + + if isRX && !kernel.TLS_Version13_RX { + return nil + } + + default: + return nil + } + + var key, iv []byte + if *hc.Version == tls.VersionTLS13 { + key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), *hc.TrafficSecret) + /*if isRX { + key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.RemoteTrafficSecret) + } else { + key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.TrafficSecret) + }*/ + } else { + // csPtr := cipherSuiteByID(cipherSuite) + // keysFromMasterSecret(*hc.Version, csPtr, keyLog.Secret, keyLog.Random) + return nil + } + + switch cipherSuite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + crypto := new(kernelCryptoAES128GCM) + + crypto.version = *hc.Version + copy(crypto.key[:], key) + copy(crypto.iv[:], iv[4:]) + copy(crypto.salt[:], iv[:4]) + crypto.rec_seq = *hc.Seq + + return crypto + case tls.TLS_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + if !kernel.TLS_AES_256_GCM { + return nil + } + + crypto := new(kernelCryptoAES256GCM) + + crypto.version = *hc.Version + copy(crypto.key[:], key) + copy(crypto.iv[:], iv[4:]) + copy(crypto.salt[:], iv[:4]) + crypto.rec_seq = *hc.Seq + + return crypto + //case tls.TLS_AES_128_CCM_SHA256, tls.TLS_RSA_WITH_AES_128_CCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_SHA256: + // if !kernel.TLS_AES_128_CCM { + // return nil + // } + // + // crypto := new(kernelCryptoAES128CCM) + // + // crypto.version = *hc.Version + // copy(crypto.key[:], key) + // copy(crypto.iv[:], iv[4:]) + // copy(crypto.salt[:], iv[:4]) + // crypto.rec_seq = *hc.Seq + // + // return crypto + case tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: + if !kernel.TLS_CHACHA20_POLY1305 { + return nil + } + + crypto := new(kernelCryptoChacha20Poly1035) + + crypto.version = *hc.Version + copy(crypto.key[:], key) + copy(crypto.iv[:], iv) + crypto.rec_seq = *hc.Seq + + return crypto + //case tls.TLS_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256: + // if !kernel.TLS_ARIA_GCM { + // return nil + // } + // + // crypto := new(kernelCryptoARIA128GCM) + // + // crypto.version = *hc.Version + // copy(crypto.key[:], key) + // copy(crypto.iv[:], iv[4:]) + // copy(crypto.salt[:], iv[:4]) + // crypto.rec_seq = *hc.Seq + // + // return crypto + //case tls.TLS_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384: + // if !kernel.TLS_ARIA_GCM { + // return nil + // } + // + // crypto := new(kernelCryptoARIA256GCM) + // + // crypto.version = *hc.Version + // copy(crypto.key[:], key) + // copy(crypto.iv[:], iv[4:]) + // copy(crypto.salt[:], iv[:4]) + // crypto.rec_seq = *hc.Seq + // + // return crypto + default: + return nil + } +} diff --git a/common/ktls/ktls_close.go b/common/ktls/ktls_close.go new file mode 100644 index 00000000..f9392a56 --- /dev/null +++ b/common/ktls/ktls_close.go @@ -0,0 +1,67 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "fmt" + "net" + "time" +) + +func (c *Conn) Close() error { + if !c.kernelTx { + return c.Conn.Close() + } + + // Interlock with Conn.Write above. + var x int32 + for { + x = c.rawConn.ActiveCall.Load() + if x&1 != 0 { + return net.ErrClosed + } + if c.rawConn.ActiveCall.CompareAndSwap(x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + + var alertErr error + if c.rawConn.IsHandshakeComplete.Load() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +func (c *Conn) closeNotify() error { + c.rawConn.Out.Lock() + defer c.rawConn.Out.Unlock() + + if !*c.rawConn.CloseNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) + *c.rawConn.CloseNotifyErr = c.sendAlertLocked(alertCloseNotify) + *c.rawConn.CloseNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) + } + return *c.rawConn.CloseNotifyErr +} diff --git a/common/ktls/ktls_const.go b/common/ktls/ktls_const.go new file mode 100644 index 00000000..0b8e72e8 --- /dev/null +++ b/common/ktls/ktls_const.go @@ -0,0 +1,24 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxHandshakeCertificateMsg = 262144 // maximum certificate message size (256 KiB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +const ( + recordTypeChangeCipherSpec = 20 + recordTypeAlert = 21 + recordTypeHandshake = 22 + recordTypeApplicationData = 23 +) diff --git a/common/ktls/ktls_handshake_messages.go b/common/ktls/ktls_handshake_messages.go new file mode 100644 index 00000000..e80531e1 --- /dev/null +++ b/common/ktls/ktls_handshake_messages.go @@ -0,0 +1,238 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "fmt" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type keyUpdateMsg struct { + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() ([]byte, error) { + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + return b.Bytes() +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeCompressedCertificate uint8 = 25 + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionPadding uint16 = 21 + extensionExtendedMasterSecret uint16 = 23 + extensionCompressCertificate uint16 = 27 // compress_certificate in TLS 1.3 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionQUICTransportParameters uint16 = 57 + extensionALPS uint16 = 17513 + extensionRenegotiationInfo uint16 = 0xff01 + extensionECHOuterExtensions uint16 = 0xfd00 + extensionEncryptedClientHello uint16 = 0xfe0d +) + +type handshakeMessage interface { + marshal() ([]byte, error) + unmarshal([]byte) bool +} +type newSessionTicketMsgTLS13 struct { + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + return b.Bytes() +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} diff --git a/common/ktls/ktls_key_update.go b/common/ktls/ktls_key_update.go new file mode 100644 index 00000000..9e0d0ee1 --- /dev/null +++ b/common/ktls/ktls_key_update.go @@ -0,0 +1,173 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "os" +) + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if *c.rawConn.Vers != tls.VersionTLS13 { + return errors.New("ktls: kernel does not support TLS 1.2 renegotiation") + } + + msg, err := c.readHandshake(nil) + if err != nil { + return err + } + //c.retryCount++ + //if c.retryCount > maxUselessRecords { + // c.sendAlert(alertUnexpectedMessage) + // return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + //} + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + // return errors.New("ktls: received new session ticket") + return nil + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + } + // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest + // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an + // unexpected_message alert here doesn't provide it with enough information to distinguish + // this condition from other unexpected messages. This is probably fine. + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + //if c.quic != nil { + // c.sendAlert(alertUnexpectedMessage) + // return c.in.setErrorLocked(errors.New("tls: received unexpected key update message")) + //} + + cipherSuite := cipherSuiteTLS13ByID(*c.rawConn.CipherSuite) + if cipherSuite == nil { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.In.TrafficSecret) + c.rawConn.In.SetTrafficSecret(cipherSuite, 0 /*tls.QUICEncryptionLevelInitial*/, newSecret) + + err := c.resetupRX() + if err != nil { + c.sendAlert(alertInternalError) + return c.rawConn.In.SetErrorLocked(fmt.Errorf("ktls: resetupRX failed: %w", err)) + } + + if keyUpdate.updateRequested { + c.rawConn.Out.Lock() + defer c.rawConn.Out.Unlock() + + resetup, err := c.resetupTX() + if err != nil { + c.sendAlertLocked(alertInternalError) + return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err)) + } + + msg := &keyUpdateMsg{} + msgBytes, err := msg.marshal() + if err != nil { + return err + } + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) + if err != nil { + // Surface the error at the next write. + c.rawConn.Out.SetErrorLocked(err) + return nil + } + + newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.Out.TrafficSecret) + c.rawConn.Out.SetTrafficSecret(cipherSuite, 0 /*QUICEncryptionLevelInitial*/, newSecret) + + err = resetup() + if err != nil { + return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err)) + } + } + + return nil +} + +func (c *Conn) readHandshakeBytes(n int) error { + //if c.quic != nil { + // return c.quicReadHandshakeBytes(n) + //} + for c.rawConn.Hand.Len() < n { + if err := c.readRecord(); err != nil { + return err + } + } + return nil +} + +func (c *Conn) readHandshake(transcript io.Writer) (any, error) { + if err := c.readHandshakeBytes(4); err != nil { + return nil, err + } + data := c.rawConn.Hand.Bytes() + + maxHandshakeSize := maxHandshake + // hasVers indicates we're past the first message, forcing someone trying to + // make us just allocate a large buffer to at least do the initial part of + // the handshake first. + //if c.haveVers && data[0] == typeCertificate { + // Since certificate messages are likely to be the only messages that + // can be larger than maxHandshake, we use a special limit for just + // those messages. + //maxHandshakeSize = maxHandshakeCertificateMsg + //} + + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshakeSize { + c.sendAlertLocked(alertInternalError) + return nil, c.rawConn.In.SetErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize)) + } + if err := c.readHandshakeBytes(4 + n); err != nil { + return nil, err + } + data = c.rawConn.Hand.Next(4 + n) + return c.unmarshalHandshakeMessage(data, transcript) +} + +func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript io.Writer) (any, error) { + var m handshakeMessage + switch data[0] { + case typeNewSessionTicket: + if *c.rawConn.Vers == tls.VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + return nil, os.ErrInvalid + } + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError)) + } + + if transcript != nil { + transcript.Write(data) + } + + return m, nil +} diff --git a/common/ktls/ktls_linux.go b/common/ktls/ktls_linux.go new file mode 100644 index 00000000..7b37948b --- /dev/null +++ b/common/ktls/ktls_linux.go @@ -0,0 +1,311 @@ +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "crypto/tls" + "errors" + "io" + "os" + "strings" + "sync" + "syscall" + "unsafe" + + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/blang/semver/v4" + "golang.org/x/sys/unix" +) + +// mod from https://gitlab.com/go-extension/tls + +const ( + TLS_TX = 1 + TLS_RX = 2 + TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now) + TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only + + TLS_SET_RECORD_TYPE = 1 + TLS_GET_RECORD_TYPE = 2 +) + +type Support struct { + TLS, TLS_RX bool + TLS_Version13, TLS_Version13_RX bool + + TLS_TX_ZEROCOPY bool + TLS_RX_NOPADDING bool + + TLS_AES_256_GCM bool + TLS_AES_128_CCM bool + TLS_CHACHA20_POLY1305 bool + TLS_SM4 bool + TLS_ARIA_GCM bool + + TLS_Version13_KeyUpdate bool +} + +var KernelSupport = sync.OnceValues(func() (*Support, error) { + _, err := os.Stat("/sys/module/tls") + if err != nil { + return nil, E.New("ktls: kernel module tls not found") + } + + var uname unix.Utsname + err = unix.Uname(&uname) + if err != nil { + return nil, err + } + + kernelVersion, err := semver.Parse(strings.Trim(string(uname.Release[:]), "\x00")) + if err != nil { + return nil, err + } + kernelVersion.Pre = nil + kernelVersion.Build = nil + + var support Support + + switch { + case kernelVersion.GTE(semver.Version{Major: 6, Minor: 14}): + support.TLS_Version13_KeyUpdate = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 6, Minor: 1}): + support.TLS_ARIA_GCM = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 6}): + support.TLS_Version13_RX = true + support.TLS_RX_NOPADDING = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 5, Minor: 19}): + support.TLS_TX_ZEROCOPY = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 5, Minor: 16}): + support.TLS_SM4 = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 5, Minor: 11}): + support.TLS_CHACHA20_POLY1305 = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 5, Minor: 2}): + support.TLS_AES_128_CCM = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 5, Minor: 1}): + support.TLS_AES_256_GCM = true + support.TLS_Version13 = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 4, Minor: 17}): + support.TLS_RX = true + fallthrough + case kernelVersion.GTE(semver.Version{Major: 4, Minor: 13}): + support.TLS = true + } + + return &support, nil +}) + +func (c *Conn) setupKernel(txOffload, rxOffload bool) error { + if !txOffload && !rxOffload { + return nil + } + support, err := KernelSupport() + if err != nil { + return err + } + if !support.TLS { + return nil + } + c.rawConn.Out.Lock() + defer c.rawConn.Out.Unlock() + err = control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptString(int(fd), unix.SOL_TCP, unix.TCP_ULP, "tls") + }) + if err != nil { + return E.Cause(err, "initialize kernel TLS") + } + + if rxOffload { + rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true) + if rxCrypto == nil { + return E.New("kTLS: unsupported cipher suite") + } + err = control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String()) + }) + if err != nil { + return err + } + if /*config.KernelRXExpectNoPad &&*/ *c.rawConn.Vers >= tls.VersionTLS13 && support.TLS_RX_NOPADDING { + err = control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1) + }) + if err != nil { + return err + } + } + c.kernelRx = true + } + + if txOffload { + txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false) + if txCrypto == nil { + return E.New("kTLS: unsupported cipher suite") + } + err = control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String()) + }) + if err != nil { + return err + } + if support.TLS_TX_ZEROCOPY { + err = control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_TX_ZEROCOPY_RO, 1) + }) + if err != nil { + return err + } + } + c.kernelTx = true + } + + return nil +} + +func (c *Conn) resetupTX() (func() error, error) { + if !c.kernelTx { + return nil, nil + } + support, err := KernelSupport() + if err != nil { + return nil, err + } + if !support.TLS_Version13_KeyUpdate { + return nil, errors.New("ktls: kernel does not support rekey") + } + txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false) + if txCrypto == nil { + return nil, errors.New("ktls: set kernelCipher on unsupported tls session") + } + return func() error { + return control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String()) + }) + }, nil +} + +func (c *Conn) resetupRX() error { + if !c.kernelRx { + return nil + } + support, err := KernelSupport() + if err != nil { + return err + } + if !support.TLS_Version13_KeyUpdate { + return errors.New("ktls: kernel does not support rekey") + } + rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true) + if rxCrypto == nil { + return errors.New("ktls: set kernelCipher on unsupported tls session") + } + return control.Raw(c.rawSyscallConn, func(fd uintptr) error { + return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String()) + }) +} + +func (c *Conn) readKernelRecord() (uint8, []byte, error) { + if c.rawConn.RawInput.Len() < maxPlaintext { + c.rawConn.RawInput.Grow(maxPlaintext - c.rawConn.RawInput.Len()) + } + + data := c.rawConn.RawInput.Bytes()[:maxPlaintext] + + // cmsg for record type + buffer := make([]byte, unix.CmsgSpace(1)) + cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) + cmsg.SetLen(unix.CmsgLen(1)) + + var iov unix.Iovec + iov.Base = &data[0] + iov.SetLen(len(data)) + + var msg unix.Msghdr + msg.Control = &buffer[0] + msg.Controllen = cmsg.Len + msg.Iov = &iov + msg.Iovlen = 1 + + var n int + var err error + er := c.rawSyscallConn.Read(func(fd uintptr) bool { + n, err = recvmsg(int(fd), &msg, 0) + return err != unix.EAGAIN + }) + if er != nil { + return 0, nil, er + } + switch err { + case nil: + case syscall.EINVAL: + return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion)) + case syscall.EMSGSIZE: + return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow)) + case syscall.EBADMSG: + return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecryptError)) + default: + return 0, nil, err + } + + if n <= 0 { + return 0, nil, io.EOF + } + + if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE { + typ := buffer[unix.CmsgLen(0)] + return typ, data[:n], nil + } + + return recordTypeApplicationData, data[:n], nil +} + +func (c *Conn) writeKernelRecord(typ uint16, data []byte) (int, error) { + if typ == recordTypeApplicationData { + return c.conn.Write(data) + } + + // cmsg for record type + buffer := make([]byte, unix.CmsgSpace(1)) + cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) + cmsg.SetLen(unix.CmsgLen(1)) + buffer[unix.CmsgLen(0)] = byte(typ) + cmsg.Level = unix.SOL_TLS + cmsg.Type = TLS_SET_RECORD_TYPE + + var iov unix.Iovec + iov.Base = &data[0] + iov.SetLen(len(data)) + + var msg unix.Msghdr + msg.Control = &buffer[0] + msg.Controllen = cmsg.Len + msg.Iov = &iov + msg.Iovlen = 1 + + var n int + var err error + ew := c.rawSyscallConn.Write(func(fd uintptr) bool { + n, err = sendmsg(int(fd), &msg, 0) + return err != unix.EAGAIN + }) + if ew != nil { + return 0, ew + } + return n, err +} + +//go:linkname recvmsg golang.org/x/sys/unix.recvmsg +func recvmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error) + +//go:linkname sendmsg golang.org/x/sys/unix.sendmsg +func sendmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error) diff --git a/common/ktls/ktls_prf.go b/common/ktls/ktls_prf.go new file mode 100644 index 00000000..f74a4876 --- /dev/null +++ b/common/ktls/ktls_prf.go @@ -0,0 +1,24 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import "unsafe" + +//go:linkname cipherSuiteByID github.com/metacubex/utls.cipherSuiteByID +func cipherSuiteByID(id uint16) unsafe.Pointer + +//go:linkname keysFromMasterSecret github.com/metacubex/utls.keysFromMasterSecret +func keysFromMasterSecret(version uint16, suite unsafe.Pointer, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) + +//go:linkname cipherSuiteTLS13ByID github.com/metacubex/utls.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) unsafe.Pointer + +//go:linkname nextTrafficSecret github.com/metacubex/utls.(*cipherSuiteTLS13).nextTrafficSecret +func nextTrafficSecret(cs unsafe.Pointer, trafficSecret []byte) []byte + +//go:linkname trafficKey github.com/metacubex/utls.(*cipherSuiteTLS13).trafficKey +func trafficKey(cs unsafe.Pointer, trafficSecret []byte) (key, iv []byte) diff --git a/common/ktls/ktls_read.go b/common/ktls/ktls_read.go new file mode 100644 index 00000000..45350441 --- /dev/null +++ b/common/ktls/ktls_read.go @@ -0,0 +1,292 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net" +) + +func (c *Conn) Read(b []byte) (int, error) { + if !c.kernelRx { + return c.Conn.Read(b) + } + + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + c.rawConn.In.Lock() + defer c.rawConn.In.Unlock() + + for c.rawConn.Input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.rawConn.Hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.rawConn.Input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.RawInput.Len() > 0 && + c.rawConn.RawInput.Bytes()[0] == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +func (c *Conn) readRecord() error { + if *c.rawConn.In.Err != nil { + return *c.rawConn.In.Err + } + + typ, data, err := c.readRawRecord() + if err != nil { + return err + } + + if len(data) > maxPlaintext { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // Application Data messages are always protected. + if c.rawConn.In.Cipher == nil && typ == recordTypeApplicationData { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + //if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + // c.retryCount = 0 + //} + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if *c.rawConn.Vers == tls.VersionTLS13 && typ != recordTypeHandshake && c.rawConn.Hand.Len() > 0 { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + case recordTypeAlert: + //if c.quic != nil { + // return c.rawConn.In.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + //} + if len(data) != 2 { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if data[1] == alertCloseNotify { + return c.rawConn.In.SetErrorLocked(io.EOF) + } + if *c.rawConn.Vers == tls.VersionTLS13 { + // TLS 1.3 removed warning-level alerts except for alertUserCanceled + // (RFC 8446, ยง 6.1). Since at least one major implementation + // (https://bugs.openjdk.org/browse/JDK-8323517) misuses this alert, + // many TLS stacks now ignore it outright when seen in a TLS 1.3 + // handshake (e.g. BoringSSL, NSS, Rustls). + if data[1] == alertUserCanceled { + // Like TLS 1.2 alertLevelWarning alerts, we drop the record and retry. + return c.retryReadRecord( /*expectChangeCipherSpec*/ ) + } + return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord( /*expectChangeCipherSpec*/ ) + case alertLevelError: + return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])}) + default: + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if c.rawConn.Hand.Len() > 0 { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if *c.rawConn.Vers == tls.VersionTLS13 { + return c.retryReadRecord( /*expectChangeCipherSpec*/ ) + } + // if !expectChangeCipherSpec { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + //} + //if err := c.rawConn.In.changeCipherSpec(); err != nil { + // return c.rawConn.In.setErrorLocked(c.sendAlert(err.(alert))) + //} + + case recordTypeApplicationData: + // Some OpenSSL servers send empty records in order to randomize the + // CBC RawIV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord( /*expectChangeCipherSpec*/ ) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.rawConn.Input.Reset(data) + case recordTypeHandshake: + if len(data) == 0 { + return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.rawConn.Hand.Write(data) + } + + return nil +} + +//nolint:staticcheck +func (c *Conn) readRawRecord() (typ uint8, data []byte, err error) { + // Read from kernel. + if c.kernelRx { + return c.readKernelRecord() + } + + // Read header, payload. + if err = c.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify + // is an error, but popular web sites seem to do this, so we accept it + // if and only if at the record boundary. + if err == io.ErrUnexpectedEOF && c.rawConn.RawInput.Len() == 0 { + err = io.EOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.rawConn.In.SetErrorLocked(err) + } + return + } + hdr := c.rawConn.RawInput.Bytes()[:recordHeaderLen] + typ = hdr[0] + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + expectedVers := *c.rawConn.Vers + if expectedVers == tls.VersionTLS13 { + // All TLS 1.3 records are expected to have 0x0303 (1.2) after + // the initial hello (RFC 8446 Section 5.1). + expectedVers = tls.VersionTLS12 + } + n := int(hdr[3])<<8 | int(hdr[4]) + if /*c.haveVers && */ vers != expectedVers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers) + err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg)) + return + } + //if !c.haveVers { + // // First message, be extra suspicious: this might not be a TLS + // // client. Bail out before reading a full 'body', if possible. + // // The current max version is 3.3 so if the version is >= 16.0, + // // it's probably not real. + // if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + // err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + // return + // } + //} + if *c.rawConn.Vers == tls.VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg)) + return + } + if err = c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.rawConn.In.SetErrorLocked(err) + } + return + } + + // Process message. + record := c.rawConn.RawInput.Next(recordHeaderLen + n) + data, typ, err = c.rawConn.In.Decrypt(record) + if err != nil { + err = c.rawConn.In.SetErrorLocked(c.sendAlert(uint8(err.(tls.AlertError)))) + return + } + return +} + +// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord( /*expectChangeCipherSpec bool*/ ) error { + //c.retryCount++ + //if c.retryCount > maxUselessRecords { + // c.sendAlert(alertUnexpectedMessage) + // return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + //} + return c.readRecord( /*expectChangeCipherSpec*/ ) +} + +// atLeastReader reads from R, stopping with EOF once at least N bytes have been +// read. It is different from an io.LimitedReader in that it doesn't cut short +// the last Read call, and in that it considers an early EOF an error. +type atLeastReader struct { + R io.Reader + N int64 +} + +func (r *atLeastReader) Read(p []byte) (int, error) { + if r.N <= 0 { + return 0, io.EOF + } + n, err := r.R.Read(p) + r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 + if r.N > 0 && err == io.EOF { + return n, io.ErrUnexpectedEOF + } + if r.N <= 0 && err == nil { + return n, io.EOF + } + return n, err +} + +// readFromUntil reads from r into c.rawConn.RawInput until c.rawConn.RawInput contains +// at least n bytes or else returns an error. +func (c *Conn) readFromUntil(r io.Reader, n int) error { + if c.rawConn.RawInput.Len() >= n { + return nil + } + needs := n - c.rawConn.RawInput.Len() + // There might be extra input waiting on the wire. Make a best effort + // attempt to fetch it so that it can be used in (*Conn).Read to + // "predict" closeNotify alerts. + c.rawConn.RawInput.Grow(needs + bytes.MinRead) + _, err := c.rawConn.RawInput.ReadFrom(&atLeastReader{r, int64(needs)}) + return err +} + +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err tls.RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawConn.RawInput.Bytes()) + return err +} diff --git a/common/ktls/ktls_read_wait.go b/common/ktls/ktls_read_wait.go new file mode 100644 index 00000000..8c1a8ff0 --- /dev/null +++ b/common/ktls/ktls_read_wait.go @@ -0,0 +1,41 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *Conn) WaitReadBuffer() (buffer *buf.Buffer, err error) { + c.rawConn.In.Lock() + defer c.rawConn.In.Unlock() + for c.rawConn.Input.Len() == 0 { + err = c.readRecord() + if err != nil { + return + } + } + buffer = c.readWaitOptions.NewBuffer() + n, err := c.rawConn.Input.Read(buffer.FreeBytes()) + if err != nil { + buffer.Release() + return + } + buffer.Truncate(n) + if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 && + c.rawConn.RawInput.Bytes()[0] == recordTypeAlert { + _ = c.rawConn.ReadRecord() + } + c.readWaitOptions.PostReturn(buffer) + return +} diff --git a/common/ktls/ktls_stub.go b/common/ktls/ktls_stub.go new file mode 100644 index 00000000..ab4fc0b4 --- /dev/null +++ b/common/ktls/ktls_stub.go @@ -0,0 +1,13 @@ +//go:build !linux || !go1.25 || without_badtls + +package ktls + +import ( + "os" + + aTLS "github.com/sagernet/sing/common/tls" +) + +func NewConn(conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) { + return nil, os.ErrInvalid +} diff --git a/common/ktls/ktls_write.go b/common/ktls/ktls_write.go new file mode 100644 index 00000000..6f04ca29 --- /dev/null +++ b/common/ktls/ktls_write.go @@ -0,0 +1,154 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && go1.25 && !without_badtls + +package ktls + +import ( + "crypto/cipher" + "crypto/tls" + "errors" + "net" +) + +func (c *Conn) Write(b []byte) (int, error) { + if !c.kernelTx { + return c.Conn.Write(b) + } + // interlock with Close below + for { + x := c.rawConn.ActiveCall.Load() + if x&1 != 0 { + return 0, net.ErrClosed + } + if c.rawConn.ActiveCall.CompareAndSwap(x, x+2) { + break + } + } + defer c.rawConn.ActiveCall.Add(-2) + + //if err := c.Conn.HandshakeContext(context.Background()); err != nil { + // return 0, err + //} + + c.rawConn.Out.Lock() + defer c.rawConn.Out.Unlock() + + if err := *c.rawConn.Out.Err; err != nil { + return 0, err + } + + if !c.rawConn.IsHandshakeComplete.Load() { + return 0, tls.AlertError(alertInternalError) + } + + if *c.rawConn.CloseNotifySent { + // return 0, errShutdown + return 0, errors.New("tls: protocol is shutdown") + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the RawIV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && *c.rawConn.Vers == tls.VersionTLS10 { + if _, ok := (*c.rawConn.Out.Cipher).(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.rawConn.Out.SetErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.rawConn.Out.SetErrorLocked(err) +} + +func (c *Conn) writeRecordLocked(typ uint16, data []byte) (n int, err error) { + if !c.kernelTx { + return c.rawConn.WriteRecordLocked(typ, data) + } + /*for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + _, err = c.writeKernelRecord(typ, data[:m]) + if err != nil { + return + } + n += m + data = data[m:] + }*/ + return c.writeKernelRecord(typ, data) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +func (c *Conn) maxPayloadSizeForWrite(typ uint16) int { + if /*c.config.DynamicRecordSizingDisabled ||*/ typ != recordTypeApplicationData { + return maxPlaintext + } + + if *c.rawConn.PacketsSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.rawConn.Out.ExplicitNonceLen() + if rawCipher := *c.rawConn.Out.Cipher; rawCipher != nil { + switch ciph := rawCipher.(type) { + case cipher.Stream: + payloadBytes -= (*c.rawConn.Out.Mac).Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + /*case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The RawMac is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size()*/ + default: + panic("unknown cipher type") + } + } + if *c.rawConn.Vers == tls.VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := *c.rawConn.PacketsSent + *c.rawConn.PacketsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} diff --git a/common/tls/client.go b/common/tls/client.go index d45d6173..d42bebd4 100644 --- a/common/tls/client.go +++ b/common/tls/client.go @@ -8,8 +8,10 @@ import ( "os" "github.com/sagernet/sing-box/common/badtls" + "github.com/sagernet/sing-box/common/ktls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" @@ -45,6 +47,12 @@ func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, e if err != nil { return nil, err } + if kConfig, isKConfig := config.(KTLSCapableConfig); isKConfig && (kConfig.KernelTx() || kConfig.KernelRx()) { + if !C.IsLinux { + return nil, E.New("kTLS is only supported on Linux") + } + return ktls.NewConn(tlsConn, kConfig.KernelTx(), kConfig.KernelRx()) + } readWaitConn, err := badtls.NewReadWaitConn(tlsConn) if err == nil { return readWaitConn, nil diff --git a/common/tls/config.go b/common/tls/config.go index 72bbd194..e3f19239 100644 --- a/common/tls/config.go +++ b/common/tls/config.go @@ -21,6 +21,12 @@ type ( CurveID = tls.CurveID ) +type KTLSCapableConfig interface { + Config + KernelTx() bool + KernelRx() bool +} + func ParseTLSVersion(version string) (uint16, error) { switch version { case "1.0": diff --git a/common/tls/server.go b/common/tls/server.go index bcc5ddfa..7265b6d3 100644 --- a/common/tls/server.go +++ b/common/tls/server.go @@ -6,9 +6,11 @@ import ( "os" "github.com/sagernet/sing-box/common/badtls" + "github.com/sagernet/sing-box/common/ktls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" aTLS "github.com/sagernet/sing/common/tls" ) @@ -29,6 +31,12 @@ func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (C if err != nil { return nil, err } + if kConfig, isKConfig := config.(KTLSCapableConfig); isKConfig && (kConfig.KernelTx() || kConfig.KernelRx()) { + if !C.IsLinux { + return nil, E.New("kTLS is only supported on Linux") + } + return ktls.NewConn(tlsConn, kConfig.KernelTx(), kConfig.KernelRx()) + } readWaitConn, err := badtls.NewReadWaitConn(tlsConn) if err == nil { return readWaitConn, nil diff --git a/common/tls/std_client.go b/common/tls/std_client.go index 0f855228..f801b590 100644 --- a/common/tls/std_client.go +++ b/common/tls/std_client.go @@ -22,6 +22,7 @@ type STDClientConfig struct { fragment bool fragmentFallbackDelay time.Duration recordFragment bool + kernelTx, kernelRx bool } func (c *STDClientConfig) ServerName() string { @@ -52,7 +53,15 @@ func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) { } func (c *STDClientConfig) Clone() Config { - return &STDClientConfig{c.ctx, c.config.Clone(), c.fragment, c.fragmentFallbackDelay, c.recordFragment} + return &STDClientConfig{ + ctx: c.ctx, + config: c.config.Clone(), + fragment: c.fragment, + fragmentFallbackDelay: c.fragmentFallbackDelay, + recordFragment: c.recordFragment, + kernelTx: c.kernelTx, + kernelRx: c.kernelRx, + } } func (c *STDClientConfig) ECHConfigList() []byte { @@ -63,6 +72,14 @@ func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList } +func (c *STDClientConfig) KernelTx() bool { + return c.kernelTx +} + +func (c *STDClientConfig) KernelRx() bool { + return c.kernelRx +} + func NewSTDClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { var serverName string if options.ServerName != "" { @@ -146,7 +163,15 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb } tlsConfig.RootCAs = certPool } - stdConfig := &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment} + stdConfig := &STDClientConfig{ + ctx: ctx, + config: &tlsConfig, + fragment: options.Fragment, + fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay), + recordFragment: options.RecordFragment, + kernelTx: options.KernelTx, + kernelRx: options.KernelRx, + } if options.ECH != nil && options.ECH.Enabled { return parseECHClientConfig(ctx, stdConfig, options) } else { diff --git a/common/tls/std_server.go b/common/tls/std_server.go index 82ba71ed..39a959f1 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -20,15 +20,16 @@ import ( var errInsecureUnused = E.New("tls: insecure unused") type STDServerConfig struct { - config *tls.Config - logger log.Logger - acmeService adapter.SimpleLifecycle - certificate []byte - key []byte - certificatePath string - keyPath string - echKeyPath string - watcher *fswatch.Watcher + config *tls.Config + logger log.Logger + kernelTx, kernelRx bool + acmeService adapter.SimpleLifecycle + certificate []byte + key []byte + certificatePath string + keyPath string + echKeyPath string + watcher *fswatch.Watcher } func (c *STDServerConfig) ServerName() string { @@ -69,10 +70,20 @@ func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) { func (c *STDServerConfig) Clone() Config { return &STDServerConfig{ - config: c.config.Clone(), + config: c.config.Clone(), + kernelTx: c.kernelTx, + kernelRx: c.kernelRx, } } +func (c *STDServerConfig) KernelTx() bool { + return c.kernelTx +} + +func (c *STDServerConfig) KernelRx() bool { + return c.kernelRx +} + func (c *STDServerConfig) Start() error { if c.acmeService != nil { return c.acmeService.Start() @@ -265,6 +276,8 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound return &STDServerConfig{ config: tlsConfig, logger: logger, + kernelTx: options.KernelTx, + kernelRx: options.KernelRx, acmeService: acmeService, certificate: certificate, key: key, diff --git a/common/tls/utls_client.go b/common/tls/utls_client.go index fceb15b8..88ea1ada 100644 --- a/common/tls/utls_client.go +++ b/common/tls/utls_client.go @@ -29,6 +29,8 @@ type UTLSClientConfig struct { fragment bool fragmentFallbackDelay time.Duration recordFragment bool + kernelTx bool + kernelRx bool } func (c *UTLSClientConfig) ServerName() string { @@ -67,7 +69,7 @@ func (c *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []by func (c *UTLSClientConfig) Clone() Config { return &UTLSClientConfig{ - c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment, + c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment, c.kernelTx, c.kernelRx, } } @@ -79,6 +81,14 @@ func (c *UTLSClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byt c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList } +func (c *UTLSClientConfig) KernelTx() bool { + return c.kernelTx +} + +func (c *UTLSClientConfig) KernelRx() bool { + return c.kernelRx +} + type utlsConnWrapper struct { *utls.UConn } @@ -214,7 +224,12 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out if err != nil { return nil, err } - uConfig := &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment} + uConfig := &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment, options.KernelTx, options.KernelRx} + if uConfig.kernelTx || uConfig.kernelRx { + if options.Reality != nil && options.Reality.Enabled { + return nil, E.New("Reality is conflict with kTLS") + } + } if options.ECH != nil && options.ECH.Enabled { if options.Reality != nil && options.Reality.Enabled { return nil, E.New("Reality is conflict with ECH") diff --git a/go.mod b/go.mod index 56ffd5e0..036ae027 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.1 require ( github.com/anytls/sing-anytls v0.0.8 + github.com/blang/semver/v4 v4.0.0 github.com/caddyserver/certmagic v0.23.0 github.com/coder/websocket v1.8.13 github.com/cretz/bine v0.2.0 @@ -27,7 +28,7 @@ require ( github.com/sagernet/gomobile v0.1.8 github.com/sagernet/gvisor v0.0.0-20250822052253-5558536cf237 github.com/sagernet/quic-go v0.52.0-beta.1 - github.com/sagernet/sing v0.7.8-0.20250906004629-421beb6473ea + github.com/sagernet/sing v0.7.8-0.20250907125815-3d24f9b5ff7c github.com/sagernet/sing-mux v0.3.3 github.com/sagernet/sing-quic v0.5.1 github.com/sagernet/sing-shadowsocks v0.2.8 diff --git a/go.sum b/go.sum index 9fa2811f..db3b7df0 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/anytls/sing-anytls v0.0.8 h1:1u/fnH1HoeeMV5mX7/eUOjLBvPdkd1UJRmXiRi6V github.com/anytls/sing-anytls v0.0.8/go.mod h1:7rjN6IukwysmdusYsrV51Fgu1uW6vsrdd6ctjnEAln8= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/caddyserver/certmagic v0.23.0 h1:CfpZ/50jMfG4+1J/u2LV6piJq4HOfO6ppOnOf7DkFEU= github.com/caddyserver/certmagic v0.23.0/go.mod h1:9mEZIWqqWoI+Gf+4Trh04MOVPD0tGSxtqsxg87hAIH4= github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= @@ -167,8 +169,8 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs= github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4= github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= -github.com/sagernet/sing v0.7.8-0.20250906004629-421beb6473ea h1:CDRl4q5Y2dM6MQE1MwukhrxbObfK/rj0QtK7vnJhST0= -github.com/sagernet/sing v0.7.8-0.20250906004629-421beb6473ea/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.7.8-0.20250907125815-3d24f9b5ff7c h1:7jbIoVt3p1vlgQoyWIlzRvwInNwDxe6+tSJME07LTmI= +github.com/sagernet/sing v0.7.8-0.20250907125815-3d24f9b5ff7c/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw= github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA= github.com/sagernet/sing-quic v0.5.1 h1:o+mX/schfy6fbbU2rnb6ouUYOL+iUBjA4jOZqyIvDsU= diff --git a/option/tls.go b/option/tls.go index 1c09527c..db51ed1a 100644 --- a/option/tls.go +++ b/option/tls.go @@ -14,6 +14,8 @@ type InboundTLSOptions struct { CertificatePath string `json:"certificate_path,omitempty"` Key badoption.Listable[string] `json:"key,omitempty"` KeyPath string `json:"key_path,omitempty"` + KernelTx bool `json:"kernel_tx,omitempty"` + KernelRx bool `json:"kernel_rx,omitempty"` ACME *InboundACMEOptions `json:"acme,omitempty"` ECH *InboundECHOptions `json:"ech,omitempty"` Reality *InboundRealityOptions `json:"reality,omitempty"` @@ -50,6 +52,8 @@ type OutboundTLSOptions struct { Fragment bool `json:"fragment,omitempty"` FragmentFallbackDelay badoption.Duration `json:"fragment_fallback_delay,omitempty"` RecordFragment bool `json:"record_fragment,omitempty"` + KernelTx bool `json:"kernel_tx,omitempty"` + KernelRx bool `json:"kernel_rx,omitempty"` ECH *OutboundECHOptions `json:"ech,omitempty"` UTLS *OutboundUTLSOptions `json:"utls,omitempty"` Reality *OutboundRealityOptions `json:"reality,omitempty"` diff --git a/release/local/debug.sh b/release/local/debug.sh index d6bd3057..d649bed4 100755 --- a/release/local/debug.sh +++ b/release/local/debug.sh @@ -13,7 +13,7 @@ pushd $PROJECT git fetch git reset FETCH_HEAD --hard git clean -fdx -go install -v -trimpath -ldflags "-s -w -buildid=" -tags with_quic,with_acme,debug ./cmd/sing-box +go install -v -trimpath -ldflags "-s -w -buildid= -checklinkname=0" -tags with_quic,with_acme,debug ./cmd/sing-box popd sudo systemctl stop sing-box diff --git a/release/local/install.sh b/release/local/install.sh index 24e9d006..3aa3d976 100755 --- a/release/local/install.sh +++ b/release/local/install.sh @@ -10,7 +10,7 @@ DIR=$(dirname "$0") PROJECT=$DIR/../.. pushd $PROJECT -go install -v -trimpath -ldflags "-s -w -buildid=" -tags with_quic,with_wireguard,with_acme ./cmd/sing-box +go install -v -trimpath -ldflags "-s -w -buildid= -checklinkname=0" -tags with_quic,with_wireguard,with_acme ./cmd/sing-box popd sudo cp $(go env GOPATH)/bin/sing-box /usr/local/bin/ diff --git a/release/local/reinstall.sh b/release/local/reinstall.sh index 71d07109..04cef16b 100755 --- a/release/local/reinstall.sh +++ b/release/local/reinstall.sh @@ -10,7 +10,7 @@ DIR=$(dirname "$0") PROJECT=$DIR/../.. pushd $PROJECT -go install -v -trimpath -ldflags "-s -w -buildid=" -tags with_quic,with_wireguard,with_acme ./cmd/sing-box +go install -v -trimpath -ldflags "-s -w -buildid= -checklinkname=0" -tags with_quic,with_wireguard,with_acme ./cmd/sing-box popd sudo systemctl stop sing-box diff --git a/route/conn.go b/route/conn.go index 3d2b8f05..e6fbaafd 100644 --- a/route/conn.go +++ b/route/conn.go @@ -102,6 +102,8 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co m.connections.Remove(element) }) var done atomic.Bool + m.preConnectionCopy(ctx, conn, remoteConn, false, &done, onClose) + m.preConnectionCopy(ctx, remoteConn, conn, true, &done, onClose) go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) } @@ -224,6 +226,24 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } +func (m *ConnectionManager) preConnectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { + if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() { + err := m.connectionCopyEarly(source, destination) + if err != nil { + if done.Swap(true) { + onClose(err) + } + common.Close(source, destination) + if !direction { + m.logger.ErrorContext(ctx, "connection upload handshake: ", err) + } else { + m.logger.ErrorContext(ctx, "connection download handshake: ", err) + } + return + } + } +} + func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { var ( sourceReader io.Reader = source @@ -262,21 +282,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, } break } - if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() { - err := m.connectionCopyEarly(source, destination) - if err != nil { - if done.Swap(true) { - onClose(err) - } - common.Close(source, destination) - if !direction { - m.logger.ErrorContext(ctx, "connection upload handshake: ", err) - } else { - m.logger.ErrorContext(ctx, "connection download handshake: ", err) - } - return - } - } + _, err := bufio.CopyWithCounters(destinationWriter, sourceReader, source, readCounters, writeCounters, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize) if err != nil { common.Close(source, destination) diff --git a/transport/trojan/protocol.go b/transport/trojan/protocol.go index e13dda67..7c12201e 100644 --- a/transport/trojan/protocol.go +++ b/transport/trojan/protocol.go @@ -83,6 +83,14 @@ func (c *ClientConn) Upstream() any { return c.ExtendedConn } +func (c *ClientConn) ReaderReplaceable() bool { + return c.headerWritten +} + +func (c *ClientConn) WriterReplaceable() bool { + return c.headerWritten +} + type ClientPacketConn struct { net.Conn access sync.Mutex