- Replace subprocess.run() with Popen for cancellable processes - Add cancel() method to SneakyScanner with process termination - Track running scanners in registry for stop signal delivery - Handle ScanCancelledError to set scan status to 'cancelled' - Add POST /api/scans/<id>/stop endpoint - Add 'cancelled' as valid scan status - Add Stop button to scans list and detail views - Show cancelled status with warning badge in UI
349 lines
12 KiB
Python
349 lines
12 KiB
Python
"""
|
|
Background scan job execution.
|
|
|
|
This module handles the execution of scans in background threads,
|
|
updating database status and handling errors.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import threading
|
|
import traceback
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from src.scanner import SneakyScanner, ScanCancelledError
|
|
from web.models import Scan, ScanProgress
|
|
from web.services.scan_service import ScanService
|
|
from web.services.alert_service import AlertService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Registry for tracking running scanners (scan_id -> SneakyScanner instance)
|
|
_running_scanners = {}
|
|
_running_scanners_lock = threading.Lock()
|
|
|
|
|
|
def get_running_scanner(scan_id: int):
|
|
"""Get a running scanner instance by scan ID."""
|
|
with _running_scanners_lock:
|
|
return _running_scanners.get(scan_id)
|
|
|
|
|
|
def stop_scan(scan_id: int, db_url: str) -> bool:
|
|
"""
|
|
Stop a running scan.
|
|
|
|
Args:
|
|
scan_id: ID of the scan to stop
|
|
db_url: Database connection URL
|
|
|
|
Returns:
|
|
True if scan was cancelled, False if not found or already stopped
|
|
"""
|
|
logger.info(f"Attempting to stop scan {scan_id}")
|
|
|
|
# Get the scanner instance
|
|
scanner = get_running_scanner(scan_id)
|
|
if not scanner:
|
|
logger.warning(f"Scanner for scan {scan_id} not found in registry")
|
|
return False
|
|
|
|
# Cancel the scanner
|
|
scanner.cancel()
|
|
logger.info(f"Cancellation signal sent to scan {scan_id}")
|
|
|
|
return True
|
|
|
|
|
|
def create_progress_callback(scan_id: int, session):
|
|
"""
|
|
Create a progress callback function for updating scan progress in database.
|
|
|
|
Args:
|
|
scan_id: ID of the scan record
|
|
session: Database session
|
|
|
|
Returns:
|
|
Callback function that accepts (phase, ip, data)
|
|
"""
|
|
ip_to_site = {}
|
|
|
|
def progress_callback(phase: str, ip: str, data: dict):
|
|
"""Update scan progress in database."""
|
|
nonlocal ip_to_site
|
|
|
|
try:
|
|
# Get scan record
|
|
scan = session.query(Scan).filter_by(id=scan_id).first()
|
|
if not scan:
|
|
return
|
|
|
|
# Handle initialization phase
|
|
if phase == 'init':
|
|
scan.total_ips = data.get('total_ips', 0)
|
|
scan.completed_ips = 0
|
|
scan.current_phase = 'ping'
|
|
ip_to_site = data.get('ip_to_site', {})
|
|
|
|
# Create progress entries for all IPs
|
|
for ip_addr, site_name in ip_to_site.items():
|
|
progress = ScanProgress(
|
|
scan_id=scan_id,
|
|
ip_address=ip_addr,
|
|
site_name=site_name,
|
|
phase='pending',
|
|
status='pending'
|
|
)
|
|
session.add(progress)
|
|
|
|
session.commit()
|
|
return
|
|
|
|
# Update current phase
|
|
if data.get('status') == 'starting':
|
|
scan.current_phase = phase
|
|
scan.completed_ips = 0
|
|
session.commit()
|
|
return
|
|
|
|
# Handle phase completion with results
|
|
if data.get('status') == 'completed':
|
|
results = data.get('results', {})
|
|
|
|
if phase == 'ping':
|
|
# Update progress entries with ping results
|
|
for ip_addr, ping_result in results.items():
|
|
progress = session.query(ScanProgress).filter_by(
|
|
scan_id=scan_id, ip_address=ip_addr
|
|
).first()
|
|
if progress:
|
|
progress.ping_result = ping_result
|
|
progress.phase = 'ping'
|
|
progress.status = 'completed'
|
|
|
|
scan.completed_ips = len(results)
|
|
|
|
elif phase == 'tcp_scan':
|
|
# Update progress entries with TCP/UDP port results
|
|
for ip_addr, port_data in results.items():
|
|
progress = session.query(ScanProgress).filter_by(
|
|
scan_id=scan_id, ip_address=ip_addr
|
|
).first()
|
|
if progress:
|
|
progress.tcp_ports = json.dumps(port_data.get('tcp_ports', []))
|
|
progress.udp_ports = json.dumps(port_data.get('udp_ports', []))
|
|
progress.phase = 'tcp_scan'
|
|
progress.status = 'completed'
|
|
|
|
scan.completed_ips = len(results)
|
|
|
|
elif phase == 'service_detection':
|
|
# Update progress entries with service detection results
|
|
for ip_addr, services in results.items():
|
|
progress = session.query(ScanProgress).filter_by(
|
|
scan_id=scan_id, ip_address=ip_addr
|
|
).first()
|
|
if progress:
|
|
# Simplify service data for storage
|
|
service_list = []
|
|
for svc in services:
|
|
service_list.append({
|
|
'port': svc.get('port'),
|
|
'service': svc.get('service', 'unknown'),
|
|
'product': svc.get('product', ''),
|
|
'version': svc.get('version', '')
|
|
})
|
|
progress.services = json.dumps(service_list)
|
|
progress.phase = 'service_detection'
|
|
progress.status = 'completed'
|
|
|
|
scan.completed_ips = len(results)
|
|
|
|
elif phase == 'http_analysis':
|
|
# Mark HTTP analysis as complete
|
|
scan.current_phase = 'completed'
|
|
scan.completed_ips = scan.total_ips
|
|
|
|
session.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Progress callback error for scan {scan_id}: {str(e)}")
|
|
# Don't re-raise - we don't want to break the scan
|
|
session.rollback()
|
|
|
|
return progress_callback
|
|
|
|
|
|
def execute_scan(scan_id: int, config_id: int, db_url: str = None):
|
|
"""
|
|
Execute a scan in the background.
|
|
|
|
This function is designed to run in a background thread via APScheduler.
|
|
It creates its own database session to avoid conflicts with the main
|
|
application thread.
|
|
|
|
Args:
|
|
scan_id: ID of the scan record in database
|
|
config_id: Database config ID
|
|
db_url: Database connection URL
|
|
|
|
Workflow:
|
|
1. Create new database session for this thread
|
|
2. Update scan status to 'running'
|
|
3. Execute scanner
|
|
4. Generate output files (JSON, HTML, ZIP)
|
|
5. Save results to database
|
|
6. Update status to 'completed' or 'failed'
|
|
"""
|
|
logger.info(f"Starting background scan execution: scan_id={scan_id}, config_id={config_id}")
|
|
|
|
# Create new database session for this thread
|
|
engine = create_engine(db_url, echo=False)
|
|
Session = sessionmaker(bind=engine)
|
|
session = Session()
|
|
|
|
try:
|
|
# Get scan record
|
|
scan = session.query(Scan).filter_by(id=scan_id).first()
|
|
if not scan:
|
|
logger.error(f"Scan {scan_id} not found in database")
|
|
return
|
|
|
|
# Update status to running (in case it wasn't already)
|
|
scan.status = 'running'
|
|
scan.started_at = datetime.utcnow()
|
|
session.commit()
|
|
|
|
logger.info(f"Scan {scan_id}: Initializing scanner with config_id={config_id}")
|
|
|
|
# Initialize scanner with database config
|
|
scanner = SneakyScanner(config_id=config_id)
|
|
|
|
# Register scanner in the running registry
|
|
with _running_scanners_lock:
|
|
_running_scanners[scan_id] = scanner
|
|
logger.debug(f"Scan {scan_id}: Registered in running scanners registry")
|
|
|
|
# Create progress callback
|
|
progress_callback = create_progress_callback(scan_id, session)
|
|
|
|
# Execute scan with progress tracking
|
|
logger.info(f"Scan {scan_id}: Running scanner...")
|
|
start_time = datetime.utcnow()
|
|
report, timestamp = scanner.scan(progress_callback=progress_callback)
|
|
end_time = datetime.utcnow()
|
|
|
|
scan_duration = (end_time - start_time).total_seconds()
|
|
logger.info(f"Scan {scan_id}: Scanner completed in {scan_duration:.2f} seconds")
|
|
|
|
# Generate output files (JSON, HTML, ZIP)
|
|
logger.info(f"Scan {scan_id}: Generating output files...")
|
|
output_paths = scanner.generate_outputs(report, timestamp)
|
|
|
|
# Save results to database
|
|
logger.info(f"Scan {scan_id}: Saving results to database...")
|
|
scan_service = ScanService(session)
|
|
scan_service._save_scan_to_db(report, scan_id, status='completed', output_paths=output_paths)
|
|
|
|
# Evaluate alert rules
|
|
logger.info(f"Scan {scan_id}: Evaluating alert rules...")
|
|
try:
|
|
alert_service = AlertService(session)
|
|
alerts_triggered = alert_service.evaluate_alert_rules(scan_id)
|
|
logger.info(f"Scan {scan_id}: {len(alerts_triggered)} alerts triggered")
|
|
except Exception as e:
|
|
# Don't fail the scan if alert evaluation fails
|
|
logger.error(f"Scan {scan_id}: Alert evaluation failed: {str(e)}")
|
|
logger.debug(f"Alert evaluation error details: {traceback.format_exc()}")
|
|
|
|
logger.info(f"Scan {scan_id}: Completed successfully")
|
|
|
|
except ScanCancelledError:
|
|
# Scan was cancelled by user
|
|
logger.info(f"Scan {scan_id}: Cancelled by user")
|
|
|
|
scan = session.query(Scan).filter_by(id=scan_id).first()
|
|
if scan:
|
|
scan.status = 'cancelled'
|
|
scan.error_message = 'Scan cancelled by user'
|
|
scan.completed_at = datetime.utcnow()
|
|
if scan.started_at:
|
|
scan.duration = (datetime.utcnow() - scan.started_at).total_seconds()
|
|
session.commit()
|
|
|
|
except FileNotFoundError as e:
|
|
# Config file not found
|
|
error_msg = f"Configuration file not found: {str(e)}"
|
|
logger.error(f"Scan {scan_id}: {error_msg}")
|
|
|
|
scan = session.query(Scan).filter_by(id=scan_id).first()
|
|
if scan:
|
|
scan.status = 'failed'
|
|
scan.error_message = error_msg
|
|
scan.completed_at = datetime.utcnow()
|
|
session.commit()
|
|
|
|
except Exception as e:
|
|
# Any other error during scan execution
|
|
error_msg = f"Scan execution failed: {str(e)}"
|
|
logger.error(f"Scan {scan_id}: {error_msg}")
|
|
logger.error(f"Scan {scan_id}: Traceback:\n{traceback.format_exc()}")
|
|
|
|
try:
|
|
scan = session.query(Scan).filter_by(id=scan_id).first()
|
|
if scan:
|
|
scan.status = 'failed'
|
|
scan.error_message = error_msg
|
|
scan.completed_at = datetime.utcnow()
|
|
session.commit()
|
|
except Exception as db_error:
|
|
logger.error(f"Scan {scan_id}: Failed to update error status in database: {str(db_error)}")
|
|
|
|
finally:
|
|
# Unregister scanner from registry
|
|
with _running_scanners_lock:
|
|
if scan_id in _running_scanners:
|
|
del _running_scanners[scan_id]
|
|
logger.debug(f"Scan {scan_id}: Unregistered from running scanners registry")
|
|
|
|
# Always close the session
|
|
session.close()
|
|
logger.info(f"Scan {scan_id}: Background job completed, session closed")
|
|
|
|
|
|
def get_scan_status_from_db(scan_id: int, db_url: str) -> dict:
|
|
"""
|
|
Helper function to get scan status directly from database.
|
|
|
|
Useful for monitoring background jobs without needing Flask app context.
|
|
|
|
Args:
|
|
scan_id: Scan ID to check
|
|
db_url: Database connection URL
|
|
|
|
Returns:
|
|
Dictionary with scan status information
|
|
"""
|
|
engine = create_engine(db_url, echo=False)
|
|
Session = sessionmaker(bind=engine)
|
|
session = Session()
|
|
|
|
try:
|
|
scan = session.query(Scan).filter_by(id=scan_id).first()
|
|
if not scan:
|
|
return None
|
|
|
|
return {
|
|
'scan_id': scan.id,
|
|
'status': scan.status,
|
|
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
|
|
'duration': scan.duration,
|
|
'error_message': scan.error_message
|
|
}
|
|
finally:
|
|
session.close()
|