""" Background scan job execution. This module handles the execution of scans in background threads, updating database status and handling errors. """ import json import logging import threading import traceback from datetime import datetime from pathlib import Path from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from src.scanner import SneakyScanner, ScanCancelledError from web.models import Scan, ScanProgress from web.services.scan_service import ScanService from web.services.alert_service import AlertService logger = logging.getLogger(__name__) # Registry for tracking running scanners (scan_id -> SneakyScanner instance) _running_scanners = {} _running_scanners_lock = threading.Lock() def get_running_scanner(scan_id: int): """Get a running scanner instance by scan ID.""" with _running_scanners_lock: return _running_scanners.get(scan_id) def stop_scan(scan_id: int, db_url: str) -> bool: """ Stop a running scan. Args: scan_id: ID of the scan to stop db_url: Database connection URL Returns: True if scan was cancelled, False if not found or already stopped """ logger.info(f"Attempting to stop scan {scan_id}") # Get the scanner instance scanner = get_running_scanner(scan_id) if not scanner: logger.warning(f"Scanner for scan {scan_id} not found in registry") return False # Cancel the scanner scanner.cancel() logger.info(f"Cancellation signal sent to scan {scan_id}") return True def create_progress_callback(scan_id: int, session): """ Create a progress callback function for updating scan progress in database. Args: scan_id: ID of the scan record session: Database session Returns: Callback function that accepts (phase, ip, data) """ ip_to_site = {} def progress_callback(phase: str, ip: str, data: dict): """Update scan progress in database.""" nonlocal ip_to_site try: # Get scan record scan = session.query(Scan).filter_by(id=scan_id).first() if not scan: return # Handle initialization phase if phase == 'init': scan.total_ips = data.get('total_ips', 0) scan.completed_ips = 0 scan.current_phase = 'ping' ip_to_site = data.get('ip_to_site', {}) # Create progress entries for all IPs for ip_addr, site_name in ip_to_site.items(): progress = ScanProgress( scan_id=scan_id, ip_address=ip_addr, site_name=site_name, phase='pending', status='pending' ) session.add(progress) session.commit() return # Update current phase if data.get('status') == 'starting': scan.current_phase = phase scan.completed_ips = 0 session.commit() return # Handle phase completion with results if data.get('status') == 'completed': results = data.get('results', {}) if phase == 'ping': # Update progress entries with ping results for ip_addr, ping_result in results.items(): progress = session.query(ScanProgress).filter_by( scan_id=scan_id, ip_address=ip_addr ).first() if progress: progress.ping_result = ping_result progress.phase = 'ping' progress.status = 'completed' scan.completed_ips = len(results) elif phase == 'tcp_scan': # Update progress entries with TCP/UDP port results for ip_addr, port_data in results.items(): progress = session.query(ScanProgress).filter_by( scan_id=scan_id, ip_address=ip_addr ).first() if progress: progress.tcp_ports = json.dumps(port_data.get('tcp_ports', [])) progress.udp_ports = json.dumps(port_data.get('udp_ports', [])) progress.phase = 'tcp_scan' progress.status = 'completed' scan.completed_ips = len(results) elif phase == 'service_detection': # Update progress entries with service detection results for ip_addr, services in results.items(): progress = session.query(ScanProgress).filter_by( scan_id=scan_id, ip_address=ip_addr ).first() if progress: # Simplify service data for storage service_list = [] for svc in services: service_list.append({ 'port': svc.get('port'), 'service': svc.get('service', 'unknown'), 'product': svc.get('product', ''), 'version': svc.get('version', '') }) progress.services = json.dumps(service_list) progress.phase = 'service_detection' progress.status = 'completed' scan.completed_ips = len(results) elif phase == 'http_analysis': # Mark HTTP analysis as complete scan.current_phase = 'completed' scan.completed_ips = scan.total_ips session.commit() except Exception as e: logger.error(f"Progress callback error for scan {scan_id}: {str(e)}") # Don't re-raise - we don't want to break the scan session.rollback() return progress_callback def execute_scan(scan_id: int, config_id: int, db_url: str = None): """ Execute a scan in the background. This function is designed to run in a background thread via APScheduler. It creates its own database session to avoid conflicts with the main application thread. Args: scan_id: ID of the scan record in database config_id: Database config ID db_url: Database connection URL Workflow: 1. Create new database session for this thread 2. Update scan status to 'running' 3. Execute scanner 4. Generate output files (JSON, HTML, ZIP) 5. Save results to database 6. Update status to 'completed' or 'failed' """ logger.info(f"Starting background scan execution: scan_id={scan_id}, config_id={config_id}") # Create new database session for this thread engine = create_engine(db_url, echo=False) Session = sessionmaker(bind=engine) session = Session() try: # Get scan record scan = session.query(Scan).filter_by(id=scan_id).first() if not scan: logger.error(f"Scan {scan_id} not found in database") return # Update status to running (in case it wasn't already) scan.status = 'running' scan.started_at = datetime.utcnow() session.commit() logger.info(f"Scan {scan_id}: Initializing scanner with config_id={config_id}") # Initialize scanner with database config scanner = SneakyScanner(config_id=config_id) # Register scanner in the running registry with _running_scanners_lock: _running_scanners[scan_id] = scanner logger.debug(f"Scan {scan_id}: Registered in running scanners registry") # Create progress callback progress_callback = create_progress_callback(scan_id, session) # Execute scan with progress tracking logger.info(f"Scan {scan_id}: Running scanner...") start_time = datetime.utcnow() report, timestamp = scanner.scan(progress_callback=progress_callback) end_time = datetime.utcnow() scan_duration = (end_time - start_time).total_seconds() logger.info(f"Scan {scan_id}: Scanner completed in {scan_duration:.2f} seconds") # Transition to 'finalizing' status before output generation try: scan = session.query(Scan).filter_by(id=scan_id).first() if scan: scan.status = 'finalizing' scan.current_phase = 'generating_outputs' session.commit() logger.info(f"Scan {scan_id}: Status changed to 'finalizing'") except Exception as e: logger.error(f"Scan {scan_id}: Failed to update status to finalizing: {e}") session.rollback() # Generate output files (JSON, HTML, ZIP) with error handling output_paths = {} output_generation_failed = False try: logger.info(f"Scan {scan_id}: Generating output files...") output_paths = scanner.generate_outputs(report, timestamp) except Exception as e: output_generation_failed = True logger.error(f"Scan {scan_id}: Output generation failed: {str(e)}") logger.error(f"Scan {scan_id}: Traceback:\n{traceback.format_exc()}") # Still mark scan as completed with warning since scan data is valid try: scan = session.query(Scan).filter_by(id=scan_id).first() if scan: scan.status = 'completed' scan.error_message = f"Scan completed but output file generation failed: {str(e)}" scan.completed_at = datetime.utcnow() if scan.started_at: scan.duration = (datetime.utcnow() - scan.started_at).total_seconds() session.commit() logger.info(f"Scan {scan_id}: Marked as completed with output generation warning") except Exception as db_error: logger.error(f"Scan {scan_id}: Failed to update status after output error: {db_error}") # Save results to database (only if output generation succeeded) if not output_generation_failed: logger.info(f"Scan {scan_id}: Saving results to database...") scan_service = ScanService(session) scan_service._save_scan_to_db(report, scan_id, status='completed', output_paths=output_paths) # Evaluate alert rules logger.info(f"Scan {scan_id}: Evaluating alert rules...") try: alert_service = AlertService(session) alerts_triggered = alert_service.evaluate_alert_rules(scan_id) logger.info(f"Scan {scan_id}: {len(alerts_triggered)} alerts triggered") except Exception as e: # Don't fail the scan if alert evaluation fails logger.error(f"Scan {scan_id}: Alert evaluation failed: {str(e)}") logger.debug(f"Alert evaluation error details: {traceback.format_exc()}") logger.info(f"Scan {scan_id}: Completed successfully") except ScanCancelledError: # Scan was cancelled by user logger.info(f"Scan {scan_id}: Cancelled by user") scan = session.query(Scan).filter_by(id=scan_id).first() if scan: scan.status = 'cancelled' scan.error_message = 'Scan cancelled by user' scan.completed_at = datetime.utcnow() if scan.started_at: scan.duration = (datetime.utcnow() - scan.started_at).total_seconds() session.commit() except FileNotFoundError as e: # Config file not found error_msg = f"Configuration file not found: {str(e)}" logger.error(f"Scan {scan_id}: {error_msg}") scan = session.query(Scan).filter_by(id=scan_id).first() if scan: scan.status = 'failed' scan.error_message = error_msg scan.completed_at = datetime.utcnow() session.commit() except Exception as e: # Any other error during scan execution error_msg = f"Scan execution failed: {str(e)}" logger.error(f"Scan {scan_id}: {error_msg}") logger.error(f"Scan {scan_id}: Traceback:\n{traceback.format_exc()}") try: scan = session.query(Scan).filter_by(id=scan_id).first() if scan: scan.status = 'failed' scan.error_message = error_msg scan.completed_at = datetime.utcnow() session.commit() except Exception as db_error: logger.error(f"Scan {scan_id}: Failed to update error status in database: {str(db_error)}") finally: # Unregister scanner from registry with _running_scanners_lock: if scan_id in _running_scanners: del _running_scanners[scan_id] logger.debug(f"Scan {scan_id}: Unregistered from running scanners registry") # Always close the session session.close() logger.info(f"Scan {scan_id}: Background job completed, session closed") def get_scan_status_from_db(scan_id: int, db_url: str) -> dict: """ Helper function to get scan status directly from database. Useful for monitoring background jobs without needing Flask app context. Args: scan_id: Scan ID to check db_url: Database connection URL Returns: Dictionary with scan status information """ engine = create_engine(db_url, echo=False) Session = sessionmaker(bind=engine) session = Session() try: scan = session.query(Scan).filter_by(id=scan_id).first() if not scan: return None return { 'scan_id': scan.id, 'status': scan.status, 'timestamp': scan.timestamp.isoformat() if scan.timestamp else None, 'duration': scan.duration, 'error_message': scan.error_message } finally: session.close()