From 4cd9555df6587e816e28d7594f654114518d6ff8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Aug 2025 01:48:42 +0800 Subject: [PATCH] Fix atomic pointer usages --- common/dialer/tfo.go | 80 ++++++++++++++++++-------------- experimental/v2rayapi/stats.go | 2 +- go.mod | 4 +- go.sum | 8 ++-- transport/v2raygrpc/client.go | 17 ++++--- transport/v2rayquic/client.go | 15 +++--- transport/v2raywebsocket/conn.go | 52 ++++++++++++--------- 7 files changed, 100 insertions(+), 78 deletions(-) diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 8ea59ca6..cd1a2a22 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -22,7 +24,7 @@ type slowOpenConn struct { ctx context.Context network string destination M.Socksaddr - conn net.Conn + conn atomic.Pointer[net.TCPConn] create chan struct{} done chan struct{} access sync.Mutex @@ -50,22 +52,25 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des } func (c *slowOpenConn) Read(b []byte) (n int, err error) { - if c.conn == nil { - select { - case <-c.create: - if c.err != nil { - return 0, c.err - } - case <-c.done: - return 0, os.ErrClosed - } + conn := c.conn.Load() + if conn != nil { + return conn.Read(b) + } + select { + case <-c.create: + if c.err != nil { + return 0, c.err + } + return c.conn.Load().Read(b) + case <-c.done: + return 0, os.ErrClosed } - return c.conn.Read(b) } func (c *slowOpenConn) Write(b []byte) (n int, err error) { - if c.conn != nil { - return c.conn.Write(b) + tcpConn := c.conn.Load() + if tcpConn != nil { + return tcpConn.Write(b) } c.access.Lock() defer c.access.Unlock() @@ -74,7 +79,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) { if c.err != nil { return 0, c.err } - return c.conn.Write(b) + return c.conn.Load().Write(b) case <-c.done: return 0, os.ErrClosed default: @@ -83,7 +88,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) { if err != nil { c.err = err } else { - c.conn = conn + c.conn.Store(conn.(*net.TCPConn)) } n = len(b) close(c.create) @@ -93,70 +98,77 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) { func (c *slowOpenConn) Close() error { c.closeOnce.Do(func() { close(c.done) - if c.conn != nil { - c.conn.Close() + conn := c.conn.Load() + if conn != nil { + conn.Close() } }) return nil } func (c *slowOpenConn) LocalAddr() net.Addr { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return M.Socksaddr{} } - return c.conn.LocalAddr() + return conn.LocalAddr() } func (c *slowOpenConn) RemoteAddr() net.Addr { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return M.Socksaddr{} } - return c.conn.RemoteAddr() + return conn.RemoteAddr() } func (c *slowOpenConn) SetDeadline(t time.Time) error { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return os.ErrInvalid } - return c.conn.SetDeadline(t) + return conn.SetDeadline(t) } func (c *slowOpenConn) SetReadDeadline(t time.Time) error { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return os.ErrInvalid } - return c.conn.SetReadDeadline(t) + return conn.SetReadDeadline(t) } func (c *slowOpenConn) SetWriteDeadline(t time.Time) error { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return os.ErrInvalid } - return c.conn.SetWriteDeadline(t) + return conn.SetWriteDeadline(t) } func (c *slowOpenConn) Upstream() any { - return c.conn + return common.PtrOrNil(c.conn.Load()) } func (c *slowOpenConn) ReaderReplaceable() bool { - return c.conn != nil + return c.conn.Load() != nil } func (c *slowOpenConn) WriterReplaceable() bool { - return c.conn != nil + return c.conn.Load() != nil } func (c *slowOpenConn) LazyHeadroom() bool { - return c.conn == nil + return c.conn.Load() == nil } func (c *slowOpenConn) NeedHandshake() bool { - return c.conn == nil + return c.conn.Load() == nil } func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { select { case <-c.create: if c.err != nil { @@ -166,5 +178,5 @@ func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { return 0, c.err } } - return bufio.Copy(w, c.conn) + return bufio.Copy(w, c.conn.Load()) } diff --git a/experimental/v2rayapi/stats.go b/experimental/v2rayapi/stats.go index a85e190f..16d44114 100644 --- a/experimental/v2rayapi/stats.go +++ b/experimental/v2rayapi/stats.go @@ -115,7 +115,7 @@ func (s *StatsService) RoutedPacketConnection(ctx context.Context, conn N.Packet writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink")) } s.access.Unlock() - return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter) + return bufio.NewInt64CounterPacketConn(conn, readCounter, nil, writeCounter, nil) } func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) { diff --git a/go.mod b/go.mod index 67425190..ea43a4e1 100644 --- a/go.mod +++ b/go.mod @@ -27,9 +27,9 @@ require ( github.com/sagernet/gomobile v0.1.8 github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb github.com/sagernet/quic-go v0.52.0-beta.1 - github.com/sagernet/sing v0.7.5 + github.com/sagernet/sing v0.7.6-0.20250813174432-bf460591becd github.com/sagernet/sing-mux v0.3.3 - github.com/sagernet/sing-quic v0.5.0-beta.3 + github.com/sagernet/sing-quic v0.5.0 github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 diff --git a/go.sum b/go.sum index 1f926629..1d21f925 100644 --- a/go.sum +++ b/go.sum @@ -167,12 +167,12 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs= github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4= github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= -github.com/sagernet/sing v0.7.5 h1:gNMwZCLPqR+4e0g6dwi0sSsrvOmoMjpZgqxKsuJZatc= -github.com/sagernet/sing v0.7.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.7.6-0.20250813174432-bf460591becd h1:wfu8wOtIQ+BOYWDDh3n6Ue47J3Vac8IIKjaGbGYGB6k= +github.com/sagernet/sing v0.7.6-0.20250813174432-bf460591becd/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw= github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA= -github.com/sagernet/sing-quic v0.5.0-beta.3 h1:X/acRNsqQNfDlmwE7SorHfaZiny5e67hqIzM/592ric= -github.com/sagernet/sing-quic v0.5.0-beta.3/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0= +github.com/sagernet/sing-quic v0.5.0 h1:jNLIyVk24lFPvu8A4x+ZNEnZdI+Tg1rp7eCJ6v0Csak= +github.com/sagernet/sing-quic v0.5.0/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0= github.com/sagernet/sing-shadowsocks v0.2.8 h1:PURj5PRoAkqeHh2ZW205RWzN9E9RtKCVCzByXruQWfE= github.com/sagernet/sing-shadowsocks v0.2.8/go.mod h1:lo7TWEMDcN5/h5B8S0ew+r78ZODn6SwVaFhvB6H+PTI= github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnqqs2gQ2/Qioo= diff --git a/transport/v2raygrpc/client.go b/transport/v2raygrpc/client.go index e649aa8f..2bbaa627 100644 --- a/transport/v2raygrpc/client.go +++ b/transport/v2raygrpc/client.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "time" "github.com/sagernet/sing-box/adapter" @@ -29,7 +30,7 @@ type Client struct { serverAddr string serviceName string dialOptions []grpc.DialOption - conn *grpc.ClientConn + conn atomic.Pointer[grpc.ClientConn] connAccess sync.Mutex } @@ -74,13 +75,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt } func (c *Client) connect() (*grpc.ClientConn, error) { - conn := c.conn + conn := c.conn.Load() if conn != nil && conn.GetState() != connectivity.Shutdown { return conn, nil } c.connAccess.Lock() defer c.connAccess.Unlock() - conn = c.conn + conn = c.conn.Load() if conn != nil && conn.GetState() != connectivity.Shutdown { return conn, nil } @@ -89,7 +90,7 @@ func (c *Client) connect() (*grpc.ClientConn, error) { if err != nil { return nil, err } - c.conn = conn + c.conn.Store(conn) return conn, nil } @@ -109,11 +110,9 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } func (c *Client) Close() error { - c.connAccess.Lock() - defer c.connAccess.Unlock() - if c.conn != nil { - c.conn.Close() - c.conn = nil + conn := c.conn.Swap(nil) + if conn != nil { + conn.Close() } return nil } diff --git a/transport/v2rayquic/client.go b/transport/v2rayquic/client.go index f9556211..3d1d916e 100644 --- a/transport/v2rayquic/client.go +++ b/transport/v2rayquic/client.go @@ -15,6 +15,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -29,7 +30,7 @@ type Client struct { tlsConfig tls.Config quicConfig *quic.Config connAccess sync.Mutex - conn quic.Connection + conn atomic.TypedValue[quic.Connection] rawConn net.Conn } @@ -50,13 +51,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt } func (c *Client) offer() (quic.Connection, error) { - conn := c.conn + conn := c.conn.Load() if conn != nil && !common.Done(conn.Context()) { return conn, nil } c.connAccess.Lock() defer c.connAccess.Unlock() - conn = c.conn + conn = c.conn.Load() if conn != nil && !common.Done(conn.Context()) { return conn, nil } @@ -78,7 +79,7 @@ func (c *Client) offerNew() (quic.Connection, error) { packetConn.Close() return nil, err } - c.conn = quicConn + c.conn.Store(quicConn) c.rawConn = udpConn return quicConn, nil } @@ -98,13 +99,13 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { func (c *Client) Close() error { c.connAccess.Lock() defer c.connAccess.Unlock() - if c.conn != nil { - c.conn.CloseWithError(0, "") + conn := c.conn.Swap(nil) + if conn != nil { + conn.CloseWithError(0, "") } if c.rawConn != nil { c.rawConn.Close() } - c.conn = nil c.rawConn = nil return nil } diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index 7f347dc9..009cadd8 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -8,6 +8,7 @@ import ( "net" "os" "sync" + "sync/atomic" "time" C "github.com/sagernet/sing-box/constant" @@ -135,20 +136,22 @@ func (c *WebsocketConn) Upstream() any { type EarlyWebsocketConn struct { *Client ctx context.Context - conn *WebsocketConn + conn atomic.Pointer[WebsocketConn] access sync.Mutex create chan struct{} err error } func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { <-c.create if c.err != nil { return 0, c.err } + conn = c.conn.Load() } - return wrapWsError0(c.conn.Read(b)) + return wrapWsError0(conn.Read(b)) } func (c *EarlyWebsocketConn) writeRequest(content []byte) error { @@ -187,21 +190,23 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error { return err } } - c.conn = conn + c.conn.Store(conn) return nil } func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { - if c.conn != nil { - return wrapWsError0(c.conn.Write(b)) + conn := c.conn.Load() + if conn != nil { + return wrapWsError0(conn.Write(b)) } c.access.Lock() defer c.access.Unlock() + conn = c.conn.Load() if c.err != nil { return 0, c.err } - if c.conn != nil { - return wrapWsError0(c.conn.Write(b)) + if conn != nil { + return wrapWsError0(conn.Write(b)) } err = c.writeRequest(b) c.err = err @@ -213,17 +218,19 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { } func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { - if c.conn != nil { - return wrapWsError(c.conn.WriteBuffer(buffer)) + conn := c.conn.Load() + if conn != nil { + return wrapWsError(conn.WriteBuffer(buffer)) } c.access.Lock() defer c.access.Unlock() - if c.conn != nil { - return wrapWsError(c.conn.WriteBuffer(buffer)) - } if c.err != nil { return c.err } + conn = c.conn.Load() + if conn != nil { + return wrapWsError(conn.WriteBuffer(buffer)) + } err := c.writeRequest(buffer.Bytes()) c.err = err close(c.create) @@ -231,24 +238,27 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { } func (c *EarlyWebsocketConn) Close() error { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return nil } - return c.conn.Close() + return conn.Close() } func (c *EarlyWebsocketConn) LocalAddr() net.Addr { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return M.Socksaddr{} } - return c.conn.LocalAddr() + return conn.LocalAddr() } func (c *EarlyWebsocketConn) RemoteAddr() net.Addr { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return M.Socksaddr{} } - return c.conn.RemoteAddr() + return conn.RemoteAddr() } func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error { @@ -268,11 +278,11 @@ func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool { } func (c *EarlyWebsocketConn) Upstream() any { - return common.PtrOrNil(c.conn) + return common.PtrOrNil(c.conn.Load()) } func (c *EarlyWebsocketConn) LazyHeadroom() bool { - return c.conn == nil + return c.conn.Load() == nil } func wrapWsError(err error) error {