scheduling and jobs, new dataclasses and such better UDP handling
This commit is contained in:
156
app/main.py
156
app/main.py
@@ -1,41 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
port_checker.py
|
||||
- expects `expected.json` in same dir (see format below)
|
||||
- writes nmap XML to a temp file, parses, compares, prints a report
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
import xml.etree.ElementTree as ET
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# TODO:
|
||||
# LOGGING
|
||||
# TLS SCANNING
|
||||
# TLS Version PROBE
|
||||
# EMAIL
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
|
||||
from utils.scan_config_loader import ScanConfigRepository, ScanConfigFile
|
||||
from utils.schedule_manager import ScanScheduler
|
||||
from utils.scanner import nmap_scanner
|
||||
from utils.models import HostResult
|
||||
from reporting_jinja import write_html_report_jinja
|
||||
|
||||
EXPECTED_FILE = Path() / "data" / "expected.json"
|
||||
from reporting_jinja import write_html_report_jinja
|
||||
from utils.settings import get_settings
|
||||
from utils.common import get_common_utils
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
utils = get_common_utils()
|
||||
settings = get_settings()
|
||||
|
||||
HTML_REPORT_FILE = Path() / "data" / "report.html"
|
||||
|
||||
def load_expected(path: Path) -> Dict[str, Dict[str, Set[int]]]:
|
||||
with path.open() as fh:
|
||||
arr = json.load(fh)
|
||||
out = {}
|
||||
for entry in arr:
|
||||
ip = entry["ip"]
|
||||
out[ip] = {
|
||||
"expected_tcp": set(entry.get("expected_tcp", [])),
|
||||
"expected_udp": set(entry.get("expected_udp", [])),
|
||||
}
|
||||
return out
|
||||
|
||||
# def write_targets(expected: Dict[str, Dict[str, Set[int]]], path: Path) -> None:
|
||||
path.write_text("\n".join(sorted(expected.keys())) + "\n")
|
||||
|
||||
#
|
||||
def results_to_open_sets(
|
||||
results: List[HostResult],
|
||||
count_as_open: Set[str] = frozenset({"open", "open|filtered"})) -> Dict[str, Dict[str, Set[int]]]:
|
||||
@@ -56,11 +49,14 @@ def results_to_open_sets(
|
||||
|
||||
# Build the "reports" dict (what the HTML renderer expects)
|
||||
def build_reports(
|
||||
expected: Dict[str, Dict[str, Set[int]]],
|
||||
scan_config: "ScanConfigFile",
|
||||
discovered: Dict[str, Dict[str, Set[int]]],
|
||||
) -> Dict[str, Dict[str, List[int]]]:
|
||||
"""
|
||||
Create the per-IP delta structure:
|
||||
Create the per-IP delta structure using expected ports from `scan_config.scan_targets`
|
||||
and discovered ports from `discovered`.
|
||||
|
||||
Output format:
|
||||
{
|
||||
ip: {
|
||||
"unexpected_tcp": [...],
|
||||
@@ -69,15 +65,52 @@ def build_reports(
|
||||
"missing_udp": [...]
|
||||
}
|
||||
}
|
||||
|
||||
Notes:
|
||||
- If a host has no expected UDP ports in the config, `expected_udp` is empty here.
|
||||
(This function reflects *expectations*, not what to scan. Your scan logic can still
|
||||
choose 'top UDP ports' for those hosts.)
|
||||
- The `discovered` dict is expected to use keys "tcp" / "udp" per host.
|
||||
"""
|
||||
reports: Dict[str, Dict[str, List[int]]] = {}
|
||||
# Build `expected` from scan_config.scan_targets
|
||||
expected: Dict[str, Dict[str, Set[int]]] = {}
|
||||
cfg_targets = getattr(scan_config, "scan_targets", []) or []
|
||||
|
||||
for t in cfg_targets:
|
||||
# Works whether ScanTarget is a dataclass or a dict-like object
|
||||
ip = getattr(t, "ip", None) if hasattr(t, "ip") else t.get("ip")
|
||||
if not ip:
|
||||
continue
|
||||
|
||||
raw_tcp = getattr(t, "expected_tcp", None) if hasattr(t, "expected_tcp") else t.get("expected_tcp", [])
|
||||
raw_udp = getattr(t, "expected_udp", None) if hasattr(t, "expected_udp") else t.get("expected_udp", [])
|
||||
|
||||
exp_tcp = set(int(p) for p in (raw_tcp or []))
|
||||
exp_udp = set(int(p) for p in (raw_udp or []))
|
||||
|
||||
expected[ip] = {
|
||||
"expected_tcp": exp_tcp,
|
||||
"expected_udp": exp_udp,
|
||||
}
|
||||
|
||||
# Union of IPs present in either expectations or discoveries
|
||||
all_ips = set(expected.keys()) | set(discovered.keys())
|
||||
|
||||
reports: Dict[str, Dict[str, List[int]]] = {}
|
||||
for ip in sorted(all_ips):
|
||||
# Expected sets (default to empty sets if not present)
|
||||
exp_tcp = expected.get(ip, {}).get("expected_tcp", set())
|
||||
exp_udp = expected.get(ip, {}).get("expected_udp", set())
|
||||
disc_tcp = discovered.get(ip, {}).get("tcp", set())
|
||||
disc_udp = discovered.get(ip, {}).get("udp", set())
|
||||
|
||||
# Discovered sets (default to empty sets if not present)
|
||||
disc_tcp = discovered.get(ip, {}).get("tcp", set()) or set()
|
||||
disc_udp = discovered.get(ip, {}).get("udp", set()) or set()
|
||||
|
||||
# Ensure sets in case caller provided lists
|
||||
if not isinstance(disc_tcp, set):
|
||||
disc_tcp = set(disc_tcp)
|
||||
if not isinstance(disc_udp, set):
|
||||
disc_udp = set(disc_udp)
|
||||
|
||||
reports[ip] = {
|
||||
"unexpected_tcp": sorted(disc_tcp - exp_tcp),
|
||||
@@ -85,24 +118,55 @@ def build_reports(
|
||||
"unexpected_udp": sorted(disc_udp - exp_udp),
|
||||
"missing_udp": sorted(exp_udp - disc_udp),
|
||||
}
|
||||
|
||||
return reports
|
||||
|
||||
def main():
|
||||
|
||||
# repo = ScanConfigRepository()
|
||||
|
||||
if not EXPECTED_FILE.exists():
|
||||
print("Expected File not found")
|
||||
return
|
||||
|
||||
expected = load_expected(EXPECTED_FILE)
|
||||
targets = sorted(expected.keys())
|
||||
scanner = nmap_scanner(targets)
|
||||
def run_repo_scan(scan_config:ScanConfigFile):
|
||||
logger.info(f"Starting scan for {scan_config.name}")
|
||||
logger.info("Options: udp=%s tls_sec=%s tls_exp=%s",
|
||||
scan_config.scan_options.udp_scan,
|
||||
scan_config.scan_options.tls_security_scan,
|
||||
scan_config.scan_options.tls_exp_check)
|
||||
logger.info("Targets: %d hosts", len(scan_config.scan_targets))
|
||||
scanner = nmap_scanner(scan_config)
|
||||
scan_results = scanner.scan_targets()
|
||||
discovered_sets = results_to_open_sets(scan_results, count_as_open={"open", "open|filtered"})
|
||||
reports = build_reports(expected, discovered_sets)
|
||||
reports = build_reports(scan_config, discovered_sets)
|
||||
write_html_report_jinja(reports=reports,host_results=scan_results,out_path=HTML_REPORT_FILE,title="Compliance Report",only_issues=True)
|
||||
scanner.cleanup()
|
||||
|
||||
def main():
|
||||
logger.info(f"{settings.app.name} - v{settings.app.version_major}.{settings.app.version_minor} Started")
|
||||
logger.info(f"Application Running Production flag set to: {settings.app.production}")
|
||||
|
||||
# timezone validation
|
||||
if utils.TextUtils.is_valid_timezone(settings.app.timezone):
|
||||
logger.info(f"Timezone set to {settings.app.timezone}")
|
||||
app_timezone = settings.app.timezone
|
||||
else:
|
||||
logger.warning(f"The Timezone {settings.app.timezone} is invalid, Defaulting to UTC")
|
||||
app_timezone = "America/Danmarkshavn" # UTC
|
||||
|
||||
# load / configure the scan repos
|
||||
repo = ScanConfigRepository()
|
||||
scan_configs = repo.load_all()
|
||||
|
||||
# if in prod - run the scheduler like normal
|
||||
if settings.app.production:
|
||||
sched = ScanScheduler(timezone=app_timezone)
|
||||
sched.start()
|
||||
|
||||
jobs = sched.schedule_configs(scan_configs, run_scan_fn=run_repo_scan)
|
||||
logger.info("Scheduled %d job(s).", jobs)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
except KeyboardInterrupt:
|
||||
sched.shutdown()
|
||||
else:
|
||||
# run single scan in dev mode
|
||||
run_repo_scan(scan_configs[0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
Jinja2==3.1.6
|
||||
MarkupSafe==3.0.3
|
||||
PyYAML >= 5.3.1
|
||||
APScheduler ==3.11
|
||||
requests >= 2.32.5
|
||||
213
app/utils/common.py
Normal file
213
app/utils/common.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import zipfile
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo, available_timezones
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
try:
|
||||
import requests
|
||||
import yaml
|
||||
except ModuleNotFoundError:
|
||||
msg = (
|
||||
"Required modules are not installed. "
|
||||
"Can not continue with module / application loading.\n"
|
||||
"Install it with: pip install -r requirements"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
logger.error(msg)
|
||||
exit()
|
||||
|
||||
|
||||
# ---------- SINGLETON DECORATOR ----------
|
||||
T = type("T", (), {})
|
||||
def singleton_loader(func):
|
||||
"""Decorator to ensure a singleton instance."""
|
||||
cache = {}
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if func.__name__ not in cache:
|
||||
cache[func.__name__] = func(*args, **kwargs)
|
||||
return cache[func.__name__]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# ---------- UTILITY CLASSES ----------
|
||||
class FileUtils:
|
||||
"""File and directory utilities."""
|
||||
|
||||
@staticmethod
|
||||
def ensure_directory(path):
|
||||
"""Create the directory if it doesn't exist."""
|
||||
dir_path = Path(path)
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created directory: {dir_path}")
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_dir_if_not_exist(dir_to_create):
|
||||
return FileUtils.ensure_directory(dir_to_create)
|
||||
|
||||
@staticmethod
|
||||
def list_files_with_ext(directory="/tmp", ext="docx"):
|
||||
"""List all files in a directory with a specific extension."""
|
||||
return [f for f in os.listdir(directory) if f.endswith(ext)]
|
||||
|
||||
@staticmethod
|
||||
def download_file(url, dest_path):
|
||||
"""Download a file from a URL to a local path."""
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(dest_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info(f"File Downloaded to: {dest_path} from {url}")
|
||||
|
||||
@staticmethod
|
||||
def unzip_file(zip_path, extract_to="."):
|
||||
"""Unzip a file to the given directory."""
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(extract_to)
|
||||
logger.info(f"{zip_path} Extracted to: {extract_to}")
|
||||
|
||||
@staticmethod
|
||||
def verify_file_exist(filepath: Path, exit_if_false=False):
|
||||
"""Verify a file exists."""
|
||||
if not filepath.exists():
|
||||
if exit_if_false:
|
||||
sys.stderr.write(f"[FATAL] File not found: {filepath}\n")
|
||||
sys.exit(1)
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def read_yaml_file(full_file_path: Path):
|
||||
"""Read a YAML file safely."""
|
||||
if not FileUtils.verify_file_exist(full_file_path):
|
||||
logger.error(f"Unable to read yaml - {full_file_path} does not exist")
|
||||
return {}
|
||||
try:
|
||||
with open(full_file_path, 'r') as yfile:
|
||||
return yaml.safe_load(yfile)
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to read yaml due to: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def delete_list_of_files(files_to_delete: list):
|
||||
"""Delete multiple files safely."""
|
||||
for file_path in files_to_delete:
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.info(f"Deleted {file_path}")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"File not found: {file_path}")
|
||||
except PermissionError:
|
||||
logger.warning(f"Permission denied: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting {file_path}: {e}")
|
||||
|
||||
class TextUtils:
|
||||
"""Text parsing and string utilities."""
|
||||
|
||||
@staticmethod
|
||||
def extract_strings(data: bytes, min_length: int = 4):
|
||||
"""Extract ASCII and UTF-16LE strings from binary data."""
|
||||
ascii_re = re.compile(rb"[ -~]{%d,}" % min_length)
|
||||
ascii_strings = [match.decode("ascii", errors="ignore") for match in ascii_re.findall(data)]
|
||||
|
||||
wide_re = re.compile(rb"(?:[ -~]\x00){%d,}" % min_length)
|
||||
wide_strings = [match.decode("utf-16le", errors="ignore") for match in wide_re.findall(data)]
|
||||
|
||||
return ascii_strings + wide_strings
|
||||
|
||||
@staticmethod
|
||||
def defang_url(url: str) -> str:
|
||||
"""Defang a URL to prevent it from being clickable."""
|
||||
return url.replace('.', '[.]').replace(':', '[:]')
|
||||
|
||||
@staticmethod
|
||||
def load_dirty_json(json_text: str):
|
||||
"""Load JSON, return None on error."""
|
||||
try:
|
||||
return json.loads(json_text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse JSON: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_valid_timezone(tz_str: str) -> bool:
|
||||
"""
|
||||
Check if a timezone string is a valid IANA timezone.
|
||||
Example: 'America/Chicago', 'UTC', etc.
|
||||
"""
|
||||
try:
|
||||
ZoneInfo(tz_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
class DataUtils:
|
||||
"""Data manipulation utilities (CSV, dict lists)."""
|
||||
|
||||
@staticmethod
|
||||
def sort_dict_list(dict_list, key):
|
||||
"""Sort a list of dictionaries by a given key."""
|
||||
return sorted(dict_list, key=lambda x: x[key])
|
||||
|
||||
@staticmethod
|
||||
def write_to_csv(data, headers, filename):
|
||||
"""
|
||||
Write a list of dictionaries to a CSV file with specified headers.
|
||||
Nested dicts/lists are flattened for CSV output.
|
||||
"""
|
||||
if not data:
|
||||
logger.warning("No data provided to write to CSV")
|
||||
return
|
||||
|
||||
with open(filename, mode='w', newline='', encoding='utf-8') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(headers)
|
||||
|
||||
key_mapping = list(data[0].keys())
|
||||
for item in data:
|
||||
row = []
|
||||
for key in key_mapping:
|
||||
item_value = item.get(key, "")
|
||||
if isinstance(item_value, list):
|
||||
entry = ", ".join(str(v) for v in item_value)
|
||||
elif isinstance(item_value, dict):
|
||||
entry = json.dumps(item_value)
|
||||
else:
|
||||
entry = str(item_value)
|
||||
row.append(entry)
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
# ---------- SINGLETON FACTORY ----------
|
||||
@singleton_loader
|
||||
def get_common_utils():
|
||||
"""
|
||||
Returns the singleton instance for common utilities.
|
||||
Usage:
|
||||
utils = get_common_utils()
|
||||
utils.FileUtils.ensure_directory("/tmp/data")
|
||||
utils.TextUtils.defang_url("http://example.com")
|
||||
"""
|
||||
# Aggregate all utility classes into one instance
|
||||
class _CommonUtils:
|
||||
FileUtils = FileUtils
|
||||
TextUtils = TextUtils
|
||||
DataUtils = DataUtils
|
||||
|
||||
return _CommonUtils()
|
||||
@@ -6,8 +6,10 @@ import os
|
||||
import yaml
|
||||
import logging
|
||||
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,6 +27,7 @@ class ScanOptions:
|
||||
"""
|
||||
Feature toggles that affect how scans are executed.
|
||||
"""
|
||||
cron: Optional[str] = None
|
||||
udp_scan: bool = False
|
||||
tls_security_scan: bool = True
|
||||
tls_exp_check: bool = True
|
||||
@@ -38,6 +41,8 @@ class Reporting:
|
||||
report_name: str = "Scan Report"
|
||||
report_filename: str = "report.html"
|
||||
full_details: bool = False
|
||||
email_to: List[str] = field(default_factory=list)
|
||||
email_cc: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,7 +63,7 @@ class ScanConfigRepository:
|
||||
Search order for the config directory:
|
||||
1) Explicit path argument to load_all()
|
||||
2) Environment variable SCAN_TARGETS_DIR
|
||||
3) Default: /data/scan_targets
|
||||
3) Default: /app/data/scan_targets
|
||||
"""
|
||||
|
||||
SUPPORTED_EXT = (".yaml", ".yml")
|
||||
@@ -102,7 +107,7 @@ class ScanConfigRepository:
|
||||
env = os.getenv("SCAN_TARGETS_DIR")
|
||||
if env:
|
||||
return Path(env)
|
||||
return Path("/data/scan_targets")
|
||||
return Path("/app/data/scan_targets")
|
||||
|
||||
@staticmethod
|
||||
def _read_yaml(path: Path) -> Dict[str, Any]:
|
||||
@@ -144,6 +149,7 @@ class ScanConfigRepository:
|
||||
# Parse scan_options
|
||||
so_raw = data.get("scan_options", {}) or {}
|
||||
scan_options = ScanOptions(
|
||||
cron=self._validate_cron_or_none(so_raw.get("cron")),
|
||||
udp_scan=bool(so_raw.get("udp_scan", False)),
|
||||
tls_security_scan=bool(so_raw.get("tls_security_scan", True)),
|
||||
tls_exp_check=bool(so_raw.get("tls_exp_check", True)),
|
||||
@@ -155,6 +161,8 @@ class ScanConfigRepository:
|
||||
report_name=str(rep_raw.get("report_name", "Scan Report")),
|
||||
report_filename=str(rep_raw.get("report_filename", "report.html")),
|
||||
full_details=bool(rep_raw.get("full_details", False)),
|
||||
email_to=self._as_str_list(rep_raw.get("email_to", []), "email_to"),
|
||||
email_cc=self._as_str_list(rep_raw.get("email_cc", []), "email_cc"),
|
||||
)
|
||||
|
||||
# Parse targets
|
||||
@@ -179,6 +187,20 @@ class ScanConfigRepository:
|
||||
scan_targets=targets,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_cron_or_none(expr: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Validate a standard 5-field crontab string via CronTrigger.from_crontab.
|
||||
Return the original string if valid; None if empty/None.
|
||||
Raise ValueError on invalid expressions.
|
||||
"""
|
||||
if not expr:
|
||||
return None
|
||||
expr = str(expr).strip()
|
||||
# Validate now so we fail early on bad configs
|
||||
CronTrigger.from_crontab(expr) # will raise on invalid
|
||||
return expr
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(cfg: ScanConfigFile, source: str) -> None:
|
||||
"""
|
||||
@@ -192,7 +214,27 @@ class ScanConfigRepository:
|
||||
if dups:
|
||||
raise ValueError(f"{source}: duplicate IP(s) in scan_targets: {dups}")
|
||||
|
||||
# Optional helpers
|
||||
@staticmethod
|
||||
def _as_str_list(value: Any, field_name: str) -> List[str]:
|
||||
"""
|
||||
Accept a single string or a list of strings; return List[str].
|
||||
"""
|
||||
if value is None or value == []:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value.strip()] if value.strip() else []
|
||||
if isinstance(value, (list, tuple)):
|
||||
out: List[str] = []
|
||||
for v in value:
|
||||
if not isinstance(v, str):
|
||||
raise TypeError(f"'{field_name}' must contain only strings.")
|
||||
s = v.strip()
|
||||
if s:
|
||||
out.append(s)
|
||||
return out
|
||||
raise TypeError(f"'{field_name}' must be a string or a list of strings.")
|
||||
|
||||
# helpers
|
||||
|
||||
def list_configs(self) -> List[str]:
|
||||
"""
|
||||
|
||||
@@ -2,11 +2,12 @@ from __future__ import annotations
|
||||
import os
|
||||
import subprocess
|
||||
import xml.etree.ElementTree as ET
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Dict, Optional, Tuple
|
||||
from typing import Iterable, List, Dict, Optional, Tuple, Union
|
||||
|
||||
from utils.models import HostResult, PortFinding
|
||||
|
||||
from utils.scan_config_loader import ScanConfigFile,ScanTarget
|
||||
|
||||
class nmap_scanner:
|
||||
|
||||
@@ -14,9 +15,11 @@ class nmap_scanner:
|
||||
UDP_REPORT_PATH = Path() / "data" / "nmap-udp-results.xml"
|
||||
NMAP_RESULTS_PATH = Path() / "data" / "nmap-results.xml"
|
||||
|
||||
def __init__(self, targets:Iterable[str],scan_udp=False):
|
||||
self.targets = list(targets)
|
||||
self.scan_udp = scan_udp
|
||||
def __init__(self, config:ScanConfigFile):
|
||||
self.scan_config = config
|
||||
self.targets = config.scan_targets
|
||||
self.target_list = [t.ip for t in config.scan_targets]
|
||||
self.scan_udp = config.scan_options.udp_scan
|
||||
pass
|
||||
|
||||
def scan_targets(self):
|
||||
@@ -24,7 +27,7 @@ class nmap_scanner:
|
||||
|
||||
if self.scan_udp:
|
||||
udp_results = self.run_nmap_udp()
|
||||
all_results = List[HostResult] = self.merge_host_results(tcp_results,udp_results)
|
||||
all_results: List[HostResult] = self.merge_host_results(tcp_results,udp_results)
|
||||
else:
|
||||
all_results = tcp_results
|
||||
|
||||
@@ -35,7 +38,7 @@ class nmap_scanner:
|
||||
Run a TCP SYN scan across all ports (0-65535) for the given targets and parse results.
|
||||
Returns a list of HostResult objects.
|
||||
"""
|
||||
targets_list = self.targets
|
||||
targets_list = self.target_list
|
||||
if not targets_list:
|
||||
return []
|
||||
|
||||
@@ -54,35 +57,135 @@ class nmap_scanner:
|
||||
self._run_nmap(cmd)
|
||||
return self.parse_nmap_xml(self.TCP_REPORT_PATH)
|
||||
|
||||
def run_nmap_udp(self, ports: Optional[Iterable[int]] = None, min_rate: int = 500, assume_up: bool = True) -> List[HostResult]:
|
||||
def run_nmap_udp(
|
||||
self,
|
||||
ports: Optional[Iterable[int]] = None,
|
||||
min_rate: int = 500,
|
||||
assume_up: bool = True,
|
||||
) -> List[HostResult]:
|
||||
"""
|
||||
Run a UDP scan for the provided ports (recommended to keep this list small).
|
||||
If 'ports' is None, nmap defaults to its "top" UDP ports; full -p- UDP is very slow.
|
||||
Run UDP scans.
|
||||
|
||||
Behavior:
|
||||
- If `ports` is provided -> single nmap run against all targets using that port list.
|
||||
- If `ports` is None ->
|
||||
* For hosts with `expected_udp` defined and non-empty: scan only those ports.
|
||||
* For hosts with no `expected_udp` (or empty): omit `-p` so nmap uses its default top UDP ports.
|
||||
Hosts sharing the same explicit UDP port set are grouped into one nmap run.
|
||||
Returns:
|
||||
Merged List[HostResult] across all runs.
|
||||
"""
|
||||
targets_list = self.targets
|
||||
targets_list = getattr(self, "target_list", [])
|
||||
if not targets_list:
|
||||
return []
|
||||
|
||||
cmd = [
|
||||
"nmap",
|
||||
"-sU", # UDP scan
|
||||
"-T3", # less aggressive timing by default for UDP
|
||||
"--min-rate", str(min_rate),
|
||||
"-oX", str(self.UDP_REPORT_PATH),
|
||||
]
|
||||
if assume_up:
|
||||
cmd.append("-Pn")
|
||||
# Optional logger (don't fail if not present)
|
||||
logger = getattr(self, "logger", None)
|
||||
def _log(msg: str) -> None:
|
||||
if logger:
|
||||
logger.info(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
# Case 1: caller provided a global port list -> one run, all targets
|
||||
if ports:
|
||||
# Explicit port set
|
||||
port_list = sorted(set(int(p) for p in ports))
|
||||
port_list = sorted({int(p) for p in ports})
|
||||
port_str = ",".join(str(p) for p in port_list)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="nmap_udp_", suffix=".xml", delete=False) as tmp:
|
||||
report_path = tmp.name
|
||||
|
||||
cmd = [
|
||||
"nmap",
|
||||
"-sU",
|
||||
"-T3",
|
||||
"--min-rate", str(min_rate),
|
||||
"-oX", str(report_path),
|
||||
]
|
||||
if assume_up:
|
||||
cmd.append("-Pn")
|
||||
cmd.extend(["-p", port_str])
|
||||
cmd.extend(targets_list)
|
||||
|
||||
cmd.extend(targets_list)
|
||||
_log(f"UDP scan (global ports): {port_str} on {len(targets_list)} host(s)")
|
||||
self._run_nmap(cmd)
|
||||
results = self.parse_nmap_xml(report_path)
|
||||
try:
|
||||
os.remove(report_path)
|
||||
except OSError:
|
||||
pass
|
||||
return results
|
||||
|
||||
self._run_nmap(cmd)
|
||||
return self.parse_nmap_xml(self.UDP_REPORT_PATH)
|
||||
# Case 2: per-host behavior using self.scan_config.scan_targets
|
||||
# Build per-IP port tuple (empty tuple => use nmap's default top UDP ports)
|
||||
ip_to_ports: Dict[str, Tuple[int, ...]] = {}
|
||||
|
||||
# Prefer the IPs present in self.target_list (order/selection comes from there)
|
||||
# Map from ScanConfigFile / ScanTarget
|
||||
cfg_targets = getattr(getattr(self, "scan_config", None), "scan_targets", []) or []
|
||||
# Build quick lookup from config
|
||||
conf_map: Dict[str, List[int]] = {}
|
||||
for t in cfg_targets:
|
||||
# Support either dataclass (attrs) or dict-like
|
||||
ip = getattr(t, "ip", None) if hasattr(t, "ip") else t.get("ip")
|
||||
if not ip:
|
||||
continue
|
||||
raw_udp = getattr(t, "expected_udp", None) if hasattr(t, "expected_udp") else t.get("expected_udp", [])
|
||||
conf_map[ip] = list(raw_udp or [])
|
||||
|
||||
for ip in targets_list:
|
||||
raw = conf_map.get(ip, [])
|
||||
if raw:
|
||||
ip_to_ports[ip] = tuple(sorted(int(p) for p in raw))
|
||||
else:
|
||||
ip_to_ports[ip] = () # empty => use nmap defaults (top UDP ports)
|
||||
|
||||
# Group hosts by identical port tuple
|
||||
groups: Dict[Tuple[int, ...], List[str]] = {}
|
||||
for ip, port_tuple in ip_to_ports.items():
|
||||
groups.setdefault(port_tuple, []).append(ip)
|
||||
|
||||
all_result_sets: List[List[HostResult]] = []
|
||||
|
||||
for port_tuple, ips in groups.items():
|
||||
# Per-group report path
|
||||
with tempfile.NamedTemporaryFile(prefix="nmap_udp_", suffix=".xml", delete=False) as tmp:
|
||||
report_path = tmp.name
|
||||
|
||||
cmd = [
|
||||
"nmap",
|
||||
"-sU",
|
||||
"-T3",
|
||||
"--min-rate", str(min_rate),
|
||||
"-oX", str(report_path),
|
||||
]
|
||||
if assume_up:
|
||||
cmd.append("-Pn")
|
||||
|
||||
if port_tuple:
|
||||
# explicit per-group ports
|
||||
port_str = ",".join(str(p) for p in port_tuple)
|
||||
cmd.extend(["-p", port_str])
|
||||
_log(f"UDP scan (explicit ports {port_str}) on {len(ips)} host(s): {', '.join(ips)}")
|
||||
else:
|
||||
# no -p -> nmap defaults to its top UDP ports
|
||||
_log(f"UDP scan (nmap top UDP ports) on {len(ips)} host(s): {', '.join(ips)}")
|
||||
|
||||
cmd.extend(ips)
|
||||
|
||||
self._run_nmap(cmd)
|
||||
result = self.parse_nmap_xml(report_path)
|
||||
all_result_sets.append(result)
|
||||
try:
|
||||
os.remove(report_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not all_result_sets:
|
||||
return []
|
||||
|
||||
# Merge per-run results into final list
|
||||
return self.merge_host_results(*all_result_sets)
|
||||
|
||||
def merge_host_results(self, *result_sets: List[HostResult]) -> List[HostResult]:
|
||||
"""
|
||||
@@ -174,6 +277,6 @@ class nmap_scanner:
|
||||
self.TCP_REPORT_PATH.unlink()
|
||||
if self.UDP_REPORT_PATH.exists():
|
||||
self.UDP_REPORT_PATH.unlink()
|
||||
if self.NMAP_RESULTS_PATH.exists:
|
||||
self.NMAP_RESULTS_PATH.unlink()
|
||||
# if self.NMAP_RESULTS_PATH.exists:
|
||||
# self.NMAP_RESULTS_PATH.unlink()
|
||||
|
||||
79
app/utils/schedule_manager.py
Normal file
79
app/utils/schedule_manager.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# scheduler_manager.py
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from utils.scan_config_loader import ScanConfigFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ScanScheduler:
|
||||
"""
|
||||
Owns an APScheduler and schedules one job per ScanConfigFile that has scan_options.cron set.
|
||||
"""
|
||||
|
||||
def __init__(self, timezone: str = "America/Chicago") -> None:
|
||||
self.tz = ZoneInfo(timezone)
|
||||
self.scheduler = BackgroundScheduler(timezone=self.tz)
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the underlying scheduler thread.
|
||||
"""
|
||||
if not self.scheduler.running:
|
||||
self.scheduler.start()
|
||||
logger.info("APScheduler started (tz=%s).", self.tz)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Gracefully stop the scheduler.
|
||||
"""
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
logger.info("APScheduler stopped.")
|
||||
|
||||
def schedule_configs(
|
||||
self,
|
||||
configs: List[ScanConfigFile],
|
||||
run_scan_fn: Callable[[ScanConfigFile], None],
|
||||
replace_existing: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Create/replace jobs for all configs with a valid cron.
|
||||
Returns number of scheduled jobs.
|
||||
"""
|
||||
count = 0
|
||||
for cfg in configs:
|
||||
cron = (cfg.scan_options.cron or "").strip() if cfg.scan_options else ""
|
||||
if not cron:
|
||||
logger.info("Skipping schedule (no cron): %s", cfg.name)
|
||||
continue
|
||||
|
||||
job_id = f"scan::{cfg.name}"
|
||||
trigger = CronTrigger.from_crontab(cron, timezone=self.tz)
|
||||
|
||||
self.scheduler.add_job(
|
||||
func=run_scan_fn,
|
||||
trigger=trigger,
|
||||
id=job_id,
|
||||
args=[cfg],
|
||||
max_instances=1,
|
||||
replace_existing=replace_existing,
|
||||
misfire_grace_time=300,
|
||||
coalesce=True,
|
||||
)
|
||||
logger.info("Scheduled '%s' with cron '%s' (next run: %s)",
|
||||
cfg.name, cron, self._next_run_time(job_id))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _next_run_time(self, job_id: str):
|
||||
j = self.scheduler.get_job(job_id)
|
||||
if j and hasattr(j, "next_run_time"):
|
||||
return j.next_run_time.isoformat() if j.next_run_time else None
|
||||
return None
|
||||
@@ -35,30 +35,21 @@ except ModuleNotFoundError:
|
||||
logger.error(msg)
|
||||
exit()
|
||||
|
||||
DEFAULT_SETTINGS_FILE = Path.cwd() / "config" /"settings.yaml"
|
||||
DEFAULT_SETTINGS_FILE = Path.cwd() / "data" /"settings.yaml"
|
||||
|
||||
# ---------- CONFIG DATA CLASSES ----------
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
username: str = "root"
|
||||
password: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
name: str = "MyApp"
|
||||
version_major: int = 1
|
||||
version_minor: int = 0
|
||||
name: str = "Mass Scan"
|
||||
version_major: int = 0
|
||||
version_minor: int = 1
|
||||
production: bool = False
|
||||
enabled: bool = True
|
||||
token_expiry: int = 3600
|
||||
timezone: str = "America/Chicago"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
app: AppConfig = field(default_factory=AppConfig)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user