scheduling and jobs, new dataclasses and such better UDP handling

This commit is contained in:
2025-10-17 16:49:30 -05:00
parent 9956667c8f
commit 41306801ae
13 changed files with 771 additions and 169 deletions

View File

@@ -1,41 +1,34 @@
#!/usr/bin/env python3
"""
port_checker.py
- expects `expected.json` in same dir (see format below)
- writes nmap XML to a temp file, parses, compares, prints a report
"""
import os
import json
import subprocess
import tempfile
from datetime import datetime
import xml.etree.ElementTree as ET
import logging
logging.basicConfig(level=logging.INFO)
# TODO:
# LOGGING
# TLS SCANNING
# TLS Version PROBE
# EMAIL
import time
from pathlib import Path
from typing import Dict, List, Set
from utils.scan_config_loader import ScanConfigRepository, ScanConfigFile
from utils.schedule_manager import ScanScheduler
from utils.scanner import nmap_scanner
from utils.models import HostResult
from reporting_jinja import write_html_report_jinja
EXPECTED_FILE = Path() / "data" / "expected.json"
from reporting_jinja import write_html_report_jinja
from utils.settings import get_settings
from utils.common import get_common_utils
logger = logging.getLogger(__file__)
utils = get_common_utils()
settings = get_settings()
HTML_REPORT_FILE = Path() / "data" / "report.html"
def load_expected(path: Path) -> Dict[str, Dict[str, Set[int]]]:
with path.open() as fh:
arr = json.load(fh)
out = {}
for entry in arr:
ip = entry["ip"]
out[ip] = {
"expected_tcp": set(entry.get("expected_tcp", [])),
"expected_udp": set(entry.get("expected_udp", [])),
}
return out
# def write_targets(expected: Dict[str, Dict[str, Set[int]]], path: Path) -> None:
path.write_text("\n".join(sorted(expected.keys())) + "\n")
#
def results_to_open_sets(
results: List[HostResult],
count_as_open: Set[str] = frozenset({"open", "open|filtered"})) -> Dict[str, Dict[str, Set[int]]]:
@@ -56,11 +49,14 @@ def results_to_open_sets(
# Build the "reports" dict (what the HTML renderer expects)
def build_reports(
expected: Dict[str, Dict[str, Set[int]]],
scan_config: "ScanConfigFile",
discovered: Dict[str, Dict[str, Set[int]]],
) -> Dict[str, Dict[str, List[int]]]:
"""
Create the per-IP delta structure:
Create the per-IP delta structure using expected ports from `scan_config.scan_targets`
and discovered ports from `discovered`.
Output format:
{
ip: {
"unexpected_tcp": [...],
@@ -69,15 +65,52 @@ def build_reports(
"missing_udp": [...]
}
}
Notes:
- If a host has no expected UDP ports in the config, `expected_udp` is empty here.
(This function reflects *expectations*, not what to scan. Your scan logic can still
choose 'top UDP ports' for those hosts.)
- The `discovered` dict is expected to use keys "tcp" / "udp" per host.
"""
reports: Dict[str, Dict[str, List[int]]] = {}
# Build `expected` from scan_config.scan_targets
expected: Dict[str, Dict[str, Set[int]]] = {}
cfg_targets = getattr(scan_config, "scan_targets", []) or []
for t in cfg_targets:
# Works whether ScanTarget is a dataclass or a dict-like object
ip = getattr(t, "ip", None) if hasattr(t, "ip") else t.get("ip")
if not ip:
continue
raw_tcp = getattr(t, "expected_tcp", None) if hasattr(t, "expected_tcp") else t.get("expected_tcp", [])
raw_udp = getattr(t, "expected_udp", None) if hasattr(t, "expected_udp") else t.get("expected_udp", [])
exp_tcp = set(int(p) for p in (raw_tcp or []))
exp_udp = set(int(p) for p in (raw_udp or []))
expected[ip] = {
"expected_tcp": exp_tcp,
"expected_udp": exp_udp,
}
# Union of IPs present in either expectations or discoveries
all_ips = set(expected.keys()) | set(discovered.keys())
reports: Dict[str, Dict[str, List[int]]] = {}
for ip in sorted(all_ips):
# Expected sets (default to empty sets if not present)
exp_tcp = expected.get(ip, {}).get("expected_tcp", set())
exp_udp = expected.get(ip, {}).get("expected_udp", set())
disc_tcp = discovered.get(ip, {}).get("tcp", set())
disc_udp = discovered.get(ip, {}).get("udp", set())
# Discovered sets (default to empty sets if not present)
disc_tcp = discovered.get(ip, {}).get("tcp", set()) or set()
disc_udp = discovered.get(ip, {}).get("udp", set()) or set()
# Ensure sets in case caller provided lists
if not isinstance(disc_tcp, set):
disc_tcp = set(disc_tcp)
if not isinstance(disc_udp, set):
disc_udp = set(disc_udp)
reports[ip] = {
"unexpected_tcp": sorted(disc_tcp - exp_tcp),
@@ -85,24 +118,55 @@ def build_reports(
"unexpected_udp": sorted(disc_udp - exp_udp),
"missing_udp": sorted(exp_udp - disc_udp),
}
return reports
def main():
# repo = ScanConfigRepository()
if not EXPECTED_FILE.exists():
print("Expected File not found")
return
expected = load_expected(EXPECTED_FILE)
targets = sorted(expected.keys())
scanner = nmap_scanner(targets)
def run_repo_scan(scan_config:ScanConfigFile):
logger.info(f"Starting scan for {scan_config.name}")
logger.info("Options: udp=%s tls_sec=%s tls_exp=%s",
scan_config.scan_options.udp_scan,
scan_config.scan_options.tls_security_scan,
scan_config.scan_options.tls_exp_check)
logger.info("Targets: %d hosts", len(scan_config.scan_targets))
scanner = nmap_scanner(scan_config)
scan_results = scanner.scan_targets()
discovered_sets = results_to_open_sets(scan_results, count_as_open={"open", "open|filtered"})
reports = build_reports(expected, discovered_sets)
reports = build_reports(scan_config, discovered_sets)
write_html_report_jinja(reports=reports,host_results=scan_results,out_path=HTML_REPORT_FILE,title="Compliance Report",only_issues=True)
scanner.cleanup()
def main():
logger.info(f"{settings.app.name} - v{settings.app.version_major}.{settings.app.version_minor} Started")
logger.info(f"Application Running Production flag set to: {settings.app.production}")
# timezone validation
if utils.TextUtils.is_valid_timezone(settings.app.timezone):
logger.info(f"Timezone set to {settings.app.timezone}")
app_timezone = settings.app.timezone
else:
logger.warning(f"The Timezone {settings.app.timezone} is invalid, Defaulting to UTC")
app_timezone = "America/Danmarkshavn" # UTC
# load / configure the scan repos
repo = ScanConfigRepository()
scan_configs = repo.load_all()
# if in prod - run the scheduler like normal
if settings.app.production:
sched = ScanScheduler(timezone=app_timezone)
sched.start()
jobs = sched.schedule_configs(scan_configs, run_scan_fn=run_repo_scan)
logger.info("Scheduled %d job(s).", jobs)
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
sched.shutdown()
else:
# run single scan in dev mode
run_repo_scan(scan_configs[0])
if __name__ == "__main__":
main()

View File

@@ -1,2 +1,5 @@
Jinja2==3.1.6
MarkupSafe==3.0.3
PyYAML >= 5.3.1
APScheduler ==3.11
requests >= 2.32.5

213
app/utils/common.py Normal file
View File

@@ -0,0 +1,213 @@
import re
import os
import sys
import csv
import json
import logging
import zipfile
import functools
from pathlib import Path
from zoneinfo import ZoneInfo, available_timezones
logger = logging.getLogger(__file__)
try:
import requests
import yaml
except ModuleNotFoundError:
msg = (
"Required modules are not installed. "
"Can not continue with module / application loading.\n"
"Install it with: pip install -r requirements"
)
print(msg, file=sys.stderr)
logger.error(msg)
exit()
# ---------- SINGLETON DECORATOR ----------
T = type("T", (), {})
def singleton_loader(func):
"""Decorator to ensure a singleton instance."""
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
if func.__name__ not in cache:
cache[func.__name__] = func(*args, **kwargs)
return cache[func.__name__]
return wrapper
# ---------- UTILITY CLASSES ----------
class FileUtils:
"""File and directory utilities."""
@staticmethod
def ensure_directory(path):
"""Create the directory if it doesn't exist."""
dir_path = Path(path)
if not dir_path.exists():
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory: {dir_path}")
return True
return False
@staticmethod
def create_dir_if_not_exist(dir_to_create):
return FileUtils.ensure_directory(dir_to_create)
@staticmethod
def list_files_with_ext(directory="/tmp", ext="docx"):
"""List all files in a directory with a specific extension."""
return [f for f in os.listdir(directory) if f.endswith(ext)]
@staticmethod
def download_file(url, dest_path):
"""Download a file from a URL to a local path."""
response = requests.get(url, stream=True)
response.raise_for_status()
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"File Downloaded to: {dest_path} from {url}")
@staticmethod
def unzip_file(zip_path, extract_to="."):
"""Unzip a file to the given directory."""
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
logger.info(f"{zip_path} Extracted to: {extract_to}")
@staticmethod
def verify_file_exist(filepath: Path, exit_if_false=False):
"""Verify a file exists."""
if not filepath.exists():
if exit_if_false:
sys.stderr.write(f"[FATAL] File not found: {filepath}\n")
sys.exit(1)
return False
return True
@staticmethod
def read_yaml_file(full_file_path: Path):
"""Read a YAML file safely."""
if not FileUtils.verify_file_exist(full_file_path):
logger.error(f"Unable to read yaml - {full_file_path} does not exist")
return {}
try:
with open(full_file_path, 'r') as yfile:
return yaml.safe_load(yfile)
except Exception as e:
logger.error(f"Unable to read yaml due to: {e}")
return {}
@staticmethod
def delete_list_of_files(files_to_delete: list):
"""Delete multiple files safely."""
for file_path in files_to_delete:
try:
os.remove(file_path)
logger.info(f"Deleted {file_path}")
except FileNotFoundError:
logger.warning(f"File not found: {file_path}")
except PermissionError:
logger.warning(f"Permission denied: {file_path}")
except Exception as e:
logger.error(f"Error deleting {file_path}: {e}")
class TextUtils:
"""Text parsing and string utilities."""
@staticmethod
def extract_strings(data: bytes, min_length: int = 4):
"""Extract ASCII and UTF-16LE strings from binary data."""
ascii_re = re.compile(rb"[ -~]{%d,}" % min_length)
ascii_strings = [match.decode("ascii", errors="ignore") for match in ascii_re.findall(data)]
wide_re = re.compile(rb"(?:[ -~]\x00){%d,}" % min_length)
wide_strings = [match.decode("utf-16le", errors="ignore") for match in wide_re.findall(data)]
return ascii_strings + wide_strings
@staticmethod
def defang_url(url: str) -> str:
"""Defang a URL to prevent it from being clickable."""
return url.replace('.', '[.]').replace(':', '[:]')
@staticmethod
def load_dirty_json(json_text: str):
"""Load JSON, return None on error."""
try:
return json.loads(json_text)
except Exception as e:
logger.warning(f"Failed to parse JSON: {e}")
return None
@staticmethod
def is_valid_timezone(tz_str: str) -> bool:
"""
Check if a timezone string is a valid IANA timezone.
Example: 'America/Chicago', 'UTC', etc.
"""
try:
ZoneInfo(tz_str)
return True
except Exception:
return False
class DataUtils:
"""Data manipulation utilities (CSV, dict lists)."""
@staticmethod
def sort_dict_list(dict_list, key):
"""Sort a list of dictionaries by a given key."""
return sorted(dict_list, key=lambda x: x[key])
@staticmethod
def write_to_csv(data, headers, filename):
"""
Write a list of dictionaries to a CSV file with specified headers.
Nested dicts/lists are flattened for CSV output.
"""
if not data:
logger.warning("No data provided to write to CSV")
return
with open(filename, mode='w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(headers)
key_mapping = list(data[0].keys())
for item in data:
row = []
for key in key_mapping:
item_value = item.get(key, "")
if isinstance(item_value, list):
entry = ", ".join(str(v) for v in item_value)
elif isinstance(item_value, dict):
entry = json.dumps(item_value)
else:
entry = str(item_value)
row.append(entry)
writer.writerow(row)
# ---------- SINGLETON FACTORY ----------
@singleton_loader
def get_common_utils():
"""
Returns the singleton instance for common utilities.
Usage:
utils = get_common_utils()
utils.FileUtils.ensure_directory("/tmp/data")
utils.TextUtils.defang_url("http://example.com")
"""
# Aggregate all utility classes into one instance
class _CommonUtils:
FileUtils = FileUtils
TextUtils = TextUtils
DataUtils = DataUtils
return _CommonUtils()

View File

@@ -6,8 +6,10 @@ import os
import yaml
import logging
from apscheduler.triggers.cron import CronTrigger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@dataclass
@@ -25,6 +27,7 @@ class ScanOptions:
"""
Feature toggles that affect how scans are executed.
"""
cron: Optional[str] = None
udp_scan: bool = False
tls_security_scan: bool = True
tls_exp_check: bool = True
@@ -38,6 +41,8 @@ class Reporting:
report_name: str = "Scan Report"
report_filename: str = "report.html"
full_details: bool = False
email_to: List[str] = field(default_factory=list)
email_cc: List[str] = field(default_factory=list)
@dataclass
@@ -58,7 +63,7 @@ class ScanConfigRepository:
Search order for the config directory:
1) Explicit path argument to load_all()
2) Environment variable SCAN_TARGETS_DIR
3) Default: /data/scan_targets
3) Default: /app/data/scan_targets
"""
SUPPORTED_EXT = (".yaml", ".yml")
@@ -102,7 +107,7 @@ class ScanConfigRepository:
env = os.getenv("SCAN_TARGETS_DIR")
if env:
return Path(env)
return Path("/data/scan_targets")
return Path("/app/data/scan_targets")
@staticmethod
def _read_yaml(path: Path) -> Dict[str, Any]:
@@ -144,6 +149,7 @@ class ScanConfigRepository:
# Parse scan_options
so_raw = data.get("scan_options", {}) or {}
scan_options = ScanOptions(
cron=self._validate_cron_or_none(so_raw.get("cron")),
udp_scan=bool(so_raw.get("udp_scan", False)),
tls_security_scan=bool(so_raw.get("tls_security_scan", True)),
tls_exp_check=bool(so_raw.get("tls_exp_check", True)),
@@ -155,6 +161,8 @@ class ScanConfigRepository:
report_name=str(rep_raw.get("report_name", "Scan Report")),
report_filename=str(rep_raw.get("report_filename", "report.html")),
full_details=bool(rep_raw.get("full_details", False)),
email_to=self._as_str_list(rep_raw.get("email_to", []), "email_to"),
email_cc=self._as_str_list(rep_raw.get("email_cc", []), "email_cc"),
)
# Parse targets
@@ -179,6 +187,20 @@ class ScanConfigRepository:
scan_targets=targets,
)
@staticmethod
def _validate_cron_or_none(expr: Optional[str]) -> Optional[str]:
"""
Validate a standard 5-field crontab string via CronTrigger.from_crontab.
Return the original string if valid; None if empty/None.
Raise ValueError on invalid expressions.
"""
if not expr:
return None
expr = str(expr).strip()
# Validate now so we fail early on bad configs
CronTrigger.from_crontab(expr) # will raise on invalid
return expr
@staticmethod
def _validate_config(cfg: ScanConfigFile, source: str) -> None:
"""
@@ -192,7 +214,27 @@ class ScanConfigRepository:
if dups:
raise ValueError(f"{source}: duplicate IP(s) in scan_targets: {dups}")
# Optional helpers
@staticmethod
def _as_str_list(value: Any, field_name: str) -> List[str]:
"""
Accept a single string or a list of strings; return List[str].
"""
if value is None or value == []:
return []
if isinstance(value, str):
return [value.strip()] if value.strip() else []
if isinstance(value, (list, tuple)):
out: List[str] = []
for v in value:
if not isinstance(v, str):
raise TypeError(f"'{field_name}' must contain only strings.")
s = v.strip()
if s:
out.append(s)
return out
raise TypeError(f"'{field_name}' must be a string or a list of strings.")
# helpers
def list_configs(self) -> List[str]:
"""

View File

@@ -2,11 +2,12 @@ from __future__ import annotations
import os
import subprocess
import xml.etree.ElementTree as ET
import tempfile
from pathlib import Path
from typing import Iterable, List, Dict, Optional, Tuple
from typing import Iterable, List, Dict, Optional, Tuple, Union
from utils.models import HostResult, PortFinding
from utils.scan_config_loader import ScanConfigFile,ScanTarget
class nmap_scanner:
@@ -14,9 +15,11 @@ class nmap_scanner:
UDP_REPORT_PATH = Path() / "data" / "nmap-udp-results.xml"
NMAP_RESULTS_PATH = Path() / "data" / "nmap-results.xml"
def __init__(self, targets:Iterable[str],scan_udp=False):
self.targets = list(targets)
self.scan_udp = scan_udp
def __init__(self, config:ScanConfigFile):
self.scan_config = config
self.targets = config.scan_targets
self.target_list = [t.ip for t in config.scan_targets]
self.scan_udp = config.scan_options.udp_scan
pass
def scan_targets(self):
@@ -24,7 +27,7 @@ class nmap_scanner:
if self.scan_udp:
udp_results = self.run_nmap_udp()
all_results = List[HostResult] = self.merge_host_results(tcp_results,udp_results)
all_results: List[HostResult] = self.merge_host_results(tcp_results,udp_results)
else:
all_results = tcp_results
@@ -35,7 +38,7 @@ class nmap_scanner:
Run a TCP SYN scan across all ports (0-65535) for the given targets and parse results.
Returns a list of HostResult objects.
"""
targets_list = self.targets
targets_list = self.target_list
if not targets_list:
return []
@@ -54,35 +57,135 @@ class nmap_scanner:
self._run_nmap(cmd)
return self.parse_nmap_xml(self.TCP_REPORT_PATH)
def run_nmap_udp(self, ports: Optional[Iterable[int]] = None, min_rate: int = 500, assume_up: bool = True) -> List[HostResult]:
def run_nmap_udp(
self,
ports: Optional[Iterable[int]] = None,
min_rate: int = 500,
assume_up: bool = True,
) -> List[HostResult]:
"""
Run a UDP scan for the provided ports (recommended to keep this list small).
If 'ports' is None, nmap defaults to its "top" UDP ports; full -p- UDP is very slow.
Run UDP scans.
Behavior:
- If `ports` is provided -> single nmap run against all targets using that port list.
- If `ports` is None ->
* For hosts with `expected_udp` defined and non-empty: scan only those ports.
* For hosts with no `expected_udp` (or empty): omit `-p` so nmap uses its default top UDP ports.
Hosts sharing the same explicit UDP port set are grouped into one nmap run.
Returns:
Merged List[HostResult] across all runs.
"""
targets_list = self.targets
targets_list = getattr(self, "target_list", [])
if not targets_list:
return []
cmd = [
"nmap",
"-sU", # UDP scan
"-T3", # less aggressive timing by default for UDP
"--min-rate", str(min_rate),
"-oX", str(self.UDP_REPORT_PATH),
]
if assume_up:
cmd.append("-Pn")
# Optional logger (don't fail if not present)
logger = getattr(self, "logger", None)
def _log(msg: str) -> None:
if logger:
logger.info(msg)
else:
print(msg)
# Case 1: caller provided a global port list -> one run, all targets
if ports:
# Explicit port set
port_list = sorted(set(int(p) for p in ports))
port_list = sorted({int(p) for p in ports})
port_str = ",".join(str(p) for p in port_list)
with tempfile.NamedTemporaryFile(prefix="nmap_udp_", suffix=".xml", delete=False) as tmp:
report_path = tmp.name
cmd = [
"nmap",
"-sU",
"-T3",
"--min-rate", str(min_rate),
"-oX", str(report_path),
]
if assume_up:
cmd.append("-Pn")
cmd.extend(["-p", port_str])
cmd.extend(targets_list)
cmd.extend(targets_list)
_log(f"UDP scan (global ports): {port_str} on {len(targets_list)} host(s)")
self._run_nmap(cmd)
results = self.parse_nmap_xml(report_path)
try:
os.remove(report_path)
except OSError:
pass
return results
self._run_nmap(cmd)
return self.parse_nmap_xml(self.UDP_REPORT_PATH)
# Case 2: per-host behavior using self.scan_config.scan_targets
# Build per-IP port tuple (empty tuple => use nmap's default top UDP ports)
ip_to_ports: Dict[str, Tuple[int, ...]] = {}
# Prefer the IPs present in self.target_list (order/selection comes from there)
# Map from ScanConfigFile / ScanTarget
cfg_targets = getattr(getattr(self, "scan_config", None), "scan_targets", []) or []
# Build quick lookup from config
conf_map: Dict[str, List[int]] = {}
for t in cfg_targets:
# Support either dataclass (attrs) or dict-like
ip = getattr(t, "ip", None) if hasattr(t, "ip") else t.get("ip")
if not ip:
continue
raw_udp = getattr(t, "expected_udp", None) if hasattr(t, "expected_udp") else t.get("expected_udp", [])
conf_map[ip] = list(raw_udp or [])
for ip in targets_list:
raw = conf_map.get(ip, [])
if raw:
ip_to_ports[ip] = tuple(sorted(int(p) for p in raw))
else:
ip_to_ports[ip] = () # empty => use nmap defaults (top UDP ports)
# Group hosts by identical port tuple
groups: Dict[Tuple[int, ...], List[str]] = {}
for ip, port_tuple in ip_to_ports.items():
groups.setdefault(port_tuple, []).append(ip)
all_result_sets: List[List[HostResult]] = []
for port_tuple, ips in groups.items():
# Per-group report path
with tempfile.NamedTemporaryFile(prefix="nmap_udp_", suffix=".xml", delete=False) as tmp:
report_path = tmp.name
cmd = [
"nmap",
"-sU",
"-T3",
"--min-rate", str(min_rate),
"-oX", str(report_path),
]
if assume_up:
cmd.append("-Pn")
if port_tuple:
# explicit per-group ports
port_str = ",".join(str(p) for p in port_tuple)
cmd.extend(["-p", port_str])
_log(f"UDP scan (explicit ports {port_str}) on {len(ips)} host(s): {', '.join(ips)}")
else:
# no -p -> nmap defaults to its top UDP ports
_log(f"UDP scan (nmap top UDP ports) on {len(ips)} host(s): {', '.join(ips)}")
cmd.extend(ips)
self._run_nmap(cmd)
result = self.parse_nmap_xml(report_path)
all_result_sets.append(result)
try:
os.remove(report_path)
except OSError:
pass
if not all_result_sets:
return []
# Merge per-run results into final list
return self.merge_host_results(*all_result_sets)
def merge_host_results(self, *result_sets: List[HostResult]) -> List[HostResult]:
"""
@@ -174,6 +277,6 @@ class nmap_scanner:
self.TCP_REPORT_PATH.unlink()
if self.UDP_REPORT_PATH.exists():
self.UDP_REPORT_PATH.unlink()
if self.NMAP_RESULTS_PATH.exists:
self.NMAP_RESULTS_PATH.unlink()
# if self.NMAP_RESULTS_PATH.exists:
# self.NMAP_RESULTS_PATH.unlink()

View File

@@ -0,0 +1,79 @@
# scheduler_manager.py
from __future__ import annotations
import logging
from dataclasses import asdict
from typing import Callable, List, Optional
from zoneinfo import ZoneInfo
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from utils.scan_config_loader import ScanConfigFile
logger = logging.getLogger(__name__)
class ScanScheduler:
"""
Owns an APScheduler and schedules one job per ScanConfigFile that has scan_options.cron set.
"""
def __init__(self, timezone: str = "America/Chicago") -> None:
self.tz = ZoneInfo(timezone)
self.scheduler = BackgroundScheduler(timezone=self.tz)
def start(self) -> None:
"""
Start the underlying scheduler thread.
"""
if not self.scheduler.running:
self.scheduler.start()
logger.info("APScheduler started (tz=%s).", self.tz)
def shutdown(self) -> None:
"""
Gracefully stop the scheduler.
"""
if self.scheduler.running:
self.scheduler.shutdown(wait=False)
logger.info("APScheduler stopped.")
def schedule_configs(
self,
configs: List[ScanConfigFile],
run_scan_fn: Callable[[ScanConfigFile], None],
replace_existing: bool = True,
) -> int:
"""
Create/replace jobs for all configs with a valid cron.
Returns number of scheduled jobs.
"""
count = 0
for cfg in configs:
cron = (cfg.scan_options.cron or "").strip() if cfg.scan_options else ""
if not cron:
logger.info("Skipping schedule (no cron): %s", cfg.name)
continue
job_id = f"scan::{cfg.name}"
trigger = CronTrigger.from_crontab(cron, timezone=self.tz)
self.scheduler.add_job(
func=run_scan_fn,
trigger=trigger,
id=job_id,
args=[cfg],
max_instances=1,
replace_existing=replace_existing,
misfire_grace_time=300,
coalesce=True,
)
logger.info("Scheduled '%s' with cron '%s' (next run: %s)",
cfg.name, cron, self._next_run_time(job_id))
count += 1
return count
def _next_run_time(self, job_id: str):
j = self.scheduler.get_job(job_id)
if j and hasattr(j, "next_run_time"):
return j.next_run_time.isoformat() if j.next_run_time else None
return None

View File

@@ -35,30 +35,21 @@ except ModuleNotFoundError:
logger.error(msg)
exit()
DEFAULT_SETTINGS_FILE = Path.cwd() / "config" /"settings.yaml"
DEFAULT_SETTINGS_FILE = Path.cwd() / "data" /"settings.yaml"
# ---------- CONFIG DATA CLASSES ----------
@dataclass
class DatabaseConfig:
host: str = "localhost"
port: int = 5432
username: str = "root"
password: str = ""
@dataclass
class AppConfig:
name: str = "MyApp"
version_major: int = 1
version_minor: int = 0
name: str = "Mass Scan"
version_major: int = 0
version_minor: int = 1
production: bool = False
enabled: bool = True
token_expiry: int = 3600
timezone: str = "America/Chicago"
@dataclass
class Settings:
database: DatabaseConfig = field(default_factory=DatabaseConfig)
app: AppConfig = field(default_factory=AppConfig)
@classmethod