#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import os
import re
import subprocess
from typing import Dict, List, Tuple

# =======================
# Utils
# =======================

def run_cmd(cmd: List[str]) -> str:
    try:
        return subprocess.check_output(cmd, text=True).strip()
    except (subprocess.CalledProcessError, FileNotFoundError):
        return ""

def resolve(path: str) -> str:
    try:
        return os.path.realpath(path)
    except Exception:
        return path

def human_size_bytes(b: int) -> str:
    try:
        b = int(b)
    except Exception:
        return "-"
    TB = 1099511627776
    GB = 1073741824
    if b >= TB:
        return f"{b/TB:.2f} TB"
    return f"{b/GB:.2f} GB"

def size_to_bytes(size_str: str) -> float:
    s = size_str.strip().upper()
    if s.endswith("T"):
        return float(s[:-1]) * 1099511627776
    if s.endswith("G"):
        return float(s[:-1]) * 1073741824
    return float(s)

# =======================
# Data collectors
# =======================

def lsblk_tree() -> dict:
    # NAME,TYPE,SIZE (bytes),MODEL,SERIAL,FSTYPE,MOUNTPOINT
    out = run_cmd(["lsblk", "-J", "-b",
                   "-o", "NAME,TYPE,SIZE,MODEL,SERIAL,FSTYPE,MOUNTPOINT"])
    return json.loads(out) if out else {"blockdevices": []}

def is_virtual_name(name: str) -> bool:
    return name.startswith(("loop", "ram", "zram", "dm-", "md", "zd", "sr"))

def smart_info(dev: str) -> Tuple[str, str, int]:
    """
    Retourne (model, serial, capacity_bytes) via smartctl si dispo,
    sinon '-' / 0
    """
    out = run_cmd(["smartctl", "-i", dev])
    model, serial, capacity = "-", "-", 0
    for line in out.splitlines():
        if ":" not in line:
            continue
        k, v = line.split(":", 1)
        k = k.strip().lower()
        v = v.strip()
        if k in ("device model", "product", "model family") and model == "-":
            model = v
        elif k == "serial number" and serial == "-":
            serial = v
        elif k == "user capacity":
            # ex: User Capacity:    18,000,207,937,536 bytes [18.0 TB]
            try:
                inside = line.split("[", 1)[0]
                num = inside.split()[-2].replace(",", "")
                capacity = int(float(num))
            except Exception:
                pass
    return model, serial, capacity

def collect_zfs_map() -> Dict[str, str]:
    """
    Associe device réel -> pool ZFS
    Utilise `zpool status -P` (plus robuste que list -v selon versions)
    """
    zfs = {}
    out = run_cmd(["zpool", "status", "-P"])
    if not out:
        return zfs
    # On repère les lignes de vdevs avec chemins (/dev/...)
    for line in out.splitlines():
        line = line.strip()
        if not line:
            continue
        if line.lower().startswith("pool:"):
            pool = line.split(":", 1)[1].strip()
            current_pool = pool
            continue
        # lignes des devices: chemin en début de ligne
        if line.startswith("/") and "ONLINE" in line or "DEGRADED" in line or "OFFLINE" in line or "FAULTED" in line:
            dev_path = line.split()[0]
            zfs[resolve(dev_path)] = current_pool  # map device -> pool
    return zfs

def collect_mergerfs_sources() -> Dict[str, str]:
    """
    Associe device réel -> mountpoint mergerfs.
    - Récupère les montages de type fuse.mergerfs
    - Pour chaque source listée (chemin), retrouve le device via findmnt -o SOURCE <source_path>
    """
    mapping = {}
    findmnt = run_cmd(["findmnt", "-t", "fuse.mergerfs", "-rn", "-o", "SOURCE,TARGET"])
    if not findmnt:
        # fallback via 'mount'
        for line in run_cmd(["mount"]).splitlines():
            if "type fuse.mergerfs" in line or "mergerfs" in line:
                # format: SRC1:SRC2 on /target type fuse.mergerfs (...)
                try:
                    srcs = line.split(" on ")[0]
                    target = line.split(" on ")[1].split(" type ")[0]
                    for s in srcs.split(":"):
                        dev = run_cmd(["findmnt", "-rn", "-o", "SOURCE", s])
                        if dev.startswith("/dev/"):
                            mapping[resolve(dev)] = target
                except Exception:
                    pass
        return mapping

    for ln in findmnt.splitlines():
        parts = ln.split()
        if len(parts) != 2:
            # certains findmnt compressent différemment; essayer split une fois
            ss = ln.strip().split(" ", 1)
            if len(ss) != 2:
                continue
            sources_str, target = ss[0], ss[1]
        else:
            sources_str, target = parts

        for s in sources_str.split(":"):
            dev = run_cmd(["findmnt", "-rn", "-o", "SOURCE", s])  # -> /dev/XXX
            if dev.startswith("/dev/"):
                mapping[resolve(dev)] = target
    return mapping

# =======================
# Usage detection
# =======================

