diff --git a/adapter/outbound/manager.go b/adapter/outbound/manager.go index 44ac8bc5..b58f5277 100644 --- a/adapter/outbound/manager.go +++ b/adapter/outbound/manager.go @@ -30,7 +30,7 @@ type Manager struct { outboundByTag map[string]adapter.Outbound dependByTag map[string][]string defaultOutbound adapter.Outbound - defaultOutboundFallback adapter.Outbound + defaultOutboundFallback func() (adapter.Outbound, error) } func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager { @@ -44,7 +44,7 @@ func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, } } -func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) { +func (m *Manager) Initialize(defaultOutboundFallback func() (adapter.Outbound, error)) { m.defaultOutboundFallback = defaultOutboundFallback } @@ -55,18 +55,31 @@ func (m *Manager) Start(stage adapter.StartStage) error { } m.started = true m.stage = stage - outbounds := m.outbounds - m.access.Unlock() if stage == adapter.StartStateStart { + if m.defaultOutbound == nil { + directOutbound, err := m.defaultOutboundFallback() + if err != nil { + m.access.Unlock() + return E.Cause(err, "create direct outbound for fallback") + } + m.outbounds = append(m.outbounds, directOutbound) + m.outboundByTag[directOutbound.Tag()] = directOutbound + m.defaultOutbound = directOutbound + } if m.defaultTag != "" && m.defaultOutbound == nil { defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag) if !loaded { + m.access.Unlock() return E.New("default outbound not found: ", m.defaultTag) } m.defaultOutbound = defaultEndpoint } + outbounds := m.outbounds + m.access.Unlock() return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...)) } else { + outbounds := m.outbounds + m.access.Unlock() for _, outbound := range outbounds { err := adapter.LegacyStart(outbound, stage) if err != nil { @@ -187,11 +200,7 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { func (m *Manager) Default() adapter.Outbound { m.access.RLock() defer m.access.RUnlock() - if m.defaultOutbound != nil { - return m.defaultOutbound - } else { - return m.defaultOutboundFallback - } + return m.defaultOutbound } func (m *Manager) Remove(tag string) error { diff --git a/box.go b/box.go index bfb0b47e..8a38f6ae 100644 --- a/box.go +++ b/box.go @@ -314,15 +314,15 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize service[", i, "]") } } - outboundManager.Initialize(common.Must1( - direct.NewOutbound( + outboundManager.Initialize(func() (adapter.Outbound, error) { + return direct.NewOutbound( ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.DirectOutboundOptions{}, - ), - )) + ) + }) dnsTransportManager.Initialize(common.Must1( local.NewTransport( ctx, diff --git a/service/resolved/resolve1.go b/service/resolved/resolve1.go index d64619e6..8e6dd3fa 100644 --- a/service/resolved/resolve1.go +++ b/service/resolved/resolve1.go @@ -182,9 +182,9 @@ func (t *resolve1Manager) logRequest(sender dbus.Sender, message ...any) context } else if metadata.ProcessInfo.UserId != 0 { prefix = F.ToString("uid:", metadata.ProcessInfo.UserId) } - t.logger.InfoContext(ctx, "(", prefix, ") ", F.ToString(message...)) + t.logger.InfoContext(ctx, "(", prefix, ") ", strings.Join(F.MapToString(message), " ")) } else { - t.logger.InfoContext(ctx, F.ToString(message...)) + t.logger.InfoContext(ctx, strings.Join(F.MapToString(message), " ")) } return adapter.WithContext(ctx, &metadata) } @@ -280,7 +280,10 @@ func (t *resolve1Manager) ResolveAddress(sender dbus.Sender, ifIndex int32, fami }, } ctx := t.logRequest(sender, "ResolveAddress ", link.iif.Name, familyToString(family), addr, flags) - response, lookupErr := t.dnsRouter.Exchange(ctx, request, adapter.DNSQueryOptions{}) + var metadata adapter.InboundContext + metadata.InboundType = t.Type() + metadata.Inbound = t.Tag() + response, lookupErr := t.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), request, adapter.DNSQueryOptions{}) if lookupErr != nil { err = wrapError(err) return @@ -301,7 +304,7 @@ func (t *resolve1Manager) ResolveAddress(sender dbus.Sender, ifIndex int32, fami return } -func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, family int32, hostname string, qClass uint16, qType uint16, flags uint64) (records []ResourceRecord, outflags uint64, err *dbus.Error) { +func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, hostname string, qClass uint16, qType uint16, flags uint64) (records []ResourceRecord, outflags uint64, err *dbus.Error) { t.linkAccess.Lock() link, err := t.getLink(ifIndex) if err != nil { @@ -320,8 +323,11 @@ func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, famil }, }, } - ctx := t.logRequest(sender, "ResolveRecord ", link.iif.Name, familyToString(family), hostname, mDNS.Class(qClass), mDNS.Type(qType), flags) - response, exchangeErr := t.dnsRouter.Exchange(ctx, request, adapter.DNSQueryOptions{}) + ctx := t.logRequest(sender, "ResolveRecord", link.iif.Name, hostname, mDNS.Class(qClass), mDNS.Type(qType), flags) + var metadata adapter.InboundContext + metadata.InboundType = t.Type() + metadata.Inbound = t.Tag() + response, exchangeErr := t.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), request, adapter.DNSQueryOptions{}) if exchangeErr != nil { err = wrapError(exchangeErr) return @@ -341,6 +347,7 @@ func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, famil err = wrapError(unpackErr) } record.Data = data + records = append(records, record) } return } @@ -380,8 +387,10 @@ func (t *resolve1Manager) ResolveService(sender dbus.Sender, ifIndex int32, host }, }, } - - srvResponse, exchangeErr := t.dnsRouter.Exchange(ctx, srvRequest, adapter.DNSQueryOptions{}) + var metadata adapter.InboundContext + metadata.InboundType = t.Type() + metadata.Inbound = t.Tag() + srvResponse, exchangeErr := t.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), srvRequest, adapter.DNSQueryOptions{}) if exchangeErr != nil { err = wrapError(exchangeErr) return diff --git a/service/resolved/service.go b/service/resolved/service.go index 133cb1ab..eaedc09d 100644 --- a/service/resolved/service.go +++ b/service/resolved/service.go @@ -91,11 +91,6 @@ func (i *Service) Start(stage adapter.StartStage) error { return E.New("multiple resolved service are not supported") } } - case adapter.StartStateStart: - err := i.listener.Start() - if err != nil { - return err - } systemBus, err := dbus.SystemBus() if err != nil { return err @@ -117,6 +112,11 @@ func (i *Service) Start(stage adapter.StartStage) error { return E.New("unknown request name reply: ", reply) } i.networkUpdateCallback = i.network.NetworkMonitor().RegisterCallback(i.onNetworkUpdate) + case adapter.StartStateStart: + err := i.listener.Start() + if err != nil { + return err + } } return nil } @@ -167,6 +167,8 @@ func (i *Service) exchangePacket0(ctx context.Context, buffer *buf.Buffer, oob [ } var metadata adapter.InboundContext metadata.Source = source + metadata.InboundType = i.Type() + metadata.Inbound = i.Tag() response, err := i.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), &message, adapter.DNSQueryOptions{}) if err != nil { return err