""" Unit tests for scan comparison functionality. Tests scan comparison logic including port, service, and certificate comparisons, as well as drift score calculation. """ import pytest from datetime import datetime from web.models import Scan, ScanSite, ScanIP, ScanPort from web.models import ScanService as ScanServiceModel, ScanCertificate from web.services.scan_service import ScanService class TestScanComparison: """Tests for scan comparison methods.""" @pytest.fixture def scan1_data(self, test_db, sample_db_config): """Create first scan with test data.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_db_config, triggered_by='manual') # Get scan and add some test data scan = test_db.query(Scan).filter(Scan.id == scan_id).first() scan.status = 'completed' # Create site site = ScanSite(scan_id=scan.id, site_name='Test Site') test_db.add(site) test_db.flush() # Create IP ip = ScanIP( scan_id=scan.id, site_id=site.id, ip_address='192.168.1.100', ping_expected=True, ping_actual=True ) test_db.add(ip) test_db.flush() # Create ports port1 = ScanPort( scan_id=scan.id, ip_id=ip.id, port=80, protocol='tcp', state='open', expected=True ) port2 = ScanPort( scan_id=scan.id, ip_id=ip.id, port=443, protocol='tcp', state='open', expected=True ) test_db.add(port1) test_db.add(port2) test_db.flush() # Create service svc1 = ScanServiceModel( scan_id=scan.id, port_id=port1.id, service_name='http', product='nginx', version='1.18.0' ) test_db.add(svc1) test_db.commit() return scan_id @pytest.fixture def scan2_data(self, test_db, sample_db_config): """Create second scan with modified test data.""" service = ScanService(test_db) scan_id = service.trigger_scan(sample_db_config, triggered_by='manual') # Get scan and add some test data scan = test_db.query(Scan).filter(Scan.id == scan_id).first() scan.status = 'completed' # Create site site = ScanSite(scan_id=scan.id, site_name='Test Site') test_db.add(site) test_db.flush() # Create IP ip = ScanIP( scan_id=scan.id, site_id=site.id, ip_address='192.168.1.100', ping_expected=True, ping_actual=True ) test_db.add(ip) test_db.flush() # Create ports (port 80 removed, 443 kept, 8080 added) port2 = ScanPort( scan_id=scan.id, ip_id=ip.id, port=443, protocol='tcp', state='open', expected=True ) port3 = ScanPort( scan_id=scan.id, ip_id=ip.id, port=8080, protocol='tcp', state='open', expected=False ) test_db.add(port2) test_db.add(port3) test_db.flush() # Create service with updated version svc2 = ScanServiceModel( scan_id=scan.id, port_id=port3.id, service_name='http', product='nginx', version='1.20.0' # Version changed ) test_db.add(svc2) test_db.commit() return scan_id def test_compare_scans_basic(self, test_db, scan1_data, scan2_data): """Test basic scan comparison.""" service = ScanService(test_db) result = service.compare_scans(scan1_data, scan2_data) assert result is not None assert 'scan1' in result assert 'scan2' in result assert 'ports' in result assert 'services' in result assert 'certificates' in result assert 'drift_score' in result # Verify scan metadata assert result['scan1']['id'] == scan1_data assert result['scan2']['id'] == scan2_data def test_compare_scans_not_found(self, test_db): """Test comparison with nonexistent scan.""" service = ScanService(test_db) result = service.compare_scans(999, 998) assert result is None def test_compare_ports(self, test_db, scan1_data, scan2_data): """Test port comparison logic.""" service = ScanService(test_db) result = service.compare_scans(scan1_data, scan2_data) # Scan1 has ports 80, 443 # Scan2 has ports 443, 8080 # Expected: added=[8080], removed=[80], unchanged=[443] ports = result['ports'] assert len(ports['added']) == 1 assert len(ports['removed']) == 1 assert len(ports['unchanged']) == 1 # Check added port added_port = ports['added'][0] assert added_port['port'] == 8080 # Check removed port removed_port = ports['removed'][0] assert removed_port['port'] == 80 # Check unchanged port unchanged_port = ports['unchanged'][0] assert unchanged_port['port'] == 443 def test_compare_services(self, test_db, scan1_data, scan2_data): """Test service comparison logic.""" service = ScanService(test_db) result = service.compare_scans(scan1_data, scan2_data) services = result['services'] # Scan1 has nginx 1.18.0 on port 80 # Scan2 has nginx 1.20.0 on port 8080 # These are on different ports, so they should be added/removed, not changed assert len(services['added']) >= 0 assert len(services['removed']) >= 0 def test_drift_score_calculation(self, test_db, scan1_data, scan2_data): """Test drift score calculation.""" service = ScanService(test_db) result = service.compare_scans(scan1_data, scan2_data) drift_score = result['drift_score'] # Drift score should be between 0.0 and 1.0 assert 0.0 <= drift_score <= 1.0 # Since we have changes (1 port added, 1 removed), drift should be > 0 assert drift_score > 0.0 def test_compare_identical_scans(self, test_db, scan1_data): """Test comparing a scan with itself (should have zero drift).""" service = ScanService(test_db) result = service.compare_scans(scan1_data, scan1_data) # Comparing scan with itself should have zero drift assert result['drift_score'] == 0.0 assert len(result['ports']['added']) == 0 assert len(result['ports']['removed']) == 0 class TestScanComparisonAPI: """Tests for scan comparison API endpoint.""" def test_compare_scans_api(self, client, auth_headers, scan1_data, scan2_data): """Test scan comparison API endpoint.""" response = client.get( f'/api/scans/{scan1_data}/compare/{scan2_data}', headers=auth_headers ) assert response.status_code == 200 data = response.get_json() assert 'scan1' in data assert 'scan2' in data assert 'ports' in data assert 'services' in data assert 'drift_score' in data def test_compare_scans_api_not_found(self, client, auth_headers): """Test comparison API with nonexistent scans.""" response = client.get( '/api/scans/999/compare/998', headers=auth_headers ) assert response.status_code == 404 data = response.get_json() assert 'error' in data def test_compare_scans_api_requires_auth(self, client, scan1_data, scan2_data): """Test that comparison API requires authentication.""" response = client.get(f'/api/scans/{scan1_data}/compare/{scan2_data}') assert response.status_code == 401 class TestHistoricalChartAPI: """Tests for historical scan chart API endpoint.""" def test_scan_history_api(self, client, auth_headers, scan1_data): """Test scan history API endpoint.""" response = client.get( f'/api/stats/scan-history/{scan1_data}', headers=auth_headers ) assert response.status_code == 200 data = response.get_json() assert 'scans' in data assert 'labels' in data assert 'port_counts' in data assert 'config_file' in data # Should include at least the scan we created assert len(data['scans']) >= 1 def test_scan_history_api_not_found(self, client, auth_headers): """Test history API with nonexistent scan.""" response = client.get( '/api/stats/scan-history/999', headers=auth_headers ) assert response.status_code == 404 data = response.get_json() assert 'error' in data def test_scan_history_api_limit(self, client, auth_headers, scan1_data): """Test scan history API with limit parameter.""" response = client.get( f'/api/stats/scan-history/{scan1_data}?limit=5', headers=auth_headers ) assert response.status_code == 200 data = response.get_json() # Should respect limit assert len(data['scans']) <= 5 def test_scan_history_api_requires_auth(self, client, scan1_data): """Test that history API requires authentication.""" response = client.get(f'/api/stats/scan-history/{scan1_data}') assert response.status_code == 401