Files
SneakyScan/app/web/services/scan_service.py

1034 lines
37 KiB
Python

"""
Scan service for managing scan operations and database integration.
This service handles the business logic for triggering scans, retrieving
scan results, and mapping scanner output to database models.
"""
import json
import logging
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session, joinedload
from web.models import (
Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel,
ScanCertificate, ScanTLSVersion, Site, ScanSiteAssociation
)
from web.utils.pagination import paginate, PaginatedResult
from web.utils.validators import validate_scan_status
logger = logging.getLogger(__name__)
class ScanService:
"""
Service for managing scan operations.
Handles scan lifecycle: triggering, status tracking, result storage,
and cleanup.
"""
def __init__(self, db_session: Session):
"""
Initialize scan service.
Args:
db_session: SQLAlchemy database session
"""
self.db = db_session
def trigger_scan(self, config_id: int,
triggered_by: str = 'manual', schedule_id: Optional[int] = None,
scheduler=None) -> int:
"""
Trigger a new scan.
Creates a Scan record in the database with status='running' and
queues the scan for background execution.
Args:
config_id: Database config ID
triggered_by: Source that triggered scan (manual, scheduled, api)
schedule_id: Optional schedule ID if triggered by schedule
scheduler: Optional SchedulerService instance for queuing background jobs
Returns:
Scan ID of the created scan
Raises:
ValueError: If config is invalid
"""
from web.models import ScanConfig
# Validate config exists
db_config = self.db.query(ScanConfig).filter_by(id=config_id).first()
if not db_config:
raise ValueError(f"Config with ID {config_id} not found")
# Create scan record with config_id
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_id=config_id,
title=db_config.title,
triggered_by=triggered_by,
schedule_id=schedule_id,
created_at=datetime.utcnow()
)
self.db.add(scan)
self.db.commit()
self.db.refresh(scan)
logger.info(f"Scan {scan.id} triggered via {triggered_by} with config_id={config_id}")
# Queue background job if scheduler provided
if scheduler:
try:
job_id = scheduler.queue_scan(scan.id, config_id=config_id)
logger.info(f"Scan {scan.id} queued for background execution (job_id={job_id})")
except Exception as e:
logger.error(f"Failed to queue scan {scan.id}: {str(e)}")
# Mark scan as failed if job queuing fails
scan.status = 'failed'
scan.error_message = f"Failed to queue background job: {str(e)}"
self.db.commit()
raise
else:
logger.warning(f"Scan {scan.id} created but not queued (no scheduler provided)")
return scan.id
def get_scan(self, scan_id: int) -> Optional[Dict[str, Any]]:
"""
Get scan details with all related data.
Args:
scan_id: Scan ID to retrieve
Returns:
Dictionary with scan data including sites, IPs, ports, services, etc.
Returns None if scan not found.
"""
# Query with eager loading of all relationships
scan = (
self.db.query(Scan)
.options(
joinedload(Scan.sites).joinedload(ScanSite.ips).joinedload(ScanIP.ports),
joinedload(Scan.ports).joinedload(ScanPort.services),
joinedload(Scan.services).joinedload(ScanServiceModel.certificates),
joinedload(Scan.certificates).joinedload(ScanCertificate.tls_versions)
)
.filter(Scan.id == scan_id)
.first()
)
if not scan:
return None
# Convert to dictionary
return self._scan_to_dict(scan)
def list_scans(self, page: int = 1, per_page: int = 20,
status_filter: Optional[str] = None) -> PaginatedResult:
"""
List scans with pagination and optional filtering.
Args:
page: Page number (1-indexed)
per_page: Items per page
status_filter: Optional filter by status (running, completed, failed)
Returns:
PaginatedResult with scan list and metadata
"""
# Build query
query = self.db.query(Scan).order_by(Scan.timestamp.desc())
# Apply status filter if provided
if status_filter:
is_valid, error_msg = validate_scan_status(status_filter)
if not is_valid:
raise ValueError(error_msg)
query = query.filter(Scan.status == status_filter)
# Paginate
result = paginate(query, page=page, per_page=per_page)
# Convert scans to dictionaries (summary only, not full details)
result.items = [self._scan_to_summary_dict(scan) for scan in result.items]
return result
def delete_scan(self, scan_id: int) -> bool:
"""
Delete a scan and all associated files.
Removes:
- Database record (cascade deletes related records)
- JSON report file
- HTML report file
- ZIP archive file
- Screenshot directory
Args:
scan_id: Scan ID to delete
Returns:
True if deleted successfully
Raises:
ValueError: If scan not found
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
logger.info(f"Deleting scan {scan_id}")
# Delete files (handle missing files gracefully)
files_to_delete = [
scan.json_path,
scan.html_path,
scan.zip_path
]
for file_path in files_to_delete:
if file_path:
try:
Path(file_path).unlink()
logger.debug(f"Deleted file: {file_path}")
except FileNotFoundError:
logger.warning(f"File not found (already deleted?): {file_path}")
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
# Delete screenshot directory
if scan.screenshot_dir:
try:
shutil.rmtree(scan.screenshot_dir)
logger.debug(f"Deleted directory: {scan.screenshot_dir}")
except FileNotFoundError:
logger.warning(f"Directory not found (already deleted?): {scan.screenshot_dir}")
except Exception as e:
logger.error(f"Error deleting directory {scan.screenshot_dir}: {e}")
# Delete database record (cascade handles related records)
self.db.delete(scan)
self.db.commit()
logger.info(f"Scan {scan_id} deleted successfully")
return True
def get_scan_status(self, scan_id: int) -> Optional[Dict[str, Any]]:
"""
Get current scan status and progress.
Args:
scan_id: Scan ID
Returns:
Dictionary with status information, or None if scan not found
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
return None
status_info = {
'scan_id': scan.id,
'status': scan.status,
'title': scan.title,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'started_at': scan.started_at.isoformat() if scan.started_at else None,
'completed_at': scan.completed_at.isoformat() if scan.completed_at else None,
'duration': scan.duration,
'triggered_by': scan.triggered_by
}
# Add progress estimate based on status
if scan.status == 'running':
status_info['progress'] = 'In progress'
elif scan.status == 'completed':
status_info['progress'] = 'Complete'
elif scan.status == 'failed':
status_info['progress'] = 'Failed'
status_info['error_message'] = scan.error_message
return status_info
def cleanup_orphaned_scans(self) -> int:
"""
Clean up orphaned scans that are stuck in 'running' status.
This should be called on application startup to handle scans that
were running when the system crashed or was restarted.
Scans in 'running' status are marked as 'failed' with an appropriate
error message indicating they were orphaned.
Returns:
Number of orphaned scans cleaned up
"""
# Find all scans with status='running'
orphaned_scans = self.db.query(Scan).filter(Scan.status == 'running').all()
if not orphaned_scans:
logger.info("No orphaned scans found")
return 0
count = len(orphaned_scans)
logger.warning(f"Found {count} orphaned scan(s) in 'running' status, marking as failed")
# Mark each orphaned scan as failed
for scan in orphaned_scans:
scan.status = 'failed'
scan.completed_at = datetime.utcnow()
scan.error_message = (
"Scan was interrupted by system shutdown or crash. "
"The scan was running but did not complete normally."
)
# Calculate duration if we have a started_at time
if scan.started_at:
duration = (datetime.utcnow() - scan.started_at).total_seconds()
scan.duration = duration
logger.info(
f"Marked orphaned scan {scan.id} as failed "
f"(started: {scan.started_at.isoformat() if scan.started_at else 'unknown'})"
)
self.db.commit()
logger.info(f"Cleaned up {count} orphaned scan(s)")
return count
def _save_scan_to_db(self, report: Dict[str, Any], scan_id: int,
status: str = 'completed') -> None:
"""
Save scan results to database.
Updates the Scan record and creates all related records
(sites, IPs, ports, services, certificates, TLS versions).
Args:
report: Scan report dictionary from scanner
scan_id: Scan ID to update
status: Final scan status (completed or failed)
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
# Update scan record
scan.status = status
scan.duration = report.get('scan_duration')
scan.completed_at = datetime.utcnow()
# Map report data to database models
self._map_report_to_models(report, scan)
self.db.commit()
logger.info(f"Scan {scan_id} saved to database with status '{status}'")
def _map_report_to_models(self, report: Dict[str, Any], scan_obj: Scan) -> None:
"""
Map JSON report structure to database models.
Creates records for sites, IPs, ports, services, certificates, and TLS versions.
Processes nested relationships in order to handle foreign keys correctly.
Args:
report: Scan report dictionary
scan_obj: Scan database object to attach records to
"""
logger.debug(f"Mapping report to database models for scan {scan_obj.id}")
# Process each site
for site_data in report.get('sites', []):
# Create ScanSite record
site = ScanSite(
scan_id=scan_obj.id,
site_name=site_data['name']
)
self.db.add(site)
self.db.flush() # Get site.id for foreign key
# Create ScanSiteAssociation if this site exists in the database
# This links the scan to reusable site definitions
master_site = (
self.db.query(Site)
.filter(Site.name == site_data['name'])
.first()
)
if master_site:
# Check if association already exists (avoid duplicates)
existing_assoc = (
self.db.query(ScanSiteAssociation)
.filter(
ScanSiteAssociation.scan_id == scan_obj.id,
ScanSiteAssociation.site_id == master_site.id
)
.first()
)
if not existing_assoc:
assoc = ScanSiteAssociation(
scan_id=scan_obj.id,
site_id=master_site.id,
created_at=datetime.utcnow()
)
self.db.add(assoc)
logger.debug(f"Created association between scan {scan_obj.id} and site '{master_site.name}' (id={master_site.id})")
# Process each IP in this site
for ip_data in site_data.get('ips', []):
# Create ScanIP record
ip = ScanIP(
scan_id=scan_obj.id,
site_id=site.id,
ip_address=ip_data['address'],
ping_expected=ip_data.get('expected', {}).get('ping'),
ping_actual=ip_data.get('actual', {}).get('ping')
)
self.db.add(ip)
self.db.flush()
# Process TCP ports
expected_tcp = set(ip_data.get('expected', {}).get('tcp_ports', []))
actual_tcp = ip_data.get('actual', {}).get('tcp_ports', [])
for port_num in actual_tcp:
port = ScanPort(
scan_id=scan_obj.id,
ip_id=ip.id,
port=port_num,
protocol='tcp',
expected=(port_num in expected_tcp),
state='open'
)
self.db.add(port)
self.db.flush()
# Find service for this port
service_data = self._find_service_for_port(
ip_data.get('actual', {}).get('services', []),
port_num
)
if service_data:
# Create ScanService record
service = ScanServiceModel(
scan_id=scan_obj.id,
port_id=port.id,
service_name=service_data.get('service'),
product=service_data.get('product'),
version=service_data.get('version'),
extrainfo=service_data.get('extrainfo'),
ostype=service_data.get('ostype'),
http_protocol=service_data.get('http_info', {}).get('protocol'),
screenshot_path=service_data.get('http_info', {}).get('screenshot')
)
self.db.add(service)
self.db.flush()
# Process certificate and TLS info if present
http_info = service_data.get('http_info', {})
if http_info.get('certificate'):
self._process_certificate(
http_info['certificate'],
scan_obj.id,
service.id
)
# Process UDP ports
expected_udp = set(ip_data.get('expected', {}).get('udp_ports', []))
actual_udp = ip_data.get('actual', {}).get('udp_ports', [])
for port_num in actual_udp:
port = ScanPort(
scan_id=scan_obj.id,
ip_id=ip.id,
port=port_num,
protocol='udp',
expected=(port_num in expected_udp),
state='open'
)
self.db.add(port)
logger.debug(f"Report mapping complete for scan {scan_obj.id}")
def _find_service_for_port(self, services: List[Dict], port: int) -> Optional[Dict]:
"""
Find service data for a specific port.
Args:
services: List of service dictionaries
port: Port number to find
Returns:
Service dictionary if found, None otherwise
"""
for service in services:
if service.get('port') == port:
return service
return None
def _process_certificate(self, cert_data: Dict[str, Any], scan_id: int,
service_id: int) -> None:
"""
Process certificate and TLS version data.
Args:
cert_data: Certificate data dictionary
scan_id: Scan ID
service_id: Service ID
"""
# Create ScanCertificate record
cert = ScanCertificate(
scan_id=scan_id,
service_id=service_id,
subject=cert_data.get('subject'),
issuer=cert_data.get('issuer'),
serial_number=cert_data.get('serial_number'),
not_valid_before=self._parse_datetime(cert_data.get('not_valid_before')),
not_valid_after=self._parse_datetime(cert_data.get('not_valid_after')),
days_until_expiry=cert_data.get('days_until_expiry'),
sans=json.dumps(cert_data.get('sans', [])),
is_self_signed=cert_data.get('is_self_signed', False)
)
self.db.add(cert)
self.db.flush()
# Process TLS versions
tls_versions = cert_data.get('tls_versions', {})
for version, version_data in tls_versions.items():
tls = ScanTLSVersion(
scan_id=scan_id,
certificate_id=cert.id,
tls_version=version,
supported=version_data.get('supported', False),
cipher_suites=json.dumps(version_data.get('cipher_suites', []))
)
self.db.add(tls)
def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]:
"""
Parse ISO datetime string.
Args:
date_str: ISO format datetime string
Returns:
datetime object or None if parsing fails
"""
if not date_str:
return None
try:
# Handle ISO format with 'Z' suffix
if date_str.endswith('Z'):
date_str = date_str[:-1] + '+00:00'
return datetime.fromisoformat(date_str)
except (ValueError, AttributeError) as e:
logger.warning(f"Failed to parse datetime '{date_str}': {e}")
return None
def _scan_to_dict(self, scan: Scan) -> Dict[str, Any]:
"""
Convert Scan object to dictionary with full details.
Args:
scan: Scan database object
Returns:
Dictionary representation with all related data
"""
return {
'id': scan.id,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'status': scan.status,
'title': scan.title,
'config_id': scan.config_id,
'json_path': scan.json_path,
'html_path': scan.html_path,
'zip_path': scan.zip_path,
'screenshot_dir': scan.screenshot_dir,
'triggered_by': scan.triggered_by,
'created_at': scan.created_at.isoformat() if scan.created_at else None,
'sites': [self._site_to_dict(site) for site in scan.sites]
}
def _scan_to_summary_dict(self, scan: Scan) -> Dict[str, Any]:
"""
Convert Scan object to summary dictionary (no related data).
Args:
scan: Scan database object
Returns:
Summary dictionary
"""
return {
'id': scan.id,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'status': scan.status,
'title': scan.title,
'config_id': scan.config_id,
'triggered_by': scan.triggered_by,
'created_at': scan.created_at.isoformat() if scan.created_at else None
}
def _site_to_dict(self, site: ScanSite) -> Dict[str, Any]:
"""Convert ScanSite to dictionary."""
return {
'id': site.id,
'name': site.site_name,
'ips': [self._ip_to_dict(ip) for ip in site.ips]
}
def _ip_to_dict(self, ip: ScanIP) -> Dict[str, Any]:
"""Convert ScanIP to dictionary."""
return {
'id': ip.id,
'address': ip.ip_address,
'ping_expected': ip.ping_expected,
'ping_actual': ip.ping_actual,
'ports': [self._port_to_dict(port) for port in ip.ports]
}
def _port_to_dict(self, port: ScanPort) -> Dict[str, Any]:
"""Convert ScanPort to dictionary."""
return {
'id': port.id,
'port': port.port,
'protocol': port.protocol,
'state': port.state,
'expected': port.expected,
'services': [self._service_to_dict(svc) for svc in port.services]
}
def _service_to_dict(self, service: ScanServiceModel) -> Dict[str, Any]:
"""Convert ScanService to dictionary."""
result = {
'id': service.id,
'service_name': service.service_name,
'product': service.product,
'version': service.version,
'extrainfo': service.extrainfo,
'ostype': service.ostype,
'http_protocol': service.http_protocol,
'screenshot_path': service.screenshot_path
}
# Add certificate info if present
if service.certificates:
result['certificates'] = [
self._certificate_to_dict(cert) for cert in service.certificates
]
return result
def _certificate_to_dict(self, cert: ScanCertificate) -> Dict[str, Any]:
"""Convert ScanCertificate to dictionary."""
result = {
'id': cert.id,
'subject': cert.subject,
'issuer': cert.issuer,
'serial_number': cert.serial_number,
'not_valid_before': cert.not_valid_before.isoformat() if cert.not_valid_before else None,
'not_valid_after': cert.not_valid_after.isoformat() if cert.not_valid_after else None,
'days_until_expiry': cert.days_until_expiry,
'is_self_signed': cert.is_self_signed
}
# Parse SANs from JSON
if cert.sans:
try:
result['sans'] = json.loads(cert.sans)
except json.JSONDecodeError:
result['sans'] = []
# Add TLS versions
result['tls_versions'] = [
self._tls_version_to_dict(tls) for tls in cert.tls_versions
]
return result
def _tls_version_to_dict(self, tls: ScanTLSVersion) -> Dict[str, Any]:
"""Convert ScanTLSVersion to dictionary."""
result = {
'id': tls.id,
'tls_version': tls.tls_version,
'supported': tls.supported
}
# Parse cipher suites from JSON
if tls.cipher_suites:
try:
result['cipher_suites'] = json.loads(tls.cipher_suites)
except json.JSONDecodeError:
result['cipher_suites'] = []
return result
def compare_scans(self, scan1_id: int, scan2_id: int) -> Optional[Dict[str, Any]]:
"""
Compare two scans and return the differences.
Compares ports, services, and certificates between two scans,
highlighting added, removed, and changed items.
Args:
scan1_id: ID of the first (older) scan
scan2_id: ID of the second (newer) scan
Returns:
Dictionary with comparison results, or None if either scan not found
{
'scan1': {...}, # Scan 1 summary
'scan2': {...}, # Scan 2 summary
'same_config': bool, # Whether both scans used the same config
'config_warning': str | None, # Warning message if configs differ
'ports': {
'added': [...],
'removed': [...],
'unchanged': [...]
},
'services': {
'added': [...],
'removed': [...],
'changed': [...]
},
'certificates': {
'added': [...],
'removed': [...],
'changed': [...]
},
'drift_score': 0.0-1.0
}
"""
# Get both scans
scan1 = self.get_scan(scan1_id)
scan2 = self.get_scan(scan2_id)
if not scan1 or not scan2:
return None
# Check if scans use the same configuration
config1 = scan1.get('config_id')
config2 = scan2.get('config_id')
same_config = (config1 == config2) and (config1 is not None)
# Generate warning message if configs differ
config_warning = None
if not same_config:
config_warning = (
f"These scans use different configurations. "
f"Scan #{scan1_id} used config_id={config1 or 'unknown'} and "
f"Scan #{scan2_id} used config_id={config2 or 'unknown'}. "
f"The comparison may show all changes as additions/removals if the scans "
f"cover different IP ranges or infrastructure."
)
# Extract port data
ports1 = self._extract_ports_from_scan(scan1)
ports2 = self._extract_ports_from_scan(scan2)
# Compare ports
ports_comparison = self._compare_ports(ports1, ports2)
# Extract service data
services1 = self._extract_services_from_scan(scan1)
services2 = self._extract_services_from_scan(scan2)
# Compare services
services_comparison = self._compare_services(services1, services2)
# Extract certificate data
certs1 = self._extract_certificates_from_scan(scan1)
certs2 = self._extract_certificates_from_scan(scan2)
# Compare certificates
certificates_comparison = self._compare_certificates(certs1, certs2)
# Calculate drift score (0.0 = identical, 1.0 = completely different)
drift_score = self._calculate_drift_score(
ports_comparison,
services_comparison,
certificates_comparison
)
return {
'scan1': {
'id': scan1['id'],
'timestamp': scan1['timestamp'],
'title': scan1['title'],
'status': scan1['status'],
'config_id': config1
},
'scan2': {
'id': scan2['id'],
'timestamp': scan2['timestamp'],
'title': scan2['title'],
'status': scan2['status'],
'config_id': config2
},
'same_config': same_config,
'config_warning': config_warning,
'ports': ports_comparison,
'services': services_comparison,
'certificates': certificates_comparison,
'drift_score': drift_score
}
def _extract_ports_from_scan(self, scan: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract port information from a scan.
Returns:
Dictionary mapping "ip:port:protocol" to port details
"""
ports = {}
for site in scan.get('sites', []):
for ip_data in site.get('ips', []):
ip_addr = ip_data['address']
for port_data in ip_data.get('ports', []):
key = f"{ip_addr}:{port_data['port']}:{port_data['protocol']}"
ports[key] = {
'ip': ip_addr,
'port': port_data['port'],
'protocol': port_data['protocol'],
'state': port_data.get('state', 'unknown'),
'expected': port_data.get('expected')
}
return ports
def _extract_services_from_scan(self, scan: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract service information from a scan.
Returns:
Dictionary mapping "ip:port:protocol" to service details
"""
services = {}
for site in scan.get('sites', []):
for ip_data in site.get('ips', []):
ip_addr = ip_data['address']
for port_data in ip_data.get('ports', []):
port_num = port_data['port']
protocol = port_data['protocol']
key = f"{ip_addr}:{port_num}:{protocol}"
# Get first service (usually only one per port)
port_services = port_data.get('services', [])
if port_services:
svc = port_services[0]
services[key] = {
'ip': ip_addr,
'port': port_num,
'protocol': protocol,
'service_name': svc.get('service_name'),
'product': svc.get('product'),
'version': svc.get('version'),
'extrainfo': svc.get('extrainfo')
}
return services
def _extract_certificates_from_scan(self, scan: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract certificate information from a scan.
Returns:
Dictionary mapping "ip:port" to certificate details
"""
certificates = {}
for site in scan.get('sites', []):
for ip_data in site.get('ips', []):
ip_addr = ip_data['address']
for port_data in ip_data.get('ports', []):
port_num = port_data['port']
protocol = port_data['protocol']
# Get certificates from services
for svc in port_data.get('services', []):
if svc.get('certificates'):
for cert in svc['certificates']:
key = f"{ip_addr}:{port_num}"
certificates[key] = {
'ip': ip_addr,
'port': port_num,
'subject': cert.get('subject'),
'issuer': cert.get('issuer'),
'not_valid_after': cert.get('not_valid_after'),
'days_until_expiry': cert.get('days_until_expiry'),
'is_self_signed': cert.get('is_self_signed')
}
return certificates
def _compare_ports(self, ports1: Dict, ports2: Dict) -> Dict[str, List]:
"""
Compare port sets between two scans.
Returns:
Dictionary with added, removed, and unchanged ports
"""
keys1 = set(ports1.keys())
keys2 = set(ports2.keys())
added_keys = keys2 - keys1
removed_keys = keys1 - keys2
unchanged_keys = keys1 & keys2
return {
'added': [ports2[k] for k in sorted(added_keys)],
'removed': [ports1[k] for k in sorted(removed_keys)],
'unchanged': [ports2[k] for k in sorted(unchanged_keys)]
}
def _compare_services(self, services1: Dict, services2: Dict) -> Dict[str, List]:
"""
Compare services between two scans.
Returns:
Dictionary with added, removed, and changed services
"""
keys1 = set(services1.keys())
keys2 = set(services2.keys())
added_keys = keys2 - keys1
removed_keys = keys1 - keys2
common_keys = keys1 & keys2
# Find changed services (same port, different version/product)
changed = []
for key in sorted(common_keys):
svc1 = services1[key]
svc2 = services2[key]
# Check if service details changed
if (svc1.get('product') != svc2.get('product') or
svc1.get('version') != svc2.get('version') or
svc1.get('service_name') != svc2.get('service_name')):
changed.append({
'ip': svc2['ip'],
'port': svc2['port'],
'protocol': svc2['protocol'],
'old': {
'service_name': svc1.get('service_name'),
'product': svc1.get('product'),
'version': svc1.get('version')
},
'new': {
'service_name': svc2.get('service_name'),
'product': svc2.get('product'),
'version': svc2.get('version')
}
})
return {
'added': [services2[k] for k in sorted(added_keys)],
'removed': [services1[k] for k in sorted(removed_keys)],
'changed': changed
}
def _compare_certificates(self, certs1: Dict, certs2: Dict) -> Dict[str, List]:
"""
Compare certificates between two scans.
Returns:
Dictionary with added, removed, and changed certificates
"""
keys1 = set(certs1.keys())
keys2 = set(certs2.keys())
added_keys = keys2 - keys1
removed_keys = keys1 - keys2
common_keys = keys1 & keys2
# Find changed certificates (same IP:port, different cert details)
changed = []
for key in sorted(common_keys):
cert1 = certs1[key]
cert2 = certs2[key]
# Check if certificate changed
if (cert1.get('subject') != cert2.get('subject') or
cert1.get('issuer') != cert2.get('issuer') or
cert1.get('not_valid_after') != cert2.get('not_valid_after')):
changed.append({
'ip': cert2['ip'],
'port': cert2['port'],
'old': {
'subject': cert1.get('subject'),
'issuer': cert1.get('issuer'),
'not_valid_after': cert1.get('not_valid_after'),
'days_until_expiry': cert1.get('days_until_expiry')
},
'new': {
'subject': cert2.get('subject'),
'issuer': cert2.get('issuer'),
'not_valid_after': cert2.get('not_valid_after'),
'days_until_expiry': cert2.get('days_until_expiry')
}
})
return {
'added': [certs2[k] for k in sorted(added_keys)],
'removed': [certs1[k] for k in sorted(removed_keys)],
'changed': changed
}
def _calculate_drift_score(self, ports_comp: Dict, services_comp: Dict,
certs_comp: Dict) -> float:
"""
Calculate drift score based on comparison results.
Returns:
Float between 0.0 (identical) and 1.0 (completely different)
"""
# Count total items in both scans
total_ports = (
len(ports_comp['added']) +
len(ports_comp['removed']) +
len(ports_comp['unchanged'])
)
total_services = (
len(services_comp['added']) +
len(services_comp['removed']) +
len(services_comp['changed']) +
max(0, len(ports_comp['unchanged']) - len(services_comp['changed']))
)
# Count changed items
changed_ports = len(ports_comp['added']) + len(ports_comp['removed'])
changed_services = (
len(services_comp['added']) +
len(services_comp['removed']) +
len(services_comp['changed'])
)
changed_certs = (
len(certs_comp['added']) +
len(certs_comp['removed']) +
len(certs_comp['changed'])
)
# Calculate weighted drift score
# Ports have 50% weight, services 30%, certificates 20%
port_drift = changed_ports / max(total_ports, 1)
service_drift = changed_services / max(total_services, 1)
cert_drift = changed_certs / max(len(certs_comp['added']) + len(certs_comp['removed']) + len(certs_comp['changed']), 1)
drift_score = (port_drift * 0.5) + (service_drift * 0.3) + (cert_drift * 0.2)
return round(min(drift_score, 1.0), 3) # Cap at 1.0 and round to 3 decimals