Fix atomic pointer usages

This commit is contained in:
世界 2025-08-14 01:48:42 +08:00
parent 378e39f70c
commit 4cd9555df6
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
7 changed files with 100 additions and 78 deletions

View File

@ -10,6 +10,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -22,7 +24,7 @@ type slowOpenConn struct {
ctx context.Context ctx context.Context
network string network string
destination M.Socksaddr destination M.Socksaddr
conn net.Conn conn atomic.Pointer[net.TCPConn]
create chan struct{} create chan struct{}
done chan struct{} done chan struct{}
access sync.Mutex access sync.Mutex
@ -50,22 +52,25 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des
} }
func (c *slowOpenConn) Read(b []byte) (n int, err error) { func (c *slowOpenConn) Read(b []byte) (n int, err error) {
if c.conn == nil { conn := c.conn.Load()
select { if conn != nil {
case <-c.create: return conn.Read(b)
if c.err != nil { }
return 0, c.err select {
} case <-c.create:
case <-c.done: if c.err != nil {
return 0, os.ErrClosed return 0, c.err
} }
return c.conn.Load().Read(b)
case <-c.done:
return 0, os.ErrClosed
} }
return c.conn.Read(b)
} }
func (c *slowOpenConn) Write(b []byte) (n int, err error) { func (c *slowOpenConn) Write(b []byte) (n int, err error) {
if c.conn != nil { tcpConn := c.conn.Load()
return c.conn.Write(b) if tcpConn != nil {
return tcpConn.Write(b)
} }
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
@ -74,7 +79,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
if c.err != nil { if c.err != nil {
return 0, c.err return 0, c.err
} }
return c.conn.Write(b) return c.conn.Load().Write(b)
case <-c.done: case <-c.done:
return 0, os.ErrClosed return 0, os.ErrClosed
default: default:
@ -83,7 +88,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
if err != nil { if err != nil {
c.err = err c.err = err
} else { } else {
c.conn = conn c.conn.Store(conn.(*net.TCPConn))
} }
n = len(b) n = len(b)
close(c.create) close(c.create)
@ -93,70 +98,77 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
func (c *slowOpenConn) Close() error { func (c *slowOpenConn) Close() error {
c.closeOnce.Do(func() { c.closeOnce.Do(func() {
close(c.done) close(c.done)
if c.conn != nil { conn := c.conn.Load()
c.conn.Close() if conn != nil {
conn.Close()
} }
}) })
return nil return nil
} }
func (c *slowOpenConn) LocalAddr() net.Addr { func (c *slowOpenConn) LocalAddr() net.Addr {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return M.Socksaddr{} return M.Socksaddr{}
} }
return c.conn.LocalAddr() return conn.LocalAddr()
} }
func (c *slowOpenConn) RemoteAddr() net.Addr { func (c *slowOpenConn) RemoteAddr() net.Addr {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return M.Socksaddr{} return M.Socksaddr{}
} }
return c.conn.RemoteAddr() return conn.RemoteAddr()
} }
func (c *slowOpenConn) SetDeadline(t time.Time) error { func (c *slowOpenConn) SetDeadline(t time.Time) error {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return os.ErrInvalid return os.ErrInvalid
} }
return c.conn.SetDeadline(t) return conn.SetDeadline(t)
} }
func (c *slowOpenConn) SetReadDeadline(t time.Time) error { func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return os.ErrInvalid return os.ErrInvalid
} }
return c.conn.SetReadDeadline(t) return conn.SetReadDeadline(t)
} }
func (c *slowOpenConn) SetWriteDeadline(t time.Time) error { func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return os.ErrInvalid return os.ErrInvalid
} }
return c.conn.SetWriteDeadline(t) return conn.SetWriteDeadline(t)
} }
func (c *slowOpenConn) Upstream() any { func (c *slowOpenConn) Upstream() any {
return c.conn return common.PtrOrNil(c.conn.Load())
} }
func (c *slowOpenConn) ReaderReplaceable() bool { func (c *slowOpenConn) ReaderReplaceable() bool {
return c.conn != nil return c.conn.Load() != nil
} }
func (c *slowOpenConn) WriterReplaceable() bool { func (c *slowOpenConn) WriterReplaceable() bool {
return c.conn != nil return c.conn.Load() != nil
} }
func (c *slowOpenConn) LazyHeadroom() bool { func (c *slowOpenConn) LazyHeadroom() bool {
return c.conn == nil return c.conn.Load() == nil
} }
func (c *slowOpenConn) NeedHandshake() bool { func (c *slowOpenConn) NeedHandshake() bool {
return c.conn == nil return c.conn.Load() == nil
} }
func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
select { select {
case <-c.create: case <-c.create:
if c.err != nil { if c.err != nil {
@ -166,5 +178,5 @@ func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
return 0, c.err return 0, c.err
} }
} }
return bufio.Copy(w, c.conn) return bufio.Copy(w, c.conn.Load())
} }

View File

