from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import List, Dict, Any, Optional, Tuple import os import yaml import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @dataclass class ScanTarget: """ One IP and its expected ports. """ ip: str expected_tcp: List[int] = field(default_factory=list) expected_udp: List[int] = field(default_factory=list) @dataclass class ScanOptions: """ Feature toggles that affect how scans are executed. """ udp_scan: bool = False tls_security_scan: bool = True tls_exp_check: bool = True @dataclass class Reporting: """ Output/report preferences for this config file. """ report_name: str = "Scan Report" report_filename: str = "report.html" full_details: bool = False @dataclass class ScanConfigFile: """ Full configuration for a single logical scan "set" (e.g., DMZ, WAN). """ name: str = "Unnamed" scan_options: ScanOptions = field(default_factory=ScanOptions) reporting: Reporting = field(default_factory=Reporting) scan_targets: List[ScanTarget] = field(default_factory=list) class ScanConfigRepository: """ Loads and validates *.yaml scan configuration files from a directory. Search order for the config directory: 1) Explicit path argument to load_all() 2) Environment variable SCAN_TARGETS_DIR 3) Default: /data/scan_targets """ SUPPORTED_EXT = (".yaml", ".yml") def __init__(self) -> None: self._loaded: List[ScanConfigFile] = [] def load_all(self, directory: Optional[Path] = None) -> List[ScanConfigFile]: """ Load all YAML configs from the given directory and return them. :param directory: Optional explicit directory path. """ root = self._resolve_directory(directory) logger.info("Loading scan configs from: %s", root) files = sorted([p for p in root.iterdir() if p.suffix.lower() in self.SUPPORTED_EXT]) logger.info("Found %d config file(s).", len(files)) configs: List[ScanConfigFile] = [] for fpath in files: try: data = self._read_yaml(fpath) cfg = self._parse_config(data, default_name=fpath.stem) self._validate_config(cfg, source=str(fpath)) configs.append(cfg) logger.info("Loaded config: %s (%s targets)", cfg.name, len(cfg.scan_targets)) except Exception as exc: # Fail-open vs fail-fast is up to you; here we log and continue. logger.error("Failed to load %s: %s", fpath, exc) self._loaded = configs return configs def _resolve_directory(self, directory: Optional[Path]) -> Path: """ Decide which directory to load from. """ if directory: return directory env = os.getenv("SCAN_TARGETS_DIR") if env: return Path(env) return Path("/data/scan_targets") @staticmethod def _read_yaml(path: Path) -> Dict[str, Any]: """ Safely read YAML file into a Python dict. """ with path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} if not isinstance(data, dict): raise ValueError("Top-level YAML must be a mapping (dict).") return data @staticmethod def _as_int_list(value: Any, field_name: str) -> List[int]: """ Coerce a sequence to a list of ints; raise if invalid. """ if value in (None, []): return [] if not isinstance(value, (list, tuple)): raise TypeError(f"'{field_name}' must be a list of integers.") out: List[int] = [] for v in value: if isinstance(v, bool): # Avoid True/False being treated as 1/0 raise TypeError(f"'{field_name}' must contain integers, not booleans.") try: out.append(int(v)) except Exception as exc: raise TypeError(f"'{field_name}' contains a non-integer: {v!r}") from exc return out def _parse_config(self, data: Dict[str, Any], default_name: str) -> ScanConfigFile: """ Convert a raw dict (from YAML) into a validated ScanConfigFile. """ name = str(data.get("name", default_name)) # Parse scan_options so_raw = data.get("scan_options", {}) or {} scan_options = ScanOptions( udp_scan=bool(so_raw.get("udp_scan", False)), tls_security_scan=bool(so_raw.get("tls_security_scan", True)), tls_exp_check=bool(so_raw.get("tls_exp_check", True)), ) # Parse reporting rep_raw = data.get("reporting", {}) or {} reporting = Reporting( report_name=str(rep_raw.get("report_name", "Scan Report")), report_filename=str(rep_raw.get("report_filename", "report.html")), full_details=bool(rep_raw.get("full_details", False)), ) # Parse targets targets_raw = data.get("scan_targets", []) or [] if not isinstance(targets_raw, list): raise TypeError("'scan_targets' must be a list.") targets: List[ScanTarget] = [] for idx, item in enumerate(targets_raw, start=1): if not isinstance(item, dict): raise TypeError(f"scan_targets[{idx}] must be a mapping (dict).") ip = item.get("ip") if not ip or not isinstance(ip, str): raise ValueError(f"scan_targets[{idx}].ip must be a non-empty string.") expected_tcp = self._as_int_list(item.get("expected_tcp", []), "expected_tcp") expected_udp = self._as_int_list(item.get("expected_udp", []), "expected_udp") targets.append(ScanTarget(ip=ip, expected_tcp=expected_tcp, expected_udp=expected_udp)) return ScanConfigFile( name=name, scan_options=scan_options, reporting=reporting, scan_targets=targets, ) @staticmethod def _validate_config(cfg: ScanConfigFile, source: str) -> None: """ Lightweight semantic checks. """ # Example: disallow duplicate IPs within a single file seen: Dict[str, int] = {} for t in cfg.scan_targets: seen[t.ip] = seen.get(t.ip, 0) + 1 dups = [ip for ip, count in seen.items() if count > 1] if dups: raise ValueError(f"{source}: duplicate IP(s) in scan_targets: {dups}") # Optional helpers def list_configs(self) -> List[str]: """ Return names of loaded configs for UI selection. """ return [c.name for c in self._loaded] def get_by_name(self, name: str) -> Optional[ScanConfigFile]: """ Fetch a loaded config by its name. """ for c in self._loaded: if c.name == name: return c return None