def detect_usage_for_node(dev: str, mountpoint: str,
                          zfs_map: Dict[str, str],
                          mergerfs_map: Dict[str, str]) -> Tuple[str, str]:
    """
    Détection pour un *noeud* précis (/dev/sdX ou /dev/sdX1)
    Retourne (USAGE_TYPE, MOUNT)
      - zfs:<pool>, mergerfs, mounted, unused
      - MOUNT = target mergerfs ou mountpoint, sinon '-'
    """
    rdev = resolve(dev)
    # ZFS attaché directement ?
    if rdev in zfs_map:
        return f"zfs:{zfs_map[rdev]}", "-"
    # Source mergerfs ?
    if rdev in mergerfs_map:
        return "mergerfs", mergerfs_map[rdev]
    # simple mount
    if mountpoint and mountpoint != "-":
        return "mounted", mountpoint
    return "unused", "-"

def detect_usage_unique(disk: dict,
                        zfs_map: Dict[str, str],
                        mergerfs_map: Dict[str, str]) -> Tuple[str, str]:
    """
    Pour un disque entier (mode --unique), inspecte le disque et ses partitions.
    Priorité: ZFS > MergerFS > Mounted > Unused
    Retourne (USAGE_TYPE, MOUNT)
    """
    dev = f"/dev/{disk['name']}"
    # parent
    ut, mp = detect_usage_for_node(dev, disk.get("mountpoint") or "-", zfs_map, mergerfs_map)
    if ut.startswith("zfs"):
        return ut, mp
    # enfants
    best = ("unused", "-")
    for part in disk.get("children", []) or []:
        pdev = f"/dev/{part['name']}"
        put, pmp = detect_usage_for_node(pdev, part.get("mountpoint") or "-", zfs_map, mergerfs_map)
        if put.startswith("zfs"):
            return put, pmp
        if put == "mergerfs":
            return put, pmp
        if put == "mounted":
            best = ("mounted", pmp)  # garder le 1er mount rencontré
    return best

# =======================
# Printing
# =======================

def print_header(mode_unique: bool):
    if mode_unique:
        print(f"{'DEVICE':12} {'LABEL':12} {'MODEL':25} {'SIZE':10} {'SERIAL':20} {'USAGE_TYPE':12} {'MOUNT':20}")
    else:
        print(f"{'DEVICE':12} {'LABEL':12} {'MODEL':25} {'FSTYPE':10} {'SIZE':10} {'SERIAL':20} {'USAGE_TYPE':12} {'MOUNT':20}")

def print_disk_unique(disk: dict,
                      zfs_map: Dict[str, str],
                      mergerfs_map: Dict[str, str],
                      debug: bool = False):
    dev = f"/dev/{disk['name']}"
    # smart fallback
    model = (disk.get("model") or "-")
    serial = (disk.get("serial") or "-")
    if model == "-" or serial == "-":
        sm_m, sm_s, _ = smart_info(dev)
        if model == "-" and sm_m != "-":
            model = sm_m
        if serial == "-" and sm_s != "-":
            serial = sm_s

    size_hr = human_size_bytes(int(disk.get("size") or 0))
    label = run_cmd(["blkid", "-s", "LABEL", "-o", "value", dev]) or "-"
    ut, mp = detect_usage_unique(disk, zfs_map, mergerfs_map)

    print(f"{dev:12} {label[:12]:12} {model[:25]:25} {size_hr:10} {serial[:20]:20} {ut[:12]:12} {mp[:20]:20}")

def print_disk_grouped(disk: dict,
                       zfs_map: Dict[str, str],
                       mergerfs_map: Dict[str, str],
                       debug: bool = False):
    dev = f"/dev/{disk['name']}"
    model = (disk.get("model") or "-")
    serial = (disk.get("serial") or "-")
    if model == "-" or serial == "-":
        sm_m, sm_s, _ = smart_info(dev)
        if model == "-" and sm_m != "-":
            model = sm_m
        if serial == "-" and sm_s != "-":
            serial = sm_s

    dlabel = run_cmd(["blkid", "-s", "LABEL", "-o", "value", dev]) or "-"
    dsize_hr = human_size_bytes(int(disk.get("size") or 0))
    dut, dmp = detect_usage_for_node(dev, disk.get("mountpoint") or "-", zfs_map, mergerfs_map)

    # ligne disque (FSTYPE '-')
    print(f"{dev:12} {dlabel[:12]:12} {model[:25]:25} {'-':10} {dsize_hr:10} {serial[:20]:20} {dut[:12]:12} {dmp[:20]:20}")

    # partitions
    for part in disk.get("children", []) or []:
        pdev = f"/dev/{part['name']}"
        plabel = run_cmd(["blkid", "-s", "LABEL", "-o", "value", pdev]) or "-"
        pfstype = (part.get("fstype") or "-")
        psize_hr = human_size_bytes(int(part.get("size") or 0))
        put, pmp = detect_usage_for_node(pdev, part.get("mountpoint") or "-", zfs_map, mergerfs_map)
        print(f"  {pdev:10} {plabel[:12]:12} {'-':25} {pfstype[:10]:10} {psize_hr:10} {'-':20} {put[:12]:12} {pmp[:20]:20}")