@ -115,7 +115,7 @@ func (s *StatsService) RoutedPacketConnection(ctx context.Context, conn N.Packet
writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink")) writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink"))
} }
s.access.Unlock() s.access.Unlock()
return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter) return bufio.NewInt64CounterPacketConn(conn, readCounter, nil, writeCounter, nil)
} }
func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) { func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) {

4
go.mod
View File

@ -27,9 +27,9 @@ require (
github.com/sagernet/gomobile v0.1.8 github.com/sagernet/gomobile v0.1.8
github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb
github.com/sagernet/quic-go v0.52.0-beta.1 github.com/sagernet/quic-go v0.52.0-beta.1
github.com/sagernet/sing v0.7.5 github.com/sagernet/sing v0.7.6-0.20250813174432-bf460591becd
github.com/sagernet/sing-mux v0.3.3 github.com/sagernet/sing-mux v0.3.3
github.com/sagernet/sing-quic v0.5.0-beta.3 github.com/sagernet/sing-quic v0.5.0
github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks v0.2.8
github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowsocks2 v0.2.1
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11

8
go.sum
View File

@ -167,12 +167,12 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l
github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs= github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs=
github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4= github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4=
github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.7.5 h1:gNMwZCLPqR+4e0g6dwi0sSsrvOmoMjpZgqxKsuJZatc= github.com/sagernet/sing v0.7.6-0.20250813174432-bf460591becd h1:wfu8wOtIQ+BOYWDDh3n6Ue47J3Vac8IIKjaGbGYGB6k=
github.com/sagernet/sing v0.7.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing v0.7.6-0.20250813174432-bf460591becd/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw= github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw=
github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA= github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA=
github.com/sagernet/sing-quic v0.5.0-beta.3 h1:X/acRNsqQNfDlmwE7SorHfaZiny5e67hqIzM/592ric= github.com/sagernet/sing-quic v0.5.0 h1:jNLIyVk24lFPvu8A4x+ZNEnZdI+Tg1rp7eCJ6v0Csak=
github.com/sagernet/sing-quic v0.5.0-beta.3/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0= github.com/sagernet/sing-quic v0.5.0/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0=
github.com/sagernet/sing-shadowsocks v0.2.8 h1:PURj5PRoAkqeHh2ZW205RWzN9E9RtKCVCzByXruQWfE= github.com/sagernet/sing-shadowsocks v0.2.8 h1:PURj5PRoAkqeHh2ZW205RWzN9E9RtKCVCzByXruQWfE=
github.com/sagernet/sing-shadowsocks v0.2.8/go.mod h1:lo7TWEMDcN5/h5B8S0ew+r78ZODn6SwVaFhvB6H+PTI= github.com/sagernet/sing-shadowsocks v0.2.8/go.mod h1:lo7TWEMDcN5/h5B8S0ew+r78ZODn6SwVaFhvB6H+PTI=
github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnqqs2gQ2/Qioo= github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnqqs2gQ2/Qioo=

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@ -29,7 +30,7 @@ type Client struct {
serverAddr string serverAddr string
serviceName string serviceName string
dialOptions []grpc.DialOption dialOptions []grpc.DialOption
conn *grpc.ClientConn conn atomic.Pointer[grpc.ClientConn]
connAccess sync.Mutex connAccess sync.Mutex
} }
@ -74,13 +75,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
} }
func (c *Client) connect() (*grpc.ClientConn, error) { func (c *Client) connect() (*grpc.ClientConn, error) {
conn := c.conn conn := c.conn.Load()
if conn != nil && conn.GetState() != connectivity.Shutdown { if conn != nil && conn.GetState() != connectivity.Shutdown {
return conn, nil return conn, nil
} }
c.connAccess.Lock() c.connAccess.Lock()
defer c.connAccess.Unlock() defer c.connAccess.Unlock()
conn = c.conn conn = c.conn.Load()
if conn != nil && conn.GetState() != connectivity.Shutdown { if conn != nil && conn.GetState() != connectivity.Shutdown {
return conn, nil return conn, nil
} }
@ -89,7 +90,7 @@ func (c *Client) connect() (*grpc.ClientConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.conn = conn c.conn.Store(conn)
return conn, nil return conn, nil
} }
@ -109,11 +110,9 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.connAccess.Lock() conn := c.conn.Swap(nil)
defer c.connAccess.Unlock() if conn != nil {
if c.conn != nil { conn.Close()
c.conn.Close()
c.conn = nil
} }
return nil return nil
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-quic" "github.com/sagernet/sing-quic"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -29,7 +30,7 @@ type Client struct {
tlsConfig tls.Config tlsConfig tls.Config
quicConfig *quic.Config quicConfig *quic.Config
connAccess sync.Mutex connAccess sync.Mutex
conn quic.Connection conn atomic.TypedValue[quic.Connection]
rawConn net.Conn rawConn net.Conn
} }
@ -50,13 +51,13 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
} }
func (c *Client) offer() (quic.Connection, error) { func (c *Client) offer() (quic.Connection, error) {
conn := c.conn conn := c.conn.Load()
if conn != nil && !common.Done(conn.Context()) { if conn != nil && !common.Done(conn.Context()) {
return conn, nil return conn, nil
} }
c.connAccess.Lock() c.connAccess.Lock()
defer c.connAccess.Unlock() defer c.connAccess.Unlock()
conn = c.conn conn = c.conn.Load()
if conn != nil && !common.Done(conn.Context()) { if conn != nil && !common.Done(conn.Context()) {
return conn, nil return conn, nil
} }
@ -78,7 +79,7 @@ func (c *Client) offerNew() (quic.Connection, error) {
packetConn.Close() packetConn.Close()
return nil, err return nil, err
} }
c.conn = quicConn c.conn.Store(quicConn)
c.rawConn = udpConn c.rawConn = udpConn
return quicConn, nil return quicConn, nil
} }
@ -98,13 +99,13 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
func (c *Client) Close() error { func (c *Client) Close() error {
c.connAccess.Lock() c.connAccess.Lock()
defer c.connAccess.Unlock() defer c.connAccess.Unlock()
if c.conn != nil { conn := c.conn.Swap(nil)
c.conn.CloseWithError(0, "") if conn != nil {
conn.CloseWithError(0, "")
} }
if c.rawConn != nil { if c.rawConn != nil {
c.rawConn.Close() c.rawConn.Close()
} }
c.conn = nil
c.rawConn = nil c.rawConn = nil
return nil return nil
} }

