from __future__ import annotations import base64 import json import re from typing import Any, Iterable from urllib.parse import parse_qs, unquote, urlparse import httpx import yaml from app.config import get_settings from app.models import FetchResult, ProviderDocument, SourceConfig, SourceSnapshot from app.services.cache import TTLCache from app.services.headers import parse_subscription_userinfo _fetch_cache: TTLCache[FetchResult] = TTLCache() _provider_cache: TTLCache[ProviderDocument] = TTLCache() _snapshot_cache: TTLCache[SourceSnapshot] = TTLCache() async def fetch_source(name: str, source: SourceConfig) -> FetchResult: settings = get_settings() ttl = source.cache_ttl_seconds or settings.cache_ttl_seconds cached = _fetch_cache.get(name) if cached is not None: return cached headers = {"User-Agent": settings.default_user_agent} headers.update(source.headers) async with httpx.AsyncClient(timeout=settings.request_timeout_seconds, follow_redirects=True) as client: response = await client.get(source.url, headers=headers) response.raise_for_status() result = FetchResult(text=response.text, headers=dict(response.headers)) _fetch_cache.set(name, result, ttl) return result async def build_provider_document(name: str, source: SourceConfig) -> ProviderDocument: settings = get_settings() ttl = source.cache_ttl_seconds or settings.cache_ttl_seconds cache_key = f"provider:{name}" cached = _provider_cache.get(cache_key) if cached is not None: return cached fetched = await fetch_source(name, source) proxies = parse_source_proxies(fetched.text, source.kind) proxies = transform_proxies(proxies, source, settings.max_proxy_name_length) document = ProviderDocument(proxies=proxies) _provider_cache.set(cache_key, document, ttl) return document async def build_source_snapshot(name: str, source: SourceConfig) -> SourceSnapshot: settings = get_settings() ttl = source.cache_ttl_seconds or settings.cache_ttl_seconds cache_key = f"snapshot:{name}" cached = _snapshot_cache.get(cache_key) if cached is not None: return cached fetched = await fetch_source(name, source) document = await build_provider_document(name, source) snapshot = SourceSnapshot( name=name, display_name=source.display_name or name, document=document, headers=fetched.headers, quota=parse_subscription_userinfo(fetched.headers), ) _snapshot_cache.set(cache_key, snapshot, ttl) return snapshot async def build_source_snapshots(source_items: Iterable[tuple[str, SourceConfig]]) -> list[SourceSnapshot]: snapshots: list[SourceSnapshot] = [] for name, source in source_items: snapshots.append(await build_source_snapshot(name, source)) return snapshots async def build_merged_provider_document(source_items: Iterable[tuple[str, SourceConfig]]) -> ProviderDocument: snapshots = await build_source_snapshots(source_items) proxies: list[dict[str, Any]] = [] seen: set[str] = set() for snapshot in snapshots: for proxy in snapshot.document.proxies: candidate = dict(proxy) name = str(candidate.get("name", "")).strip() if not name: continue original = name index = 2 while name in seen: name = f"{original} #{index}" index += 1 candidate["name"] = name seen.add(name) proxies.append(candidate) return ProviderDocument(proxies=proxies) async def get_first_quota(source_items: Iterable[tuple[str, SourceConfig]]): source_list = list(source_items) if not source_list: return None snapshot = await build_source_snapshot(source_list[0][0], source_list[0][1]) return snapshot.quota def parse_clash_yaml_proxies(text: str) -> list[dict[str, Any]]: data = yaml.safe_load(text) if not isinstance(data, dict): raise ValueError("Upstream YAML must be a mapping with a top-level 'proxies' field") proxies = data.get("proxies") if not isinstance(proxies, list): raise ValueError("Upstream YAML must contain a list field named 'proxies'") normalized: list[dict[str, Any]] = [] for item in proxies: if not isinstance(item, dict): continue if not item.get("name") or not item.get("type"): continue normalized.append(item) return normalized def parse_source_proxies(text: str, source_kind: str) -> list[dict[str, Any]]: parsers: dict[str, list] = { "auto": [parse_clash_yaml_proxies, parse_base64_uri_proxies, parse_uri_text_proxies], "clash_yaml": [parse_clash_yaml_proxies], "base64_uri": [parse_base64_uri_proxies, parse_uri_text_proxies], "uri": [parse_uri_text_proxies, parse_base64_uri_proxies], } errors: list[str] = [] for parser in parsers.get(source_kind, []): try: proxies = parser(text) if proxies: return proxies except Exception as exc: # noqa: BLE001 errors.append(f"{parser.__name__}: {exc}") detail = "; ".join(errors) if errors else f"unsupported source kind: {source_kind}" raise ValueError(f"Failed to parse upstream subscription: {detail}") def parse_base64_uri_proxies(text: str) -> list[dict[str, Any]]: decoded = decode_base64_subscription(text) return parse_uri_text_proxies(decoded) def parse_uri_text_proxies(text: str) -> list[dict[str, Any]]: candidates = [ line.strip() for line in text.splitlines() if line.strip() and not line.strip().startswith("#") ] proxies: list[dict[str, Any]] = [] unsupported: set[str] = set() for line in candidates: if "://" not in line: continue scheme = line.split("://", 1)[0].lower() if scheme == "anytls": proxies.append(parse_anytls_uri(line)) continue if scheme == "trojan": proxies.append(parse_trojan_uri(line)) continue if scheme == "vless": proxies.append(parse_vless_uri(line)) continue if scheme == "ss": proxies.append(parse_ss_uri(line)) continue if scheme == "vmess": proxies.append(parse_vmess_uri(line)) continue unsupported.add(scheme) if not proxies: detail = f"unsupported URI schemes: {', '.join(sorted(unsupported))}" if unsupported else "no proxy URIs found" raise ValueError(f"Base64 subscription parsing failed: {detail}") return proxies def decode_base64_subscription(text: str) -> str: compact = "".join(text.strip().split()) if not compact: raise ValueError("Base64 subscription is empty") padded = compact + ("=" * (-len(compact) % 4)) try: return base64.b64decode(padded, validate=False).decode("utf-8") except Exception as exc: # noqa: BLE001 raise ValueError("Upstream content is not valid base64 subscription text") from exc def parse_anytls_uri(uri: str) -> dict[str, Any]: parsed = urlparse(uri) server = parsed.hostname port = parsed.port password = unquote(parsed.username or "") if not server or not port or not password: raise ValueError("Invalid anytls URI: missing server, port, or password") params = parse_qs(parsed.query, keep_blank_values=False) proxy: dict[str, Any] = { "name": unquote(parsed.fragment or f"{server}:{port}"), "type": "anytls", "server": server, "port": port, "password": password, "udp": True, } sni = _first_param(params, "sni", "serverName", "servername", "peer") if sni: proxy["sni"] = sni fingerprint = _first_param(params, "fp", "fingerprint", "client-fingerprint", "clientFingerprint") if fingerprint: proxy["client-fingerprint"] = fingerprint insecure = _first_param(params, "insecure", "allowInsecure", "skip-cert-verify") if insecure is not None: proxy["skip-cert-verify"] = insecure.lower() in {"1", "true", "yes"} udp = _first_param(params, "udp") if udp is not None: proxy["udp"] = udp.lower() in {"1", "true", "yes"} alpn = _first_param(params, "alpn") if alpn: proxy["alpn"] = [item.strip() for item in alpn.split(",") if item.strip()] for uri_key, proxy_key in ( ("idle-session-check-interval", "idle-session-check-interval"), ("idleSessionCheckInterval", "idle-session-check-interval"), ("idle-session-timeout", "idle-session-timeout"), ("idleSessionTimeout", "idle-session-timeout"), ("min-idle-session", "min-idle-session"), ("minIdleSession", "min-idle-session"), ): value = _first_param(params, uri_key) if value is None: continue try: proxy[proxy_key] = int(value) except ValueError: continue return proxy def parse_trojan_uri(uri: str) -> dict[str, Any]: parsed = urlparse(uri) server = parsed.hostname port = parsed.port password = unquote(parsed.username or "") if not server or not port or not password: raise ValueError("Invalid trojan URI: missing server, port, or password") params = parse_qs(parsed.query, keep_blank_values=False) proxy: dict[str, Any] = { "name": unquote(parsed.fragment or f"{server}:{port}"), "type": "trojan", "server": server, "port": port, "password": password, "udp": True, } _apply_tls_like_params(proxy, params, default_sni=server) network = _first_param(params, "type", "network") if network: proxy["network"] = network ws_path = _first_param(params, "path") if ws_path and proxy.get("network") == "ws": proxy["ws-opts"] = {"path": ws_path} host = _first_param(params, "host", "Host") if host: proxy["ws-opts"]["headers"] = {"Host": host} return proxy def parse_vless_uri(uri: str) -> dict[str, Any]: parsed = urlparse(uri) server = parsed.hostname port = parsed.port uuid = unquote(parsed.username or "") if not server or not port or not uuid: raise ValueError("Invalid vless URI: missing server, port, or uuid") params = parse_qs(parsed.query, keep_blank_values=False) proxy: dict[str, Any] = { "name": unquote(parsed.fragment or f"{server}:{port}"), "type": "vless", "server": server, "port": port, "uuid": uuid, "udp": True, } network = _first_param(params, "type", "network") or "tcp" proxy["network"] = network flow = _first_param(params, "flow") if flow: proxy["flow"] = flow if (_first_param(params, "security") or "").lower() == "reality": public_key = _first_param(params, "pbk", "public-key") if public_key: proxy["reality-opts"] = {"public-key": public_key} short_id = _first_param(params, "sid", "short-id") if short_id: proxy["reality-opts"]["short-id"] = short_id else: _apply_tls_like_params(proxy, params, default_sni=server) if network == "ws": proxy["ws-opts"] = {"path": _first_param(params, "path") or "/"} host = _first_param(params, "host", "Host") if host: proxy["ws-opts"]["headers"] = {"Host": host} return proxy def parse_ss_uri(uri: str) -> dict[str, Any]: rest = uri[len("ss://") :] fragment = "" if "#" in rest: rest, fragment = rest.split("#", 1) query = "" if "?" in rest: rest, query = rest.split("?", 1) if "@" not in rest: decoded = decode_base64_subscription(rest) if "@" not in decoded: raise ValueError("Invalid ss URI: missing server section") userinfo, server_part = decoded.rsplit("@", 1) else: userinfo, server_part = rest.rsplit("@", 1) try: userinfo = decode_base64_subscription(userinfo) except ValueError: userinfo = unquote(userinfo) if ":" not in userinfo or ":" not in server_part: raise ValueError("Invalid ss URI: malformed credentials or server") cipher, password = userinfo.split(":", 1) server, port_text = server_part.rsplit(":", 1) params = parse_qs(query, keep_blank_values=False) proxy: dict[str, Any] = { "name": unquote(fragment or f"{server}:{port_text}"), "type": "ss", "server": server.strip("[]"), "port": int(port_text), "cipher": cipher, "password": password, "udp": True, } plugin = _first_param(params, "plugin") if plugin: proxy["plugin"] = plugin.split(";", 1)[0] plugin_opts: dict[str, Any] = {} for item in plugin.split(";")[1:]: if "=" not in item: continue key, value = item.split("=", 1) plugin_opts[key] = value if plugin_opts: proxy["plugin-opts"] = plugin_opts return proxy def parse_vmess_uri(uri: str) -> dict[str, Any]: raw = uri[len("vmess://") :] decoded = decode_base64_subscription(raw) try: data = json.loads(decoded) except json.JSONDecodeError as exc: raise ValueError("Invalid vmess URI JSON payload") from exc server = str(data.get("add") or "").strip() port = int(str(data.get("port") or "0")) uuid = str(data.get("id") or "").strip() if not server or not port or not uuid: raise ValueError("Invalid vmess URI: missing add, port, or id") network = str(data.get("net") or "tcp").strip() or "tcp" proxy: dict[str, Any] = { "name": str(data.get("ps") or f"{server}:{port}"), "type": "vmess", "server": server, "port": port, "uuid": uuid, "alterId": int(str(data.get("aid") or "0")), "cipher": "auto", "udp": True, "network": network, } if str(data.get("tls") or "").lower() == "tls": proxy["tls"] = True sni = str(data.get("sni") or data.get("host") or "").strip() if sni: proxy["servername"] = sni if str(data.get("allowInsecure") or "").lower() in {"1", "true"}: proxy["skip-cert-verify"] = True if network == "ws": proxy["ws-opts"] = {"path": str(data.get("path") or "/")} host = str(data.get("host") or "").strip() if host: proxy["ws-opts"]["headers"] = {"Host": host} return proxy def _apply_tls_like_params(proxy: dict[str, Any], params: dict[str, list[str]], *, default_sni: str | None = None) -> None: security = (_first_param(params, "security") or "").lower() if security in {"tls", "xtls"} or any(key in params for key in ("sni", "peer", "allowInsecure", "insecure")): proxy["tls"] = True sni = _first_param(params, "sni", "peer", "serverName", "servername") or default_sni if sni: proxy["sni"] = sni proxy["servername"] = sni insecure = _first_param(params, "insecure", "allowInsecure", "skip-cert-verify") if insecure is not None: proxy["skip-cert-verify"] = insecure.lower() in {"1", "true", "yes"} alpn = _first_param(params, "alpn") if alpn: proxy["alpn"] = [item.strip() for item in alpn.split(",") if item.strip()] fp = _first_param(params, "fp", "fingerprint", "client-fingerprint", "clientFingerprint") if fp: proxy["client-fingerprint"] = fp def _first_param(params: dict[str, list[str]], *keys: str) -> str | None: for key in keys: values = params.get(key) if values: return unquote(values[0]) return None def transform_proxies( proxies: list[dict[str, Any]], source: SourceConfig, max_proxy_name_length: int ) -> list[dict[str, Any]]: include = re.compile(source.include_regex) if source.include_regex else None exclude = re.compile(source.exclude_regex) if source.exclude_regex else None transformed: list[dict[str, Any]] = [] seen: dict[str, int] = {} for proxy in proxies: name = str(proxy.get("name", "")).strip() if not name: continue if include and not include.search(name): continue if exclude and exclude.search(name): continue new_proxy = dict(proxy) new_name = f"{source.prefix}{name}{source.suffix}".strip() if len(new_name) > max_proxy_name_length: new_name = new_name[:max_proxy_name_length].rstrip() count = seen.get(new_name, 0) + 1 seen[new_name] = count if count > 1: new_name = f"{new_name} #{count}" new_proxy["name"] = new_name transformed.append(new_proxy) return transformed def dump_provider_yaml(document: ProviderDocument) -> str: return yaml.safe_dump( {"proxies": document.proxies}, allow_unicode=True, sort_keys=False, default_flow_style=False, )