# =======================
# Filters
# =======================

def passes_filters(disk: dict,
                   model: str,
                   serial: str,
                   size_min_bytes: float,
                   only_zfs: bool,
                   only_mergerfs: bool,
                   only_unused: bool,
                   zfs_map: Dict[str, str],
                   mergerfs_map: Dict[str, str]) -> bool:
    dev = f"/dev/{disk['name']}"
    # model/serial
    dmodel = (disk.get("model") or "-")
    dserial = (disk.get("serial") or "-")
    if model:
        if not re.search(model, dmodel or "", re.IGNORECASE):
            # tenter SMART fallback
            sm_m, _, _ = smart_info(dev)
            if not re.search(model, sm_m or "", re.IGNORECASE):
                return False
    if serial and not re.search(serial, dserial or ""):
        sm_, sm_s, _ = smart_info(dev)
        if not re.search(serial, sm_s or ""):
            return False
    # size_min
    try:
        if size_min_bytes and float(disk.get("size") or 0) < size_min_bytes:
            return False
    except Exception:
        pass

    # usage-type filtres
    ut, _ = detect_usage_unique(disk, zfs_map, mergerfs_map)
    if only_zfs and not ut.startswith("zfs"):
        return False
    if only_mergerfs and ut != "mergerfs":
        return False
    if only_unused and ut != "unused":
        return False
    return True

# =======================
# CLI
# =======================

def build_parser() -> argparse.ArgumentParser:
    ex = r"""
Exemples :
  - Lister tous les disques (grouped par défaut) :
      hdds_list.py
  - Une seule ligne par disque :
      hdds_list.py --unique
  - Seulement les disques ZFS :
      hdds_list.py --only-zfs
  - Seulement mergerfs :
      hdds_list.py --only-mergerfs
  - Seulement les disques non utilisés :
      hdds_list.py --only-unused
  - Filtrer par taille minimale (>= 10 To) :
      hdds_list.py --size-min 10T
  - Filtrer par modèle (regex, insensible à la casse) :
      hdds_list.py --model Samsung
  - Filtrer par numéro de série (regex) :
      hdds_list.py --serial 6TG0797D
"""
    p = argparse.ArgumentParser(
        description="Lister les disques physiques (SMART, ZFS, mergerfs) avec modes grouped/unique.",
        formatter_class=argparse.RawTextHelpFormatter,
        epilog=ex
    )
    p.add_argument("--unique", "-u", action="store_true",
                   help="Une seule ligne par disque (pas de partitions)")
    p.add_argument("--only-zfs", action="store_true",
                   help="Afficher uniquement les disques utilisés par ZFS")
    p.add_argument("--only-mergerfs", action="store_true",
                   help="Afficher uniquement les disques utilisés par mergerfs")
    p.add_argument("--only-unused", action="store_true",
                   help="Afficher uniquement les disques non utilisés")
    p.add_argument("--size-min", help="Filtrer les disques >= taille (ex: 10T, 500G)")
    p.add_argument("--model", help="Filtrer par modèle (regex, insensible à la casse)")
    p.add_argument("--serial", help="Filtrer par numéro de série (regex)")
    p.add_argument("--debug", action="store_true", help="Logs debug (résolutions mergerfs/ZFS)")
    return p

# =======================
# Main
# =======================

def main():
    args = build_parser().parse_args()

    tree = lsblk_tree()
    disks = []
    for d in tree.get("blockdevices", []):
        if d.get("type") != "disk":
            continue
        if is_virtual_name(d.get("name", "")):
            continue
        disks.append(d)

    # maps
    zfs_map = collect_zfs_map()          # device_real -> pool
    mergerfs_map = collect_mergerfs_sources()  # device_real -> mergerfs_target

    if args.debug:
        print("DEBUG ZFS MAP:")
        for k, v in zfs_map.items():
            print(f"  {k} -> {v}")
        print("DEBUG MERGERFS MAP:")
        for k, v in mergerfs_map.items():
            print(f"  {k} -> {v}")
        print()

    # header
    mode = "unique" if args.unique else "grouped"
    print(f"📦 Listing all disks ({mode} mode)\n")
    print_header(args.unique)

    # precompute size_min
    size_min_b = size_to_bytes(args.size_min) if args.size_min else 0

    for d in disks:
        if not passes_filters(d, args.model or "", args.serial or "",
                              size_min_b, args.only_zfs, args.only_mergerfs,
                              args.only_unused, zfs_map, mergerfs_map):
            continue
        if args.unique:
            print_disk_unique(d, zfs_map, mergerfs_map, debug=args.debug)
        else:
            print_disk_grouped(d, zfs_map, mergerfs_map, debug=args.debug)

if __name__ == "__main__":
    main()