320 lines
9.5 KiB
Python
320 lines
9.5 KiB
Python
"""
|
|
Unit tests for scan comparison functionality.
|
|
|
|
Tests scan comparison logic including port, service, and certificate comparisons,
|
|
as well as drift score calculation.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
|
|
from web.models import Scan, ScanSite, ScanIP, ScanPort
|
|
from web.models import ScanService as ScanServiceModel, ScanCertificate
|
|
from web.services.scan_service import ScanService
|
|
|
|
|
|
class TestScanComparison:
|
|
"""Tests for scan comparison methods."""
|
|
|
|
@pytest.fixture
|
|
def scan1_data(self, test_db, sample_config_file):
|
|
"""Create first scan with test data."""
|
|
service = ScanService(test_db)
|
|
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
|
|
|
# Get scan and add some test data
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
|
scan.status = 'completed'
|
|
|
|
# Create site
|
|
site = ScanSite(scan_id=scan.id, site_name='Test Site')
|
|
test_db.add(site)
|
|
test_db.flush()
|
|
|
|
# Create IP
|
|
ip = ScanIP(
|
|
scan_id=scan.id,
|
|
site_id=site.id,
|
|
ip_address='192.168.1.100',
|
|
ping_expected=True,
|
|
ping_actual=True
|
|
)
|
|
test_db.add(ip)
|
|
test_db.flush()
|
|
|
|
# Create ports
|
|
port1 = ScanPort(
|
|
scan_id=scan.id,
|
|
ip_id=ip.id,
|
|
port=80,
|
|
protocol='tcp',
|
|
state='open',
|
|
expected=True
|
|
)
|
|
port2 = ScanPort(
|
|
scan_id=scan.id,
|
|
ip_id=ip.id,
|
|
port=443,
|
|
protocol='tcp',
|
|
state='open',
|
|
expected=True
|
|
)
|
|
test_db.add(port1)
|
|
test_db.add(port2)
|
|
test_db.flush()
|
|
|
|
# Create service
|
|
svc1 = ScanServiceModel(
|
|
scan_id=scan.id,
|
|
port_id=port1.id,
|
|
service_name='http',
|
|
product='nginx',
|
|
version='1.18.0'
|
|
)
|
|
test_db.add(svc1)
|
|
|
|
test_db.commit()
|
|
return scan_id
|
|
|
|
@pytest.fixture
|
|
def scan2_data(self, test_db, sample_config_file):
|
|
"""Create second scan with modified test data."""
|
|
service = ScanService(test_db)
|
|
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
|
|
|
# Get scan and add some test data
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
|
scan.status = 'completed'
|
|
|
|
# Create site
|
|
site = ScanSite(scan_id=scan.id, site_name='Test Site')
|
|
test_db.add(site)
|
|
test_db.flush()
|
|
|
|
# Create IP
|
|
ip = ScanIP(
|
|
scan_id=scan.id,
|
|
site_id=site.id,
|
|
ip_address='192.168.1.100',
|
|
ping_expected=True,
|
|
ping_actual=True
|
|
)
|
|
test_db.add(ip)
|
|
test_db.flush()
|
|
|
|
# Create ports (port 80 removed, 443 kept, 8080 added)
|
|
port2 = ScanPort(
|
|
scan_id=scan.id,
|
|
ip_id=ip.id,
|
|
port=443,
|
|
protocol='tcp',
|
|
state='open',
|
|
expected=True
|
|
)
|
|
port3 = ScanPort(
|
|
scan_id=scan.id,
|
|
ip_id=ip.id,
|
|
port=8080,
|
|
protocol='tcp',
|
|
state='open',
|
|
expected=False
|
|
)
|
|
test_db.add(port2)
|
|
test_db.add(port3)
|
|
test_db.flush()
|
|
|
|
# Create service with updated version
|
|
svc2 = ScanServiceModel(
|
|
scan_id=scan.id,
|
|
port_id=port3.id,
|
|
service_name='http',
|
|
product='nginx',
|
|
version='1.20.0' # Version changed
|
|
)
|
|
test_db.add(svc2)
|
|
|
|
test_db.commit()
|
|
return scan_id
|
|
|
|
def test_compare_scans_basic(self, test_db, scan1_data, scan2_data):
|
|
"""Test basic scan comparison."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.compare_scans(scan1_data, scan2_data)
|
|
|
|
assert result is not None
|
|
assert 'scan1' in result
|
|
assert 'scan2' in result
|
|
assert 'ports' in result
|
|
assert 'services' in result
|
|
assert 'certificates' in result
|
|
assert 'drift_score' in result
|
|
|
|
# Verify scan metadata
|
|
assert result['scan1']['id'] == scan1_data
|
|
assert result['scan2']['id'] == scan2_data
|
|
|
|
def test_compare_scans_not_found(self, test_db):
|
|
"""Test comparison with nonexistent scan."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.compare_scans(999, 998)
|
|
|
|
assert result is None
|
|
|
|
def test_compare_ports(self, test_db, scan1_data, scan2_data):
|
|
"""Test port comparison logic."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.compare_scans(scan1_data, scan2_data)
|
|
|
|
# Scan1 has ports 80, 443
|
|
# Scan2 has ports 443, 8080
|
|
# Expected: added=[8080], removed=[80], unchanged=[443]
|
|
|
|
ports = result['ports']
|
|
assert len(ports['added']) == 1
|
|
assert len(ports['removed']) == 1
|
|
assert len(ports['unchanged']) == 1
|
|
|
|
# Check added port
|
|
added_port = ports['added'][0]
|
|
assert added_port['port'] == 8080
|
|
|
|
# Check removed port
|
|
removed_port = ports['removed'][0]
|
|
assert removed_port['port'] == 80
|
|
|
|
# Check unchanged port
|
|
unchanged_port = ports['unchanged'][0]
|
|
assert unchanged_port['port'] == 443
|
|
|
|
def test_compare_services(self, test_db, scan1_data, scan2_data):
|
|
"""Test service comparison logic."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.compare_scans(scan1_data, scan2_data)
|
|
|
|
services = result['services']
|
|
|
|
# Scan1 has nginx 1.18.0 on port 80
|
|
# Scan2 has nginx 1.20.0 on port 8080
|
|
# These are on different ports, so they should be added/removed, not changed
|
|
|
|
assert len(services['added']) >= 0
|
|
assert len(services['removed']) >= 0
|
|
|
|
def test_drift_score_calculation(self, test_db, scan1_data, scan2_data):
|
|
"""Test drift score calculation."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.compare_scans(scan1_data, scan2_data)
|
|
|
|
drift_score = result['drift_score']
|
|
|
|
# Drift score should be between 0.0 and 1.0
|
|
assert 0.0 <= drift_score <= 1.0
|
|
|
|
# Since we have changes (1 port added, 1 removed), drift should be > 0
|
|
assert drift_score > 0.0
|
|
|
|
def test_compare_identical_scans(self, test_db, scan1_data):
|
|
"""Test comparing a scan with itself (should have zero drift)."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.compare_scans(scan1_data, scan1_data)
|
|
|
|
# Comparing scan with itself should have zero drift
|
|
assert result['drift_score'] == 0.0
|
|
assert len(result['ports']['added']) == 0
|
|
assert len(result['ports']['removed']) == 0
|
|
|
|
|
|
class TestScanComparisonAPI:
|
|
"""Tests for scan comparison API endpoint."""
|
|
|
|
def test_compare_scans_api(self, client, auth_headers, scan1_data, scan2_data):
|
|
"""Test scan comparison API endpoint."""
|
|
response = client.get(
|
|
f'/api/scans/{scan1_data}/compare/{scan2_data}',
|
|
headers=auth_headers
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
|
|
assert 'scan1' in data
|
|
assert 'scan2' in data
|
|
assert 'ports' in data
|
|
assert 'services' in data
|
|
assert 'drift_score' in data
|
|
|
|
def test_compare_scans_api_not_found(self, client, auth_headers):
|
|
"""Test comparison API with nonexistent scans."""
|
|
response = client.get(
|
|
'/api/scans/999/compare/998',
|
|
headers=auth_headers
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
data = response.get_json()
|
|
assert 'error' in data
|
|
|
|
def test_compare_scans_api_requires_auth(self, client, scan1_data, scan2_data):
|
|
"""Test that comparison API requires authentication."""
|
|
response = client.get(f'/api/scans/{scan1_data}/compare/{scan2_data}')
|
|
|
|
assert response.status_code == 401
|
|
|
|
|
|
class TestHistoricalChartAPI:
|
|
"""Tests for historical scan chart API endpoint."""
|
|
|
|
def test_scan_history_api(self, client, auth_headers, scan1_data):
|
|
"""Test scan history API endpoint."""
|
|
response = client.get(
|
|
f'/api/stats/scan-history/{scan1_data}',
|
|
headers=auth_headers
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
|
|
assert 'scans' in data
|
|
assert 'labels' in data
|
|
assert 'port_counts' in data
|
|
assert 'config_file' in data
|
|
|
|
# Should include at least the scan we created
|
|
assert len(data['scans']) >= 1
|
|
|
|
def test_scan_history_api_not_found(self, client, auth_headers):
|
|
"""Test history API with nonexistent scan."""
|
|
response = client.get(
|
|
'/api/stats/scan-history/999',
|
|
headers=auth_headers
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
data = response.get_json()
|
|
assert 'error' in data
|
|
|
|
def test_scan_history_api_limit(self, client, auth_headers, scan1_data):
|
|
"""Test scan history API with limit parameter."""
|
|
response = client.get(
|
|
f'/api/stats/scan-history/{scan1_data}?limit=5',
|
|
headers=auth_headers
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.get_json()
|
|
|
|
# Should respect limit
|
|
assert len(data['scans']) <= 5
|
|
|
|
def test_scan_history_api_requires_auth(self, client, scan1_data):
|
|
"""Test that history API requires authentication."""
|
|
response = client.get(f'/api/stats/scan-history/{scan1_data}')
|
|
|
|
assert response.status_code == 401
|