Files
sslh-multiplex-lab/internal/sslh/config.go
2026-01-29 00:03:02 +00:00

310 lines
8.9 KiB
Go

package sslh
import (
"fmt"
"strings"
"sslh-multiplex-lab/internal/services"
)
type ProtocolRoute struct {
Name string
Host string
Port string
Probe string
SNIHostnames []string
ALPNProtocols []string
RegexPatterns []string
ProxyProtocol bool
LogLevel int
Fork bool
}
type Config struct {
Verbose int
Foreground bool
Listen []ListenAddress
Protocols []ProtocolRoute
Timeout int
OnTimeout string
}
type ListenAddress struct {
Host string
Port string
MaxConnections int // Limit concurrent connections per listen address (DoS protection)
}
func GenerateConfig(svcs []services.Service, serverIP, domain string) (*Config, error) {
// Set max_connections per listen address to protect against DoS attacks
// This limits concurrent connections to prevent file descriptor exhaustion
// Recommended: 1000 connections per listen address (leaves room for system)
// See: https://github.com/yrutschle/sslh/blob/master/doc/max_connections.md
maxConns := 1000
listen := []ListenAddress{
{Host: "0.0.0.0", Port: "443", MaxConnections: maxConns},
{Host: "[::]", Port: "443", MaxConnections: maxConns},
}
protocols, err := GenerateSNIRoutes(svcs, domain)
if err != nil {
return nil, fmt.Errorf("failed to generate protocol routes: %w", err)
}
// Find the default TLS route for on_timeout
// If SSLH times out during protocol detection, route to TLS (HTTPS) instead of anyprot
// This ensures HTTPS connections that are slow to start still work
onTimeout := "tls"
for _, proto := range protocols {
if proto.Name == "tls" && len(proto.SNIHostnames) == 0 {
// Found the default TLS route (catch-all, no SNI restriction)
onTimeout = "tls"
break
}
}
return &Config{
Verbose: 2,
Foreground: true,
Listen: listen,
Protocols: protocols,
Timeout: 5, // Increased from 3 to 5 seconds to give TLS handshake more time
OnTimeout: onTimeout, // Route to TLS on timeout, not anyprot (port 445)
}, nil
}
func GenerateProtocolRoutes(svcs []services.Service) ([]ProtocolRoute, error) {
return GenerateSNIRoutes(svcs, "")
}
func GenerateSNIRoutes(svcs []services.Service, domain string) ([]ProtocolRoute, error) {
var routes []ProtocolRoute
// SSH route - will be added after TLS routes to ensure TLS is checked first
// SSLH probes in order, so TLS routes should come before SSH
sshRoute := ProtocolRoute{
Name: "ssh",
Host: "127.0.0.1",
Port: "22",
Probe: "builtin",
Fork: true,
}
tlsRoutes := make(map[string][]services.Service)
regexRoutes := []services.Service{}
var defaultTLSRoute *ProtocolRoute
for _, svc := range svcs {
if svc.Name == "ssh" {
continue
}
switch svc.Protocol {
case "tls":
if svc.SNIRequired {
tlsRoutes[svc.Protocol] = append(tlsRoutes[svc.Protocol], svc)
} else {
// HTTPS service (root domain) should be the default TLS route
// This will catch all TLS connections that don't match SNI-specific routes
// Always prefer HTTPS service if it exists
if defaultTLSRoute == nil || svc.Name == "https" {
defaultTLSRoute = &ProtocolRoute{
Name: "tls",
Host: "127.0.0.1",
Port: fmt.Sprintf("%d", svc.BackendPort),
Probe: "builtin",
// SNIHostnames left as nil - will be set to empty array later for catch-all
}
if alpn, ok := svc.Config["alpn_protocols"].([]string); ok && len(alpn) > 0 {
defaultTLSRoute.ALPNProtocols = alpn
}
}
}
case "regex":
regexRoutes = append(regexRoutes, svc)
default:
regexRoutes = append(regexRoutes, svc)
}
}
// Add SNI-specific TLS routes first (for subdomains)
// These are checked first and take precedence when SNI matches
for _, svc := range tlsRoutes["tls"] {
sniHostnames := []string{}
if domain != "" {
sniHostnames = []string{svc.GetFQDN(domain)}
}
route := ProtocolRoute{
Name: "tls",
Host: "127.0.0.1",
Port: fmt.Sprintf("%d", svc.BackendPort),
Probe: "builtin",
SNIHostnames: sniHostnames,
LogLevel: 0,
}
if alpn, ok := svc.Config["alpn_protocols"].([]string); ok {
route.ALPNProtocols = alpn
}
routes = append(routes, route)
}
// Add default TLS route after SNI-specific routes (for root domain HTTPS)
// According to SSLH docs: "if neither are set, it is just checked whether this is the TLS protocol or not"
// "if you use TLS with no ALPN/SNI set it as the last TLS probe"
// We add TWO TLS routes:
// 1. One with ALPN protocols (for modern HTTPS clients)
// 2. One without ALPN (true catch-all for any TLS connection)
// This ensures all TLS connections are routed correctly
if defaultTLSRoute == nil {
// If no default TLS service found, use nginx on 8444 as fallback
// Add catch-all TLS route (no ALPN, no SNI)
defaultTLSRoute = &ProtocolRoute{
Name: "tls",
Host: "127.0.0.1",
Port: "8444",
Probe: "builtin",
// No SNI, no ALPN = true catch-all for any TLS connection
}
routes = append(routes, *defaultTLSRoute)
} else {
// First add TLS route with ALPN protocols (for modern HTTPS)
if len(defaultTLSRoute.ALPNProtocols) > 0 {
alpnRoute := *defaultTLSRoute
alpnRoute.SNIHostnames = []string{} // No SNI restriction
routes = append(routes, alpnRoute)
}
// Then add catch-all TLS route without ALPN (for any TLS connection)
catchAllRoute := *defaultTLSRoute
catchAllRoute.ALPNProtocols = []string{} // Clear ALPN for catch-all
catchAllRoute.SNIHostnames = []string{} // No SNI restriction
routes = append(routes, catchAllRoute)
}
// Add SSH route AFTER TLS routes to ensure TLS is checked first
// SSLH will still probe SSH quickly, but TLS routes take precedence
routes = append(routes, sshRoute)
for _, svc := range regexRoutes {
route := ProtocolRoute{
Name: "regex", // SSLH requires "regex" as the protocol name for regex probes
Host: "127.0.0.1",
Port: fmt.Sprintf("%d", svc.BackendPort),
Probe: "regex",
}
if patterns, ok := svc.Config["regex_patterns"].([]string); ok {
route.RegexPatterns = patterns
}
routes = append(routes, route)
}
anyprotRoute := ProtocolRoute{
Name: "anyprot",
Host: "127.0.0.1",
Port: "445",
Probe: "builtin",
}
routes = append(routes, anyprotRoute)
return routes, nil
}
func (c *Config) ToLibConfig() string {
var sb strings.Builder
sb.WriteString("verbose: ")
sb.WriteString(fmt.Sprintf("%d", c.Verbose))
sb.WriteString(";\n")
sb.WriteString("foreground: ")
sb.WriteString(fmt.Sprintf("%v", c.Foreground))
sb.WriteString(";\n\n")
sb.WriteString("listen:\n")
sb.WriteString("(\n")
for i, addr := range c.Listen {
comma := ","
if i == len(c.Listen)-1 {
comma = ""
}
if addr.MaxConnections > 0 {
sb.WriteString(fmt.Sprintf(" { host: \"%s\"; port: \"%s\"; max_connections: %d; }%s\n", addr.Host, addr.Port, addr.MaxConnections, comma))
} else {
sb.WriteString(fmt.Sprintf(" { host: \"%s\"; port: \"%s\"; }%s\n", addr.Host, addr.Port, comma))
}
}
sb.WriteString(");\n\n")
sb.WriteString("protocols:\n")
sb.WriteString("(\n")
for i, proto := range c.Protocols {
sb.WriteString(" {\n")
sb.WriteString(fmt.Sprintf(" name: \"%s\";\n", proto.Name))
sb.WriteString(fmt.Sprintf(" host: \"%s\";\n", proto.Host))
sb.WriteString(fmt.Sprintf(" port: \"%s\";\n", proto.Port))
if proto.Probe != "" {
sb.WriteString(fmt.Sprintf(" probe: \"%s\";\n", proto.Probe))
}
if len(proto.SNIHostnames) > 0 {
sb.WriteString(" sni_hostnames: [")
for j, hostname := range proto.SNIHostnames {
if j > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("\"%s\"", hostname))
}
sb.WriteString("];\n")
}
// Only include alpn_protocols if non-empty
// Empty ALPN means catch-all (matches any TLS connection)
if len(proto.ALPNProtocols) > 0 {
sb.WriteString(" alpn_protocols: [")
for j, protocol := range proto.ALPNProtocols {
if j > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("\"%s\"", protocol))
}
sb.WriteString("];\n")
}
// Note: If both SNI and ALPN are empty/omitted, this is a true catch-all TLS route
if len(proto.RegexPatterns) > 0 {
sb.WriteString(" regex_patterns: [")
for j, pattern := range proto.RegexPatterns {
if j > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("\"%s\"", pattern))
}
sb.WriteString("];\n")
}
if proto.ProxyProtocol {
sb.WriteString(" proxy_protocol: true;\n")
}
if proto.LogLevel > 0 {
sb.WriteString(fmt.Sprintf(" log_level: %d;\n", proto.LogLevel))
}
if proto.Fork {
sb.WriteString(" fork: true;\n")
}
sb.WriteString(" }")
if i < len(c.Protocols)-1 {
sb.WriteString(",")
}
sb.WriteString("\n")
}
sb.WriteString(");\n")
if c.Timeout > 0 {
sb.WriteString(fmt.Sprintf("\ntimeout: %d;\n", c.Timeout))
if c.OnTimeout != "" {
sb.WriteString(fmt.Sprintf("on_timeout: { name: \"%s\"; };\n", c.OnTimeout))
}
}
return sb.String()
}