Complete the foundation for Phase 2 by implementing the service layer, utilities, and comprehensive test suite. This establishes the core business logic for scan management. Service Layer: - Add ScanService class with complete scan lifecycle management * trigger_scan() - Create scan record and prepare for execution * get_scan() - Retrieve scan with all related data (eager loading) * list_scans() - Paginated scan list with status filtering * delete_scan() - Remove scan from DB and delete all files * get_scan_status() - Poll current scan status and progress * _save_scan_to_db() - Persist scan results to database * _map_report_to_models() - Complex JSON-to-DB mapping logic Database Mapping: - Comprehensive mapping from scanner JSON output to normalized schema - Handles nested relationships: sites → IPs → ports → services → certs → TLS - Processes both TCP and UDP ports with expected/actual tracking - Maps service detection results with HTTP/HTTPS information - Stores SSL/TLS certificates with expiration tracking - Records TLS version support and cipher suites - Links screenshots to services Utilities: - Add pagination.py with PaginatedResult class * paginate() function for SQLAlchemy queries * validate_page_params() for input sanitization * Metadata: total, pages, has_prev, has_next, etc. - Add validators.py with comprehensive validation functions * validate_config_file() - YAML structure and required fields * validate_scan_status() - Enum validation (running/completed/failed) * validate_scan_id() - Positive integer validation * validate_port() - Port range validation (1-65535) * validate_ip_address() - Basic IPv4 format validation * sanitize_filename() - Path traversal prevention Database Migration: - Add migration 002 for scan status index - Optimizes queries filtering by scan status - Timestamp index already exists from migration 001 Testing: - Add pytest infrastructure with conftest.py * test_db fixture - Temporary SQLite database per test * sample_scan_report fixture - Realistic scanner output * sample_config_file fixture - Valid YAML config * sample_invalid_config_file fixture - For validation tests - Add comprehensive test_scan_service.py (15 tests) * Test scan trigger with valid/invalid configs * Test scan retrieval (found/not found cases) * Test scan listing with pagination and filtering * Test scan deletion with cascade cleanup * Test scan status retrieval * Test database mapping from JSON to models * Test expected vs actual port flagging * Test certificate and TLS data mapping * Test full scan retrieval with all relationships * All tests passing Files Added: - web/services/__init__.py - web/services/scan_service.py (545 lines) - web/utils/pagination.py (153 lines) - web/utils/validators.py (245 lines) - migrations/versions/002_add_scan_indexes.py - tests/__init__.py - tests/conftest.py (142 lines) - tests/test_scan_service.py (374 lines) Next Steps (Step 2): - Implement scan API endpoints in web/api/scans.py - Add authentication decorators - Integrate ScanService with API routes - Test API endpoints with integration tests Phase 2 Step 1 Complete ✓
403 lines
14 KiB
Python
403 lines
14 KiB
Python
"""
|
|
Unit tests for ScanService class.
|
|
|
|
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
|
|
from web.services.scan_service import ScanService
|
|
|
|
|
|
class TestScanServiceTrigger:
|
|
"""Tests for triggering scans."""
|
|
|
|
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
|
|
"""Test triggering a scan with valid config file."""
|
|
service = ScanService(test_db)
|
|
|
|
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
|
|
|
# Verify scan created
|
|
assert scan_id is not None
|
|
assert isinstance(scan_id, int)
|
|
|
|
# Verify scan in database
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
|
assert scan is not None
|
|
assert scan.status == 'running'
|
|
assert scan.title == 'Test Scan'
|
|
assert scan.triggered_by == 'manual'
|
|
assert scan.config_file == sample_config_file
|
|
|
|
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
|
|
"""Test triggering a scan with invalid config file."""
|
|
service = ScanService(test_db)
|
|
|
|
with pytest.raises(ValueError, match="Invalid config file"):
|
|
service.trigger_scan(sample_invalid_config_file)
|
|
|
|
def test_trigger_scan_nonexistent_file(self, test_db):
|
|
"""Test triggering a scan with nonexistent config file."""
|
|
service = ScanService(test_db)
|
|
|
|
with pytest.raises(ValueError, match="does not exist"):
|
|
service.trigger_scan('/nonexistent/config.yaml')
|
|
|
|
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
|
|
"""Test triggering a scan via schedule."""
|
|
service = ScanService(test_db)
|
|
|
|
scan_id = service.trigger_scan(
|
|
sample_config_file,
|
|
triggered_by='scheduled',
|
|
schedule_id=42
|
|
)
|
|
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
|
assert scan.triggered_by == 'scheduled'
|
|
assert scan.schedule_id == 42
|
|
|
|
|
|
class TestScanServiceGet:
|
|
"""Tests for retrieving scans."""
|
|
|
|
def test_get_scan_not_found(self, test_db):
|
|
"""Test getting a nonexistent scan."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.get_scan(999)
|
|
assert result is None
|
|
|
|
def test_get_scan_found(self, test_db, sample_config_file):
|
|
"""Test getting an existing scan."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create a scan
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
|
|
# Retrieve it
|
|
result = service.get_scan(scan_id)
|
|
|
|
assert result is not None
|
|
assert result['id'] == scan_id
|
|
assert result['title'] == 'Test Scan'
|
|
assert result['status'] == 'running'
|
|
assert 'sites' in result
|
|
|
|
|
|
class TestScanServiceList:
|
|
"""Tests for listing scans."""
|
|
|
|
def test_list_scans_empty(self, test_db):
|
|
"""Test listing scans when database is empty."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.list_scans(page=1, per_page=20)
|
|
|
|
assert result.total == 0
|
|
assert len(result.items) == 0
|
|
assert result.pages == 0
|
|
|
|
def test_list_scans_with_data(self, test_db, sample_config_file):
|
|
"""Test listing scans with multiple scans."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create 3 scans
|
|
for i in range(3):
|
|
service.trigger_scan(sample_config_file, triggered_by='api')
|
|
|
|
# List all scans
|
|
result = service.list_scans(page=1, per_page=20)
|
|
|
|
assert result.total == 3
|
|
assert len(result.items) == 3
|
|
assert result.pages == 1
|
|
|
|
def test_list_scans_pagination(self, test_db, sample_config_file):
|
|
"""Test pagination."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create 5 scans
|
|
for i in range(5):
|
|
service.trigger_scan(sample_config_file)
|
|
|
|
# Get page 1 (2 items per page)
|
|
result = service.list_scans(page=1, per_page=2)
|
|
assert len(result.items) == 2
|
|
assert result.total == 5
|
|
assert result.pages == 3
|
|
assert result.has_next is True
|
|
|
|
# Get page 2
|
|
result = service.list_scans(page=2, per_page=2)
|
|
assert len(result.items) == 2
|
|
assert result.has_prev is True
|
|
assert result.has_next is True
|
|
|
|
# Get page 3 (last page)
|
|
result = service.list_scans(page=3, per_page=2)
|
|
assert len(result.items) == 1
|
|
assert result.has_next is False
|
|
|
|
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
|
|
"""Test filtering scans by status."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create scans with different statuses
|
|
scan_id_1 = service.trigger_scan(sample_config_file)
|
|
scan_id_2 = service.trigger_scan(sample_config_file)
|
|
|
|
# Mark one as completed
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
|
|
scan.status = 'completed'
|
|
test_db.commit()
|
|
|
|
# Filter by running
|
|
result = service.list_scans(status_filter='running')
|
|
assert result.total == 1
|
|
|
|
# Filter by completed
|
|
result = service.list_scans(status_filter='completed')
|
|
assert result.total == 1
|
|
|
|
def test_list_scans_invalid_status_filter(self, test_db):
|
|
"""Test filtering with invalid status."""
|
|
service = ScanService(test_db)
|
|
|
|
with pytest.raises(ValueError, match="Invalid status"):
|
|
service.list_scans(status_filter='invalid_status')
|
|
|
|
|
|
class TestScanServiceDelete:
|
|
"""Tests for deleting scans."""
|
|
|
|
def test_delete_scan_not_found(self, test_db):
|
|
"""Test deleting a nonexistent scan."""
|
|
service = ScanService(test_db)
|
|
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.delete_scan(999)
|
|
|
|
def test_delete_scan_success(self, test_db, sample_config_file):
|
|
"""Test successful scan deletion."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create a scan
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
|
|
# Verify it exists
|
|
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
|
|
|
|
# Delete it
|
|
result = service.delete_scan(scan_id)
|
|
assert result is True
|
|
|
|
# Verify it's gone
|
|
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
|
|
|
|
|
|
class TestScanServiceStatus:
|
|
"""Tests for scan status retrieval."""
|
|
|
|
def test_get_scan_status_not_found(self, test_db):
|
|
"""Test getting status of nonexistent scan."""
|
|
service = ScanService(test_db)
|
|
|
|
result = service.get_scan_status(999)
|
|
assert result is None
|
|
|
|
def test_get_scan_status_running(self, test_db, sample_config_file):
|
|
"""Test getting status of running scan."""
|
|
service = ScanService(test_db)
|
|
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
status = service.get_scan_status(scan_id)
|
|
|
|
assert status is not None
|
|
assert status['scan_id'] == scan_id
|
|
assert status['status'] == 'running'
|
|
assert status['progress'] == 'In progress'
|
|
assert status['title'] == 'Test Scan'
|
|
|
|
def test_get_scan_status_completed(self, test_db, sample_config_file):
|
|
"""Test getting status of completed scan."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create and mark as completed
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
|
scan.status = 'completed'
|
|
scan.duration = 125.5
|
|
test_db.commit()
|
|
|
|
status = service.get_scan_status(scan_id)
|
|
|
|
assert status['status'] == 'completed'
|
|
assert status['progress'] == 'Complete'
|
|
assert status['duration'] == 125.5
|
|
|
|
|
|
class TestScanServiceDatabaseMapping:
|
|
"""Tests for mapping scan reports to database models."""
|
|
|
|
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
|
|
"""Test saving a complete scan report to database."""
|
|
service = ScanService(test_db)
|
|
|
|
# Create a scan
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
|
|
# Save report to database
|
|
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
|
|
|
|
# Verify scan updated
|
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
|
assert scan.status == 'completed'
|
|
assert scan.duration == 125.5
|
|
|
|
# Verify sites created
|
|
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
|
|
assert len(sites) == 1
|
|
assert sites[0].site_name == 'Test Site'
|
|
|
|
# Verify IPs created
|
|
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
|
|
assert len(ips) == 1
|
|
assert ips[0].ip_address == '192.168.1.10'
|
|
assert ips[0].ping_expected is True
|
|
assert ips[0].ping_actual is True
|
|
|
|
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
|
|
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
|
|
assert len(ports) == 5 # 4 TCP + 1 UDP
|
|
|
|
# Verify TCP ports
|
|
tcp_ports = [p for p in ports if p.protocol == 'tcp']
|
|
assert len(tcp_ports) == 4
|
|
tcp_port_numbers = sorted([p.port for p in tcp_ports])
|
|
assert tcp_port_numbers == [22, 80, 443, 8080]
|
|
|
|
# Verify UDP ports
|
|
udp_ports = [p for p in ports if p.protocol == 'udp']
|
|
assert len(udp_ports) == 1
|
|
assert udp_ports[0].port == 53
|
|
|
|
# Verify services created
|
|
services = test_db.query(ScanServiceModel).filter(
|
|
ScanServiceModel.scan_id == scan_id
|
|
).all()
|
|
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
|
|
|
|
# Find HTTPS service
|
|
https_service = next(
|
|
(s for s in services if s.service_name == 'https'), None
|
|
)
|
|
assert https_service is not None
|
|
assert https_service.product == 'nginx'
|
|
assert https_service.version == '1.24.0'
|
|
assert https_service.http_protocol == 'https'
|
|
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
|
|
|
|
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
|
|
"""Test that expected vs actual ports are correctly flagged."""
|
|
service = ScanService(test_db)
|
|
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
|
|
|
# Check TCP ports
|
|
tcp_ports = test_db.query(ScanPort).filter(
|
|
ScanPort.scan_id == scan_id,
|
|
ScanPort.protocol == 'tcp'
|
|
).all()
|
|
|
|
# Ports 22, 80, 443 were expected
|
|
expected_ports = {22, 80, 443}
|
|
for port in tcp_ports:
|
|
if port.port in expected_ports:
|
|
assert port.expected is True, f"Port {port.port} should be expected"
|
|
else:
|
|
# Port 8080 was not expected
|
|
assert port.expected is False, f"Port {port.port} should not be expected"
|
|
|
|
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
|
|
"""Test that certificate and TLS data are correctly mapped."""
|
|
service = ScanService(test_db)
|
|
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
|
|
|
# Find HTTPS service
|
|
https_service = test_db.query(ScanServiceModel).filter(
|
|
ScanServiceModel.scan_id == scan_id,
|
|
ScanServiceModel.service_name == 'https'
|
|
).first()
|
|
|
|
assert https_service is not None
|
|
|
|
# Verify certificate created
|
|
assert len(https_service.certificates) == 1
|
|
cert = https_service.certificates[0]
|
|
|
|
assert cert.subject == 'CN=example.com'
|
|
assert cert.issuer == "CN=Let's Encrypt Authority"
|
|
assert cert.days_until_expiry == 365
|
|
assert cert.is_self_signed is False
|
|
|
|
# Verify SANs
|
|
import json
|
|
sans = json.loads(cert.sans)
|
|
assert 'example.com' in sans
|
|
assert 'www.example.com' in sans
|
|
|
|
# Verify TLS versions
|
|
assert len(cert.tls_versions) == 2
|
|
|
|
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
|
|
assert tls_12 is not None
|
|
assert tls_12.supported is True
|
|
|
|
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
|
|
assert tls_13 is not None
|
|
assert tls_13.supported is True
|
|
|
|
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
|
|
"""Test retrieving scan with all nested relationships."""
|
|
service = ScanService(test_db)
|
|
|
|
scan_id = service.trigger_scan(sample_config_file)
|
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
|
|
|
# Get full scan details
|
|
result = service.get_scan(scan_id)
|
|
|
|
assert result is not None
|
|
assert len(result['sites']) == 1
|
|
|
|
site = result['sites'][0]
|
|
assert site['name'] == 'Test Site'
|
|
assert len(site['ips']) == 1
|
|
|
|
ip = site['ips'][0]
|
|
assert ip['address'] == '192.168.1.10'
|
|
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
|
|
|
|
# Find HTTPS port
|
|
https_port = next(
|
|
(p for p in ip['ports'] if p['port'] == 443), None
|
|
)
|
|
assert https_port is not None
|
|
assert len(https_port['services']) == 1
|
|
|
|
service_data = https_port['services'][0]
|
|
assert service_data['service_name'] == 'https'
|
|
assert 'certificates' in service_data
|
|
assert len(service_data['certificates']) == 1
|
|
|
|
cert = service_data['certificates'][0]
|
|
assert cert['subject'] == 'CN=example.com'
|
|
assert 'tls_versions' in cert
|
|
assert len(cert['tls_versions']) == 2
|