Initial code commit
This commit is contained in:
48
pkg/utils/debug_log.go
Normal file
48
pkg/utils/debug_log.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const debugLogPath = "/Users/gmapple/Documents/Projects/labs/sslh-multiplex-lab/.cursor/debug.log"
|
||||
|
||||
type DebugLogEntry struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Location string `json:"location"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
SessionID string `json:"sessionId"`
|
||||
RunID string `json:"runId"`
|
||||
HypothesisID string `json:"hypothesisId"`
|
||||
}
|
||||
|
||||
func DebugLog(location, message string, data map[string]interface{}, hypothesisID string) {
|
||||
entry := DebugLogEntry{
|
||||
ID: fmt.Sprintf("log_%d", time.Now().UnixNano()),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
Location: location,
|
||||
Message: message,
|
||||
Data: data,
|
||||
SessionID: "debug-session",
|
||||
RunID: "run1",
|
||||
HypothesisID: hypothesisID,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(debugLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
f.Write(jsonData)
|
||||
f.WriteString("\n")
|
||||
}
|
||||
131
pkg/utils/dns.go
Normal file
131
pkg/utils/dns.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func ValidateDNSPropagation(hostname, expectedIP string, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 10 * time.Second
|
||||
|
||||
dnsServers := []string{"1.1.1.1:53", "8.8.8.8:53", "1.0.0.1:53"}
|
||||
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return fmt.Errorf("timeout waiting for DNS propagation for %s", hostname)
|
||||
}
|
||||
|
||||
resolved := false
|
||||
|
||||
// Try public DNS servers first
|
||||
for _, dnsServer := range dnsServers {
|
||||
ips, err := queryDNS(dnsServer, hostname)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == expectedIP {
|
||||
resolved = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if resolved {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to system resolver if public DNS servers failed
|
||||
if !resolved {
|
||||
ips, err := net.LookupIP(hostname)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip.String() == expectedIP {
|
||||
resolved = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resolved {
|
||||
return nil
|
||||
}
|
||||
|
||||
time.Sleep(backoff)
|
||||
if backoff < maxBackoff {
|
||||
backoff += 1 * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func queryDNS(server, hostname string) ([]string, error) {
|
||||
client := dns.Client{Timeout: 5 * time.Second}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(hostname), dns.TypeA)
|
||||
|
||||
resp, _, err := client.Exchange(msg, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("DNS query failed with Rcode: %d", resp.Rcode)
|
||||
}
|
||||
|
||||
var ips []string
|
||||
for _, answer := range resp.Answer {
|
||||
if a, ok := answer.(*dns.A); ok {
|
||||
ips = append(ips, a.A.String())
|
||||
}
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func ResolveHostname(hostname string) ([]string, error) {
|
||||
ips, err := net.LookupIP(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ipStrings []string
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
ipStrings = append(ipStrings, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
return ipStrings, nil
|
||||
}
|
||||
|
||||
func ValidateDNSPropagationWithResolver(hostname, expectedIP, resolver string, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 10 * time.Second
|
||||
|
||||
dnsServer := resolver + ":53"
|
||||
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return fmt.Errorf("timeout waiting for DNS propagation for %s on resolver %s", hostname, resolver)
|
||||
}
|
||||
|
||||
ips, err := queryDNS(dnsServer, hostname)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == expectedIP {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(backoff)
|
||||
if backoff < maxBackoff {
|
||||
backoff += 1 * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
44
pkg/utils/ip.go
Normal file
44
pkg/utils/ip.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetPublicIP() (string, error) {
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
services := []string{
|
||||
"https://api.ipify.org",
|
||||
"https://icanhazip.com",
|
||||
"https://ifconfig.me/ip",
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
resp, err := client.Get(service)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := string(body)
|
||||
if ip != "" {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to determine public IP address from any service")
|
||||
}
|
||||
126
pkg/utils/progress.go
Normal file
126
pkg/utils/progress.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProgressBar struct {
|
||||
total int
|
||||
current int
|
||||
width int
|
||||
startTime time.Time
|
||||
label string
|
||||
}
|
||||
|
||||
func NewProgressBar(total int, label string) *ProgressBar {
|
||||
return &ProgressBar{
|
||||
total: total,
|
||||
current: 0,
|
||||
width: 40,
|
||||
startTime: time.Now(),
|
||||
label: label,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProgressBar) Update(current int) {
|
||||
p.current = current
|
||||
p.render()
|
||||
}
|
||||
|
||||
func (p *ProgressBar) Increment() {
|
||||
p.current++
|
||||
p.render()
|
||||
}
|
||||
|
||||
func (p *ProgressBar) render() {
|
||||
if p.total == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
percent := float64(p.current) / float64(p.total)
|
||||
filled := int(percent * float64(p.width))
|
||||
empty := p.width - filled
|
||||
|
||||
bar := ""
|
||||
for i := 0; i < filled; i++ {
|
||||
bar += "█"
|
||||
}
|
||||
for i := 0; i < empty; i++ {
|
||||
bar += "░"
|
||||
}
|
||||
|
||||
elapsed := time.Since(p.startTime)
|
||||
var eta time.Duration
|
||||
if p.current > 0 && elapsed.Seconds() > 0 {
|
||||
rate := float64(p.current) / elapsed.Seconds()
|
||||
if rate > 0 {
|
||||
remaining := float64(p.total-p.current) / rate
|
||||
eta = time.Duration(remaining) * time.Second
|
||||
if eta < 0 {
|
||||
eta = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if eta > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\r%s [%s] %d/%d (%.1f%%) ETA: %s", p.label, bar, p.current, p.total, percent*100, eta.Round(time.Second))
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\r%s [%s] %d/%d (%.1f%%)", p.label, bar, p.current, p.total, percent*100)
|
||||
}
|
||||
if p.current >= p.total {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProgressBar) Finish() {
|
||||
p.current = p.total
|
||||
p.render()
|
||||
}
|
||||
|
||||
type Spinner struct {
|
||||
chars []string
|
||||
index int
|
||||
label string
|
||||
stopChan chan bool
|
||||
doneChan chan bool
|
||||
}
|
||||
|
||||
func NewSpinner(label string) *Spinner {
|
||||
return &Spinner{
|
||||
chars: []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"},
|
||||
index: 0,
|
||||
label: label,
|
||||
stopChan: make(chan bool),
|
||||
doneChan: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Spinner) Start() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
fmt.Fprintf(os.Stderr, "\r%s %s\n", s.chars[s.index], s.label)
|
||||
s.doneChan <- true
|
||||
return
|
||||
case <-ticker.C:
|
||||
fmt.Fprintf(os.Stderr, "\r%s %s", s.chars[s.index], s.label)
|
||||
s.index = (s.index + 1) % len(s.chars)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Spinner) Stop() {
|
||||
s.stopChan <- true
|
||||
<-s.doneChan
|
||||
}
|
||||
|
||||
func (s *Spinner) StopWithMessage(message string) {
|
||||
s.Stop()
|
||||
fmt.Fprintf(os.Stderr, "%s\n", message)
|
||||
}
|
||||
53
pkg/utils/retry.go
Normal file
53
pkg/utils/retry.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RetryWithBackoff(maxAttempts int, initialDelay time.Duration, fn func() error) error {
|
||||
var lastErr error
|
||||
delay := initialDelay
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
err := fn()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
if attempt < maxAttempts {
|
||||
time.Sleep(delay)
|
||||
delay *= 2
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("max attempts (%d) reached, last error: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
func RetryWithExponentialBackoff(maxAttempts int, initialDelay, maxDelay time.Duration, fn func() error) error {
|
||||
var lastErr error
|
||||
delay := initialDelay
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
err := fn()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
if attempt < maxAttempts {
|
||||
time.Sleep(delay)
|
||||
if delay < maxDelay {
|
||||
delay *= 2
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("max attempts (%d) reached, last error: %w", maxAttempts, lastErr)
|
||||
}
|
||||
40
pkg/utils/ssh.go
Normal file
40
pkg/utils/ssh.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func RemoveSSHKnownHost(host string) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
knownHostsPath := filepath.Join(homeDir, ".ssh", "known_hosts")
|
||||
|
||||
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("ssh-keygen", "-R", host)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to remove host key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSSHUser() (string, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current user: %w", err)
|
||||
}
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
34
pkg/utils/validation.go
Normal file
34
pkg/utils/validation.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
func ValidateIP(ip string) error {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return fmt.Errorf("invalid IP address: %s", ip)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateDomain(domain string) error {
|
||||
if domain == "" {
|
||||
return fmt.Errorf("domain cannot be empty")
|
||||
}
|
||||
|
||||
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
|
||||
if !domainRegex.MatchString(domain) {
|
||||
return fmt.Errorf("invalid domain format: %s", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePort(port int) error {
|
||||
if port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %d (must be between 1 and 65535)", port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user