Compare commits

..

35 Commits

Author SHA1 Message Date
世界
7bf80c591a
documentation: Bump version 2024-12-03 11:52:11 +08:00
世界
5ac24fd704
hysteria2: Add more masquerade options 2024-12-03 11:52:02 +08:00
世界
ddc4a9886f
Improve timeouts 2024-12-03 11:52:01 +08:00
世界
ed98de2406
Add UDP timeout route option 2024-12-03 11:52:01 +08:00
世界
d26aa021f2
Make GSO adaptive 2024-12-03 11:52:01 +08:00
世界
0636352d8b
Fix tests 2024-12-03 11:52:01 +08:00
世界
c5620e8a14
Fix lint 2024-12-03 11:52:01 +08:00
世界
5bbc763a4b
refactor: WireGuard endpoint 2024-12-03 11:52:00 +08:00
世界
cada4d5361
refactor: connection manager 2024-12-03 11:52:00 +08:00
世界
5a4c5c80af
documentation: Fix typo 2024-12-03 11:52:00 +08:00
世界
0a52b9e138
Add override destination to route options 2024-12-03 11:51:59 +08:00
世界
d9c7bd91ee
Add dns.cache_capacity 2024-12-03 11:51:59 +08:00
世界
7f3b810d6f
Refactor multi networks strategy 2024-12-03 11:51:58 +08:00
世界
ad6d67cf11
documentation: Remove unused titles 2024-12-03 11:51:58 +08:00
世界
6ab0f5a805
Add multi network dialing 2024-12-03 11:51:58 +08:00
世界
65f3349bdd
documentation: Merge route options to route actions 2024-12-03 11:51:57 +08:00
世界
373d119efe
Add network_[type/is_expensive/is_constrained] rule items 2024-12-03 11:51:57 +08:00
世界
09cd846652
Merge route options to route actions 2024-12-03 11:51:57 +08:00
世界
f5c662689c
refactor: Platform Interfaces 2024-12-03 11:51:56 +08:00
世界
1b1972407f
refactor: Extract services form router 2024-12-03 11:51:56 +08:00
世界
60acb45b2b
refactor: Modular network manager 2024-12-03 11:51:56 +08:00
世界
a504d6be52
refactor: Modular inbound/outbound manager 2024-12-03 11:51:54 +08:00
世界
3bc7d1677c
documentation: Add rule action 2024-12-03 11:51:54 +08:00
世界
c21ed43fb8
documentation: Update the scheduled removal time of deprecated features 2024-12-03 11:51:53 +08:00
世界
cbfbba8dc9
documentation: Remove outdated icons 2024-12-03 11:51:53 +08:00
世界
76b5749a44
Migrate bad options to library 2024-12-03 11:51:53 +08:00
世界
c4401a0d5d
Implement udp connect 2024-12-03 11:51:52 +08:00
世界
ef3e72b5ca
Implement new deprecated warnings 2024-12-03 11:51:52 +08:00
世界
8159c1f422
Improve rule actions 2024-12-03 11:51:52 +08:00
世界
eabe4dde4c
Remove unused reject methods 2024-12-03 11:51:51 +08:00
世界
caad9f950b
refactor: Modular inbounds/outbounds 2024-12-03 11:51:50 +08:00
世界
337977aa68
Implement dns-hijack 2024-12-03 11:51:50 +08:00
世界
15d5375d95
Implement resolve(server) 2024-12-03 11:51:50 +08:00
世界
819eb98c4a
Implement TCP and ICMP rejects 2024-12-03 11:51:50 +08:00
世界
0350d206ed
Crazy sekai overturns the small pond 2024-12-03 11:51:40 +08:00
504 changed files with 6982 additions and 32304 deletions

View File

@ -1,30 +0,0 @@
-s dir
--name sing-box
--category net
--license GPL-3.0-or-later
--description "The universal proxy platform."
--url "https://sing-box.sagernet.org/"
--maintainer "nekohasekai <contact-git@sekai.icu>"
--no-deb-generate-changes
--config-files /etc/config/sing-box
--config-files /etc/sing-box/config.json
--depends ca-bundle
--depends kmod-inet-diag
--depends kmod-tun
--depends firewall4
--before-remove release/config/openwrt.prerm
release/config/config.json=/etc/sing-box/config.json
release/config/openwrt.conf=/etc/config/sing-box
release/config/openwrt.init=/etc/init.d/sing-box
release/config/openwrt.keep=/lib/upgrade/keep.d/sing-box
release/completions/sing-box.bash=/usr/share/bash-completion/completions/sing-box.bash
release/completions/sing-box.fish=/usr/share/fish/vendor_completions.d/sing-box.fish
release/completions/sing-box.zsh=/usr/share/zsh/site-functions/_sing-box
LICENSE=/usr/share/licenses/sing-box/LICENSE

View File

@ -1,25 +0,0 @@
-s dir
--name sing-box
--category net
--license GPL-3.0-or-later
--description "The universal proxy platform."
--url "https://sing-box.sagernet.org/"
--maintainer "nekohasekai <contact-git@sekai.icu>"
--deb-field "Bug: https://github.com/SagerNet/sing-box/issues"
--no-deb-generate-changes
--config-files /etc/sing-box/config.json
--after-install release/config/sing-box.postinst
release/config/config.json=/etc/sing-box/config.json
release/config/sing-box.service=/usr/lib/systemd/system/sing-box.service
release/config/sing-box@.service=/usr/lib/systemd/system/sing-box@.service
release/config/sing-box.sysusers=/usr/lib/sysusers.d/sing-box.conf
release/config/sing-box.rules=usr/share/polkit-1/rules.d/sing-box.rules
release/config/sing-box-split-dns.xml=/usr/share/dbus-1/system.d/sing-box-split-dns.conf
release/completions/sing-box.bash=/usr/share/bash-completion/completions/sing-box.bash
release/completions/sing-box.fish=/usr/share/fish/vendor_completions.d/sing-box.fish
release/completions/sing-box.zsh=/usr/share/zsh/site-functions/_sing-box
LICENSE=/usr/share/licenses/sing-box/LICENSE

28
.github/deb2ipk.sh vendored
View File

@ -1,28 +0,0 @@
#!/usr/bin/env bash
# mod from https://gist.github.com/pldubouilh/c5703052986bfdd404005951dee54683
set -e -o pipefail
PROJECT=$(dirname "$0")/../..
TMP_PATH=`mktemp -d`
cp $2 $TMP_PATH
pushd $TMP_PATH
DEB_NAME=`ls *.deb`
ar x $DEB_NAME
mkdir control
pushd control
tar xf ../control.tar.gz
rm md5sums
sed "s/Architecture:\\ \w*/Architecture:\\ $1/g" ./control -i
cat control
tar czf ../control.tar.gz ./*
popd
DEB_NAME=${DEB_NAME%.deb}
tar czf $DEB_NAME.ipk control.tar.gz data.tar.gz debian-binary
popd
cp $TMP_PATH/$DEB_NAME.ipk $3
rm -r $TMP_PATH

View File

@ -1,25 +0,0 @@
#!/usr/bin/env bash
VERSION="1.23.12"
mkdir -p $HOME/go
cd $HOME/go
wget "https://dl.google.com/go/go${VERSION}.linux-amd64.tar.gz"
tar -xzf "go${VERSION}.linux-amd64.tar.gz"
mv go go_legacy
cd go_legacy
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.23.x
# that means after golang1.24 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
curl https://github.com/MetaCubeX/go/commit/9ac42137ef6730e8b7daca016ece831297a1d75b.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/21290de8a4c91408de7c2b5b68757b1e90af49dd.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/69e2eed6dd0f6d815ebf15797761c13f31213dd6.diff | patch --verbose -p 1

View File

@ -1,673 +0,0 @@
name: Build
on:
workflow_dispatch:
inputs:
version:
description: "Version name"
required: true
type: string
build:
description: "Build type"
required: true
type: choice
default: "All"
options:
- All
- Binary
- Android
- Apple
- app-store
- iOS
- macOS
- tvOS
- macOS-standalone
- publish-android
push:
branches:
- main-next
- dev-next
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }}-${{ inputs.build }}
cancel-in-progress: true
jobs:
calculate_version:
name: Calculate version
runs-on: ubuntu-latest
outputs:
version: ${{ steps.outputs.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.25.0
- name: Check input version
if: github.event_name == 'workflow_dispatch'
run: |-
echo "version=${{ inputs.version }}"
echo "version=${{ inputs.version }}" >> "$GITHUB_ENV"
- name: Calculate version
if: github.event_name != 'workflow_dispatch'
run: |-
go run -v ./cmd/internal/read_tag --ci --nightly
- name: Set outputs
id: outputs
run: |-
echo "version=$version" >> "$GITHUB_OUTPUT"
build:
name: Build binary
if: github.event_name != 'workflow_dispatch' || inputs.build == 'All' || inputs.build == 'Binary'
runs-on: ubuntu-latest
needs:
- calculate_version
strategy:
matrix:
include:
- { os: linux, arch: amd64, debian: amd64, rpm: x86_64, pacman: x86_64, openwrt: "x86_64" }
- { os: linux, arch: "386", go386: sse2, debian: i386, rpm: i386, openwrt: "i386_pentium4" }
- { os: linux, arch: "386", go386: softfloat, openwrt: "i386_pentium-mmx" }
- { os: linux, arch: arm64, debian: arm64, rpm: aarch64, pacman: aarch64, openwrt: "aarch64_cortex-a53 aarch64_cortex-a72 aarch64_cortex-a76 aarch64_generic" }
- { os: linux, arch: arm, goarm: "5", openwrt: "arm_arm926ej-s arm_cortex-a7 arm_cortex-a9 arm_fa526 arm_xscale" }
- { os: linux, arch: arm, goarm: "6", debian: armel, rpm: armv6hl, openwrt: "arm_arm1176jzf-s_vfp" }
- { os: linux, arch: arm, goarm: "7", debian: armhf, rpm: armv7hl, pacman: armv7hl, openwrt: "arm_cortex-a5_vfpv4 arm_cortex-a7_neon-vfpv4 arm_cortex-a7_vfpv4 arm_cortex-a8_vfpv3 arm_cortex-a9_neon arm_cortex-a9_vfpv3-d16 arm_cortex-a15_neon-vfpv4" }
- { os: linux, arch: mips, gomips: softfloat, openwrt: "mips_24kc mips_4kec mips_mips32" }
- { os: linux, arch: mipsle, gomips: hardfloat, debian: mipsel, rpm: mipsel, openwrt: "mipsel_24kc_24kf" }
- { os: linux, arch: mipsle, gomips: softfloat, openwrt: "mipsel_24kc mipsel_74kc mipsel_mips32" }
- { os: linux, arch: mips64, gomips: softfloat, openwrt: "mips64_mips64r2 mips64_octeonplus" }
- { os: linux, arch: mips64le, gomips: hardfloat, debian: mips64el, rpm: mips64el }
- { os: linux, arch: mips64le, gomips: softfloat, openwrt: "mips64el_mips64r2" }
- { os: linux, arch: s390x, debian: s390x, rpm: s390x }
- { os: linux, arch: ppc64le, debian: ppc64el, rpm: ppc64le }
- { os: linux, arch: riscv64, debian: riscv64, rpm: riscv64, openwrt: "riscv64_generic" }
- { os: linux, arch: loong64, debian: loongarch64, rpm: loongarch64, openwrt: "loongarch64_generic" }
- { os: windows, arch: amd64 }
- { os: windows, arch: amd64, legacy_go123: true, legacy_name: "windows-7" }
- { os: windows, arch: "386" }
- { os: windows, arch: "386", legacy_go123: true, legacy_name: "windows-7" }
- { os: windows, arch: arm64 }
- { os: darwin, arch: amd64 }
- { os: darwin, arch: arm64 }
- { os: darwin, arch: amd64, legacy_go124: true, legacy_name: "macos-11" }
- { os: android, arch: arm64, ndk: "aarch64-linux-android21" }
- { os: android, arch: arm, ndk: "armv7a-linux-androideabi21" }
- { os: android, arch: amd64, ndk: "x86_64-linux-android21" }
- { os: android, arch: "386", ndk: "i686-linux-android21" }
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
if: ${{ ! (matrix.legacy_go123 || matrix.legacy_go124) }}
uses: actions/setup-go@v5
with:
go-version: ^1.25.0
- name: Setup Go 1.24
if: matrix.legacy_go124
uses: actions/setup-go@v5
with:
go-version: ~1.24.6
- name: Cache Go 1.23
if: matrix.legacy_go123
id: cache-legacy-go
uses: actions/cache@v4
with:
path: |
~/go/go_legacy
key: go_legacy_12312
- name: Setup Go 1.23
if: matrix.legacy_go123 && steps.cache-legacy-go.outputs.cache-hit != 'true'
run: |-
.github/setup_legacy_go.sh
- name: Setup Go 1.23
if: matrix.legacy_go123
run: |-
echo "PATH=$HOME/go/go_legacy/bin:$PATH" >> $GITHUB_ENV
echo "GOROOT=$HOME/go/go_legacy" >> $GITHUB_ENV
- name: Setup Android NDK
if: matrix.os == 'android'
uses: nttld/setup-ndk@v1
with:
ndk-version: r28
local-cache: true
- name: Set tag
run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV"
git tag v${{ needs.calculate_version.outputs.version }} -f
- name: Set build tags
run: |
set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale'
echo "BUILD_TAGS=${TAGS}" >> "${GITHUB_ENV}"
- name: Build
if: matrix.os != 'android'
run: |
set -xeuo pipefail
mkdir -p dist
go build -v -trimpath -o dist/sing-box -tags "${BUILD_TAGS}" \
-ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0' \
./cmd/sing-box
env:
CGO_ENABLED: "0"
GOOS: ${{ matrix.os }}
GOARCH: ${{ matrix.arch }}
GO386: ${{ matrix.go386 }}
GOARM: ${{ matrix.goarm }}
GOMIPS: ${{ matrix.gomips }}
GOMIPS64: ${{ matrix.gomips }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Build Android
if: matrix.os == 'android'
run: |
set -xeuo pipefail
go install -v ./cmd/internal/build
export CC='${{ matrix.ndk }}-clang'
export CXX="${CC}++"
mkdir -p dist
GOOS=$BUILD_GOOS GOARCH=$BUILD_GOARCH build go build -v -trimpath -o dist/sing-box -tags "${BUILD_TAGS}" \
-ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0' \
./cmd/sing-box
env:
CGO_ENABLED: "1"
BUILD_GOOS: ${{ matrix.os }}
BUILD_GOARCH: ${{ matrix.arch }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Set name
run: |-
DIR_NAME="sing-box-${{ needs.calculate_version.outputs.version }}-${{ matrix.os }}-${{ matrix.arch }}"
if [[ -n "${{ matrix.goarm }}" ]]; then
DIR_NAME="${DIR_NAME}v${{ matrix.goarm }}"
elif [[ -n "${{ matrix.go386 }}" && "${{ matrix.go386 }}" != 'sse2' ]]; then
DIR_NAME="${DIR_NAME}-${{ matrix.go386 }}"
elif [[ -n "${{ matrix.gomips }}" && "${{ matrix.gomips }}" != 'hardfloat' ]]; then
DIR_NAME="${DIR_NAME}-${{ matrix.gomips }}"
elif [[ -n "${{ matrix.legacy_name }}" ]]; then
DIR_NAME="${DIR_NAME}-legacy-${{ matrix.legacy_name }}"
fi
echo "DIR_NAME=${DIR_NAME}" >> "${GITHUB_ENV}"
PKG_VERSION="${{ needs.calculate_version.outputs.version }}"
PKG_VERSION="${PKG_VERSION//-/\~}"
echo "PKG_VERSION=${PKG_VERSION}" >> "${GITHUB_ENV}"
- name: Package DEB
if: matrix.debian != ''
run: |
set -xeuo pipefail
sudo gem install fpm
sudo apt-get update
sudo apt-get install -y debsigs
cp .fpm_systemd .fpm
fpm -t deb \
-v "$PKG_VERSION" \
-p "dist/sing-box_${{ needs.calculate_version.outputs.version }}_${{ matrix.os }}_${{ matrix.debian }}.deb" \
--architecture ${{ matrix.debian }} \
dist/sing-box=/usr/bin/sing-box
curl -Lo '/tmp/debsigs.diff' 'https://gitlab.com/debsigs/debsigs/-/commit/160138f5de1ec110376d3c807b60a37388bc7c90.diff'
sudo patch /usr/bin/debsigs < '/tmp/debsigs.diff'
rm -rf $HOME/.gnupg
gpg --pinentry-mode loopback --passphrase "${{ secrets.GPG_PASSPHRASE }}" --import <<EOF
${{ secrets.GPG_KEY }}
EOF
debsigs --sign=origin -k ${{ secrets.GPG_KEY_ID }} --gpgopts '--pinentry-mode loopback --passphrase "${{ secrets.GPG_PASSPHRASE }}"' dist/*.deb
- name: Package RPM
if: matrix.rpm != ''
run: |-
set -xeuo pipefail
sudo gem install fpm
cp .fpm_systemd .fpm
fpm -t rpm \
-v "$PKG_VERSION" \
-p "dist/sing-box_${{ needs.calculate_version.outputs.version }}_${{ matrix.os }}_${{ matrix.rpm }}.rpm" \
--architecture ${{ matrix.rpm }} \
dist/sing-box=/usr/bin/sing-box
cat > $HOME/.rpmmacros <<EOF
%_gpg_name ${{ secrets.GPG_KEY_ID }}
%_gpg_sign_cmd_extra_args --pinentry-mode loopback --passphrase ${{ secrets.GPG_PASSPHRASE }}
EOF
gpg --pinentry-mode loopback --passphrase "${{ secrets.GPG_PASSPHRASE }}" --import <<EOF
${{ secrets.GPG_KEY }}
EOF
rpmsign --addsign dist/*.rpm
- name: Package Pacman
if: matrix.pacman != ''
run: |-
set -xeuo pipefail
sudo gem install fpm
sudo apt-get update
sudo apt-get install -y libarchive-tools
cp .fpm_systemd .fpm
fpm -t pacman \
-v "$PKG_VERSION" \
-p "dist/sing-box_${{ needs.calculate_version.outputs.version }}_${{ matrix.os }}_${{ matrix.pacman }}.pkg.tar.zst" \
--architecture ${{ matrix.pacman }} \
dist/sing-box=/usr/bin/sing-box
- name: Package OpenWrt
if: matrix.openwrt != ''
run: |-
set -xeuo pipefail
sudo gem install fpm
cp .fpm_openwrt .fpm
fpm -t deb \
-v "$PKG_VERSION" \
-p "dist/openwrt.deb" \
--architecture all \
dist/sing-box=/usr/bin/sing-box
for architecture in ${{ matrix.openwrt }}; do
.github/deb2ipk.sh "$architecture" "dist/openwrt.deb" "dist/sing-box_${{ needs.calculate_version.outputs.version }}_openwrt_${architecture}.ipk"
done
rm "dist/openwrt.deb"
- name: Archive
run: |
set -xeuo pipefail
cd dist
mkdir -p "${DIR_NAME}"
cp ../LICENSE "${DIR_NAME}"
if [ '${{ matrix.os }}' = 'windows' ]; then
cp sing-box "${DIR_NAME}/sing-box.exe"
zip -r "${DIR_NAME}.zip" "${DIR_NAME}"
else
cp sing-box "${DIR_NAME}"
tar -czvf "${DIR_NAME}.tar.gz" "${DIR_NAME}"
fi
rm -r "${DIR_NAME}"
- name: Cleanup
run: rm dist/sing-box
- name: Upload artifact
uses: actions/upload-artifact@v4
with:
name: binary-${{ matrix.os }}_${{ matrix.arch }}${{ matrix.goarm && format('v{0}', matrix.goarm) }}${{ matrix.go386 && format('_{0}', matrix.go386) }}${{ matrix.gomips && format('_{0}', matrix.gomips) }}${{ matrix.legacy_name && format('-legacy-{0}', matrix.legacy_name) }}
path: "dist"
build_android:
name: Build Android
if: github.event_name != 'workflow_dispatch' || inputs.build == 'All' || inputs.build == 'Android'
runs-on: ubuntu-latest
needs:
- calculate_version
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
submodules: 'recursive'
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.25.0
- name: Setup Android NDK
id: setup-ndk
uses: nttld/setup-ndk@v1
with:
ndk-version: r28
- name: Setup OpenJDK
run: |-
sudo apt update && sudo apt install -y openjdk-17-jdk-headless
/usr/lib/jvm/java-17-openjdk-amd64/bin/java --version
- name: Set tag
run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV"
git tag v${{ needs.calculate_version.outputs.version }} -f
- name: Build library
run: |-
make lib_install
export PATH="$PATH:$(go env GOPATH)/bin"
make lib_android
env:
JAVA_HOME: /usr/lib/jvm/java-17-openjdk-amd64
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
- name: Checkout main branch
if: github.ref == 'refs/heads/main-next' && github.event_name != 'workflow_dispatch'
run: |-
cd clients/android
git checkout main
- name: Checkout dev branch
if: github.ref == 'refs/heads/dev-next'
run: |-
cd clients/android
git checkout dev
- name: Gradle cache
uses: actions/cache@v4
with:
path: ~/.gradle
key: gradle-${{ hashFiles('**/*.gradle') }}
- name: Update version
if: github.event_name == 'workflow_dispatch'
run: |-
go run -v ./cmd/internal/update_android_version --ci
- name: Update nightly version
if: github.event_name != 'workflow_dispatch'
run: |-
go run -v ./cmd/internal/update_android_version --ci --nightly
- name: Build
run: |-
mkdir clients/android/app/libs
cp libbox.aar clients/android/app/libs
cd clients/android
./gradlew :app:assemblePlayRelease :app:assembleOtherRelease
env:
JAVA_HOME: /usr/lib/jvm/java-17-openjdk-amd64
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
LOCAL_PROPERTIES: ${{ secrets.LOCAL_PROPERTIES }}
- name: Prepare upload
run: |-
mkdir -p dist
cp clients/android/app/build/outputs/apk/play/release/*.apk dist
cp clients/android/app/build/outputs/apk/other/release/*-universal.apk dist
- name: Upload artifact
uses: actions/upload-artifact@v4
with:
name: binary-android-apks
path: 'dist'
publish_android:
name: Publish Android
if: github.event_name == 'workflow_dispatch' && inputs.build == 'publish-android'
runs-on: ubuntu-latest
needs:
- calculate_version
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
submodules: 'recursive'
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.25.0
- name: Setup Android NDK
id: setup-ndk
uses: nttld/setup-ndk@v1
with:
ndk-version: r28
- name: Setup OpenJDK
run: |-
sudo apt update && sudo apt install -y openjdk-17-jdk-headless
/usr/lib/jvm/java-17-openjdk-amd64/bin/java --version
- name: Set tag
run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV"
git tag v${{ needs.calculate_version.outputs.version }} -f
- name: Build library
run: |-
make lib_install
export PATH="$PATH:$(go env GOPATH)/bin"
make lib_android
env:
JAVA_HOME: /usr/lib/jvm/java-17-openjdk-amd64
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
- name: Checkout main branch
if: github.ref == 'refs/heads/main-next' && github.event_name != 'workflow_dispatch'
run: |-
cd clients/android
git checkout main
- name: Checkout dev branch
if: github.ref == 'refs/heads/dev-next'
run: |-
cd clients/android
git checkout dev
- name: Gradle cache
uses: actions/cache@v4
with:
path: ~/.gradle
key: gradle-${{ hashFiles('**/*.gradle') }}
- name: Build
run: |-
go run -v ./cmd/internal/update_android_version --ci
mkdir clients/android/app/libs
cp libbox.aar clients/android/app/libs
cd clients/android
echo -n "$SERVICE_ACCOUNT_CREDENTIALS" | base64 --decode > service-account-credentials.json
./gradlew :app:publishPlayReleaseBundle
env:
JAVA_HOME: /usr/lib/jvm/java-17-openjdk-amd64
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
LOCAL_PROPERTIES: ${{ secrets.LOCAL_PROPERTIES }}
SERVICE_ACCOUNT_CREDENTIALS: ${{ secrets.SERVICE_ACCOUNT_CREDENTIALS }}
build_apple:
name: Build Apple clients
runs-on: macos-15
needs:
- calculate_version
strategy:
matrix:
include:
- name: iOS
if: ${{ github.event_name != 'workflow_dispatch' || inputs.build == 'All' || inputs.build == 'Apple' || inputs.build == 'app-store'|| inputs.build == 'iOS' }}
platform: ios
scheme: SFI
destination: 'generic/platform=iOS'
archive: build/SFI.xcarchive
upload: SFI/Upload.plist
- name: macOS
if: ${{ github.event_name != 'workflow_dispatch' || inputs.build == 'All' || inputs.build == 'Apple' || inputs.build == 'app-store'|| inputs.build == 'macOS' }}
platform: macos
scheme: SFM
destination: 'generic/platform=macOS'
archive: build/SFM.xcarchive
upload: SFI/Upload.plist
- name: tvOS
if: ${{ github.event_name != 'workflow_dispatch' || inputs.build == 'All' || inputs.build == 'Apple' || inputs.build == 'app-store'|| inputs.build == 'tvOS' }}
platform: tvos
scheme: SFT
destination: 'generic/platform=tvOS'
archive: build/SFT.xcarchive
upload: SFI/Upload.plist
- name: macOS-standalone
if: ${{ github.event_name != 'workflow_dispatch' || inputs.build == 'All' || inputs.build == 'Apple' || inputs.build == 'macOS-standalone' }}
platform: macos
scheme: SFM.System
destination: 'generic/platform=macOS'
archive: build/SFM.System.xcarchive
export: SFM.System/Export.plist
export_path: build/SFM.System
steps:
- name: Checkout
if: matrix.if
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
submodules: 'recursive'
- name: Setup Go
if: matrix.if
uses: actions/setup-go@v5
with:
go-version: ^1.25.0
- name: Setup Xcode stable
if: matrix.if && github.ref == 'refs/heads/main-next'
run: |-
sudo xcode-select -s /Applications/Xcode_16.4.app
- name: Setup Xcode beta
if: matrix.if && github.ref == 'refs/heads/dev-next'
run: |-
sudo xcode-select -s /Applications/Xcode_16.4.app
- name: Set tag
if: matrix.if
run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV"
git tag v${{ needs.calculate_version.outputs.version }} -f
echo "VERSION=${{ needs.calculate_version.outputs.version }}" >> "$GITHUB_ENV"
- name: Checkout main branch
if: matrix.if && github.ref == 'refs/heads/main-next' && github.event_name != 'workflow_dispatch'
run: |-
cd clients/apple
git checkout main
- name: Checkout dev branch
if: matrix.if && github.ref == 'refs/heads/dev-next'
run: |-
cd clients/apple
git checkout dev
- name: Setup certificates
if: matrix.if
run: |-
CERTIFICATE_PATH=$RUNNER_TEMP/Certificates.p12
KEYCHAIN_PATH=$RUNNER_TEMP/certificates.keychain-db
echo -n "$CERTIFICATES_P12" | base64 --decode -o $CERTIFICATE_PATH
security create-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH
security set-keychain-settings -lut 21600 $KEYCHAIN_PATH
security unlock-keychain -p "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH
security import $CERTIFICATE_PATH -P "$P12_PASSWORD" -A -t cert -f pkcs12 -k $KEYCHAIN_PATH
security set-key-partition-list -S apple-tool:,apple: -k "$KEYCHAIN_PASSWORD" $KEYCHAIN_PATH
security list-keychain -d user -s $KEYCHAIN_PATH
PROFILES_ZIP_PATH=$RUNNER_TEMP/Profiles.zip
echo -n "$PROVISIONING_PROFILES" | base64 --decode -o $PROFILES_ZIP_PATH
PROFILES_PATH="$HOME/Library/MobileDevice/Provisioning Profiles"
mkdir -p "$PROFILES_PATH"
unzip $PROFILES_ZIP_PATH -d "$PROFILES_PATH"
ASC_KEY_PATH=$RUNNER_TEMP/Key.p12
echo -n "$ASC_KEY" | base64 --decode -o $ASC_KEY_PATH
xcrun notarytool store-credentials "notarytool-password" \
--key $ASC_KEY_PATH \
--key-id $ASC_KEY_ID \
--issuer $ASC_KEY_ISSUER_ID
echo "ASC_KEY_PATH=$ASC_KEY_PATH" >> "$GITHUB_ENV"
echo "ASC_KEY_ID=$ASC_KEY_ID" >> "$GITHUB_ENV"
echo "ASC_KEY_ISSUER_ID=$ASC_KEY_ISSUER_ID" >> "$GITHUB_ENV"
env:
CERTIFICATES_P12: ${{ secrets.CERTIFICATES_P12 }}
P12_PASSWORD: ${{ secrets.P12_PASSWORD }}
KEYCHAIN_PASSWORD: ${{ secrets.P12_PASSWORD }}
PROVISIONING_PROFILES: ${{ secrets.PROVISIONING_PROFILES }}
ASC_KEY: ${{ secrets.ASC_KEY }}
ASC_KEY_ID: ${{ secrets.ASC_KEY_ID }}
ASC_KEY_ISSUER_ID: ${{ secrets.ASC_KEY_ISSUER_ID }}
- name: Build library
if: matrix.if
run: |-
make lib_install
export PATH="$PATH:$(go env GOPATH)/bin"
go run ./cmd/internal/build_libbox -target apple -platform ${{ matrix.platform }}
mv Libbox.xcframework clients/apple
- name: Update macOS version
if: matrix.if && matrix.name == 'macOS' && github.event_name == 'workflow_dispatch'
run: |-
MACOS_PROJECT_VERSION=$(go run -v ./cmd/internal/app_store_connect next_macos_project_version)
echo "MACOS_PROJECT_VERSION=$MACOS_PROJECT_VERSION"
echo "MACOS_PROJECT_VERSION=$MACOS_PROJECT_VERSION" >> "$GITHUB_ENV"
- name: Update version
if: matrix.if && matrix.name != 'iOS'
run: |-
go run -v ./cmd/internal/update_apple_version --ci
- name: Build
if: matrix.if
run: |-
cd clients/apple
xcodebuild archive \
-scheme "${{ matrix.scheme }}" \
-configuration Release \
-destination "${{ matrix.destination }}" \
-archivePath "${{ matrix.archive }}" \
-allowProvisioningUpdates \
-authenticationKeyPath $ASC_KEY_PATH \
-authenticationKeyID $ASC_KEY_ID \
-authenticationKeyIssuerID $ASC_KEY_ISSUER_ID
- name: Upload to App Store Connect
if: matrix.if && matrix.name != 'macOS-standalone' && github.event_name == 'workflow_dispatch'
run: |-
go run -v ./cmd/internal/app_store_connect cancel_app_store ${{ matrix.platform }}
cd clients/apple
xcodebuild -exportArchive \
-archivePath "${{ matrix.archive }}" \
-exportOptionsPlist ${{ matrix.upload }} \
-allowProvisioningUpdates \
-authenticationKeyPath $ASC_KEY_PATH \
-authenticationKeyID $ASC_KEY_ID \
-authenticationKeyIssuerID $ASC_KEY_ISSUER_ID
- name: Publish to TestFlight
if: matrix.if && matrix.name != 'macOS-standalone' && github.event_name == 'workflow_dispatch' && github.ref =='refs/heads/dev-next'
run: |-
go run -v ./cmd/internal/app_store_connect publish_testflight ${{ matrix.platform }}
- name: Build image
if: matrix.if && matrix.name == 'macOS-standalone' && github.event_name == 'workflow_dispatch'
run: |-
pushd clients/apple
xcodebuild -exportArchive \
-archivePath "${{ matrix.archive }}" \
-exportOptionsPlist ${{ matrix.export }} \
-exportPath "${{ matrix.export_path }}"
brew install create-dmg
create-dmg \
--volname "sing-box" \
--volicon "${{ matrix.export_path }}/SFM.app/Contents/Resources/AppIcon.icns" \
--icon "SFM.app" 0 0 \
--hide-extension "SFM.app" \
--app-drop-link 0 0 \
--skip-jenkins \
SFM.dmg "${{ matrix.export_path }}/SFM.app"
xcrun notarytool submit "SFM.dmg" --wait --keychain-profile "notarytool-password"
cd "${{ matrix.archive }}"
zip -r SFM.dSYMs.zip dSYMs
popd
mkdir -p dist
cp clients/apple/SFM.dmg "dist/SFM-${VERSION}-universal.dmg"
cp "clients/apple/${{ matrix.archive }}/SFM.dSYMs.zip" "dist/SFM-${VERSION}-universal.dSYMs.zip"
- name: Upload image
if: matrix.if && matrix.name == 'macOS-standalone' && github.event_name == 'workflow_dispatch'
uses: actions/upload-artifact@v4
with:
name: binary-macos-dmg
path: 'dist'
upload:
name: Upload builds
if: "!failure() && github.event_name == 'workflow_dispatch' && (inputs.build == 'All' || inputs.build == 'Binary' || inputs.build == 'Android' || inputs.build == 'Apple' || inputs.build == 'macOS-standalone')"
runs-on: ubuntu-latest
needs:
- calculate_version
- build
- build_android
- build_apple
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Cache ghr
uses: actions/cache@v4
id: cache-ghr
with:
path: |
~/go/bin/ghr
key: ghr
- name: Setup ghr
if: steps.cache-ghr.outputs.cache-hit != 'true'
run: |-
cd $HOME
git clone https://github.com/nekohasekai/ghr ghr
cd ghr
go install -v .
- name: Set tag
run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV"
git tag v${{ needs.calculate_version.outputs.version }} -f
echo "VERSION=${{ needs.calculate_version.outputs.version }}" >> "$GITHUB_ENV"
- name: Download builds
uses: actions/download-artifact@v5
with:
path: dist
merge-multiple: true
- name: Upload builds
if: ${{ env.PUBLISHED == 'false' }}
run: |-
export PATH="$PATH:$HOME/go/bin"
ghr --replace --draft --prerelease -p 5 "v${VERSION}" dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Replace builds
if: ${{ env.PUBLISHED != 'false' }}
run: |-
export PATH="$PATH:$HOME/go/bin"
ghr --replace -p 5 "v${VERSION}" dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

219
.github/workflows/debug.yml vendored Normal file
View File

