From 905a2ded9376282598f5cd3bc81d8096fe12e671 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 14 Apr 2025 15:41:20 +0800 Subject: [PATCH] Add SSM API service --- adapter/ssm.go | 18 +++ constant/proxy.go | 1 + docs/configuration/service/ssm-api.md | 52 +++++++ include/registry.go | 2 + mkdocs.yml | 1 + option/shadowsocks.go | 1 + option/ssmapi.go | 11 ++ protocol/shadowsocks/inbound.go | 4 +- protocol/shadowsocks/inbound_multi.go | 47 +++++- service/ssmapi/api.go | 181 ++++++++++++++++++++++ service/ssmapi/server.go | 117 ++++++++++++++ service/ssmapi/traffic.go | 215 ++++++++++++++++++++++++++ service/ssmapi/user.go | 85 ++++++++++ 13 files changed, 726 insertions(+), 9 deletions(-) create mode 100644 adapter/ssm.go create mode 100644 docs/configuration/service/ssm-api.md create mode 100644 option/ssmapi.go create mode 100644 service/ssmapi/api.go create mode 100644 service/ssmapi/server.go create mode 100644 service/ssmapi/traffic.go create mode 100644 service/ssmapi/user.go diff --git a/adapter/ssm.go b/adapter/ssm.go new file mode 100644 index 00000000..caab9221 --- /dev/null +++ b/adapter/ssm.go @@ -0,0 +1,18 @@ +package adapter + +import ( + "net" + + N "github.com/sagernet/sing/common/network" +) + +type ManagedSSMServer interface { + Inbound + SetTracker(tracker SSMTracker) + UpdateUsers(users []string, uPSKs []string) error +} + +type SSMTracker interface { + TrackConnection(conn net.Conn, metadata InboundContext) net.Conn + TrackPacketConnection(conn N.PacketConn, metadata InboundContext) N.PacketConn +} diff --git a/constant/proxy.go b/constant/proxy.go index 4a09ab0b..cf12c48d 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -27,6 +27,7 @@ const ( TypeTailscale = "tailscale" TypeDERP = "derp" TypeResolved = "resolved" + TypeSSMAPI = "ssm-api" ) const ( diff --git a/docs/configuration/service/ssm-api.md b/docs/configuration/service/ssm-api.md new file mode 100644 index 00000000..854ec687 --- /dev/null +++ b/docs/configuration/service/ssm-api.md @@ -0,0 +1,52 @@ +--- +icon: material/new-box +--- + +!!! question "Since sing-box 1.12.0" + +# SSM API + +SSM API service is a RESTful API server for managing Shadowsocks servers. + +See https://github.com/Shadowsocks-NET/shadowsocks-specs/blob/main/2023-1-shadowsocks-server-management-api-v1.md + +### Structure + +```json +{ + "type": "ssm-api", + + ... // Listen Fields + + "servers": {}, + "tls": {} +} +``` + +### Listen Fields + +See [Listen Fields](/configuration/shared/listen/) for details. + +### Fields + +#### servers + +==Required== + +A mapping Object from HTTP endpoints to [Shadowsocks Inbound](/configuration/inbound/shadowsocks) tags. + +Selected Shadowsocks inbounds must be configured with [managed](/configuration/inbound/shadowsocks#managed) enabled. + +Example: + +```json +{ + "servers": { + "/": "ss-in" + } +} +``` + +#### tls + +TLS configuration, see [TLS](/configuration/shared/tls/#inbound). diff --git a/include/registry.go b/include/registry.go index 11069d5a..4c9ad449 100644 --- a/include/registry.go +++ b/include/registry.go @@ -35,6 +35,7 @@ import ( "github.com/sagernet/sing-box/protocol/vless" "github.com/sagernet/sing-box/protocol/vmess" "github.com/sagernet/sing-box/service/resolved" + "github.com/sagernet/sing-box/service/ssmapi" E "github.com/sagernet/sing/common/exceptions" ) @@ -125,6 +126,7 @@ func ServiceRegistry() *service.Registry { registry := service.NewRegistry() resolved.RegisterService(registry) + ssmapi.RegisterService(registry) registerDERPService(registry) diff --git a/mkdocs.yml b/mkdocs.yml index 35d1f1e3..951d9504 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -174,6 +174,7 @@ nav: - configuration/service/index.md - DERP: configuration/service/derp.md - Resolved: configuration/service/resolved.md + - SSM API: configuration/service/ssm-api.md markdown_extensions: - pymdownx.inlinehilite - pymdownx.snippets diff --git a/option/shadowsocks.go b/option/shadowsocks.go index 187b9b63..7cb656f3 100644 --- a/option/shadowsocks.go +++ b/option/shadowsocks.go @@ -8,6 +8,7 @@ type ShadowsocksInboundOptions struct { Users []ShadowsocksUser `json:"users,omitempty"` Destinations []ShadowsocksDestination `json:"destinations,omitempty"` Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"` + Managed bool `json:"managed,omitempty"` } type ShadowsocksUser struct { diff --git a/option/ssmapi.go b/option/ssmapi.go new file mode 100644 index 00000000..2fbdc1bc --- /dev/null +++ b/option/ssmapi.go @@ -0,0 +1,11 @@ +package option + +import ( + "github.com/sagernet/sing/common/json/badjson" +) + +type SSMAPIServiceOptions struct { + ListenOptions + Servers *badjson.TypedMap[string, string] `json:"servers"` + InboundTLSOptionsContainer +} diff --git a/protocol/shadowsocks/inbound.go b/protocol/shadowsocks/inbound.go index d921bfa6..52e2c524 100644 --- a/protocol/shadowsocks/inbound.go +++ b/protocol/shadowsocks/inbound.go @@ -32,8 +32,10 @@ func RegisterInbound(registry *inbound.Registry) { func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) { if len(options.Users) > 0 && len(options.Destinations) > 0 { return nil, E.New("users and destinations options must not be combined") + } else if options.Managed && (len(options.Users) > 0 || len(options.Destinations) > 0) { + return nil, E.New("users and destinations options are not supported in managed servers") } - if len(options.Users) > 0 { + if len(options.Users) > 0 || options.Managed { return newMultiInbound(ctx, router, logger, tag, options) } else if len(options.Destinations) > 0 { return newRelayInbound(ctx, router, logger, tag, options) diff --git a/protocol/shadowsocks/inbound_multi.go b/protocol/shadowsocks/inbound_multi.go index 5b604365..0120a08a 100644 --- a/protocol/shadowsocks/inbound_multi.go +++ b/protocol/shadowsocks/inbound_multi.go @@ -28,7 +28,10 @@ import ( "github.com/sagernet/sing/common/ntp" ) -var _ adapter.TCPInjectableInbound = (*MultiInbound)(nil) +var ( + _ adapter.TCPInjectableInbound = (*MultiInbound)(nil) + _ adapter.ManagedSSMServer = (*MultiInbound)(nil) +) type MultiInbound struct { inbound.Adapter @@ -38,6 +41,7 @@ type MultiInbound struct { listener *listener.Listener service shadowsocks.MultiService[int] users []option.ShadowsocksUser + tracker adapter.SSMTracker } func newMultiInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (*MultiInbound, error) { @@ -79,13 +83,15 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont if err != nil { return nil, err } - err = service.UpdateUsersWithPasswords(common.MapIndexed(options.Users, func(index int, user option.ShadowsocksUser) int { - return index - }), common.Map(options.Users, func(user option.ShadowsocksUser) string { - return user.Password - })) - if err != nil { - return nil, err + if len(options.Users) > 0 { + err = service.UpdateUsersWithPasswords(common.MapIndexed(options.Users, func(index int, user option.ShadowsocksUser) int { + return index + }), common.Map(options.Users, func(user option.ShadowsocksUser) string { + return user.Password + })) + if err != nil { + return nil, err + } } inbound.service = service inbound.users = options.Users @@ -112,6 +118,25 @@ func (h *MultiInbound) Close() error { return h.listener.Close() } +func (h *MultiInbound) SetTracker(tracker adapter.SSMTracker) { + h.tracker = tracker +} + +func (h *MultiInbound) UpdateUsers(users []string, uPSKs []string) error { + err := h.service.UpdateUsersWithPasswords(common.MapIndexed(users, func(index int, user string) int { + return index + }), uPSKs) + if err != nil { + return err + } + h.users = common.Map(users, func(user string) option.ShadowsocksUser { + return option.ShadowsocksUser{ + Name: user, + } + }) + return nil +} + //nolint:staticcheck func (h *MultiInbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata)) @@ -151,6 +176,9 @@ func (h *MultiInbound) newConnection(ctx context.Context, conn net.Conn, metadat metadata.InboundDetour = h.listener.ListenOptions().Detour //nolint:staticcheck metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + if h.tracker != nil { + conn = h.tracker.TrackConnection(conn, metadata) + } return h.router.RouteConnection(ctx, conn, metadata) } @@ -174,6 +202,9 @@ func (h *MultiInbound) newPacketConnection(ctx context.Context, conn N.PacketCon metadata.InboundDetour = h.listener.ListenOptions().Detour //nolint:staticcheck metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + if h.tracker != nil { + conn = h.tracker.TrackPacketConnection(conn, metadata) + } return h.router.RoutePacketConnection(ctx, conn, metadata) } diff --git a/service/ssmapi/api.go b/service/ssmapi/api.go new file mode 100644 index 00000000..b9b753a4 --- /dev/null +++ b/service/ssmapi/api.go @@ -0,0 +1,181 @@ +package ssmapi + +import ( + "net/http" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common/logger" + sHTTP "github.com/sagernet/sing/protocol/http" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" +) + +type APIServer struct { + logger logger.Logger + traffic *TrafficManager + user *UserManager +} + +func NewAPIServer(logger logger.Logger, traffic *TrafficManager, user *UserManager) *APIServer { + return &APIServer{ + logger: logger, + traffic: traffic, + user: user, + } +} + +func (s *APIServer) Route(r chi.Router) { + r.Route("/server/v1", func(r chi.Router) { + r.Use(func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + s.logger.Debug(request.Method, " ", request.RequestURI, " ", sHTTP.SourceAddress(request)) + handler.ServeHTTP(writer, request) + }) + }) + r.Get("/", s.getServerInfo) + r.Get("/users", s.listUser) + r.Post("/users", s.addUser) + r.Get("/users/{username}", s.getUser) + r.Put("/users/{username}", s.updateUser) + r.Delete("/users/{username}", s.deleteUser) + r.Get("/stats", s.getStats) + }) +} + +func (s *APIServer) getServerInfo(writer http.ResponseWriter, request *http.Request) { + render.JSON(writer, request, render.M{ + "server": "sing-box " + C.Version, + "apiVersion": "v1", + }) +} + +type UserObject struct { + UserName string `json:"username"` + Password string `json:"uPSK,omitempty"` + DownlinkBytes int64 `json:"downlinkBytes"` + UplinkBytes int64 `json:"uplinkBytes"` + DownlinkPackets int64 `json:"downlinkPackets"` + UplinkPackets int64 `json:"uplinkPackets"` + TCPSessions int64 `json:"tcpSessions"` + UDPSessions int64 `json:"udpSessions"` +} + +func (s *APIServer) listUser(writer http.ResponseWriter, request *http.Request) { + render.JSON(writer, request, render.M{ + "users": s.user.List(), + }) +} + +func (s *APIServer) addUser(writer http.ResponseWriter, request *http.Request) { + var addRequest struct { + UserName string `json:"username"` + Password string `json:"uPSK"` + } + err := render.DecodeJSON(request.Body, &addRequest) + if err != nil { + render.Status(request, http.StatusBadRequest) + render.PlainText(writer, request, err.Error()) + return + } + err = s.user.Add(addRequest.UserName, addRequest.Password) + if err != nil { + render.Status(request, http.StatusBadRequest) + render.PlainText(writer, request, err.Error()) + return + } + writer.WriteHeader(http.StatusCreated) +} + +func (s *APIServer) getUser(writer http.ResponseWriter, request *http.Request) { + userName := chi.URLParam(request, "username") + if userName == "" { + writer.WriteHeader(http.StatusBadRequest) + return + } + uPSK, loaded := s.user.Get(userName) + if !loaded { + writer.WriteHeader(http.StatusNotFound) + return + } + user := UserObject{ + UserName: userName, + Password: uPSK, + } + s.traffic.ReadUser(&user) + render.JSON(writer, request, user) +} + +func (s *APIServer) updateUser(writer http.ResponseWriter, request *http.Request) { + userName := chi.URLParam(request, "username") + if userName == "" { + writer.WriteHeader(http.StatusBadRequest) + return + } + var updateRequest struct { + Password string `json:"uPSK"` + } + err := render.DecodeJSON(request.Body, &updateRequest) + if err != nil { + render.Status(request, http.StatusBadRequest) + render.PlainText(writer, request, err.Error()) + return + } + _, loaded := s.user.Get(userName) + if !loaded { + writer.WriteHeader(http.StatusNotFound) + return + } + err = s.user.Update(userName, updateRequest.Password) + if err != nil { + render.Status(request, http.StatusBadRequest) + render.PlainText(writer, request, err.Error()) + return + } + writer.WriteHeader(http.StatusNoContent) +} + +func (s *APIServer) deleteUser(writer http.ResponseWriter, request *http.Request) { + userName := chi.URLParam(request, "username") + if userName == "" { + writer.WriteHeader(http.StatusBadRequest) + return + } + _, loaded := s.user.Get(userName) + if !loaded { + writer.WriteHeader(http.StatusNotFound) + return + } + err := s.user.Delete(userName) + if err != nil { + render.Status(request, http.StatusBadRequest) + render.PlainText(writer, request, err.Error()) + return + } + writer.WriteHeader(http.StatusNoContent) +} + +func (s *APIServer) getStats(writer http.ResponseWriter, request *http.Request) { + requireClear := chi.URLParam(request, "clear") == "true" + + users := s.user.List() + s.traffic.ReadUsers(users) + for i := range users { + users[i].Password = "" + } + uplinkBytes, downlinkBytes, uplinkPackets, downlinkPackets, tcpSessions, udpSessions := s.traffic.ReadGlobal() + + if requireClear { + s.traffic.Clear() + } + + render.JSON(writer, request, render.M{ + "uplinkBytes": uplinkBytes, + "downlinkBytes": downlinkBytes, + "uplinkPackets": uplinkPackets, + "downlinkPackets": downlinkPackets, + "tcpSessions": tcpSessions, + "udpSessions": udpSessions, + "users": users, + }) +} diff --git a/service/ssmapi/server.go b/service/ssmapi/server.go new file mode 100644 index 00000000..92d7354f --- /dev/null +++ b/service/ssmapi/server.go @@ -0,0 +1,117 @@ +package ssmapi + +import ( + "context" + "errors" + "net/http" + + "github.com/sagernet/sing-box/adapter" + boxService "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/common/listener" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" + "github.com/sagernet/sing/service" + + "github.com/go-chi/chi/v5" + "golang.org/x/net/http2" +) + +func RegisterService(registry *boxService.Registry) { + boxService.Register[option.SSMAPIServiceOptions](registry, C.TypeSSMAPI, NewService) +} + +type Service struct { + boxService.Adapter + ctx context.Context + logger log.ContextLogger + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server +} + +func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.SSMAPIServiceOptions) (adapter.Service, error) { + chiRouter := chi.NewRouter() + s := &Service{ + Adapter: boxService.NewAdapter(C.TypeSSMAPI, tag), + ctx: ctx, + logger: logger, + listener: listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Network: []string{N.NetworkTCP}, + Listen: options.ListenOptions, + }), + httpServer: &http.Server{ + Handler: chiRouter, + }, + } + inboundManager := service.FromContext[adapter.InboundManager](ctx) + if options.Servers.Size() == 0 { + return nil, E.New("missing servers") + } + for i, entry := range options.Servers.Entries() { + inbound, loaded := inboundManager.Get(entry.Value) + if !loaded { + return nil, E.New("parse SSM server[", i, "]: inbound ", entry.Value, " not found") + } + managedServer, isManaged := inbound.(adapter.ManagedSSMServer) + if !isManaged { + return nil, E.New("parse SSM server[", i, "]: inbound/", inbound.Type(), "[", inbound.Tag(), "] is not a SSM server") + } + traffic := NewTrafficManager() + managedServer.SetTracker(traffic) + user := NewUserManager(managedServer, traffic) + chiRouter.Route(entry.Key, NewAPIServer(logger, traffic, user).Route) + } + if options.TLS != nil { + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + s.tlsConfig = tlsConfig + } + return s, nil +} + +func (s *Service) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + if s.tlsConfig != nil { + err := s.tlsConfig.Start() + if err != nil { + return E.Cause(err, "create TLS config") + } + } + tcpListener, err := s.listener.ListenTCP() + if err != nil { + return err + } + if s.tlsConfig != nil { + if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) { + s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...)) + } + tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig) + } + go func() { + err = s.httpServer.Serve(tcpListener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("serve error: ", err) + } + }() + return nil +} + +func (s *Service) Close() error { + return common.Close( + common.PtrOrNil(s.httpServer), + common.PtrOrNil(s.listener), + s.tlsConfig, + ) +} diff --git a/service/ssmapi/traffic.go b/service/ssmapi/traffic.go new file mode 100644 index 00000000..7f3f103e --- /dev/null +++ b/service/ssmapi/traffic.go @@ -0,0 +1,215 @@ +package ssmapi + +import ( + "net" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/bufio" + N "github.com/sagernet/sing/common/network" +) + +var _ adapter.SSMTracker = (*TrafficManager)(nil) + +type TrafficManager struct { + globalUplink atomic.Int64 + globalDownlink atomic.Int64 + globalUplinkPackets atomic.Int64 + globalDownlinkPackets atomic.Int64 + globalTCPSessions atomic.Int64 + globalUDPSessions atomic.Int64 + userAccess sync.Mutex + userUplink map[string]*atomic.Int64 + userDownlink map[string]*atomic.Int64 + userUplinkPackets map[string]*atomic.Int64 + userDownlinkPackets map[string]*atomic.Int64 + userTCPSessions map[string]*atomic.Int64 + userUDPSessions map[string]*atomic.Int64 +} + +func NewTrafficManager() *TrafficManager { + manager := &TrafficManager{ + userUplink: make(map[string]*atomic.Int64), + userDownlink: make(map[string]*atomic.Int64), + userUplinkPackets: make(map[string]*atomic.Int64), + userDownlinkPackets: make(map[string]*atomic.Int64), + userTCPSessions: make(map[string]*atomic.Int64), + userUDPSessions: make(map[string]*atomic.Int64), + } + return manager +} + +func (s *TrafficManager) UpdateUsers(users []string) { + s.userAccess.Lock() + defer s.userAccess.Unlock() + newUserUplink := make(map[string]*atomic.Int64) + newUserDownlink := make(map[string]*atomic.Int64) + newUserUplinkPackets := make(map[string]*atomic.Int64) + newUserDownlinkPackets := make(map[string]*atomic.Int64) + newUserTCPSessions := make(map[string]*atomic.Int64) + newUserUDPSessions := make(map[string]*atomic.Int64) + for _, user := range users { + newUserUplink[user] = s.userUplinkPackets[user] + newUserDownlink[user] = s.userDownlinkPackets[user] + newUserUplinkPackets[user] = s.userUplinkPackets[user] + newUserDownlinkPackets[user] = s.userDownlinkPackets[user] + newUserTCPSessions[user] = s.userTCPSessions[user] + newUserUDPSessions[user] = s.userUDPSessions[user] + } + s.userUplink = newUserUplink + s.userDownlink = newUserDownlink + s.userUplinkPackets = newUserUplinkPackets + s.userDownlinkPackets = newUserDownlinkPackets + s.userTCPSessions = newUserTCPSessions + s.userUDPSessions = newUserUDPSessions +} + +func (s *TrafficManager) userCounter(user string) (*atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64) { + s.userAccess.Lock() + defer s.userAccess.Unlock() + upCounter, loaded := s.userUplink[user] + if !loaded { + upCounter = new(atomic.Int64) + s.userUplink[user] = upCounter + } + downCounter, loaded := s.userDownlink[user] + if !loaded { + downCounter = new(atomic.Int64) + s.userDownlink[user] = downCounter + } + upPacketsCounter, loaded := s.userUplinkPackets[user] + if !loaded { + upPacketsCounter = new(atomic.Int64) + s.userUplinkPackets[user] = upPacketsCounter + } + downPacketsCounter, loaded := s.userDownlinkPackets[user] + if !loaded { + downPacketsCounter = new(atomic.Int64) + s.userDownlinkPackets[user] = downPacketsCounter + } + tcpSessionsCounter, loaded := s.userTCPSessions[user] + if !loaded { + tcpSessionsCounter = new(atomic.Int64) + s.userTCPSessions[user] = tcpSessionsCounter + } + udpSessionsCounter, loaded := s.userUDPSessions[user] + if !loaded { + udpSessionsCounter = new(atomic.Int64) + s.userUDPSessions[user] = udpSessionsCounter + } + return upCounter, downCounter, upPacketsCounter, downPacketsCounter, tcpSessionsCounter, udpSessionsCounter +} + +func (s *TrafficManager) TrackConnection(conn net.Conn, metadata adapter.InboundContext) net.Conn { + s.globalTCPSessions.Add(1) + var readCounter []*atomic.Int64 + var writeCounter []*atomic.Int64 + readCounter = append(readCounter, &s.globalUplink) + writeCounter = append(writeCounter, &s.globalDownlink) + upCounter, downCounter, _, _, tcpSessionCounter, _ := s.userCounter(metadata.User) + readCounter = append(readCounter, upCounter) + writeCounter = append(writeCounter, downCounter) + tcpSessionCounter.Add(1) + return bufio.NewInt64CounterConn(conn, readCounter, writeCounter) +} + +func (s *TrafficManager) TrackPacketConnection(conn N.PacketConn, metadata adapter.InboundContext) N.PacketConn { + s.globalUDPSessions.Add(1) + var readCounter []*atomic.Int64 + var readPacketCounter []*atomic.Int64 + var writeCounter []*atomic.Int64 + var writePacketCounter []*atomic.Int64 + readCounter = append(readCounter, &s.globalUplink) + writeCounter = append(writeCounter, &s.globalDownlink) + readPacketCounter = append(readPacketCounter, &s.globalUplinkPackets) + writePacketCounter = append(writePacketCounter, &s.globalDownlinkPackets) + upCounter, downCounter, upPacketsCounter, downPacketsCounter, _, udpSessionCounter := s.userCounter(metadata.User) + readCounter = append(readCounter, upCounter) + writeCounter = append(writeCounter, downCounter) + readPacketCounter = append(readPacketCounter, upPacketsCounter) + writePacketCounter = append(writePacketCounter, downPacketsCounter) + udpSessionCounter.Add(1) + return bufio.NewInt64CounterPacketConn(conn, append(readCounter, readPacketCounter...), append(writeCounter, writePacketCounter...)) +} + +func (s *TrafficManager) ReadUser(user *UserObject) { + s.userAccess.Lock() + defer s.userAccess.Unlock() + s.readUser(user) +} + +func (s *TrafficManager) readUser(user *UserObject) { + if counter, loaded := s.userUplink[user.UserName]; loaded { + user.UplinkBytes = counter.Load() + } + if counter, loaded := s.userDownlink[user.UserName]; loaded { + user.DownlinkBytes = counter.Load() + } + if counter, loaded := s.userUplinkPackets[user.UserName]; loaded { + user.UplinkPackets = counter.Load() + } + if counter, loaded := s.userDownlinkPackets[user.UserName]; loaded { + user.DownlinkPackets = counter.Load() + } + if counter, loaded := s.userTCPSessions[user.UserName]; loaded { + user.TCPSessions = counter.Load() + } + if counter, loaded := s.userUDPSessions[user.UserName]; loaded { + user.UDPSessions = counter.Load() + } +} + +func (s *TrafficManager) ReadUsers(users []*UserObject) { + s.userAccess.Lock() + defer s.userAccess.Unlock() + for _, user := range users { + s.readUser(user) + } + return +} + +func (s *TrafficManager) ReadGlobal() ( + uplinkBytes int64, + downlinkBytes int64, + uplinkPackets int64, + downlinkPackets int64, + tcpSessions int64, + udpSessions int64, +) { + return s.globalUplink.Load(), + s.globalDownlink.Load(), + s.globalUplinkPackets.Load(), + s.globalDownlinkPackets.Load(), + s.globalTCPSessions.Load(), + s.globalUDPSessions.Load() +} + +func (s *TrafficManager) Clear() { + s.globalUplink.Store(0) + s.globalDownlink.Store(0) + s.globalUplinkPackets.Store(0) + s.globalDownlinkPackets.Store(0) + s.globalTCPSessions.Store(0) + s.globalUDPSessions.Store(0) + s.userAccess.Lock() + defer s.userAccess.Unlock() + for _, counter := range s.userUplink { + counter.Store(0) + } + for _, counter := range s.userDownlink { + counter.Store(0) + } + for _, counter := range s.userUplinkPackets { + counter.Store(0) + } + for _, counter := range s.userDownlinkPackets { + counter.Store(0) + } + for _, counter := range s.userTCPSessions { + counter.Store(0) + } + for _, counter := range s.userUDPSessions { + counter.Store(0) + } +} diff --git a/service/ssmapi/user.go b/service/ssmapi/user.go new file mode 100644 index 00000000..a8eb27fb --- /dev/null +++ b/service/ssmapi/user.go @@ -0,0 +1,85 @@ +package ssmapi + +import ( + "sync" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" +) + +type UserManager struct { + access sync.Mutex + usersMap map[string]string + server adapter.ManagedSSMServer + trafficManager *TrafficManager +} + +func NewUserManager(inbound adapter.ManagedSSMServer, trafficManager *TrafficManager) *UserManager { + return &UserManager{ + usersMap: make(map[string]string), + server: inbound, + trafficManager: trafficManager, + } +} + +func (m *UserManager) postUpdate() error { + users := make([]string, 0, len(m.usersMap)) + uPSKs := make([]string, 0, len(m.usersMap)) + for username, password := range m.usersMap { + users = append(users, username) + uPSKs = append(uPSKs, password) + } + err := m.server.UpdateUsers(users, uPSKs) + if err != nil { + return err + } + m.trafficManager.UpdateUsers(users) + return nil +} + +func (m *UserManager) List() []*UserObject { + m.access.Lock() + defer m.access.Unlock() + + users := make([]*UserObject, 0, len(m.usersMap)) + for username, password := range m.usersMap { + users = append(users, &UserObject{ + UserName: username, + Password: password, + }) + } + return users +} + +func (m *UserManager) Add(username string, password string) error { + m.access.Lock() + defer m.access.Unlock() + if _, found := m.usersMap[username]; found { + return E.New("user ", username, " already exists") + } + m.usersMap[username] = password + return m.postUpdate() +} + +func (m *UserManager) Get(username string) (string, bool) { + m.access.Lock() + defer m.access.Unlock() + if password, found := m.usersMap[username]; found { + return password, true + } + return "", false +} + +func (m *UserManager) Update(username string, password string) error { + m.access.Lock() + defer m.access.Unlock() + m.usersMap[username] = password + return m.postUpdate() +} + +func (m *UserManager) Delete(username string) error { + m.access.Lock() + defer m.access.Unlock() + delete(m.usersMap, username) + return m.postUpdate() +}