Files
SneakyScan/web/services/scan_service.py
Phillip Tarrant d7c68a2be8 Phase 2 Step 1: Implement database and service layer
Complete the foundation for Phase 2 by implementing the service layer,
utilities, and comprehensive test suite. This establishes the core
business logic for scan management.

Service Layer:
- Add ScanService class with complete scan lifecycle management
  * trigger_scan() - Create scan record and prepare for execution
  * get_scan() - Retrieve scan with all related data (eager loading)
  * list_scans() - Paginated scan list with status filtering
  * delete_scan() - Remove scan from DB and delete all files
  * get_scan_status() - Poll current scan status and progress
  * _save_scan_to_db() - Persist scan results to database
  * _map_report_to_models() - Complex JSON-to-DB mapping logic

Database Mapping:
- Comprehensive mapping from scanner JSON output to normalized schema
- Handles nested relationships: sites → IPs → ports → services → certs → TLS
- Processes both TCP and UDP ports with expected/actual tracking
- Maps service detection results with HTTP/HTTPS information
- Stores SSL/TLS certificates with expiration tracking
- Records TLS version support and cipher suites
- Links screenshots to services

Utilities:
- Add pagination.py with PaginatedResult class
  * paginate() function for SQLAlchemy queries
  * validate_page_params() for input sanitization
  * Metadata: total, pages, has_prev, has_next, etc.

- Add validators.py with comprehensive validation functions
  * validate_config_file() - YAML structure and required fields
  * validate_scan_status() - Enum validation (running/completed/failed)
  * validate_scan_id() - Positive integer validation
  * validate_port() - Port range validation (1-65535)
  * validate_ip_address() - Basic IPv4 format validation
  * sanitize_filename() - Path traversal prevention

Database Migration:
- Add migration 002 for scan status index
- Optimizes queries filtering by scan status
- Timestamp index already exists from migration 001

Testing:
- Add pytest infrastructure with conftest.py
  * test_db fixture - Temporary SQLite database per test
  * sample_scan_report fixture - Realistic scanner output
  * sample_config_file fixture - Valid YAML config
  * sample_invalid_config_file fixture - For validation tests

- Add comprehensive test_scan_service.py (15 tests)
  * Test scan trigger with valid/invalid configs
  * Test scan retrieval (found/not found cases)
  * Test scan listing with pagination and filtering
  * Test scan deletion with cascade cleanup
  * Test scan status retrieval
  * Test database mapping from JSON to models
  * Test expected vs actual port flagging
  * Test certificate and TLS data mapping
  * Test full scan retrieval with all relationships
  * All tests passing

Files Added:
- web/services/__init__.py
- web/services/scan_service.py (545 lines)
- web/utils/pagination.py (153 lines)
- web/utils/validators.py (245 lines)
- migrations/versions/002_add_scan_indexes.py
- tests/__init__.py
- tests/conftest.py (142 lines)
- tests/test_scan_service.py (374 lines)

Next Steps (Step 2):
- Implement scan API endpoints in web/api/scans.py
- Add authentication decorators
- Integrate ScanService with API routes
- Test API endpoints with integration tests

Phase 2 Step 1 Complete ✓
2025-11-14 00:26:06 -06:00

590 lines
20 KiB
Python