@ -0,0 +1,219 @@
name: Debug build
on:
push:
branches:
- stable-next
- main-next
- dev-next
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- stable-next
- main-next
- dev-next
jobs:
build:
name: Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
- name: Run Test
run: |
go test -v ./...
build_go120:
name: Debug build (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.20
- name: Cache go module
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
key: go120-${{ hashFiles('**/go.sum') }}
- name: Run Test
run: make ci_build_go120
build_go121:
name: Debug build (Go 1.21)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.21
- name: Cache go module
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
key: go121-${{ hashFiles('**/go.sum') }}
- name: Run Test
run: make ci_build
build_go122:
name: Debug build (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
- name: Cache go module
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
key: go122-${{ hashFiles('**/go.sum') }}
- name: Run Test
run: make ci_build
cross:
strategy:
matrix:
include:
# windows
- name: windows-amd64
goos: windows
goarch: amd64
goamd64: v1
- name: windows-amd64-v3
goos: windows
goarch: amd64
goamd64: v3
- name: windows-386
goos: windows
goarch: 386
- name: windows-arm64
goos: windows
goarch: arm64
- name: windows-arm32v7
goos: windows
goarch: arm
goarm: 7
# linux
- name: linux-amd64
goos: linux
goarch: amd64
goamd64: v1
- name: linux-amd64-v3
goos: linux
goarch: amd64
goamd64: v3
- name: linux-386
goos: linux
goarch: 386
- name: linux-arm64
goos: linux
goarch: arm64
- name: linux-armv5
goos: linux
goarch: arm
goarm: 5
- name: linux-armv6
goos: linux
goarch: arm
goarm: 6
- name: linux-armv7
goos: linux
goarch: arm
goarm: 7
- name: linux-mips-softfloat
goos: linux
goarch: mips
gomips: softfloat
- name: linux-mips-hardfloat
goos: linux
goarch: mips
gomips: hardfloat
- name: linux-mipsel-softfloat
goos: linux
goarch: mipsle
gomips: softfloat
- name: linux-mipsel-hardfloat
goos: linux
goarch: mipsle
gomips: hardfloat
- name: linux-mips64
goos: linux
goarch: mips64
- name: linux-mips64el
goos: linux
goarch: mips64le
- name: linux-s390x
goos: linux
goarch: s390x
# darwin
- name: darwin-amd64
goos: darwin
goarch: amd64
goamd64: v1
- name: darwin-amd64-v3
goos: darwin
goarch: amd64
goamd64: v3
- name: darwin-arm64
goos: darwin
goarch: arm64
# freebsd
- name: freebsd-amd64
goos: freebsd
goarch: amd64
goamd64: v1
- name: freebsd-amd64-v3
goos: freebsd
goarch: amd64
goamd64: v3
- name: freebsd-386
goos: freebsd
goarch: 386
- name: freebsd-arm64
goos: freebsd
goarch: arm64
fail-fast: true
runs-on: ubuntu-latest
env:
GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }}
GOAMD64: ${{ matrix.goamd64 }}
GOARM: ${{ matrix.goarm }}
GOMIPS: ${{ matrix.gomips }}
CGO_ENABLED: 0
TAGS: with_gvisor,with_dhcp,with_wireguard,with_clash_api,with_quic,with_utls,with_ech
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.21
- name: Build
id: build
run: make

View File

@ -39,7 +39,7 @@ jobs:
echo "ref=$ref" echo "ref=$ref"
echo "ref=$ref" >> $GITHUB_OUTPUT echo "ref=$ref" >> $GITHUB_OUTPUT
- name: Checkout - name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with: with:
ref: ${{ steps.ref.outputs.ref }} ref: ${{ steps.ref.outputs.ref }}
fetch-depth: 0 fetch-depth: 0
@ -107,7 +107,7 @@ jobs:
echo "latest=$latest" echo "latest=$latest"
echo "latest=$latest" >> $GITHUB_OUTPUT echo "latest=$latest" >> $GITHUB_OUTPUT
- name: Download digests - name: Download digests
uses: actions/download-artifact@v5 uses: actions/download-artifact@v4
with: with:
path: /tmp/digests path: /tmp/digests
pattern: digests-* pattern: digests-*

View File

@ -22,17 +22,16 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ~1.24.6 go-version: ^1.23
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v8 uses: golangci/golangci-lint-action@v6
with: with:
version: latest version: latest
args: --timeout=30m args: --timeout=30m
install-mode: binary install-mode: binary
verify: false

View File

