""" Unit tests for ScanService class. Tests scan lifecycle operations: trigger, get, list, delete, and database mapping. """ import pytest from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel from web.services.scan_service import ScanService class TestScanServiceTrigger: """Tests for triggering scans.""" def test_trigger_scan_valid_config(self, test_db, sample_config_file): """Test triggering a scan with valid config file.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_config_file, triggered_by='manual') # Verify scan created assert scan_id is not None assert isinstance(scan_id, int) # Verify scan in database scan = test_db.query(Scan).filter(Scan.id == scan_id).first() assert scan is not None assert scan.status == 'running' assert scan.title == 'Test Scan' assert scan.triggered_by == 'manual' assert scan.config_file == sample_config_file def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file): """Test triggering a scan with invalid config file.""" service = ScanService(test_db) with pytest.raises(ValueError, match="Invalid config file"): service.trigger_scan(sample_invalid_config_file) def test_trigger_scan_nonexistent_file(self, test_db): """Test triggering a scan with nonexistent config file.""" service = ScanService(test_db) with pytest.raises(ValueError, match="does not exist"): service.trigger_scan('/nonexistent/config.yaml') def test_trigger_scan_with_schedule(self, test_db, sample_config_file): """Test triggering a scan via schedule.""" service = ScanService(test_db) scan_id = service.trigger_scan( sample_config_file, triggered_by='scheduled', schedule_id=42 ) scan = test_db.query(Scan).filter(Scan.id == scan_id).first() assert scan.triggered_by == 'scheduled' assert scan.schedule_id == 42 class TestScanServiceGet: """Tests for retrieving scans.""" def test_get_scan_not_found(self, test_db): """Test getting a nonexistent scan.""" service = ScanService(test_db) result = service.get_scan(999) assert result is None def test_get_scan_found(self, test_db, sample_config_file): """Test getting an existing scan.""" service = ScanService(test_db) # Create a scan scan_id = service.trigger_scan(sample_config_file) # Retrieve it result = service.get_scan(scan_id) assert result is not None assert result['id'] == scan_id assert result['title'] == 'Test Scan' assert result['status'] == 'running' assert 'sites' in result class TestScanServiceList: """Tests for listing scans.""" def test_list_scans_empty(self, test_db): """Test listing scans when database is empty.""" service = ScanService(test_db) result = service.list_scans(page=1, per_page=20) assert result.total == 0 assert len(result.items) == 0 assert result.pages == 0 def test_list_scans_with_data(self, test_db, sample_config_file): """Test listing scans with multiple scans.""" service = ScanService(test_db) # Create 3 scans for i in range(3): service.trigger_scan(sample_config_file, triggered_by='api') # List all scans result = service.list_scans(page=1, per_page=20) assert result.total == 3 assert len(result.items) == 3 assert result.pages == 1 def test_list_scans_pagination(self, test_db, sample_config_file): """Test pagination.""" service = ScanService(test_db) # Create 5 scans for i in range(5): service.trigger_scan(sample_config_file) # Get page 1 (2 items per page) result = service.list_scans(page=1, per_page=2) assert len(result.items) == 2 assert result.total == 5 assert result.pages == 3 assert result.has_next is True # Get page 2 result = service.list_scans(page=2, per_page=2) assert len(result.items) == 2 assert result.has_prev is True assert result.has_next is True # Get page 3 (last page) result = service.list_scans(page=3, per_page=2) assert len(result.items) == 1 assert result.has_next is False def test_list_scans_filter_by_status(self, test_db, sample_config_file): """Test filtering scans by status.""" service = ScanService(test_db) # Create scans with different statuses scan_id_1 = service.trigger_scan(sample_config_file) scan_id_2 = service.trigger_scan(sample_config_file) # Mark one as completed scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first() scan.status = 'completed' test_db.commit() # Filter by running result = service.list_scans(status_filter='running') assert result.total == 1 # Filter by completed result = service.list_scans(status_filter='completed') assert result.total == 1 def test_list_scans_invalid_status_filter(self, test_db): """Test filtering with invalid status.""" service = ScanService(test_db) with pytest.raises(ValueError, match="Invalid status"): service.list_scans(status_filter='invalid_status') class TestScanServiceDelete: """Tests for deleting scans.""" def test_delete_scan_not_found(self, test_db): """Test deleting a nonexistent scan.""" service = ScanService(test_db) with pytest.raises(ValueError, match="not found"): service.delete_scan(999) def test_delete_scan_success(self, test_db, sample_config_file): """Test successful scan deletion.""" service = ScanService(test_db) # Create a scan scan_id = service.trigger_scan(sample_config_file) # Verify it exists assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None # Delete it result = service.delete_scan(scan_id) assert result is True # Verify it's gone assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None class TestScanServiceStatus: """Tests for scan status retrieval.""" def test_get_scan_status_not_found(self, test_db): """Test getting status of nonexistent scan.""" service = ScanService(test_db) result = service.get_scan_status(999) assert result is None def test_get_scan_status_running(self, test_db, sample_config_file): """Test getting status of running scan.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_config_file) status = service.get_scan_status(scan_id) assert status is not None assert status['scan_id'] == scan_id assert status['status'] == 'running' assert status['progress'] == 'In progress' assert status['title'] == 'Test Scan' def test_get_scan_status_completed(self, test_db, sample_config_file): """Test getting status of completed scan.""" service = ScanService(test_db) # Create and mark as completed scan_id = service.trigger_scan(sample_config_file) scan = test_db.query(Scan).filter(Scan.id == scan_id).first() scan.status = 'completed' scan.duration = 125.5 test_db.commit() status = service.get_scan_status(scan_id) assert status['status'] == 'completed' assert status['progress'] == 'Complete' assert status['duration'] == 125.5 class TestScanServiceDatabaseMapping: """Tests for mapping scan reports to database models.""" def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report): """Test saving a complete scan report to database.""" service = ScanService(test_db) # Create a scan scan_id = service.trigger_scan(sample_config_file) # Save report to database service._save_scan_to_db(sample_scan_report, scan_id, status='completed') # Verify scan updated scan = test_db.query(Scan).filter(Scan.id == scan_id).first() assert scan.status == 'completed' assert scan.duration == 125.5 # Verify sites created sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all() assert len(sites) == 1 assert sites[0].site_name == 'Test Site' # Verify IPs created ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all() assert len(ips) == 1 assert ips[0].ip_address == '192.168.1.10' assert ips[0].ping_expected is True assert ips[0].ping_actual is True # Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53) ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all() assert len(ports) == 5 # 4 TCP + 1 UDP # Verify TCP ports tcp_ports = [p for p in ports if p.protocol == 'tcp'] assert len(tcp_ports) == 4 tcp_port_numbers = sorted([p.port for p in tcp_ports]) assert tcp_port_numbers == [22, 80, 443, 8080] # Verify UDP ports udp_ports = [p for p in ports if p.protocol == 'udp'] assert len(udp_ports) == 1 assert udp_ports[0].port == 53 # Verify services created services = test_db.query(ScanServiceModel).filter( ScanServiceModel.scan_id == scan_id ).all() assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080) # Find HTTPS service https_service = next( (s for s in services if s.service_name == 'https'), None ) assert https_service is not None assert https_service.product == 'nginx' assert https_service.version == '1.24.0' assert https_service.http_protocol == 'https' assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png' def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report): """Test that expected vs actual ports are correctly flagged.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_config_file) service._save_scan_to_db(sample_scan_report, scan_id) # Check TCP ports tcp_ports = test_db.query(ScanPort).filter( ScanPort.scan_id == scan_id, ScanPort.protocol == 'tcp' ).all() # Ports 22, 80, 443 were expected expected_ports = {22, 80, 443} for port in tcp_ports: if port.port in expected_ports: assert port.expected is True, f"Port {port.port} should be expected" else: # Port 8080 was not expected assert port.expected is False, f"Port {port.port} should not be expected" def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report): """Test that certificate and TLS data are correctly mapped.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_config_file) service._save_scan_to_db(sample_scan_report, scan_id) # Find HTTPS service https_service = test_db.query(ScanServiceModel).filter( ScanServiceModel.scan_id == scan_id, ScanServiceModel.service_name == 'https' ).first() assert https_service is not None # Verify certificate created assert len(https_service.certificates) == 1 cert = https_service.certificates[0] assert cert.subject == 'CN=example.com' assert cert.issuer == "CN=Let's Encrypt Authority" assert cert.days_until_expiry == 365 assert cert.is_self_signed is False # Verify SANs import json sans = json.loads(cert.sans) assert 'example.com' in sans assert 'www.example.com' in sans # Verify TLS versions assert len(cert.tls_versions) == 2 tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None) assert tls_12 is not None assert tls_12.supported is True tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None) assert tls_13 is not None assert tls_13.supported is True def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report): """Test retrieving scan with all nested relationships.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_config_file) service._save_scan_to_db(sample_scan_report, scan_id) # Get full scan details result = service.get_scan(scan_id) assert result is not None assert len(result['sites']) == 1 site = result['sites'][0] assert site['name'] == 'Test Site' assert len(site['ips']) == 1 ip = site['ips'][0] assert ip['address'] == '192.168.1.10' assert len(ip['ports']) == 5 # 4 TCP + 1 UDP # Find HTTPS port https_port = next( (p for p in ip['ports'] if p['port'] == 443), None ) assert https_port is not None assert len(https_port['services']) == 1 service_data = https_port['services'][0] assert service_data['service_name'] == 'https' assert 'certificates' in service_data assert len(service_data['certificates']) == 1 cert = service_data['certificates'][0] assert cert['subject'] == 'CN=example.com' assert 'tls_versions' in cert assert len(cert['tls_versions']) == 2