#!/usr/bin/env python3

import re
import sys
import argparse
import socket
from collections import defaultdict
from typing import List, Tuple

CADDYFILE = "/etc/caddy/Caddyfile"

def test_port(port: str, timeout: float = 0.5) -> bool:
    """Teste si un port est ouvert/répond"""
    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(timeout)
        result = sock.connect_ex(('localhost', int(port)))
        sock.close()
        return result == 0
    except:
        return False

def parse_caddyfile(filepath: str) -> List[Tuple[str, str]]:
    """Parse le Caddyfile et extrait les mappings domaine:port"""
    with open(filepath, 'r') as f:
        content = f.read()
    
    mappings = []
    matchers = {}
    
    # Méthode 1: Parser les blocs simples (domaine { ... reverse_proxy localhost:PORT ... })
    # Pattern pour capturer les blocs simples
    simple_pattern = r'([a-zA-Z0-9][-a-zA-Z0-9.]*\.[a-zA-Z]{2,})\s*\{([^}]*)\}'
    
    for match in re.finditer(simple_pattern, content, re.MULTILINE):
        domain = match.group(1)
        block_content = match.group(2)
        
        # Chercher reverse_proxy dans le bloc
        port_matches = re.findall(r'reverse_proxy\s+.*?localhost:(\d+)', block_content)
        for port in port_matches:
            mappings.append((domain, port))
    
    # Méthode 2: Parser les wildcards avec matchers
    # Extraire d'abord tous les matchers
    matcher_pattern = r'@(\w+)\s+host\s+([^\n]+)'
    for match in re.finditer(matcher_pattern, content):
        matcher_name = match.group(1)
        domains_str = match.group(2).strip()
        # Nettoyer et extraire les domaines
        domains = re.findall(r'([a-zA-Z0-9][-a-zA-Z0-9.]*\.[a-zA-Z]{2,})', domains_str)
        matchers[matcher_name] = domains
    
    # Ensuite, trouver les reverse_proxy associés aux matchers
    handle_pattern = r'handle\s+@(\w+)\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}'
    
    for match in re.finditer(handle_pattern, content, re.DOTALL):
        matcher_name = match.group(1)
        handle_content = match.group(2)
        
        if matcher_name in matchers:
            port_matches = re.findall(r'reverse_proxy\s+.*?localhost:(\d+)', handle_content)
            for port in port_matches:
                for domain in matchers[matcher_name]:
                    mappings.append((domain, port))
    
    return sorted(set(mappings))

def get_status_icon(port: str, test_ports: bool = False) -> str:
    """Retourne l'icône de statut du port"""
    if not test_ports:
        return ""
    
    is_open = test_port(port)
    return "✅" if is_open else "❌"

def display_by_domain(mappings: List[Tuple[str, str]], group: bool = False, test_ports: bool = False) -> int:
    """Affiche les mappings triés par domaine. Retourne le nombre de ports inactifs."""
    print("\n╔════════════════════════════════════════════════════════════════╗")
    print("║         Mappings triés par DOMAINE → Port Local                ║")
    print("╚════════════════════════════════════════════════════════════════╝")
    print()
    
    if test_ports:
        print("🔍 Test de connectivité en cours...")
        print()
    
    max_len = max(len(d) for d, _ in mappings) if mappings else 0
    
    online_count = 0
    offline_count = 0
    offline_ports = set()
    
    for domain, port in sorted(mappings, key=lambda x: x[0]):
        status = get_status_icon(port, test_ports)
        
        if test_ports:
            if "✅" in status:
                online_count += 1
            else:
                offline_count += 1
                offline_ports.add(port)
        
        if test_ports:
            print(f"{status} {domain:<{max_len}} → localhost:{port}")
        else:
            print(f"{domain:<{max_len}} → localhost:{port}")
    
    print()
    print("─" * 64)
    print(f"Total: {len(mappings)} mappings", end="")
    
    if test_ports:
        print(f" (✅ {online_count} actifs, ❌ {offline_count} inactifs)")
    else:
        print()
    
    return len(offline_ports) if test_ports else 0