@ -1,189 +1,39 @@
name: Build Linux Packages name: Release to Linux repository
on: on:
workflow_dispatch:
inputs:
version:
description: "Version name"
required: true
type: string
forceBeta:
description: "Force beta"
required: false
type: boolean
default: false
release: release:
types: types:
- published - published
jobs: jobs:
calculate_version:
name: Calculate version
runs-on: ubuntu-latest
outputs:
version: ${{ steps.outputs.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.25.0
- name: Check input version
if: github.event_name == 'workflow_dispatch'
run: |-
echo "version=${{ inputs.version }}"
echo "version=${{ inputs.version }}" >> "$GITHUB_ENV"
- name: Calculate version
if: github.event_name != 'workflow_dispatch'
run: |-
go run -v ./cmd/internal/read_tag --ci --nightly
- name: Set outputs
id: outputs
run: |-
echo "version=$version" >> "$GITHUB_OUTPUT"
build: build:
name: Build binary
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs:
- calculate_version
strategy:
matrix:
include:
- { os: linux, arch: amd64, debian: amd64, rpm: x86_64, pacman: x86_64 }
- { os: linux, arch: "386", debian: i386, rpm: i386 }
- { os: linux, arch: arm, goarm: "6", debian: armel, rpm: armv6hl }
- { os: linux, arch: arm, goarm: "7", debian: armhf, rpm: armv7hl, pacman: armv7hl }
- { os: linux, arch: arm64, debian: arm64, rpm: aarch64, pacman: aarch64 }
- { os: linux, arch: mips64le, debian: mips64el, rpm: mips64el }
- { os: linux, arch: mipsle, debian: mipsel, rpm: mipsel }
- { os: linux, arch: s390x, debian: s390x, rpm: s390x }
- { os: linux, arch: ppc64le, debian: ppc64el, rpm: ppc64le }
- { os: linux, arch: riscv64, debian: riscv64, rpm: riscv64 }
- { os: linux, arch: loong64, debian: loongarch64, rpm: loongarch64 }
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ^1.25.0 go-version: ^1.23
- name: Setup Android NDK - name: Extract signing key
if: matrix.os == 'android'
uses: nttld/setup-ndk@v1
with:
ndk-version: r28
local-cache: true
- name: Set tag
run: |- run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV" mkdir -p $HOME/.gnupg
git tag v${{ needs.calculate_version.outputs.version }} -f cat > $HOME/.gnupg/sagernet.key <<EOF
- name: Set build tags ${{ secrets.GPG_KEY }}
run: | echo "HOME=$HOME" >> "$GITHUB_ENV"
set -xeuo pipefail EOF
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale' echo "HOME=$HOME" >> "$GITHUB_ENV"
echo "BUILD_TAGS=${TAGS}" >> "${GITHUB_ENV}" - name: Publish release
- name: Build uses: goreleaser/goreleaser-action@v6
run: | with:
set -xeuo pipefail distribution: goreleaser-pro
mkdir -p dist version: latest
go build -v -trimpath -o dist/sing-box -tags "${BUILD_TAGS}" \ args: release -f .goreleaser.fury.yaml --clean
-ldflags '-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }}' \
./cmd/sing-box
env: env:
CGO_ENABLED: "0"
GOOS: ${{ matrix.os }}
GOARCH: ${{ matrix.arch }}
GOARM: ${{ matrix.goarm }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Set mtime GORELEASER_KEY: ${{ secrets.GORELEASER_KEY }}
run: |- FURY_TOKEN: ${{ secrets.FURY_TOKEN }}
TZ=UTC touch -t '197001010000' dist/sing-box NFPM_KEY_PATH: ${{ env.HOME }}/.gnupg/sagernet.key
- name: Set name NFPM_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
if: (! contains(needs.calculate_version.outputs.version, '-')) && !inputs.forceBeta
run: |-
echo "NAME=sing-box" >> "$GITHUB_ENV"
- name: Set beta name
if: contains(needs.calculate_version.outputs.version, '-') || inputs.forceBeta
run: |-
echo "NAME=sing-box-beta" >> "$GITHUB_ENV"
- name: Set version
run: |-
PKG_VERSION="${{ needs.calculate_version.outputs.version }}"
PKG_VERSION="${PKG_VERSION//-/\~}"
echo "PKG_VERSION=${PKG_VERSION}" >> "${GITHUB_ENV}"
- name: Package DEB
if: matrix.debian != ''
run: |
set -xeuo pipefail
sudo gem install fpm
sudo apt-get install -y debsigs
cp .fpm_systemd .fpm
fpm -t deb \
--name "${NAME}" \
-v "$PKG_VERSION" \
-p "dist/${NAME}_${{ needs.calculate_version.outputs.version }}_linux_${{ matrix.debian }}.deb" \
--architecture ${{ matrix.debian }} \
dist/sing-box=/usr/bin/sing-box
curl -Lo '/tmp/debsigs.diff' 'https://gitlab.com/debsigs/debsigs/-/commit/160138f5de1ec110376d3c807b60a37388bc7c90.diff'
sudo patch /usr/bin/debsigs < '/tmp/debsigs.diff'
rm -rf $HOME/.gnupg
gpg --pinentry-mode loopback --passphrase "${{ secrets.GPG_PASSPHRASE }}" --import <<EOF
${{ secrets.GPG_KEY }}
EOF
debsigs --sign=origin -k ${{ secrets.GPG_KEY_ID }} --gpgopts '--pinentry-mode loopback --passphrase "${{ secrets.GPG_PASSPHRASE }}"' dist/*.deb
- name: Package RPM
if: matrix.rpm != ''
run: |-
set -xeuo pipefail
sudo gem install fpm
cp .fpm_systemd .fpm
fpm -t rpm \
--name "${NAME}" \
-v "$PKG_VERSION" \
-p "dist/${NAME}_${{ needs.calculate_version.outputs.version }}_linux_${{ matrix.rpm }}.rpm" \
--architecture ${{ matrix.rpm }} \
dist/sing-box=/usr/bin/sing-box
cat > $HOME/.rpmmacros <<EOF
%_gpg_name ${{ secrets.GPG_KEY_ID }}
%_gpg_sign_cmd_extra_args --pinentry-mode loopback --passphrase ${{ secrets.GPG_PASSPHRASE }}
EOF
gpg --pinentry-mode loopback --passphrase "${{ secrets.GPG_PASSPHRASE }}" --import <<EOF
${{ secrets.GPG_KEY }}
EOF
rpmsign --addsign dist/*.rpm
- name: Cleanup
run: rm dist/sing-box
- name: Upload artifact
uses: actions/upload-artifact@v4
with:
name: binary-${{ matrix.os }}_${{ matrix.arch }}${{ matrix.goarm && format('v{0}', matrix.goarm) }}${{ matrix.legacy_go && '-legacy' || '' }}
path: "dist"
upload:
name: Upload builds
runs-on: ubuntu-latest
needs:
- calculate_version
- build
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Set tag
run: |-
git ls-remote --exit-code --tags origin v${{ needs.calculate_version.outputs.version }} || echo "PUBLISHED=false" >> "$GITHUB_ENV"
git tag v${{ needs.calculate_version.outputs.version }} -f
echo "VERSION=${{ needs.calculate_version.outputs.version }}" >> "$GITHUB_ENV"
- name: Download builds
uses: actions/download-artifact@v5
with:
path: dist
merge-multiple: true
- name: Publish packages
run: |-
ls dist | xargs -I {} curl -F "package=@dist/{}" https://${{ secrets.FURY_TOKEN }}@push.fury.io/sagernet/

View File

@ -1,59 +1,28 @@
version: "2"
run:
go: "1.24"
build-tags:
- with_gvisor
- with_quic
- with_dhcp
- with_wireguard
- with_utls
- with_acme
- with_clash_api
linters: linters:
default: none disable-all: true
enable: enable:
- govet
- ineffassign
- paralleltest
- staticcheck
settings:
staticcheck:
checks:
- all
- -S1000
- -S1008
- -S1017
- -ST1003
- -QF1001
- -QF1003
- -QF1008
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
paths:
- transport/simple-obfs
- third_party$
- builtin$
- examples$
formatters:
enable:
- gci
- gofumpt - gofumpt
settings: - govet
gci: - gci
sections: - staticcheck
- standard - paralleltest
- prefix(github.com/sagernet/) - ineffassign
- default
custom-order: true linters-settings:
exclusions: gci:
generated: lax custom-order: true
paths: sections:
- transport/simple-obfs - standard
- third_party$ - prefix(github.com/sagernet/)
- builtin$ - default
- examples$ staticcheck:
checks:
- all
- -SA1003
run:
go: "1.23"
issues:
exclude-dirs:
- transport/simple-obfs

View File

@ -6,18 +6,17 @@ builds:
- -v - -v
- -trimpath - -trimpath
ldflags: ldflags:
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }} - -X github.com/sagernet/sing-box/constant.Version={{ .Version }} -s -w -buildid=
- -s
- -buildid=
tags: tags:
- with_gvisor - with_gvisor
- with_quic - with_quic
- with_dhcp - with_dhcp
- with_wireguard - with_wireguard
- with_ech
- with_utls - with_utls
- with_reality_server
- with_acme - with_acme
- with_clash_api - with_clash_api
- with_tailscale
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
targets: targets:
@ -49,18 +48,12 @@ nfpms:
contents: contents:
- src: release/config/config.json - src: release/config/config.json
dst: /etc/sing-box/config.json dst: /etc/sing-box/config.json
type: "config|noreplace" type: config
- src: release/config/sing-box.service - src: release/config/sing-box.service
dst: /usr/lib/systemd/system/sing-box.service dst: /usr/lib/systemd/system/sing-box.service
- src: release/config/sing-box@.service - src: release/config/sing-box@.service
dst: /usr/lib/systemd/system/sing-box@.service dst: /usr/lib/systemd/system/sing-box@.service
- src: release/config/sing-box.sysusers
dst: /usr/lib/sysusers.d/sing-box.conf
- src: release/config/sing-box.rules
dst: /usr/share/polkit-1/rules.d/sing-box.rules
- src: release/config/sing-box-split-dns.xml
dst: /usr/share/dbus-1/system.d/sing-box-split-dns.conf
- src: release/completions/sing-box.bash - src: release/completions/sing-box.bash
dst: /usr/share/bash-completion/completions/sing-box.bash dst: /usr/share/bash-completion/completions/sing-box.bash

View File

@ -16,13 +16,13 @@ builds:
- with_quic - with_quic
- with_dhcp - with_dhcp
- with_wireguard - with_wireguard
- with_ech
- with_utls - with_utls
- with_reality_server
- with_acme - with_acme
- with_clash_api - with_clash_api
- with_tailscale
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
- GOTOOLCHAIN=local
targets: targets:
- linux_386 - linux_386
- linux_amd64_v1 - linux_amd64_v1
@ -46,21 +46,21 @@ builds:
- with_dhcp - with_dhcp
- with_wireguard - with_wireguard
- with_utls - with_utls
- with_reality_server
- with_acme - with_acme
- with_clash_api - with_clash_api
- with_tailscale
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
- GOROOT={{ .Env.GOPATH }}/go_legacy - GOROOT={{ .Env.GOPATH }}/go1.20.14
tool: "{{ .Env.GOPATH }}/go_legacy/bin/go" gobinary: "{{ .Env.GOPATH }}/go1.20.14/bin/go"
targets: targets:
- windows_amd64_v1 - windows_amd64_v1
- windows_386 - windows_386
- darwin_amd64_v1
- id: android - id: android
<<: *template <<: *template
env: env:
- CGO_ENABLED=1 - CGO_ENABLED=1
- GOTOOLCHAIN=local
overrides: overrides:
- goos: android - goos: android
goarch: arm goarch: arm
@ -95,12 +95,10 @@ archives:
builds: builds:
- main - main
- android - android
formats: format: tar.gz
- tar.gz
format_overrides: format_overrides:
- goos: windows - goos: windows
formats: format: zip
- zip
wrap_in_directory: true wrap_in_directory: true
files: files:
- LICENSE - LICENSE
@ -130,18 +128,12 @@ nfpms:
contents: contents:
- src: release/config/config.json - src: release/config/config.json
dst: /etc/sing-box/config.json dst: /etc/sing-box/config.json
type: "config|noreplace" type: config
- src: release/config/sing-box.service - src: release/config/sing-box.service
dst: /usr/lib/systemd/system/sing-box.service dst: /usr/lib/systemd/system/sing-box.service
- src: release/config/sing-box@.service - src: release/config/sing-box@.service
dst: /usr/lib/systemd/system/sing-box@.service dst: /usr/lib/systemd/system/sing-box@.service
- src: release/config/sing-box.sysusers
dst: /usr/lib/sysusers.d/sing-box.conf
- src: release/config/sing-box.rules
dst: /usr/share/polkit-1/rules.d/sing-box.rules
- src: release/config/sing-box-split-dns.xml
dst: /usr/share/dbus-1/system.d/sing-box-split-dns.conf
- src: release/completions/sing-box.bash - src: release/completions/sing-box.bash
dst: /usr/share/bash-completion/completions/sing-box.bash dst: /usr/share/bash-completion/completions/sing-box.bash
@ -208,6 +200,4 @@ release:
ids: ids:
- archive - archive
- package - package
skip_upload: true skip_upload: true
partial:
by: target

View File

@ -1,4 +1,4 @@
FROM --platform=$BUILDPLATFORM golang:1.25-alpine AS builder FROM --platform=$BUILDPLATFORM golang:1.23-alpine AS builder
LABEL maintainer="nekohasekai <contact-git@sekai.icu>" LABEL maintainer="nekohasekai <contact-git@sekai.icu>"
COPY . /go/src/github.com/sagernet/sing-box COPY . /go/src/github.com/sagernet/sing-box
WORKDIR /go/src/github.com/sagernet/sing-box WORKDIR /go/src/github.com/sagernet/sing-box
@ -13,9 +13,9 @@ RUN set -ex \
&& export COMMIT=$(git rev-parse --short HEAD) \ && export COMMIT=$(git rev-parse --short HEAD) \
&& export VERSION=$(go run ./cmd/internal/read_tag) \ && export VERSION=$(go run ./cmd/internal/read_tag) \
&& go build -v -trimpath -tags \ && go build -v -trimpath -tags \
"with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale" \ "with_gvisor,with_quic,with_dhcp,with_wireguard,with_ech,with_utls,with_reality_server,with_acme,with_clash_api" \
-o /go/bin/sing-box \ -o /go/bin/sing-box \
-ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid= -checklinkname=0" \ -ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid=" \
./cmd/sing-box ./cmd/sing-box
FROM --platform=$TARGETPLATFORM alpine AS dist FROM --platform=$TARGETPLATFORM alpine AS dist
LABEL maintainer="nekohasekai <contact-git@sekai.icu>" LABEL maintainer="nekohasekai <contact-git@sekai.icu>"

View File

@ -1,29 +1,34 @@
NAME = sing-box NAME = sing-box
COMMIT = $(shell git rev-parse --short HEAD) COMMIT = $(shell git rev-parse --short HEAD)
TAGS ?= with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale TAGS_GO120 = with_gvisor,with_dhcp,with_wireguard,with_reality_server,with_clash_api,with_quic,with_utls
TAGS_GO121 = with_ech
TAGS ?= $(TAGS_GO118),$(TAGS_GO120),$(TAGS_GO121)
TAGS_TEST ?= with_gvisor,with_quic,with_wireguard,with_grpc,with_ech,with_utls,with_reality_server
GOHOSTOS = $(shell go env GOHOSTOS) GOHOSTOS = $(shell go env GOHOSTOS)
GOHOSTARCH = $(shell go env GOHOSTARCH) GOHOSTARCH = $(shell go env GOHOSTARCH)
VERSION=$(shell CGO_ENABLED=0 GOOS=$(GOHOSTOS) GOARCH=$(GOHOSTARCH) go run github.com/sagernet/sing-box/cmd/internal/read_tag@latest) VERSION=$(shell CGO_ENABLED=0 GOOS=$(GOHOSTOS) GOARCH=$(GOHOSTARCH) go run ./cmd/internal/read_tag)
PARAMS = -v -trimpath -ldflags "-X 'github.com/sagernet/sing-box/constant.Version=$(VERSION)' -s -w -buildid= -checklinkname=0" PARAMS = -v -trimpath -ldflags "-X 'github.com/sagernet/sing-box/constant.Version=$(VERSION)' -s -w -buildid="
MAIN_PARAMS = $(PARAMS) -tags "$(TAGS)" MAIN_PARAMS = $(PARAMS) -tags $(TAGS)
MAIN = ./cmd/sing-box MAIN = ./cmd/sing-box
PREFIX ?= $(shell go env GOPATH) PREFIX ?= $(shell go env GOPATH)
.PHONY: test release docs build .PHONY: test release docs build
build: build:
export GOTOOLCHAIN=local && \
go build $(MAIN_PARAMS) $(MAIN) go build $(MAIN_PARAMS) $(MAIN)
ci_build_go120:
go build $(PARAMS) $(MAIN)
go build $(PARAMS) -tags "$(TAGS_GO120)" $(MAIN)
ci_build: ci_build:
export GOTOOLCHAIN=local && \ go build $(PARAMS) $(MAIN)
go build $(PARAMS) $(MAIN) && \
go build $(MAIN_PARAMS) $(MAIN) go build $(MAIN_PARAMS) $(MAIN)
generate_completions: generate_completions:
go run -v --tags "$(TAGS),generate,generate_completions" $(MAIN) go run -v --tags generate,generate_completions $(MAIN)
install: install:
go build -o $(PREFIX)/bin/$(NAME) $(MAIN_PARAMS) $(MAIN) go build -o $(PREFIX)/bin/$(NAME) $(MAIN_PARAMS) $(MAIN)
@ -45,7 +50,7 @@ lint:
GOOS=freebsd golangci-lint run ./... GOOS=freebsd golangci-lint run ./...
lint_install: lint_install:
go install -v github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@latest
proto: proto:
@go run ./cmd/internal/protogen @go run ./cmd/internal/protogen
@ -56,9 +61,6 @@ proto_install:
go install -v google.golang.org/protobuf/cmd/protoc-gen-go@latest go install -v google.golang.org/protobuf/cmd/protoc-gen-go@latest
go install -v google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest go install -v google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
update_certificates:
go run ./cmd/internal/update_certificates
release: release:
go run ./cmd/internal/build goreleaser release --clean --skip publish go run ./cmd/internal/build goreleaser release --clean --skip publish
mkdir dist/release mkdir dist/release
@ -69,7 +71,7 @@ release:
dist/*_amd64.pkg.tar.zst \ dist/*_amd64.pkg.tar.zst \
dist/*_arm64.pkg.tar.zst \ dist/*_arm64.pkg.tar.zst \
dist/release dist/release
ghr --replace --draft --prerelease -p 5 "v${VERSION}" dist/release ghr --replace --draft --prerelease -p 3 "v${VERSION}" dist/release
rm -r dist/release rm -r dist/release
release_repo: release_repo:
@ -88,7 +90,7 @@ upload_android:
mkdir -p dist/release_android mkdir -p dist/release_android
cp ../sing-box-for-android/app/build/outputs/apk/play/release/*.apk dist/release_android cp ../sing-box-for-android/app/build/outputs/apk/play/release/*.apk dist/release_android
cp ../sing-box-for-android/app/build/outputs/apk/other/release/*-universal.apk dist/release_android cp ../sing-box-for-android/app/build/outputs/apk/other/release/*-universal.apk dist/release_android
ghr --replace --draft --prerelease -p 5 "v${VERSION}" dist/release_android ghr --replace --draft --prerelease -p 3 "v${VERSION}" dist/release_android
rm -rf dist/release_android rm -rf dist/release_android
release_android: lib_android update_android_version build_android upload_android release_android: lib_android update_android_version build_android upload_android
@ -108,16 +110,6 @@ upload_ios_app_store:
cd ../sing-box-for-apple && \ cd ../sing-box-for-apple && \
xcodebuild -exportArchive -archivePath build/SFI.xcarchive -exportOptionsPlist SFI/Upload.plist -allowProvisioningUpdates xcodebuild -exportArchive -archivePath build/SFI.xcarchive -exportOptionsPlist SFI/Upload.plist -allowProvisioningUpdates
export_ios_ipa:
cd ../sing-box-for-apple && \
xcodebuild -exportArchive -archivePath build/SFI.xcarchive -exportOptionsPlist SFI/Export.plist -allowProvisioningUpdates -exportPath build/SFI && \
cp build/SFI/sing-box.ipa dist/SFI.ipa
upload_ios_ipa:
cd dist && \
cp SFI.ipa "SFI-${VERSION}.ipa" && \
ghr --replace --draft --prerelease "v${VERSION}" "SFI-${VERSION}.ipa"
release_ios: build_ios upload_ios_app_store release_ios: build_ios upload_ios_app_store
build_macos: build_macos:
@ -185,37 +177,15 @@ upload_tvos_app_store:
cd ../sing-box-for-apple && \ cd ../sing-box-for-apple && \
xcodebuild -exportArchive -archivePath "build/SFT.xcarchive" -exportOptionsPlist SFI/Upload.plist -allowProvisioningUpdates xcodebuild -exportArchive -archivePath "build/SFT.xcarchive" -exportOptionsPlist SFI/Upload.plist -allowProvisioningUpdates
export_tvos_ipa:
cd ../sing-box-for-apple && \
xcodebuild -exportArchive -archivePath "build/SFT.xcarchive" -exportOptionsPlist SFI/Export.plist -allowProvisioningUpdates -exportPath build/SFT && \
cp build/SFT/sing-box.ipa dist/SFT.ipa
upload_tvos_ipa:
cd dist && \
cp SFT.ipa "SFT-${VERSION}.ipa" && \
ghr --replace --draft --prerelease "v${VERSION}" "SFT-${VERSION}.ipa"
release_tvos: build_tvos upload_tvos_app_store release_tvos: build_tvos upload_tvos_app_store
update_apple_version: update_apple_version:
go run ./cmd/internal/update_apple_version go run ./cmd/internal/update_apple_version
update_macos_version:
MACOS_PROJECT_VERSION=$(shell go run -v ./cmd/internal/app_store_connect next_macos_project_version) go run ./cmd/internal/update_apple_version
release_apple: lib_ios update_apple_version release_ios release_macos release_tvos release_macos_standalone release_apple: lib_ios update_apple_version release_ios release_macos release_tvos release_macos_standalone
release_apple_beta: update_apple_version release_ios release_macos release_tvos release_apple_beta: update_apple_version release_ios release_macos release_tvos
publish_testflight:
go run -v ./cmd/internal/app_store_connect publish_testflight
prepare_app_store:
go run -v ./cmd/internal/app_store_connect prepare_app_store
publish_app_store:
go run -v ./cmd/internal/app_store_connect publish_app_store
test: test:
@go test -v ./... && \ @go test -v ./... && \
cd test && \ cd test && \
@ -231,22 +201,16 @@ test_stdio:
lib_android: lib_android:
go run ./cmd/internal/build_libbox -target android go run ./cmd/internal/build_libbox -target android
lib_android_debug:
go run ./cmd/internal/build_libbox -target android -debug
lib_apple:
go run ./cmd/internal/build_libbox -target apple
lib_ios: lib_ios:
go run ./cmd/internal/build_libbox -target apple -platform ios -debug go run ./cmd/internal/build_libbox -target ios
lib: lib:
go run ./cmd/internal/build_libbox -target android go run ./cmd/internal/build_libbox -target android
go run ./cmd/internal/build_libbox -target ios go run ./cmd/internal/build_libbox -target ios
lib_install: lib_install:
go install -v github.com/sagernet/gomobile/cmd/gomobile@v0.1.8 go install -v github.com/sagernet/gomobile/cmd/gomobile@v0.1.4
go install -v github.com/sagernet/gomobile/cmd/gobind@v0.1.8 go install -v github.com/sagernet/gomobile/cmd/gobind@v0.1.4
docs: docs:
venv/bin/mkdocs serve venv/bin/mkdocs serve
@ -265,4 +229,4 @@ clean:
update: update:
git fetch git fetch
git reset FETCH_HEAD --hard git reset FETCH_HEAD --hard
git clean -fdx git clean -fdx

View File

@ -1,21 +0,0 @@
package adapter
import (
"context"
"crypto/x509"
"github.com/sagernet/sing/service"
)
type CertificateStore interface {
LifecycleService
Pool() *x509.CertPool
}
func RootPoolFromContext(ctx context.Context) *x509.CertPool {
store := service.FromContext[CertificateStore](ctx)
if store == nil {
return nil
}
return store.Pool()
}

View File

@ -1,94 +0,0 @@
package adapter
import (
"context"
"net/netip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"github.com/miekg/dns"
)
type DNSRouter interface {
Lifecycle
Exchange(ctx context.Context, message *dns.Msg, options DNSQueryOptions) (*dns.Msg, error)
Lookup(ctx context.Context, domain string, options DNSQueryOptions) ([]netip.Addr, error)
ClearCache()
LookupReverseMapping(ip netip.Addr) (string, bool)
ResetNetwork()
}
type DNSClient interface {
Start()
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool)
ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool)
ClearCache()
}
type DNSQueryOptions struct {
Transport DNSTransport
Strategy C.DomainStrategy
LookupStrategy C.DomainStrategy
DisableCache bool
RewriteTTL *uint32
ClientSubnet netip.Prefix
}
func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) {
if options == nil {
return &DNSQueryOptions{}, nil
}
transportManager := service.FromContext[DNSTransportManager](ctx)
transport, loaded := transportManager.Transport(options.Server)
if !loaded {
return nil, E.New("domain resolver not found: " + options.Server)
}
return &DNSQueryOptions{
Transport: transport,
Strategy: C.DomainStrategy(options.Strategy),
DisableCache: options.DisableCache,
RewriteTTL: options.RewriteTTL,
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
}, nil
}
type RDRCStore interface {
LoadRDRC(transportName string, qName string, qType uint16) (rejected bool)
SaveRDRC(transportName string, qName string, qType uint16) error
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
}
type DNSTransport interface {
Lifecycle
Type() string
Tag() string
Dependencies() []string
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
}
type LegacyDNSTransport interface {
LegacyStrategy() C.DomainStrategy
LegacyClientSubnet() netip.Prefix
}
type DNSTransportRegistry interface {
option.DNSTransportOptionsRegistry
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)
}
type DNSTransportManager interface {
Lifecycle
Transports() []DNSTransport
Transport(tag string) (DNSTransport, bool)
Default() DNSTransport
FakeIP() FakeIPTransport
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error
}

View File

@ -6,6 +6,8 @@ import (
"encoding/binary" "encoding/binary"
"time" "time"
"github.com/sagernet/sing-box/common/urltest"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/varbin" "github.com/sagernet/sing/common/varbin"
) )
@ -14,20 +16,7 @@ type ClashServer interface {
ConnectionTracker ConnectionTracker
Mode() string Mode() string
ModeList() []string ModeList() []string
HistoryStorage() URLTestHistoryStorage HistoryStorage() *urltest.HistoryStorage
}
type URLTestHistory struct {
Time time.Time `json:"time"`
Delay uint16 `json:"delay"`
}
type URLTestHistoryStorage interface {
SetHook(hook chan<- struct{})
LoadURLTestHistory(tag string) *URLTestHistory
DeleteURLTestHistory(tag string)
StoreURLTestHistory(tag string, history *URLTestHistory)
Close() error
} }
type V2RayServer interface { type V2RayServer interface {
@ -42,7 +31,7 @@ type CacheFile interface {
FakeIPStorage FakeIPStorage
StoreRDRC() bool StoreRDRC() bool
RDRCStore dns.RDRCStore
LoadMode() string LoadMode() string
StoreMode(mode string) error StoreMode(mode string) error
@ -50,17 +39,17 @@ type CacheFile interface {
StoreSelected(group string, selected string) error StoreSelected(group string, selected string) error
LoadGroupExpand(group string) (isExpand bool, loaded bool) LoadGroupExpand(group string) (isExpand bool, loaded bool)
StoreGroupExpand(group string, expand bool) error StoreGroupExpand(group string, expand bool) error
LoadRuleSet(tag string) *SavedBinary LoadRuleSet(tag string) *SavedRuleSet
SaveRuleSet(tag string, set *SavedBinary) error SaveRuleSet(tag string, set *SavedRuleSet) error
} }
type SavedBinary struct { type SavedRuleSet struct {
Content []byte Content []byte
LastUpdated time.Time LastUpdated time.Time
LastEtag string LastEtag string
} }
func (s *SavedBinary) MarshalBinary() ([]byte, error) { func (s *SavedRuleSet) MarshalBinary() ([]byte, error) {
var buffer bytes.Buffer var buffer bytes.Buffer
err := binary.Write(&buffer, binary.BigEndian, uint8(1)) err := binary.Write(&buffer, binary.BigEndian, uint8(1))
if err != nil { if err != nil {
@ -81,7 +70,7 @@ func (s *SavedBinary) MarshalBinary() ([]byte, error) {
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
func (s *SavedBinary) UnmarshalBinary(data []byte) error { func (s *SavedRuleSet) UnmarshalBinary(data []byte) error {
reader := bytes.NewReader(data) reader := bytes.NewReader(data)
var version uint8 var version uint8
err := binary.Read(reader, binary.BigEndian, &version) err := binary.Read(reader, binary.BigEndian, &version)

View File

@ -3,11 +3,12 @@ package adapter
import ( import (
"net/netip" "net/netip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
) )
type FakeIPStore interface { type FakeIPStore interface {
SimpleLifecycle Service
Contains(address netip.Addr) bool Contains(address netip.Addr) bool
Create(domain string, isIPv6 bool) (netip.Addr, error) Create(domain string, isIPv6 bool) (netip.Addr, error)
Lookup(address netip.Addr) (string, bool) Lookup(address netip.Addr) (string, bool)
@ -26,6 +27,6 @@ type FakeIPStorage interface {
} }
type FakeIPTransport interface { type FakeIPTransport interface {
DNSTransport dns.Transport
Store() FakeIPStore Store() FakeIPStore
} }

View File

@ -57,8 +57,6 @@ type InboundContext struct {
Domain string Domain string
Client string Client string
SniffContext any SniffContext any
SnifferNames []string
SniffError error
// cache // cache
@ -73,15 +71,14 @@ type InboundContext struct {
UDPDisableDomainUnmapping bool UDPDisableDomainUnmapping bool
UDPConnect bool UDPConnect bool
UDPTimeout time.Duration UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
TLSRecordFragment bool
NetworkStrategy *C.NetworkStrategy NetworkStrategy C.NetworkStrategy
NetworkType []C.InterfaceType NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration FallbackDelay time.Duration
DNSServer string
DestinationAddresses []netip.Addr DestinationAddresses []netip.Addr
SourceGeoIPCode string SourceGeoIPCode string
GeoIPCode string GeoIPCode string
@ -136,7 +133,8 @@ func ExtendContext(ctx context.Context) (context.Context, *InboundContext) {
func OverrideContext(ctx context.Context) context.Context { func OverrideContext(ctx context.Context) context.Context {
if metadata := ContextFrom(ctx); metadata != nil { if metadata := ContextFrom(ctx); metadata != nil {
newMetadata := *metadata var newMetadata InboundContext
newMetadata = *metadata
return WithContext(ctx, &newMetadata) return WithContext(ctx, &newMetadata)
} }
return ctx return ctx

View File

@ -37,14 +37,13 @@ func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endp
func (m *Manager) Start(stage adapter.StartStage) error { func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock() m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage { if m.started && m.stage >= stage {
panic("already started") panic("already started")
} }
m.started = true m.started = true
m.stage = stage m.stage = stage
inbounds := m.inbounds for _, inbound := range m.inbounds {
m.access.Unlock()
for _, inbound := range inbounds {
err := adapter.LegacyStart(inbound, stage) err := adapter.LegacyStart(inbound, stage)
if err != nil { if err != nil {
return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]") return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]")

View File

@ -2,11 +2,6 @@ package adapter
import E "github.com/sagernet/sing/common/exceptions" import E "github.com/sagernet/sing/common/exceptions"
type SimpleLifecycle interface {
Start() error
Close() error
}
type StartStage uint8 type StartStage uint8
const ( const (

View File

@ -28,14 +28,14 @@ func LegacyStart(starter any, stage StartStage) error {
} }
type lifecycleServiceWrapper struct { type lifecycleServiceWrapper struct {
SimpleLifecycle Service
name string name string
} }
func NewLifecycleService(service SimpleLifecycle, name string) LifecycleService { func NewLifecycleService(service Service, name string) LifecycleService {
return &lifecycleServiceWrapper{ return &lifecycleServiceWrapper{
SimpleLifecycle: service, Service: service,
name: name, name: name,
} }
} }
@ -44,9 +44,9 @@ func (l *lifecycleServiceWrapper) Name() string {
} }
func (l *lifecycleServiceWrapper) Start(stage StartStage) error { func (l *lifecycleServiceWrapper) Start(stage StartStage) error {
return LegacyStart(l.SimpleLifecycle, stage) return LegacyStart(l.Service, stage)
} }
func (l *lifecycleServiceWrapper) Close() error { func (l *lifecycleServiceWrapper) Close() error {
return l.SimpleLifecycle.Close() return l.Service.Close()
} }

View File

@ -20,24 +20,20 @@ type NetworkManager interface {
DefaultOptions() NetworkOptions DefaultOptions() NetworkOptions
RegisterAutoRedirectOutputMark(mark uint32) error RegisterAutoRedirectOutputMark(mark uint32) error
AutoRedirectOutputMark() uint32 AutoRedirectOutputMark() uint32
AutoRedirectOutputMarkFunc() control.Func
NetworkMonitor() tun.NetworkUpdateMonitor NetworkMonitor() tun.NetworkUpdateMonitor
InterfaceMonitor() tun.DefaultInterfaceMonitor InterfaceMonitor() tun.DefaultInterfaceMonitor
PackageManager() tun.PackageManager PackageManager() tun.PackageManager
WIFIState() WIFIState WIFIState() WIFIState
ResetNetwork() ResetNetwork()
UpdateWIFIState()
} }
type NetworkOptions struct { type NetworkOptions struct {
BindInterface string NetworkStrategy C.NetworkStrategy
RoutingMark uint32 NetworkType []C.InterfaceType
DomainResolver string FallbackNetworkType []C.InterfaceType
DomainResolveOptions DNSQueryOptions FallbackDelay time.Duration
NetworkStrategy *C.NetworkStrategy BindInterface string
NetworkType []C.InterfaceType RoutingMark uint32
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
} }
type InterfaceUpdateListener interface { type InterfaceUpdateListener interface {

View File

@ -2,12 +2,9 @@ package adapter
import ( import (
"context" "context"
"net/netip"
"time"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -21,17 +18,6 @@ type Outbound interface {
N.Dialer N.Dialer
} }
type OutboundWithPreferredRoutes interface {
Outbound
PreferredDomain(domain string) bool
PreferredAddress(address netip.Addr) bool
}
type DirectRouteOutbound interface {
Outbound
NewDirectRouteConnection(metadata InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error)
}
type OutboundRegistry interface { type OutboundRegistry interface {
option.OutboundOptionsRegistry option.OutboundOptionsRegistry
CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error)

View File

@ -23,14 +23,14 @@ type Manager struct {
registry adapter.OutboundRegistry registry adapter.OutboundRegistry
endpoint adapter.EndpointManager endpoint adapter.EndpointManager
defaultTag string defaultTag string
access sync.RWMutex access sync.Mutex
started bool started bool
stage adapter.StartStage stage adapter.StartStage
outbounds []adapter.Outbound outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound outboundByTag map[string]adapter.Outbound
dependByTag map[string][]string dependByTag map[string][]string
defaultOutbound adapter.Outbound defaultOutbound adapter.Outbound
defaultOutboundFallback func() (adapter.Outbound, error) defaultOutboundFallback adapter.Outbound
} }
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager { func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
@ -44,7 +44,7 @@ func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry,
} }
} }
func (m *Manager) Initialize(defaultOutboundFallback func() (adapter.Outbound, error)) { func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) {
m.defaultOutboundFallback = defaultOutboundFallback m.defaultOutboundFallback = defaultOutboundFallback
} }
@ -55,31 +55,18 @@ func (m *Manager) Start(stage adapter.StartStage) error {
} }
m.started = true m.started = true
m.stage = stage m.stage = stage
outbounds := m.outbounds
m.access.Unlock()
if stage == adapter.StartStateStart { if stage == adapter.StartStateStart {
if m.defaultTag != "" && m.defaultOutbound == nil { if m.defaultTag != "" && m.defaultOutbound == nil {
defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag) defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
if !loaded { if !loaded {
m.access.Unlock()
return E.New("default outbound not found: ", m.defaultTag) return E.New("default outbound not found: ", m.defaultTag)
} }
m.defaultOutbound = defaultEndpoint m.defaultOutbound = defaultEndpoint
} }
if m.defaultOutbound == nil {
directOutbound, err := m.defaultOutboundFallback()
if err != nil {
m.access.Unlock()
return E.Cause(err, "create direct outbound for fallback")
}
m.outbounds = append(m.outbounds, directOutbound)
m.outboundByTag[directOutbound.Tag()] = directOutbound
m.defaultOutbound = directOutbound
}
outbounds := m.outbounds
m.access.Unlock()
return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...)) return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
} else { } else {
outbounds := m.outbounds
m.access.Unlock()
for _, outbound := range outbounds { for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage) err := adapter.LegacyStart(outbound, stage)
if err != nil { if err != nil {
@ -182,15 +169,15 @@ func (m *Manager) Close() error {
} }
func (m *Manager) Outbounds() []adapter.Outbound { func (m *Manager) Outbounds() []adapter.Outbound {
m.access.RLock() m.access.Lock()
defer m.access.RUnlock() defer m.access.Unlock()
return m.outbounds return m.outbounds
} }
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.RLock() m.access.Lock()
outbound, found := m.outboundByTag[tag] outbound, found := m.outboundByTag[tag]
m.access.RUnlock() m.access.Unlock()
if found { if found {
return outbound, true return outbound, true
} }
@ -198,16 +185,20 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
} }
func (m *Manager) Default() adapter.Outbound { func (m *Manager) Default() adapter.Outbound {
m.access.RLock() m.access.Lock()
defer m.access.RUnlock() defer m.access.Unlock()
return m.defaultOutbound if m.defaultOutbound != nil {
return m.defaultOutbound
} else {
return m.defaultOutboundFallback
}
} }
func (m *Manager) Remove(tag string) error { func (m *Manager) Remove(tag string) error {
m.access.Lock() m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag] outbound, found := m.outboundByTag[tag]
if !found { if !found {
m.access.Unlock()
return os.ErrInvalid return os.ErrInvalid
} }
delete(m.outboundByTag, tag) delete(m.outboundByTag, tag)
@ -241,6 +232,7 @@ func (m *Manager) Remove(tag string) error {
}) })
} }
} }
m.access.Unlock()
if started { if started {
return common.Close(outbound) return common.Close(outbound)
} }
@ -255,6 +247,8 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
if err != nil { if err != nil {
return err return err
} }
m.access.Lock()
defer m.access.Unlock()
if m.started { if m.started {
for _, stage := range adapter.ListStartStages { for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(outbound, stage) err = adapter.LegacyStart(outbound, stage)
@ -263,8 +257,6 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
} }
} }
} }
m.access.Lock()
defer m.access.Unlock()
if existsOutbound, loaded := m.outboundByTag[tag]; loaded { if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
if m.started { if m.started {
err = common.Close(existsOutbound) err = common.Close(existsOutbound)

View File

@ -2,31 +2,44 @@ package adapter
import ( import (
"context" "context"
"crypto/tls"
"net" "net"
"net/http" "net/http"
"net/netip"
"sync" "sync"
"time"
"github.com/sagernet/sing-box/common/geoip"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-dns"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
mdns "github.com/miekg/dns"
"go4.org/netipx" "go4.org/netipx"
) )
type Router interface { type Router interface {
Lifecycle Lifecycle
FakeIPStore() FakeIPStore
ConnectionRouter ConnectionRouter
PreMatch(metadata InboundContext, context tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) PreMatch(metadata InboundContext) error
ConnectionRouterEx ConnectionRouterEx
GeoIPReader() *geoip.Reader
LoadGeosite(code string) (Rule, error)
RuleSet(tag string) (RuleSet, bool) RuleSet(tag string) (RuleSet, bool)
NeedWIFIState() bool NeedWIFIState() bool
Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error)
Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error)
LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error)
ClearDNSCache()
Rules() []Rule Rules() []Rule
AppendTracker(tracker ConnectionTracker)
SetTracker(tracker ConnectionTracker)
ResetNetwork() ResetNetwork()
} }
@ -70,14 +83,12 @@ type RuleSetMetadata struct {
ContainsIPCIDRRule bool ContainsIPCIDRRule bool
} }
type HTTPStartContext struct { type HTTPStartContext struct {
ctx context.Context
access sync.Mutex access sync.Mutex
httpClientCache map[string]*http.Client httpClientCache map[string]*http.Client
} }
func NewHTTPStartContext(ctx context.Context) *HTTPStartContext { func NewHTTPStartContext() *HTTPStartContext {
return &HTTPStartContext{ return &HTTPStartContext{
ctx: ctx,
httpClientCache: make(map[string]*http.Client), httpClientCache: make(map[string]*http.Client),
} }
} }
@ -95,10 +106,6 @@ func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Clie
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
}, },
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(c.ctx),
RootCAs: RootPoolFromContext(c.ctx),
},
}, },
} }
c.httpClientCache[detour] = httpClient c.httpClientCache[detour] = httpClient

View File

@ -11,8 +11,9 @@ type HeadlessRule interface {
type Rule interface { type Rule interface {
HeadlessRule HeadlessRule
SimpleLifecycle Service
Type() string Type() string
UpdateGeosite() error
Action() RuleAction Action() RuleAction
} }

View File

@ -1,27 +1,6 @@
package adapter package adapter
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type Service interface { type Service interface {
Lifecycle Start() error
Type() string Close() error
Tag() string
}
type ServiceRegistry interface {
option.ServiceOptionsRegistry
Create(ctx context.Context, logger log.ContextLogger, tag string, serviceType string, options any) (Service, error)
}
type ServiceManager interface {
Lifecycle
Services() []Service
Get(tag string) (Service, bool)
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, serviceType string, options any) error
} }

View File

@ -1,21 +0,0 @@
package service
type Adapter struct {
serviceType string
serviceTag string
}
func NewAdapter(serviceType string, serviceTag string) Adapter {
return Adapter{
serviceType: serviceType,
serviceTag: serviceTag,
}
}
func (a *Adapter) Type() string {
return a.serviceType
}
func (a *Adapter) Tag() string {
return a.serviceTag
}

View File

@ -1,144 +0,0 @@
package service
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.ServiceManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.ServiceRegistry
access sync.Mutex
started bool
stage adapter.StartStage
services []adapter.Service
serviceByTag map[string]adapter.Service
}
func NewManager(logger log.ContextLogger, registry adapter.ServiceRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
serviceByTag: make(map[string]adapter.Service),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
services := m.services
m.access.Unlock()
for _, service := range services {
err := adapter.LegacyStart(service, stage)
if err != nil {
return E.Cause(err, stage, " service/", service.Type(), "[", service.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
services := m.services
m.services = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, service := range services {
monitor.Start("close service/", service.Type(), "[", service.Tag(), "]")
err = E.Append(err, service.Close(), func(err error) error {
return E.Cause(err, "close service/", service.Type(), "[", service.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Services() []adapter.Service {
m.access.Lock()
defer m.access.Unlock()
return m.services
}
func (m *Manager) Get(tag string) (adapter.Service, bool) {
m.access.Lock()
service, found := m.serviceByTag[tag]
m.access.Unlock()
return service, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
service, found := m.serviceByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.serviceByTag, tag)
index := common.Index(m.services, func(it adapter.Service) bool {
return it == service
})
if index == -1 {
panic("invalid service index")
}
m.services = append(m.services[:index], m.services[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return service.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, logger log.ContextLogger, tag string, serviceType string, options any) error {
service, err := m.registry.Create(ctx, logger, tag, serviceType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(service, stage)
if err != nil {
return E.Cause(err, stage, " service/", service.Type(), "[", service.Tag(), "]")
}
}
}
if existsService, loaded := m.serviceByTag[tag]; loaded {
if m.started {
err = existsService.Close()
if err != nil {
return E.Cause(err, "close service/", existsService.Type(), "[", existsService.Tag(), "]")
}
}
existsIndex := common.Index(m.services, func(it adapter.Service) bool {
return it == existsService
})
if existsIndex == -1 {
panic("invalid service index")
}
m.services = append(m.services[:existsIndex], m.services[existsIndex+1:]...)
}
m.services = append(m.services, service)
m.serviceByTag[tag] = service
return nil
}

View File

@ -1,72 +0,0 @@
package service
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type ConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.Service, error)
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any {
return new(Options)
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.Service, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.ServiceRegistry = (*Registry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.Service, error)
)
type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}
func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (m *Registry) Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Service, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[outboundType]
if !loaded {
return nil, E.New("outbound type not found: " + outboundType)
}
return constructor(ctx, logger, tag, options)
}
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[outboundType] = optionsConstructor
m.constructor[outboundType] = constructor
}

View File

@ -1,18 +0,0 @@
package adapter
import (
"net"
N "github.com/sagernet/sing/common/network"
)
type ManagedSSMServer interface {
Inbound
SetTracker(tracker SSMTracker)
UpdateUsers(users []string, uPSKs []string) error
}
type SSMTracker interface {
TrackConnection(conn net.Conn, metadata InboundContext) net.Conn
TrackPacketConnection(conn N.PacketConn, metadata InboundContext) N.PacketConn
}

View File

@ -3,6 +3,6 @@ package adapter
import "time" import "time"
type TimeService interface { type TimeService interface {
SimpleLifecycle Service
TimeFunc() func() time.Time TimeFunc() func() time.Time
} }

View File

@ -78,8 +78,8 @@ func (w *myUpstreamHandlerWrapper) NewError(ctx context.Context, err error) {
// Deprecated: removed // Deprecated: removed
func UpstreamMetadata(metadata InboundContext) M.Metadata { func UpstreamMetadata(metadata InboundContext) M.Metadata {
return M.Metadata{ return M.Metadata{
Source: metadata.Source.Unwrap(), Source: metadata.Source,
Destination: metadata.Destination.Unwrap(), Destination: metadata.Destination,
} }
} }

224
box.go
View File

@ -12,14 +12,9 @@ import (
"github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/certificate"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/taskmonitor" "github.com/sagernet/sing-box/common/taskmonitor"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/local"
"github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/experimental/libbox/platform"
@ -35,23 +30,20 @@ import (
"github.com/sagernet/sing/service/pause" "github.com/sagernet/sing/service/pause"
) )
var _ adapter.SimpleLifecycle = (*Box)(nil) var _ adapter.Service = (*Box)(nil)
type Box struct { type Box struct {
createdAt time.Time createdAt time.Time
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
network *route.NetworkManager network *route.NetworkManager
endpoint *endpoint.Manager endpoint *endpoint.Manager
inbound *inbound.Manager inbound *inbound.Manager
outbound *outbound.Manager outbound *outbound.Manager
service *boxService.Manager connection *route.ConnectionManager
dnsTransport *dns.TransportManager router *route.Router
dnsRouter *dns.Router services []adapter.LifecycleService
connection *route.ConnectionManager done chan struct{}
router *route.Router
internalService []adapter.LifecycleService
done chan struct{}
} }
type Options struct { type Options struct {
@ -65,8 +57,6 @@ func Context(
inboundRegistry adapter.InboundRegistry, inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry, outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry, endpointRegistry adapter.EndpointRegistry,
dnsTransportRegistry adapter.DNSTransportRegistry,
serviceRegistry adapter.ServiceRegistry,
) context.Context { ) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil { service.FromContext[adapter.InboundRegistry](ctx) == nil {
@ -83,14 +73,6 @@ func Context(
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry) ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry) ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
} }
if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil {
ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry)
ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry)
}
if service.FromContext[adapter.ServiceRegistry](ctx) == nil {
ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry)
ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry)
}
return ctx return ctx
} }
@ -105,8 +87,6 @@ func New(options Options) (*Box, error) {
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx) endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx)
if endpointRegistry == nil { if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context") return nil, E.New("missing endpoint registry in context")
@ -117,12 +97,6 @@ func New(options Options) (*Box, error) {
if outboundRegistry == nil { if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context") return nil, E.New("missing outbound registry in context")
} }
if dnsTransportRegistry == nil {
return nil, E.New("missing DNS transport registry in context")
}
if serviceRegistry == nil {
return nil, E.New("missing service registry in context")
}
ctx = pause.WithDefaultManager(ctx) ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental) experimentalOptions := common.PtrValueOrDefault(options.Experimental)
@ -156,34 +130,14 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "create log factory") return nil, E.Cause(err, "create log factory")
} }
var internalServices []adapter.LifecycleService
certificateOptions := common.PtrValueOrDefault(options.Certificate)
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
len(certificateOptions.Certificate) > 0 ||
len(certificateOptions.CertificatePath) > 0 ||
len(certificateOptions.CertificateDirectoryPath) > 0 {
certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions)
if err != nil {
return nil, err
}
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
internalServices = append(internalServices, certificateStore)
}
routeOptions := common.PtrValueOrDefault(options.Route) routeOptions := common.PtrValueOrDefault(options.Route)
dnsOptions := common.PtrValueOrDefault(options.DNS)
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry) endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize network manager") return nil, E.Cause(err, "initialize network manager")
@ -191,40 +145,10 @@ func New(options Options) (*Box, error) {
service.MustRegister[adapter.NetworkManager](ctx, networkManager) service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions) router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
service.MustRegister[adapter.Router](ctx, router)
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize router") return nil, E.Cause(err, "initialize router")
} }
ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper
if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService)
}
for i, transportOptions := range dnsOptions.Servers {
var tag string
if transportOptions.Tag != "" {
tag = transportOptions.Tag
} else {
tag = F.ToString(i)
}
err = dnsTransportManager.Create(
ctx,
logFactory.NewLogger(F.ToString("dns/", transportOptions.Type, "[", tag, "]")),
tag,
transportOptions.Type,
transportOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize DNS server[", i, "]")
}
}
err = dnsRouter.Initialize(dnsOptions.Rules)
if err != nil {
return nil, E.Cause(err, "initialize dns router")
}
for i, endpointOptions := range options.Endpoints { for i, endpointOptions := range options.Endpoints {
var tag string var tag string
if endpointOptions.Tag != "" { if endpointOptions.Tag != "" {
@ -232,15 +156,7 @@ func New(options Options) (*Box, error) {
} else { } else {
tag = F.ToString(i) tag = F.ToString(i)
} }
endpointCtx := ctx err = endpointManager.Create(ctx,
if tag != "" {
// TODO: remove this
endpointCtx = adapter.WithContext(endpointCtx, &adapter.InboundContext{
Outbound: tag,
})
}
err = endpointManager.Create(
endpointCtx,
router, router,
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
tag, tag,
@ -248,7 +164,7 @@ func New(options Options) (*Box, error) {
endpointOptions.Options, endpointOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize endpoint[", i, "]") return nil, E.Cause(err, "initialize inbound[", i, "]")
} }
} }
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {
@ -258,8 +174,7 @@ func New(options Options) (*Box, error) {
} else { } else {
tag = F.ToString(i) tag = F.ToString(i)
} }
err = inboundManager.Create( err = inboundManager.Create(ctx,
ctx,
router, router,
logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")),
tag, tag,
@ -296,51 +211,26 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize outbound[", i, "]") return nil, E.Cause(err, "initialize outbound[", i, "]")
} }
} }
for i, serviceOptions := range options.Services { outboundManager.Initialize(common.Must1(
var tag string direct.NewOutbound(
if serviceOptions.Tag != "" {
tag = serviceOptions.Tag
} else {
tag = F.ToString(i)
}
err = serviceManager.Create(
ctx,
logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")),
tag,
serviceOptions.Type,
serviceOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize service[", i, "]")
}
}
outboundManager.Initialize(func() (adapter.Outbound, error) {
return direct.NewOutbound(
ctx, ctx,
router, router,
logFactory.NewLogger("outbound/direct"), logFactory.NewLogger("outbound/direct"),
"direct", "direct",
option.DirectOutboundOptions{}, option.DirectOutboundOptions{},
) ),
}) ))
dnsTransportManager.Initialize(func() (adapter.DNSTransport, error) {
return local.NewTransport(
ctx,
logFactory.NewLogger("dns/local"),
"local",
option.LocalDNSServerOptions{},
)
})
if platformInterface != nil { if platformInterface != nil {
err = platformInterface.Initialize(networkManager) err = platformInterface.Initialize(networkManager)
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize platform interface") return nil, E.Cause(err, "initialize platform interface")
} }
} }
var services []adapter.LifecycleService
if needCacheFile { if needCacheFile {
cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile)) cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
service.MustRegister[adapter.CacheFile](ctx, cacheFile) service.MustRegister[adapter.CacheFile](ctx, cacheFile)
internalServices = append(internalServices, cacheFile) services = append(services, cacheFile)
} }
if needClashAPI { if needClashAPI {
clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI) clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI)
@ -349,9 +239,9 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "create clash-server") return nil, E.Cause(err, "create clash-server")
} }
router.AppendTracker(clashServer) router.SetTracker(clashServer)
service.MustRegister[adapter.ClashServer](ctx, clashServer) service.MustRegister[adapter.ClashServer](ctx, clashServer)
internalServices = append(internalServices, clashServer) services = append(services, clashServer)
} }
if needV2RayAPI { if needV2RayAPI {
v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(experimentalOptions.V2RayAPI)) v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(experimentalOptions.V2RayAPI))
@ -359,17 +249,18 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "create v2ray-server") return nil, E.Cause(err, "create v2ray-server")
} }
if v2rayServer.StatsService() != nil { if v2rayServer.StatsService() != nil {
router.AppendTracker(v2rayServer.StatsService()) router.SetTracker(v2rayServer.StatsService())
internalServices = append(internalServices, v2rayServer) services = append(services, v2rayServer)
service.MustRegister[adapter.V2RayServer](ctx, v2rayServer) service.MustRegister[adapter.V2RayServer](ctx, v2rayServer)
} }
} }
ntpOptions := common.PtrValueOrDefault(options.NTP)
if ntpOptions.Enabled { if ntpOptions.Enabled {
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions, ntpOptions.ServerIsDomain()) ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "create NTP service") return nil, E.Cause(err, "create NTP service")
} }
ntpService := ntp.NewService(ntp.Options{ timeService := ntp.NewService(ntp.Options{
Context: ctx, Context: ctx,
Dialer: ntpDialer, Dialer: ntpDialer,
Logger: logFactory.NewLogger("ntp"), Logger: logFactory.NewLogger("ntp"),
@ -377,24 +268,21 @@ func New(options Options) (*Box, error) {
Interval: time.Duration(ntpOptions.Interval), Interval: time.Duration(ntpOptions.Interval),
WriteToSystem: ntpOptions.WriteToSystem, WriteToSystem: ntpOptions.WriteToSystem,
}) })
timeService.TimeService = ntpService service.MustRegister[ntp.TimeService](ctx, timeService)
internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service")) services = append(services, adapter.NewLifecycleService(timeService, "ntp service"))
} }
return &Box{ return &Box{
network: networkManager, network: networkManager,
endpoint: endpointManager, endpoint: endpointManager,
inbound: inboundManager, inbound: inboundManager,
outbound: outboundManager, outbound: outboundManager,
dnsTransport: dnsTransportManager, connection: connectionManager,
service: serviceManager, router: router,
dnsRouter: dnsRouter, createdAt: createdAt,
connection: connectionManager, logFactory: logFactory,
router: router, logger: logFactory.Logger(),
createdAt: createdAt, services: services,
logFactory: logFactory, done: make(chan struct{}),
logger: logFactory.Logger(),
internalService: internalServices,
done: make(chan struct{}),
}, nil }, nil
} }
@ -444,15 +332,15 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return E.Cause(err, "start logger") return E.Cause(err, "start logger")
} }
err = adapter.StartNamed(adapter.StartStateInitialize, s.internalService) // cache-file clash-api v2ray-api err = adapter.StartNamed(adapter.StartStateInitialize, s.services) // cache-file clash-api v2ray-api
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router)
if err != nil { if err != nil {
return err return err
} }
@ -464,27 +352,31 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.StartNamed(adapter.StartStateStart, s.internalService) err = adapter.StartNamed(adapter.StartStateStart, s.services)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStart, s.inbound, s.endpoint, s.service) err = s.inbound.Start(adapter.StartStateStart)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service) err = adapter.Start(adapter.StartStateStart, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
err = adapter.StartNamed(adapter.StartStatePostStart, s.internalService) err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) err = adapter.StartNamed(adapter.StartStatePostStart, s.services)
if err != nil { if err != nil {
return err return err
} }
err = adapter.StartNamed(adapter.StartStateStarted, s.internalService) err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
err = adapter.StartNamed(adapter.StartStateStarted, s.services)
if err != nil { if err != nil {
return err return err
} }
@ -499,9 +391,9 @@ func (s *Box) Close() error {
close(s.done) close(s.done)
} }
err := common.Close( err := common.Close(
s.service, s.endpoint, s.inbound, s.outbound, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, s.inbound, s.outbound, s.router, s.connection, s.network,
) )
for _, lifecycleService := range s.internalService { for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error { err = E.Append(err, lifecycleService.Close(), func(err error) error {
return E.Cause(err, "close ", lifecycleService.Name()) return E.Cause(err, "close ", lifecycleService.Name())
}) })

@ -1 +1 @@
Subproject commit 9d71ede01a1a51bb459847bb5d4444b2846c77de Subproject commit cff12c57dd3ace2d2776ad607fa0c6bb93873649

@ -1 +1 @@
Subproject commit c5734677bdfcba5d2d4faf10c4f10475077a82ab Subproject commit fa107e3b7ca03e1acc5d250aabf82da6ddff997c

View File

@ -1,450 +0,0 @@
package main
import (
"context"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/sagernet/asc-go/asc"
"github.com/sagernet/sing-box/cmd/internal/build_shared"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
func main() {
ctx := context.Background()
switch os.Args[1] {
case "next_macos_project_version":
err := fetchMacOSVersion(ctx)
if err != nil {
log.Fatal(err)
}
case "publish_testflight":
err := publishTestflight(ctx)
if err != nil {
log.Fatal(err)
}
case "cancel_app_store":
err := cancelAppStore(ctx, os.Args[2])
if err != nil {
log.Fatal(err)
}
case "prepare_app_store":
err := prepareAppStore(ctx)
if err != nil {
log.Fatal(err)
}
case "publish_app_store":
err := publishAppStore(ctx)
if err != nil {
log.Fatal(err)
}
default:
log.Fatal("unknown action: ", os.Args[1])
}
}
const (
appID = "6673731168"
groupID = "5c5f3b78-b7a0-40c0-bcad-e6ef87bbefda"
)
func createClient(expireDuration time.Duration) *asc.Client {
privateKey, err := os.ReadFile(os.Getenv("ASC_KEY_PATH"))
if err != nil {
log.Fatal(err)
}
tokenConfig, err := asc.NewTokenConfig(os.Getenv("ASC_KEY_ID"), os.Getenv("ASC_KEY_ISSUER_ID"), expireDuration, privateKey)
if err != nil {
log.Fatal(err)
}
return asc.NewClient(tokenConfig.Client())
}
func fetchMacOSVersion(ctx context.Context) error {
client := createClient(time.Minute)
versions, _, err := client.Apps.ListAppStoreVersionsForApp(ctx, appID, &asc.ListAppStoreVersionsQuery{
FilterPlatform: []string{"MAC_OS"},
})
if err != nil {
return err
}
var versionID string
findVersion:
for _, version := range versions.Data {
switch *version.Attributes.AppStoreState {
case asc.AppStoreVersionStateReadyForSale,
asc.AppStoreVersionStatePendingDeveloperRelease:
versionID = version.ID
break findVersion
}
}
if versionID == "" {
return E.New("no version found")
}
latestBuild, _, err := client.Builds.GetBuildForAppStoreVersion(ctx, versionID, &asc.GetBuildForAppStoreVersionQuery{})
if err != nil {
return err
}
versionInt, err := strconv.Atoi(*latestBuild.Data.Attributes.Version)
if err != nil {
return E.Cause(err, "parse version code")
}
os.Stdout.WriteString(F.ToString(versionInt+1, "\n"))
return nil
}
func publishTestflight(ctx context.Context) error {
tagVersion, err := build_shared.ReadTagVersion()
if err != nil {
return err
}
tag := tagVersion.VersionString()
client := createClient(20 * time.Minute)
log.Info(tag, " list build IDs")
buildIDsResponse, _, err := client.TestFlight.ListBuildIDsForBetaGroup(ctx, groupID, nil)
if err != nil {
return err
}
buildIDs := common.Map(buildIDsResponse.Data, func(it asc.RelationshipData) string {
return it.ID
})
var platforms []asc.Platform
if len(os.Args) == 3 {
switch os.Args[2] {
case "ios":
platforms = []asc.Platform{asc.PlatformIOS}
case "macos":
platforms = []asc.Platform{asc.PlatformMACOS}
case "tvos":
platforms = []asc.Platform{asc.PlatformTVOS}
default:
return E.New("unknown platform: ", os.Args[2])
}
} else {
platforms = []asc.Platform{
asc.PlatformIOS,
asc.PlatformMACOS,
asc.PlatformTVOS,
}
}
for _, platform := range platforms {
log.Info(string(platform), " list builds")
for {
builds, _, err := client.Builds.ListBuilds(ctx, &asc.ListBuildsQuery{
FilterApp: []string{appID},
FilterPreReleaseVersionPlatform: []string{string(platform)},
})
if err != nil {
return err
}
build := builds.Data[0]
if common.Contains(buildIDs, build.ID) || time.Since(build.Attributes.UploadedDate.Time) > 30*time.Minute {
log.Info(string(platform), " ", tag, " waiting for process")
time.Sleep(15 * time.Second)
continue
}
if *build.Attributes.ProcessingState != "VALID" {
log.Info(string(platform), " ", tag, " waiting for process: ", *build.Attributes.ProcessingState)
time.Sleep(15 * time.Second)
continue
}
log.Info(string(platform), " ", tag, " list localizations")
localizations, _, err := client.TestFlight.ListBetaBuildLocalizationsForBuild(ctx, build.ID, nil)
if err != nil {
return err
}
localization := common.Find(localizations.Data, func(it asc.BetaBuildLocalization) bool {
return *it.Attributes.Locale == "en-US"
})
if localization.ID == "" {
log.Fatal(string(platform), " ", tag, " no en-US localization found")
}
if localization.Attributes == nil || localization.Attributes.WhatsNew == nil || *localization.Attributes.WhatsNew == "" {
log.Info(string(platform), " ", tag, " update localization")
_, _, err = client.TestFlight.UpdateBetaBuildLocalization(ctx, localization.ID, common.Ptr(
F.ToString("sing-box ", tagVersion.String()),
))
if err != nil {
return err
}
}
log.Info(string(platform), " ", tag, " publish")
response, err := client.TestFlight.AddBuildsToBetaGroup(ctx, groupID, []string{build.ID})
if response != nil && (response.StatusCode == http.StatusUnprocessableEntity || response.StatusCode == http.StatusNotFound) {
log.Info("waiting for process")
time.Sleep(15 * time.Second)
continue
} else if err != nil {
return err
}
log.Info(string(platform), " ", tag, " list submissions")
betaSubmissions, _, err := client.TestFlight.ListBetaAppReviewSubmissions(ctx, &asc.ListBetaAppReviewSubmissionsQuery{
FilterBuild: []string{build.ID},
})
if err != nil {
return err
}
if len(betaSubmissions.Data) == 0 {
log.Info(string(platform), " ", tag, " create submission")
_, _, err = client.TestFlight.CreateBetaAppReviewSubmission(ctx, build.ID)
if err != nil {
if strings.Contains(err.Error(), "ANOTHER_BUILD_IN_REVIEW") {
log.Error(err)
break
}
return err
}
}
break
}
}
return nil
}
func cancelAppStore(ctx context.Context, platform string) error {
switch platform {
case "ios":
platform = string(asc.PlatformIOS)
case "macos":
platform = string(asc.PlatformMACOS)
case "tvos":
platform = string(asc.PlatformTVOS)
}
tag, err := build_shared.ReadTag()
if err != nil {
return err
}
client := createClient(time.Minute)
for {
log.Info(platform, " list versions")
versions, response, err := client.Apps.ListAppStoreVersionsForApp(ctx, appID, &asc.ListAppStoreVersionsQuery{
FilterPlatform: []string{string(platform)},
})
if isRetryable(response) {
continue
} else if err != nil {
return err
}
version := common.Find(versions.Data, func(it asc.AppStoreVersion) bool {
return *it.Attributes.VersionString == tag
})
if version.ID == "" {
return nil
}
log.Info(platform, " ", tag, " get submission")
submission, response, err := client.Submission.GetAppStoreVersionSubmissionForAppStoreVersion(ctx, version.ID, nil)
if response != nil && response.StatusCode == http.StatusNotFound {
return nil
}
if isRetryable(response) {
continue
} else if err != nil {
return err
}
log.Info(platform, " ", tag, " delete submission")
_, err = client.Submission.DeleteSubmission(ctx, submission.Data.ID)
if err != nil {
return err
}
return nil
}
}
func prepareAppStore(ctx context.Context) error {
tag, err := build_shared.ReadTag()
if err != nil {
return err
}
client := createClient(time.Minute)
for _, platform := range []asc.Platform{
asc.PlatformIOS,
asc.PlatformMACOS,
asc.PlatformTVOS,
} {
log.Info(string(platform), " list versions")
versions, _, err := client.Apps.ListAppStoreVersionsForApp(ctx, appID, &asc.ListAppStoreVersionsQuery{
FilterPlatform: []string{string(platform)},
})
if err != nil {
return err
}
version := common.Find(versions.Data, func(it asc.AppStoreVersion) bool {
return *it.Attributes.VersionString == tag
})
log.Info(string(platform), " ", tag, " list builds")
builds, _, err := client.Builds.ListBuilds(ctx, &asc.ListBuildsQuery{
FilterApp: []string{appID},
FilterPreReleaseVersionPlatform: []string{string(platform)},
})
if err != nil {
return err
}
if len(builds.Data) == 0 {
log.Fatal(platform, " ", tag, " no build found")
}
buildID := common.Ptr(builds.Data[0].ID)
if version.ID == "" {
log.Info(string(platform), " ", tag, " create version")
newVersion, _, err := client.Apps.CreateAppStoreVersion(ctx, asc.AppStoreVersionCreateRequestAttributes{
Platform: platform,
VersionString: tag,
}, appID, buildID)
if err != nil {
return err
}
version = newVersion.Data
} else {
log.Info(string(platform), " ", tag, " check build")
currentBuild, response, err := client.Apps.GetBuildIDForAppStoreVersion(ctx, version.ID)
if err != nil {
return err
}
if response.StatusCode != http.StatusOK || currentBuild.Data.ID != *buildID {
switch *version.Attributes.AppStoreState {
case asc.AppStoreVersionStatePrepareForSubmission,
asc.AppStoreVersionStateRejected,
asc.AppStoreVersionStateDeveloperRejected:
case asc.AppStoreVersionStateWaitingForReview,
asc.AppStoreVersionStateInReview,
asc.AppStoreVersionStatePendingDeveloperRelease:
submission, _, err := client.Submission.GetAppStoreVersionSubmissionForAppStoreVersion(ctx, version.ID, nil)
if err != nil {
return err
}
if submission != nil {
log.Info(string(platform), " ", tag, " delete submission")
_, err = client.Submission.DeleteSubmission(ctx, submission.Data.ID)
if err != nil {
return err
}
time.Sleep(5 * time.Second)
}
default:
log.Fatal(string(platform), " ", tag, " unknown state ", string(*version.Attributes.AppStoreState))
}
log.Info(string(platform), " ", tag, " update build")
response, err = client.Apps.UpdateBuildForAppStoreVersion(ctx, version.ID, buildID)
if err != nil {
return err
}
if response.StatusCode != http.StatusNoContent {
response.Write(os.Stderr)
log.Fatal(string(platform), " ", tag, " unexpected response: ", response.Status)
}
} else {
switch *version.Attributes.AppStoreState {
case asc.AppStoreVersionStatePrepareForSubmission,
asc.AppStoreVersionStateRejected,
asc.AppStoreVersionStateDeveloperRejected:
case asc.AppStoreVersionStateWaitingForReview,
asc.AppStoreVersionStateInReview,
asc.AppStoreVersionStatePendingDeveloperRelease:
continue
default:
log.Fatal(string(platform), " ", tag, " unknown state ", string(*version.Attributes.AppStoreState))
}
}
}
log.Info(string(platform), " ", tag, " list localization")
localizations, _, err := client.Apps.ListLocalizationsForAppStoreVersion(ctx, version.ID, nil)
if err != nil {
return err
}
localization := common.Find(localizations.Data, func(it asc.AppStoreVersionLocalization) bool {
return *it.Attributes.Locale == "en-US"
})
if localization.ID == "" {
log.Info(string(platform), " ", tag, " no en-US localization found")
}
if localization.Attributes == nil || localization.Attributes.WhatsNew == nil || *localization.Attributes.WhatsNew == "" {
log.Info(string(platform), " ", tag, " update localization")
_, _, err = client.Apps.UpdateAppStoreVersionLocalization(ctx, localization.ID, &asc.AppStoreVersionLocalizationUpdateRequestAttributes{
PromotionalText: common.Ptr("Yet another distribution for sing-box, the universal proxy platform."),
WhatsNew: common.Ptr(F.ToString("sing-box ", tag, ": Fixes and improvements.")),
})
if err != nil {
return err
}
}
log.Info(string(platform), " ", tag, " create submission")
fixSubmit:
for {
_, response, err := client.Submission.CreateSubmission(ctx, version.ID)
if err != nil {
switch response.StatusCode {
case http.StatusInternalServerError:
continue
default:
return err
}
}
switch response.StatusCode {
case http.StatusCreated:
break fixSubmit
default:
return err
}
}
}
return nil
}
func publishAppStore(ctx context.Context) error {
tag, err := build_shared.ReadTag()
if err != nil {
return err
}
client := createClient(time.Minute)
for _, platform := range []asc.Platform{
asc.PlatformIOS,
asc.PlatformMACOS,
asc.PlatformTVOS,
} {
log.Info(string(platform), " list versions")
versions, _, err := client.Apps.ListAppStoreVersionsForApp(ctx, appID, &asc.ListAppStoreVersionsQuery{
FilterPlatform: []string{string(platform)},
})
if err != nil {
return err
}
version := common.Find(versions.Data, func(it asc.AppStoreVersion) bool {
return *it.Attributes.VersionString == tag
})
switch *version.Attributes.AppStoreState {
case asc.AppStoreVersionStatePrepareForSubmission, asc.AppStoreVersionStateDeveloperRejected:
log.Fatal(string(platform), " ", tag, " not submitted")
case asc.AppStoreVersionStateWaitingForReview,
asc.AppStoreVersionStateInReview:
log.Warn(string(platform), " ", tag, " waiting for review")
continue
case asc.AppStoreVersionStatePendingDeveloperRelease:
default:
log.Fatal(string(platform), " ", tag, " unknown state ", string(*version.Attributes.AppStoreState))
}
_, _, err = client.Publishing.CreatePhasedRelease(ctx, common.Ptr(asc.PhasedReleaseStateComplete), version.ID)
if err != nil {
return err
}
}
return nil
}
func isRetryable(response *asc.Response) bool {
if response == nil {
return false
}
switch response.StatusCode {
case http.StatusInternalServerError, http.StatusUnprocessableEntity:
return true
default:
return false
}
}

View File

@ -10,23 +10,17 @@ import (
_ "github.com/sagernet/gomobile" _ "github.com/sagernet/gomobile"
"github.com/sagernet/sing-box/cmd/internal/build_shared" "github.com/sagernet/sing-box/cmd/internal/build_shared"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/shell"
) )
var ( var (
debugEnabled bool debugEnabled bool
target string target string
platform string
withTailscale bool
) )
func init() { func init() {
flag.BoolVar(&debugEnabled, "debug", false, "enable debug") flag.BoolVar(&debugEnabled, "debug", false, "enable debug")
flag.StringVar(&target, "target", "android", "target platform") flag.StringVar(&target, "target", "android", "target platform")
flag.StringVar(&platform, "platform", "", "specify platform")
flag.BoolVar(&withTailscale, "with-tailscale", false, "build tailscale for iOS and tvOS")
} }
func main() { func main() {
@ -37,8 +31,8 @@ func main() {
switch target { switch target {
case "android": case "android":
buildAndroid() buildAndroid()
case "apple": case "ios":
buildApple() buildiOS()
} }
} }
@ -46,9 +40,7 @@ var (
sharedFlags []string sharedFlags []string
debugFlags []string debugFlags []string
sharedTags []string sharedTags []string
macOSTags []string iosTags []string
memcTags []string
notMemcTags []string
debugTags []string debugTags []string
) )
@ -59,73 +51,42 @@ func init() {
if err != nil { if err != nil {
currentTag = "unknown" currentTag = "unknown"
} }
sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -s -w -buildid= -checklinkname=0") sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -s -w -buildid=")
debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -checklinkname=0") debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag)
sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_utls", "with_clash_api", "with_conntrack") sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_ech", "with_utls", "with_clash_api")
macOSTags = append(macOSTags, "with_dhcp") iosTags = append(iosTags, "with_dhcp", "with_low_memory", "with_conntrack")
memcTags = append(memcTags, "with_tailscale")
notMemcTags = append(notMemcTags, "with_low_memory")
debugTags = append(debugTags, "debug") debugTags = append(debugTags, "debug")
} }
func buildAndroid() { func buildAndroid() {
build_shared.FindSDK() build_shared.FindSDK()
var javaPath string
javaHome := os.Getenv("JAVA_HOME")
if javaHome == "" {
javaPath = "java"
} else {
javaPath = filepath.Join(javaHome, "bin", "java")
}
javaVersion, err := shell.Exec(javaPath, "--version").ReadOutput()
if err != nil {
log.Fatal(E.Cause(err, "check java version"))
}
if !strings.Contains(javaVersion, "openjdk 17") {
log.Fatal("java version should be openjdk 17")
}
var bindTarget string
if platform != "" {
bindTarget = platform
} else if debugEnabled {
bindTarget = "android/arm64"
} else {
bindTarget = "android"
}
args := []string{ args := []string{
"bind", "bind",
"-v", "-v",
"-target", bindTarget,
"-androidapi", "21", "-androidapi", "21",
"-javapkg=io.nekohasekai", "-javapkg=io.nekohasekai",
"-libname=box", "-libname=box",
} }
if !debugEnabled { if !debugEnabled {
// sharedFlags[3] = sharedFlags[3] + " -checklinkname=0"
args = append(args, sharedFlags...) args = append(args, sharedFlags...)
} else { } else {
// debugFlags[1] = debugFlags[1] + " -checklinkname=0"
args = append(args, debugFlags...) args = append(args, debugFlags...)
} }
tags := append(sharedTags, memcTags...) args = append(args, "-tags")
if debugEnabled { if !debugEnabled {
tags = append(tags, debugTags...) args = append(args, strings.Join(sharedTags, ","))
} else {
args = append(args, strings.Join(append(sharedTags, debugTags...), ","))
} }
args = append(args, "-tags", strings.Join(tags, ","))
args = append(args, "./experimental/libbox") args = append(args, "./experimental/libbox")
command := exec.Command(build_shared.GoBinPath+"/gomobile", args...) command := exec.Command(build_shared.GoBinPath+"/gomobile", args...)
command.Stdout = os.Stdout command.Stdout = os.Stdout
command.Stderr = os.Stderr command.Stderr = os.Stderr
err = command.Run() err := command.Run()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -142,44 +103,26 @@ func buildAndroid() {
} }
} }
func buildApple() { func buildiOS() {
var bindTarget string
if platform != "" {
bindTarget = platform
} else if debugEnabled {
bindTarget = "ios"
} else {
bindTarget = "ios,tvos,macos"
}
args := []string{ args := []string{
"bind", "bind",
"-v", "-v",
"-target", bindTarget, "-target", "ios,iossimulator,tvos,tvossimulator,macos",
"-libname=box", "-libname=box",
"-tags-not-macos=with_low_memory",
} }
if !withTailscale {
args = append(args, "-tags-macos="+strings.Join(append(macOSTags, memcTags...), ","))
} else {
args = append(args, "-tags-macos="+strings.Join(macOSTags, ","))
}
if !debugEnabled { if !debugEnabled {
args = append(args, sharedFlags...) args = append(args, sharedFlags...)
} else { } else {
args = append(args, debugFlags...) args = append(args, debugFlags...)
} }
tags := sharedTags tags := append(sharedTags, iosTags...)
if withTailscale { args = append(args, "-tags")
tags = append(tags, memcTags...) if !debugEnabled {
args = append(args, strings.Join(tags, ","))
} else {
args = append(args, strings.Join(append(tags, debugTags...), ","))
} }
if debugEnabled {
tags = append(tags, debugTags...)
}
args = append(args, "-tags", strings.Join(tags, ","))
args = append(args, "./experimental/libbox") args = append(args, "./experimental/libbox")
command := exec.Command(build_shared.GoBinPath+"/gomobile", args...) command := exec.Command(build_shared.GoBinPath+"/gomobile", args...)

View File

@ -11,7 +11,9 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/shell"
) )
var ( var (
@ -40,6 +42,14 @@ func FindSDK() {
log.Fatal("android NDK not found") log.Fatal("android NDK not found")
} }
javaVersion, err := shell.Exec("java", "--version").ReadOutput()
if err != nil {
log.Fatal(E.Cause(err, "check java version"))
}
if !strings.Contains(javaVersion, "openjdk 17") {
log.Fatal("java version should be openjdk 17")
}
os.Setenv("ANDROID_HOME", androidSDKPath) os.Setenv("ANDROID_HOME", androidSDKPath)
os.Setenv("ANDROID_SDK_HOME", androidSDKPath) os.Setenv("ANDROID_SDK_HOME", androidSDKPath)
os.Setenv("ANDROID_NDK_HOME", androidNDKPath) os.Setenv("ANDROID_NDK_HOME", androidNDKPath)
@ -48,16 +58,12 @@ func FindSDK() {
} }
func findNDK() bool { func findNDK() bool {
const fixedVersion = "28.0.13004108" const fixedVersion = "26.2.11394342"
const versionFile = "source.properties" const versionFile = "source.properties"
if fixedPath := filepath.Join(androidSDKPath, "ndk", fixedVersion); rw.IsFile(filepath.Join(fixedPath, versionFile)) { if fixedPath := filepath.Join(androidSDKPath, "ndk", fixedVersion); rw.IsFile(filepath.Join(fixedPath, versionFile)) {
androidNDKPath = fixedPath androidNDKPath = fixedPath
return true return true
} }
if ndkHomeEnv := os.Getenv("ANDROID_NDK_HOME"); rw.IsFile(filepath.Join(ndkHomeEnv, versionFile)) {
androidNDKPath = ndkHomeEnv
return true
}
ndkVersions, err := os.ReadDir(filepath.Join(androidSDKPath, "ndk")) ndkVersions, err := os.ReadDir(filepath.Join(androidSDKPath, "ndk"))
if err != nil { if err != nil {
return false return false

View File

@ -20,11 +20,6 @@ func ReadTag() (string, error) {
return version.String() + "-" + shortCommit, nil return version.String() + "-" + shortCommit, nil
} }
func ReadTagVersionRev() (badversion.Version, error) {
currentTagRev := common.Must1(shell.Exec("git", "describe", "--tags", "--abbrev=0").ReadOutput())
return badversion.Parse(currentTagRev[1:]), nil
}
func ReadTagVersion() (badversion.Version, error) { func ReadTagVersion() (badversion.Version, error) {
currentTag := common.Must1(shell.Exec("git", "describe", "--tags").ReadOutput()) currentTag := common.Must1(shell.Exec("git", "describe", "--tags").ReadOutput())
currentTagRev := common.Must1(shell.Exec("git", "describe", "--tags", "--abbrev=0").ReadOutput()) currentTagRev := common.Must1(shell.Exec("git", "describe", "--tags", "--abbrev=0").ReadOutput())

View File

@ -1,71 +1,21 @@
package main package main
import ( import (
"flag"
"os" "os"
"github.com/sagernet/sing-box/cmd/internal/build_shared" "github.com/sagernet/sing-box/cmd/internal/build_shared"
"github.com/sagernet/sing-box/common/badversion"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
) )
var (
flagRunInCI bool
flagRunNightly bool
)
func init() {
flag.BoolVar(&flagRunInCI, "ci", false, "Run in CI")
flag.BoolVar(&flagRunNightly, "nightly", false, "Run nightly")
}
func main() { func main() {
flag.Parse() currentTag, err := build_shared.ReadTag()
var ( if err != nil {
versionStr string log.Error(err)
err error _, err = os.Stdout.WriteString("unknown\n")
)
if flagRunNightly {
var version badversion.Version
version, err = build_shared.ReadTagVersion()
if err == nil {
versionStr = version.String()
}
} else { } else {
versionStr, err = build_shared.ReadTag() _, err = os.Stdout.WriteString(currentTag + "\n")
} }
if flagRunInCI { if err != nil {
if err != nil { log.Error(err)
log.Fatal(err)
}
err = setGitHubEnv("version", versionStr)
if err != nil {
log.Fatal(err)
}
} else {
if err != nil {
log.Error(err)
os.Stdout.WriteString("unknown\n")
} else {
os.Stdout.WriteString(versionStr + "\n")
}
} }
} }
func setGitHubEnv(name string, value string) error {
outputFile, err := os.OpenFile(os.Getenv("GITHUB_ENV"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return err
}
_, err = outputFile.WriteString(name + "=" + value + "\n")
if err != nil {
outputFile.Close()
return err
}
err = outputFile.Close()
if err != nil {
return err
}
os.Stderr.WriteString(name + "=" + value + "\n")
return nil
}

View File

@ -1,284 +0,0 @@
package main
import (
"context"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"strings"
"syscall"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/shell"
)
var iperf3Path string
func main() {
err := main0()
if err != nil {
log.Fatal(err)
}
}
func main0() error {
err := shell.Exec("sudo", "ls").Run()
if err != nil {
return err
}
results, err := runTests()
if err != nil {
return err
}
encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(results)
}
func runTests() ([]TestResult, error) {
boxPaths := []string{
os.ExpandEnv("$HOME/Downloads/sing-box-1.11.15-darwin-arm64/sing-box"),
//"/Users/sekai/Downloads/sing-box-1.11.15-linux-arm64/sing-box",
"./sing-box",
}
stacks := []string{
"gvisor",
"system",
}
mtus := []int{
1500,
4064,
// 16384,
// 32768,
// 49152,
65535,
}
flagList := [][]string{
{},
}
var results []TestResult
for _, boxPath := range boxPaths {
for _, stack := range stacks {
for _, mtu := range mtus {
if strings.HasPrefix(boxPath, ".") {
for _, flags := range flagList {
result, err := testOnce(boxPath, stack, mtu, false, flags)
if err != nil {
return nil, err
}
results = append(results, *result)
}
} else {
result, err := testOnce(boxPath, stack, mtu, false, nil)
if err != nil {
return nil, err
}
results = append(results, *result)
}
}
}
}
return results, nil
}
type TestResult struct {
BoxPath string `json:"box_path"`
Stack string `json:"stack"`
MTU int `json:"mtu"`
Flags []string `json:"flags"`
MultiThread bool `json:"multi_thread"`
UploadSpeed string `json:"upload_speed"`
DownloadSpeed string `json:"download_speed"`
}
func testOnce(boxPath string, stackName string, mtu int, multiThread bool, flags []string) (result *TestResult, err error) {
testAddress := netip.MustParseAddr("1.1.1.1")
testConfig := option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeTun,
Options: &option.TunInboundOptions{
Address: []netip.Prefix{netip.MustParsePrefix("172.18.0.1/30")},
AutoRoute: true,
MTU: uint32(mtu),
Stack: stackName,
RouteAddress: []netip.Prefix{netip.PrefixFrom(testAddress, testAddress.BitLen())},
},
},
},
Route: &option.RouteOptions{
Rules: []option.Rule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultRule{
RawDefaultRule: option.RawDefaultRule{
IPCIDR: []string{testAddress.String()},
},
RuleAction: option.RuleAction{
Action: C.RuleActionTypeRouteOptions,
RouteOptionsOptions: option.RouteOptionsActionOptions{
OverrideAddress: "127.0.0.1",
},
},
},
},
},
AutoDetectInterface: true,
},
}
ctx := include.Context(context.Background())
tempConfig, err := os.CreateTemp("", "tun-bench-*.json")
if err != nil {
return
}
defer os.Remove(tempConfig.Name())
encoder := json.NewEncoderContext(ctx, tempConfig)
encoder.SetIndent("", " ")
err = encoder.Encode(testConfig)
if err != nil {
return nil, E.Cause(err, "encode test config")
}
tempConfig.Close()
var sudoArgs []string
if len(flags) > 0 {
sudoArgs = append(sudoArgs, "env")
sudoArgs = append(sudoArgs, flags...)
}
sudoArgs = append(sudoArgs, boxPath, "run", "-c", tempConfig.Name())
boxProcess := shell.Exec("sudo", sudoArgs...)
boxProcess.Stdout = &stderrWriter{}
boxProcess.Stderr = io.Discard
err = boxProcess.Start()
if err != nil {
return
}
if C.IsDarwin {
iperf3Path, err = exec.LookPath("iperf3-darwin")
} else {
iperf3Path, err = exec.LookPath("iperf3")
}
if err != nil {
return
}
serverProcess := shell.Exec(iperf3Path, "-s")
serverProcess.Stdout = io.Discard
serverProcess.Stderr = io.Discard
err = serverProcess.Start()
if err != nil {
return nil, E.Cause(err, "start iperf3 server")
}
time.Sleep(time.Second)
args := []string{"-c", testAddress.String()}
if multiThread {
args = append(args, "-P", "10")
}
uploadProcess := shell.Exec(iperf3Path, args...)
output, err := uploadProcess.Read()
if err != nil {
boxProcess.Process.Signal(syscall.SIGKILL)
serverProcess.Process.Signal(syscall.SIGKILL)
println(output)
return
}
uploadResult := common.SubstringBeforeLast(output, "iperf Done.")
uploadResult = common.SubstringBeforeLast(uploadResult, "sender")
uploadResult = common.SubstringBeforeLast(uploadResult, "bits/sec")
uploadResult = common.SubstringAfterLast(uploadResult, "Bytes")
uploadResult = strings.ReplaceAll(uploadResult, " ", "")
result = &TestResult{
BoxPath: boxPath,
Stack: stackName,
MTU: mtu,
Flags: flags,
MultiThread: multiThread,
UploadSpeed: uploadResult,
}
downloadProcess := shell.Exec(iperf3Path, append(args, "-R")...)
output, err = downloadProcess.Read()
if err != nil {
boxProcess.Process.Signal(syscall.SIGKILL)
serverProcess.Process.Signal(syscall.SIGKILL)
println(output)
return
}
downloadResult := common.SubstringBeforeLast(output, "iperf Done.")
downloadResult = common.SubstringBeforeLast(downloadResult, "receiver")
downloadResult = common.SubstringBeforeLast(downloadResult, "bits/sec")
downloadResult = common.SubstringAfterLast(downloadResult, "Bytes")
downloadResult = strings.ReplaceAll(downloadResult, " ", "")
result.DownloadSpeed = downloadResult
printArgs := []any{boxPath, stackName, mtu, "upload", uploadResult, "download", downloadResult}
if len(flags) > 0 {
printArgs = append(printArgs, "flags", strings.Join(flags, " "))
}
if multiThread {
printArgs = append(printArgs, "(-P 10)")
}
fmt.Println(printArgs...)
err = boxProcess.Process.Signal(syscall.SIGTERM)
if err != nil {
return
}
err = serverProcess.Process.Signal(syscall.SIGTERM)
if err != nil {
return
}
boxDone := make(chan struct{})
go func() {
boxProcess.Cmd.Wait()
close(boxDone)
}()
serverDone := make(chan struct{})
go func() {
serverProcess.Process.Wait()
close(serverDone)
}()
select {
case <-boxDone:
case <-time.After(2 * time.Second):
boxProcess.Process.Kill()
case <-time.After(4 * time.Second):
println("box process did not close!")
os.Exit(1)
}
select {
case <-serverDone:
case <-time.After(2 * time.Second):
serverProcess.Process.Kill()
case <-time.After(4 * time.Second):
println("server process did not close!")
os.Exit(1)
}
return
}
type stderrWriter struct{}
func (w *stderrWriter) Write(p []byte) (n int, err error) {
return os.Stderr.Write(p)
}

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"flag"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -13,26 +12,9 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
) )
var (
flagRunInCI bool
flagRunNightly bool
)
func init() {
flag.BoolVar(&flagRunInCI, "ci", false, "Run in CI")
flag.BoolVar(&flagRunNightly, "nightly", false, "Run nightly")
}
func main() { func main() {
flag.Parse() newVersion := common.Must1(build_shared.ReadTagVersion())
newVersion := common.Must1(build_shared.ReadTag()) androidPath, err := filepath.Abs("../sing-box-for-android")
var androidPath string
if flagRunInCI {
androidPath = "clients/android"
} else {
androidPath = "../sing-box-for-android"
}
androidPath, err := filepath.Abs(androidPath)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -49,24 +31,22 @@ func main() {
for _, propPair := range propsList { for _, propPair := range propsList {
switch propPair[0] { switch propPair[0] {
case "VERSION_NAME": case "VERSION_NAME":
if propPair[1] != newVersion { if propPair[1] != newVersion.String() {
log.Info("updated version from ", propPair[1], " to ", newVersion)
versionUpdated = true versionUpdated = true
propPair[1] = newVersion propPair[1] = newVersion.String()
log.Info("updated version to ", newVersion.String())
} }
case "GO_VERSION": case "GO_VERSION":
if propPair[1] != runtime.Version() { if propPair[1] != runtime.Version() {
log.Info("updated Go version from ", propPair[1], " to ", runtime.Version())
goVersionUpdated = true goVersionUpdated = true
propPair[1] = runtime.Version() propPair[1] = runtime.Version()
log.Info("updated Go version to ", runtime.Version())
} }
} }
} }
if !(versionUpdated || goVersionUpdated) { if !(versionUpdated || goVersionUpdated) {
log.Info("version not changed") log.Info("version not changed")
return return
} else if flagRunInCI && !flagRunNightly {
log.Fatal("version changed, commit changes first.")
} }
for _, propPair := range propsList { for _, propPair := range propsList {
switch propPair[0] { switch propPair[0] {

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"flag"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -14,22 +13,9 @@ import (
"howett.net/plist" "howett.net/plist"
) )
var flagRunInCI bool
func init() {
flag.BoolVar(&flagRunInCI, "ci", false, "Run in CI")
}
func main() { func main() {
flag.Parse()
newVersion := common.Must1(build_shared.ReadTagVersion()) newVersion := common.Must1(build_shared.ReadTagVersion())
var applePath string applePath, err := filepath.Abs("../sing-box-for-apple")
if flagRunInCI {
applePath = "clients/apple"
} else {
applePath = "../sing-box-for-apple"
}
applePath, err := filepath.Abs(applePath)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -1,71 +0,0 @@
package main
import (
"encoding/csv"
"io"
"net/http"
"os"
"strings"
"github.com/sagernet/sing-box/log"
"golang.org/x/exp/slices"
)
func main() {
err := updateMozillaIncludedRootCAs()
if err != nil {
log.Error(err)
}
}
func updateMozillaIncludedRootCAs() error {
response, err := http.Get("https://ccadb.my.salesforce-sites.com/mozilla/IncludedCACertificateReportPEMCSV")
if err != nil {
return err
}
defer response.Body.Close()
reader := csv.NewReader(response.Body)
header, err := reader.Read()
if err != nil {
return err
}
geoIndex := slices.Index(header, "Geographic Focus")
nameIndex := slices.Index(header, "Common Name or Certificate Name")
certIndex := slices.Index(header, "PEM Info")
generated := strings.Builder{}
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
package certificate
import "crypto/x509"
var mozillaIncluded *x509.CertPool
func init() {
mozillaIncluded = x509.NewCertPool()
`)
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return err
}
if record[geoIndex] == "China" {
continue
}
generated.WriteString("\n // ")
generated.WriteString(record[nameIndex])
generated.WriteString("\n")
generated.WriteString(" mozillaIncluded.AppendCertsFromPEM([]byte(`")
cert := record[certIndex]
// Remove single quotes
cert = cert[1 : len(cert)-1]
generated.WriteString(cert)
generated.WriteString("`))\n")
}
generated.WriteString("}\n")
return os.WriteFile("common/certificate/mozilla.go", []byte(generated.String()), 0o644)
}

View File

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/include" "github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@ -67,5 +68,6 @@ func preRun(cmd *cobra.Command, args []string) {
if len(configPaths) == 0 && len(configDirectories) == 0 { if len(configPaths) == 0 && len(configDirectories) == 0 {
configPaths = append(configPaths, "config.json") configPaths = append(configPaths, "config.json")
} }
globalCtx = include.Context(service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))) globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
} }

View File

@ -9,6 +9,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var pqSignatureSchemesEnabled bool
var commandGenerateECHKeyPair = &cobra.Command{ var commandGenerateECHKeyPair = &cobra.Command{
Use: "ech-keypair <plain_server_name>", Use: "ech-keypair <plain_server_name>",
Short: "Generate TLS ECH key pair", Short: "Generate TLS ECH key pair",
@ -22,11 +24,12 @@ var commandGenerateECHKeyPair = &cobra.Command{
} }
func init() { func init() {
commandGenerateECHKeyPair.Flags().BoolVar(&pqSignatureSchemesEnabled, "pq-signature-schemes-enabled", false, "Enable PQ signature schemes")
commandGenerate.AddCommand(commandGenerateECHKeyPair) commandGenerate.AddCommand(commandGenerateECHKeyPair)
} }
func generateECHKeyPair(serverName string) error { func generateECHKeyPair(serverName string) error {
configPem, keyPem, err := tls.ECHKeygenDefault(serverName) configPem, keyPem, err := tls.ECHKeygenDefault(serverName, pqSignatureSchemesEnabled)
if err != nil { if err != nil {
return err return err
} }

View File

@ -30,7 +30,7 @@ func init() {
} }
func generateTLSKeyPair(serverName string) error { func generateTLSKeyPair(serverName string) error {
privateKeyPem, publicKeyPem, err := tls.GenerateCertificate(nil, nil, time.Now, serverName, time.Now().AddDate(0, flagGenerateTLSKeyPairMonths, 0)) privateKeyPem, publicKeyPem, err := tls.GenerateKeyPair(time.Now, serverName, time.Now().AddDate(0, flagGenerateTLSKeyPairMonths, 0))
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,7 +18,7 @@ import (
) )
var commandMerge = &cobra.Command{ var commandMerge = &cobra.Command{
Use: "merge <output-path>", Use: "merge <output>",
Short: "Merge configurations", Short: "Merge configurations",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
err := merge(args[0]) err := merge(args[0])

View File

@ -6,10 +6,8 @@ import (
"strings" "strings"
"github.com/sagernet/sing-box/common/srs" "github.com/sagernet/sing-box/common/srs"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -71,7 +69,7 @@ func compileRuleSet(sourcePath string) error {
if err != nil { if err != nil {
return err return err
} }
err = srs.Write(outputFile, plainRuleSet.Options, downgradeRuleSetVersion(plainRuleSet.Version, plainRuleSet.Options)) err = srs.Write(outputFile, plainRuleSet.Options, plainRuleSet.Version)
if err != nil { if err != nil {
outputFile.Close() outputFile.Close()
os.Remove(outputPath) os.Remove(outputPath)
@ -80,18 +78,3 @@ func compileRuleSet(sourcePath string) error {
outputFile.Close() outputFile.Close()
return nil return nil
} }
func downgradeRuleSetVersion(version uint8, options option.PlainRuleSet) uint8 {
if version == C.RuleSetVersion4 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
return rule.NetworkInterfaceAddress != nil && rule.NetworkInterfaceAddress.Size() > 0 ||
len(rule.DefaultInterfaceAddress) > 0
}) {
version = C.RuleSetVersion3
}
if version == C.RuleSetVersion3 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
return len(rule.NetworkType) > 0 || rule.NetworkIsExpensive || rule.NetworkIsConstrained
}) {
version = C.RuleSetVersion2
}
return version
}

View File

@ -5,7 +5,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/sagernet/sing-box/common/convertor/adguard" "github.com/sagernet/sing-box/cmd/sing-box/internal/convertor/adguard"
"github.com/sagernet/sing-box/common/srs" "github.com/sagernet/sing-box/common/srs"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@ -54,7 +54,7 @@ func convertRuleSet(sourcePath string) error {
var rules []option.HeadlessRule var rules []option.HeadlessRule
switch flagRuleSetConvertType { switch flagRuleSetConvertType {
case "adguard": case "adguard":
rules, err = adguard.ToOptions(reader, log.StdLogger()) rules, err = adguard.Convert(reader)
case "": case "":
return E.New("source type is required") return E.New("source type is required")
default: default:

View File

@ -6,10 +6,7 @@ import (
"strings" "strings"
"github.com/sagernet/sing-box/common/srs" "github.com/sagernet/sing-box/common/srs"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -53,11 +50,6 @@ func decompileRuleSet(sourcePath string) error {
if err != nil { if err != nil {
return err return err
} }
if hasRule(ruleSet.Options.Rules, func(rule option.DefaultHeadlessRule) bool {
return len(rule.AdGuardDomain) > 0
}) {
return E.New("unable to decompile binary AdGuard rules to rule-set.")
}
var outputPath string var outputPath string
if flagRuleSetDecompileOutput == flagRuleSetDecompileDefaultOutput { if flagRuleSetDecompileOutput == flagRuleSetDecompileDefaultOutput {
if strings.HasSuffix(sourcePath, ".srs") { if strings.HasSuffix(sourcePath, ".srs") {
@ -83,19 +75,3 @@ func decompileRuleSet(sourcePath string) error {
outputFile.Close() outputFile.Close()
return nil return nil
} }
func hasRule(rules []option.HeadlessRule, cond func(rule option.DefaultHeadlessRule) bool) bool {
for _, rule := range rules {
switch rule.Type {
case C.RuleTypeDefault:
if cond(rule.DefaultOptions) {
return true
}
case C.RuleTypeLogical:
if hasRule(rule.LogicalOptions.Rules, cond) {
return true
}
}
}
return false
}

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"io" "io"
"os" "os"
"path/filepath"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/srs" "github.com/sagernet/sing-box/common/srs"
@ -57,14 +56,6 @@ func ruleSetMatch(sourcePath string, domain string) error {
if err != nil { if err != nil {
return E.Cause(err, "read rule-set") return E.Cause(err, "read rule-set")
} }
if flagRuleSetMatchFormat == "" {
switch filepath.Ext(sourcePath) {
case ".json":
flagRuleSetMatchFormat = C.RuleSetFormatSource
case ".srs":
flagRuleSetMatchFormat = C.RuleSetFormatBinary
}
}
var ruleSet option.PlainRuleSetCompat var ruleSet option.PlainRuleSetCompat
switch flagRuleSetMatchFormat { switch flagRuleSetMatchFormat {
case C.RuleSetFormatSource: case C.RuleSetFormatSource:

View File

@ -1,162 +0,0 @@
package main
import (
"bytes"
"io"
"os"
"path/filepath"
"sort"
"strings"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/common/rw"
"github.com/spf13/cobra"
)
var (
ruleSetPaths []string
ruleSetDirectories []string
)
var commandRuleSetMerge = &cobra.Command{
Use: "merge <output-path>",
Short: "Merge rule-set source files",
Run: func(cmd *cobra.Command, args []string) {
err := mergeRuleSet(args[0])
if err != nil {
log.Fatal(err)
}
},
Args: cobra.ExactArgs(1),
}
func init() {
commandRuleSetMerge.Flags().StringArrayVarP(&ruleSetPaths, "config", "c", nil, "set input rule-set file path")
commandRuleSetMerge.Flags().StringArrayVarP(&ruleSetDirectories, "config-directory", "C", nil, "set input rule-set directory path")
commandRuleSet.AddCommand(commandRuleSetMerge)
}
type RuleSetEntry struct {
content []byte
path string
options option.PlainRuleSetCompat
}
func readRuleSetAt(path string) (*RuleSetEntry, error) {
var (
configContent []byte
err error
)
if path == "stdin" {
configContent, err = io.ReadAll(os.Stdin)
} else {
configContent, err = os.ReadFile(path)
}
if err != nil {
return nil, E.Cause(err, "read config at ", path)
}
options, err := json.UnmarshalExtendedContext[option.PlainRuleSetCompat](globalCtx, configContent)
if err != nil {
return nil, E.Cause(err, "decode config at ", path)
}
return &RuleSetEntry{
content: configContent,
path: path,
options: options,
}, nil
}
func readRuleSet() ([]*RuleSetEntry, error) {
var optionsList []*RuleSetEntry
for _, path := range ruleSetPaths {
optionsEntry, err := readRuleSetAt(path)
if err != nil {
return nil, err
}
optionsList = append(optionsList, optionsEntry)
}
for _, directory := range ruleSetDirectories {
entries, err := os.ReadDir(directory)
if err != nil {
return nil, E.Cause(err, "read rule-set directory at ", directory)
}
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".json") || entry.IsDir() {
continue
}
optionsEntry, err := readRuleSetAt(filepath.Join(directory, entry.Name()))
if err != nil {
return nil, err
}
optionsList = append(optionsList, optionsEntry)
}
}
sort.Slice(optionsList, func(i, j int) bool {
return optionsList[i].path < optionsList[j].path
})
return optionsList, nil
}
func readRuleSetAndMerge() (option.PlainRuleSetCompat, error) {
optionsList, err := readRuleSet()
if err != nil {
return option.PlainRuleSetCompat{}, err
}
if len(optionsList) == 1 {
return optionsList[0].options, nil
}
var optionVersion uint8
for _, options := range optionsList {
if optionVersion < options.options.Version {
optionVersion = options.options.Version
}
}
var mergedMessage json.RawMessage
for _, options := range optionsList {
mergedMessage, err = badjson.MergeJSON(globalCtx, options.options.RawMessage, mergedMessage, false)
if err != nil {
return option.PlainRuleSetCompat{}, E.Cause(err, "merge config at ", options.path)
}
}
mergedOptions, err := json.UnmarshalExtendedContext[option.PlainRuleSetCompat](globalCtx, mergedMessage)
if err != nil {
return option.PlainRuleSetCompat{}, E.Cause(err, "unmarshal merged config")
}
mergedOptions.Version = optionVersion
return mergedOptions, nil
}
func mergeRuleSet(outputPath string) error {
mergedOptions, err := readRuleSetAndMerge()
if err != nil {
return err
}
buffer := new(bytes.Buffer)
encoder := json.NewEncoder(buffer)
encoder.SetIndent("", " ")
err = encoder.Encode(mergedOptions)
if err != nil {
return E.Cause(err, "encode config")
}
if existsContent, err := os.ReadFile(outputPath); err != nil {
if string(existsContent) == buffer.String() {
return nil
}
}
err = rw.MkdirParent(outputPath)
if err != nil {
return err
}
err = os.WriteFile(outputPath, buffer.Bytes(), 0o644)
if err != nil {
return err
}
outputPath, _ = filepath.Abs(outputPath)
os.Stderr.WriteString(outputPath + "\n")
return nil
}

View File

@ -61,15 +61,14 @@ func upgradeRuleSet(sourcePath string) error {
log.Info("already up-to-date") log.Info("already up-to-date")
return nil return nil
} }
plainRuleSetCompat.Options, err = plainRuleSetCompat.Upgrade() plainRuleSet, err := plainRuleSetCompat.Upgrade()
if err != nil { if err != nil {
return err return err
} }
plainRuleSetCompat.Version = C.RuleSetVersionCurrent
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
encoder := json.NewEncoder(buffer) encoder := json.NewEncoder(buffer)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
err = encoder.Encode(plainRuleSetCompat) err = encoder.Encode(plainRuleSet)
if err != nil { if err != nil {
return E.Cause(err, "encode config") return E.Cause(err, "encode config")
} }

View File

@ -30,7 +30,7 @@ func createPreStartedClient() (*box.Box, error) {
return nil, err return nil, err
} }
} }
instance, err := box.New(box.Options{Context: globalCtx, Options: options}) instance, err := box.New(box.Options{Options: options})
if err != nil { if err != nil {
return nil, E.Cause(err, "create service") return nil, E.Cause(err, "create service")
} }

View File

@ -21,7 +21,7 @@ func initializeHTTP3Client(instance *box.Box) error {
return err return err
} }
http3Client = &http.Client{ http3Client = &http.Client{
Transport: &http3.Transport{ Transport: &http3.RoundTripper{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
destination := M.ParseSocksaddr(addr) destination := M.ParseSocksaddr(addr)
udpConn, dErr := dialer.DialContext(ctx, N.NetworkUDP, destination) udpConn, dErr := dialer.DialContext(ctx, N.NetworkUDP, destination)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"os" "os"
"github.com/sagernet/sing-box/common/settings"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -57,7 +58,7 @@ func syncTime() error {
return err return err
} }
if commandSyncTimeWrite { if commandSyncTimeWrite {
err = ntp.SetSystemTime(response.Time) err = settings.SetSystemTime(response.Time)
if err != nil { if err != nil {
return E.Cause(err, "write time to system") return E.Cause(err, "write time to system")
} }

View File

@ -2,7 +2,6 @@ package adguard
import ( import (
"bufio" "bufio"
"bytes"
"io" "io"
"net/netip" "net/netip"
"os" "os"
@ -10,10 +9,10 @@ import (
"strings" "strings"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
) )
@ -28,7 +27,7 @@ type agdguardRuleLine struct {
isImportant bool isImportant bool
} }
func ToOptions(reader io.Reader, logger logger.Logger) ([]option.HeadlessRule, error) { func Convert(reader io.Reader) ([]option.HeadlessRule, error) {
scanner := bufio.NewScanner(reader) scanner := bufio.NewScanner(reader)
var ( var (
ruleLines []agdguardRuleLine ruleLines []agdguardRuleLine
@ -37,10 +36,7 @@ func ToOptions(reader io.Reader, logger logger.Logger) ([]option.HeadlessRule, e
parseLine: parseLine:
for scanner.Scan() { for scanner.Scan() {
ruleLine := scanner.Text() ruleLine := scanner.Text()
if ruleLine == "" { if ruleLine == "" || ruleLine[0] == '!' || ruleLine[0] == '#' {
continue
}
if strings.HasPrefix(ruleLine, "!") || strings.HasPrefix(ruleLine, "#") {
continue continue
} }
originRuleLine := ruleLine originRuleLine := ruleLine
@ -96,7 +92,7 @@ parseLine:
} }
if !ignored { if !ignored {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported rule with modifier: ", paramParts[0], ": ", originRuleLine) log.Debug("ignored unsupported rule with modifier: ", paramParts[0], ": ", ruleLine)
continue parseLine continue parseLine
} }
} }
@ -124,35 +120,27 @@ parseLine:
ruleLine = ruleLine[1 : len(ruleLine)-1] ruleLine = ruleLine[1 : len(ruleLine)-1]
if ignoreIPCIDRRegexp(ruleLine) { if ignoreIPCIDRRegexp(ruleLine) {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported rule with IPCIDR regexp: ", originRuleLine) log.Debug("ignored unsupported rule with IPCIDR regexp: ", ruleLine)
continue continue
} }
isRegexp = true isRegexp = true
} else { } else {
if strings.Contains(ruleLine, "://") { if strings.Contains(ruleLine, "://") {
ruleLine = common.SubstringAfter(ruleLine, "://") ruleLine = common.SubstringAfter(ruleLine, "://")
isSuffix = true
} }
if strings.Contains(ruleLine, "/") { if strings.Contains(ruleLine, "/") {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported rule with path: ", originRuleLine) log.Debug("ignored unsupported rule with path: ", ruleLine)
continue continue
} }
if strings.Contains(ruleLine, "?") || strings.Contains(ruleLine, "&") { if strings.Contains(ruleLine, "##") {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported rule with query: ", originRuleLine) log.Debug("ignored unsupported rule with element hiding: ", ruleLine)
continue continue
} }
if strings.Contains(ruleLine, "[") || strings.Contains(ruleLine, "]") || if strings.Contains(ruleLine, "#$#") {
strings.Contains(ruleLine, "(") || strings.Contains(ruleLine, ")") ||
strings.Contains(ruleLine, "!") || strings.Contains(ruleLine, "#") {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported cosmetic filter: ", originRuleLine) log.Debug("ignored unsupported rule with element hiding: ", ruleLine)
continue
}
if strings.Contains(ruleLine, "~") {
ignoredLines++
logger.Debug("ignored unsupported rule modifier: ", originRuleLine)
continue continue
} }
var domainCheck string var domainCheck string
@ -163,7 +151,7 @@ parseLine:
} }
if ruleLine == "" { if ruleLine == "" {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported rule with empty domain", originRuleLine) log.Debug("ignored unsupported rule with empty domain", originRuleLine)
continue continue
} else { } else {
domainCheck = strings.ReplaceAll(domainCheck, "*", "x") domainCheck = strings.ReplaceAll(domainCheck, "*", "x")
@ -171,13 +159,13 @@ parseLine:
_, ipErr := parseADGuardIPCIDRLine(ruleLine) _, ipErr := parseADGuardIPCIDRLine(ruleLine)
if ipErr == nil { if ipErr == nil {
ignoredLines++ ignoredLines++
logger.Debug("ignored unsupported rule with IPCIDR: ", originRuleLine) log.Debug("ignored unsupported rule with IPCIDR: ", ruleLine)
continue continue
} }
if M.ParseSocksaddr(domainCheck).Port != 0 { if M.ParseSocksaddr(domainCheck).Port != 0 {
logger.Debug("ignored unsupported rule with port: ", originRuleLine) log.Debug("ignored unsupported rule with port: ", ruleLine)
} else { } else {
logger.Debug("ignored unsupported rule with invalid domain: ", originRuleLine) log.Debug("ignored unsupported rule with invalid domain: ", ruleLine)
} }
ignoredLines++ ignoredLines++
continue continue
@ -295,112 +283,10 @@ parseLine:
}, },
} }
} }
if ignoredLines > 0 { log.Info("parsed rules: ", len(ruleLines), "/", len(ruleLines)+ignoredLines)
logger.Info("parsed rules: ", len(ruleLines), "/", len(ruleLines)+ignoredLines)
}
return []option.HeadlessRule{currentRule}, nil return []option.HeadlessRule{currentRule}, nil
} }
var ErrInvalid = E.New("invalid binary AdGuard rule-set")
func FromOptions(rules []option.HeadlessRule) ([]byte, error) {
if len(rules) != 1 {
return nil, ErrInvalid
}
rule := rules[0]
var (
importantDomain []string
importantDomainRegex []string
importantExcludeDomain []string
importantExcludeDomainRegex []string
domain []string
domainRegex []string
excludeDomain []string
excludeDomainRegex []string
)
parse:
for {
switch rule.Type {
case C.RuleTypeLogical:
if !(len(rule.LogicalOptions.Rules) == 2 && rule.LogicalOptions.Rules[0].Type == C.RuleTypeDefault) {
return nil, ErrInvalid
}
if rule.LogicalOptions.Mode == C.LogicalTypeAnd && rule.LogicalOptions.Rules[0].DefaultOptions.Invert {
if len(importantExcludeDomain) == 0 && len(importantExcludeDomainRegex) == 0 {
importantExcludeDomain = rule.LogicalOptions.Rules[0].DefaultOptions.AdGuardDomain
importantExcludeDomainRegex = rule.LogicalOptions.Rules[0].DefaultOptions.DomainRegex
if len(importantExcludeDomain)+len(importantExcludeDomainRegex) == 0 {
return nil, ErrInvalid
}
} else {
excludeDomain = rule.LogicalOptions.Rules[0].DefaultOptions.AdGuardDomain
excludeDomainRegex = rule.LogicalOptions.Rules[0].DefaultOptions.DomainRegex
if len(excludeDomain)+len(excludeDomainRegex) == 0 {
return nil, ErrInvalid
}
}
} else if rule.LogicalOptions.Mode == C.LogicalTypeOr && !rule.LogicalOptions.Rules[0].DefaultOptions.Invert {
importantDomain = rule.LogicalOptions.Rules[0].DefaultOptions.AdGuardDomain
importantDomainRegex = rule.LogicalOptions.Rules[0].DefaultOptions.DomainRegex
if len(importantDomain)+len(importantDomainRegex) == 0 {
return nil, ErrInvalid
}
} else {
return nil, ErrInvalid
}
rule = rule.LogicalOptions.Rules[1]
case C.RuleTypeDefault:
domain = rule.DefaultOptions.AdGuardDomain
domainRegex = rule.DefaultOptions.DomainRegex
if len(domain)+len(domainRegex) == 0 {
return nil, ErrInvalid
}
break parse
}
}
var output bytes.Buffer
for _, ruleLine := range importantDomain {
output.WriteString(ruleLine)
output.WriteString("$important\n")
}
for _, ruleLine := range importantDomainRegex {
output.WriteString("/")
output.WriteString(ruleLine)
output.WriteString("/$important\n")
}
for _, ruleLine := range importantExcludeDomain {
output.WriteString("@@")
output.WriteString(ruleLine)
output.WriteString("$important\n")
}
for _, ruleLine := range importantExcludeDomainRegex {
output.WriteString("@@/")
output.WriteString(ruleLine)
output.WriteString("/$important\n")
}
for _, ruleLine := range domain {
output.WriteString(ruleLine)
output.WriteString("\n")
}
for _, ruleLine := range domainRegex {
output.WriteString("/")
output.WriteString(ruleLine)
output.WriteString("/\n")
}
for _, ruleLine := range excludeDomain {
output.WriteString("@@")
output.WriteString(ruleLine)
output.WriteString("\n")
}
for _, ruleLine := range excludeDomainRegex {
output.WriteString("@@/")
output.WriteString(ruleLine)
output.WriteString("/\n")
}
return output.Bytes(), nil
}
func ignoreIPCIDRRegexp(ruleLine string) bool { func ignoreIPCIDRRegexp(ruleLine string) bool {
if strings.HasPrefix(ruleLine, "(http?:\\/\\/)") { if strings.HasPrefix(ruleLine, "(http?:\\/\\/)") {
ruleLine = ruleLine[12:] ruleLine = ruleLine[12:]
@ -408,9 +294,11 @@ func ignoreIPCIDRRegexp(ruleLine string) bool {
ruleLine = ruleLine[13:] ruleLine = ruleLine[13:]
} else if strings.HasPrefix(ruleLine, "^") { } else if strings.HasPrefix(ruleLine, "^") {
ruleLine = ruleLine[1:] ruleLine = ruleLine[1:]
} else {
return false
} }
return common.Error(strconv.ParseUint(common.SubstringBefore(ruleLine, "\\."), 10, 8)) == nil || _, parseErr := strconv.ParseUint(common.SubstringBefore(ruleLine, "\\."), 10, 8)
common.Error(strconv.ParseUint(common.SubstringBefore(ruleLine, "."), 10, 8)) == nil return parseErr == nil
} }
func parseAdGuardHostLine(ruleLine string) (string, error) { func parseAdGuardHostLine(ruleLine string) (string, error) {
@ -454,5 +342,5 @@ func parseADGuardIPCIDRLine(ruleLine string) (netip.Prefix, error) {
for len(ruleParts) < 4 { for len(ruleParts) < 4 {
ruleParts = append(ruleParts, 0) ruleParts = append(ruleParts, 0)
} }
return netip.PrefixFrom(netip.AddrFrom4([4]byte(ruleParts)), bitLen), nil return netip.PrefixFrom(netip.AddrFrom4(*(*[4]byte)(ruleParts)), bitLen), nil
} }

View File

@ -7,15 +7,13 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing/common/logger"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestConverter(t *testing.T) { func TestConverter(t *testing.T) {
t.Parallel() t.Parallel()
ruleString := `||sagernet.org^$important rules, err := Convert(strings.NewReader(`
@@|sing-box.sagernet.org^$important
||example.org^ ||example.org^
|example.com^ |example.com^
example.net^ example.net^
@ -23,9 +21,10 @@ example.net^
||example.edu.tw^ ||example.edu.tw^
|example.gov |example.gov
example.arpa example.arpa
@@|sagernet.example.org^ @@|sagernet.example.org|
` ||sagernet.org^$important
rules, err := ToOptions(strings.NewReader(ruleString), logger.NOP()) @@|sing-box.sagernet.org^$important
`))
require.NoError(t, err) require.NoError(t, err)
require.Len(t, rules, 1) require.Len(t, rules, 1)
rule, err := rule.NewHeadlessRule(context.Background(), rules[0]) rule, err := rule.NewHeadlessRule(context.Background(), rules[0])
@ -76,18 +75,15 @@ example.arpa
Domain: domain, Domain: domain,
}), domain) }), domain)
} }
ruleFromOptions, err := FromOptions(rules)
require.NoError(t, err)
require.Equal(t, ruleString, string(ruleFromOptions))
} }
func TestHosts(t *testing.T) { func TestHosts(t *testing.T) {
t.Parallel() t.Parallel()
rules, err := ToOptions(strings.NewReader(` rules, err := Convert(strings.NewReader(`
127.0.0.1 localhost 127.0.0.1 localhost
::1 localhost #[IPv6] ::1 localhost #[IPv6]
0.0.0.0 google.com 0.0.0.0 google.com
`), logger.NOP()) `))
require.NoError(t, err) require.NoError(t, err)
require.Len(t, rules, 1) require.Len(t, rules, 1)
rule, err := rule.NewHeadlessRule(context.Background(), rules[0]) rule, err := rule.NewHeadlessRule(context.Background(), rules[0])
@ -114,10 +110,10 @@ func TestHosts(t *testing.T) {
func TestSimpleHosts(t *testing.T) { func TestSimpleHosts(t *testing.T) {
t.Parallel() t.Parallel()
rules, err := ToOptions(strings.NewReader(` rules, err := Convert(strings.NewReader(`
example.com example.com
www.example.org www.example.org
`), logger.NOP()) `))
require.NoError(t, err) require.NoError(t, err)
require.Len(t, rules, 1) require.Len(t, rules, 1)
rule, err := rule.NewHeadlessRule(context.Background(), rules[0]) rule, err := rule.NewHeadlessRule(context.Background(), rules[0])

View File

@ -1,169 +0,0 @@
//go:build go1.25 && !without_badtls
package badtls
import (
"bytes"
"os"
"reflect"
"sync/atomic"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/tls"
)
type RawConn struct {
pointer unsafe.Pointer
methods *Methods
IsClient *bool
IsHandshakeComplete *atomic.Bool
Vers *uint16
CipherSuite *uint16
RawInput *bytes.Buffer
Input *bytes.Reader
Hand *bytes.Buffer
CloseNotifySent *bool
CloseNotifyErr *error
In *RawHalfConn
Out *RawHalfConn
BytesSent *int64
PacketsSent *int64
ActiveCall *atomic.Int32
}
func NewRawConn(rawTLSConn tls.Conn) (*RawConn, error) {
var (
pointer unsafe.Pointer
methods *Methods
loaded bool
)
for _, tlsCreator := range methodRegistry {
pointer, methods, loaded = tlsCreator(rawTLSConn)
if loaded {
break
}
}
if !loaded {
return nil, os.ErrInvalid
}
conn := &RawConn{
pointer: pointer,
methods: methods,
}
rawConn := reflect.Indirect(reflect.ValueOf(rawTLSConn))
rawIsClient := rawConn.FieldByName("isClient")
if !rawIsClient.IsValid() || rawIsClient.Kind() != reflect.Bool {
return nil, E.New("invalid Conn.isClient")
}
conn.IsClient = (*bool)(unsafe.Pointer(rawIsClient.UnsafeAddr()))
rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.isHandshakeComplete")
}
conn.IsHandshakeComplete = (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
rawVers := rawConn.FieldByName("vers")
if !rawVers.IsValid() || rawVers.Kind() != reflect.Uint16 {
return nil, E.New("invalid Conn.vers")
}
conn.Vers = (*uint16)(unsafe.Pointer(rawVers.UnsafeAddr()))
rawCipherSuite := rawConn.FieldByName("cipherSuite")
if !rawCipherSuite.IsValid() || rawCipherSuite.Kind() != reflect.Uint16 {
return nil, E.New("invalid Conn.cipherSuite")
}
conn.CipherSuite = (*uint16)(unsafe.Pointer(rawCipherSuite.UnsafeAddr()))
rawRawInput := rawConn.FieldByName("rawInput")
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.rawInput")
}
conn.RawInput = (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
rawInput := rawConn.FieldByName("input")
if !rawInput.IsValid() || rawInput.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.input")
}
conn.Input = (*bytes.Reader)(unsafe.Pointer(rawInput.UnsafeAddr()))
rawHand := rawConn.FieldByName("hand")
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.hand")
}
conn.Hand = (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
rawCloseNotifySent := rawConn.FieldByName("closeNotifySent")
if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool {
return nil, E.New("invalid Conn.closeNotifySent")
}
conn.CloseNotifySent = (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr()))
rawCloseNotifyErr := rawConn.FieldByName("closeNotifyErr")
if !rawCloseNotifyErr.IsValid() || rawCloseNotifyErr.Kind() != reflect.Interface {
return nil, E.New("invalid Conn.closeNotifyErr")
}
conn.CloseNotifyErr = (*error)(unsafe.Pointer(rawCloseNotifyErr.UnsafeAddr()))
rawIn := rawConn.FieldByName("in")
if !rawIn.IsValid() || rawIn.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.in")
}
halfIn, err := NewRawHalfConn(rawIn, methods)
if err != nil {
return nil, E.Cause(err, "invalid Conn.in")
}
conn.In = halfIn
rawOut := rawConn.FieldByName("out")
if !rawOut.IsValid() || rawOut.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.out")
}
halfOut, err := NewRawHalfConn(rawOut, methods)
if err != nil {
return nil, E.Cause(err, "invalid Conn.out")
}
conn.Out = halfOut
rawBytesSent := rawConn.FieldByName("bytesSent")
if !rawBytesSent.IsValid() || rawBytesSent.Kind() != reflect.Int64 {
return nil, E.New("invalid Conn.bytesSent")
}
conn.BytesSent = (*int64)(unsafe.Pointer(rawBytesSent.UnsafeAddr()))
rawPacketsSent := rawConn.FieldByName("packetsSent")
if !rawPacketsSent.IsValid() || rawPacketsSent.Kind() != reflect.Int64 {
return nil, E.New("invalid Conn.packetsSent")
}
conn.PacketsSent = (*int64)(unsafe.Pointer(rawPacketsSent.UnsafeAddr()))
rawActiveCall := rawConn.FieldByName("activeCall")
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
return nil, E.New("invalid Conn.activeCall")
}
conn.ActiveCall = (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
return conn, nil
}
func (c *RawConn) ReadRecord() error {
return c.methods.readRecord(c.pointer)
}
func (c *RawConn) HandlePostHandshakeMessage() error {
return c.methods.handlePostHandshakeMessage(c.pointer)
}
func (c *RawConn) WriteRecordLocked(typ uint16, data []byte) (int, error) {
return c.methods.writeRecordLocked(c.pointer, typ, data)
}

View File

@ -1,121 +0,0 @@
//go:build go1.25 && !without_badtls
package badtls
import (
"hash"
"reflect"
"sync"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
)
type RawHalfConn struct {
pointer unsafe.Pointer
methods *Methods
*sync.Mutex
Err *error
Version *uint16
Cipher *any
Seq *[8]byte
ScratchBuf *[13]byte
TrafficSecret *[]byte
Mac *hash.Hash
RawKey *[]byte
RawIV *[]byte
RawMac *[]byte
}
func NewRawHalfConn(rawHalfConn reflect.Value, methods *Methods) (*RawHalfConn, error) {
halfConn := &RawHalfConn{
pointer: (unsafe.Pointer)(rawHalfConn.UnsafeAddr()),
methods: methods,
}
rawMutex := rawHalfConn.FieldByName("Mutex")
if !rawMutex.IsValid() || rawMutex.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid halfConn.Mutex")
}
halfConn.Mutex = (*sync.Mutex)(unsafe.Pointer(rawMutex.UnsafeAddr()))
rawErr := rawHalfConn.FieldByName("err")
if !rawErr.IsValid() || rawErr.Kind() != reflect.Interface {
return nil, E.New("badtls: invalid halfConn.err")
}
halfConn.Err = (*error)(unsafe.Pointer(rawErr.UnsafeAddr()))
rawVersion := rawHalfConn.FieldByName("version")
if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 {
return nil, E.New("badtls: invalid halfConn.version")
}
halfConn.Version = (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr()))
rawCipher := rawHalfConn.FieldByName("cipher")
if !rawCipher.IsValid() || rawCipher.Kind() != reflect.Interface {
return nil, E.New("badtls: invalid halfConn.cipher")
}
halfConn.Cipher = (*any)(unsafe.Pointer(rawCipher.UnsafeAddr()))
rawSeq := rawHalfConn.FieldByName("seq")
if !rawSeq.IsValid() || rawSeq.Kind() != reflect.Array || rawSeq.Len() != 8 || rawSeq.Type().Elem().Kind() != reflect.Uint8 {
return nil, E.New("badtls: invalid halfConn.seq")
}
halfConn.Seq = (*[8]byte)(unsafe.Pointer(rawSeq.UnsafeAddr()))
rawScratchBuf := rawHalfConn.FieldByName("scratchBuf")
if !rawScratchBuf.IsValid() || rawScratchBuf.Kind() != reflect.Array || rawScratchBuf.Len() != 13 || rawScratchBuf.Type().Elem().Kind() != reflect.Uint8 {
return nil, E.New("badtls: invalid halfConn.scratchBuf")
}
halfConn.ScratchBuf = (*[13]byte)(unsafe.Pointer(rawScratchBuf.UnsafeAddr()))
rawTrafficSecret := rawHalfConn.FieldByName("trafficSecret")
if !rawTrafficSecret.IsValid() || rawTrafficSecret.Kind() != reflect.Slice || rawTrafficSecret.Type().Elem().Kind() != reflect.Uint8 {
return nil, E.New("badtls: invalid halfConn.trafficSecret")
}
halfConn.TrafficSecret = (*[]byte)(unsafe.Pointer(rawTrafficSecret.UnsafeAddr()))
rawMac := rawHalfConn.FieldByName("mac")
if !rawMac.IsValid() || rawMac.Kind() != reflect.Interface {
return nil, E.New("badtls: invalid halfConn.mac")
}
halfConn.Mac = (*hash.Hash)(unsafe.Pointer(rawMac.UnsafeAddr()))
rawKey := rawHalfConn.FieldByName("rawKey")
if rawKey.IsValid() {
if /*!rawKey.IsValid() || */ rawKey.Kind() != reflect.Slice || rawKey.Type().Elem().Kind() != reflect.Uint8 {
return nil, E.New("badtls: invalid halfConn.rawKey")
}
halfConn.RawKey = (*[]byte)(unsafe.Pointer(rawKey.UnsafeAddr()))
rawIV := rawHalfConn.FieldByName("rawIV")
if !rawIV.IsValid() || rawIV.Kind() != reflect.Slice || rawIV.Type().Elem().Kind() != reflect.Uint8 {
return nil, E.New("badtls: invalid halfConn.rawIV")
}
halfConn.RawIV = (*[]byte)(unsafe.Pointer(rawIV.UnsafeAddr()))
rawMAC := rawHalfConn.FieldByName("rawMac")
if !rawMAC.IsValid() || rawMAC.Kind() != reflect.Slice || rawMAC.Type().Elem().Kind() != reflect.Uint8 {
return nil, E.New("badtls: invalid halfConn.rawMac")
}
halfConn.RawMac = (*[]byte)(unsafe.Pointer(rawMAC.UnsafeAddr()))
}
return halfConn, nil
}
func (hc *RawHalfConn) Decrypt(record []byte) ([]byte, uint8, error) {
return hc.methods.decrypt(hc.pointer, record)
}
func (hc *RawHalfConn) SetErrorLocked(err error) error {
return hc.methods.setErrorLocked(hc.pointer, err)
}
func (hc *RawHalfConn) SetTrafficSecret(suite unsafe.Pointer, level int, secret []byte) {
hc.methods.setTrafficSecret(hc.pointer, suite, level, secret)
}
func (hc *RawHalfConn) ExplicitNonceLen() int {
return hc.methods.explicitNonceLen(hc.pointer)
}

View File

@ -1,9 +1,18 @@
//go:build go1.25 && !without_badtls //go:build go1.21 && !without_badtls
package badtls package badtls
import ( import (
"bytes"
"context"
"net"
"os"
"reflect"
"sync"
"unsafe"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/tls" "github.com/sagernet/sing/common/tls"
) )
@ -12,18 +21,63 @@ var _ N.ReadWaiter = (*ReadWaitConn)(nil)
type ReadWaitConn struct { type ReadWaitConn struct {
tls.Conn tls.Conn
rawConn *RawConn halfAccess *sync.Mutex
readWaitOptions N.ReadWaitOptions rawInput *bytes.Buffer
input *bytes.Reader
hand *bytes.Buffer
readWaitOptions N.ReadWaitOptions
tlsReadRecord func() error
tlsHandlePostHandshakeMessage func() error
} }
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
rawConn, err := NewRawConn(conn) var (
if err != nil { loaded bool
return nil, err tlsReadRecord func() error
tlsHandlePostHandshakeMessage func() error
)
for _, tlsCreator := range tlsRegistry {
loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
if loaded {
break
}
} }
if !loaded {
return nil, os.ErrInvalid
}
rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawHalfConn := rawConn.FieldByName("in")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn")
}
rawHalfMutex := rawHalfConn.FieldByName("Mutex")
if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half mutex")
}
halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
rawRawInput := rawConn.FieldByName("rawInput")
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid raw input")
}
rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
rawInput0 := rawConn.FieldByName("input")
if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid input")
}
input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
rawHand := rawConn.FieldByName("hand")
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid hand")
}
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
return &ReadWaitConn{ return &ReadWaitConn{
Conn: conn, Conn: conn,
rawConn: rawConn, halfAccess: halfAccess,
rawInput: rawInput,
input: input,
hand: hand,
tlsReadRecord: tlsReadRecord,
tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
}, nil }, nil
} }
@ -33,36 +87,36 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
} }
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
//err = c.HandshakeContext(context.Background()) err = c.HandshakeContext(context.Background())
//if err != nil { if err != nil {
// return return
//} }
c.rawConn.In.Lock() c.halfAccess.Lock()
defer c.rawConn.In.Unlock() defer c.halfAccess.Unlock()
for c.rawConn.Input.Len() == 0 { for c.input.Len() == 0 {
err = c.rawConn.ReadRecord() err = c.tlsReadRecord()
if err != nil { if err != nil {
return return
} }
for c.rawConn.Hand.Len() > 0 { for c.hand.Len() > 0 {
err = c.rawConn.HandlePostHandshakeMessage() err = c.tlsHandlePostHandshakeMessage()
if err != nil { if err != nil {
return return
} }
} }
} }
buffer = c.readWaitOptions.NewBuffer() buffer = c.readWaitOptions.NewBuffer()
n, err := c.rawConn.Input.Read(buffer.FreeBytes()) n, err := c.input.Read(buffer.FreeBytes())
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
buffer.Truncate(n) buffer.Truncate(n)
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 && if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
// recordType(c.RawInput.Bytes()[0]) == recordTypeAlert { // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
c.rawConn.RawInput.Bytes()[0] == 21 { c.rawInput.Bytes()[0] == 21 {
_ = c.rawConn.ReadRecord() _ = c.tlsReadRecord()
// return n, err // will be io.EOF on closeNotify // return n, err // will be io.EOF on closeNotify
} }
@ -74,6 +128,24 @@ func (c *ReadWaitConn) Upstream() any {
return c.Conn return c.Conn
} }
func (c *ReadWaitConn) ReaderReplaceable() bool { var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
return true
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := conn.(*tls.STDConn)
if !loaded {
return
}
return true, func() error {
return stdTLSReadRecord(tlsConn)
}, func() error {
return stdTLSHandlePostHandshakeMessage(tlsConn)
}
})
} }
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
func stdTLSReadRecord(c *tls.STDConn) error
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error

