Files
mass-scan2/app/utils/scan_config_loader.py

253 lines
8.7 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import os
import yaml
import logging
from apscheduler.triggers.cron import CronTrigger
logger = logging.getLogger(__name__)
@dataclass
class ScanTarget:
"""
One IP and its expected ports.
"""
ip: str
expected_tcp: List[int] = field(default_factory=list)
expected_udp: List[int] = field(default_factory=list)
@dataclass
class ScanOptions:
"""
Feature toggles that affect how scans are executed.
"""
cron: Optional[str] = None
udp_scan: bool = False
tls_security_scan: bool = True
tls_exp_check: bool = True
@dataclass
class Reporting:
"""
Output/report preferences for this config file.
"""
report_name: str = "Scan Report"
report_filename: str = "report.html"
full_details: bool = False
dark_mode: bool = True
email_to: List[str] = field(default_factory=list)
email_cc: List[str] = field(default_factory=list)
@dataclass
class ScanConfigFile:
"""
Full configuration for a single logical scan "set" (e.g., DMZ, WAN).
"""
name: str = "Unnamed"
scan_options: ScanOptions = field(default_factory=ScanOptions)
reporting: Reporting = field(default_factory=Reporting)
scan_targets: List[ScanTarget] = field(default_factory=list)
class ScanConfigRepository:
"""
Loads and validates *.yaml scan configuration files from a directory.
Search order for the config directory:
1) Explicit path argument to load_all()
2) Environment variable SCAN_TARGETS_DIR
3) Default: /app/data/scan_targets
"""
SUPPORTED_EXT = (".yaml", ".yml")
def __init__(self) -> None:
self._loaded: List[ScanConfigFile] = []
def load_all(self, directory: Optional[Path] = None) -> List[ScanConfigFile]:
"""
Load all YAML configs from the given directory and return them.
:param directory: Optional explicit directory path.
"""
root = self._resolve_directory(directory)
logger.info("Loading scan configs from: %s", root)
files = sorted([p for p in root.iterdir() if p.suffix.lower() in self.SUPPORTED_EXT])
logger.info("Found %d config file(s).", len(files))
configs: List[ScanConfigFile] = []
for fpath in files:
try:
data = self._read_yaml(fpath)
cfg = self._parse_config(data, default_name=fpath.stem)
self._validate_config(cfg, source=str(fpath))
configs.append(cfg)
logger.info("Loaded config: %s (%s targets)", cfg.name, len(cfg.scan_targets))
except Exception as exc:
# Fail-open vs fail-fast is up to you; here we log and continue.
logger.error("Failed to load %s: %s", fpath, exc)
self._loaded = configs
return configs
def _resolve_directory(self, directory: Optional[Path]) -> Path:
"""
Decide which directory to load from.
"""
if directory:
return directory
env = os.getenv("SCAN_TARGETS_DIR")
if env:
return Path(env)
return Path("/app/data/scan_targets")
@staticmethod
def _read_yaml(path: Path) -> Dict[str, Any]:
"""
Safely read YAML file into a Python dict.
"""
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
if not isinstance(data, dict):
raise ValueError("Top-level YAML must be a mapping (dict).")
return data
@staticmethod
def _as_int_list(value: Any, field_name: str) -> List[int]:
"""
Coerce a sequence to a list of ints; raise if invalid.
"""
if value in (None, []):
return []
if not isinstance(value, (list, tuple)):
raise TypeError(f"'{field_name}' must be a list of integers.")
out: List[int] = []
for v in value:
if isinstance(v, bool):
# Avoid True/False being treated as 1/0
raise TypeError(f"'{field_name}' must contain integers, not booleans.")
try:
out.append(int(v))
except Exception as exc:
raise TypeError(f"'{field_name}' contains a non-integer: {v!r}") from exc
return out
def _parse_config(self, data: Dict[str, Any], default_name: str) -> ScanConfigFile:
"""
Convert a raw dict (from YAML) into a validated ScanConfigFile.
"""
name = str(data.get("name", default_name))
# Parse scan_options
so_raw = data.get("scan_options", {}) or {}
scan_options = ScanOptions(
cron=self._validate_cron_or_none(so_raw.get("cron")),
udp_scan=bool(so_raw.get("udp_scan", False)),
tls_security_scan=bool(so_raw.get("tls_security_scan", True)),
tls_exp_check=bool(so_raw.get("tls_exp_check", True)),
)
# Parse reporting
rep_raw = data.get("reporting", {}) or {}
reporting = Reporting(
report_name=str(rep_raw.get("report_name", "Scan Report")),
report_filename=str(rep_raw.get("report_filename", "report.html")),
full_details=bool(rep_raw.get("full_details", False)),
dark_mode = bool(rep_raw.get("dark_mode", False)),
email_to=self._as_str_list(rep_raw.get("email_to", []), "email_to"),
email_cc=self._as_str_list(rep_raw.get("email_cc", []), "email_cc"),
)
# Parse targets
targets_raw = data.get("scan_targets", []) or []
if not isinstance(targets_raw, list):
raise TypeError("'scan_targets' must be a list.")
targets: List[ScanTarget] = []
for idx, item in enumerate(targets_raw, start=1):
if not isinstance(item, dict):
raise TypeError(f"scan_targets[{idx}] must be a mapping (dict).")
ip = item.get("ip")
if not ip or not isinstance(ip, str):
raise ValueError(f"scan_targets[{idx}].ip must be a non-empty string.")
expected_tcp = self._as_int_list(item.get("expected_tcp", []), "expected_tcp")
expected_udp = self._as_int_list(item.get("expected_udp", []), "expected_udp")
targets.append(ScanTarget(ip=ip, expected_tcp=expected_tcp, expected_udp=expected_udp))
return ScanConfigFile(
name=name,
scan_options=scan_options,
reporting=reporting,
scan_targets=targets,
)
@staticmethod
def _validate_cron_or_none(expr: Optional[str]) -> Optional[str]:
"""
Validate a standard 5-field crontab string via CronTrigger.from_crontab.
Return the original string if valid; None if empty/None.
Raise ValueError on invalid expressions.
"""
if not expr:
return None
expr = str(expr).strip()
# Validate now so we fail early on bad configs
CronTrigger.from_crontab(expr) # will raise on invalid
return expr
@staticmethod
def _validate_config(cfg: ScanConfigFile, source: str) -> None:
"""
Lightweight semantic checks.
"""
# Example: disallow duplicate IPs within a single file
seen: Dict[str, int] = {}
for t in cfg.scan_targets:
seen[t.ip] = seen.get(t.ip, 0) + 1
dups = [ip for ip, count in seen.items() if count > 1]
if dups:
raise ValueError(f"{source}: duplicate IP(s) in scan_targets: {dups}")
@staticmethod
def _as_str_list(value: Any, field_name: str) -> List[str]:
"""
Accept a single string or a list of strings; return List[str].
"""
if value is None or value == []:
return []
if isinstance(value, str):
return [value.strip()] if value.strip() else []
if isinstance(value, (list, tuple)):
out: List[str] = []
for v in value:
if not isinstance(v, str):
raise TypeError(f"'{field_name}' must contain only strings.")
s = v.strip()
if s:
out.append(s)
return out
raise TypeError(f"'{field_name}' must be a string or a list of strings.")
# helpers
def list_configs(self) -> List[str]:
"""
Return names of loaded configs for UI selection.
"""
return [c.name for c in self._loaded]
def get_by_name(self, name: str) -> Optional[ScanConfigFile]:
"""
Fetch a loaded config by its name.
"""
for c in self._loaded:
if c.name == name:
return c
return None