diff --git a/migrations/versions/002_add_scan_indexes.py b/migrations/versions/002_add_scan_indexes.py new file mode 100644 index 0000000..9e5ad73 --- /dev/null +++ b/migrations/versions/002_add_scan_indexes.py @@ -0,0 +1,28 @@ +"""Add indexes for scan queries + +Revision ID: 002 +Revises: 001 +Create Date: 2025-11-14 00:30:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '002' +down_revision = '001' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add database indexes for better query performance.""" + # Add index on scans.status for filtering + # Note: index on scans.timestamp already exists from migration 001 + op.create_index('ix_scans_status', 'scans', ['status'], unique=False) + + +def downgrade() -> None: + """Remove indexes.""" + op.drop_index('ix_scans_status', table_name='scans') diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..312f72f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for SneakyScanner.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..706f81b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,196 @@ +""" +Pytest configuration and fixtures for SneakyScanner tests. +""" + +import os +import tempfile +from pathlib import Path + +import pytest +import yaml +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from web.models import Base + + +@pytest.fixture(scope='function') +def test_db(): + """ + Create a temporary test database. + + Yields a SQLAlchemy session for testing, then cleans up. + """ + # Create temporary database file + db_fd, db_path = tempfile.mkstemp(suffix='.db') + + # Create engine and session + engine = create_engine(f'sqlite:///{db_path}', echo=False) + Base.metadata.create_all(engine) + + Session = sessionmaker(bind=engine) + session = Session() + + yield session + + # Cleanup + session.close() + os.close(db_fd) + os.unlink(db_path) + + +@pytest.fixture +def sample_scan_report(): + """ + Sample scan report matching the structure from scanner.py. + + Returns a dictionary representing a typical scan output. + """ + return { + 'title': 'Test Scan', + 'scan_time': '2025-11-14T10:30:00Z', + 'scan_duration': 125.5, + 'config_file': '/app/configs/test.yaml', + 'sites': [ + { + 'name': 'Test Site', + 'ips': [ + { + 'address': '192.168.1.10', + 'expected': { + 'ping': True, + 'tcp_ports': [22, 80, 443], + 'udp_ports': [53], + 'services': [] + }, + 'actual': { + 'ping': True, + 'tcp_ports': [22, 80, 443, 8080], + 'udp_ports': [53], + 'services': [ + { + 'port': 22, + 'service': 'ssh', + 'product': 'OpenSSH', + 'version': '8.9p1', + 'extrainfo': 'Ubuntu', + 'ostype': 'Linux' + }, + { + 'port': 443, + 'service': 'https', + 'product': 'nginx', + 'version': '1.24.0', + 'extrainfo': '', + 'ostype': '', + 'http_info': { + 'protocol': 'https', + 'screenshot': 'screenshots/192_168_1_10_443.png', + 'certificate': { + 'subject': 'CN=example.com', + 'issuer': 'CN=Let\'s Encrypt Authority', + 'serial_number': '123456789', + 'not_valid_before': '2025-01-01T00:00:00Z', + 'not_valid_after': '2025-12-31T23:59:59Z', + 'days_until_expiry': 365, + 'sans': ['example.com', 'www.example.com'], + 'is_self_signed': False, + 'tls_versions': { + 'TLS 1.2': { + 'supported': True, + 'cipher_suites': [ + 'TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384', + 'TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256' + ] + }, + 'TLS 1.3': { + 'supported': True, + 'cipher_suites': [ + 'TLS_AES_256_GCM_SHA384', + 'TLS_AES_128_GCM_SHA256' + ] + } + } + } + } + }, + { + 'port': 80, + 'service': 'http', + 'product': 'nginx', + 'version': '1.24.0', + 'extrainfo': '', + 'ostype': '', + 'http_info': { + 'protocol': 'http', + 'screenshot': 'screenshots/192_168_1_10_80.png' + } + }, + { + 'port': 8080, + 'service': 'http', + 'product': 'Jetty', + 'version': '9.4.48', + 'extrainfo': '', + 'ostype': '' + } + ] + } + } + ] + } + ] + } + + +@pytest.fixture +def sample_config_file(tmp_path): + """ + Create a sample YAML config file for testing. + + Args: + tmp_path: pytest temporary directory fixture + + Returns: + Path to created config file + """ + config_data = { + 'title': 'Test Scan', + 'sites': [ + { + 'name': 'Test Site', + 'ips': [ + { + 'address': '192.168.1.10', + 'expected': { + 'ping': True, + 'tcp_ports': [22, 80, 443], + 'udp_ports': [53], + 'services': ['ssh', 'http', 'https'] + } + } + ] + } + ] + } + + config_file = tmp_path / 'test_config.yaml' + with open(config_file, 'w') as f: + yaml.dump(config_data, f) + + return str(config_file) + + +@pytest.fixture +def sample_invalid_config_file(tmp_path): + """ + Create an invalid config file for testing validation. + + Returns: + Path to invalid config file + """ + config_file = tmp_path / 'invalid_config.yaml' + with open(config_file, 'w') as f: + f.write("invalid: yaml: content: [missing closing bracket") + + return str(config_file) diff --git a/tests/test_scan_service.py b/tests/test_scan_service.py new file mode 100644 index 0000000..456c5bb --- /dev/null +++ b/tests/test_scan_service.py @@ -0,0 +1,402 @@ +""" +Unit tests for ScanService class. + +Tests scan lifecycle operations: trigger, get, list, delete, and database mapping. +""" + +import pytest + +from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel +from web.services.scan_service import ScanService + + +class TestScanServiceTrigger: + """Tests for triggering scans.""" + + def test_trigger_scan_valid_config(self, test_db, sample_config_file): + """Test triggering a scan with valid config file.""" + service = ScanService(test_db) + + scan_id = service.trigger_scan(sample_config_file, triggered_by='manual') + + # Verify scan created + assert scan_id is not None + assert isinstance(scan_id, int) + + # Verify scan in database + scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + assert scan is not None + assert scan.status == 'running' + assert scan.title == 'Test Scan' + assert scan.triggered_by == 'manual' + assert scan.config_file == sample_config_file + + def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file): + """Test triggering a scan with invalid config file.""" + service = ScanService(test_db) + + with pytest.raises(ValueError, match="Invalid config file"): + service.trigger_scan(sample_invalid_config_file) + + def test_trigger_scan_nonexistent_file(self, test_db): + """Test triggering a scan with nonexistent config file.""" + service = ScanService(test_db) + + with pytest.raises(ValueError, match="does not exist"): + service.trigger_scan('/nonexistent/config.yaml') + + def test_trigger_scan_with_schedule(self, test_db, sample_config_file): + """Test triggering a scan via schedule.""" + service = ScanService(test_db) + + scan_id = service.trigger_scan( + sample_config_file, + triggered_by='scheduled', + schedule_id=42 + ) + + scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + assert scan.triggered_by == 'scheduled' + assert scan.schedule_id == 42 + + +class TestScanServiceGet: + """Tests for retrieving scans.""" + + def test_get_scan_not_found(self, test_db): + """Test getting a nonexistent scan.""" + service = ScanService(test_db) + + result = service.get_scan(999) + assert result is None + + def test_get_scan_found(self, test_db, sample_config_file): + """Test getting an existing scan.""" + service = ScanService(test_db) + + # Create a scan + scan_id = service.trigger_scan(sample_config_file) + + # Retrieve it + result = service.get_scan(scan_id) + + assert result is not None + assert result['id'] == scan_id + assert result['title'] == 'Test Scan' + assert result['status'] == 'running' + assert 'sites' in result + + +class TestScanServiceList: + """Tests for listing scans.""" + + def test_list_scans_empty(self, test_db): + """Test listing scans when database is empty.""" + service = ScanService(test_db) + + result = service.list_scans(page=1, per_page=20) + + assert result.total == 0 + assert len(result.items) == 0 + assert result.pages == 0 + + def test_list_scans_with_data(self, test_db, sample_config_file): + """Test listing scans with multiple scans.""" + service = ScanService(test_db) + + # Create 3 scans + for i in range(3): + service.trigger_scan(sample_config_file, triggered_by='api') + + # List all scans + result = service.list_scans(page=1, per_page=20) + + assert result.total == 3 + assert len(result.items) == 3 + assert result.pages == 1 + + def test_list_scans_pagination(self, test_db, sample_config_file): + """Test pagination.""" + service = ScanService(test_db) + + # Create 5 scans + for i in range(5): + service.trigger_scan(sample_config_file) + + # Get page 1 (2 items per page) + result = service.list_scans(page=1, per_page=2) + assert len(result.items) == 2 + assert result.total == 5 + assert result.pages == 3 + assert result.has_next is True + + # Get page 2 + result = service.list_scans(page=2, per_page=2) + assert len(result.items) == 2 + assert result.has_prev is True + assert result.has_next is True + + # Get page 3 (last page) + result = service.list_scans(page=3, per_page=2) + assert len(result.items) == 1 + assert result.has_next is False + + def test_list_scans_filter_by_status(self, test_db, sample_config_file): + """Test filtering scans by status.""" + service = ScanService(test_db) + + # Create scans with different statuses + scan_id_1 = service.trigger_scan(sample_config_file) + scan_id_2 = service.trigger_scan(sample_config_file) + + # Mark one as completed + scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first() + scan.status = 'completed' + test_db.commit() + + # Filter by running + result = service.list_scans(status_filter='running') + assert result.total == 1 + + # Filter by completed + result = service.list_scans(status_filter='completed') + assert result.total == 1 + + def test_list_scans_invalid_status_filter(self, test_db): + """Test filtering with invalid status.""" + service = ScanService(test_db) + + with pytest.raises(ValueError, match="Invalid status"): + service.list_scans(status_filter='invalid_status') + + +class TestScanServiceDelete: + """Tests for deleting scans.""" + + def test_delete_scan_not_found(self, test_db): + """Test deleting a nonexistent scan.""" + service = ScanService(test_db) + + with pytest.raises(ValueError, match="not found"): + service.delete_scan(999) + + def test_delete_scan_success(self, test_db, sample_config_file): + """Test successful scan deletion.""" + service = ScanService(test_db) + + # Create a scan + scan_id = service.trigger_scan(sample_config_file) + + # Verify it exists + assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None + + # Delete it + result = service.delete_scan(scan_id) + assert result is True + + # Verify it's gone + assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None + + +class TestScanServiceStatus: + """Tests for scan status retrieval.""" + + def test_get_scan_status_not_found(self, test_db): + """Test getting status of nonexistent scan.""" + service = ScanService(test_db) + + result = service.get_scan_status(999) + assert result is None + + def test_get_scan_status_running(self, test_db, sample_config_file): + """Test getting status of running scan.""" + service = ScanService(test_db) + + scan_id = service.trigger_scan(sample_config_file) + status = service.get_scan_status(scan_id) + + assert status is not None + assert status['scan_id'] == scan_id + assert status['status'] == 'running' + assert status['progress'] == 'In progress' + assert status['title'] == 'Test Scan' + + def test_get_scan_status_completed(self, test_db, sample_config_file): + """Test getting status of completed scan.""" + service = ScanService(test_db) + + # Create and mark as completed + scan_id = service.trigger_scan(sample_config_file) + scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + scan.status = 'completed' + scan.duration = 125.5 + test_db.commit() + + status = service.get_scan_status(scan_id) + + assert status['status'] == 'completed' + assert status['progress'] == 'Complete' + assert status['duration'] == 125.5 + + +class TestScanServiceDatabaseMapping: + """Tests for mapping scan reports to database models.""" + + def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report): + """Test saving a complete scan report to database.""" + service = ScanService(test_db) + + # Create a scan + scan_id = service.trigger_scan(sample_config_file) + + # Save report to database + service._save_scan_to_db(sample_scan_report, scan_id, status='completed') + + # Verify scan updated + scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + assert scan.status == 'completed' + assert scan.duration == 125.5 + + # Verify sites created + sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all() + assert len(sites) == 1 + assert sites[0].site_name == 'Test Site' + + # Verify IPs created + ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all() + assert len(ips) == 1 + assert ips[0].ip_address == '192.168.1.10' + assert ips[0].ping_expected is True + assert ips[0].ping_actual is True + + # Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53) + ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all() + assert len(ports) == 5 # 4 TCP + 1 UDP + + # Verify TCP ports + tcp_ports = [p for p in ports if p.protocol == 'tcp'] + assert len(tcp_ports) == 4 + tcp_port_numbers = sorted([p.port for p in tcp_ports]) + assert tcp_port_numbers == [22, 80, 443, 8080] + + # Verify UDP ports + udp_ports = [p for p in ports if p.protocol == 'udp'] + assert len(udp_ports) == 1 + assert udp_ports[0].port == 53 + + # Verify services created + services = test_db.query(ScanServiceModel).filter( + ScanServiceModel.scan_id == scan_id + ).all() + assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080) + + # Find HTTPS service + https_service = next( + (s for s in services if s.service_name == 'https'), None + ) + assert https_service is not None + assert https_service.product == 'nginx' + assert https_service.version == '1.24.0' + assert https_service.http_protocol == 'https' + assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png' + + def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report): + """Test that expected vs actual ports are correctly flagged.""" + service = ScanService(test_db) + + scan_id = service.trigger_scan(sample_config_file) + service._save_scan_to_db(sample_scan_report, scan_id) + + # Check TCP ports + tcp_ports = test_db.query(ScanPort).filter( + ScanPort.scan_id == scan_id, + ScanPort.protocol == 'tcp' + ).all() + + # Ports 22, 80, 443 were expected + expected_ports = {22, 80, 443} + for port in tcp_ports: + if port.port in expected_ports: + assert port.expected is True, f"Port {port.port} should be expected" + else: + # Port 8080 was not expected + assert port.expected is False, f"Port {port.port} should not be expected" + + def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report): + """Test that certificate and TLS data are correctly mapped.""" + service = ScanService(test_db) + + scan_id = service.trigger_scan(sample_config_file) + service._save_scan_to_db(sample_scan_report, scan_id) + + # Find HTTPS service + https_service = test_db.query(ScanServiceModel).filter( + ScanServiceModel.scan_id == scan_id, + ScanServiceModel.service_name == 'https' + ).first() + + assert https_service is not None + + # Verify certificate created + assert len(https_service.certificates) == 1 + cert = https_service.certificates[0] + + assert cert.subject == 'CN=example.com' + assert cert.issuer == "CN=Let's Encrypt Authority" + assert cert.days_until_expiry == 365 + assert cert.is_self_signed is False + + # Verify SANs + import json + sans = json.loads(cert.sans) + assert 'example.com' in sans + assert 'www.example.com' in sans + + # Verify TLS versions + assert len(cert.tls_versions) == 2 + + tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None) + assert tls_12 is not None + assert tls_12.supported is True + + tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None) + assert tls_13 is not None + assert tls_13.supported is True + + def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report): + """Test retrieving scan with all nested relationships.""" + service = ScanService(test_db) + + scan_id = service.trigger_scan(sample_config_file) + service._save_scan_to_db(sample_scan_report, scan_id) + + # Get full scan details + result = service.get_scan(scan_id) + + assert result is not None + assert len(result['sites']) == 1 + + site = result['sites'][0] + assert site['name'] == 'Test Site' + assert len(site['ips']) == 1 + + ip = site['ips'][0] + assert ip['address'] == '192.168.1.10' + assert len(ip['ports']) == 5 # 4 TCP + 1 UDP + + # Find HTTPS port + https_port = next( + (p for p in ip['ports'] if p['port'] == 443), None + ) + assert https_port is not None + assert len(https_port['services']) == 1 + + service_data = https_port['services'][0] + assert service_data['service_name'] == 'https' + assert 'certificates' in service_data + assert len(service_data['certificates']) == 1 + + cert = service_data['certificates'][0] + assert cert['subject'] == 'CN=example.com' + assert 'tls_versions' in cert + assert len(cert['tls_versions']) == 2 diff --git a/web/services/__init__.py b/web/services/__init__.py new file mode 100644 index 0000000..d71ac14 --- /dev/null +++ b/web/services/__init__.py @@ -0,0 +1,10 @@ +""" +Services package for SneakyScanner web application. + +This package contains business logic layer services that orchestrate +operations between API endpoints and database models. +""" + +from web.services.scan_service import ScanService + +__all__ = ['ScanService'] diff --git a/web/services/scan_service.py b/web/services/scan_service.py new file mode 100644 index 0000000..790fa35 --- /dev/null +++ b/web/services/scan_service.py @@ -0,0 +1,589 @@ +""" +Scan service for managing scan operations and database integration. + +This service handles the business logic for triggering scans, retrieving +scan results, and mapping scanner output to database models. +""" + +import json +import logging +import shutil +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from sqlalchemy.orm import Session, joinedload + +from web.models import ( + Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel, + ScanCertificate, ScanTLSVersion +) +from web.utils.pagination import paginate, PaginatedResult +from web.utils.validators import validate_config_file, validate_scan_status + +logger = logging.getLogger(__name__) + + +class ScanService: + """ + Service for managing scan operations. + + Handles scan lifecycle: triggering, status tracking, result storage, + and cleanup. + """ + + def __init__(self, db_session: Session): + """ + Initialize scan service. + + Args: + db_session: SQLAlchemy database session + """ + self.db = db_session + + def trigger_scan(self, config_file: str, triggered_by: str = 'manual', + schedule_id: Optional[int] = None) -> int: + """ + Trigger a new scan. + + Creates a Scan record in the database with status='running' and + queues the scan for background execution. + + Args: + config_file: Path to YAML configuration file + triggered_by: Source that triggered scan (manual, scheduled, api) + schedule_id: Optional schedule ID if triggered by schedule + + Returns: + Scan ID of the created scan + + Raises: + ValueError: If config file is invalid + """ + # Validate config file + is_valid, error_msg = validate_config_file(config_file) + if not is_valid: + raise ValueError(f"Invalid config file: {error_msg}") + + # Load config to get title + import yaml + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + # Create scan record + scan = Scan( + timestamp=datetime.utcnow(), + status='running', + config_file=config_file, + title=config.get('title', 'Untitled Scan'), + triggered_by=triggered_by, + schedule_id=schedule_id, + created_at=datetime.utcnow() + ) + + self.db.add(scan) + self.db.commit() + self.db.refresh(scan) + + logger.info(f"Scan {scan.id} triggered via {triggered_by}") + + # NOTE: Background job queuing will be implemented in Step 3 + # For now, just return the scan ID + return scan.id + + def get_scan(self, scan_id: int) -> Optional[Dict[str, Any]]: + """ + Get scan details with all related data. + + Args: + scan_id: Scan ID to retrieve + + Returns: + Dictionary with scan data including sites, IPs, ports, services, etc. + Returns None if scan not found. + """ + # Query with eager loading of all relationships + scan = ( + self.db.query(Scan) + .options( + joinedload(Scan.sites).joinedload(ScanSite.ips).joinedload(ScanIP.ports), + joinedload(Scan.ports).joinedload(ScanPort.services), + joinedload(Scan.services).joinedload(ScanServiceModel.certificates), + joinedload(Scan.certificates).joinedload(ScanCertificate.tls_versions) + ) + .filter(Scan.id == scan_id) + .first() + ) + + if not scan: + return None + + # Convert to dictionary + return self._scan_to_dict(scan) + + def list_scans(self, page: int = 1, per_page: int = 20, + status_filter: Optional[str] = None) -> PaginatedResult: + """ + List scans with pagination and optional filtering. + + Args: + page: Page number (1-indexed) + per_page: Items per page + status_filter: Optional filter by status (running, completed, failed) + + Returns: + PaginatedResult with scan list and metadata + """ + # Build query + query = self.db.query(Scan).order_by(Scan.timestamp.desc()) + + # Apply status filter if provided + if status_filter: + is_valid, error_msg = validate_scan_status(status_filter) + if not is_valid: + raise ValueError(error_msg) + query = query.filter(Scan.status == status_filter) + + # Paginate + result = paginate(query, page=page, per_page=per_page) + + # Convert scans to dictionaries (summary only, not full details) + result.items = [self._scan_to_summary_dict(scan) for scan in result.items] + + return result + + def delete_scan(self, scan_id: int) -> bool: + """ + Delete a scan and all associated files. + + Removes: + - Database record (cascade deletes related records) + - JSON report file + - HTML report file + - ZIP archive file + - Screenshot directory + + Args: + scan_id: Scan ID to delete + + Returns: + True if deleted successfully + + Raises: + ValueError: If scan not found + """ + scan = self.db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + raise ValueError(f"Scan {scan_id} not found") + + logger.info(f"Deleting scan {scan_id}") + + # Delete files (handle missing files gracefully) + files_to_delete = [ + scan.json_path, + scan.html_path, + scan.zip_path + ] + + for file_path in files_to_delete: + if file_path: + try: + Path(file_path).unlink() + logger.debug(f"Deleted file: {file_path}") + except FileNotFoundError: + logger.warning(f"File not found (already deleted?): {file_path}") + except Exception as e: + logger.error(f"Error deleting file {file_path}: {e}") + + # Delete screenshot directory + if scan.screenshot_dir: + try: + shutil.rmtree(scan.screenshot_dir) + logger.debug(f"Deleted directory: {scan.screenshot_dir}") + except FileNotFoundError: + logger.warning(f"Directory not found (already deleted?): {scan.screenshot_dir}") + except Exception as e: + logger.error(f"Error deleting directory {scan.screenshot_dir}: {e}") + + # Delete database record (cascade handles related records) + self.db.delete(scan) + self.db.commit() + + logger.info(f"Scan {scan_id} deleted successfully") + return True + + def get_scan_status(self, scan_id: int) -> Optional[Dict[str, Any]]: + """ + Get current scan status and progress. + + Args: + scan_id: Scan ID + + Returns: + Dictionary with status information, or None if scan not found + """ + scan = self.db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + return None + + status_info = { + 'scan_id': scan.id, + 'status': scan.status, + 'title': scan.title, + 'started_at': scan.timestamp.isoformat() if scan.timestamp else None, + 'duration': scan.duration, + 'triggered_by': scan.triggered_by + } + + # Add progress estimate based on status + if scan.status == 'running': + status_info['progress'] = 'In progress' + elif scan.status == 'completed': + status_info['progress'] = 'Complete' + elif scan.status == 'failed': + status_info['progress'] = 'Failed' + + return status_info + + def _save_scan_to_db(self, report: Dict[str, Any], scan_id: int, + status: str = 'completed') -> None: + """ + Save scan results to database. + + Updates the Scan record and creates all related records + (sites, IPs, ports, services, certificates, TLS versions). + + Args: + report: Scan report dictionary from scanner + scan_id: Scan ID to update + status: Final scan status (completed or failed) + """ + scan = self.db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + raise ValueError(f"Scan {scan_id} not found") + + # Update scan record + scan.status = status + scan.duration = report.get('scan_duration') + + # Map report data to database models + self._map_report_to_models(report, scan) + + self.db.commit() + logger.info(f"Scan {scan_id} saved to database with status '{status}'") + + def _map_report_to_models(self, report: Dict[str, Any], scan_obj: Scan) -> None: + """ + Map JSON report structure to database models. + + Creates records for sites, IPs, ports, services, certificates, and TLS versions. + Processes nested relationships in order to handle foreign keys correctly. + + Args: + report: Scan report dictionary + scan_obj: Scan database object to attach records to + """ + logger.debug(f"Mapping report to database models for scan {scan_obj.id}") + + # Process each site + for site_data in report.get('sites', []): + # Create ScanSite record + site = ScanSite( + scan_id=scan_obj.id, + site_name=site_data['name'] + ) + self.db.add(site) + self.db.flush() # Get site.id for foreign key + + # Process each IP in this site + for ip_data in site_data.get('ips', []): + # Create ScanIP record + ip = ScanIP( + scan_id=scan_obj.id, + site_id=site.id, + ip_address=ip_data['address'], + ping_expected=ip_data.get('expected', {}).get('ping'), + ping_actual=ip_data.get('actual', {}).get('ping') + ) + self.db.add(ip) + self.db.flush() + + # Process TCP ports + expected_tcp = set(ip_data.get('expected', {}).get('tcp_ports', [])) + actual_tcp = ip_data.get('actual', {}).get('tcp_ports', []) + + for port_num in actual_tcp: + port = ScanPort( + scan_id=scan_obj.id, + ip_id=ip.id, + port=port_num, + protocol='tcp', + expected=(port_num in expected_tcp), + state='open' + ) + self.db.add(port) + self.db.flush() + + # Find service for this port + service_data = self._find_service_for_port( + ip_data.get('actual', {}).get('services', []), + port_num + ) + + if service_data: + # Create ScanService record + service = ScanServiceModel( + scan_id=scan_obj.id, + port_id=port.id, + service_name=service_data.get('service'), + product=service_data.get('product'), + version=service_data.get('version'), + extrainfo=service_data.get('extrainfo'), + ostype=service_data.get('ostype'), + http_protocol=service_data.get('http_info', {}).get('protocol'), + screenshot_path=service_data.get('http_info', {}).get('screenshot') + ) + self.db.add(service) + self.db.flush() + + # Process certificate and TLS info if present + http_info = service_data.get('http_info', {}) + if http_info.get('certificate'): + self._process_certificate( + http_info['certificate'], + scan_obj.id, + service.id + ) + + # Process UDP ports + expected_udp = set(ip_data.get('expected', {}).get('udp_ports', [])) + actual_udp = ip_data.get('actual', {}).get('udp_ports', []) + + for port_num in actual_udp: + port = ScanPort( + scan_id=scan_obj.id, + ip_id=ip.id, + port=port_num, + protocol='udp', + expected=(port_num in expected_udp), + state='open' + ) + self.db.add(port) + + logger.debug(f"Report mapping complete for scan {scan_obj.id}") + + def _find_service_for_port(self, services: List[Dict], port: int) -> Optional[Dict]: + """ + Find service data for a specific port. + + Args: + services: List of service dictionaries + port: Port number to find + + Returns: + Service dictionary if found, None otherwise + """ + for service in services: + if service.get('port') == port: + return service + return None + + def _process_certificate(self, cert_data: Dict[str, Any], scan_id: int, + service_id: int) -> None: + """ + Process certificate and TLS version data. + + Args: + cert_data: Certificate data dictionary + scan_id: Scan ID + service_id: Service ID + """ + # Create ScanCertificate record + cert = ScanCertificate( + scan_id=scan_id, + service_id=service_id, + subject=cert_data.get('subject'), + issuer=cert_data.get('issuer'), + serial_number=cert_data.get('serial_number'), + not_valid_before=self._parse_datetime(cert_data.get('not_valid_before')), + not_valid_after=self._parse_datetime(cert_data.get('not_valid_after')), + days_until_expiry=cert_data.get('days_until_expiry'), + sans=json.dumps(cert_data.get('sans', [])), + is_self_signed=cert_data.get('is_self_signed', False) + ) + self.db.add(cert) + self.db.flush() + + # Process TLS versions + tls_versions = cert_data.get('tls_versions', {}) + for version, version_data in tls_versions.items(): + tls = ScanTLSVersion( + scan_id=scan_id, + certificate_id=cert.id, + tls_version=version, + supported=version_data.get('supported', False), + cipher_suites=json.dumps(version_data.get('cipher_suites', [])) + ) + self.db.add(tls) + + def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]: + """ + Parse ISO datetime string. + + Args: + date_str: ISO format datetime string + + Returns: + datetime object or None if parsing fails + """ + if not date_str: + return None + + try: + # Handle ISO format with 'Z' suffix + if date_str.endswith('Z'): + date_str = date_str[:-1] + '+00:00' + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError) as e: + logger.warning(f"Failed to parse datetime '{date_str}': {e}") + return None + + def _scan_to_dict(self, scan: Scan) -> Dict[str, Any]: + """ + Convert Scan object to dictionary with full details. + + Args: + scan: Scan database object + + Returns: + Dictionary representation with all related data + """ + return { + 'id': scan.id, + 'timestamp': scan.timestamp.isoformat() if scan.timestamp else None, + 'duration': scan.duration, + 'status': scan.status, + 'title': scan.title, + 'config_file': scan.config_file, + 'json_path': scan.json_path, + 'html_path': scan.html_path, + 'zip_path': scan.zip_path, + 'screenshot_dir': scan.screenshot_dir, + 'triggered_by': scan.triggered_by, + 'created_at': scan.created_at.isoformat() if scan.created_at else None, + 'sites': [self._site_to_dict(site) for site in scan.sites] + } + + def _scan_to_summary_dict(self, scan: Scan) -> Dict[str, Any]: + """ + Convert Scan object to summary dictionary (no related data). + + Args: + scan: Scan database object + + Returns: + Summary dictionary + """ + return { + 'id': scan.id, + 'timestamp': scan.timestamp.isoformat() if scan.timestamp else None, + 'duration': scan.duration, + 'status': scan.status, + 'title': scan.title, + 'triggered_by': scan.triggered_by, + 'created_at': scan.created_at.isoformat() if scan.created_at else None + } + + def _site_to_dict(self, site: ScanSite) -> Dict[str, Any]: + """Convert ScanSite to dictionary.""" + return { + 'id': site.id, + 'name': site.site_name, + 'ips': [self._ip_to_dict(ip) for ip in site.ips] + } + + def _ip_to_dict(self, ip: ScanIP) -> Dict[str, Any]: + """Convert ScanIP to dictionary.""" + return { + 'id': ip.id, + 'address': ip.ip_address, + 'ping_expected': ip.ping_expected, + 'ping_actual': ip.ping_actual, + 'ports': [self._port_to_dict(port) for port in ip.ports] + } + + def _port_to_dict(self, port: ScanPort) -> Dict[str, Any]: + """Convert ScanPort to dictionary.""" + return { + 'id': port.id, + 'port': port.port, + 'protocol': port.protocol, + 'state': port.state, + 'expected': port.expected, + 'services': [self._service_to_dict(svc) for svc in port.services] + } + + def _service_to_dict(self, service: ScanServiceModel) -> Dict[str, Any]: + """Convert ScanService to dictionary.""" + result = { + 'id': service.id, + 'service_name': service.service_name, + 'product': service.product, + 'version': service.version, + 'extrainfo': service.extrainfo, + 'ostype': service.ostype, + 'http_protocol': service.http_protocol, + 'screenshot_path': service.screenshot_path + } + + # Add certificate info if present + if service.certificates: + result['certificates'] = [ + self._certificate_to_dict(cert) for cert in service.certificates + ] + + return result + + def _certificate_to_dict(self, cert: ScanCertificate) -> Dict[str, Any]: + """Convert ScanCertificate to dictionary.""" + result = { + 'id': cert.id, + 'subject': cert.subject, + 'issuer': cert.issuer, + 'serial_number': cert.serial_number, + 'not_valid_before': cert.not_valid_before.isoformat() if cert.not_valid_before else None, + 'not_valid_after': cert.not_valid_after.isoformat() if cert.not_valid_after else None, + 'days_until_expiry': cert.days_until_expiry, + 'is_self_signed': cert.is_self_signed + } + + # Parse SANs from JSON + if cert.sans: + try: + result['sans'] = json.loads(cert.sans) + except json.JSONDecodeError: + result['sans'] = [] + + # Add TLS versions + result['tls_versions'] = [ + self._tls_version_to_dict(tls) for tls in cert.tls_versions + ] + + return result + + def _tls_version_to_dict(self, tls: ScanTLSVersion) -> Dict[str, Any]: + """Convert ScanTLSVersion to dictionary.""" + result = { + 'id': tls.id, + 'tls_version': tls.tls_version, + 'supported': tls.supported + } + + # Parse cipher suites from JSON + if tls.cipher_suites: + try: + result['cipher_suites'] = json.loads(tls.cipher_suites) + except json.JSONDecodeError: + result['cipher_suites'] = [] + + return result diff --git a/web/utils/pagination.py b/web/utils/pagination.py new file mode 100644 index 0000000..1916080 --- /dev/null +++ b/web/utils/pagination.py @@ -0,0 +1,158 @@ +""" +Pagination utilities for SneakyScanner web application. + +Provides helper functions for paginating SQLAlchemy queries. +""" + +from typing import Any, Dict, List +from sqlalchemy.orm import Query + + +class PaginatedResult: + """Container for paginated query results.""" + + def __init__(self, items: List[Any], total: int, page: int, per_page: int): + """ + Initialize paginated result. + + Args: + items: List of items for current page + total: Total number of items across all pages + page: Current page number (1-indexed) + per_page: Number of items per page + """ + self.items = items + self.total = total + self.page = page + self.per_page = per_page + + @property + def pages(self) -> int: + """Calculate total number of pages.""" + if self.per_page == 0: + return 0 + return (self.total + self.per_page - 1) // self.per_page + + @property + def has_prev(self) -> bool: + """Check if there is a previous page.""" + return self.page > 1 + + @property + def has_next(self) -> bool: + """Check if there is a next page.""" + return self.page < self.pages + + @property + def prev_page(self) -> int: + """Get previous page number.""" + return self.page - 1 if self.has_prev else None + + @property + def next_page(self) -> int: + """Get next page number.""" + return self.page + 1 if self.has_next else None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert to dictionary for API responses. + + Returns: + Dictionary with pagination metadata and items + """ + return { + 'items': self.items, + 'total': self.total, + 'page': self.page, + 'per_page': self.per_page, + 'pages': self.pages, + 'has_prev': self.has_prev, + 'has_next': self.has_next, + 'prev_page': self.prev_page, + 'next_page': self.next_page, + } + + +def paginate(query: Query, page: int = 1, per_page: int = 20, + max_per_page: int = 100) -> PaginatedResult: + """ + Paginate a SQLAlchemy query. + + Args: + query: SQLAlchemy query to paginate + page: Page number (1-indexed, default: 1) + per_page: Items per page (default: 20) + max_per_page: Maximum items per page (default: 100) + + Returns: + PaginatedResult with items and pagination metadata + + Examples: + >>> from web.models import Scan + >>> query = db.query(Scan).order_by(Scan.timestamp.desc()) + >>> result = paginate(query, page=1, per_page=20) + >>> scans = result.items + >>> total_pages = result.pages + """ + # Validate and sanitize parameters + page = max(1, page) # Page must be at least 1 + per_page = max(1, min(per_page, max_per_page)) # Clamp per_page + + # Get total count + total = query.count() + + # Calculate offset + offset = (page - 1) * per_page + + # Execute query with limit and offset + items = query.limit(per_page).offset(offset).all() + + return PaginatedResult( + items=items, + total=total, + page=page, + per_page=per_page + ) + + +def validate_page_params(page: Any, per_page: Any, + max_per_page: int = 100) -> tuple[int, int]: + """ + Validate and sanitize pagination parameters. + + Args: + page: Page number (any type, will be converted to int) + per_page: Items per page (any type, will be converted to int) + max_per_page: Maximum items per page (default: 100) + + Returns: + Tuple of (validated_page, validated_per_page) + + Examples: + >>> validate_page_params('2', '50') + (2, 50) + >>> validate_page_params(-1, 200) + (1, 100) + >>> validate_page_params(None, None) + (1, 20) + """ + # Default values + default_page = 1 + default_per_page = 20 + + # Convert to int, use default if invalid + try: + page = int(page) if page is not None else default_page + except (ValueError, TypeError): + page = default_page + + try: + per_page = int(per_page) if per_page is not None else default_per_page + except (ValueError, TypeError): + per_page = default_per_page + + # Validate ranges + page = max(1, page) + per_page = max(1, min(per_page, max_per_page)) + + return page, per_page diff --git a/web/utils/validators.py b/web/utils/validators.py new file mode 100644 index 0000000..379ed35 --- /dev/null +++ b/web/utils/validators.py @@ -0,0 +1,284 @@ +""" +Input validation utilities for SneakyScanner web application. + +Provides validation functions for API inputs, file paths, and data integrity. +""" + +import os +from pathlib import Path +from typing import Optional + +import yaml + + +def validate_config_file(file_path: str) -> tuple[bool, Optional[str]]: + """ + Validate that a configuration file exists and is valid YAML. + + Args: + file_path: Path to configuration file + + Returns: + Tuple of (is_valid, error_message) + If valid, returns (True, None) + If invalid, returns (False, error_message) + + Examples: + >>> validate_config_file('/app/configs/example.yaml') + (True, None) + >>> validate_config_file('/nonexistent.yaml') + (False, 'File does not exist: /nonexistent.yaml') + """ + # Check if path is provided + if not file_path: + return False, 'Config file path is required' + + # Convert to Path object + path = Path(file_path) + + # Check if file exists + if not path.exists(): + return False, f'File does not exist: {file_path}' + + # Check if it's a file (not directory) + if not path.is_file(): + return False, f'Path is not a file: {file_path}' + + # Check file extension + if path.suffix.lower() not in ['.yaml', '.yml']: + return False, f'File must be YAML (.yaml or .yml): {file_path}' + + # Try to parse as YAML + try: + with open(path, 'r') as f: + config = yaml.safe_load(f) + + # Check if it's a dictionary (basic structure validation) + if not isinstance(config, dict): + return False, 'Config file must contain a YAML dictionary' + + # Check for required top-level keys + if 'title' not in config: + return False, 'Config file missing required "title" field' + + if 'sites' not in config: + return False, 'Config file missing required "sites" field' + + # Validate sites structure + if not isinstance(config['sites'], list): + return False, '"sites" must be a list' + + if len(config['sites']) == 0: + return False, '"sites" list cannot be empty' + + except yaml.YAMLError as e: + return False, f'Invalid YAML syntax: {str(e)}' + except Exception as e: + return False, f'Error reading config file: {str(e)}' + + return True, None + + +def validate_scan_status(status: str) -> tuple[bool, Optional[str]]: + """ + Validate scan status value. + + Args: + status: Status string to validate + + Returns: + Tuple of (is_valid, error_message) + + Examples: + >>> validate_scan_status('running') + (True, None) + >>> validate_scan_status('invalid') + (False, 'Invalid status: invalid. Must be one of: running, completed, failed') + """ + valid_statuses = ['running', 'completed', 'failed'] + + if status not in valid_statuses: + return False, f'Invalid status: {status}. Must be one of: {", ".join(valid_statuses)}' + + return True, None + + +def validate_triggered_by(triggered_by: str) -> tuple[bool, Optional[str]]: + """ + Validate triggered_by value. + + Args: + triggered_by: Source that triggered the scan + + Returns: + Tuple of (is_valid, error_message) + + Examples: + >>> validate_triggered_by('manual') + (True, None) + >>> validate_triggered_by('api') + (True, None) + """ + valid_sources = ['manual', 'scheduled', 'api'] + + if triggered_by not in valid_sources: + return False, f'Invalid triggered_by: {triggered_by}. Must be one of: {", ".join(valid_sources)}' + + return True, None + + +def validate_scan_id(scan_id: any) -> tuple[bool, Optional[str]]: + """ + Validate scan ID is a positive integer. + + Args: + scan_id: Scan ID to validate + + Returns: + Tuple of (is_valid, error_message) + + Examples: + >>> validate_scan_id(42) + (True, None) + >>> validate_scan_id('42') + (True, None) + >>> validate_scan_id(-1) + (False, 'Scan ID must be a positive integer') + """ + try: + scan_id_int = int(scan_id) + if scan_id_int <= 0: + return False, 'Scan ID must be a positive integer' + except (ValueError, TypeError): + return False, f'Invalid scan ID: {scan_id}' + + return True, None + + +def validate_file_path(file_path: str, must_exist: bool = True) -> tuple[bool, Optional[str]]: + """ + Validate a file path. + + Args: + file_path: Path to validate + must_exist: If True, file must exist. If False, only validate format. + + Returns: + Tuple of (is_valid, error_message) + + Examples: + >>> validate_file_path('/app/output/scan.json', must_exist=False) + (True, None) + >>> validate_file_path('', must_exist=False) + (False, 'File path is required') + """ + if not file_path: + return False, 'File path is required' + + # Check for path traversal attempts + if '..' in file_path: + return False, 'Path traversal not allowed' + + if must_exist: + path = Path(file_path) + if not path.exists(): + return False, f'File does not exist: {file_path}' + if not path.is_file(): + return False, f'Path is not a file: {file_path}' + + return True, None + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize a filename by removing/replacing unsafe characters. + + Args: + filename: Original filename + + Returns: + Sanitized filename safe for filesystem + + Examples: + >>> sanitize_filename('my scan.json') + 'my_scan.json' + >>> sanitize_filename('../../etc/passwd') + 'etc_passwd' + """ + # Remove path components + filename = os.path.basename(filename) + + # Replace unsafe characters with underscore + unsafe_chars = ['/', '\\', '..', ' ', ':', '*', '?', '"', '<', '>', '|'] + for char in unsafe_chars: + filename = filename.replace(char, '_') + + # Remove leading/trailing underscores and dots + filename = filename.strip('_.') + + # Ensure filename is not empty + if not filename: + filename = 'unnamed' + + return filename + + +def validate_port(port: any) -> tuple[bool, Optional[str]]: + """ + Validate port number. + + Args: + port: Port number to validate + + Returns: + Tuple of (is_valid, error_message) + + Examples: + >>> validate_port(443) + (True, None) + >>> validate_port(70000) + (False, 'Port must be between 1 and 65535') + """ + try: + port_int = int(port) + if port_int < 1 or port_int > 65535: + return False, 'Port must be between 1 and 65535' + except (ValueError, TypeError): + return False, f'Invalid port: {port}' + + return True, None + + +def validate_ip_address(ip: str) -> tuple[bool, Optional[str]]: + """ + Validate IPv4 address format (basic validation). + + Args: + ip: IP address string + + Returns: + Tuple of (is_valid, error_message) + + Examples: + >>> validate_ip_address('192.168.1.1') + (True, None) + >>> validate_ip_address('256.1.1.1') + (False, 'Invalid IP address format') + """ + if not ip: + return False, 'IP address is required' + + # Basic IPv4 validation + parts = ip.split('.') + if len(parts) != 4: + return False, 'Invalid IP address format' + + try: + for part in parts: + num = int(part) + if num < 0 or num > 255: + return False, 'Invalid IP address format' + except (ValueError, TypeError): + return False, 'Invalid IP address format' + + return True, None