def display_by_port(mappings: List[Tuple[str, str]], group: bool = False, test_ports: bool = False) -> int:
    """Affiche les mappings triés par port. Retourne le nombre de ports inactifs."""
    print("\n╔════════════════════════════════════════════════════════════════╗")
    print("║         Mappings triés par PORT → Domaines                     ║")
    print("╚════════════════════════════════════════════════════════════════╝")
    print()
    
    if test_ports:
        print("🔍 Test de connectivité en cours...")
        print()
    
    # Grouper par port
    by_port = defaultdict(list)
    for domain, port in mappings:
        by_port[port].append(domain)
    
    online_ports = 0
    offline_ports = 0
    
    if not group:
        # Affichage ligne par ligne
        sorted_mappings = sorted(mappings, key=lambda x: (int(x[1]), x[0]))
        tested_ports = set()
        
        for domain, port in sorted_mappings:
            status = get_status_icon(port, test_ports)
            
            if test_ports and port not in tested_ports:
                tested_ports.add(port)
                if "✅" in status:
                    online_ports += 1
                else:
                    offline_ports += 1
            
            if test_ports:
                print(f"{status} localhost:{port:<5} → {domain}")
            else:
                print(f"localhost:{port:<5} → {domain}")
    else:
        # Affichage regroupé par port
        for port in sorted(by_port.keys(), key=int):
            domains = sorted(by_port[port])
            count = len(domains)
            
            status = get_status_icon(port, test_ports)
            
            if test_ports:
                if "✅" in status:
                    online_ports += 1
                else:
                    offline_ports += 1
                
                print(f"\n{status} 🔌 localhost:{port} ({count} domaine{'s' if count > 1 else ''})")
            else:
                print(f"\n🔌 localhost:{port} ({count} domaine{'s' if count > 1 else ''})")
            
            print("   " + "─" * 60)
            
            for domain in domains:
                print(f"   • {domain}")
    
    print()
    print("─" * 64)
    print(f"Total: {len(mappings)} mappings", end="")
    
    if test_ports and group:
        print(f" sur {len(by_port)} ports (✅ {online_ports} actifs, ❌ {offline_ports} inactifs)")
    else:
        print()
    
    return offline_ports if test_ports else 0

def display_statistics(mappings: List[Tuple[str, str]], test_ports: bool = False) -> int:
    """Affiche des statistiques détaillées. Retourne le nombre de ports inactifs."""
    print("\n╔════════════════════════════════════════════════════════════════╗")
    print("║                      STATISTIQUES                              ║")
    print("╚════════════════════════════════════════════════════════════════╝")
    print()
    
    if test_ports:
        print("🔍 Test de connectivité en cours...")
        print()
    
    by_port = defaultdict(list)
    for domain, port in mappings:
        by_port[port].append(domain)
    
    # Test des ports si demandé
    port_status = {}
    if test_ports:
        for port in by_port.keys():
            port_status[port] = test_port(port)
    
    online_ports = sum(1 for status in port_status.values() if status) if test_ports else 0
    offline_ports = len(port_status) - online_ports if test_ports else 0
    
    print(f"📊 Nombre total de mappings : {len(mappings)}")
    print(f"🔌 Nombre de ports utilisés : {len(by_port)}")
    
    if test_ports:
        print(f"   ├─ ✅ Actifs : {online_ports}")
        print(f"   └─ ❌ Inactifs : {offline_ports}")
    
    print(f"🌐 Nombre de domaines uniques : {len(set(d for d, _ in mappings))}")
    print()
    
    print("Répartition par port :")
    print("─" * 64)
    
    # Tableau avec barres
    max_count = max(len(domains) for domains in by_port.values()) if by_port else 1
    
    for port in sorted(by_port.keys(), key=int):
        count = len(by_port[port])
        bar_length = int((count / max_count) * 30)
        bar = "█" * bar_length
        
        status = ""
        if test_ports:
            status = "✅ " if port_status.get(port, False) else "❌ "
        
        print(f"  {status}localhost:{port:<5} │ {bar} {count}")
    
    print()
    
    # Top 5 des ports les plus utilisés
    top_ports = sorted(by_port.items(), key=lambda x: len(x[1]), reverse=True)[:5]
    
    print("🏆 Top 5 des ports les plus utilisés :")
    print("─" * 64)
    for i, (port, domains) in enumerate(top_ports, 1):
        status = ""
        if test_ports:
            status = "✅ " if port_status.get(port, False) else "❌ "
        
        print(f"  {i}. {status}localhost:{port} → {len(domains)} domaine(s)")
        for domain in sorted(domains)[:3]:
            print(f"     • {domain}")
        if len(domains) > 3:
            print(f"     ... et {len(domains) - 3} autre(s)")
        print()
    
    return offline_ports if test_ports else 0

