1014 lines
36 KiB
Python
1014 lines
36 KiB
Python
"""
|
|
Scan service for managing scan operations and database integration.
|
|
|
|
This service handles the business logic for triggering scans, retrieving
|
|
scan results, and mapping scanner output to database models.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from sqlalchemy.orm import Session, joinedload
|
|
|
|
from web.models import (
|
|
Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel,
|
|
ScanCertificate, ScanTLSVersion
|
|
)
|
|
from web.utils.pagination import paginate, PaginatedResult
|
|
from web.utils.validators import validate_config_file, validate_scan_status
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ScanService:
|
|
"""
|
|
Service for managing scan operations.
|
|
|
|
Handles scan lifecycle: triggering, status tracking, result storage,
|
|
and cleanup.
|
|
"""
|
|
|
|
def __init__(self, db_session: Session):
|
|
"""
|
|
Initialize scan service.
|
|
|
|
Args:
|
|
db_session: SQLAlchemy database session
|
|
"""
|
|
self.db = db_session
|
|
|
|
def trigger_scan(self, config_file: str, triggered_by: str = 'manual',
|
|
schedule_id: Optional[int] = None, scheduler=None) -> int:
|
|
"""
|
|
Trigger a new scan.
|
|
|
|
Creates a Scan record in the database with status='running' and
|
|
queues the scan for background execution.
|
|
|
|
Args:
|
|
config_file: Path to YAML configuration file
|
|
triggered_by: Source that triggered scan (manual, scheduled, api)
|
|
schedule_id: Optional schedule ID if triggered by schedule
|
|
scheduler: Optional SchedulerService instance for queuing background jobs
|
|
|
|
Returns:
|
|
Scan ID of the created scan
|
|
|
|
Raises:
|
|
ValueError: If config file is invalid
|
|
"""
|
|
# Validate config file
|
|
is_valid, error_msg = validate_config_file(config_file)
|
|
if not is_valid:
|
|
raise ValueError(f"Invalid config file: {error_msg}")
|
|
|
|
# Convert config_file to full path if it's just a filename
|
|
if not config_file.startswith('/'):
|
|
config_path = f'/app/configs/{config_file}'
|
|
else:
|
|
config_path = config_file
|
|
|
|
# Load config to get title
|
|
import yaml
|
|
with open(config_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Create scan record
|
|
scan = Scan(
|
|
timestamp=datetime.utcnow(),
|
|
status='running',
|
|
config_file=config_file,
|
|
title=config.get('title', 'Untitled Scan'),
|
|
triggered_by=triggered_by,
|
|
schedule_id=schedule_id,
|
|
created_at=datetime.utcnow()
|
|
)
|
|
|
|
self.db.add(scan)
|
|
self.db.commit()
|
|
self.db.refresh(scan)
|
|
|
|
logger.info(f"Scan {scan.id} triggered via {triggered_by}")
|
|
|
|
# Queue background job if scheduler provided
|
|
if scheduler:
|
|
try:
|
|
job_id = scheduler.queue_scan(scan.id, config_file)
|
|
logger.info(f"Scan {scan.id} queued for background execution (job_id={job_id})")
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue scan {scan.id}: {str(e)}")
|
|
# Mark scan as failed if job queuing fails
|
|
scan.status = 'failed'
|
|
scan.error_message = f"Failed to queue background job: {str(e)}"
|
|
self.db.commit()
|
|
raise
|
|
else:
|
|
logger.warning(f"Scan {scan.id} created but not queued (no scheduler provided)")
|
|
|
|
return scan.id
|
|
|
|
def get_scan(self, scan_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get scan details with all related data.
|
|
|
|
Args:
|
|
scan_id: Scan ID to retrieve
|
|
|
|
Returns:
|
|
Dictionary with scan data including sites, IPs, ports, services, etc.
|
|
Returns None if scan not found.
|
|
"""
|
|
# Query with eager loading of all relationships
|
|
scan = (
|
|
self.db.query(Scan)
|
|
.options(
|
|
joinedload(Scan.sites).joinedload(ScanSite.ips).joinedload(ScanIP.ports),
|
|
joinedload(Scan.ports).joinedload(ScanPort.services),
|
|
joinedload(Scan.services).joinedload(ScanServiceModel.certificates),
|
|
joinedload(Scan.certificates).joinedload(ScanCertificate.tls_versions)
|
|
)
|
|
.filter(Scan.id == scan_id)
|
|
.first()
|
|
)
|
|
|
|
if not scan:
|
|
return None
|
|
|
|
# Convert to dictionary
|
|
return self._scan_to_dict(scan)
|
|
|
|
def list_scans(self, page: int = 1, per_page: int = 20,
|
|
status_filter: Optional[str] = None) -> PaginatedResult:
|
|
"""
|
|
List scans with pagination and optional filtering.
|
|
|
|
Args:
|
|
page: Page number (1-indexed)
|
|
per_page: Items per page
|
|
status_filter: Optional filter by status (running, completed, failed)
|
|
|
|
Returns:
|
|
PaginatedResult with scan list and metadata
|
|
"""
|
|
# Build query
|
|
query = self.db.query(Scan).order_by(Scan.timestamp.desc())
|
|
|
|
# Apply status filter if provided
|
|
if status_filter:
|
|
is_valid, error_msg = validate_scan_status(status_filter)
|
|
if not is_valid:
|
|
raise ValueError(error_msg)
|
|
query = query.filter(Scan.status == status_filter)
|
|
|
|
# Paginate
|
|
result = paginate(query, page=page, per_page=per_page)
|
|
|
|
# Convert scans to dictionaries (summary only, not full details)
|
|
result.items = [self._scan_to_summary_dict(scan) for scan in result.items]
|
|
|
|
return result
|
|
|
|
def delete_scan(self, scan_id: int) -> bool:
|
|
"""
|
|
Delete a scan and all associated files.
|
|
|
|
Removes:
|
|
- Database record (cascade deletes related records)
|
|
- JSON report file
|
|
- HTML report file
|
|
- ZIP archive file
|
|
- Screenshot directory
|
|
|
|
Args:
|
|
scan_id: Scan ID to delete
|
|
|
|
Returns:
|
|
True if deleted successfully
|
|
|
|
Raises:
|
|
ValueError: If scan not found
|
|
"""
|
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
|
if not scan:
|
|
raise ValueError(f"Scan {scan_id} not found")
|
|
|
|
logger.info(f"Deleting scan {scan_id}")
|
|
|
|
# Delete files (handle missing files gracefully)
|
|
files_to_delete = [
|
|
scan.json_path,
|
|
scan.html_path,
|
|
scan.zip_path
|
|
]
|
|
|
|
for file_path in files_to_delete:
|
|
if file_path:
|
|
try:
|
|
Path(file_path).unlink()
|
|
logger.debug(f"Deleted file: {file_path}")
|
|
except FileNotFoundError:
|
|
logger.warning(f"File not found (already deleted?): {file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting file {file_path}: {e}")
|
|
|
|
# Delete screenshot directory
|
|
if scan.screenshot_dir:
|
|
try:
|
|
shutil.rmtree(scan.screenshot_dir)
|
|
logger.debug(f"Deleted directory: {scan.screenshot_dir}")
|
|
except FileNotFoundError:
|
|
logger.warning(f"Directory not found (already deleted?): {scan.screenshot_dir}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting directory {scan.screenshot_dir}: {e}")
|
|
|
|
# Delete database record (cascade handles related records)
|
|
self.db.delete(scan)
|
|
self.db.commit()
|
|
|
|
logger.info(f"Scan {scan_id} deleted successfully")
|
|
return True
|
|
|
|
def get_scan_status(self, scan_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get current scan status and progress.
|
|
|
|
Args:
|
|
scan_id: Scan ID
|
|
|
|
Returns:
|
|
Dictionary with status information, or None if scan not found
|
|
"""
|
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
|
if not scan:
|
|
return None
|
|
|
|
status_info = {
|
|
'scan_id': scan.id,
|
|
'status': scan.status,
|
|
'title': scan.title,
|
|
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
|
|
'started_at': scan.started_at.isoformat() if scan.started_at else None,
|
|
'completed_at': scan.completed_at.isoformat() if scan.completed_at else None,
|
|
'duration': scan.duration,
|
|
'triggered_by': scan.triggered_by
|
|
}
|
|
|
|
# Add progress estimate based on status
|
|
if scan.status == 'running':
|
|
status_info['progress'] = 'In progress'
|
|
elif scan.status == 'completed':
|
|
status_info['progress'] = 'Complete'
|
|
elif scan.status == 'failed':
|
|
status_info['progress'] = 'Failed'
|
|
status_info['error_message'] = scan.error_message
|
|
|
|
return status_info
|
|
|
|
def cleanup_orphaned_scans(self) -> int:
|
|
"""
|
|
Clean up orphaned scans that are stuck in 'running' status.
|
|
|
|
This should be called on application startup to handle scans that
|
|
were running when the system crashed or was restarted.
|
|
|
|
Scans in 'running' status are marked as 'failed' with an appropriate
|
|
error message indicating they were orphaned.
|
|
|
|
Returns:
|
|
Number of orphaned scans cleaned up
|
|
"""
|
|
# Find all scans with status='running'
|
|
orphaned_scans = self.db.query(Scan).filter(Scan.status == 'running').all()
|
|
|
|
if not orphaned_scans:
|
|
logger.info("No orphaned scans found")
|
|
return 0
|
|
|
|
count = len(orphaned_scans)
|
|
logger.warning(f"Found {count} orphaned scan(s) in 'running' status, marking as failed")
|
|
|
|
# Mark each orphaned scan as failed
|
|
for scan in orphaned_scans:
|
|
scan.status = 'failed'
|
|
scan.completed_at = datetime.utcnow()
|
|
scan.error_message = (
|
|
"Scan was interrupted by system shutdown or crash. "
|
|
"The scan was running but did not complete normally."
|
|
)
|
|
|
|
# Calculate duration if we have a started_at time
|
|
if scan.started_at:
|
|
duration = (datetime.utcnow() - scan.started_at).total_seconds()
|
|
scan.duration = duration
|
|
|
|
logger.info(
|
|
f"Marked orphaned scan {scan.id} as failed "
|
|
f"(started: {scan.started_at.isoformat() if scan.started_at else 'unknown'})"
|
|
)
|
|
|
|
self.db.commit()
|
|
logger.info(f"Cleaned up {count} orphaned scan(s)")
|
|
|
|
return count
|
|
|
|
def _save_scan_to_db(self, report: Dict[str, Any], scan_id: int,
|
|
status: str = 'completed') -> None:
|
|
"""
|
|
Save scan results to database.
|
|
|
|
Updates the Scan record and creates all related records
|
|
(sites, IPs, ports, services, certificates, TLS versions).
|
|
|
|
Args:
|
|
report: Scan report dictionary from scanner
|
|
scan_id: Scan ID to update
|
|
status: Final scan status (completed or failed)
|
|
"""
|
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
|
if not scan:
|
|
raise ValueError(f"Scan {scan_id} not found")
|
|
|
|
# Update scan record
|
|
scan.status = status
|
|
scan.duration = report.get('scan_duration')
|
|
scan.completed_at = datetime.utcnow()
|
|
|
|
# Map report data to database models
|
|
self._map_report_to_models(report, scan)
|
|
|
|
self.db.commit()
|
|
logger.info(f"Scan {scan_id} saved to database with status '{status}'")
|
|
|
|
def _map_report_to_models(self, report: Dict[str, Any], scan_obj: Scan) -> None:
|
|
"""
|
|
Map JSON report structure to database models.
|
|
|
|
Creates records for sites, IPs, ports, services, certificates, and TLS versions.
|
|
Processes nested relationships in order to handle foreign keys correctly.
|
|
|
|
Args:
|
|
report: Scan report dictionary
|
|
scan_obj: Scan database object to attach records to
|
|
"""
|
|
logger.debug(f"Mapping report to database models for scan {scan_obj.id}")
|
|
|
|
# Process each site
|
|
for site_data in report.get('sites', []):
|
|
# Create ScanSite record
|
|
site = ScanSite(
|
|
scan_id=scan_obj.id,
|
|
site_name=site_data['name']
|
|
)
|
|
self.db.add(site)
|
|
self.db.flush() # Get site.id for foreign key
|
|
|
|
# Process each IP in this site
|
|
for ip_data in site_data.get('ips', []):
|
|
# Create ScanIP record
|
|
ip = ScanIP(
|
|
scan_id=scan_obj.id,
|
|
site_id=site.id,
|
|
ip_address=ip_data['address'],
|
|
ping_expected=ip_data.get('expected', {}).get('ping'),
|
|
ping_actual=ip_data.get('actual', {}).get('ping')
|
|
)
|
|
self.db.add(ip)
|
|
self.db.flush()
|
|
|
|
# Process TCP ports
|
|
expected_tcp = set(ip_data.get('expected', {}).get('tcp_ports', []))
|
|
actual_tcp = ip_data.get('actual', {}).get('tcp_ports', [])
|
|
|
|
for port_num in actual_tcp:
|
|
port = ScanPort(
|
|
scan_id=scan_obj.id,
|
|
ip_id=ip.id,
|
|
port=port_num,
|
|
protocol='tcp',
|
|
expected=(port_num in expected_tcp),
|
|
state='open'
|
|
)
|
|
self.db.add(port)
|
|
self.db.flush()
|
|
|
|
# Find service for this port
|
|
service_data = self._find_service_for_port(
|
|
ip_data.get('actual', {}).get('services', []),
|
|
port_num
|
|
)
|
|
|
|
if service_data:
|
|
# Create ScanService record
|
|
service = ScanServiceModel(
|
|
scan_id=scan_obj.id,
|
|
port_id=port.id,
|
|
service_name=service_data.get('service'),
|
|
product=service_data.get('product'),
|
|
version=service_data.get('version'),
|
|
extrainfo=service_data.get('extrainfo'),
|
|
ostype=service_data.get('ostype'),
|
|
http_protocol=service_data.get('http_info', {}).get('protocol'),
|
|
screenshot_path=service_data.get('http_info', {}).get('screenshot')
|
|
)
|
|
self.db.add(service)
|
|
self.db.flush()
|
|
|
|
# Process certificate and TLS info if present
|
|
http_info = service_data.get('http_info', {})
|
|
if http_info.get('certificate'):
|
|
self._process_certificate(
|
|
http_info['certificate'],
|
|
scan_obj.id,
|
|
service.id
|
|
)
|
|
|
|
# Process UDP ports
|
|
expected_udp = set(ip_data.get('expected', {}).get('udp_ports', []))
|
|
actual_udp = ip_data.get('actual', {}).get('udp_ports', [])
|
|
|
|
for port_num in actual_udp:
|
|
port = ScanPort(
|
|
scan_id=scan_obj.id,
|
|
ip_id=ip.id,
|
|
port=port_num,
|
|
protocol='udp',
|
|
expected=(port_num in expected_udp),
|
|
state='open'
|
|
)
|
|
self.db.add(port)
|
|
|
|
logger.debug(f"Report mapping complete for scan {scan_obj.id}")
|
|
|
|
def _find_service_for_port(self, services: List[Dict], port: int) -> Optional[Dict]:
|
|
"""
|
|
Find service data for a specific port.
|
|
|
|
Args:
|
|
services: List of service dictionaries
|
|
port: Port number to find
|
|
|
|
Returns:
|
|
Service dictionary if found, None otherwise
|
|
"""
|
|
for service in services:
|
|
if service.get('port') == port:
|
|
return service
|
|
return None
|
|
|
|
def _process_certificate(self, cert_data: Dict[str, Any], scan_id: int,
|
|
service_id: int) -> None:
|
|
"""
|
|
Process certificate and TLS version data.
|
|
|
|
Args:
|
|
cert_data: Certificate data dictionary
|
|
scan_id: Scan ID
|
|
service_id: Service ID
|
|
"""
|
|
# Create ScanCertificate record
|
|
cert = ScanCertificate(
|
|
scan_id=scan_id,
|
|
service_id=service_id,
|
|
subject=cert_data.get('subject'),
|
|
issuer=cert_data.get('issuer'),
|
|
serial_number=cert_data.get('serial_number'),
|
|
not_valid_before=self._parse_datetime(cert_data.get('not_valid_before')),
|
|
not_valid_after=self._parse_datetime(cert_data.get('not_valid_after')),
|
|
days_until_expiry=cert_data.get('days_until_expiry'),
|
|
sans=json.dumps(cert_data.get('sans', [])),
|
|
is_self_signed=cert_data.get('is_self_signed', False)
|
|
)
|
|
self.db.add(cert)
|
|
self.db.flush()
|
|
|
|
# Process TLS versions
|
|
tls_versions = cert_data.get('tls_versions', {})
|
|
for version, version_data in tls_versions.items():
|
|
tls = ScanTLSVersion(
|
|
scan_id=scan_id,
|
|
certificate_id=cert.id,
|
|
tls_version=version,
|
|
supported=version_data.get('supported', False),
|
|
cipher_suites=json.dumps(version_data.get('cipher_suites', []))
|
|
)
|
|
self.db.add(tls)
|
|
|
|
def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]:
|
|
"""
|
|
Parse ISO datetime string.
|
|
|
|
Args:
|
|
date_str: ISO format datetime string
|
|
|
|
Returns:
|
|
datetime object or None if parsing fails
|
|
"""
|
|
if not date_str:
|
|
return None
|
|
|
|
try:
|
|
# Handle ISO format with 'Z' suffix
|
|
if date_str.endswith('Z'):
|
|
date_str = date_str[:-1] + '+00:00'
|
|
return datetime.fromisoformat(date_str)
|
|
except (ValueError, AttributeError) as e:
|
|
logger.warning(f"Failed to parse datetime '{date_str}': {e}")
|
|
return None
|
|
|
|
def _scan_to_dict(self, scan: Scan) -> Dict[str, Any]:
|
|
"""
|
|
Convert Scan object to dictionary with full details.
|
|
|
|
Args:
|
|
scan: Scan database object
|
|
|
|
Returns:
|
|
Dictionary representation with all related data
|
|
"""
|
|
return {
|
|
'id': scan.id,
|
|
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
|
|
'duration': scan.duration,
|
|
'status': scan.status,
|
|
'title': scan.title,
|
|
'config_file': scan.config_file,
|
|
'json_path': scan.json_path,
|
|
'html_path': scan.html_path,
|
|
'zip_path': scan.zip_path,
|
|
'screenshot_dir': scan.screenshot_dir,
|
|
'triggered_by': scan.triggered_by,
|
|
'created_at': scan.created_at.isoformat() if scan.created_at else None,
|
|
'sites': [self._site_to_dict(site) for site in scan.sites]
|
|
}
|
|
|
|
def _scan_to_summary_dict(self, scan: Scan) -> Dict[str, Any]:
|
|
"""
|
|
Convert Scan object to summary dictionary (no related data).
|
|
|
|
Args:
|
|
scan: Scan database object
|
|
|
|
Returns:
|
|
Summary dictionary
|
|
"""
|
|
return {
|
|
'id': scan.id,
|
|
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
|
|
'duration': scan.duration,
|
|
'status': scan.status,
|
|
'title': scan.title,
|
|
'config_file': scan.config_file,
|
|
'triggered_by': scan.triggered_by,
|
|
'created_at': scan.created_at.isoformat() if scan.created_at else None
|
|
}
|
|
|
|
def _site_to_dict(self, site: ScanSite) -> Dict[str, Any]:
|
|
"""Convert ScanSite to dictionary."""
|
|
return {
|
|
'id': site.id,
|
|
'name': site.site_name,
|
|
'ips': [self._ip_to_dict(ip) for ip in site.ips]
|
|
}
|
|
|
|
def _ip_to_dict(self, ip: ScanIP) -> Dict[str, Any]:
|
|
"""Convert ScanIP to dictionary."""
|
|
return {
|
|
'id': ip.id,
|
|
'address': ip.ip_address,
|
|
'ping_expected': ip.ping_expected,
|
|
'ping_actual': ip.ping_actual,
|
|
'ports': [self._port_to_dict(port) for port in ip.ports]
|
|
}
|
|
|
|
def _port_to_dict(self, port: ScanPort) -> Dict[str, Any]:
|
|
"""Convert ScanPort to dictionary."""
|
|
return {
|
|
'id': port.id,
|
|
'port': port.port,
|
|
'protocol': port.protocol,
|
|
'state': port.state,
|
|
'expected': port.expected,
|
|
'services': [self._service_to_dict(svc) for svc in port.services]
|
|
}
|
|
|
|
def _service_to_dict(self, service: ScanServiceModel) -> Dict[str, Any]:
|
|
"""Convert ScanService to dictionary."""
|
|
result = {
|
|
'id': service.id,
|
|
'service_name': service.service_name,
|
|
'product': service.product,
|
|
'version': service.version,
|
|
'extrainfo': service.extrainfo,
|
|
'ostype': service.ostype,
|
|
'http_protocol': service.http_protocol,
|
|
'screenshot_path': service.screenshot_path
|
|
}
|
|
|
|
# Add certificate info if present
|
|
if service.certificates:
|
|
result['certificates'] = [
|
|
self._certificate_to_dict(cert) for cert in service.certificates
|
|
]
|
|
|
|
return result
|
|
|
|
def _certificate_to_dict(self, cert: ScanCertificate) -> Dict[str, Any]:
|
|
"""Convert ScanCertificate to dictionary."""
|
|
result = {
|
|
'id': cert.id,
|
|
'subject': cert.subject,
|
|
'issuer': cert.issuer,
|
|
'serial_number': cert.serial_number,
|
|
'not_valid_before': cert.not_valid_before.isoformat() if cert.not_valid_before else None,
|
|
'not_valid_after': cert.not_valid_after.isoformat() if cert.not_valid_after else None,
|
|
'days_until_expiry': cert.days_until_expiry,
|
|
'is_self_signed': cert.is_self_signed
|
|
}
|
|
|
|
# Parse SANs from JSON
|
|
if cert.sans:
|
|
try:
|
|
result['sans'] = json.loads(cert.sans)
|
|
except json.JSONDecodeError:
|
|
result['sans'] = []
|
|
|
|
# Add TLS versions
|
|
result['tls_versions'] = [
|
|
self._tls_version_to_dict(tls) for tls in cert.tls_versions
|
|
]
|
|
|
|
return result
|
|
|
|
def _tls_version_to_dict(self, tls: ScanTLSVersion) -> Dict[str, Any]:
|
|
"""Convert ScanTLSVersion to dictionary."""
|
|
result = {
|
|
'id': tls.id,
|
|
'tls_version': tls.tls_version,
|
|
'supported': tls.supported
|
|
}
|
|
|
|
# Parse cipher suites from JSON
|
|
if tls.cipher_suites:
|
|
try:
|
|
result['cipher_suites'] = json.loads(tls.cipher_suites)
|
|
except json.JSONDecodeError:
|
|
result['cipher_suites'] = []
|
|
|
|
return result
|
|
|
|
def compare_scans(self, scan1_id: int, scan2_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Compare two scans and return the differences.
|
|
|
|
Compares ports, services, and certificates between two scans,
|
|
highlighting added, removed, and changed items.
|
|
|
|
Args:
|
|
scan1_id: ID of the first (older) scan
|
|
scan2_id: ID of the second (newer) scan
|
|
|
|
Returns:
|
|
Dictionary with comparison results, or None if either scan not found
|
|
{
|
|
'scan1': {...}, # Scan 1 summary
|
|
'scan2': {...}, # Scan 2 summary
|
|
'same_config': bool, # Whether both scans used the same config
|
|
'config_warning': str | None, # Warning message if configs differ
|
|
'ports': {
|
|
'added': [...],
|
|
'removed': [...],
|
|
'unchanged': [...]
|
|
},
|
|
'services': {
|
|
'added': [...],
|
|
'removed': [...],
|
|
'changed': [...]
|
|
},
|
|
'certificates': {
|
|
'added': [...],
|
|
'removed': [...],
|
|
'changed': [...]
|
|
},
|
|
'drift_score': 0.0-1.0
|
|
}
|
|
"""
|
|
# Get both scans
|
|
scan1 = self.get_scan(scan1_id)
|
|
scan2 = self.get_scan(scan2_id)
|
|
|
|
if not scan1 or not scan2:
|
|
return None
|
|
|
|
# Check if scans use the same configuration
|
|
config1 = scan1.get('config_file', '')
|
|
config2 = scan2.get('config_file', '')
|
|
same_config = (config1 == config2) and (config1 != '')
|
|
|
|
# Generate warning message if configs differ
|
|
config_warning = None
|
|
if not same_config:
|
|
config_warning = (
|
|
f"These scans use different configurations. "
|
|
f"Scan #{scan1_id} used '{config1 or 'unknown'}' and "
|
|
f"Scan #{scan2_id} used '{config2 or 'unknown'}'. "
|
|
f"The comparison may show all changes as additions/removals if the scans "
|
|
f"cover different IP ranges or infrastructure."
|
|
)
|
|
|
|
# Extract port data
|
|
ports1 = self._extract_ports_from_scan(scan1)
|
|
ports2 = self._extract_ports_from_scan(scan2)
|
|
|
|
# Compare ports
|
|
ports_comparison = self._compare_ports(ports1, ports2)
|
|
|
|
# Extract service data
|
|
services1 = self._extract_services_from_scan(scan1)
|
|
services2 = self._extract_services_from_scan(scan2)
|
|
|
|
# Compare services
|
|
services_comparison = self._compare_services(services1, services2)
|
|
|
|
# Extract certificate data
|
|
certs1 = self._extract_certificates_from_scan(scan1)
|
|
certs2 = self._extract_certificates_from_scan(scan2)
|
|
|
|
# Compare certificates
|
|
certificates_comparison = self._compare_certificates(certs1, certs2)
|
|
|
|
# Calculate drift score (0.0 = identical, 1.0 = completely different)
|
|
drift_score = self._calculate_drift_score(
|
|
ports_comparison,
|
|
services_comparison,
|
|
certificates_comparison
|
|
)
|
|
|
|
return {
|
|
'scan1': {
|
|
'id': scan1['id'],
|
|
'timestamp': scan1['timestamp'],
|
|
'title': scan1['title'],
|
|
'status': scan1['status'],
|
|
'config_file': config1
|
|
},
|
|
'scan2': {
|
|
'id': scan2['id'],
|
|
'timestamp': scan2['timestamp'],
|
|
'title': scan2['title'],
|
|
'status': scan2['status'],
|
|
'config_file': config2
|
|
},
|
|
'same_config': same_config,
|
|
'config_warning': config_warning,
|
|
'ports': ports_comparison,
|
|
'services': services_comparison,
|
|
'certificates': certificates_comparison,
|
|
'drift_score': drift_score
|
|
}
|
|
|
|
def _extract_ports_from_scan(self, scan: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Extract port information from a scan.
|
|
|
|
Returns:
|
|
Dictionary mapping "ip:port:protocol" to port details
|
|
"""
|
|
ports = {}
|
|
for site in scan.get('sites', []):
|
|
for ip_data in site.get('ips', []):
|
|
ip_addr = ip_data['address']
|
|
for port_data in ip_data.get('ports', []):
|
|
key = f"{ip_addr}:{port_data['port']}:{port_data['protocol']}"
|
|
ports[key] = {
|
|
'ip': ip_addr,
|
|
'port': port_data['port'],
|
|
'protocol': port_data['protocol'],
|
|
'state': port_data.get('state', 'unknown'),
|
|
'expected': port_data.get('expected')
|
|
}
|
|
return ports
|
|
|
|
def _extract_services_from_scan(self, scan: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Extract service information from a scan.
|
|
|
|
Returns:
|
|
Dictionary mapping "ip:port:protocol" to service details
|
|
"""
|
|
services = {}
|
|
for site in scan.get('sites', []):
|
|
for ip_data in site.get('ips', []):
|
|
ip_addr = ip_data['address']
|
|
for port_data in ip_data.get('ports', []):
|
|
port_num = port_data['port']
|
|
protocol = port_data['protocol']
|
|
key = f"{ip_addr}:{port_num}:{protocol}"
|
|
|
|
# Get first service (usually only one per port)
|
|
port_services = port_data.get('services', [])
|
|
if port_services:
|
|
svc = port_services[0]
|
|
services[key] = {
|
|
'ip': ip_addr,
|
|
'port': port_num,
|
|
'protocol': protocol,
|
|
'service_name': svc.get('service_name'),
|
|
'product': svc.get('product'),
|
|
'version': svc.get('version'),
|
|
'extrainfo': svc.get('extrainfo')
|
|
}
|
|
return services
|
|
|
|
def _extract_certificates_from_scan(self, scan: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Extract certificate information from a scan.
|
|
|
|
Returns:
|
|
Dictionary mapping "ip:port" to certificate details
|
|
"""
|
|
certificates = {}
|
|
for site in scan.get('sites', []):
|
|
for ip_data in site.get('ips', []):
|
|
ip_addr = ip_data['address']
|
|
for port_data in ip_data.get('ports', []):
|
|
port_num = port_data['port']
|
|
protocol = port_data['protocol']
|
|
|
|
# Get certificates from services
|
|
for svc in port_data.get('services', []):
|
|
if svc.get('certificates'):
|
|
for cert in svc['certificates']:
|
|
key = f"{ip_addr}:{port_num}"
|
|
certificates[key] = {
|
|
'ip': ip_addr,
|
|
'port': port_num,
|
|
'subject': cert.get('subject'),
|
|
'issuer': cert.get('issuer'),
|
|
'not_valid_after': cert.get('not_valid_after'),
|
|
'days_until_expiry': cert.get('days_until_expiry'),
|
|
'is_self_signed': cert.get('is_self_signed')
|
|
}
|
|
return certificates
|
|
|
|
def _compare_ports(self, ports1: Dict, ports2: Dict) -> Dict[str, List]:
|
|
"""
|
|
Compare port sets between two scans.
|
|
|
|
Returns:
|
|
Dictionary with added, removed, and unchanged ports
|
|
"""
|
|
keys1 = set(ports1.keys())
|
|
keys2 = set(ports2.keys())
|
|
|
|
added_keys = keys2 - keys1
|
|
removed_keys = keys1 - keys2
|
|
unchanged_keys = keys1 & keys2
|
|
|
|
return {
|
|
'added': [ports2[k] for k in sorted(added_keys)],
|
|
'removed': [ports1[k] for k in sorted(removed_keys)],
|
|
'unchanged': [ports2[k] for k in sorted(unchanged_keys)]
|
|
}
|
|
|
|
def _compare_services(self, services1: Dict, services2: Dict) -> Dict[str, List]:
|
|
"""
|
|
Compare services between two scans.
|
|
|
|
Returns:
|
|
Dictionary with added, removed, and changed services
|
|
"""
|
|
keys1 = set(services1.keys())
|
|
keys2 = set(services2.keys())
|
|
|
|
added_keys = keys2 - keys1
|
|
removed_keys = keys1 - keys2
|
|
common_keys = keys1 & keys2
|
|
|
|
# Find changed services (same port, different version/product)
|
|
changed = []
|
|
for key in sorted(common_keys):
|
|
svc1 = services1[key]
|
|
svc2 = services2[key]
|
|
|
|
# Check if service details changed
|
|
if (svc1.get('product') != svc2.get('product') or
|
|
svc1.get('version') != svc2.get('version') or
|
|
svc1.get('service_name') != svc2.get('service_name')):
|
|
changed.append({
|
|
'ip': svc2['ip'],
|
|
'port': svc2['port'],
|
|
'protocol': svc2['protocol'],
|
|
'old': {
|
|
'service_name': svc1.get('service_name'),
|
|
'product': svc1.get('product'),
|
|
'version': svc1.get('version')
|
|
},
|
|
'new': {
|
|
'service_name': svc2.get('service_name'),
|
|
'product': svc2.get('product'),
|
|
'version': svc2.get('version')
|
|
}
|
|
})
|
|
|
|
return {
|
|
'added': [services2[k] for k in sorted(added_keys)],
|
|
'removed': [services1[k] for k in sorted(removed_keys)],
|
|
'changed': changed
|
|
}
|
|
|
|
def _compare_certificates(self, certs1: Dict, certs2: Dict) -> Dict[str, List]:
|
|
"""
|
|
Compare certificates between two scans.
|
|
|
|
Returns:
|
|
Dictionary with added, removed, and changed certificates
|
|
"""
|
|
keys1 = set(certs1.keys())
|
|
keys2 = set(certs2.keys())
|
|
|
|
added_keys = keys2 - keys1
|
|
removed_keys = keys1 - keys2
|
|
common_keys = keys1 & keys2
|
|
|
|
# Find changed certificates (same IP:port, different cert details)
|
|
changed = []
|
|
for key in sorted(common_keys):
|
|
cert1 = certs1[key]
|
|
cert2 = certs2[key]
|
|
|
|
# Check if certificate changed
|
|
if (cert1.get('subject') != cert2.get('subject') or
|
|
cert1.get('issuer') != cert2.get('issuer') or
|
|
cert1.get('not_valid_after') != cert2.get('not_valid_after')):
|
|
changed.append({
|
|
'ip': cert2['ip'],
|
|
'port': cert2['port'],
|
|
'old': {
|
|
'subject': cert1.get('subject'),
|
|
'issuer': cert1.get('issuer'),
|
|
'not_valid_after': cert1.get('not_valid_after'),
|
|
'days_until_expiry': cert1.get('days_until_expiry')
|
|
},
|
|
'new': {
|
|
'subject': cert2.get('subject'),
|
|
'issuer': cert2.get('issuer'),
|
|
'not_valid_after': cert2.get('not_valid_after'),
|
|
'days_until_expiry': cert2.get('days_until_expiry')
|
|
}
|
|
})
|
|
|
|
return {
|
|
'added': [certs2[k] for k in sorted(added_keys)],
|
|
'removed': [certs1[k] for k in sorted(removed_keys)],
|
|
'changed': changed
|
|
}
|
|
|
|
def _calculate_drift_score(self, ports_comp: Dict, services_comp: Dict,
|
|
certs_comp: Dict) -> float:
|
|
"""
|
|
Calculate drift score based on comparison results.
|
|
|
|
Returns:
|
|
Float between 0.0 (identical) and 1.0 (completely different)
|
|
"""
|
|
# Count total items in both scans
|
|
total_ports = (
|
|
len(ports_comp['added']) +
|
|
len(ports_comp['removed']) +
|
|
len(ports_comp['unchanged'])
|
|
)
|
|
|
|
total_services = (
|
|
len(services_comp['added']) +
|
|
len(services_comp['removed']) +
|
|
len(services_comp['changed']) +
|
|
max(0, len(ports_comp['unchanged']) - len(services_comp['changed']))
|
|
)
|
|
|
|
# Count changed items
|
|
changed_ports = len(ports_comp['added']) + len(ports_comp['removed'])
|
|
changed_services = (
|
|
len(services_comp['added']) +
|
|
len(services_comp['removed']) +
|
|
len(services_comp['changed'])
|
|
)
|
|
changed_certs = (
|
|
len(certs_comp['added']) +
|
|
len(certs_comp['removed']) +
|
|
len(certs_comp['changed'])
|
|
)
|
|
|
|
# Calculate weighted drift score
|
|
# Ports have 50% weight, services 30%, certificates 20%
|
|
port_drift = changed_ports / max(total_ports, 1)
|
|
service_drift = changed_services / max(total_services, 1)
|
|
cert_drift = changed_certs / max(len(certs_comp['added']) + len(certs_comp['removed']) + len(certs_comp['changed']), 1)
|
|
|
|
drift_score = (port_drift * 0.5) + (service_drift * 0.3) + (cert_drift * 0.2)
|
|
|
|
return round(min(drift_score, 1.0), 3) # Cap at 1.0 and round to 3 decimals
|