mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-08 01:22:07 +08:00
246 lines
5.7 KiB
Go
246 lines
5.7 KiB
Go
package conntrack
|
|
|
|
import (
|
|
"net"
|
|
"net/netip"
|
|
runtimeDebug "runtime/debug"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing/common"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/memory"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/common/x/list"
|
|
)
|
|
|
|
var _ Tracker = (*DefaultTracker)(nil)
|
|
|
|
type DefaultTracker struct {
|
|
connAccess sync.RWMutex
|
|
connList list.List[net.Conn]
|
|
connAddress map[netip.AddrPort]netip.AddrPort
|
|
|
|
packetConnAccess sync.RWMutex
|
|
packetConnList list.List[AbstractPacketConn]
|
|
packetConnAddress map[netip.AddrPort]bool
|
|
|
|
pendingAccess sync.RWMutex
|
|
pendingList list.List[netip.AddrPort]
|
|
|
|
killerEnabled bool
|
|
memoryLimit uint64
|
|
killerLastCheck time.Time
|
|
}
|
|
|
|
func NewDefaultTracker(killerEnabled bool, memoryLimit uint64) *DefaultTracker {
|
|
return &DefaultTracker{
|
|
connAddress: make(map[netip.AddrPort]netip.AddrPort),
|
|
packetConnAddress: make(map[netip.AddrPort]bool),
|
|
killerEnabled: killerEnabled,
|
|
memoryLimit: memoryLimit,
|
|
}
|
|
}
|
|
|
|
func (t *DefaultTracker) NewConn(conn net.Conn) (net.Conn, error) {
|
|
err := t.KillerCheck()
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
t.connAccess.Lock()
|
|
element := t.connList.PushBack(conn)
|
|
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
|
|
t.connAccess.Unlock()
|
|
return &Conn{
|
|
Conn: conn,
|
|
closeFunc: common.OnceFunc(func() {
|
|
t.removeConn(element)
|
|
}),
|
|
}, nil
|
|
}
|
|
|
|
func (t *DefaultTracker) NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error) {
|
|
err := t.KillerCheck()
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
t.connAccess.Lock()
|
|
element := t.connList.PushBack(conn)
|
|
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
|
|
t.connAccess.Unlock()
|
|
return N.OnceClose(func(it error) {
|
|
t.removeConn(element)
|
|
}), nil
|
|
}
|
|
|
|
func (t *DefaultTracker) NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
|
|
err := t.KillerCheck()
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
t.packetConnAccess.Lock()
|
|
element := t.packetConnList.PushBack(conn)
|
|
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
|
|
t.packetConnAccess.Unlock()
|
|
return &PacketConn{
|
|
PacketConn: conn,
|
|
closeFunc: common.OnceFunc(func() {
|
|
t.removePacketConn(element)
|
|
}),
|
|
}, nil
|
|
}
|
|
|
|
func (t *DefaultTracker) NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error) {
|
|
err := t.KillerCheck()
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
t.packetConnAccess.Lock()
|
|
element := t.packetConnList.PushBack(conn)
|
|
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
|
|
t.packetConnAccess.Unlock()
|
|
return N.OnceClose(func(it error) {
|
|
t.removePacketConn(element)
|
|
}), nil
|
|
}
|
|
|
|
func (t *DefaultTracker) CheckConn(source netip.AddrPort, destination netip.AddrPort) bool {
|
|
t.connAccess.RLock()
|
|
defer t.connAccess.RUnlock()
|
|
return t.connAddress[source] == destination
|
|
}
|
|
|
|
func (t *DefaultTracker) CheckPacketConn(source netip.AddrPort) bool {
|
|
t.packetConnAccess.RLock()
|
|
defer t.packetConnAccess.RUnlock()
|
|
return t.packetConnAddress[source]
|
|
}
|
|
|
|
func (t *DefaultTracker) AddPendingDestination(destination netip.AddrPort) func() {
|
|
t.pendingAccess.Lock()
|
|
defer t.pendingAccess.Unlock()
|
|
element := t.pendingList.PushBack(destination)
|
|
return func() {
|
|
t.pendingAccess.Lock()
|
|
defer t.pendingAccess.Unlock()
|
|
t.pendingList.Remove(element)
|
|
}
|
|
}
|
|
|
|
func (t *DefaultTracker) CheckDestination(destination netip.AddrPort) bool {
|
|
t.pendingAccess.RLock()
|
|
defer t.pendingAccess.RUnlock()
|
|
for element := t.pendingList.Front(); element != nil; element = element.Next() {
|
|
if element.Value == destination {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (t *DefaultTracker) KillerCheck() error {
|
|
if !t.killerEnabled {
|
|
return nil
|
|
}
|
|
nowTime := time.Now()
|
|
if nowTime.Sub(t.killerLastCheck) < 3*time.Second {
|
|
return nil
|
|
}
|
|
t.killerLastCheck = nowTime
|
|
if memory.Total() > t.memoryLimit {
|
|
t.Close()
|
|
go func() {
|
|
time.Sleep(time.Second)
|
|
runtimeDebug.FreeOSMemory()
|
|
}()
|
|
return E.New("out of memory")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *DefaultTracker) Count() int {
|
|
t.connAccess.RLock()
|
|
defer t.connAccess.RUnlock()
|
|
t.packetConnAccess.RLock()
|
|
defer t.packetConnAccess.RUnlock()
|
|
return t.connList.Len() + t.packetConnList.Len()
|
|
}
|
|
|
|
func (t *DefaultTracker) Close() {
|
|
t.connAccess.Lock()
|
|
for element := t.connList.Front(); element != nil; element = element.Next() {
|
|
element.Value.Close()
|
|
}
|
|
t.connList.Init()
|
|
t.connAccess.Unlock()
|
|
t.packetConnAccess.Lock()
|
|
for element := t.packetConnList.Front(); element != nil; element = element.Next() {
|
|
element.Value.Close()
|
|
}
|
|
t.packetConnList.Init()
|
|
t.packetConnAccess.Unlock()
|
|
}
|
|
|
|
func (t *DefaultTracker) removeConn(element *list.Element[net.Conn]) {
|
|
t.connAccess.Lock()
|
|
defer t.connAccess.Unlock()
|
|
delete(t.connAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
|
|
t.connList.Remove(element)
|
|
}
|
|
|
|
func (t *DefaultTracker) removePacketConn(element *list.Element[AbstractPacketConn]) {
|
|
t.packetConnAccess.Lock()
|
|
defer t.packetConnAccess.Unlock()
|
|
delete(t.packetConnAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
|
|
t.packetConnList.Remove(element)
|
|
}
|
|
|
|
type Conn struct {
|
|
net.Conn
|
|
closeFunc func()
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
c.closeFunc()
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (c *Conn) Upstream() any {
|
|
return c.Conn
|
|
}
|
|
|
|
func (c *Conn) ReaderReplaceable() bool {
|
|
return true
|
|
}
|
|
|
|
func (c *Conn) WriterReplaceable() bool {
|
|
return true
|
|
}
|
|
|
|
type PacketConn struct {
|
|
net.PacketConn
|
|
closeFunc func()
|
|
}
|
|
|
|
func (c *PacketConn) Close() error {
|
|
c.closeFunc()
|
|
return c.PacketConn.Close()
|
|
}
|
|
|
|
func (c *PacketConn) Upstream() any {
|
|
return c.PacketConn
|
|
}
|
|
|
|
func (c *PacketConn) ReaderReplaceable() bool {
|
|
return true
|
|
}
|
|
|
|
func (c *PacketConn) WriterReplaceable() bool {
|
|
return true
|
|
}
|