View File

@ -0,0 +1,31 @@
//go:build go1.21 && !without_badtls && with_ech
package badtls
import (
"net"
_ "unsafe"
"github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing/common"
)
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := common.Cast[*tls.Conn](conn)
if !loaded {
return
}
return true, func() error {
return echReadRecord(tlsConn)
}, func() error {
return echHandlePostHandshakeMessage(tlsConn)
}
})
}
//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
func echReadRecord(c *tls.Conn) error
//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
func echHandlePostHandshakeMessage(c *tls.Conn) error

View File

@ -1,4 +1,4 @@
//go:build !go1.25 || without_badtls //go:build !go1.21 || without_badtls
package badtls package badtls

View File

@ -0,0 +1,31 @@
//go:build go1.21 && !without_badtls && with_utls
package badtls
import (
"net"
_ "unsafe"
"github.com/sagernet/sing/common"
"github.com/sagernet/utls"
)
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := common.Cast[*tls.UConn](conn)
if !loaded {
return
}
return true, func() error {
return utlsReadRecord(tlsConn.Conn)
}, func() error {
return utlsHandlePostHandshakeMessage(tlsConn.Conn)
}
})
}
//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
func utlsReadRecord(c *tls.Conn) error
//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
func utlsHandlePostHandshakeMessage(c *tls.Conn) error

