396 lines
13 KiB
Python
396 lines
13 KiB
Python
"""
|
|
Unit tests for ScanService class.
|
|
|
|
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
|
|
from web.services.scan_service import ScanService
|
|
|
|
|
|
class TestScanServiceTrigger:
|
|
"""Tests for triggering scans."""
|
|
|
|
def test_trigger_scan_valid_config(self, db, sample_db_config):
|
|
"""Test triggering a scan with valid config."""
|
|
service = ScanService(db)
|
|
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id, triggered_by='manual')
|
|
|
|
# Verify scan created
|
|
assert scan_id is not None
|
|
assert isinstance(scan_id, int)
|
|
|
|
# Verify scan in database
|
|
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
|
assert scan is not None
|
|
assert scan.status == 'running'
|
|
assert scan.title == 'Test Scan'
|
|
assert scan.triggered_by == 'manual'
|
|
assert scan.config_id == sample_db_config.id
|
|
|
|
def test_trigger_scan_invalid_config(self, db):
|
|
"""Test triggering a scan with invalid config ID."""
|
|
service = ScanService(db)
|
|
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.trigger_scan(config_id=99999)
|
|
|
|
def test_trigger_scan_with_schedule(self, db, sample_db_config):
|
|
"""Test triggering a scan via schedule."""
|
|
service = ScanService(db)
|
|
|
|
scan_id = service.trigger_scan(
|
|
config_id=sample_db_config.id,
|
|
triggered_by='scheduled',
|
|
schedule_id=42
|
|
)
|
|
|
|
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
|
assert scan.triggered_by == 'scheduled'
|
|
assert scan.schedule_id == 42
|
|
|
|
|
|
class TestScanServiceGet:
|
|
"""Tests for retrieving scans."""
|
|
|
|
def test_get_scan_not_found(self, db):
|
|
"""Test getting a nonexistent scan."""
|
|
service = ScanService(db)
|
|
|
|
result = service.get_scan(999)
|
|
assert result is None
|
|
|
|
def test_get_scan_found(self, db, sample_db_config):
|
|
"""Test getting an existing scan."""
|
|
service = ScanService(db)
|
|
|
|
# Create a scan
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
|
|
# Retrieve it
|
|
result = service.get_scan(scan_id)
|
|
|
|
assert result is not None
|
|
assert result['id'] == scan_id
|
|
assert result['title'] == 'Test Scan'
|
|
assert result['status'] == 'running'
|
|
assert 'sites' in result
|
|
|
|
|
|
class TestScanServiceList:
|
|
"""Tests for listing scans."""
|
|
|
|
def test_list_scans_empty(self, db):
|
|
"""Test listing scans when database is empty."""
|
|
service = ScanService(db)
|
|
|
|
result = service.list_scans(page=1, per_page=20)
|
|
|
|
assert result.total == 0
|
|
assert len(result.items) == 0
|
|
assert result.pages == 0
|
|
|
|
def test_list_scans_with_data(self, db, sample_db_config):
|
|
"""Test listing scans with multiple scans."""
|
|
service = ScanService(db)
|
|
|
|
# Create 3 scans
|
|
for i in range(3):
|
|
service.trigger_scan(config_id=sample_db_config.id, triggered_by='api')
|
|
|
|
# List all scans
|
|
result = service.list_scans(page=1, per_page=20)
|
|
|
|
assert result.total == 3
|
|
assert len(result.items) == 3
|
|
assert result.pages == 1
|
|
|
|
def test_list_scans_pagination(self, db, sample_db_config):
|
|
"""Test pagination."""
|
|
service = ScanService(db)
|
|
|
|
# Create 5 scans
|
|
for i in range(5):
|
|
service.trigger_scan(config_id=sample_db_config.id)
|
|
|
|
# Get page 1 (2 items per page)
|
|
result = service.list_scans(page=1, per_page=2)
|
|
assert len(result.items) == 2
|
|
assert result.total == 5
|
|
assert result.pages == 3
|
|
assert result.has_next is True
|
|
|
|
# Get page 2
|
|
result = service.list_scans(page=2, per_page=2)
|
|
assert len(result.items) == 2
|
|
assert result.has_prev is True
|
|
assert result.has_next is True
|
|
|
|
# Get page 3 (last page)
|
|
result = service.list_scans(page=3, per_page=2)
|
|
assert len(result.items) == 1
|
|
assert result.has_next is False
|
|
|
|
def test_list_scans_filter_by_status(self, db, sample_db_config):
|
|
"""Test filtering scans by status."""
|
|
service = ScanService(db)
|
|
|
|
# Create scans with different statuses
|
|
scan_id_1 = service.trigger_scan(config_id=sample_db_config.id)
|
|
scan_id_2 = service.trigger_scan(config_id=sample_db_config.id)
|
|
|
|
# Mark one as completed
|
|
scan = db.query(Scan).filter(Scan.id == scan_id_1).first()
|
|
scan.status = 'completed'
|
|
db.commit()
|
|
|
|
# Filter by running
|
|
result = service.list_scans(status_filter='running')
|
|
assert result.total == 1
|
|
|
|
# Filter by completed
|
|
result = service.list_scans(status_filter='completed')
|
|
assert result.total == 1
|
|
|
|
def test_list_scans_invalid_status_filter(self, db):
|
|
"""Test filtering with invalid status."""
|
|
service = ScanService(db)
|
|
|
|
with pytest.raises(ValueError, match="Invalid status"):
|
|
service.list_scans(status_filter='invalid_status')
|
|
|
|
|
|
class TestScanServiceDelete:
|
|
"""Tests for deleting scans."""
|
|
|
|
def test_delete_scan_not_found(self, db):
|
|
"""Test deleting a nonexistent scan."""
|
|
service = ScanService(db)
|
|
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.delete_scan(999)
|
|
|
|
def test_delete_scan_success(self, db, sample_db_config):
|
|
"""Test successful scan deletion."""
|
|
service = ScanService(db)
|
|
|
|
# Create a scan
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
|
|
# Verify it exists
|
|
assert db.query(Scan).filter(Scan.id == scan_id).first() is not None
|
|
|
|
# Delete it
|
|
result = service.delete_scan(scan_id)
|
|
assert result is True
|
|
|
|
# Verify it's gone
|
|
assert db.query(Scan).filter(Scan.id == scan_id).first() is None
|
|
|
|
|
|
class TestScanServiceStatus:
|
|
"""Tests for scan status retrieval."""
|
|
|
|
def test_get_scan_status_not_found(self, db):
|
|
"""Test getting status of nonexistent scan."""
|
|
service = ScanService(db)
|
|
|
|
result = service.get_scan_status(999)
|
|
assert result is None
|
|
|
|
def test_get_scan_status_running(self, db, sample_db_config):
|
|
"""Test getting status of running scan."""
|
|
service = ScanService(db)
|
|
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
status = service.get_scan_status(scan_id)
|
|
|
|
assert status is not None
|
|
assert status['scan_id'] == scan_id
|
|
assert status['status'] == 'running'
|
|
assert status['progress'] == 'In progress'
|
|
assert status['title'] == 'Test Scan'
|
|
|
|
def test_get_scan_status_completed(self, db, sample_db_config):
|
|
"""Test getting status of completed scan."""
|
|
service = ScanService(db)
|
|
|
|
# Create and mark as completed
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
|
scan.status = 'completed'
|
|
scan.duration = 125.5
|
|
db.commit()
|
|
|
|
status = service.get_scan_status(scan_id)
|
|
|
|
assert status['status'] == 'completed'
|
|
assert status['progress'] == 'Complete'
|
|
assert status['duration'] == 125.5
|
|
|
|
|
|
class TestScanServiceDatabaseMapping:
|
|
"""Tests for mapping scan reports to database models."""
|
|
|
|
def test_save_scan_to_db(self, db, sample_db_config, sample_scan_report):
|
|
"""Test saving a complete scan report to database."""
|
|
service = ScanService(db)
|
|
|
|
# Create a scan
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
|
|
# Save report to database
|
|
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
|
|
|
|
# Verify scan updated
|
|
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
|
assert scan.status == 'completed'
|
|
assert scan.duration == 125.5
|
|
|
|
# Verify sites created
|
|
sites = db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
|
|
assert len(sites) == 1
|
|
assert sites[0].site_name == 'Test Site'
|
|
|
|
# Verify IPs created
|
|
ips = db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
|
|
assert len(ips) == 1
|
|
assert ips[0].ip_address == '192.168.1.10'
|
|
assert ips[0].ping_expected is True
|
|
assert ips[0].ping_actual is True
|
|
|
|
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
|
|
ports = db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
|
|
assert len(ports) == 5 # 4 TCP + 1 UDP
|
|
|
|
# Verify TCP ports
|
|
tcp_ports = [p for p in ports if p.protocol == 'tcp']
|
|
assert len(tcp_ports) == 4
|
|
tcp_port_numbers = sorted([p.port for p in tcp_ports])
|
|
assert tcp_port_numbers == [22, 80, 443, 8080]
|
|
|
|
# Verify UDP ports
|
|
udp_ports = [p for p in ports if p.protocol == 'udp']
|
|
assert len(udp_ports) == 1
|
|
assert udp_ports[0].port == 53
|
|
|
|
# Verify services created
|
|
services = db.query(ScanServiceModel).filter(
|
|
ScanServiceModel.scan_id == scan_id
|
|
).all()
|
|
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
|
|
|
|
# Find HTTPS service
|
|
https_service = next(
|
|
(s for s in services if s.service_name == 'https'), None
|
|
)
|
|
assert https_service is not None
|
|
assert https_service.product == 'nginx'
|
|
assert https_service.version == '1.24.0'
|
|
assert https_service.http_protocol == 'https'
|
|
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
|
|
|
|
def test_map_port_expected_vs_actual(self, db, sample_db_config, sample_scan_report):
|
|
"""Test that expected vs actual ports are correctly flagged."""
|
|
service = ScanService(db)
|
|
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
|
|
|
# Check TCP ports
|
|
tcp_ports = db.query(ScanPort).filter(
|
|
ScanPort.scan_id == scan_id,
|
|
ScanPort.protocol == 'tcp'
|
|
).all()
|
|
|
|
# Ports 22, 80, 443 were expected
|
|
expected_ports = {22, 80, 443}
|
|
for port in tcp_ports:
|
|
if port.port in expected_ports:
|
|
assert port.expected is True, f"Port {port.port} should be expected"
|
|
else:
|
|
# Port 8080 was not expected
|
|
assert port.expected is False, f"Port {port.port} should not be expected"
|
|
|
|
def test_map_certificate_and_tls(self, db, sample_db_config, sample_scan_report):
|
|
"""Test that certificate and TLS data are correctly mapped."""
|
|
service = ScanService(db)
|
|
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
|
|
|
# Find HTTPS service
|
|
https_service = db.query(ScanServiceModel).filter(
|
|
ScanServiceModel.scan_id == scan_id,
|
|
ScanServiceModel.service_name == 'https'
|
|
).first()
|
|
|
|
assert https_service is not None
|
|
|
|
# Verify certificate created
|
|
assert len(https_service.certificates) == 1
|
|
cert = https_service.certificates[0]
|
|
|
|
assert cert.subject == 'CN=example.com'
|
|
assert cert.issuer == "CN=Let's Encrypt Authority"
|
|
assert cert.days_until_expiry == 365
|
|
assert cert.is_self_signed is False
|
|
|
|
# Verify SANs
|
|
import json
|
|
sans = json.loads(cert.sans)
|
|
assert 'example.com' in sans
|
|
assert 'www.example.com' in sans
|
|
|
|
# Verify TLS versions
|
|
assert len(cert.tls_versions) == 2
|
|
|
|
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
|
|
assert tls_12 is not None
|
|
assert tls_12.supported is True
|
|
|
|
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
|
|
assert tls_13 is not None
|
|
assert tls_13.supported is True
|
|
|
|
def test_get_scan_with_full_details(self, db, sample_db_config, sample_scan_report):
|
|
"""Test retrieving scan with all nested relationships."""
|
|
service = ScanService(db)
|
|
|
|
scan_id = service.trigger_scan(config_id=sample_db_config.id)
|
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
|
|
|
# Get full scan details
|
|
result = service.get_scan(scan_id)
|
|
|
|
assert result is not None
|
|
assert len(result['sites']) == 1
|
|
|
|
site = result['sites'][0]
|
|
assert site['name'] == 'Test Site'
|
|
assert len(site['ips']) == 1
|
|
|
|
ip = site['ips'][0]
|
|
assert ip['address'] == '192.168.1.10'
|
|
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
|
|
|
|
# Find HTTPS port
|
|
https_port = next(
|
|
(p for p in ip['ports'] if p['port'] == 443), None
|
|
)
|
|
assert https_port is not None
|
|
assert len(https_port['services']) == 1
|
|
|
|
service_data = https_port['services'][0]
|
|
assert service_data['service_name'] == 'https'
|
|
assert 'certificates' in service_data
|
|
assert len(service_data['certificates']) == 1
|
|
|
|
cert = service_data['certificates'][0]
|
|
assert cert['subject'] == 'CN=example.com'
|
|
assert 'tls_versions' in cert
|
|
assert len(cert['tls_versions']) == 2
|