From 43a9016c8304a4a15d2206ac1d9a8293f000fa36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Jun 2025 14:28:09 +0800 Subject: [PATCH] Fix leak in hijack-dns --- route/dns.go | 12 +++++++++++- route/route.go | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/route/dns.go b/route/dns.go index 2c6efefe..8d57c646 100644 --- a/route/dns.go +++ b/route/dns.go @@ -31,7 +31,7 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad } } -func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) { +func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { if natConn, isNatConn := conn.(udpnat.Conn); isNatConn { metadata.Destination = M.Socksaddr{} for _, packet := range packetBuffers { @@ -45,10 +45,12 @@ func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetB conn: conn, ctx: ctx, metadata: metadata, + onClose: onClose, }) return } err := dnsOutbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata) + N.CloseOnHandshakeFailure(conn, onClose, err) if err != nil && !E.IsClosedOrCanceled(err) { r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection")) } @@ -85,8 +87,16 @@ type dnsHijacker struct { conn N.PacketConn ctx context.Context metadata adapter.InboundContext + onClose N.CloseHandlerFunc } func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) { go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination) } + +func (h *dnsHijacker) Close() error { + if h.onClose != nil { + h.onClose(nil) + } + return nil +} diff --git a/route/route.go b/route/route.go index d0f93e0b..d9bf2638 100644 --- a/route/route.go +++ b/route/route.go @@ -120,7 +120,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad for _, buffer := range buffers { conn = bufio.NewCachedConn(conn, buffer) } - r.hijackDNSStream(ctx, conn, metadata) + N.CloseOnHandshakeFailure(conn, onClose, r.hijackDNSStream(ctx, conn, metadata)) return nil } } @@ -233,7 +233,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) return nil case *rule.RuleActionHijackDNS: - r.hijackDNSPacket(ctx, conn, packetBuffers, metadata) + r.hijackDNSPacket(ctx, conn, packetBuffers, metadata, onClose) return nil } }