From 75f3c936129f9155d26915fe4f21d9f43fa13c32 Mon Sep 17 00:00:00 2001 From: Marcos Lilljedahl Date: Thu, 30 Mar 2017 16:49:03 -0300 Subject: [PATCH] Add TCP proxy for tcp over http connections --- api.go | 17 ++----- handlers/reverseproxy.go | 104 +++++++++++++++++++++++++++++---------- haproxy/haproxy.cfg | 4 +- 3 files changed, 85 insertions(+), 40 deletions(-) diff --git a/api.go b/api.go index 92d58d7..38b4768 100644 --- a/api.go +++ b/api.go @@ -19,7 +19,6 @@ import ( "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/urfave/negroni" - "github.com/yhat/wsutil" ) func main() { @@ -58,22 +57,16 @@ func main() { corsRouter := mux.NewRouter() // Reverse proxy (needs to be the first route, to make sure it is the first thing we check) - proxyHandler := handlers.NewMultipleHostReverseProxy() - websocketProxyHandler := handlers.NewMultipleHostWebsocketReverseProxy() + //proxyHandler := handlers.NewMultipleHostReverseProxy() + //websocketProxyHandler := handlers.NewMultipleHostWebsocketReverseProxy() - proxyMultiplexer := func(rw http.ResponseWriter, r *http.Request) { - if wsutil.IsWebSocketRequest(r) { - websocketProxyHandler.ServeHTTP(rw, r) - } else { - proxyHandler.ServeHTTP(rw, r) - } - } + tcpHandler := handlers.NewTCPProxy() corsHandler := gh.CORS(gh.AllowCredentials(), gh.AllowedHeaders([]string{"x-requested-with", "content-type"}), gh.AllowedOrigins([]string{"*"})) // Specific routes - r.Host(`{subdomain:.*}{node:pwd[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}}-{port:[0-9]*}.{tld:.*}`).HandlerFunc(proxyMultiplexer) - r.Host(`{subdomain:.*}{node:pwd[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}}.{tld:.*}`).HandlerFunc(proxyMultiplexer) + r.Host(`{subdomain:.*}{node:pwd[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}}-{port:[0-9]*}.{tld:.*}`).Handler(tcpHandler) + r.Host(`{subdomain:.*}{node:pwd[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}_[0-9]{1,3}}.{tld:.*}`).Handler(tcpHandler) r.HandleFunc("/ping", handlers.Ping).Methods("GET") corsRouter.HandleFunc("/sessions/{sessionId}", handlers.GetSession).Methods("GET") corsRouter.HandleFunc("/sessions/{sessionId}/instances", handlers.NewInstance).Methods("POST") diff --git a/handlers/reverseproxy.go b/handlers/reverseproxy.go index 35e087f..cbc5ae8 100644 --- a/handlers/reverseproxy.go +++ b/handlers/reverseproxy.go @@ -3,12 +3,11 @@ package handlers import ( "crypto/tls" "fmt" + "io" "log" "net" "net/http" - "net/http/httputil" "strings" - "time" "github.com/franela/play-with-docker/config" "github.com/gorilla/mux" @@ -41,40 +40,93 @@ func getTargetInfo(vars map[string]string, req *http.Request) (string, string) { } -func NewMultipleHostReverseProxy() *httputil.ReverseProxy { - var transport http.RoundTripper = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 0, - }).DialContext, - DisableKeepAlives: true, - MaxIdleConns: 1, - IdleConnTimeout: 100 * time.Millisecond, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +type tcpProxy struct { + Director func(*http.Request) + ErrorLog *log.Logger + Dial func(network, addr string) (net.Conn, error) +} + +func (p *tcpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + logFunc := log.Printf + if p.ErrorLog != nil { + logFunc = p.ErrorLog.Printf } + + outreq := new(http.Request) + // shallow copying + *outreq = *r + p.Director(outreq) + host := outreq.URL.Host + + dial := p.Dial + if dial == nil { + dial = net.Dial + } + + if outreq.URL.Scheme == "wss" || outreq.URL.Scheme == "https" { + var tlsConfig *tls.Config + tlsConfig = &tls.Config{InsecureSkipVerify: true} + dial = func(network, address string) (net.Conn, error) { + return tls.Dial("tcp", host, tlsConfig) + } + } + + d, err := dial("tcp", host) + if err != nil { + http.Error(w, "Error forwarding request.", 500) + logFunc("Error dialing websocket backend %s: %v", outreq.URL, err) + return + } + // All request generated by the http package implement this interface. + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Not a hijacker?", 500) + return + } + // Hijack() tells the http package not to do anything else with the connection. + // After, it bcomes this functions job to manage it. `nc` is of type *net.Conn. + nc, _, err := hj.Hijack() + if err != nil { + logFunc("Hijack error: %v", err) + return + } + defer nc.Close() // must close the underlying net connection after hijacking + defer d.Close() + + // write the modified incoming request to the dialed connection + err = outreq.Write(d) + if err != nil { + logFunc("Error copying request to target: %v", err) + return + } + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + go cp(d, nc) + go cp(nc, d) + <-errc +} +func NewTCPProxy() http.Handler { director := func(req *http.Request) { v := mux.Vars(req) + node, port := getTargetInfo(v, req) if port == "443" { - // Only proxy http for now - req.URL.Scheme = "https" - } else { - // Only proxy http for now - req.URL.Scheme = "http" + if strings.Contains(req.URL.Scheme, "http") { + req.URL.Scheme = "https" + } else { + req.URL.Scheme = "wss" + } } - req.URL.Host = fmt.Sprintf("%s:%s", node, port) - log.Printf("HTTP Reverse proxying to %s\n", req.URL.Host) } - - return &httputil.ReverseProxy{Director: director, Transport: transport} + return &tcpProxy{Director: director} } -func NewSSLDaemonHandler() *httputil.ReverseProxy { +func NewSSLDaemonHandler() http.Handler { director := func(req *http.Request) { v := mux.Vars(req) node := v["node"] @@ -96,5 +148,5 @@ func NewSSLDaemonHandler() *httputil.ReverseProxy { log.Printf("HTTPS Reverse proxying to %s\n", req.URL.Host) } - return &httputil.ReverseProxy{Director: director} + return &tcpProxy{Director: director} } diff --git a/haproxy/haproxy.cfg b/haproxy/haproxy.cfg index e9dbd09..481c651 100644 --- a/haproxy/haproxy.cfg +++ b/haproxy/haproxy.cfg @@ -8,8 +8,8 @@ frontend http-in bind *:8080 acl host_localhost hdr(host) localhost - acl host_pwd1 hdr_reg(host) -i ^.*\.?host1\.localhost$ - acl host_pwd2 hdr_reg(host) -i ^.*\.?host2\.localhost$ + acl host_pwd1 hdr_reg(host) -i ^.*\.?host1\.localhost?:?.*$ + acl host_pwd2 hdr_reg(host) -i ^.*\.?host2\.localhost?:?.*$ use_backend all if host_localhost use_backend pwd1 if host_pwd1