"""
Scan service for managing scan operations and database integration.
This service handles the business logic for triggering scans, retrieving
scan results, and mapping scanner output to database models.
"""
import json
import logging
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session, joinedload
from web.models import (
Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel,
ScanCertificate, ScanTLSVersion
)
from web.utils.pagination import paginate, PaginatedResult
from web.utils.validators import validate_config_file, validate_scan_status
logger = logging.getLogger(__name__)
class ScanService:
"""
Service for managing scan operations.
Handles scan lifecycle: triggering, status tracking, result storage,
and cleanup.
"""
def __init__(self, db_session: Session):
"""
Initialize scan service.
Args:
db_session: SQLAlchemy database session
"""
self.db = db_session
def trigger_scan(self, config_file: str, triggered_by: str = 'manual',
schedule_id: Optional[int] = None) -> int:
"""
Trigger a new scan.
Creates a Scan record in the database with status='running' and
queues the scan for background execution.
Args:
config_file: Path to YAML configuration file
triggered_by: Source that triggered scan (manual, scheduled, api)
schedule_id: Optional schedule ID if triggered by schedule
Returns:
Scan ID of the created scan
Raises:
ValueError: If config file is invalid
"""
# Validate config file
is_valid, error_msg = validate_config_file(config_file)
if not is_valid:
raise ValueError(f"Invalid config file: {error_msg}")
# Load config to get title
import yaml
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
# Create scan record
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_file=config_file,
title=config.get('title', 'Untitled Scan'),
triggered_by=triggered_by,
schedule_id=schedule_id,
created_at=datetime.utcnow()
)
self.db.add(scan)
self.db.commit()
self.db.refresh(scan)
logger.info(f"Scan {scan.id} triggered via {triggered_by}")
# NOTE: Background job queuing will be implemented in Step 3
# For now, just return the scan ID
return scan.id
def get_scan(self, scan_id: int) -> Optional[Dict[str, Any]]:
"""
Get scan details with all related data.
Args:
scan_id: Scan ID to retrieve
Returns:
Dictionary with scan data including sites, IPs, ports, services, etc.
Returns None if scan not found.
"""
# Query with eager loading of all relationships
scan = (
self.db.query(Scan)
.options(
joinedload(Scan.sites).joinedload(ScanSite.ips).joinedload(ScanIP.ports),
joinedload(Scan.ports).joinedload(ScanPort.services),
joinedload(Scan.services).joinedload(ScanServiceModel.certificates),
joinedload(Scan.certificates).joinedload(ScanCertificate.tls_versions)
)
.filter(Scan.id == scan_id)
.first()
)
if not scan:
return None
# Convert to dictionary
return self._scan_to_dict(scan)
def list_scans(self, page: int = 1, per_page: int = 20,
status_filter: Optional[str] = None) -> PaginatedResult:
"""
List scans with pagination and optional filtering.
Args:
page: Page number (1-indexed)
per_page: Items per page
status_filter: Optional filter by status (running, completed, failed)
Returns:
PaginatedResult with scan list and metadata
"""
# Build query
query = self.db.query(Scan).order_by(Scan.timestamp.desc())
# Apply status filter if provided
if status_filter:
is_valid, error_msg = validate_scan_status(status_filter)
if not is_valid:
raise ValueError(error_msg)
query = query.filter(Scan.status == status_filter)
# Paginate
result = paginate(query, page=page, per_page=per_page)
# Convert scans to dictionaries (summary only, not full details)
result.items = [self._scan_to_summary_dict(scan) for scan in result.items]
return result
def delete_scan(self, scan_id: int) -> bool:
"""
Delete a scan and all associated files.
Removes:
- Database record (cascade deletes related records)
- JSON report file
- HTML report file
- ZIP archive file
- Screenshot directory
Args:
scan_id: Scan ID to delete
Returns:
True if deleted successfully
Raises:
ValueError: If scan not found
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
logger.info(f"Deleting scan {scan_id}")
# Delete files (handle missing files gracefully)
files_to_delete = [
scan.json_path,
scan.html_path,
scan.zip_path
]
for file_path in files_to_delete:
if file_path:
try:
Path(file_path).unlink()
logger.debug(f"Deleted file: {file_path}")
except FileNotFoundError:
logger.warning(f"File not found (already deleted?): {file_path}")
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
# Delete screenshot directory
if scan.screenshot_dir:
try:
shutil.rmtree(scan.screenshot_dir)
logger.debug(f"Deleted directory: {scan.screenshot_dir}")
except FileNotFoundError:
logger.warning(f"Directory not found (already deleted?): {scan.screenshot_dir}")
except Exception as e:
logger.error(f"Error deleting directory {scan.screenshot_dir}: {e}")
# Delete database record (cascade handles related records)
self.db.delete(scan)
self.db.commit()
logger.info(f"Scan {scan_id} deleted successfully")
return True
def get_scan_status(self, scan_id: int) -> Optional[Dict[str, Any]]:
"""
Get current scan status and progress.
Args:
scan_id: Scan ID
Returns:
Dictionary with status information, or None if scan not found
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
return None
status_info = {
'scan_id': scan.id,
'status': scan.status,
'title': scan.title,
'started_at': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'triggered_by': scan.triggered_by
}
# Add progress estimate based on status
if scan.status == 'running':
status_info['progress'] = 'In progress'
elif scan.status == 'completed':
status_info['progress'] = 'Complete'
elif scan.status == 'failed':
status_info['progress'] = 'Failed'
return status_info
def _save_scan_to_db(self, report: Dict[str, Any], scan_id: int,
status: str = 'completed') -> None:
"""
Save scan results to database.
Updates the Scan record and creates all related records
(sites, IPs, ports, services, certificates, TLS versions).
Args:
report: Scan report dictionary from scanner
scan_id: Scan ID to update
status: Final scan status (completed or failed)
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
# Update scan record
scan.status = status
scan.duration = report.get('scan_duration')
# Map report data to database models
self._map_report_to_models(report, scan)
self.db.commit()
logger.info(f"Scan {scan_id} saved to database with status '{status}'")
def _map_report_to_models(self, report: Dict[str, Any], scan_obj: Scan) -> None:
"""
Map JSON report structure to database models.
Creates records for sites, IPs, ports, services, certificates, and TLS versions.
Processes nested relationships in order to handle foreign keys correctly.
Args:
report: Scan report dictionary
scan_obj: Scan database object to attach records to
"""
logger.debug(f"Mapping report to database models for scan {scan_obj.id}")
# Process each site
for site_data in report.get('sites', []):
# Create ScanSite record
site = ScanSite(
scan_id=scan_obj.id,
site_name=site_data['name']
)
self.db.add(site)
self.db.flush() # Get site.id for foreign key
# Process each IP in this site
for ip_data in site_data.get('ips', []):
# Create ScanIP record
ip = ScanIP(
scan_id=scan_obj.id,
site_id=site.id,
ip_address=ip_data['address'],
ping_expected=ip_data.get('expected', {}).get('ping'),
ping_actual=ip_data.get('actual', {}).get('ping')
)
self.db.add(ip)
self.db.flush()
# Process TCP ports
expected_tcp = set(ip_data.get('expected', {}).get('tcp_ports', []))
actual_tcp = ip_data.get('actual', {}).get('tcp_ports', [])
for port_num in actual_tcp:
port = ScanPort(
scan_id=scan_obj.id,
ip_id=ip.id,
port=port_num,
protocol='tcp',
expected=(port_num in expected_tcp),
state='open'
)
self.db.add(port)
self.db.flush()
# Find service for this port
service_data = self._find_service_for_port(
ip_data.get('actual', {}).get('services', []),
port_num
)
if service_data:
# Create ScanService record
service = ScanServiceModel(
scan_id=scan_obj.id,
port_id=port.id,
service_name=service_data.get('service'),
product=service_data.get('product'),
version=service_data.get('version'),
extrainfo=service_data.get('extrainfo'),
ostype=service_data.get('ostype'),
http_protocol=service_data.get('http_info', {}).get('protocol'),
screenshot_path=service_data.get('http_info', {}).get('screenshot')
)
self.db.add(service)
self.db.flush()
# Process certificate and TLS info if present
http_info = service_data.get('http_info', {})
if http_info.get('certificate'):
self._process_certificate(
http_info['certificate'],
scan_obj.id,
service.id
)
# Process UDP ports
expected_udp = set(ip_data.get('expected', {}).get('udp_ports', []))
actual_udp = ip_data.get('actual', {}).get('udp_ports', [])
for port_num in actual_udp:
port = ScanPort(
scan_id=scan_obj.id,
ip_id=ip.id,
port=port_num,
protocol='udp',
expected=(port_num in expected_udp),
state='open'
)
self.db.add(port)
logger.debug(f"Report mapping complete for scan {scan_obj.id}")
def _find_service_for_port(self, services: List[Dict], port: int) -> Optional[Dict]:
"""
Find service data for a specific port.
Args:
services: List of service dictionaries
port: Port number to find
Returns:
Service dictionary if found, None otherwise
"""
for service in services:
if service.get('port') == port:
return service
return None
def _process_certificate(self, cert_data: Dict[str, Any], scan_id: int,
service_id: int) -> None:
"""
Process certificate and TLS version data.
Args:
cert_data: Certificate data dictionary
scan_id: Scan ID
service_id: Service ID
"""
# Create ScanCertificate record
cert = ScanCertificate(
scan_id=scan_id,
service_id=service_id,
subject=cert_data.get('subject'),
issuer=cert_data.get('issuer'),
serial_number=cert_data.get('serial_number'),
not_valid_before=self._parse_datetime(cert_data.get('not_valid_before')),
not_valid_after=self._parse_datetime(cert_data.get('not_valid_after')),
days_until_expiry=cert_data.get('days_until_expiry'),
sans=json.dumps(cert_data.get('sans', [])),
is_self_signed=cert_data.get('is_self_signed', False)
)
self.db.add(cert)
self.db.flush()
# Process TLS versions
tls_versions = cert_data.get('tls_versions', {})
for version, version_data in tls_versions.items():
tls = ScanTLSVersion(
scan_id=scan_id,
certificate_id=cert.id,
tls_version=version,
supported=version_data.get('supported', False),
cipher_suites=json.dumps(version_data.get('cipher_suites', []))
)
self.db.add(tls)
def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]:
"""
Parse ISO datetime string.
Args:
date_str: ISO format datetime string
Returns:
datetime object or None if parsing fails
"""
if not date_str:
return None
try:
# Handle ISO format with 'Z' suffix
if date_str.endswith('Z'):
date_str = date_str[:-1] + '+00:00'
return datetime.fromisoformat(date_str)
except (ValueError, AttributeError) as e:
logger.warning(f"Failed to parse datetime '{date_str}': {e}")
return None
def _scan_to_dict(self, scan: Scan) -> Dict[str, Any]:
"""
Convert Scan object to dictionary with full details.
Args:
scan: Scan database object
Returns:
Dictionary representation with all related data
"""
return {
'id': scan.id,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'status': scan.status,
'title': scan.title,
'config_file': scan.config_file,
'json_path': scan.json_path,
'html_path': scan.html_path,
'zip_path': scan.zip_path,
'screenshot_dir': scan.screenshot_dir,
'triggered_by': scan.triggered_by,
'created_at': scan.created_at.isoformat() if scan.created_at else None,
'sites': [self._site_to_dict(site) for site in scan.sites]
}
def _scan_to_summary_dict(self, scan: Scan) -> Dict[str, Any]:
"""
Convert Scan object to summary dictionary (no related data).
Args:
scan: Scan database object
Returns:
Summary dictionary
"""
return {
'id': scan.id,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'status': scan.status,
'title': scan.title,
'triggered_by': scan.triggered_by,
'created_at': scan.created_at.isoformat() if scan.created_at else None
}
def _site_to_dict(self, site: ScanSite) -> Dict[str, Any]:
"""Convert ScanSite to dictionary."""
return {
'id': site.id,
'name': site.site_name,
'ips': [self._ip_to_dict(ip) for ip in site.ips]
}
def _ip_to_dict(self, ip: ScanIP) -> Dict[str, Any]:
"""Convert ScanIP to dictionary."""
return {
'id': ip.id,
'address': ip.ip_address,
'ping_expected': ip.ping_expected,
'ping_actual': ip.ping_actual,
'ports': [self._port_to_dict(port) for port in ip.ports]
}
def _port_to_dict(self, port: ScanPort) -> Dict[str, Any]:
"""Convert ScanPort to dictionary."""
return {
'id': port.id,
'port': port.port,
'protocol': port.protocol,
'state': port.state,
'expected': port.expected,
'services': [self._service_to_dict(svc) for svc in port.services]
}
def _service_to_dict(self, service: ScanServiceModel) -> Dict[str, Any]:
"""Convert ScanService to dictionary."""
result = {
'id': service.id,
'service_name': service.service_name,
'product': service.product,
'version': service.version,
'extrainfo': service.extrainfo,
'ostype': service.ostype,
'http_protocol': service.http_protocol,
'screenshot_path': service.screenshot_path
}
# Add certificate info if present
if service.certificates:
result['certificates'] = [
self._certificate_to_dict(cert) for cert in service.certificates
]
return result
def _certificate_to_dict(self, cert: ScanCertificate) -> Dict[str, Any]:
"""Convert ScanCertificate to dictionary."""
result = {
'id': cert.id,
'subject': cert.subject,
'issuer': cert.issuer,
'serial_number': cert.serial_number,
'not_valid_before': cert.not_valid_before.isoformat() if cert.not_valid_before else None,
'not_valid_after': cert.not_valid_after.isoformat() if cert.not_valid_after else None,
'days_until_expiry': cert.days_until_expiry,
'is_self_signed': cert.is_self_signed
}
# Parse SANs from JSON
if cert.sans:
try:
result['sans'] = json.loads(cert.sans)
except json.JSONDecodeError:
result['sans'] = []
# Add TLS versions
result['tls_versions'] = [
self._tls_version_to_dict(tls) for tls in cert.tls_versions
]
return result
def _tls_version_to_dict(self, tls: ScanTLSVersion) -> Dict[str, Any]:
"""Convert ScanTLSVersion to dictionary."""
result = {
'id': tls.id,
'tls_version': tls.tls_version,
'supported': tls.supported
}
# Parse cipher suites from JSON
if tls.cipher_suites:
try:
result['cipher_suites'] = json.loads(tls.cipher_suites)
except json.JSONDecodeError:
result['cipher_suites'] = []
return result