View File

@ -8,6 +8,7 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"sync/atomic"
"time" "time"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@ -135,20 +136,22 @@ func (c *WebsocketConn) Upstream() any {
type EarlyWebsocketConn struct { type EarlyWebsocketConn struct {
*Client *Client
ctx context.Context ctx context.Context
conn *WebsocketConn conn atomic.Pointer[WebsocketConn]
access sync.Mutex access sync.Mutex
create chan struct{} create chan struct{}
err error err error
} }
func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
<-c.create <-c.create
if c.err != nil { if c.err != nil {
return 0, c.err return 0, c.err
} }
conn = c.conn.Load()
} }
return wrapWsError0(c.conn.Read(b)) return wrapWsError0(conn.Read(b))
} }
func (c *EarlyWebsocketConn) writeRequest(content []byte) error { func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
@ -187,21 +190,23 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
return err return err
} }
} }
c.conn = conn c.conn.Store(conn)
return nil return nil
} }
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
if c.conn != nil { conn := c.conn.Load()
return wrapWsError0(c.conn.Write(b)) if conn != nil {
return wrapWsError0(conn.Write(b))
} }
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
conn = c.conn.Load()
if c.err != nil { if c.err != nil {
return 0, c.err return 0, c.err
} }
if c.conn != nil { if conn != nil {
return wrapWsError0(c.conn.Write(b)) return wrapWsError0(conn.Write(b))
} }
err = c.writeRequest(b) err = c.writeRequest(b)
c.err = err c.err = err
@ -213,17 +218,19 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
} }
func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
if c.conn != nil { conn := c.conn.Load()
return wrapWsError(c.conn.WriteBuffer(buffer)) if conn != nil {
return wrapWsError(conn.WriteBuffer(buffer))
} }
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
if c.conn != nil {
return wrapWsError(c.conn.WriteBuffer(buffer))
}
if c.err != nil { if c.err != nil {
return c.err return c.err
} }
conn = c.conn.Load()
if conn != nil {
return wrapWsError(conn.WriteBuffer(buffer))
}
err := c.writeRequest(buffer.Bytes()) err := c.writeRequest(buffer.Bytes())
c.err = err c.err = err
close(c.create) close(c.create)
@ -231,24 +238,27 @@ func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
} }
func (c *EarlyWebsocketConn) Close() error { func (c *EarlyWebsocketConn) Close() error {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return nil return nil
} }
return c.conn.Close() return conn.Close()
} }
func (c *EarlyWebsocketConn) LocalAddr() net.Addr { func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return M.Socksaddr{} return M.Socksaddr{}
} }
return c.conn.LocalAddr() return conn.LocalAddr()
} }
func (c *EarlyWebsocketConn) RemoteAddr() net.Addr { func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
if c.conn == nil { conn := c.conn.Load()
if conn == nil {
return M.Socksaddr{} return M.Socksaddr{}
} }
return c.conn.RemoteAddr() return conn.RemoteAddr()
} }
func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error { func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
@ -268,11 +278,11 @@ func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
} }
func (c *EarlyWebsocketConn) Upstream() any { func (c *EarlyWebsocketConn) Upstream() any {
return common.PtrOrNil(c.conn) return common.PtrOrNil(c.conn.Load())
} }
func (c *EarlyWebsocketConn) LazyHeadroom() bool { func (c *EarlyWebsocketConn) LazyHeadroom() bool {
return c.conn == nil return c.conn.Load() == nil
} }
func wrapWsError(err error) error { func wrapWsError(err error) error {