View File

@ -1,62 +0,0 @@
//go:build go1.25 && !without_badtls
package badtls
import (
"crypto/tls"
"net"
"unsafe"
)
type Methods struct {
readRecord func(c unsafe.Pointer) error
handlePostHandshakeMessage func(c unsafe.Pointer) error
writeRecordLocked func(c unsafe.Pointer, typ uint16, data []byte) (int, error)
setErrorLocked func(hc unsafe.Pointer, err error) error
decrypt func(hc unsafe.Pointer, record []byte) ([]byte, uint8, error)
setTrafficSecret func(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte)
explicitNonceLen func(hc unsafe.Pointer) int
}
var methodRegistry []func(conn net.Conn) (unsafe.Pointer, *Methods, bool)
func init() {
methodRegistry = append(methodRegistry, func(conn net.Conn) (unsafe.Pointer, *Methods, bool) {
tlsConn, loaded := conn.(*tls.Conn)
if !loaded {
return nil, nil, false
}
return unsafe.Pointer(tlsConn), &Methods{
readRecord: stdTLSReadRecord,
handlePostHandshakeMessage: stdTLSHandlePostHandshakeMessage,
writeRecordLocked: stdWriteRecordLocked,
setErrorLocked: stdSetErrorLocked,
decrypt: stdDecrypt,
setTrafficSecret: stdSetTrafficSecret,
explicitNonceLen: stdExplicitNonceLen,
}, true
})
}
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
func stdTLSReadRecord(c unsafe.Pointer) error
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
func stdTLSHandlePostHandshakeMessage(c unsafe.Pointer) error
//go:linkname stdWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked
func stdWriteRecordLocked(c unsafe.Pointer, typ uint16, data []byte) (int, error)
//go:linkname stdSetErrorLocked crypto/tls.(*halfConn).setErrorLocked
func stdSetErrorLocked(hc unsafe.Pointer, err error) error
//go:linkname stdDecrypt crypto/tls.(*halfConn).decrypt
func stdDecrypt(hc unsafe.Pointer, record []byte) ([]byte, uint8, error)
//go:linkname stdSetTrafficSecret crypto/tls.(*halfConn).setTrafficSecret
func stdSetTrafficSecret(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte)
//go:linkname stdExplicitNonceLen crypto/tls.(*halfConn).explicitNonceLen
func stdExplicitNonceLen(hc unsafe.Pointer) int

View File

@ -1,56 +0,0 @@
//go:build go1.25 && !without_badtls
package badtls
import (
"net"
"unsafe"
N "github.com/sagernet/sing/common/network"
"github.com/metacubex/utls"
)
func init() {
methodRegistry = append(methodRegistry, func(conn net.Conn) (unsafe.Pointer, *Methods, bool) {
var pointer unsafe.Pointer
if uConn, loaded := N.CastReader[*tls.Conn](conn); loaded {
pointer = unsafe.Pointer(uConn)
} else if uConn, loaded := N.CastReader[*tls.UConn](conn); loaded {
pointer = unsafe.Pointer(uConn.Conn)
} else {
return nil, nil, false
}
return pointer, &Methods{
readRecord: utlsReadRecord,
handlePostHandshakeMessage: utlsHandlePostHandshakeMessage,
writeRecordLocked: utlsWriteRecordLocked,
setErrorLocked: utlsSetErrorLocked,
decrypt: utlsDecrypt,
setTrafficSecret: utlsSetTrafficSecret,
explicitNonceLen: utlsExplicitNonceLen,
}, true
})
}
//go:linkname utlsReadRecord github.com/metacubex/utls.(*Conn).readRecord
func utlsReadRecord(c unsafe.Pointer) error
//go:linkname utlsHandlePostHandshakeMessage github.com/metacubex/utls.(*Conn).handlePostHandshakeMessage
func utlsHandlePostHandshakeMessage(c unsafe.Pointer) error
//go:linkname utlsWriteRecordLocked github.com/metacubex/utls.(*Conn).writeRecordLocked
func utlsWriteRecordLocked(hc unsafe.Pointer, typ uint16, data []byte) (int, error)
//go:linkname utlsSetErrorLocked github.com/metacubex/utls.(*halfConn).setErrorLocked
func utlsSetErrorLocked(hc unsafe.Pointer, err error) error
//go:linkname utlsDecrypt github.com/metacubex/utls.(*halfConn).decrypt
func utlsDecrypt(hc unsafe.Pointer, record []byte) ([]byte, uint8, error)
//go:linkname utlsSetTrafficSecret github.com/metacubex/utls.(*halfConn).setTrafficSecret
func utlsSetTrafficSecret(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte)
//go:linkname utlsExplicitNonceLen github.com/metacubex/utls.(*halfConn).explicitNonceLen
func utlsExplicitNonceLen(hc unsafe.Pointer) int

