package sslh import ( "fmt" "strings" "sslh-multiplex-lab/internal/services" ) type ProtocolRoute struct { Name string Host string Port string Probe string SNIHostnames []string ALPNProtocols []string RegexPatterns []string ProxyProtocol bool LogLevel int Fork bool } type Config struct { Verbose int Foreground bool Listen []ListenAddress Protocols []ProtocolRoute Timeout int OnTimeout string } type ListenAddress struct { Host string Port string MaxConnections int // Limit concurrent connections per listen address (DoS protection) } func GenerateConfig(svcs []services.Service, serverIP, domain string) (*Config, error) { // Set max_connections per listen address to protect against DoS attacks // This limits concurrent connections to prevent file descriptor exhaustion // Recommended: 1000 connections per listen address (leaves room for system) // See: https://github.com/yrutschle/sslh/blob/master/doc/max_connections.md maxConns := 1000 listen := []ListenAddress{ {Host: "0.0.0.0", Port: "443", MaxConnections: maxConns}, {Host: "[::]", Port: "443", MaxConnections: maxConns}, } protocols, err := GenerateSNIRoutes(svcs, domain) if err != nil { return nil, fmt.Errorf("failed to generate protocol routes: %w", err) } // Find the default TLS route for on_timeout // If SSLH times out during protocol detection, route to TLS (HTTPS) instead of anyprot // This ensures HTTPS connections that are slow to start still work onTimeout := "tls" for _, proto := range protocols { if proto.Name == "tls" && len(proto.SNIHostnames) == 0 { // Found the default TLS route (catch-all, no SNI restriction) onTimeout = "tls" break } } return &Config{ Verbose: 2, Foreground: true, Listen: listen, Protocols: protocols, Timeout: 5, // Increased from 3 to 5 seconds to give TLS handshake more time OnTimeout: onTimeout, // Route to TLS on timeout, not anyprot (port 445) }, nil } func GenerateProtocolRoutes(svcs []services.Service) ([]ProtocolRoute, error) { return GenerateSNIRoutes(svcs, "") } func GenerateSNIRoutes(svcs []services.Service, domain string) ([]ProtocolRoute, error) { var routes []ProtocolRoute // SSH route - will be added after TLS routes to ensure TLS is checked first // SSLH probes in order, so TLS routes should come before SSH sshRoute := ProtocolRoute{ Name: "ssh", Host: "127.0.0.1", Port: "22", Probe: "builtin", Fork: true, } tlsRoutes := make(map[string][]services.Service) regexRoutes := []services.Service{} var defaultTLSRoute *ProtocolRoute for _, svc := range svcs { if svc.Name == "ssh" { continue } switch svc.Protocol { case "tls": if svc.SNIRequired { tlsRoutes[svc.Protocol] = append(tlsRoutes[svc.Protocol], svc) } else { // HTTPS service (root domain) should be the default TLS route // This will catch all TLS connections that don't match SNI-specific routes // Always prefer HTTPS service if it exists if defaultTLSRoute == nil || svc.Name == "https" { defaultTLSRoute = &ProtocolRoute{ Name: "tls", Host: "127.0.0.1", Port: fmt.Sprintf("%d", svc.BackendPort), Probe: "builtin", // SNIHostnames left as nil - will be set to empty array later for catch-all } if alpn, ok := svc.Config["alpn_protocols"].([]string); ok && len(alpn) > 0 { defaultTLSRoute.ALPNProtocols = alpn } } } case "regex": regexRoutes = append(regexRoutes, svc) default: regexRoutes = append(regexRoutes, svc) } } // Add SNI-specific TLS routes first (for subdomains) // These are checked first and take precedence when SNI matches for _, svc := range tlsRoutes["tls"] { sniHostnames := []string{} if domain != "" { sniHostnames = []string{svc.GetFQDN(domain)} } route := ProtocolRoute{ Name: "tls", Host: "127.0.0.1", Port: fmt.Sprintf("%d", svc.BackendPort), Probe: "builtin", SNIHostnames: sniHostnames, LogLevel: 0, } if alpn, ok := svc.Config["alpn_protocols"].([]string); ok { route.ALPNProtocols = alpn } routes = append(routes, route) } // Add default TLS route after SNI-specific routes (for root domain HTTPS) // According to SSLH docs: "if neither are set, it is just checked whether this is the TLS protocol or not" // "if you use TLS with no ALPN/SNI set it as the last TLS probe" // We add TWO TLS routes: // 1. One with ALPN protocols (for modern HTTPS clients) // 2. One without ALPN (true catch-all for any TLS connection) // This ensures all TLS connections are routed correctly if defaultTLSRoute == nil { // If no default TLS service found, use nginx on 8444 as fallback // Add catch-all TLS route (no ALPN, no SNI) defaultTLSRoute = &ProtocolRoute{ Name: "tls", Host: "127.0.0.1", Port: "8444", Probe: "builtin", // No SNI, no ALPN = true catch-all for any TLS connection } routes = append(routes, *defaultTLSRoute) } else { // First add TLS route with ALPN protocols (for modern HTTPS) if len(defaultTLSRoute.ALPNProtocols) > 0 { alpnRoute := *defaultTLSRoute alpnRoute.SNIHostnames = []string{} // No SNI restriction routes = append(routes, alpnRoute) } // Then add catch-all TLS route without ALPN (for any TLS connection) catchAllRoute := *defaultTLSRoute catchAllRoute.ALPNProtocols = []string{} // Clear ALPN for catch-all catchAllRoute.SNIHostnames = []string{} // No SNI restriction routes = append(routes, catchAllRoute) } // Add SSH route AFTER TLS routes to ensure TLS is checked first // SSLH will still probe SSH quickly, but TLS routes take precedence routes = append(routes, sshRoute) for _, svc := range regexRoutes { route := ProtocolRoute{ Name: "regex", // SSLH requires "regex" as the protocol name for regex probes Host: "127.0.0.1", Port: fmt.Sprintf("%d", svc.BackendPort), Probe: "regex", } if patterns, ok := svc.Config["regex_patterns"].([]string); ok { route.RegexPatterns = patterns } routes = append(routes, route) } anyprotRoute := ProtocolRoute{ Name: "anyprot", Host: "127.0.0.1", Port: "445", Probe: "builtin", } routes = append(routes, anyprotRoute) return routes, nil } func (c *Config) ToLibConfig() string { var sb strings.Builder sb.WriteString("verbose: ") sb.WriteString(fmt.Sprintf("%d", c.Verbose)) sb.WriteString(";\n") sb.WriteString("foreground: ") sb.WriteString(fmt.Sprintf("%v", c.Foreground)) sb.WriteString(";\n\n") sb.WriteString("listen:\n") sb.WriteString("(\n") for i, addr := range c.Listen { comma := "," if i == len(c.Listen)-1 { comma = "" } if addr.MaxConnections > 0 { sb.WriteString(fmt.Sprintf(" { host: \"%s\"; port: \"%s\"; max_connections: %d; }%s\n", addr.Host, addr.Port, addr.MaxConnections, comma)) } else { sb.WriteString(fmt.Sprintf(" { host: \"%s\"; port: \"%s\"; }%s\n", addr.Host, addr.Port, comma)) } } sb.WriteString(");\n\n") sb.WriteString("protocols:\n") sb.WriteString("(\n") for i, proto := range c.Protocols { sb.WriteString(" {\n") sb.WriteString(fmt.Sprintf(" name: \"%s\";\n", proto.Name)) sb.WriteString(fmt.Sprintf(" host: \"%s\";\n", proto.Host)) sb.WriteString(fmt.Sprintf(" port: \"%s\";\n", proto.Port)) if proto.Probe != "" { sb.WriteString(fmt.Sprintf(" probe: \"%s\";\n", proto.Probe)) } if len(proto.SNIHostnames) > 0 { sb.WriteString(" sni_hostnames: [") for j, hostname := range proto.SNIHostnames { if j > 0 { sb.WriteString(", ") } sb.WriteString(fmt.Sprintf("\"%s\"", hostname)) } sb.WriteString("];\n") } // Only include alpn_protocols if non-empty // Empty ALPN means catch-all (matches any TLS connection) if len(proto.ALPNProtocols) > 0 { sb.WriteString(" alpn_protocols: [") for j, protocol := range proto.ALPNProtocols { if j > 0 { sb.WriteString(", ") } sb.WriteString(fmt.Sprintf("\"%s\"", protocol)) } sb.WriteString("];\n") } // Note: If both SNI and ALPN are empty/omitted, this is a true catch-all TLS route if len(proto.RegexPatterns) > 0 { sb.WriteString(" regex_patterns: [") for j, pattern := range proto.RegexPatterns { if j > 0 { sb.WriteString(", ") } sb.WriteString(fmt.Sprintf("\"%s\"", pattern)) } sb.WriteString("];\n") } if proto.ProxyProtocol { sb.WriteString(" proxy_protocol: true;\n") } if proto.LogLevel > 0 { sb.WriteString(fmt.Sprintf(" log_level: %d;\n", proto.LogLevel)) } if proto.Fork { sb.WriteString(" fork: true;\n") } sb.WriteString(" }") if i < len(c.Protocols)-1 { sb.WriteString(",") } sb.WriteString("\n") } sb.WriteString(");\n") if c.Timeout > 0 { sb.WriteString(fmt.Sprintf("\ntimeout: %d;\n", c.Timeout)) if c.OnTimeout != "" { sb.WriteString(fmt.Sprintf("on_timeout: { name: \"%s\"; };\n", c.OnTimeout)) } } return sb.String() }