Files
SneakyScan/app/tests/test_scan_service.py

403 lines
14 KiB
Python

"""
Unit tests for ScanService class.
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
"""
import pytest
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
from web.services.scan_service import ScanService
class TestScanServiceTrigger:
"""Tests for triggering scans."""
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
"""Test triggering a scan with valid config file."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
# Verify scan created
assert scan_id is not None
assert isinstance(scan_id, int)
# Verify scan in database
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan is not None
assert scan.status == 'running'
assert scan.title == 'Test Scan'
assert scan.triggered_by == 'manual'
assert scan.config_file == sample_config_file
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
"""Test triggering a scan with invalid config file."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="Invalid config file"):
service.trigger_scan(sample_invalid_config_file)
def test_trigger_scan_nonexistent_file(self, test_db):
"""Test triggering a scan with nonexistent config file."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="does not exist"):
service.trigger_scan('/nonexistent/config.yaml')
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
"""Test triggering a scan via schedule."""
service = ScanService(test_db)
scan_id = service.trigger_scan(
sample_config_file,
triggered_by='scheduled',
schedule_id=42
)
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan.triggered_by == 'scheduled'
assert scan.schedule_id == 42
class TestScanServiceGet:
"""Tests for retrieving scans."""
def test_get_scan_not_found(self, test_db):
"""Test getting a nonexistent scan."""
service = ScanService(test_db)
result = service.get_scan(999)
assert result is None
def test_get_scan_found(self, test_db, sample_config_file):
"""Test getting an existing scan."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Retrieve it
result = service.get_scan(scan_id)
assert result is not None
assert result['id'] == scan_id
assert result['title'] == 'Test Scan'
assert result['status'] == 'running'
assert 'sites' in result
class TestScanServiceList:
"""Tests for listing scans."""
def test_list_scans_empty(self, test_db):
"""Test listing scans when database is empty."""
service = ScanService(test_db)
result = service.list_scans(page=1, per_page=20)
assert result.total == 0
assert len(result.items) == 0
assert result.pages == 0
def test_list_scans_with_data(self, test_db, sample_config_file):
"""Test listing scans with multiple scans."""
service = ScanService(test_db)
# Create 3 scans
for i in range(3):
service.trigger_scan(sample_config_file, triggered_by='api')
# List all scans
result = service.list_scans(page=1, per_page=20)
assert result.total == 3
assert len(result.items) == 3
assert result.pages == 1
def test_list_scans_pagination(self, test_db, sample_config_file):
"""Test pagination."""
service = ScanService(test_db)
# Create 5 scans
for i in range(5):
service.trigger_scan(sample_config_file)
# Get page 1 (2 items per page)
result = service.list_scans(page=1, per_page=2)
assert len(result.items) == 2
assert result.total == 5
assert result.pages == 3
assert result.has_next is True
# Get page 2
result = service.list_scans(page=2, per_page=2)
assert len(result.items) == 2
assert result.has_prev is True
assert result.has_next is True
# Get page 3 (last page)
result = service.list_scans(page=3, per_page=2)
assert len(result.items) == 1
assert result.has_next is False
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
"""Test filtering scans by status."""
service = ScanService(test_db)
# Create scans with different statuses
scan_id_1 = service.trigger_scan(sample_config_file)
scan_id_2 = service.trigger_scan(sample_config_file)
# Mark one as completed
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
scan.status = 'completed'
test_db.commit()
# Filter by running
result = service.list_scans(status_filter='running')
assert result.total == 1
# Filter by completed
result = service.list_scans(status_filter='completed')
assert result.total == 1
def test_list_scans_invalid_status_filter(self, test_db):
"""Test filtering with invalid status."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="Invalid status"):
service.list_scans(status_filter='invalid_status')
class TestScanServiceDelete:
"""Tests for deleting scans."""
def test_delete_scan_not_found(self, test_db):
"""Test deleting a nonexistent scan."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="not found"):
service.delete_scan(999)
def test_delete_scan_success(self, test_db, sample_config_file):
"""Test successful scan deletion."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Verify it exists
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
# Delete it
result = service.delete_scan(scan_id)
assert result is True
# Verify it's gone
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
class TestScanServiceStatus:
"""Tests for scan status retrieval."""
def test_get_scan_status_not_found(self, test_db):
"""Test getting status of nonexistent scan."""
service = ScanService(test_db)
result = service.get_scan_status(999)
assert result is None
def test_get_scan_status_running(self, test_db, sample_config_file):
"""Test getting status of running scan."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
status = service.get_scan_status(scan_id)
assert status is not None
assert status['scan_id'] == scan_id
assert status['status'] == 'running'
assert status['progress'] == 'In progress'
assert status['title'] == 'Test Scan'
def test_get_scan_status_completed(self, test_db, sample_config_file):
"""Test getting status of completed scan."""
service = ScanService(test_db)
# Create and mark as completed
scan_id = service.trigger_scan(sample_config_file)
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
scan.status = 'completed'
scan.duration = 125.5
test_db.commit()
status = service.get_scan_status(scan_id)
assert status['status'] == 'completed'
assert status['progress'] == 'Complete'
assert status['duration'] == 125.5
class TestScanServiceDatabaseMapping:
"""Tests for mapping scan reports to database models."""
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
"""Test saving a complete scan report to database."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Save report to database
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
# Verify scan updated
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan.status == 'completed'
assert scan.duration == 125.5
# Verify sites created
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
assert len(sites) == 1
assert sites[0].site_name == 'Test Site'
# Verify IPs created
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
assert len(ips) == 1
assert ips[0].ip_address == '192.168.1.10'
assert ips[0].ping_expected is True
assert ips[0].ping_actual is True
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
assert len(ports) == 5 # 4 TCP + 1 UDP
# Verify TCP ports
tcp_ports = [p for p in ports if p.protocol == 'tcp']
assert len(tcp_ports) == 4
tcp_port_numbers = sorted([p.port for p in tcp_ports])
assert tcp_port_numbers == [22, 80, 443, 8080]
# Verify UDP ports
udp_ports = [p for p in ports if p.protocol == 'udp']
assert len(udp_ports) == 1
assert udp_ports[0].port == 53
# Verify services created
services = test_db.query(ScanServiceModel).filter(
ScanServiceModel.scan_id == scan_id
).all()
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
# Find HTTPS service
https_service = next(
(s for s in services if s.service_name == 'https'), None
)
assert https_service is not None
assert https_service.product == 'nginx'
assert https_service.version == '1.24.0'
assert https_service.http_protocol == 'https'
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
"""Test that expected vs actual ports are correctly flagged."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Check TCP ports
tcp_ports = test_db.query(ScanPort).filter(
ScanPort.scan_id == scan_id,
ScanPort.protocol == 'tcp'
).all()
# Ports 22, 80, 443 were expected
expected_ports = {22, 80, 443}
for port in tcp_ports:
if port.port in expected_ports:
assert port.expected is True, f"Port {port.port} should be expected"
else:
# Port 8080 was not expected
assert port.expected is False, f"Port {port.port} should not be expected"
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
"""Test that certificate and TLS data are correctly mapped."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Find HTTPS service
https_service = test_db.query(ScanServiceModel).filter(
ScanServiceModel.scan_id == scan_id,
ScanServiceModel.service_name == 'https'
).first()
assert https_service is not None
# Verify certificate created
assert len(https_service.certificates) == 1
cert = https_service.certificates[0]
assert cert.subject == 'CN=example.com'
assert cert.issuer == "CN=Let's Encrypt Authority"
assert cert.days_until_expiry == 365
assert cert.is_self_signed is False
# Verify SANs
import json
sans = json.loads(cert.sans)
assert 'example.com' in sans
assert 'www.example.com' in sans
# Verify TLS versions
assert len(cert.tls_versions) == 2
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
assert tls_12 is not None
assert tls_12.supported is True
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
assert tls_13 is not None
assert tls_13.supported is True
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
"""Test retrieving scan with all nested relationships."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Get full scan details
result = service.get_scan(scan_id)
assert result is not None
assert len(result['sites']) == 1
site = result['sites'][0]
assert site['name'] == 'Test Site'
assert len(site['ips']) == 1
ip = site['ips'][0]
assert ip['address'] == '192.168.1.10'
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
# Find HTTPS port
https_port = next(
(p for p in ip['ports'] if p['port'] == 443), None
)
assert https_port is not None
assert len(https_port['services']) == 1
service_data = https_port['services'][0]
assert service_data['service_name'] == 'https'
assert 'certificates' in service_data
assert len(service_data['certificates']) == 1
cert = service_data['certificates'][0]
assert cert['subject'] == 'CN=example.com'
assert 'tls_versions' in cert
assert len(cert['tls_versions']) == 2