File diff suppressed because it is too large Load Diff

View File

@ -1,185 +0,0 @@
package certificate
import (
"context"
"crypto/x509"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
)
var _ adapter.CertificateStore = (*Store)(nil)
type Store struct {
systemPool *x509.CertPool
currentPool *x509.CertPool
certificate string
certificatePaths []string
certificateDirectoryPaths []string
watcher *fswatch.Watcher
}
func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) {
var systemPool *x509.CertPool
switch options.Store {
case C.CertificateStoreSystem, "":
systemPool = x509.NewCertPool()
platformInterface := service.FromContext[platform.Interface](ctx)
var systemValid bool
if platformInterface != nil {
for _, cert := range platformInterface.SystemCertificates() {
if systemPool.AppendCertsFromPEM([]byte(cert)) {
systemValid = true
}
}
}
if !systemValid {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
systemPool = certPool
}
case C.CertificateStoreMozilla:
systemPool = mozillaIncluded
case C.CertificateStoreNone:
systemPool = nil
default:
return nil, E.New("unknown certificate store: ", options.Store)
}
store := &Store{
systemPool: systemPool,
certificate: strings.Join(options.Certificate, "\n"),
certificatePaths: options.CertificatePath,
certificateDirectoryPaths: options.CertificateDirectoryPath,
}
var watchPaths []string
for _, target := range options.CertificatePath {
watchPaths = append(watchPaths, target)
}
for _, target := range options.CertificateDirectoryPath {
watchPaths = append(watchPaths, target)
}
if len(watchPaths) > 0 {
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPaths,
Logger: logger,
Callback: func(_ string) {
err := store.update()
if err != nil {
logger.Error(E.Cause(err, "reload certificates"))
}
},
})
if err != nil {
return nil, E.Cause(err, "fswatch: create fsnotify watcher")
}
store.watcher = watcher
}
err := store.update()
if err != nil {
return nil, E.Cause(err, "initializing certificate store")
}
return store, nil
}
func (s *Store) Name() string {
return "certificate"
}
func (s *Store) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if s.watcher != nil {
return s.watcher.Start()
}
return nil
}
func (s *Store) Close() error {
if s.watcher != nil {
return s.watcher.Close()
}
return nil
}
func (s *Store) Pool() *x509.CertPool {
return s.currentPool
}
func (s *Store) update() error {
var currentPool *x509.CertPool
if s.systemPool == nil {
currentPool = x509.NewCertPool()
} else {
currentPool = s.systemPool.Clone()
}
if s.certificate != "" {
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
return E.New("invalid certificate PEM strings")
}
}
for _, path := range s.certificatePaths {
pemContent, err := os.ReadFile(path)
if err != nil {
return err
}
if !currentPool.AppendCertsFromPEM(pemContent) {
return E.New("invalid certificate PEM file: ", path)
}
}
var firstErr error
for _, directoryPath := range s.certificateDirectoryPaths {
directoryEntries, err := readUniqueDirectoryEntries(directoryPath)
if err != nil {
if firstErr == nil && !os.IsNotExist(err) {
firstErr = E.Cause(err, "invalid certificate directory: ", directoryPath)
}
continue
}
for _, directoryEntry := range directoryEntries {
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
if err == nil {
currentPool.AppendCertsFromPEM(pemContent)
}
}
}
if firstErr != nil {
return firstErr
}
s.currentPool = currentPool
return nil
}
func readUniqueDirectoryEntries(dir string) ([]fs.DirEntry, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
uniq := files[:0]
for _, f := range files {
if !isSameDirSymlink(f, dir) {
uniq = append(uniq, f)
}
}
return uniq, nil
}
func isSameDirSymlink(f fs.DirEntry, dir string) bool {
if f.Type()&fs.ModeSymlink == 0 {
return false
}
target, err := os.Readlink(filepath.Join(dir, f.Name()))
return err == nil && !strings.Contains(target, "/")
}

View File

