310 lines
8.9 KiB
Go
310 lines
8.9 KiB
Go
package sslh
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"sslh-multiplex-lab/internal/services"
|
|
)
|
|
|
|
type ProtocolRoute struct {
|
|
Name string
|
|
Host string
|
|
Port string
|
|
Probe string
|
|
SNIHostnames []string
|
|
ALPNProtocols []string
|
|
RegexPatterns []string
|
|
ProxyProtocol bool
|
|
LogLevel int
|
|
Fork bool
|
|
}
|
|
|
|
type Config struct {
|
|
Verbose int
|
|
Foreground bool
|
|
Listen []ListenAddress
|
|
Protocols []ProtocolRoute
|
|
Timeout int
|
|
OnTimeout string
|
|
}
|
|
|
|
type ListenAddress struct {
|
|
Host string
|
|
Port string
|
|
MaxConnections int // Limit concurrent connections per listen address (DoS protection)
|
|
}
|
|
|
|
func GenerateConfig(svcs []services.Service, serverIP, domain string) (*Config, error) {
|
|
// Set max_connections per listen address to protect against DoS attacks
|
|
// This limits concurrent connections to prevent file descriptor exhaustion
|
|
// Recommended: 1000 connections per listen address (leaves room for system)
|
|
// See: https://github.com/yrutschle/sslh/blob/master/doc/max_connections.md
|
|
maxConns := 1000
|
|
listen := []ListenAddress{
|
|
{Host: "0.0.0.0", Port: "443", MaxConnections: maxConns},
|
|
{Host: "[::]", Port: "443", MaxConnections: maxConns},
|
|
}
|
|
|
|
protocols, err := GenerateSNIRoutes(svcs, domain)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate protocol routes: %w", err)
|
|
}
|
|
|
|
// Find the default TLS route for on_timeout
|
|
// If SSLH times out during protocol detection, route to TLS (HTTPS) instead of anyprot
|
|
// This ensures HTTPS connections that are slow to start still work
|
|
onTimeout := "tls"
|
|
for _, proto := range protocols {
|
|
if proto.Name == "tls" && len(proto.SNIHostnames) == 0 {
|
|
// Found the default TLS route (catch-all, no SNI restriction)
|
|
onTimeout = "tls"
|
|
break
|
|
}
|
|
}
|
|
|
|
return &Config{
|
|
Verbose: 2,
|
|
Foreground: true,
|
|
Listen: listen,
|
|
Protocols: protocols,
|
|
Timeout: 5, // Increased from 3 to 5 seconds to give TLS handshake more time
|
|
OnTimeout: onTimeout, // Route to TLS on timeout, not anyprot (port 445)
|
|
}, nil
|
|
}
|
|
|
|
func GenerateProtocolRoutes(svcs []services.Service) ([]ProtocolRoute, error) {
|
|
return GenerateSNIRoutes(svcs, "")
|
|
}
|
|
|
|
func GenerateSNIRoutes(svcs []services.Service, domain string) ([]ProtocolRoute, error) {
|
|
var routes []ProtocolRoute
|
|
|
|
// SSH route - will be added after TLS routes to ensure TLS is checked first
|
|
// SSLH probes in order, so TLS routes should come before SSH
|
|
sshRoute := ProtocolRoute{
|
|
Name: "ssh",
|
|
Host: "127.0.0.1",
|
|
Port: "22",
|
|
Probe: "builtin",
|
|
Fork: true,
|
|
}
|
|
|
|
tlsRoutes := make(map[string][]services.Service)
|
|
regexRoutes := []services.Service{}
|
|
var defaultTLSRoute *ProtocolRoute
|
|
|
|
for _, svc := range svcs {
|
|
if svc.Name == "ssh" {
|
|
continue
|
|
}
|
|
|
|
switch svc.Protocol {
|
|
case "tls":
|
|
if svc.SNIRequired {
|
|
tlsRoutes[svc.Protocol] = append(tlsRoutes[svc.Protocol], svc)
|
|
} else {
|
|
// HTTPS service (root domain) should be the default TLS route
|
|
// This will catch all TLS connections that don't match SNI-specific routes
|
|
// Always prefer HTTPS service if it exists
|
|
if defaultTLSRoute == nil || svc.Name == "https" {
|
|
defaultTLSRoute = &ProtocolRoute{
|
|
Name: "tls",
|
|
Host: "127.0.0.1",
|
|
Port: fmt.Sprintf("%d", svc.BackendPort),
|
|
Probe: "builtin",
|
|
// SNIHostnames left as nil - will be set to empty array later for catch-all
|
|
}
|
|
if alpn, ok := svc.Config["alpn_protocols"].([]string); ok && len(alpn) > 0 {
|
|
defaultTLSRoute.ALPNProtocols = alpn
|
|
}
|
|
}
|
|
}
|
|
case "regex":
|
|
regexRoutes = append(regexRoutes, svc)
|
|
default:
|
|
regexRoutes = append(regexRoutes, svc)
|
|
}
|
|
}
|
|
|
|
// Add SNI-specific TLS routes first (for subdomains)
|
|
// These are checked first and take precedence when SNI matches
|
|
for _, svc := range tlsRoutes["tls"] {
|
|
sniHostnames := []string{}
|
|
if domain != "" {
|
|
sniHostnames = []string{svc.GetFQDN(domain)}
|
|
}
|
|
|
|
route := ProtocolRoute{
|
|
Name: "tls",
|
|
Host: "127.0.0.1",
|
|
Port: fmt.Sprintf("%d", svc.BackendPort),
|
|
Probe: "builtin",
|
|
SNIHostnames: sniHostnames,
|
|
LogLevel: 0,
|
|
}
|
|
|
|
if alpn, ok := svc.Config["alpn_protocols"].([]string); ok {
|
|
route.ALPNProtocols = alpn
|
|
}
|
|
|
|
routes = append(routes, route)
|
|
}
|
|
|
|
// Add default TLS route after SNI-specific routes (for root domain HTTPS)
|
|
// According to SSLH docs: "if neither are set, it is just checked whether this is the TLS protocol or not"
|
|
// "if you use TLS with no ALPN/SNI set it as the last TLS probe"
|
|
// We add TWO TLS routes:
|
|
// 1. One with ALPN protocols (for modern HTTPS clients)
|
|
// 2. One without ALPN (true catch-all for any TLS connection)
|
|
// This ensures all TLS connections are routed correctly
|
|
if defaultTLSRoute == nil {
|
|
// If no default TLS service found, use nginx on 8444 as fallback
|
|
// Add catch-all TLS route (no ALPN, no SNI)
|
|
defaultTLSRoute = &ProtocolRoute{
|
|
Name: "tls",
|
|
Host: "127.0.0.1",
|
|
Port: "8444",
|
|
Probe: "builtin",
|
|
// No SNI, no ALPN = true catch-all for any TLS connection
|
|
}
|
|
routes = append(routes, *defaultTLSRoute)
|
|
} else {
|
|
// First add TLS route with ALPN protocols (for modern HTTPS)
|
|
if len(defaultTLSRoute.ALPNProtocols) > 0 {
|
|
alpnRoute := *defaultTLSRoute
|
|
alpnRoute.SNIHostnames = []string{} // No SNI restriction
|
|
routes = append(routes, alpnRoute)
|
|
}
|
|
// Then add catch-all TLS route without ALPN (for any TLS connection)
|
|
catchAllRoute := *defaultTLSRoute
|
|
catchAllRoute.ALPNProtocols = []string{} // Clear ALPN for catch-all
|
|
catchAllRoute.SNIHostnames = []string{} // No SNI restriction
|
|
routes = append(routes, catchAllRoute)
|
|
}
|
|
|
|
// Add SSH route AFTER TLS routes to ensure TLS is checked first
|
|
// SSLH will still probe SSH quickly, but TLS routes take precedence
|
|
routes = append(routes, sshRoute)
|
|
|
|
for _, svc := range regexRoutes {
|
|
route := ProtocolRoute{
|
|
Name: "regex", // SSLH requires "regex" as the protocol name for regex probes
|
|
Host: "127.0.0.1",
|
|
Port: fmt.Sprintf("%d", svc.BackendPort),
|
|
Probe: "regex",
|
|
}
|
|
|
|
if patterns, ok := svc.Config["regex_patterns"].([]string); ok {
|
|
route.RegexPatterns = patterns
|
|
}
|
|
|
|
routes = append(routes, route)
|
|
}
|
|
|
|
anyprotRoute := ProtocolRoute{
|
|
Name: "anyprot",
|
|
Host: "127.0.0.1",
|
|
Port: "445",
|
|
Probe: "builtin",
|
|
}
|
|
routes = append(routes, anyprotRoute)
|
|
|
|
return routes, nil
|
|
}
|
|
|
|
func (c *Config) ToLibConfig() string {
|
|
var sb strings.Builder
|
|
|
|
sb.WriteString("verbose: ")
|
|
sb.WriteString(fmt.Sprintf("%d", c.Verbose))
|
|
sb.WriteString(";\n")
|
|
sb.WriteString("foreground: ")
|
|
sb.WriteString(fmt.Sprintf("%v", c.Foreground))
|
|
sb.WriteString(";\n\n")
|
|
|
|
sb.WriteString("listen:\n")
|
|
sb.WriteString("(\n")
|
|
for i, addr := range c.Listen {
|
|
comma := ","
|
|
if i == len(c.Listen)-1 {
|
|
comma = ""
|
|
}
|
|
if addr.MaxConnections > 0 {
|
|
sb.WriteString(fmt.Sprintf(" { host: \"%s\"; port: \"%s\"; max_connections: %d; }%s\n", addr.Host, addr.Port, addr.MaxConnections, comma))
|
|
} else {
|
|
sb.WriteString(fmt.Sprintf(" { host: \"%s\"; port: \"%s\"; }%s\n", addr.Host, addr.Port, comma))
|
|
}
|
|
}
|
|
sb.WriteString(");\n\n")
|
|
|
|
sb.WriteString("protocols:\n")
|
|
sb.WriteString("(\n")
|
|
for i, proto := range c.Protocols {
|
|
sb.WriteString(" {\n")
|
|
sb.WriteString(fmt.Sprintf(" name: \"%s\";\n", proto.Name))
|
|
sb.WriteString(fmt.Sprintf(" host: \"%s\";\n", proto.Host))
|
|
sb.WriteString(fmt.Sprintf(" port: \"%s\";\n", proto.Port))
|
|
if proto.Probe != "" {
|
|
sb.WriteString(fmt.Sprintf(" probe: \"%s\";\n", proto.Probe))
|
|
}
|
|
if len(proto.SNIHostnames) > 0 {
|
|
sb.WriteString(" sni_hostnames: [")
|
|
for j, hostname := range proto.SNIHostnames {
|
|
if j > 0 {
|
|
sb.WriteString(", ")
|
|
}
|
|
sb.WriteString(fmt.Sprintf("\"%s\"", hostname))
|
|
}
|
|
sb.WriteString("];\n")
|
|
}
|
|
// Only include alpn_protocols if non-empty
|
|
// Empty ALPN means catch-all (matches any TLS connection)
|
|
if len(proto.ALPNProtocols) > 0 {
|
|
sb.WriteString(" alpn_protocols: [")
|
|
for j, protocol := range proto.ALPNProtocols {
|
|
if j > 0 {
|
|
sb.WriteString(", ")
|
|
}
|
|
sb.WriteString(fmt.Sprintf("\"%s\"", protocol))
|
|
}
|
|
sb.WriteString("];\n")
|
|
}
|
|
// Note: If both SNI and ALPN are empty/omitted, this is a true catch-all TLS route
|
|
if len(proto.RegexPatterns) > 0 {
|
|
sb.WriteString(" regex_patterns: [")
|
|
for j, pattern := range proto.RegexPatterns {
|
|
if j > 0 {
|
|
sb.WriteString(", ")
|
|
}
|
|
sb.WriteString(fmt.Sprintf("\"%s\"", pattern))
|
|
}
|
|
sb.WriteString("];\n")
|
|
}
|
|
if proto.ProxyProtocol {
|
|
sb.WriteString(" proxy_protocol: true;\n")
|
|
}
|
|
if proto.LogLevel > 0 {
|
|
sb.WriteString(fmt.Sprintf(" log_level: %d;\n", proto.LogLevel))
|
|
}
|
|
if proto.Fork {
|
|
sb.WriteString(" fork: true;\n")
|
|
}
|
|
sb.WriteString(" }")
|
|
if i < len(c.Protocols)-1 {
|
|
sb.WriteString(",")
|
|
}
|
|
sb.WriteString("\n")
|
|
}
|
|
sb.WriteString(");\n")
|
|
|
|
if c.Timeout > 0 {
|
|
sb.WriteString(fmt.Sprintf("\ntimeout: %d;\n", c.Timeout))
|
|
if c.OnTimeout != "" {
|
|
sb.WriteString(fmt.Sprintf("on_timeout: { name: \"%s\"; };\n", c.OnTimeout))
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|