feat(proxy): implement raw TCP/TLS write with HTTP CONNECT tunnel

Rewrite hyper_client with a two-tier strategy for header case preservation:

Primary path (raw write):
- Peek raw TCP bytes in server.rs to capture OriginalHeaderCases before
  hyper lowercases them
- Build raw HTTP/1.1 request bytes with exact original header name casing
- Write directly to TLS stream, then use WriteFilter to let hyper parse
  the response while discarding its duplicate request writes
- Support HTTP CONNECT tunneling through upstream proxies, so header case
  is preserved even when a proxy (Clash, V2Ray) is configured

Fallback path (hyper-util Client):
- Used when OriginalHeaderCases is empty or raw write fails
- Configured with title_case_headers(true) for best-effort casing

TLS improvements:
- Load native system certificates alongside webpki roots so proxy MITM
  CAs (installed in system keychain) are trusted through CONNECT tunnels

Key types added:
- OriginalHeaderCases: maps lowercase name → original wire-casing bytes
- WriteFilter<S>: AsyncRead+AsyncWrite wrapper that discards writes
- connect_via_proxy(): HTTP CONNECT tunnel establishment
- ExtensionDebugMarker: diagnostic marker for extension chain debugging
This commit is contained in:
YoVinchen
2026-03-28 12:17:07 +08:00
parent f4e960253e
commit 3e4c87278f
+418 -19
View File
@@ -3,6 +3,7 @@
//! Uses hyper directly (instead of reqwest) to support:
//! - `preserve_header_case(true)` — keeps original header name casing
//! - Header order preservation via `HeaderCaseMap` extension transfer
//! - `title_case_headers(true)` — fallback if HeaderCaseMap is absent
//!
//! Falls back to reqwest when an upstream proxy (HTTP/SOCKS5) is configured,
//! since hyper-util's legacy client doesn't natively support proxy tunneling.
@@ -15,6 +16,60 @@ use hyper_rustls::HttpsConnectorBuilder;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use std::sync::OnceLock;
/// Debug marker inserted into extensions to verify they survive the chain.
/// If this marker is found in hyper_client but HeaderCaseMap is not used,
/// the issue is in hyper's encoder, not in extension passing.
#[derive(Clone, Debug)]
pub(crate) struct ExtensionDebugMarker;
/// Our own header case map: maps lowercase header name → original wire-casing bytes.
///
/// This is a backup mechanism independent of hyper's internal `HeaderCaseMap` (which is
/// `pub(crate)` and cannot be directly inspected or constructed from outside hyper).
///
/// Populated in `server.rs` by peeking at raw TCP bytes before hyper parses them.
/// Used in `send_request` to manually write headers with original casing when hyper's
/// own mechanism fails.
#[derive(Clone, Debug, Default)]
pub(crate) struct OriginalHeaderCases {
/// Ordered list of (lowercase_name, original_wire_bytes) pairs.
/// Multiple entries with the same name are allowed (for repeated headers).
pub cases: Vec<(String, Vec<u8>)>,
}
impl OriginalHeaderCases {
/// Parse raw HTTP request bytes (from TcpStream::peek) to extract original header casings.
pub fn from_raw_bytes(buf: &[u8]) -> Self {
let mut headers_buf = [httparse::EMPTY_HEADER; 128];
let mut req = httparse::Request::new(&mut headers_buf);
// We don't care if parsing is partial — we just want the header names we can get
let _ = req.parse(buf);
let mut cases = Vec::new();
for header in req.headers.iter() {
if header.name.is_empty() {
break;
}
cases.push((
header.name.to_ascii_lowercase(),
header.name.as_bytes().to_vec(),
));
}
Self { cases }
}
/// Look up the original casing for a header name (case-insensitive).
/// Returns an iterator over all original casings for that name.
pub fn get_all<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a [u8]> + 'a {
let lower = name.to_ascii_lowercase();
self.cases
.iter()
.filter(move |(k, _)| *k == lower)
.map(|(_, v)| v.as_slice())
}
}
type HyperClient = Client<
hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
http_body_util::Full<Bytes>,
@@ -32,6 +87,7 @@ fn global_hyper_client() -> &'static HyperClient {
Client::builder(TokioExecutor::new())
.http1_preserve_header_case(true)
.http1_title_case_headers(true)
.build(connector)
})
}
@@ -129,11 +185,21 @@ impl ProxyResponse {
}
}
/// Send an HTTP request via the global hyper client (with header-case preservation).
/// Send an HTTP request with header-case preservation.
///
/// `original_extensions` should carry the `HeaderCaseMap` populated by the
/// server-side hyper parser (via `preserve_header_case(true)`).
/// The hyper client will read it back and serialise headers with the original casing.
/// Uses a two-tier strategy:
/// 1. Primary: raw HTTP/1.1 write via TLS stream with exact original header casing
/// (from `OriginalHeaderCases` captured by peek in server.rs), then hand off to
/// hyper for response parsing.
/// 2. Fallback: hyper-util Client with `title_case_headers(true)` when raw write
/// isn't feasible (e.g., missing original cases).
///
/// The caller is expected to include `Host` in the supplied `headers` at the
/// correct position.
///
/// `proxy_url`: optional upstream HTTP proxy URL (e.g. `http://127.0.0.1:7890`).
/// When set, the raw write path uses HTTP CONNECT tunneling through the proxy,
/// so header-case preservation works even when an upstream proxy is configured.
pub async fn send_request(
uri: http::Uri,
method: http::Method,
@@ -141,35 +207,368 @@ pub async fn send_request(
original_extensions: http::Extensions,
body: Vec<u8>,
timeout: std::time::Duration,
proxy_url: Option<&str>,
) -> Result<ProxyResponse, ProxyError> {
// Extract our own OriginalHeaderCases if available
let original_cases = original_extensions.get::<OriginalHeaderCases>().cloned();
let has_cases = original_cases
.as_ref()
.map(|c| !c.cases.is_empty())
.unwrap_or(false);
log::debug!(
"[HyperClient] Sending request: uri={uri}, header_count={}, \
has_host={}, has_original_cases={has_cases}, proxy={:?}",
headers.len(),
headers.contains_key(http::header::HOST),
proxy_url,
);
if has_cases {
// Primary path: use raw write + hyper handshake for exact header casing
let result = tokio::time::timeout(
timeout,
send_raw_request(
&uri,
&method,
&headers,
original_cases.as_ref().unwrap(),
&body,
proxy_url,
),
)
.await
.map_err(|_| ProxyError::Timeout(format!("请求超时: {}s", timeout.as_secs())))?;
match result {
Ok(resp) => return Ok(resp),
Err(e) => {
log::warn!("[HyperClient] Raw write failed, falling back to hyper-util: {e}");
// Fall through to hyper-util Client
}
}
}
// Fallback: hyper-util Client (title-case headers, no proxy support)
let mut req = http::Request::builder()
.method(method)
.uri(&uri)
.body(http_body_util::Full::new(Bytes::from(body)))
.map_err(|e| ProxyError::ForwardFailed(format!("Failed to build request: {e}")))?;
// Set headers (order is preserved by http::HeaderMap insertion order)
*req.headers_mut() = headers;
// Transfer extensions from the incoming request — this carries the internal
// `HeaderCaseMap` that tells the hyper client how to case each header name.
// Debug: check extension count before transfer
log::debug!("[HyperClient] Transferring extensions to outgoing request (uri={uri})");
*req.extensions_mut() = original_extensions;
let client = global_hyper_client();
let resp = tokio::time::timeout(timeout, client.request(req))
.await
.map_err(|_| ProxyError::Timeout(format!("请求超时: {}s", timeout.as_secs())))?
.map_err(|e| {
let msg = e.to_string();
if msg.contains("connect") {
ProxyError::ForwardFailed(format!("连接失败: {e}"))
} else {
ProxyError::ForwardFailed(e.to_string())
}
})?;
.map_err(|e| ProxyError::ForwardFailed(format!("上游请求失败: {e}")))?;
Ok(ProxyResponse::Hyper(resp))
}
/// Send request via raw TCP/TLS with exact original header casing.
///
/// When `proxy_url` is provided, establishes an HTTP CONNECT tunnel through
/// the proxy first, then performs TLS + raw write through the tunnel.
/// This preserves header casing even when an upstream proxy is configured.
async fn send_raw_request(
uri: &http::Uri,
method: &http::Method,
headers: &http::HeaderMap,
original_cases: &OriginalHeaderCases,
body: &[u8],
proxy_url: Option<&str>,
) -> Result<ProxyResponse, ProxyError> {
use tokio::io::AsyncWriteExt;
let scheme = uri.scheme_str().unwrap_or("https");
let host = uri
.host()
.ok_or_else(|| ProxyError::ForwardFailed("URI has no host".into()))?;
let port = uri
.port_u16()
.unwrap_or(if scheme == "https" { 443 } else { 80 });
let path_and_query = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
// Build raw HTTP request bytes
let raw = build_raw_request(method, path_and_query, headers, original_cases, body);
// Establish TCP connection — either direct or through HTTP CONNECT proxy
let tcp = if let Some(proxy) = proxy_url {
connect_via_proxy(proxy, host, port).await?
} else {
tokio::net::TcpStream::connect((host, port))
.await
.map_err(|e| ProxyError::ForwardFailed(format!("TCP connect failed: {e}")))?
};
if scheme == "https" {
let tls_connector = global_tls_connector();
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())
.map_err(|e| ProxyError::ForwardFailed(format!("Invalid server name: {e}")))?;
let mut tls_stream = tls_connector
.connect(server_name, tcp)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("TLS handshake failed: {e}")))?;
tls_stream
.write_all(&raw)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("Write failed: {e}")))?;
let filtered = WriteFilter::new(tls_stream);
do_hyper_response(filtered, method.clone()).await
} else {
let mut tcp_stream = tcp;
tcp_stream
.write_all(&raw)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("Write failed: {e}")))?;
let filtered = WriteFilter::new(tcp_stream);
do_hyper_response(filtered, method.clone()).await
}
}
/// Establish a TCP connection through an HTTP CONNECT proxy tunnel.
///
/// 1. Connect TCP to the proxy server
/// 2. Send `CONNECT host:port HTTP/1.1`
/// 3. Read the proxy's 200 response
/// 4. Return the tunneled TCP stream (ready for TLS handshake + raw write)
async fn connect_via_proxy(
proxy_url: &str,
target_host: &str,
target_port: u16,
) -> Result<tokio::net::TcpStream, ProxyError> {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let parsed = url::Url::parse(proxy_url)
.map_err(|e| ProxyError::ForwardFailed(format!("Invalid proxy URL: {e}")))?;
let proxy_host = parsed
.host_str()
.ok_or_else(|| ProxyError::ForwardFailed("Proxy URL has no host".into()))?;
let proxy_port = parsed
.port()
.unwrap_or(if parsed.scheme() == "https" { 443 } else { 80 });
// Connect to the proxy
let mut tcp = tokio::net::TcpStream::connect((proxy_host, proxy_port))
.await
.map_err(|e| ProxyError::ForwardFailed(format!("Proxy TCP connect failed: {e}")))?;
// Send CONNECT request
let connect_req = format!(
"CONNECT {target_host}:{target_port} HTTP/1.1\r\n\
Host: {target_host}:{target_port}\r\n\
\r\n"
);
tcp.write_all(connect_req.as_bytes())
.await
.map_err(|e| ProxyError::ForwardFailed(format!("CONNECT write failed: {e}")))?;
// Read the proxy's response status line
let mut reader = BufReader::new(&mut tcp);
let mut status_line = String::new();
reader
.read_line(&mut status_line)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("CONNECT read failed: {e}")))?;
// Expect "HTTP/1.1 200 ..." or "HTTP/1.0 200 ..."
if !status_line.contains(" 200 ") {
return Err(ProxyError::ForwardFailed(format!(
"Proxy CONNECT rejected: {}",
status_line.trim()
)));
}
// Drain remaining response headers (until empty line)
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("CONNECT header read: {e}")))?;
if line.trim().is_empty() {
break;
}
}
// BufReader might have buffered data; drop it to get raw tcp back.
// Since CONNECT response is headers-only (no body), and we read until \r\n\r\n,
// the BufReader buffer should be empty at this point.
drop(reader);
log::debug!(
"[HyperClient] CONNECT tunnel established via {proxy_host}:{proxy_port} -> {target_host}:{target_port}"
);
Ok(tcp)
}
/// Lazily-initialized TLS connector for raw connections.
///
/// Loads both webpki roots AND native system certificates so that
/// proxy MITM CAs (e.g. Clash, mitmproxy) installed in the system
/// keychain are trusted through the CONNECT tunnel.
fn global_tls_connector() -> &'static tokio_rustls::TlsConnector {
static CONNECTOR: OnceLock<tokio_rustls::TlsConnector> = OnceLock::new();
CONNECTOR.get_or_init(|| {
let mut root_store = rustls::RootCertStore::empty();
// Baseline: Mozilla/webpki roots
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
// Native system certs (includes user-installed proxy CAs)
let native = rustls_native_certs::load_native_certs();
let (added, _errors) = root_store.add_parsable_certificates(native.certs);
log::debug!("[HyperClient] TLS root store: webpki + {added} native certs");
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
tokio_rustls::TlsConnector::from(std::sync::Arc::new(config))
})
}
/// Build raw HTTP/1.1 request bytes with original header casing.
fn build_raw_request(
method: &http::Method,
path_and_query: &str,
headers: &http::HeaderMap,
original_cases: &OriginalHeaderCases,
body: &[u8],
) -> Vec<u8> {
let mut raw = Vec::with_capacity(4096 + body.len());
// Request line
raw.extend_from_slice(method.as_str().as_bytes());
raw.extend_from_slice(b" ");
raw.extend_from_slice(path_and_query.as_bytes());
raw.extend_from_slice(b" HTTP/1.1\r\n");
// Headers with original casing
for name in headers.keys() {
let name_str = name.as_str();
let mut case_iter = original_cases.get_all(name_str);
for value in headers.get_all(name) {
if let Some(orig_name_bytes) = case_iter.next() {
raw.extend_from_slice(orig_name_bytes);
} else {
// Header not in original request (added by proxy) — use lowercase
raw.extend_from_slice(name_str.as_bytes());
}
raw.extend_from_slice(b": ");
raw.extend_from_slice(value.as_bytes());
raw.extend_from_slice(b"\r\n");
}
}
// Add Content-Length if not already present
if !headers.contains_key(http::header::CONTENT_LENGTH) {
raw.extend_from_slice(b"Content-Length: ");
raw.extend_from_slice(body.len().to_string().as_bytes());
raw.extend_from_slice(b"\r\n");
}
// End of headers + body
raw.extend_from_slice(b"\r\n");
raw.extend_from_slice(body);
raw
}
/// Use hyper's low-level client to parse the response on a stream where we've
/// already written the request.
///
/// `WriteFilter` discards any writes from hyper (it would try to send its own
/// request encoding), while passing reads through transparently.
async fn do_hyper_response<S>(
stream: WriteFilter<S>,
method: http::Method,
) -> Result<ProxyResponse, ProxyError>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let io = hyper_util::rt::TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
.preserve_header_case(true)
.handshake::<_, http_body_util::Full<Bytes>>(io)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("Handshake failed: {e}")))?;
// Spawn the connection driver (reads responses from the stream)
tokio::spawn(async move {
if let Err(e) = conn.await {
log::debug!("[HyperClient] raw conn driver error: {e}");
}
});
// Send a dummy request through hyper — hyper will encode this and try to write it,
// but WriteFilter discards all writes. Hyper will then read the response from the stream.
let dummy_req = http::Request::builder()
.method(method)
.uri("/")
.body(http_body_util::Full::new(Bytes::new()))
.map_err(|e| ProxyError::ForwardFailed(format!("Build dummy request: {e}")))?;
let resp = sender
.send_request(dummy_req)
.await
.map_err(|e| ProxyError::ForwardFailed(format!("Response parse failed: {e}")))?;
Ok(ProxyResponse::Hyper(resp))
}
/// A stream wrapper that discards all writes but passes reads through.
///
/// This lets hyper's connection driver think it sent a request (its encoded bytes
/// go to /dev/null), while correctly parsing the response that the upstream server
/// sends in reply to our raw-written request.
struct WriteFilter<S> {
inner: S,
}
impl<S> WriteFilter<S> {
fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for WriteFilter<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
// Pass reads through to the underlying stream
let inner = std::pin::Pin::new(&mut self.get_mut().inner);
inner.poll_read(cx, buf)
}
}
impl<S: Unpin> tokio::io::AsyncWrite for WriteFilter<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
// Discard all writes — pretend they succeeded
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}