from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import List, Dict, Any, Optional, Tuple import os import yaml import logging from apscheduler.triggers.cron import CronTrigger logger = logging.getLogger(__name__) @dataclass class ScanTarget: """ One IP and its expected ports. """ ip: str expected_tcp: List[int] = field(default_factory=list) expected_udp: List[int] = field(default_factory=list) @dataclass class ScanOptions: """ Feature toggles that affect how scans are executed. """ cron: Optional[str] = None udp_scan: bool = False tls_security_scan: bool = True tls_exp_check: bool = True @dataclass class Reporting: """ Output/report preferences for this config file. """ report_name: str = "Scan Report" report_filename: str = "report.html" full_details: bool = False dark_mode: bool = True email_to: List[str] = field(default_factory=list) email_cc: List[str] = field(default_factory=list) @dataclass class ScanConfigFile: """ Full configuration for a single logical scan "set" (e.g., DMZ, WAN). """ name: str = "Unnamed" scan_options: ScanOptions = field(default_factory=ScanOptions) reporting: Reporting = field(default_factory=Reporting) scan_targets: List[ScanTarget] = field(default_factory=list) class ScanConfigRepository: """ Loads and validates *.yaml scan configuration files from a directory. Search order for the config directory: 1) Explicit path argument to load_all() 2) Environment variable SCAN_TARGETS_DIR 3) Default: /app/data/scan_targets """ SUPPORTED_EXT = (".yaml", ".yml") def __init__(self) -> None: self._loaded: List[ScanConfigFile] = [] def load_all(self, directory: Optional[Path] = None) -> List[ScanConfigFile]: """ Load all YAML configs from the given directory and return them. :param directory: Optional explicit directory path. """ root = self._resolve_directory(directory) logger.info("Loading scan configs from: %s", root) files = sorted([p for p in root.iterdir() if p.suffix.lower() in self.SUPPORTED_EXT]) logger.info("Found %d config file(s).", len(files)) configs: List[ScanConfigFile] = [] for fpath in files: try: data = self._read_yaml(fpath) cfg = self._parse_config(data, default_name=fpath.stem) self._validate_config(cfg, source=str(fpath)) configs.append(cfg) logger.info("Loaded config: %s (%s targets)", cfg.name, len(cfg.scan_targets)) except Exception as exc: # Fail-open vs fail-fast is up to you; here we log and continue. logger.error("Failed to load %s: %s", fpath, exc) self._loaded = configs return configs def _resolve_directory(self, directory: Optional[Path]) -> Path: """ Decide which directory to load from. """ if directory: return directory env = os.getenv("SCAN_TARGETS_DIR") if env: return Path(env) return Path("/app/data/scan_targets") @staticmethod def _read_yaml(path: Path) -> Dict[str, Any]: """ Safely read YAML file into a Python dict. """ with path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} if not isinstance(data, dict): raise ValueError("Top-level YAML must be a mapping (dict).") return data @staticmethod def _as_int_list(value: Any, field_name: str) -> List[int]: """ Coerce a sequence to a list of ints; raise if invalid. """ if value in (None, []): return [] if not isinstance(value, (list, tuple)): raise TypeError(f"'{field_name}' must be a list of integers.") out: List[int] = [] for v in value: if isinstance(v, bool): # Avoid True/False being treated as 1/0 raise TypeError(f"'{field_name}' must contain integers, not booleans.") try: out.append(int(v)) except Exception as exc: raise TypeError(f"'{field_name}' contains a non-integer: {v!r}") from exc return out def _parse_config(self, data: Dict[str, Any], default_name: str) -> ScanConfigFile: """ Convert a raw dict (from YAML) into a validated ScanConfigFile. """ name = str(data.get("name", default_name)) # Parse scan_options so_raw = data.get("scan_options", {}) or {} scan_options = ScanOptions( cron=self._validate_cron_or_none(so_raw.get("cron")), udp_scan=bool(so_raw.get("udp_scan", False)), tls_security_scan=bool(so_raw.get("tls_security_scan", True)), tls_exp_check=bool(so_raw.get("tls_exp_check", True)), ) # Parse reporting rep_raw = data.get("reporting", {}) or {} reporting = Reporting( report_name=str(rep_raw.get("report_name", "Scan Report")), report_filename=str(rep_raw.get("report_filename", "report.html")), full_details=bool(rep_raw.get("full_details", False)), dark_mode = bool(rep_raw.get("dark_mode", False)), email_to=self._as_str_list(rep_raw.get("email_to", []), "email_to"), email_cc=self._as_str_list(rep_raw.get("email_cc", []), "email_cc"), ) # Parse targets targets_raw = data.get("scan_targets", []) or [] if not isinstance(targets_raw, list): raise TypeError("'scan_targets' must be a list.") targets: List[ScanTarget] = [] for idx, item in enumerate(targets_raw, start=1): if not isinstance(item, dict): raise TypeError(f"scan_targets[{idx}] must be a mapping (dict).") ip = item.get("ip") if not ip or not isinstance(ip, str): raise ValueError(f"scan_targets[{idx}].ip must be a non-empty string.") expected_tcp = self._as_int_list(item.get("expected_tcp", []), "expected_tcp") expected_udp = self._as_int_list(item.get("expected_udp", []), "expected_udp") targets.append(ScanTarget(ip=ip, expected_tcp=expected_tcp, expected_udp=expected_udp)) return ScanConfigFile( name=name, scan_options=scan_options, reporting=reporting, scan_targets=targets, ) @staticmethod def _validate_cron_or_none(expr: Optional[str]) -> Optional[str]: """ Validate a standard 5-field crontab string via CronTrigger.from_crontab. Return the original string if valid; None if empty/None. Raise ValueError on invalid expressions. """ if not expr: return None expr = str(expr).strip() # Validate now so we fail early on bad configs CronTrigger.from_crontab(expr) # will raise on invalid return expr @staticmethod def _validate_config(cfg: ScanConfigFile, source: str) -> None: """ Lightweight semantic checks. """ # Example: disallow duplicate IPs within a single file seen: Dict[str, int] = {} for t in cfg.scan_targets: seen[t.ip] = seen.get(t.ip, 0) + 1 dups = [ip for ip, count in seen.items() if count > 1] if dups: raise ValueError(f"{source}: duplicate IP(s) in scan_targets: {dups}") @staticmethod def _as_str_list(value: Any, field_name: str) -> List[str]: """ Accept a single string or a list of strings; return List[str]. """ if value is None or value == []: return [] if isinstance(value, str): return [value.strip()] if value.strip() else [] if isinstance(value, (list, tuple)): out: List[str] = [] for v in value: if not isinstance(v, str): raise TypeError(f"'{field_name}' must contain only strings.") s = v.strip() if s: out.append(s) return out raise TypeError(f"'{field_name}' must be a string or a list of strings.") # helpers def list_configs(self) -> List[str]: """ Return names of loaded configs for UI selection. """ return [c.name for c in self._loaded] def get_by_name(self, name: str) -> Optional[ScanConfigFile]: """ Fetch a loaded config by its name. """ for c in self._loaded: if c.name == name: return c return None