def display_combined(mappings: List[Tuple[str, str]], test_ports: bool = False) -> int:
    """Affiche une vue combinée avec les deux tris. Retourne le nombre de ports inactifs."""
    by_port = defaultdict(list)
    for domain, port in mappings:
        by_port[port].append(domain)
    
    print("\n╔════════════════════════════════════════════════════════════════╗")
    print("║              Vue combinée : Port ⟷ Domaines                   ║")
    print("╚════════════════════════════════════════════════════════════════╝")
    print()
    
    if test_ports:
        print("🔍 Test de connectivité en cours...")
        print()
    
    online_count = 0
    offline_count = 0
    
    for port in sorted(by_port.keys(), key=int):
        domains = sorted(by_port[port])
        count = len(domains)
        
        status = get_status_icon(port, test_ports)
        
        if test_ports:
            if "✅" in status:
                online_count += 1
            else:
                offline_count += 1
        
        # En-tête du port
        print(f"\n{'='*64}")
        if test_ports:
            print(f"{status} 🔌 localhost:{port:<5} │ {count} domaine{'s' if count > 1 else ''}")
        else:
            print(f"🔌 localhost:{port:<5} │ {count} domaine{'s' if count > 1 else ''}")
        print(f"{'='*64}")
        
        # Liste des domaines en colonnes si multiple
        if count == 1:
            print(f"   → {domains[0]}")
        else:
            # Affichage en 2 colonnes si > 4 domaines
            if count > 4:
                mid = (count + 1) // 2
                col1 = domains[:mid]
                col2 = domains[mid:]
                
                max_len = max(len(d) for d in col1) if col1 else 0
                
                for i in range(mid):
                    left = col1[i] if i < len(col1) else ""
                    right = col2[i] if i < len(col2) else ""
                    print(f"   • {left:<{max_len}}    • {right}")
            else:
                for domain in domains:
                    print(f"   • {domain}")
    
    print(f"\n{'='*64}")
    if test_ports:
        print(f"Total: {len(mappings)} mappings sur {len(by_port)} ports (✅ {online_count} actifs, ❌ {offline_count} inactifs)\n")
    else:
        print(f"Total: {len(mappings)} mappings sur {len(by_port)} ports\n")
    
    return offline_count if test_ports else 0

def main():
    parser = argparse.ArgumentParser(
        description='Liste les mappings Caddy domaine:port avec test de connectivité',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Exemples d'utilisation:
  %(prog)s                          # Affichage par défaut (par domaine)
  %(prog)s --test                   # Affichage avec test de connectivité
  %(prog)s --sort port --test       # Tri par port avec test
  %(prog)s --sort port --group --test  # Tri groupé avec test
  %(prog)s --stats --test           # Statistiques avec test
  %(prog)s --combined --test        # Vue combinée avec test
  %(prog)s --file /path/Caddyfile   # Spécifier un fichier

Notes:
  - Sans paramètres : affiche l'aide + mappings triés par domaine
  - --test : teste la connectivité de chaque port (✅ actif / ❌ inactif)
  - Les options --sort, --group, etc. remplacent l'affichage par défaut
  - Code de sortie : 0 si tous les ports sont actifs, 1 si des ports sont inactifs (avec --test)

Codes de sortie:
  0 : Succès (tous les ports actifs avec --test, ou pas de test)
  1 : Des ports sont inactifs (avec --test)
  2 : Erreur (fichier non trouvé, etc.)
        """
    )
    
    parser.add_argument(
        '--sort', '-s',
        choices=['domain', 'port'],
        default=None,
        help='Trier par domaine ou par port'
    )
    
    parser.add_argument(
        '--group', '-g',
        action='store_true',
        help='Regrouper les domaines multiples (avec --sort port)'
    )
    
    parser.add_argument(
        '--test', '-t',
        action='store_true',
        help='Tester la connectivité des ports (✅/❌) - exit 1 si des ports sont inactifs'
    )
    
    parser.add_argument(
        '--stats',
        action='store_true',
        help='Afficher les statistiques détaillées'
    )
    
    parser.add_argument(
        '--combined', '-c',
        action='store_true',
        help='Afficher une vue combinée port/domaines'
    )
    
    parser.add_argument(
        '--file', '-f',
        default=CADDYFILE,
        help=f'Chemin vers le Caddyfile (défaut: {CADDYFILE})'
    )
    
    # Si aucun argument n'est fourni, afficher l'aide puis les mappings par défaut
    show_default = len(sys.argv) == 1
    
    args = parser.parse_args()
    
    try:
        mappings = parse_caddyfile(args.file)
        
        if not mappings:
            print("❌ Aucun mapping trouvé")
            sys.exit(2)
        
        offline_count = 0
        
        # Si aucun argument, afficher l'aide puis le mapping par défaut
        if show_default:
            parser.print_help()
            print("\n" + "="*64)
            print("📋 Affichage par défaut (tri par domaine)")
            print("="*64)
            offline_count = display_by_domain(mappings, False, False)
        # Sinon, afficher selon les options
        elif args.combined:
            offline_count = display_combined(mappings, args.test)
        elif args.stats:
            offline_count = display_statistics(mappings, args.test)
        elif args.sort == 'port':
            offline_count = display_by_port(mappings, args.group, args.test)
        elif args.sort == 'domain':
            offline_count = display_by_domain(mappings, args.group, args.test)
        else:
            # Par défaut si --sort n'est pas spécifié mais d'autres options le sont
            offline_count = display_by_domain(mappings, args.group, args.test)
        
        # Exit avec code d'erreur si des ports sont inactifs ET que --test est activé
        if args.test and offline_count > 0:
            print(f"\n⚠️  Attention : {offline_count} port(s) inactif(s) détecté(s)")
            sys.exit(1)
        
        sys.exit(0)
            
    except FileNotFoundError:
        print(f"❌ Fichier non trouvé : {args.file}")
        sys.exit(2)
    except Exception as e:
        print(f"❌ Erreur : {e}")
        import traceback
        traceback.print_exc()
        sys.exit(2)

if __name__ == '__main__':
    main()