@ -2,24 +2,20 @@ package dialer
import ( import (
"context" "context"
"errors"
"net" "net"
"net/netip" "net/netip"
"syscall"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/listener"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
var ( var (
@ -28,36 +24,31 @@ var (
) )
type DefaultDialer struct { type DefaultDialer struct {
dialer4 tcpDialer dialer4 tcpDialer
dialer6 tcpDialer dialer6 tcpDialer
udpDialer4 net.Dialer udpDialer4 net.Dialer
udpDialer6 net.Dialer udpDialer6 net.Dialer
udpListener net.ListenConfig udpListener net.ListenConfig
udpAddr4 string udpAddr4 string
udpAddr6 string udpAddr6 string
netns string isWireGuardListener bool
networkManager adapter.NetworkManager networkManager adapter.NetworkManager
networkStrategy *C.NetworkStrategy networkStrategy C.NetworkStrategy
defaultNetworkStrategy bool networkType []C.InterfaceType
networkType []C.InterfaceType fallbackNetworkType []C.InterfaceType
fallbackNetworkType []C.InterfaceType networkFallbackDelay time.Duration
networkFallbackDelay time.Duration networkLastFallback atomic.TypedValue[time.Time]
networkLastFallback common.TypedValue[time.Time]
} }
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) { func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) {
networkManager := service.FromContext[adapter.NetworkManager](ctx)
platformInterface := service.FromContext[platform.Interface](ctx)
var ( var (
dialer net.Dialer dialer net.Dialer
listener net.ListenConfig listener net.ListenConfig
interfaceFinder control.InterfaceFinder interfaceFinder control.InterfaceFinder
networkStrategy *C.NetworkStrategy networkStrategy C.NetworkStrategy
defaultNetworkStrategy bool networkType []C.InterfaceType
networkType []C.InterfaceType fallbackNetworkType []C.InterfaceType
fallbackNetworkType []C.InterfaceType networkFallbackDelay time.Duration
networkFallbackDelay time.Duration
) )
if networkManager != nil { if networkManager != nil {
interfaceFinder = networkManager.InterfaceFinder() interfaceFinder = networkManager.InterfaceFinder()
@ -65,51 +56,48 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
interfaceFinder = control.NewDefaultInterfaceFinder() interfaceFinder = control.NewDefaultInterfaceFinder()
} }
if options.BindInterface != "" { if options.BindInterface != "" {
if !(C.IsLinux || C.IsDarwin || C.IsWindows) {
return nil, E.New("`bind_interface` is only supported on Linux, macOS and Windows")
}
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
dialer.Control = control.Append(dialer.Control, bindFunc) dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc)
} }
if options.RoutingMark > 0 { if options.RoutingMark > 0 {
if !C.IsLinux { dialer.Control = control.Append(dialer.Control, control.RoutingMark(uint32(options.RoutingMark)))
return nil, E.New("`routing_mark` is only supported on Linux") listener.Control = control.Append(listener.Control, control.RoutingMark(uint32(options.RoutingMark)))
}
dialer.Control = control.Append(dialer.Control, setMarkWrapper(networkManager, uint32(options.RoutingMark), false))
listener.Control = control.Append(listener.Control, setMarkWrapper(networkManager, uint32(options.RoutingMark), false))
} }
disableDefaultBind := options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil
if disableDefaultBind || options.TCPFastOpen {
if options.NetworkStrategy != nil || len(options.NetworkType) > 0 && options.FallbackNetworkType == nil && options.FallbackDelay == 0 {
return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address`, `inet6_bind_address` and `tcp_fast_open`")
}
}
if networkManager != nil { if networkManager != nil {
autoRedirectOutputMark := networkManager.AutoRedirectOutputMark()
if autoRedirectOutputMark > 0 {
if options.RoutingMark > 0 {
return nil, E.New("`routing_mark` is conflict with `tun.auto_redirect` with `tun.route_[_exclude]_address_set")
}
dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark))
listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark))
}
}
if C.NetworkStrategy(options.NetworkStrategy) != C.NetworkStrategyDefault {
if options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil {
return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address` and `inet6_bind_address`")
}
networkStrategy = C.NetworkStrategy(options.NetworkStrategy)
networkType = common.Map(options.NetworkType, option.InterfaceType.Build)
fallbackNetworkType = common.Map(options.FallbackNetworkType, option.InterfaceType.Build)
networkFallbackDelay = time.Duration(options.NetworkFallbackDelay)
if networkManager == nil || !networkManager.AutoDetectInterface() {
return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
}
}
if networkManager != nil && options.BindInterface == "" && options.Inet4BindAddress == nil && options.Inet6BindAddress == nil {
defaultOptions := networkManager.DefaultOptions() defaultOptions := networkManager.DefaultOptions()
if defaultOptions.BindInterface != "" { if defaultOptions.BindInterface != "" {
bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.BindInterface, -1) bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.BindInterface, -1)
dialer.Control = control.Append(dialer.Control, bindFunc) dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc)
} else if networkManager.AutoDetectInterface() && !disableDefaultBind { } else if networkManager.AutoDetectInterface() {
if platformInterface != nil { if defaultOptions.NetworkStrategy != C.NetworkStrategyDefault && C.NetworkStrategy(options.NetworkStrategy) == C.NetworkStrategyDefault {
networkStrategy = (*C.NetworkStrategy)(options.NetworkStrategy) networkStrategy = defaultOptions.NetworkStrategy
networkType = common.Map(options.NetworkType, option.InterfaceType.Build) networkType = defaultOptions.NetworkType
fallbackNetworkType = common.Map(options.FallbackNetworkType, option.InterfaceType.Build) fallbackNetworkType = defaultOptions.FallbackNetworkType
if networkStrategy == nil && len(networkType) == 0 && len(fallbackNetworkType) == 0 { networkFallbackDelay = defaultOptions.FallbackDelay
networkStrategy = defaultOptions.NetworkStrategy
networkType = defaultOptions.NetworkType
fallbackNetworkType = defaultOptions.FallbackNetworkType
}
networkFallbackDelay = time.Duration(options.FallbackDelay)
if networkFallbackDelay == 0 && defaultOptions.FallbackDelay != 0 {
networkFallbackDelay = defaultOptions.FallbackDelay
}
if networkStrategy == nil {
networkStrategy = common.Ptr(C.NetworkStrategyDefault)
defaultNetworkStrategy = true
}
bindFunc := networkManager.ProtectFunc() bindFunc := networkManager.ProtectFunc()
dialer.Control = control.Append(dialer.Control, bindFunc) dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc)
@ -119,15 +107,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
listener.Control = control.Append(listener.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc)
} }
} }
if options.RoutingMark == 0 && defaultOptions.RoutingMark != 0 {
dialer.Control = control.Append(dialer.Control, setMarkWrapper(networkManager, defaultOptions.RoutingMark, true))
listener.Control = control.Append(listener.Control, setMarkWrapper(networkManager, defaultOptions.RoutingMark, true))
}
}
if networkManager != nil {
markFunc := networkManager.AutoRedirectOutputMarkFunc()
dialer.Control = control.Append(dialer.Control, markFunc)
listener.Control = control.Append(listener.Control, markFunc)
} }
if options.ReuseAddr { if options.ReuseAddr {
listener.Control = control.Append(listener.Control, control.ReuseAddr()) listener.Control = control.Append(listener.Control, control.ReuseAddr())
@ -182,6 +161,14 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
} }
setMultiPathTCP(&dialer4) setMultiPathTCP(&dialer4)
} }
if options.IsWireGuardListener {
for _, controlFn := range WgControlFns {
listener.Control = control.Append(listener.Control, controlFn)
}
}
if networkStrategy != C.NetworkStrategyDefault && options.TCPFastOpen {
return nil, E.New("`tcp_fast_open` is conflict with `network_strategy` or `route.default_network_strategy`")
}
tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
if err != nil { if err != nil {
return nil, err return nil, err
@ -191,81 +178,51 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
return nil, err return nil, err
} }
return &DefaultDialer{ return &DefaultDialer{
dialer4: tcpDialer4, dialer4: tcpDialer4,
dialer6: tcpDialer6, dialer6: tcpDialer6,
udpDialer4: udpDialer4, udpDialer4: udpDialer4,
udpDialer6: udpDialer6, udpDialer6: udpDialer6,
udpListener: listener, udpListener: listener,
udpAddr4: udpAddr4, udpAddr4: udpAddr4,
udpAddr6: udpAddr6, udpAddr6: udpAddr6,
netns: options.NetNs, isWireGuardListener: options.IsWireGuardListener,
networkManager: networkManager, networkManager: networkManager,
networkStrategy: networkStrategy, networkStrategy: networkStrategy,
defaultNetworkStrategy: defaultNetworkStrategy, networkType: networkType,
networkType: networkType, fallbackNetworkType: fallbackNetworkType,
fallbackNetworkType: fallbackNetworkType, networkFallbackDelay: networkFallbackDelay,
networkFallbackDelay: networkFallbackDelay,
}, nil }, nil
} }
func setMarkWrapper(networkManager adapter.NetworkManager, mark uint32, isDefault bool) control.Func {
if networkManager == nil {
return control.RoutingMark(mark)
}
return func(network, address string, conn syscall.RawConn) error {
if networkManager.AutoRedirectOutputMark() != 0 {
if isDefault {
return E.New("`route.default_mark` is conflict with `tun.auto_redirect`")
} else {
return E.New("`routing_mark` is conflict with `tun.auto_redirect`")
}
}
return control.RoutingMark(mark)(network, address, conn)
}
}
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
if !address.IsValid() { if !address.IsValid() {
return nil, E.New("invalid address") return nil, E.New("invalid address")
} else if address.IsFqdn() {
return nil, E.New("domain not resolved")
} }
if d.networkStrategy == nil { if d.networkStrategy == C.NetworkStrategyDefault {
return trackConn(listener.ListenNetworkNamespace[net.Conn](d.netns, func() (net.Conn, error) { switch N.NetworkName(network) {
switch N.NetworkName(network) { case N.NetworkUDP:
case N.NetworkUDP:
if !address.IsIPv6() {
return d.udpDialer4.DialContext(ctx, network, address.String())
} else {
return d.udpDialer6.DialContext(ctx, network, address.String())
}
}
if !address.IsIPv6() { if !address.IsIPv6() {
return DialSlowContext(&d.dialer4, ctx, network, address) return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
} else { } else {
return DialSlowContext(&d.dialer6, ctx, network, address) return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
} }
})) }
if !address.IsIPv6() {
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
} else {
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
}
} else { } else {
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay) return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
} }
} }
func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
if strategy == nil { if strategy == C.NetworkStrategyDefault {
strategy = d.networkStrategy
}
if strategy == nil {
return d.DialContext(ctx, network, address) return d.DialContext(ctx, network, address)
} }
if len(interfaceType) == 0 { if !d.networkManager.AutoDetectInterface() {
interfaceType = d.networkType return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
}
if len(fallbackInterfaceType) == 0 {
fallbackInterfaceType = d.fallbackNetworkType
}
if fallbackDelay == 0 {
fallbackDelay = d.networkFallbackDelay
} }
var dialer net.Dialer var dialer net.Dialer
if N.NetworkName(network) == N.NetworkTCP { if N.NetworkName(network) == N.NetworkTCP {
@ -273,25 +230,19 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
} else { } else {
dialer = d.udpDialer4 dialer = d.udpDialer4
} }
fastFallback := time.Since(d.networkLastFallback.Load()) < C.TCPTimeout fastFallback := time.Now().Sub(d.networkLastFallback.Load()) < C.TCPTimeout
var ( var (
conn net.Conn conn net.Conn
isPrimary bool isPrimary bool
err error err error
) )
if !fastFallback { if !fastFallback {
conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), *strategy, interfaceType, fallbackInterfaceType, fallbackDelay) conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} else { } else {
conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), *strategy, interfaceType, fallbackInterfaceType, fallbackDelay, d.networkLastFallback.Store) conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), strategy, interfaceType, fallbackInterfaceType, fallbackDelay, d.networkLastFallback.Store)
} }
if err != nil { if err != nil {
// bind interface failed on legacy xiaomi systems return nil, err
if d.defaultNetworkStrategy && errors.Is(err, syscall.EPERM) {
d.networkStrategy = nil
return d.DialContext(ctx, network, address)
} else {
return nil, err
}
} }
if !fastFallback && !isPrimary { if !fastFallback && !isPrimary {
d.networkLastFallback.Store(time.Now()) d.networkLastFallback.Store(time.Now())
@ -300,74 +251,35 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
} }
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if d.networkStrategy == nil { if d.networkStrategy == C.NetworkStrategyDefault {
return trackPacketConn(listener.ListenNetworkNamespace[net.PacketConn](d.netns, func() (net.PacketConn, error) { if destination.IsIPv6() {
if destination.IsIPv6() { return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6) } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() { return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
return d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4) } else {
} else { return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4) }
}
}))
} else { } else {
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay) return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
} }
} }
func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer { func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
if !destination.Is6() { if strategy == C.NetworkStrategyDefault {
return dialerFromTCPDialer(d.dialer6)
} else {
return dialerFromTCPDialer(d.dialer4)
}
}
func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
if strategy == nil {
strategy = d.networkStrategy
}
if strategy == nil {
return d.ListenPacket(ctx, destination) return d.ListenPacket(ctx, destination)
} }
if len(interfaceType) == 0 { if !d.networkManager.AutoDetectInterface() {
interfaceType = d.networkType return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
}
if len(fallbackInterfaceType) == 0 {
fallbackInterfaceType = d.fallbackNetworkType
}
if fallbackDelay == 0 {
fallbackDelay = d.networkFallbackDelay
} }
network := N.NetworkUDP network := N.NetworkUDP
if destination.IsIPv4() && !destination.Addr.IsUnspecified() { if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
network += "4" network += "4"
} }
packetConn, err := d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", *strategy, interfaceType, fallbackInterfaceType, fallbackDelay) return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", strategy, interfaceType, fallbackInterfaceType, fallbackDelay))
if err != nil {
// bind interface failed on legacy xiaomi systems
if d.defaultNetworkStrategy && errors.Is(err, syscall.EPERM) {
d.networkStrategy = nil
return d.ListenPacket(ctx, destination)
} else {
return nil, err
}
}
return trackPacketConn(packetConn, nil)
} }
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
udpListener := d.udpListener return d.udpListener.ListenPacket(context.Background(), network, address)
udpListener.Control = control.Append(udpListener.Control, func(network, address string, conn syscall.RawConn) error {
for _, wgControlFn := range WgControlFns {
err := wgControlFn(network, address, conn)
if err != nil {
return err
}
}
return nil
})
return udpListener.ListenPacket(context.Background(), network, address)
} }
func trackConn(conn net.Conn, err error) (net.Conn, error) { func trackConn(conn net.Conn, err error) (net.Conn, error) {

View File

@ -18,7 +18,6 @@ func (d *DefaultDialer) dialParallelInterface(ctx context.Context, dialer net.Di
if len(primaryInterfaces)+len(fallbackInterfaces) == 0 { if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
return nil, false, E.New("no available network interface") return nil, false, E.New("no available network interface")
} }
defaultInterface := d.networkManager.InterfaceMonitor().DefaultInterface()
if fallbackDelay == 0 { if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay fallbackDelay = N.DefaultFallbackDelay
} }
@ -32,18 +31,16 @@ func (d *DefaultDialer) dialParallelInterface(ctx context.Context, dialer net.Di
results := make(chan dialResult) // unbuffered results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) { startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) {
perNetDialer := dialer perNetDialer := dialer
if defaultInterface == nil || iif.Index != defaultInterface.Index { perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
}
conn, err := perNetDialer.DialContext(ctx, network, addr) conn, err := perNetDialer.DialContext(ctx, network, addr)
if err != nil { if err != nil {
select { select {
case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Index, ")"), primary: primary}: case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Name, ")"), primary: primary}:
case <-returned: case <-returned:
} }
} else { } else {
select { select {
case results <- dialResult{Conn: conn, primary: primary}: case results <- dialResult{Conn: conn}:
case <-returned: case <-returned:
conn.Close() conn.Close()
} }
@ -92,7 +89,6 @@ func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, d
if len(primaryInterfaces)+len(fallbackInterfaces) == 0 { if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
return nil, false, E.New("no available network interface") return nil, false, E.New("no available network interface")
} }
defaultInterface := d.networkManager.InterfaceMonitor().DefaultInterface()
if fallbackDelay == 0 { if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay fallbackDelay = N.DefaultFallbackDelay
} }
@ -107,18 +103,16 @@ func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, d
results := make(chan dialResult) // unbuffered results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) { startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) {
perNetDialer := dialer perNetDialer := dialer
if defaultInterface == nil || iif.Index != defaultInterface.Index { perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
}
conn, err := perNetDialer.DialContext(ctx, network, addr) conn, err := perNetDialer.DialContext(ctx, network, addr)
if err != nil { if err != nil {
select { select {
case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Index, ")"), primary: primary}: case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Name, ")"), primary: primary}:
case <-returned: case <-returned:
} }
} else { } else {
select { select {
case results <- dialResult{Conn: conn, primary: primary}: case results <- dialResult{Conn: conn}:
case <-returned: case <-returned:
if primary && time.Since(startAt) <= fallbackDelay { if primary && time.Since(startAt) <= fallbackDelay {
resetFastFallback(time.Time{}) resetFastFallback(time.Time{})
@ -155,29 +149,24 @@ func (d *DefaultDialer) listenSerialInterfacePacket(ctx context.Context, listene
if len(primaryInterfaces)+len(fallbackInterfaces) == 0 { if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
return nil, E.New("no available network interface") return nil, E.New("no available network interface")
} }
defaultInterface := d.networkManager.InterfaceMonitor().DefaultInterface()
var errors []error var errors []error
for _, primaryInterface := range primaryInterfaces { for _, primaryInterface := range primaryInterfaces {
perNetListener := listener perNetListener := listener
if defaultInterface == nil || primaryInterface.Index != defaultInterface.Index { perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, primaryInterface.Name, primaryInterface.Index))
perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, primaryInterface.Name, primaryInterface.Index))
}
conn, err := perNetListener.ListenPacket(ctx, network, addr) conn, err := perNetListener.ListenPacket(ctx, network, addr)
if err == nil { if err == nil {
return conn, nil return conn, nil
} }
errors = append(errors, E.Cause(err, "listen ", primaryInterface.Name, " (", primaryInterface.Index, ")")) errors = append(errors, E.Cause(err, "listen ", primaryInterface.Name, " (", primaryInterface.Name, ")"))
} }
for _, fallbackInterface := range fallbackInterfaces { for _, fallbackInterface := range fallbackInterfaces {
perNetListener := listener perNetListener := listener
if defaultInterface == nil || fallbackInterface.Index != defaultInterface.Index { perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, fallbackInterface.Name, fallbackInterface.Index))
perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, fallbackInterface.Name, fallbackInterface.Index))
}
conn, err := perNetListener.ListenPacket(ctx, network, addr) conn, err := perNetListener.ListenPacket(ctx, network, addr)
if err == nil { if err == nil {
return conn, nil return conn, nil
} }
errors = append(errors, E.Cause(err, "listen ", fallbackInterface.Name, " (", fallbackInterface.Index, ")")) errors = append(errors, E.Cause(err, "listen ", fallbackInterface.Name, " (", fallbackInterface.Name, ")"))
} }
return nil, E.Errors(errors...) return nil, E.Errors(errors...)
} }
@ -188,57 +177,44 @@ func selectInterfaces(networkManager adapter.NetworkManager, strategy C.NetworkS
case C.NetworkStrategyDefault: case C.NetworkStrategyDefault:
if len(interfaceType) == 0 { if len(interfaceType) == 0 {
defaultIf := networkManager.InterfaceMonitor().DefaultInterface() defaultIf := networkManager.InterfaceMonitor().DefaultInterface()
if defaultIf != nil { for _, iif := range interfaces {
for _, iif := range interfaces { if iif.Index == defaultIf.Index {
if iif.Index == defaultIf.Index { primaryInterfaces = append(primaryInterfaces, iif)
primaryInterfaces = append(primaryInterfaces, iif) } else {
} fallbackInterfaces = append(fallbackInterfaces, iif)
} }
} else {
primaryInterfaces = interfaces
} }
} else { } else {
primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { primaryInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool {
return common.Contains(interfaceType, it.Type) return common.Contains(interfaceType, iif.Type)
}) })
} }
case C.NetworkStrategyHybrid: case C.NetworkStrategyHybrid:
if len(interfaceType) == 0 { if len(interfaceType) == 0 {
primaryInterfaces = interfaces primaryInterfaces = interfaces
} else { } else {
primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { primaryInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool {
return common.Contains(interfaceType, it.Type) return common.Contains(interfaceType, iif.Type)
}) })
} }
case C.NetworkStrategyFallback: case C.NetworkStrategyFallback:
if len(interfaceType) == 0 { if len(interfaceType) == 0 {
defaultIf := networkManager.InterfaceMonitor().DefaultInterface() defaultIf := networkManager.InterfaceMonitor().DefaultInterface()
if defaultIf != nil { for _, iif := range interfaces {
for _, iif := range interfaces { if iif.Index == defaultIf.Index {
if iif.Index == defaultIf.Index { primaryInterfaces = append(primaryInterfaces, iif)
primaryInterfaces = append(primaryInterfaces, iif) } else {
break fallbackInterfaces = append(fallbackInterfaces, iif)
}
} }
} else {
primaryInterfaces = interfaces
} }
} else { } else {
primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { primaryInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool {
return common.Contains(interfaceType, it.Type) return common.Contains(interfaceType, iif.Type)
})
}
if len(fallbackInterfaceType) == 0 {
fallbackInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool {
return !common.Any(primaryInterfaces, func(iif adapter.NetworkInterface) bool {
return it.Index == iif.Index
})
})
} else {
fallbackInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool {
return common.Contains(fallbackInterfaceType, iif.Type)
}) })
} }
fallbackInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool {
return common.Contains(fallbackInterfaceType, iif.Type)
})
} }
return primaryInterfaces, fallbackInterfaces return primaryInterfaces, fallbackInterfaces
} }

View File

@ -13,13 +13,7 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
if len(destinationAddresses) == 0 {
if !destination.IsIP() {
panic("invalid usage")
}
destinationAddresses = []netip.Addr{destination.Addr}
}
if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} }
@ -44,14 +38,7 @@ func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, des
return nil, E.Errors(errors...) return nil, E.Errors(errors...)
} }
func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
if len(destinationAddresses) == 0 {
if !destination.IsIP() {
panic("invalid usage")
}
destinationAddresses = []netip.Addr{destination.Addr}
}
if fallbackDelay == 0 { if fallbackDelay == 0 {
fallbackDelay = N.DefaultFallbackDelay fallbackDelay = N.DefaultFallbackDelay
} }
@ -129,13 +116,7 @@ func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, ne
} }
} }
func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
if len(destinationAddresses) == 0 {
if !destination.IsIP() {
panic("invalid usage")
}
destinationAddresses = []netip.Addr{destination.Addr}
}
if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} }

View File

@ -6,61 +6,37 @@ import (
"sync" "sync"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type DirectDialer interface {
IsEmpty() bool
}
type DetourDialer struct { type DetourDialer struct {
outboundManager adapter.OutboundManager outboundManager adapter.OutboundManager
detour string detour string
legacyDNSDialer bool
dialer N.Dialer dialer N.Dialer
initOnce sync.Once initOnce sync.Once
initErr error initErr error
} }
func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNSDialer bool) N.Dialer { func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer {
return &DetourDialer{ return &DetourDialer{outboundManager: outboundManager, detour: detour}
outboundManager: outboundManager,
detour: detour,
legacyDNSDialer: legacyDNSDialer,
}
} }
func InitializeDetour(dialer N.Dialer) error { func (d *DetourDialer) Start() error {
detourDialer, isDetour := common.Cast[*DetourDialer](dialer) _, err := d.Dialer()
if !isDetour { return err
return nil
}
return common.Error(detourDialer.Dialer())
} }
func (d *DetourDialer) Dialer() (N.Dialer, error) { func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(d.init) d.initOnce.Do(func() {
return d.dialer, d.initErr var loaded bool
} d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded {
func (d *DetourDialer) init() { d.initErr = E.New("outbound detour not found: ", d.detour)
dialer, loaded := d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
return
}
if !d.legacyDNSDialer {
if directDialer, isDirect := dialer.(DirectDialer); isDirect {
if directDialer.IsEmpty() {
d.initErr = E.New("detour to an empty direct outbound makes no sense")
return
}
} }
} })
d.dialer = dialer return d.dialer, d.initErr
} }
func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {

View File

@ -8,140 +8,80 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
) )
type Options struct { func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
Context context.Context networkManager := service.FromContext[adapter.NetworkManager](ctx)
Options option.DialerOptions if options.IsWireGuardListener {
RemoteIsDomain bool return NewDefault(networkManager, options)
DirectResolver bool }
ResolverOnDetour bool
NewDialer bool
LegacyDNSDialer bool
DirectOutbound bool
}
// TODO: merge with NewWithOptions
func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
return NewWithOptions(Options{
Context: ctx,
Options: options,
RemoteIsDomain: remoteIsDomain,
})
}
func NewWithOptions(options Options) (N.Dialer, error) {
dialOptions := options.Options
var ( var (
dialer N.Dialer dialer N.Dialer
err error err error
) )
if dialOptions.Detour != "" { if options.Detour == "" {
outboundManager := service.FromContext[adapter.OutboundManager](options.Context) dialer, err = NewDefault(networkManager, options)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
} else {
dialer, err = NewDefault(options.Context, dialOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, options.Detour)
} }
if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") { if networkManager == nil {
networkManager := service.FromContext[adapter.NetworkManager](options.Context) return NewDefault(networkManager, options)
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context) }
var defaultOptions adapter.NetworkOptions if options.Detour == "" {
if networkManager != nil { router := service.FromContext[adapter.Router](ctx)
defaultOptions = networkManager.DefaultOptions() if router != nil {
dialer = NewResolveDialer(
router,
dialer,
options.Detour == "" && !options.TCPFastOpen,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay))
} }
var (
server string
dnsQueryOptions adapter.DNSQueryOptions
resolveFallbackDelay time.Duration
)
if dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "" {
var transport adapter.DNSTransport
if !options.DirectResolver {
var loaded bool
transport, loaded = dnsTransport.Transport(dialOptions.DomainResolver.Server)
if !loaded {
return nil, E.New("domain resolver not found: " + dialOptions.DomainResolver.Server)
}
}
var strategy C.DomainStrategy
if dialOptions.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
strategy = C.DomainStrategy(dialOptions.DomainResolver.Strategy)
} else if
//nolint:staticcheck
dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck
strategy = C.DomainStrategy(dialOptions.DomainStrategy)
deprecated.Report(options.Context, deprecated.OptionLegacyDomainStrategyOptions)
}
server = dialOptions.DomainResolver.Server
dnsQueryOptions = adapter.DNSQueryOptions{
Transport: transport,
Strategy: strategy,
DisableCache: dialOptions.DomainResolver.DisableCache,
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
}
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else if options.DirectResolver {
return nil, E.New("missing domain resolver for domain server address")
} else {
if defaultOptions.DomainResolver != "" {
dnsQueryOptions = defaultOptions.DomainResolveOptions
transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver)
if !loaded {
return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver)
}
dnsQueryOptions.Transport = transport
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else {
transports := dnsTransport.Transports()
if len(transports) < 2 {
dnsQueryOptions.Transport = dnsTransport.Default()
} else if options.NewDialer {
return nil, E.New("missing domain resolver for domain server address")
} else {
deprecated.Report(options.Context, deprecated.OptionMissingDomainResolver)
}
}
if
//nolint:staticcheck
dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck
dnsQueryOptions.Strategy = C.DomainStrategy(dialOptions.DomainStrategy)
deprecated.Report(options.Context, deprecated.OptionLegacyDomainStrategyOptions)
}
}
dialer = NewResolveDialer(
options.Context,
dialer,
dialOptions.Detour == "" && !dialOptions.TCPFastOpen,
server,
dnsQueryOptions,
resolveFallbackDelay,
)
} }
return dialer, nil return dialer, nil
} }
func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInterfaceDialer, error) {
if options.Detour != "" {
return nil, E.New("`detour` is not supported in direct context")
}
networkManager := service.FromContext[adapter.NetworkManager](ctx)
if options.IsWireGuardListener {
return NewDefault(networkManager, options)
}
dialer, err := NewDefault(networkManager, options)
if err != nil {
return nil, err
}
return NewResolveParallelInterfaceDialer(
service.FromContext[adapter.Router](ctx),
dialer,
true,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay),
), nil
}
type ParallelInterfaceDialer interface { type ParallelInterfaceDialer interface {
N.Dialer N.Dialer
DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error)
ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error)
} }
type ParallelNetworkDialer interface { type ParallelNetworkDialer interface {
DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error)
ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error)
} }

View File

@ -3,17 +3,16 @@ package dialer
import ( import (
"context" "context"
"net" "net"
"sync" "net/netip"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
var ( var (
@ -21,51 +20,21 @@ var (
_ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil) _ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil)
) )
type ResolveDialer interface {
N.Dialer
QueryOptions() adapter.DNSQueryOptions
}
type ParallelInterfaceResolveDialer interface {
ParallelInterfaceDialer
QueryOptions() adapter.DNSQueryOptions
}
type resolveDialer struct { type resolveDialer struct {
transport adapter.DNSTransportManager
router adapter.DNSRouter
dialer N.Dialer dialer N.Dialer
parallel bool parallel bool
server string router adapter.Router
initOnce sync.Once strategy dns.DomainStrategy
initErr error
queryOptions adapter.DNSQueryOptions
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer { func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
return &resolveParallelNetworkDialer{
resolveDialer{
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
},
parallelDialer,
}
}
return &resolveDialer{ return &resolveDialer{
transport: service.FromContext[adapter.DNSTransportManager](ctx), dialer,
router: service.FromContext[adapter.DNSRouter](ctx), parallel,
dialer: dialer, router,
parallel: parallel, strategy,
server: server, fallbackDelay,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
} }
} }
@ -74,53 +43,59 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer dialer ParallelInterfaceDialer
} }
func (d *resolveDialer) initialize() error { func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
d.initOnce.Do(d.initServer) return &resolveParallelNetworkDialer{
return d.initErr resolveDialer{
} dialer,
parallel,
func (d *resolveDialer) initServer() { router,
if d.server == "" { strategy,
return fallbackDelay,
},
dialer,
} }
transport, loaded := d.transport.Transport(d.server)
if !loaded {
d.initErr = E.New("domain resolver not found: " + d.server)
return
}
d.queryOptions.Transport = transport
} }
func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions) metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if d.parallel { if d.parallel {
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, d.fallbackDelay)
} else { } else {
return N.DialSerial(ctx, d.dialer, network, destination, addresses) return N.DialSerial(ctx, d.dialer, network, destination, addresses)
} }
} }
func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions) metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -131,24 +106,21 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
} }
func (d *resolveDialer) QueryOptions() adapter.DNSQueryOptions { func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
return d.queryOptions
}
func (d *resolveDialer) Upstream() any {
return d.dialer
}
func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions) metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -156,28 +128,30 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
fallbackDelay = d.fallbackDelay fallbackDelay = d.fallbackDelay
} }
if d.parallel { if d.parallel {
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} else { } else {
return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} }
} }
func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions) metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if fallbackDelay == 0 {
fallbackDelay = d.fallbackDelay
}
conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
if err != nil { if err != nil {
return nil, err return nil, err
@ -185,10 +159,6 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
} }
func (d *resolveParallelNetworkDialer) QueryOptions() adapter.DNSQueryOptions { func (d *resolveDialer) Upstream() any {
return d.queryOptions
}
func (d *resolveParallelNetworkDialer) Upstream() any {
return d.dialer return d.dialer
} }

View File

@ -7,27 +7,24 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
type DefaultOutboundDialer struct { type DefaultOutboundDialer struct {
outbound adapter.OutboundManager outboundManager adapter.OutboundManager
} }
func NewDefaultOutbound(ctx context.Context) N.Dialer { func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer {
return &DefaultOutboundDialer{ return &DefaultOutboundDialer{outboundManager: outboundManager}
outbound: service.FromContext[adapter.OutboundManager](ctx),
}
} }
func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.outbound.Default().DialContext(ctx, network, destination) return d.outboundManager.Default().DialContext(ctx, network, destination)
} }
func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return d.outbound.Default().ListenPacket(ctx, destination) return d.outboundManager.Default().ListenPacket(ctx, destination)
} }
func (d *DefaultOutboundDialer) Upstream() any { func (d *DefaultOutboundDialer) Upstream() any {
return d.outbound.Default() return d.outboundManager.Default()
} }

View File

@ -8,11 +8,11 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -24,11 +24,9 @@ type slowOpenConn struct {
ctx context.Context ctx context.Context
network string network string
destination M.Socksaddr destination M.Socksaddr
conn atomic.Pointer[net.TCPConn] conn net.Conn
create chan struct{} create chan struct{}
done chan struct{}
access sync.Mutex access sync.Mutex
closeOnce sync.Once
err error err error
} }
@ -47,30 +45,26 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des
network: network, network: network,
destination: destination, destination: destination,
create: make(chan struct{}), create: make(chan struct{}),
done: make(chan struct{}),
}, nil }, nil
} }
func (c *slowOpenConn) Read(b []byte) (n int, err error) { func (c *slowOpenConn) Read(b []byte) (n int, err error) {
conn := c.conn.Load() if c.conn == nil {
if conn != nil { select {
return conn.Read(b) case <-c.create:
} if c.err != nil {
select { return 0, c.err
case <-c.create: }
if c.err != nil { case <-c.ctx.Done():
return 0, c.err return 0, c.ctx.Err()
} }
return c.conn.Load().Read(b)
case <-c.done:
return 0, os.ErrClosed
} }
return c.conn.Read(b)
} }
func (c *slowOpenConn) Write(b []byte) (n int, err error) { func (c *slowOpenConn) Write(b []byte) (n int, err error) {
tcpConn := c.conn.Load() if c.conn != nil {
if tcpConn != nil { return c.conn.Write(b)
return tcpConn.Write(b)
} }
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
@ -79,16 +73,13 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
if c.err != nil { if c.err != nil {
return 0, c.err return 0, c.err
} }
return c.conn.Load().Write(b) return c.conn.Write(b)
case <-c.done:
return 0, os.ErrClosed
default: default:
} }
conn, err := c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
if err != nil { if err != nil {
c.err = err c.conn = nil
} else { c.err = E.Cause(err, "dial tcp fast open")
c.conn.Store(conn.(*net.TCPConn))
} }
n = len(b) n = len(b)
close(c.create) close(c.create)
@ -96,87 +87,74 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
} }
func (c *slowOpenConn) Close() error { func (c *slowOpenConn) Close() error {
c.closeOnce.Do(func() { return common.Close(c.conn)
close(c.done)
conn := c.conn.Load()
if conn != nil {
conn.Close()
}
})
return nil
} }
func (c *slowOpenConn) LocalAddr() net.Addr { func (c *slowOpenConn) LocalAddr() net.Addr {
conn := c.conn.Load() if c.conn == nil {
if conn == nil {
return M.Socksaddr{} return M.Socksaddr{}
} }
return conn.LocalAddr() return c.conn.LocalAddr()
} }
func (c *slowOpenConn) RemoteAddr() net.Addr { func (c *slowOpenConn) RemoteAddr() net.Addr {
conn := c.conn.Load() if c.conn == nil {
if conn == nil {
return M.Socksaddr{} return M.Socksaddr{}
} }
return conn.RemoteAddr() return c.conn.RemoteAddr()
} }
func (c *slowOpenConn) SetDeadline(t time.Time) error { func (c *slowOpenConn) SetDeadline(t time.Time) error {
conn := c.conn.Load() if c.conn == nil {
if conn == nil {
return os.ErrInvalid return os.ErrInvalid
} }
return conn.SetDeadline(t) return c.conn.SetDeadline(t)
} }
func (c *slowOpenConn) SetReadDeadline(t time.Time) error { func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
conn := c.conn.Load() if c.conn == nil {
if conn == nil {
return os.ErrInvalid return os.ErrInvalid
} }
return conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
func (c *slowOpenConn) SetWriteDeadline(t time.Time) error { func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
conn := c.conn.Load() if c.conn == nil {
if conn == nil {
return os.ErrInvalid return os.ErrInvalid
} }
return conn.SetWriteDeadline(t) return c.conn.SetWriteDeadline(t)
} }
func (c *slowOpenConn) Upstream() any { func (c *slowOpenConn) Upstream() any {
return common.PtrOrNil(c.conn.Load()) return c.conn
} }
func (c *slowOpenConn) ReaderReplaceable() bool { func (c *slowOpenConn) ReaderReplaceable() bool {
return c.conn.Load() != nil return c.conn != nil
} }
func (c *slowOpenConn) WriterReplaceable() bool { func (c *slowOpenConn) WriterReplaceable() bool {
return c.conn.Load() != nil return c.conn != nil
} }
func (c *slowOpenConn) LazyHeadroom() bool { func (c *slowOpenConn) LazyHeadroom() bool {
return c.conn.Load() == nil return c.conn == nil
} }
func (c *slowOpenConn) NeedHandshake() bool { func (c *slowOpenConn) NeedHandshake() bool {
return c.conn.Load() == nil return c.conn == nil
} }
func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
conn := c.conn.Load() if c.conn == nil {
if conn == nil {
select { select {
case <-c.create: case <-c.create:
if c.err != nil { if c.err != nil {
return 0, c.err return 0, c.err
} }
case <-c.done: case <-c.ctx.Done():
return 0, c.err return 0, c.ctx.Err()
} }
} }
return bufio.Copy(w, c.conn.Load()) return bufio.Copy(w, c.conn)
} }

158
common/humanize/bytes.go Normal file
View File

@ -0,0 +1,158 @@
package humanize
import (
"fmt"
"math"
"strconv"
"strings"
"unicode"
)
// IEC Sizes.
// kibis of bits
const (
Byte = 1 << (iota * 10)
KiByte
MiByte
GiByte
TiByte
PiByte
EiByte
)
// SI Sizes.
const (
IByte = 1
KByte = IByte * 1000
MByte = KByte * 1000
GByte = MByte * 1000
TByte = GByte * 1000
PByte = TByte * 1000
EByte = PByte * 1000
)
var defaultSizeTable = map[string]uint64{
"b": Byte,
"kib": KiByte,
"kb": KByte,
"mib": MiByte,
"mb": MByte,
"gib": GiByte,
"gb": GByte,
"tib": TiByte,
"tb": TByte,
"pib": PiByte,
"pb": PByte,
"eib": EiByte,
"eb": EByte,
// Without suffix
"": Byte,
"ki": KiByte,
"k": KByte,
"mi": MiByte,
"m": MByte,
"gi": GiByte,
"g": GByte,
"ti": TiByte,
"t": TByte,
"pi": PiByte,
"p": PByte,
"ei": EiByte,
"e": EByte,
}
var memorysSizeTable = map[string]uint64{
"b": Byte,
"kb": KiByte,
"mb": MiByte,
"gb": GiByte,
"tb": TiByte,
"pb": PiByte,
"eb": EiByte,
"": Byte,
"k": KiByte,
"m": MiByte,
"g": GiByte,
"t": TiByte,
"p": PiByte,
"e": EiByte,
}
var (
defaultSizes = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}
iSizes = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"}
)
func Bytes(s uint64) string {
return humanateBytes(s, 1000, defaultSizes)
}
func MemoryBytes(s uint64) string {
return humanateBytes(s, 1024, defaultSizes)
}
func IBytes(s uint64) string {
return humanateBytes(s, 1024, iSizes)
}
func logn(n, b float64) float64 {
return math.Log(n) / math.Log(b)
}
func humanateBytes(s uint64, base float64, sizes []string) string {
if s < 10 {
return fmt.Sprintf("%d B", s)
}
e := math.Floor(logn(float64(s), base))
suffix := sizes[int(e)]
val := math.Floor(float64(s)/math.Pow(base, e)*10+0.5) / 10
f := "%.0f %s"
if val < 10 {
f = "%.1f %s"
}
return fmt.Sprintf(f, val, suffix)
}
func ParseBytes(s string) (uint64, error) {
return parseBytes0(s, defaultSizeTable)
}
func ParseMemoryBytes(s string) (uint64, error) {
return parseBytes0(s, memorysSizeTable)
}
func parseBytes0(s string, sizeTable map[string]uint64) (uint64, error) {
lastDigit := 0
hasComma := false
for _, r := range s {
if !(unicode.IsDigit(r) || r == '.' || r == ',') {
break
}
if r == ',' {
hasComma = true
}
lastDigit++
}
num := s[:lastDigit]
if hasComma {
num = strings.Replace(num, ",", "", -1)
}
f, err := strconv.ParseFloat(num, 64)
if err != nil {
return 0, err
}
extra := strings.ToLower(strings.TrimSpace(s[lastDigit:]))
if m, ok := sizeTable[extra]; ok {
f *= float64(m)
if f >= math.MaxUint64 {
return 0, fmt.Errorf("too large: %v", s)
}
return uint64(f), nil
}
return 0, fmt.Errorf("unhandled size name: %v", extra)
}

View File

@ -1,84 +0,0 @@
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/tls"
"io"
"net"
"os"
"syscall"
"github.com/sagernet/sing-box/common/badtls"
// C "github.com/sagernet/sing-box/constant"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
)
type Conn struct {
aTLS.Conn
conn net.Conn
rawConn *badtls.RawConn
rawSyscallConn syscall.RawConn
readWaitOptions N.ReadWaitOptions
kernelTx bool
kernelRx bool
tmp [16]byte
}
func NewConn(conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
syscallConn, isSyscallConn := N.CastReader[interface {
io.Reader
syscall.Conn
}](conn.NetConn())
if !isSyscallConn {
return nil, os.ErrInvalid
}
rawSyscallConn, err := syscallConn.SyscallConn()
if err != nil {
return nil, err
}
rawConn, err := badtls.NewRawConn(conn)
if err != nil {
return nil, err
}
if *rawConn.Vers != tls.VersionTLS13 {
return nil, os.ErrInvalid
}
for rawConn.RawInput.Len() > 0 {
err = rawConn.ReadRecord()
if err != nil {
return nil, err
}
for rawConn.Hand.Len() > 0 {
err = rawConn.HandlePostHandshakeMessage()
if err != nil {
return nil, E.Cause(err, "ktls: failed to handle post-handshake messages")
}
}
}
kConn := &Conn{
Conn: conn,
conn: conn.NetConn(),
rawConn: rawConn,
rawSyscallConn: rawSyscallConn,
}
err = kConn.setupKernel(txOffload, rxOffload)
if err != nil {
return nil, err
}
return kConn, nil
}
func (c *Conn) Upstream() any {
return c.conn
}
func (c *Conn) ReaderReplaceable() bool {
return c.kernelRx
}
func (c *Conn) WriterReplaceable() bool {
return c.kernelTx
}

View File

@ -1,80 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/tls"
"net"
)
const (
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify = 0
alertUnexpectedMessage = 10
alertBadRecordMAC = 20
alertDecryptionFailed = 21
alertRecordOverflow = 22
alertDecompressionFailure = 30
alertHandshakeFailure = 40
alertBadCertificate = 42
alertUnsupportedCertificate = 43
alertCertificateRevoked = 44
alertCertificateExpired = 45
alertCertificateUnknown = 46
alertIllegalParameter = 47
alertUnknownCA = 48
alertAccessDenied = 49
alertDecodeError = 50
alertDecryptError = 51
alertExportRestriction = 60
alertProtocolVersion = 70
alertInsufficientSecurity = 71
alertInternalError = 80
alertInappropriateFallback = 86
alertUserCanceled = 90
alertNoRenegotiation = 100
alertMissingExtension = 109
alertUnsupportedExtension = 110
alertCertificateUnobtainable = 111
alertUnrecognizedName = 112
alertBadCertificateStatusResponse = 113
alertBadCertificateHashValue = 114
alertUnknownPSKIdentity = 115
alertCertificateRequired = 116
alertNoApplicationProtocol = 120
alertECHRequired = 121
)
func (c *Conn) sendAlertLocked(err uint8) error {
switch err {
case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning
default:
c.tmp[0] = alertLevelError
}
c.tmp[1] = byte(err)
_, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
if err == alertCloseNotify {
// closeNotify is a special case in that it isn't an error.
return writeErr
}
return c.rawConn.Out.SetErrorLocked(&net.OpError{Op: "local error", Err: tls.AlertError(err)})
}
// sendAlert sends a TLS alert message.
func (c *Conn) sendAlert(err uint8) error {
c.rawConn.Out.Lock()
defer c.rawConn.Out.Unlock()
return c.sendAlertLocked(err)
}

View File

@ -1,326 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/tls"
"unsafe"
"github.com/sagernet/sing-box/common/badtls"
)
type kernelCryptoCipherType uint16
const (
TLS_CIPHER_AES_GCM_128 kernelCryptoCipherType = 51
TLS_CIPHER_AES_GCM_128_IV_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_AES_GCM_128_KEY_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_AES_GCM_128_SALT_SIZE kernelCryptoCipherType = 4
TLS_CIPHER_AES_GCM_128_TAG_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_AES_GCM_256 kernelCryptoCipherType = 52
TLS_CIPHER_AES_GCM_256_IV_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_AES_GCM_256_KEY_SIZE kernelCryptoCipherType = 32
TLS_CIPHER_AES_GCM_256_SALT_SIZE kernelCryptoCipherType = 4
TLS_CIPHER_AES_GCM_256_TAG_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_AES_CCM_128 kernelCryptoCipherType = 53
TLS_CIPHER_AES_CCM_128_IV_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_AES_CCM_128_KEY_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_AES_CCM_128_SALT_SIZE kernelCryptoCipherType = 4
TLS_CIPHER_AES_CCM_128_TAG_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_CHACHA20_POLY1305 kernelCryptoCipherType = 54
TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE kernelCryptoCipherType = 12
TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE kernelCryptoCipherType = 32
TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE kernelCryptoCipherType = 0
TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE kernelCryptoCipherType = 8
// TLS_CIPHER_SM4_GCM kernelCryptoCipherType = 55
// TLS_CIPHER_SM4_GCM_IV_SIZE kernelCryptoCipherType = 8
// TLS_CIPHER_SM4_GCM_KEY_SIZE kernelCryptoCipherType = 16
// TLS_CIPHER_SM4_GCM_SALT_SIZE kernelCryptoCipherType = 4
// TLS_CIPHER_SM4_GCM_TAG_SIZE kernelCryptoCipherType = 16
// TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE kernelCryptoCipherType = 8
// TLS_CIPHER_SM4_CCM kernelCryptoCipherType = 56
// TLS_CIPHER_SM4_CCM_IV_SIZE kernelCryptoCipherType = 8
// TLS_CIPHER_SM4_CCM_KEY_SIZE kernelCryptoCipherType = 16
// TLS_CIPHER_SM4_CCM_SALT_SIZE kernelCryptoCipherType = 4
// TLS_CIPHER_SM4_CCM_TAG_SIZE kernelCryptoCipherType = 16
// TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_ARIA_GCM_128 kernelCryptoCipherType = 57
TLS_CIPHER_ARIA_GCM_128_IV_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_ARIA_GCM_128_KEY_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_ARIA_GCM_128_SALT_SIZE kernelCryptoCipherType = 4
TLS_CIPHER_ARIA_GCM_128_TAG_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_ARIA_GCM_256 kernelCryptoCipherType = 58
TLS_CIPHER_ARIA_GCM_256_IV_SIZE kernelCryptoCipherType = 8
TLS_CIPHER_ARIA_GCM_256_KEY_SIZE kernelCryptoCipherType = 32
TLS_CIPHER_ARIA_GCM_256_SALT_SIZE kernelCryptoCipherType = 4
TLS_CIPHER_ARIA_GCM_256_TAG_SIZE kernelCryptoCipherType = 16
TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8
)
type kernelCrypto interface {
String() string
}
type kernelCryptoInfo struct {
version uint16
cipher_type kernelCryptoCipherType
}
var _ kernelCrypto = &kernelCryptoAES128GCM{}
type kernelCryptoAES128GCM struct {
kernelCryptoInfo
iv [TLS_CIPHER_AES_GCM_128_IV_SIZE]byte
key [TLS_CIPHER_AES_GCM_128_KEY_SIZE]byte
salt [TLS_CIPHER_AES_GCM_128_SALT_SIZE]byte
rec_seq [TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE]byte
}
func (crypto *kernelCryptoAES128GCM) String() string {
crypto.cipher_type = TLS_CIPHER_AES_GCM_128
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
}
var _ kernelCrypto = &kernelCryptoAES256GCM{}
type kernelCryptoAES256GCM struct {
kernelCryptoInfo
iv [TLS_CIPHER_AES_GCM_256_IV_SIZE]byte
key [TLS_CIPHER_AES_GCM_256_KEY_SIZE]byte
salt [TLS_CIPHER_AES_GCM_256_SALT_SIZE]byte
rec_seq [TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE]byte
}
func (crypto *kernelCryptoAES256GCM) String() string {
crypto.cipher_type = TLS_CIPHER_AES_GCM_256
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
}
var _ kernelCrypto = &kernelCryptoAES128CCM{}
type kernelCryptoAES128CCM struct {
kernelCryptoInfo
iv [TLS_CIPHER_AES_CCM_128_IV_SIZE]byte
key [TLS_CIPHER_AES_CCM_128_KEY_SIZE]byte
salt [TLS_CIPHER_AES_CCM_128_SALT_SIZE]byte
rec_seq [TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE]byte
}
func (crypto *kernelCryptoAES128CCM) String() string {
crypto.cipher_type = TLS_CIPHER_AES_CCM_128
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
}
var _ kernelCrypto = &kernelCryptoChacha20Poly1035{}
type kernelCryptoChacha20Poly1035 struct {
kernelCryptoInfo
iv [TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE]byte
key [TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE]byte
salt [TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE]byte
rec_seq [TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE]byte
}
func (crypto *kernelCryptoChacha20Poly1035) String() string {
crypto.cipher_type = TLS_CIPHER_CHACHA20_POLY1305
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
}
// var _ kernelCrypto = &kernelCryptoSM4GCM{}
// type kernelCryptoSM4GCM struct {
// kernelCryptoInfo
// iv [TLS_CIPHER_SM4_GCM_IV_SIZE]byte
// key [TLS_CIPHER_SM4_GCM_KEY_SIZE]byte
// salt [TLS_CIPHER_SM4_GCM_SALT_SIZE]byte
// rec_seq [TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE]byte
// }
// func (crypto *kernelCryptoSM4GCM) String() string {
// crypto.cipher_type = TLS_CIPHER_SM4_GCM
// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
// }
// var _ kernelCrypto = &kernelCryptoSM4CCM{}
// type kernelCryptoSM4CCM struct {
// kernelCryptoInfo
// iv [TLS_CIPHER_SM4_CCM_IV_SIZE]byte
// key [TLS_CIPHER_SM4_CCM_KEY_SIZE]byte
// salt [TLS_CIPHER_SM4_CCM_SALT_SIZE]byte
// rec_seq [TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE]byte
// }
// func (crypto *kernelCryptoSM4CCM) String() string {
// crypto.cipher_type = TLS_CIPHER_SM4_CCM
// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
// }
var _ kernelCrypto = &kernelCryptoARIA128GCM{}
type kernelCryptoARIA128GCM struct {
kernelCryptoInfo
iv [TLS_CIPHER_ARIA_GCM_128_IV_SIZE]byte
key [TLS_CIPHER_ARIA_GCM_128_KEY_SIZE]byte
salt [TLS_CIPHER_ARIA_GCM_128_SALT_SIZE]byte
rec_seq [TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE]byte
}
func (crypto *kernelCryptoARIA128GCM) String() string {
crypto.cipher_type = TLS_CIPHER_ARIA_GCM_128
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
}
var _ kernelCrypto = &kernelCryptoARIA256GCM{}
type kernelCryptoARIA256GCM struct {
kernelCryptoInfo
iv [TLS_CIPHER_ARIA_GCM_256_IV_SIZE]byte
key [TLS_CIPHER_ARIA_GCM_256_KEY_SIZE]byte
salt [TLS_CIPHER_ARIA_GCM_256_SALT_SIZE]byte
rec_seq [TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE]byte
}
func (crypto *kernelCryptoARIA256GCM) String() string {
crypto.cipher_type = TLS_CIPHER_ARIA_GCM_256
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
}
func kernelCipher(kernel *Support, hc *badtls.RawHalfConn, cipherSuite uint16, isRX bool) kernelCrypto {
if !kernel.TLS {
return nil
}
switch *hc.Version {
case tls.VersionTLS12:
if isRX && !kernel.TLS_Version13_RX {
return nil
}
case tls.VersionTLS13:
if !kernel.TLS_Version13 {
return nil
}
if isRX && !kernel.TLS_Version13_RX {
return nil
}
default:
return nil
}
var key, iv []byte
if *hc.Version == tls.VersionTLS13 {
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), *hc.TrafficSecret)
/*if isRX {
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.RemoteTrafficSecret)
} else {
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.TrafficSecret)
}*/
} else {
// csPtr := cipherSuiteByID(cipherSuite)
// keysFromMasterSecret(*hc.Version, csPtr, keyLog.Secret, keyLog.Random)
return nil
}
switch cipherSuite {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
crypto := new(kernelCryptoAES128GCM)
crypto.version = *hc.Version
copy(crypto.key[:], key)
copy(crypto.iv[:], iv[4:])
copy(crypto.salt[:], iv[:4])
crypto.rec_seq = *hc.Seq
return crypto
case tls.TLS_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
if !kernel.TLS_AES_256_GCM {
return nil
}
crypto := new(kernelCryptoAES256GCM)
crypto.version = *hc.Version
copy(crypto.key[:], key)
copy(crypto.iv[:], iv[4:])
copy(crypto.salt[:], iv[:4])
crypto.rec_seq = *hc.Seq
return crypto
//case tls.TLS_AES_128_CCM_SHA256, tls.TLS_RSA_WITH_AES_128_CCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_SHA256:
// if !kernel.TLS_AES_128_CCM {
// return nil
// }
//
// crypto := new(kernelCryptoAES128CCM)
//
// crypto.version = *hc.Version
// copy(crypto.key[:], key)
// copy(crypto.iv[:], iv[4:])
// copy(crypto.salt[:], iv[:4])
// crypto.rec_seq = *hc.Seq
//
// return crypto
case tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
if !kernel.TLS_CHACHA20_POLY1305 {
return nil
}
crypto := new(kernelCryptoChacha20Poly1035)
crypto.version = *hc.Version
copy(crypto.key[:], key)
copy(crypto.iv[:], iv)
crypto.rec_seq = *hc.Seq
return crypto
//case tls.TLS_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256:
// if !kernel.TLS_ARIA_GCM {
// return nil
// }
//
// crypto := new(kernelCryptoARIA128GCM)
//
// crypto.version = *hc.Version
// copy(crypto.key[:], key)
// copy(crypto.iv[:], iv[4:])
// copy(crypto.salt[:], iv[:4])
// crypto.rec_seq = *hc.Seq
//
// return crypto
//case tls.TLS_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384:
// if !kernel.TLS_ARIA_GCM {
// return nil
// }
//
// crypto := new(kernelCryptoARIA256GCM)
//
// crypto.version = *hc.Version
// copy(crypto.key[:], key)
// copy(crypto.iv[:], iv[4:])
// copy(crypto.salt[:], iv[:4])
// crypto.rec_seq = *hc.Seq
//
// return crypto
default:
return nil
}
}

View File

@ -1,67 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"fmt"
"net"
"time"
)
func (c *Conn) Close() error {
if !c.kernelTx {
return c.Conn.Close()
}
// Interlock with Conn.Write above.
var x int32
for {
x = c.rawConn.ActiveCall.Load()
if x&1 != 0 {
return net.ErrClosed
}
if c.rawConn.ActiveCall.CompareAndSwap(x, x|1) {
break
}
}
if x != 0 {
// io.Writer and io.Closer should not be used concurrently.
// If Close is called while a Write is currently in-flight,
// interpret that as a sign that this Close is really just
// being used to break the Write and/or clean up resources and
// avoid sending the alertCloseNotify, which may block
// waiting on handshakeMutex or the c.out mutex.
return c.conn.Close()
}
var alertErr error
if c.rawConn.IsHandshakeComplete.Load() {
if err := c.closeNotify(); err != nil {
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
}
}
if err := c.conn.Close(); err != nil {
return err
}
return alertErr
}
func (c *Conn) closeNotify() error {
c.rawConn.Out.Lock()
defer c.rawConn.Out.Unlock()
if !*c.rawConn.CloseNotifySent {
// Set a Write Deadline to prevent possibly blocking forever.
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
*c.rawConn.CloseNotifyErr = c.sendAlertLocked(alertCloseNotify)
*c.rawConn.CloseNotifySent = true
// Any subsequent writes will fail.
c.SetWriteDeadline(time.Now())
}
return *c.rawConn.CloseNotifyErr
}

View File

@ -1,24 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
const (
maxPlaintext = 16384 // maximum plaintext payload length
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
maxHandshakeCertificateMsg = 262144 // maximum certificate message size (256 KiB)
maxUselessRecords = 16 // maximum number of consecutive non-advancing records
)
const (
recordTypeChangeCipherSpec = 20
recordTypeAlert = 21
recordTypeHandshake = 22
recordTypeApplicationData = 23
)

View File

@ -1,238 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"fmt"
"golang.org/x/crypto/cryptobyte"
)
// The marshalingFunction type is an adapter to allow the use of ordinary
// functions as cryptobyte.MarshalingValue.
type marshalingFunction func(b *cryptobyte.Builder) error
func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
return f(b)
}
// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
// the length of the sequence is not the value specified, it produces an error.
func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
if len(v) != n {
return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
}
b.AddBytes(v)
return nil
}))
}
// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
func addUint64(b *cryptobyte.Builder, v uint64) {
b.AddUint32(uint32(v >> 32))
b.AddUint32(uint32(v))
}
// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
// It reports whether the read was successful.
func readUint64(s *cryptobyte.String, out *uint64) bool {
var hi, lo uint32
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
return false
}
*out = uint64(hi)<<32 | uint64(lo)
return true
}
// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
}
// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
}
// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
}
type keyUpdateMsg struct {
updateRequested bool
}
func (m *keyUpdateMsg) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
if m.updateRequested {
b.AddUint8(1)
} else {
b.AddUint8(0)
}
})
return b.Bytes()
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
s := cryptobyte.String(data)
var updateRequested uint8
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8(&updateRequested) || !s.Empty() {
return false
}
switch updateRequested {
case 0:
m.updateRequested = false
case 1:
m.updateRequested = true
default:
return false
}
return true
}
// TLS handshake message types.
const (
typeHelloRequest uint8 = 0
typeClientHello uint8 = 1
typeServerHello uint8 = 2
typeNewSessionTicket uint8 = 4
typeEndOfEarlyData uint8 = 5
typeEncryptedExtensions uint8 = 8
typeCertificate uint8 = 11
typeServerKeyExchange uint8 = 12
typeCertificateRequest uint8 = 13
typeServerHelloDone uint8 = 14
typeCertificateVerify uint8 = 15
typeClientKeyExchange uint8 = 16
typeFinished uint8 = 20
typeCertificateStatus uint8 = 22
typeKeyUpdate uint8 = 24
typeCompressedCertificate uint8 = 25
typeMessageHash uint8 = 254 // synthetic message
)
// TLS compression types.
const (
compressionNone uint8 = 0
)
// TLS extension numbers
const (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionSCT uint16 = 18
extensionPadding uint16 = 21
extensionExtendedMasterSecret uint16 = 23
extensionCompressCertificate uint16 = 27 // compress_certificate in TLS 1.3
extensionSessionTicket uint16 = 35
extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43
extensionCookie uint16 = 44
extensionPSKModes uint16 = 45
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
extensionQUICTransportParameters uint16 = 57
extensionALPS uint16 = 17513
extensionRenegotiationInfo uint16 = 0xff01
extensionECHOuterExtensions uint16 = 0xfd00
extensionEncryptedClientHello uint16 = 0xfe0d
)
type handshakeMessage interface {
marshal() ([]byte, error)
unmarshal([]byte) bool
}
type newSessionTicketMsgTLS13 struct {
lifetime uint32
ageAdd uint32
nonce []byte
label []byte
maxEarlyData uint32
}
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.lifetime)
b.AddUint32(m.ageAdd)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.nonce)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.label)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.maxEarlyData > 0 {
b.AddUint16(extensionEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.maxEarlyData)
})
}
})
})
return b.Bytes()
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
*m = newSessionTicketMsgTLS13{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint32(&m.lifetime) ||
!s.ReadUint32(&m.ageAdd) ||
!readUint8LengthPrefixed(&s, &m.nonce) ||
!readUint16LengthPrefixed(&s, &m.label) ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionEarlyData:
if !extData.ReadUint32(&m.maxEarlyData) {
return false
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}

View File

@ -1,173 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/tls"
"errors"
"fmt"
"io"
"os"
)
// handlePostHandshakeMessage processes a handshake message arrived after the
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
func (c *Conn) handlePostHandshakeMessage() error {
if *c.rawConn.Vers != tls.VersionTLS13 {
return errors.New("ktls: kernel does not support TLS 1.2 renegotiation")
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
//c.retryCount++
//if c.retryCount > maxUselessRecords {
// c.sendAlert(alertUnexpectedMessage)
// return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
//}
switch msg := msg.(type) {
case *newSessionTicketMsgTLS13:
// return errors.New("ktls: received new session ticket")
return nil
case *keyUpdateMsg:
return c.handleKeyUpdate(msg)
}
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
// unexpected_message alert here doesn't provide it with enough information to distinguish
// this condition from other unexpected messages. This is probably fine.
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
//if c.quic != nil {
// c.sendAlert(alertUnexpectedMessage)
// return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
//}
cipherSuite := cipherSuiteTLS13ByID(*c.rawConn.CipherSuite)
if cipherSuite == nil {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertInternalError))
}
newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.In.TrafficSecret)
c.rawConn.In.SetTrafficSecret(cipherSuite, 0 /*tls.QUICEncryptionLevelInitial*/, newSecret)
err := c.resetupRX()
if err != nil {
c.sendAlert(alertInternalError)
return c.rawConn.In.SetErrorLocked(fmt.Errorf("ktls: resetupRX failed: %w", err))
}
if keyUpdate.updateRequested {
c.rawConn.Out.Lock()
defer c.rawConn.Out.Unlock()
resetup, err := c.resetupTX()
if err != nil {
c.sendAlertLocked(alertInternalError)
return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
}
msg := &keyUpdateMsg{}
msgBytes, err := msg.marshal()
if err != nil {
return err
}
_, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
if err != nil {
// Surface the error at the next write.
c.rawConn.Out.SetErrorLocked(err)
return nil
}
newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.Out.TrafficSecret)
c.rawConn.Out.SetTrafficSecret(cipherSuite, 0 /*QUICEncryptionLevelInitial*/, newSecret)
err = resetup()
if err != nil {
return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
}
}
return nil
}
func (c *Conn) readHandshakeBytes(n int) error {
//if c.quic != nil {
// return c.quicReadHandshakeBytes(n)
//}
for c.rawConn.Hand.Len() < n {
if err := c.readRecord(); err != nil {
return err
}
}
return nil
}
func (c *Conn) readHandshake(transcript io.Writer) (any, error) {
if err := c.readHandshakeBytes(4); err != nil {
return nil, err
}
data := c.rawConn.Hand.Bytes()
maxHandshakeSize := maxHandshake
// hasVers indicates we're past the first message, forcing someone trying to
// make us just allocate a large buffer to at least do the initial part of
// the handshake first.
//if c.haveVers && data[0] == typeCertificate {
// Since certificate messages are likely to be the only messages that
// can be larger than maxHandshake, we use a special limit for just
// those messages.
//maxHandshakeSize = maxHandshakeCertificateMsg
//}
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshakeSize {
c.sendAlertLocked(alertInternalError)
return nil, c.rawConn.In.SetErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
}
if err := c.readHandshakeBytes(4 + n); err != nil {
return nil, err
}
data = c.rawConn.Hand.Next(4 + n)
return c.unmarshalHandshakeMessage(data, transcript)
}
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript io.Writer) (any, error) {
var m handshakeMessage
switch data[0] {
case typeNewSessionTicket:
if *c.rawConn.Vers == tls.VersionTLS13 {
m = new(newSessionTicketMsgTLS13)
} else {
return nil, os.ErrInvalid
}
case typeKeyUpdate:
m = new(keyUpdateMsg)
default:
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshalers
// expect to be able to keep references to data,
// so pass in a fresh copy that won't be overwritten.
data = append([]byte(nil), data...)
if !m.unmarshal(data) {
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
}
if transcript != nil {
transcript.Write(data)
}
return m, nil
}

View File

@ -1,311 +0,0 @@
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/tls"
"errors"
"io"
"os"
"strings"
"sync"
"syscall"
"unsafe"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/blang/semver/v4"
"golang.org/x/sys/unix"
)
// mod from https://gitlab.com/go-extension/tls
const (
TLS_TX = 1
TLS_RX = 2
TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now)
TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only
TLS_SET_RECORD_TYPE = 1
TLS_GET_RECORD_TYPE = 2
)
type Support struct {
TLS, TLS_RX bool
TLS_Version13, TLS_Version13_RX bool
TLS_TX_ZEROCOPY bool
TLS_RX_NOPADDING bool
TLS_AES_256_GCM bool
TLS_AES_128_CCM bool
TLS_CHACHA20_POLY1305 bool
TLS_SM4 bool
TLS_ARIA_GCM bool
TLS_Version13_KeyUpdate bool
}
var KernelSupport = sync.OnceValues(func() (*Support, error) {
_, err := os.Stat("/sys/module/tls")
if err != nil {
return nil, E.New("ktls: kernel module tls not found")
}
var uname unix.Utsname
err = unix.Uname(&uname)
if err != nil {
return nil, err
}
kernelVersion, err := semver.Parse(strings.Trim(string(uname.Release[:]), "\x00"))
if err != nil {
return nil, err
}
kernelVersion.Pre = nil
kernelVersion.Build = nil
var support Support
switch {
case kernelVersion.GTE(semver.Version{Major: 6, Minor: 14}):
support.TLS_Version13_KeyUpdate = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 6, Minor: 1}):
support.TLS_ARIA_GCM = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 6}):
support.TLS_Version13_RX = true
support.TLS_RX_NOPADDING = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 5, Minor: 19}):
support.TLS_TX_ZEROCOPY = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 5, Minor: 16}):
support.TLS_SM4 = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 5, Minor: 11}):
support.TLS_CHACHA20_POLY1305 = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 5, Minor: 2}):
support.TLS_AES_128_CCM = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 5, Minor: 1}):
support.TLS_AES_256_GCM = true
support.TLS_Version13 = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 4, Minor: 17}):
support.TLS_RX = true
fallthrough
case kernelVersion.GTE(semver.Version{Major: 4, Minor: 13}):
support.TLS = true
}
return &support, nil
})
func (c *Conn) setupKernel(txOffload, rxOffload bool) error {
if !txOffload && !rxOffload {
return nil
}
support, err := KernelSupport()
if err != nil {
return err
}
if !support.TLS {
return nil
}
c.rawConn.Out.Lock()
defer c.rawConn.Out.Unlock()
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptString(int(fd), unix.SOL_TCP, unix.TCP_ULP, "tls")
})
if err != nil {
return E.Cause(err, "initialize kernel TLS")
}
if rxOffload {
rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
if rxCrypto == nil {
return E.New("kTLS: unsupported cipher suite")
}
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
})
if err != nil {
return err
}
if /*config.KernelRXExpectNoPad &&*/ *c.rawConn.Vers >= tls.VersionTLS13 && support.TLS_RX_NOPADDING {
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1)
})
if err != nil {
return err
}
}
c.kernelRx = true
}
if txOffload {
txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
if txCrypto == nil {
return E.New("kTLS: unsupported cipher suite")
}
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
})
if err != nil {
return err
}
if support.TLS_TX_ZEROCOPY {
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_TX_ZEROCOPY_RO, 1)
})
if err != nil {
return err
}
}
c.kernelTx = true
}
return nil
}
func (c *Conn) resetupTX() (func() error, error) {
if !c.kernelTx {
return nil, nil
}
support, err := KernelSupport()
if err != nil {
return nil, err
}
if !support.TLS_Version13_KeyUpdate {
return nil, errors.New("ktls: kernel does not support rekey")
}
txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
if txCrypto == nil {
return nil, errors.New("ktls: set kernelCipher on unsupported tls session")
}
return func() error {
return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
})
}, nil
}
func (c *Conn) resetupRX() error {
if !c.kernelRx {
return nil
}
support, err := KernelSupport()
if err != nil {
return err
}
if !support.TLS_Version13_KeyUpdate {
return errors.New("ktls: kernel does not support rekey")
}
rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
if rxCrypto == nil {
return errors.New("ktls: set kernelCipher on unsupported tls session")
}
return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
})
}
func (c *Conn) readKernelRecord() (uint8, []byte, error) {
if c.rawConn.RawInput.Len() < maxPlaintext {
c.rawConn.RawInput.Grow(maxPlaintext - c.rawConn.RawInput.Len())
}
data := c.rawConn.RawInput.Bytes()[:maxPlaintext]
// cmsg for record type
buffer := make([]byte, unix.CmsgSpace(1))
cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
cmsg.SetLen(unix.CmsgLen(1))
var iov unix.Iovec
iov.Base = &data[0]
iov.SetLen(len(data))
var msg unix.Msghdr
msg.Control = &buffer[0]
msg.Controllen = cmsg.Len
msg.Iov = &iov
msg.Iovlen = 1
var n int
var err error
er := c.rawSyscallConn.Read(func(fd uintptr) bool {
n, err = recvmsg(int(fd), &msg, 0)
return err != unix.EAGAIN
})
if er != nil {
return 0, nil, er
}
switch err {
case nil:
case syscall.EINVAL:
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion))
case syscall.EMSGSIZE:
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
case syscall.EBADMSG:
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecryptError))
default:
return 0, nil, err
}
if n <= 0 {
return 0, nil, io.EOF
}
if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE {
typ := buffer[unix.CmsgLen(0)]
return typ, data[:n], nil
}
return recordTypeApplicationData, data[:n], nil
}
func (c *Conn) writeKernelRecord(typ uint16, data []byte) (int, error) {
if typ == recordTypeApplicationData {
return c.conn.Write(data)
}
// cmsg for record type
buffer := make([]byte, unix.CmsgSpace(1))
cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
cmsg.SetLen(unix.CmsgLen(1))
buffer[unix.CmsgLen(0)] = byte(typ)
cmsg.Level = unix.SOL_TLS
cmsg.Type = TLS_SET_RECORD_TYPE
var iov unix.Iovec
iov.Base = &data[0]
iov.SetLen(len(data))
var msg unix.Msghdr
msg.Control = &buffer[0]
msg.Controllen = cmsg.Len
msg.Iov = &iov
msg.Iovlen = 1
var n int
var err error
ew := c.rawSyscallConn.Write(func(fd uintptr) bool {
n, err = sendmsg(int(fd), &msg, 0)
return err != unix.EAGAIN
})
if ew != nil {
return 0, ew
}
return n, err
}
//go:linkname recvmsg golang.org/x/sys/unix.recvmsg
func recvmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
func sendmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)

View File

@ -1,24 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import "unsafe"
//go:linkname cipherSuiteByID github.com/metacubex/utls.cipherSuiteByID
func cipherSuiteByID(id uint16) unsafe.Pointer
//go:linkname keysFromMasterSecret github.com/metacubex/utls.keysFromMasterSecret
func keysFromMasterSecret(version uint16, suite unsafe.Pointer, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte)
//go:linkname cipherSuiteTLS13ByID github.com/metacubex/utls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) unsafe.Pointer
//go:linkname nextTrafficSecret github.com/metacubex/utls.(*cipherSuiteTLS13).nextTrafficSecret
func nextTrafficSecret(cs unsafe.Pointer, trafficSecret []byte) []byte
//go:linkname trafficKey github.com/metacubex/utls.(*cipherSuiteTLS13).trafficKey
func trafficKey(cs unsafe.Pointer, trafficSecret []byte) (key, iv []byte)

View File

@ -1,292 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
)
func (c *Conn) Read(b []byte) (int, error) {
if !c.kernelRx {
return c.Conn.Read(b)
}
if len(b) == 0 {
// Put this after Handshake, in case people were calling
// Read(nil) for the side effect of the Handshake.
return 0, nil
}
c.rawConn.In.Lock()
defer c.rawConn.In.Unlock()
for c.rawConn.Input.Len() == 0 {
if err := c.readRecord(); err != nil {
return 0, err
}
for c.rawConn.Hand.Len() > 0 {
if err := c.handlePostHandshakeMessage(); err != nil {
return 0, err
}
}
}
n, _ := c.rawConn.Input.Read(b)
// If a close-notify alert is waiting, read it so that we can return (n,
// EOF) instead of (n, nil), to signal to the HTTP response reading
// goroutine that the connection is now closed. This eliminates a race
// where the HTTP response reading goroutine would otherwise not observe
// the EOF until its next read, by which time a client goroutine might
// have already tried to reuse the HTTP connection for a new request.
// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.RawInput.Len() > 0 &&
c.rawConn.RawInput.Bytes()[0] == recordTypeAlert {
if err := c.readRecord(); err != nil {
return n, err // will be io.EOF on closeNotify
}
}
return n, nil
}
func (c *Conn) readRecord() error {
if *c.rawConn.In.Err != nil {
return *c.rawConn.In.Err
}
typ, data, err := c.readRawRecord()
if err != nil {
return err
}
if len(data) > maxPlaintext {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
}
// Application Data messages are always protected.
if c.rawConn.In.Cipher == nil && typ == recordTypeApplicationData {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
//if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
// This is a state-advancing message: reset the retry count.
// c.retryCount = 0
//}
// Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
if *c.rawConn.Vers == tls.VersionTLS13 && typ != recordTypeHandshake && c.rawConn.Hand.Len() > 0 {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
switch typ {
default:
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
//if c.quic != nil {
// return c.rawConn.In.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
//}
if len(data) != 2 {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if data[1] == alertCloseNotify {
return c.rawConn.In.SetErrorLocked(io.EOF)
}
if *c.rawConn.Vers == tls.VersionTLS13 {
// TLS 1.3 removed warning-level alerts except for alertUserCanceled
// (RFC 8446, § 6.1). Since at least one major implementation
// (https://bugs.openjdk.org/browse/JDK-8323517) misuses this alert,
// many TLS stacks now ignore it outright when seen in a TLS 1.3
// handshake (e.g. BoringSSL, NSS, Rustls).
if data[1] == alertUserCanceled {
// Like TLS 1.2 alertLevelWarning alerts, we drop the record and retry.
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
}
return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])})
}
switch data[0] {
case alertLevelWarning:
// Drop the record on the floor and retry.
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
case alertLevelError:
return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])})
default:
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if len(data) != 1 || data[0] != 1 {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
}
// Handshake messages are not allowed to fragment across the CCS.
if c.rawConn.Hand.Len() > 0 {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// In TLS 1.3, change_cipher_spec records are ignored until the
// Finished. See RFC 8446, Appendix D.4. Note that according to Section
// 5, a server can send a ChangeCipherSpec before its ServerHello, when
// c.vers is still unset. That's not useful though and suspicious if the
// server then selects a lower protocol version, so don't allow that.
if *c.rawConn.Vers == tls.VersionTLS13 {
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
}
// if !expectChangeCipherSpec {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
//}
//if err := c.rawConn.In.changeCipherSpec(); err != nil {
// return c.rawConn.In.setErrorLocked(c.sendAlert(err.(alert)))
//}
case recordTypeApplicationData:
// Some OpenSSL servers send empty records in order to randomize the
// CBC RawIV. Ignore a limited number of empty records.
if len(data) == 0 {
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
}
// Note that data is owned by c.rawInput, following the Next call above,
// to avoid copying the plaintext. This is safe because c.rawInput is
// not read from or written to until c.input is drained.
c.rawConn.Input.Reset(data)
case recordTypeHandshake:
if len(data) == 0 {
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
c.rawConn.Hand.Write(data)
}
return nil
}
//nolint:staticcheck
func (c *Conn) readRawRecord() (typ uint8, data []byte, err error) {
// Read from kernel.
if c.kernelRx {
return c.readKernelRecord()
}
// Read header, payload.
if err = c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
// is an error, but popular web sites seem to do this, so we accept it
// if and only if at the record boundary.
if err == io.ErrUnexpectedEOF && c.rawConn.RawInput.Len() == 0 {
err = io.EOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.rawConn.In.SetErrorLocked(err)
}
return
}
hdr := c.rawConn.RawInput.Bytes()[:recordHeaderLen]
typ = hdr[0]
vers := uint16(hdr[1])<<8 | uint16(hdr[2])
expectedVers := *c.rawConn.Vers
if expectedVers == tls.VersionTLS13 {
// All TLS 1.3 records are expected to have 0x0303 (1.2) after
// the initial hello (RFC 8446 Section 5.1).
expectedVers = tls.VersionTLS12
}
n := int(hdr[3])<<8 | int(hdr[4])
if /*c.haveVers && */ vers != expectedVers {
c.sendAlert(alertProtocolVersion)
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg))
return
}
//if !c.haveVers {
// // First message, be extra suspicious: this might not be a TLS
// // client. Bail out before reading a full 'body', if possible.
// // The current max version is 3.3 so if the version is >= 16.0,
// // it's probably not real.
// if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
// err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
// return
// }
//}
if *c.rawConn.Vers == tls.VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
msg := fmt.Sprintf("oversized record received with length %d", n)
err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg))
return
}
if err = c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.rawConn.In.SetErrorLocked(err)
}
return
}
// Process message.
record := c.rawConn.RawInput.Next(recordHeaderLen + n)
data, typ, err = c.rawConn.In.Decrypt(record)
if err != nil {
err = c.rawConn.In.SetErrorLocked(c.sendAlert(uint8(err.(tls.AlertError))))
return
}
return
}
// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
func (c *Conn) retryReadRecord( /*expectChangeCipherSpec bool*/ ) error {
//c.retryCount++
//if c.retryCount > maxUselessRecords {
// c.sendAlert(alertUnexpectedMessage)
// return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
//}
return c.readRecord( /*expectChangeCipherSpec*/ )
}
// atLeastReader reads from R, stopping with EOF once at least N bytes have been
// read. It is different from an io.LimitedReader in that it doesn't cut short
// the last Read call, and in that it considers an early EOF an error.
type atLeastReader struct {
R io.Reader
N int64
}
func (r *atLeastReader) Read(p []byte) (int, error) {
if r.N <= 0 {
return 0, io.EOF
}
n, err := r.R.Read(p)
r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
if r.N > 0 && err == io.EOF {
return n, io.ErrUnexpectedEOF
}
if r.N <= 0 && err == nil {
return n, io.EOF
}
return n, err
}
// readFromUntil reads from r into c.rawConn.RawInput until c.rawConn.RawInput contains
// at least n bytes or else returns an error.
func (c *Conn) readFromUntil(r io.Reader, n int) error {
if c.rawConn.RawInput.Len() >= n {
return nil
}
needs := n - c.rawConn.RawInput.Len()
// There might be extra input waiting on the wire. Make a best effort
// attempt to fetch it so that it can be used in (*Conn).Read to
// "predict" closeNotify alerts.
c.rawConn.RawInput.Grow(needs + bytes.MinRead)
_, err := c.rawConn.RawInput.ReadFrom(&atLeastReader{r, int64(needs)})
return err
}
func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err tls.RecordHeaderError) {
err.Msg = msg
err.Conn = conn
copy(err.RecordHeader[:], c.rawConn.RawInput.Bytes())
return err
}

View File

@ -1,41 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *Conn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
c.rawConn.In.Lock()
defer c.rawConn.In.Unlock()
for c.rawConn.Input.Len() == 0 {
err = c.readRecord()
if err != nil {
return
}
}
buffer = c.readWaitOptions.NewBuffer()
n, err := c.rawConn.Input.Read(buffer.FreeBytes())
if err != nil {
buffer.Release()
return
}
buffer.Truncate(n)
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 &&
c.rawConn.RawInput.Bytes()[0] == recordTypeAlert {
_ = c.rawConn.ReadRecord()
}
c.readWaitOptions.PostReturn(buffer)
return
}

View File

@ -1,13 +0,0 @@
//go:build !linux || !go1.25 || without_badtls
package ktls
import (
"os"
aTLS "github.com/sagernet/sing/common/tls"
)
func NewConn(conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
return nil, os.ErrInvalid
}

View File

@ -1,154 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && go1.25 && !without_badtls
package ktls
import (
"crypto/cipher"
"crypto/tls"
"errors"
"net"
)
func (c *Conn) Write(b []byte) (int, error) {
if !c.kernelTx {
return c.Conn.Write(b)
}
// interlock with Close below
for {
x := c.rawConn.ActiveCall.Load()
if x&1 != 0 {
return 0, net.ErrClosed
}
if c.rawConn.ActiveCall.CompareAndSwap(x, x+2) {
break
}
}
defer c.rawConn.ActiveCall.Add(-2)
//if err := c.Conn.HandshakeContext(context.Background()); err != nil {
// return 0, err
//}
c.rawConn.Out.Lock()
defer c.rawConn.Out.Unlock()
if err := *c.rawConn.Out.Err; err != nil {
return 0, err
}
if !c.rawConn.IsHandshakeComplete.Load() {
return 0, tls.AlertError(alertInternalError)
}
if *c.rawConn.CloseNotifySent {
// return 0, errShutdown
return 0, errors.New("tls: protocol is shutdown")
}
// TLS 1.0 is susceptible to a chosen-plaintext
// attack when using block mode ciphers due to predictable IVs.
// This can be prevented by splitting each Application Data
// record into two records, effectively randomizing the RawIV.
//
// https://www.openssl.org/~bodo/tls-cbc.txt
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
if len(b) > 1 && *c.rawConn.Vers == tls.VersionTLS10 {
if _, ok := (*c.rawConn.Out.Cipher).(cipher.BlockMode); ok {
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.rawConn.Out.SetErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
return n + m, c.rawConn.Out.SetErrorLocked(err)
}
func (c *Conn) writeRecordLocked(typ uint16, data []byte) (n int, err error) {
if !c.kernelTx {
return c.rawConn.WriteRecordLocked(typ, data)
}
/*for len(data) > 0 {
m := len(data)
if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
m = maxPayload
}
_, err = c.writeKernelRecord(typ, data[:m])
if err != nil {
return
}
n += m
data = data[m:]
}*/
return c.writeKernelRecord(typ, data)
}
const (
// tcpMSSEstimate is a conservative estimate of the TCP maximum segment
// size (MSS). A constant is used, rather than querying the kernel for
// the actual MSS, to avoid complexity. The value here is the IPv6
// minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
// bytes) and a TCP header with timestamps (32 bytes).
tcpMSSEstimate = 1208
// recordSizeBoostThreshold is the number of bytes of application data
// sent after which the TLS record size will be increased to the
// maximum.
recordSizeBoostThreshold = 128 * 1024
)
func (c *Conn) maxPayloadSizeForWrite(typ uint16) int {
if /*c.config.DynamicRecordSizingDisabled ||*/ typ != recordTypeApplicationData {
return maxPlaintext
}
if *c.rawConn.PacketsSent >= recordSizeBoostThreshold {
return maxPlaintext
}
// Subtract TLS overheads to get the maximum payload size.
payloadBytes := tcpMSSEstimate - recordHeaderLen - c.rawConn.Out.ExplicitNonceLen()
if rawCipher := *c.rawConn.Out.Cipher; rawCipher != nil {
switch ciph := rawCipher.(type) {
case cipher.Stream:
payloadBytes -= (*c.rawConn.Out.Mac).Size()
case cipher.AEAD:
payloadBytes -= ciph.Overhead()
/*case cbcMode:
blockSize := ciph.BlockSize()
// The payload must fit in a multiple of blockSize, with
// room for at least one padding byte.
payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
// The RawMac is appended before padding so affects the
// payload size directly.
payloadBytes -= c.out.mac.Size()*/
default:
panic("unknown cipher type")
}
}
if *c.rawConn.Vers == tls.VersionTLS13 {
payloadBytes-- // encrypted ContentType
}
// Allow packet growth in arithmetic progression up to max.
pkt := *c.rawConn.PacketsSent
*c.rawConn.PacketsSent++
if pkt > 1000 {
return maxPlaintext // avoid overflow in multiply below
}
n := payloadBytes * int(pkt+1)
if n > maxPlaintext {
n = maxPlaintext
}
return n
}

View File

@ -4,8 +4,6 @@ import (
"context" "context"
"net" "net"
"net/netip" "net/netip"
"runtime"
"strings"
"sync/atomic" "sync/atomic"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@ -16,8 +14,6 @@ import (
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/vishvananda/netns"
) )
type Listener struct { type Listener struct {
@ -32,7 +28,6 @@ type Listener struct {
disablePacketOutput bool disablePacketOutput bool
setSystemProxy bool setSystemProxy bool
systemProxySOCKS bool systemProxySOCKS bool
tproxy bool
tcpListener net.Listener tcpListener net.Listener
systemProxy settings.SystemProxy systemProxy settings.SystemProxy
@ -55,7 +50,6 @@ type Options struct {
DisablePacketOutput bool DisablePacketOutput bool
SetSystemProxy bool SetSystemProxy bool
SystemProxySOCKS bool SystemProxySOCKS bool
TProxy bool
} }
func New( func New(
@ -73,7 +67,6 @@ func New(
disablePacketOutput: options.DisablePacketOutput, disablePacketOutput: options.DisablePacketOutput,
setSystemProxy: options.SetSystemProxy, setSystemProxy: options.SetSystemProxy,
systemProxySOCKS: options.SystemProxySOCKS, systemProxySOCKS: options.SystemProxySOCKS,
tproxy: options.TProxy,
} }
} }
@ -142,30 +135,3 @@ func (l *Listener) UDPConn() *net.UDPConn {
func (l *Listener) ListenOptions() option.ListenOptions { func (l *Listener) ListenOptions() option.ListenOptions {
return l.listenOptions return l.listenOptions
} }
func ListenNetworkNamespace[T any](nameOrPath string, block func() (T, error)) (T, error) {
if nameOrPath != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
currentNs, err := netns.Get()
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "get current netns")
}
defer netns.Set(currentNs)
var targetNs netns.NsHandle
if strings.HasPrefix(nameOrPath, "/") {
targetNs, err = netns.GetFromPath(nameOrPath)
} else {
targetNs, err = netns.GetFromName(nameOrPath)
}
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "get netns ", nameOrPath)
}
defer targetNs.Close()
err = netns.Set(targetNs)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "set netns to ", nameOrPath)
}
}
return block()
}

View File

@ -3,39 +3,23 @@ package listener
import ( import (
"net" "net"
"net/netip" "net/netip"
"syscall"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/redir"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/metacubex/tfo-go" "github.com/metacubex/tfo-go"
) )
func (l *Listener) ListenTCP() (net.Listener, error) { func (l *Listener) ListenTCP() (net.Listener, error) {
//nolint:staticcheck
if l.listenOptions.ProxyProtocol || l.listenOptions.ProxyProtocolAcceptNoHeader {
return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")
}
var err error var err error
bindAddr := M.SocksaddrFrom(l.listenOptions.Listen.Build(netip.AddrFrom4([4]byte{127, 0, 0, 1})), l.listenOptions.ListenPort) bindAddr := M.SocksaddrFrom(l.listenOptions.Listen.Build(netip.AddrFrom4([4]byte{127, 0, 0, 1})), l.listenOptions.ListenPort)
var tcpListener net.Listener
var listenConfig net.ListenConfig var listenConfig net.ListenConfig
if l.listenOptions.BindInterface != "" {
listenConfig.Control = control.Append(listenConfig.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
listenConfig.Control = control.Append(listenConfig.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
}
if l.listenOptions.TCPKeepAlive >= 0 { if l.listenOptions.TCPKeepAlive >= 0 {
keepIdle := time.Duration(l.listenOptions.TCPKeepAlive) keepIdle := time.Duration(l.listenOptions.TCPKeepAlive)
if keepIdle == 0 { if keepIdle == 0 {
@ -53,26 +37,20 @@ func (l *Listener) ListenTCP() (net.Listener, error) {
} }
setMultiPathTCP(&listenConfig) setMultiPathTCP(&listenConfig)
} }
if l.tproxy { if l.listenOptions.TCPFastOpen {
listenConfig.Control = control.Append(listenConfig.Control, func(network, address string, conn syscall.RawConn) error { var tfoConfig tfo.ListenConfig
return control.Raw(conn, func(fd uintptr) error { tfoConfig.ListenConfig = listenConfig
return redir.TProxy(fd, !M.ParseSocksaddr(address).IsIPv4(), false) tcpListener, err = tfoConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
}) } else {
}) tcpListener, err = listenConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
} }
tcpListener, err := ListenNetworkNamespace[net.Listener](l.listenOptions.NetNs, func() (net.Listener, error) { if err == nil {
if l.listenOptions.TCPFastOpen { l.logger.Info("tcp server started at ", tcpListener.Addr())
var tfoConfig tfo.ListenConfig }
tfoConfig.ListenConfig = listenConfig //nolint:staticcheck
return tfoConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String()) if l.listenOptions.ProxyProtocol || l.listenOptions.ProxyProtocolAcceptNoHeader {
} else { return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")
return listenConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
}
})
if err != nil {
return nil, err
} }
l.logger.Info("tcp server started at ", tcpListener.Addr())
l.tcpListener = tcpListener l.tcpListener = tcpListener
return tcpListener, err return tcpListener, err
} }

View File

@ -1,34 +1,20 @@
package listener package listener
import ( import (
"context"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"syscall"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/redir"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
func (l *Listener) ListenUDP() (net.PacketConn, error) { func (l *Listener) ListenUDP() (net.PacketConn, error) {
bindAddr := M.SocksaddrFrom(l.listenOptions.Listen.Build(netip.AddrFrom4([4]byte{127, 0, 0, 1})), l.listenOptions.ListenPort) bindAddr := M.SocksaddrFrom(l.listenOptions.Listen.Build(netip.AddrFrom4([4]byte{127, 0, 0, 1})), l.listenOptions.ListenPort)
var listenConfig net.ListenConfig var lc net.ListenConfig
if l.listenOptions.BindInterface != "" {
listenConfig.Control = control.Append(listenConfig.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
listenConfig.Control = control.Append(listenConfig.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
}
var udpFragment bool var udpFragment bool
if l.listenOptions.UDPFragment != nil { if l.listenOptions.UDPFragment != nil {
udpFragment = *l.listenOptions.UDPFragment udpFragment = *l.listenOptions.UDPFragment
@ -36,18 +22,9 @@ func (l *Listener) ListenUDP() (net.PacketConn, error) {
udpFragment = l.listenOptions.UDPFragmentDefault udpFragment = l.listenOptions.UDPFragmentDefault
} }
if !udpFragment { if !udpFragment {
listenConfig.Control = control.Append(listenConfig.Control, control.DisableUDPFragment()) lc.Control = control.Append(lc.Control, control.DisableUDPFragment())
} }
if l.tproxy { udpConn, err := lc.ListenPacket(l.ctx, M.NetworkFromNetAddr(N.NetworkUDP, bindAddr.Addr), bindAddr.String())
listenConfig.Control = control.Append(listenConfig.Control, func(network, address string, conn syscall.RawConn) error {
return control.Raw(conn, func(fd uintptr) error {
return redir.TProxy(fd, !M.ParseSocksaddr(address).IsIPv4(), true)
})
})
}
udpConn, err := ListenNetworkNamespace[net.PacketConn](l.listenOptions.NetNs, func() (net.PacketConn, error) {
return listenConfig.ListenPacket(l.ctx, M.NetworkFromNetAddr(N.NetworkUDP, bindAddr.Addr), bindAddr.String())
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,36 +34,6 @@ func (l *Listener) ListenUDP() (net.PacketConn, error) {
return udpConn, err return udpConn, err
} }
func (l *Listener) DialContext(dialer net.Dialer, ctx context.Context, network string, address string) (net.Conn, error) {
return ListenNetworkNamespace[net.Conn](l.listenOptions.NetNs, func() (net.Conn, error) {
if l.listenOptions.BindInterface != "" {
dialer.Control = control.Append(dialer.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
dialer.Control = control.Append(dialer.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
dialer.Control = control.Append(dialer.Control, control.ReuseAddr())
}
return dialer.DialContext(ctx, network, address)
})
}
func (l *Listener) ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error) {
return ListenNetworkNamespace[net.PacketConn](l.listenOptions.NetNs, func() (net.PacketConn, error) {
if l.listenOptions.BindInterface != "" {
listenConfig.Control = control.Append(listenConfig.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
listenConfig.Control = control.Append(listenConfig.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
}
return listenConfig.ListenPacket(ctx, network, address)
})
}
func (l *Listener) UDPAddr() M.Socksaddr { func (l *Listener) UDPAddr() M.Socksaddr {
return l.udpAddr return l.udpAddr
} }
@ -164,8 +111,9 @@ func (l *Listener) loopUDPOut() {
if l.shutdown.Load() && E.IsClosed(err) { if l.shutdown.Load() && E.IsClosed(err) {
return return
} }
l.udpConn.Close()
l.logger.Error("udp listener write back: ", destination, ": ", err) l.logger.Error("udp listener write back: ", destination, ": ", err)
continue return
} }
continue continue
case <-l.packetOutboundClosed: case <-l.packetOutboundClosed:

View File

@ -23,7 +23,6 @@ type Config struct {
} }
type Info struct { type Info struct {
ProcessID uint32
ProcessPath string ProcessPath string
PackageName string PackageName string
User string User string

View File

@ -76,8 +76,6 @@ func findProcessName(network string, ip netip.Addr, port int) (string, error) {
// rup8(sizeof(xtcpcb_n)) // rup8(sizeof(xtcpcb_n))
itemSize += 208 itemSize += 208
} }
var fallbackUDPProcess string
// skip the first xinpgen(24 bytes) block // skip the first xinpgen(24 bytes) block
for i := 24; i+itemSize <= len(buf); i += itemSize { for i := 24; i+itemSize <= len(buf); i += itemSize {
// offset of xinpcb_n and xsocket_n // offset of xinpcb_n and xsocket_n
@ -92,34 +90,24 @@ func findProcessName(network string, ip netip.Addr, port int) (string, error) {
flag := buf[inp+44] flag := buf[inp+44]
var srcIP netip.Addr var srcIP netip.Addr
srcIsIPv4 := false
switch { switch {
case flag&0x1 > 0 && isIPv4: case flag&0x1 > 0 && isIPv4:
// ipv4 // ipv4
srcIP = netip.AddrFrom4([4]byte(buf[inp+76 : inp+80])) srcIP = netip.AddrFrom4(*(*[4]byte)(buf[inp+76 : inp+80]))
srcIsIPv4 = true
case flag&0x2 > 0 && !isIPv4: case flag&0x2 > 0 && !isIPv4:
// ipv6 // ipv6
srcIP = netip.AddrFrom16([16]byte(buf[inp+64 : inp+80])) srcIP = netip.AddrFrom16(*(*[16]byte)(buf[inp+64 : inp+80]))
default: default:
continue continue
} }
if ip == srcIP { if ip != srcIP {
// xsocket_n.so_last_pid continue
pid := readNativeUint32(buf[so+68 : so+72])
return getExecPathFromPID(pid)
} }
// udp packet connection may be not equal with srcIP // xsocket_n.so_last_pid
if network == N.NetworkUDP && srcIP.IsUnspecified() && isIPv4 == srcIsIPv4 { pid := readNativeUint32(buf[so+68 : so+72])
pid := readNativeUint32(buf[so+68 : so+72]) return getExecPathFromPID(pid)
fallbackUDPProcess, _ = getExecPathFromPID(pid)
}
}
if network == N.NetworkUDP && len(fallbackUDPProcess) > 0 {
return fallbackUDPProcess, nil
} }
return "", ErrNotFound return "", ErrNotFound

View File

@ -2,11 +2,14 @@ package process
import ( import (
"context" "context"
"fmt"
"net/netip" "net/netip"
"os"
"syscall" "syscall"
"unsafe"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/winiphlpapi" N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@ -23,39 +26,209 @@ func NewSearcher(_ Config) (Searcher, error) {
return &windowsSearcher{}, nil return &windowsSearcher{}, nil
} }
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
)
func initWin32API() error { func initWin32API() error {
return winiphlpapi.LoadExtendedTable() err := modiphlpapi.Load()
if err != nil {
return E.Cause(err, "load iphlpapi.dll")
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedTcpTable")
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedUdpTable")
}
err = modkernel32.Load()
if err != nil {
return E.Cause(err, "load kernel32.dll")
}
err = procQueryFullProcessImageNameW.Find()
if err != nil {
return E.Cause(err, "load kernel32::QueryFullProcessImageNameW")
}
return nil
} }
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
pid, err := winiphlpapi.FindPid(network, source) processName, err := findProcessName(network, source.Addr(), int(source.Port()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
path, err := getProcessPath(pid) return &Info{ProcessPath: processName, UserId: -1}, nil
if err != nil {
return &Info{ProcessID: pid, UserId: -1}, err
}
return &Info{ProcessID: pid, ProcessPath: path, UserId: -1}, nil
} }
func getProcessPath(pid uint32) (string, error) { func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) {
family := windows.AF_INET
if ip.Is6() {
family = windows.AF_INET6
}
const (
tcpTablePidConn = 4
udpTablePid = 1
)
var class int
var fn uintptr
switch network {
case N.NetworkTCP:
fn = procGetExtendedTcpTable.Addr()
class = tcpTablePidConn
case N.NetworkUDP:
fn = procGetExtendedUdpTable.Addr()
class = udpTablePid
default:
return "", os.ErrInvalid
}
buf, err := getTransportTable(fn, family, class)
if err != nil {
return "", err
}
s := newSearcher(family == windows.AF_INET, network == N.NetworkTCP)
pid, err := s.Search(buf, ip, uint16(srcPort))
if err != nil {
return "", err
}
return getExecPathFromPID(pid)
}
type searcher struct {
itemSize int
port int
ip int
ipSize int
pid int
tcpState int
}
func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) {
n := int(readNativeUint32(b[:4]))
itemSize := s.itemSize
for i := 0; i < n; i++ {
row := b[4+itemSize*i : 4+itemSize*(i+1)]
if s.tcpState >= 0 {
tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4])
// MIB_TCP_STATE_ESTAB, only check established connections for TCP
if tcpState != 5 {
continue
}
}
// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
// this field can be illustrated as follows depends on different machine endianess:
// little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB)
// big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB)
// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
if srcPort != port {
continue
}
srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize])
// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
if ip != srcIP && (!srcIP.IsUnspecified() || s.tcpState != -1) {
continue
}
pid := readNativeUint32(row[s.pid : s.pid+4])
return pid, nil
}
return 0, ErrNotFound
}
func newSearcher(isV4, isTCP bool) *searcher {
var itemSize, port, ip, ipSize, pid int
tcpState := -1
switch {
case isV4 && isTCP:
// struct MIB_TCPROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
case isV4 && !isTCP:
// struct MIB_UDPROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
case !isV4 && isTCP:
// struct MIB_TCP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
case !isV4 && !isTCP:
// struct MIB_UDP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
}
return &searcher{
itemSize: itemSize,
port: port,
ip: ip,
ipSize: ipSize,
pid: pid,
tcpState: tcpState,
}
}
func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
for size, buf := uint32(8), make([]byte, 8); ; {
ptr := unsafe.Pointer(&buf[0])
err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
switch err {
case 0:
return buf, nil
case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
buf = make([]byte, size)
default:
return nil, fmt.Errorf("syscall error: %d", err)
}
}
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}
func getExecPathFromPID(pid uint32) (string, error) {
// kernel process starts with a colon in order to distinguish with normal processes
switch pid { switch pid {
case 0: case 0:
// reserved pid for system idle process
return ":System Idle Process", nil return ":System Idle Process", nil
case 4: case 4:
// reserved pid for windows kernel image
return ":System", nil return ":System", nil
} }
handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid) h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
if err != nil { if err != nil {
return "", err return "", err
} }
defer windows.CloseHandle(handle) defer windows.CloseHandle(h)
size := uint32(syscall.MAX_LONG_PATH)
buf := make([]uint16, syscall.MAX_LONG_PATH) buf := make([]uint16, syscall.MAX_LONG_PATH)
err = windows.QueryFullProcessImageName(handle, 0, &buf[0], &size) size := uint32(len(buf))
if err != nil { r1, _, err := syscall.SyscallN(
procQueryFullProcessImageNameW.Addr(),
uintptr(h),
uintptr(0),
uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&size)),
)
if r1 == 0 {
return "", err return "", err
} }
return windows.UTF16ToString(buf[:size]), nil return syscall.UTF16ToString(buf[:size]), nil
} }

View File

@ -12,7 +12,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func TProxy(fd uintptr, isIPv6 bool, isUDP bool) error { func TProxy(fd uintptr, isIPv6 bool) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err == nil { if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
@ -20,13 +20,11 @@ func TProxy(fd uintptr, isIPv6 bool, isUDP bool) error {
if err == nil && isIPv6 { if err == nil && isIPv6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1) err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
} }
if isUDP { if err == nil {
if err == nil { err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1) }
} if err == nil && isIPv6 {
if err == nil && isIPv6 { err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
}
} }
return err return err
} }

Some files were not shown because too many files have changed in this diff Show More