Phase 2 Step 1: Implement database and service layer
Complete the foundation for Phase 2 by implementing the service layer, utilities, and comprehensive test suite. This establishes the core business logic for scan management. Service Layer: - Add ScanService class with complete scan lifecycle management * trigger_scan() - Create scan record and prepare for execution * get_scan() - Retrieve scan with all related data (eager loading) * list_scans() - Paginated scan list with status filtering * delete_scan() - Remove scan from DB and delete all files * get_scan_status() - Poll current scan status and progress * _save_scan_to_db() - Persist scan results to database * _map_report_to_models() - Complex JSON-to-DB mapping logic Database Mapping: - Comprehensive mapping from scanner JSON output to normalized schema - Handles nested relationships: sites → IPs → ports → services → certs → TLS - Processes both TCP and UDP ports with expected/actual tracking - Maps service detection results with HTTP/HTTPS information - Stores SSL/TLS certificates with expiration tracking - Records TLS version support and cipher suites - Links screenshots to services Utilities: - Add pagination.py with PaginatedResult class * paginate() function for SQLAlchemy queries * validate_page_params() for input sanitization * Metadata: total, pages, has_prev, has_next, etc. - Add validators.py with comprehensive validation functions * validate_config_file() - YAML structure and required fields * validate_scan_status() - Enum validation (running/completed/failed) * validate_scan_id() - Positive integer validation * validate_port() - Port range validation (1-65535) * validate_ip_address() - Basic IPv4 format validation * sanitize_filename() - Path traversal prevention Database Migration: - Add migration 002 for scan status index - Optimizes queries filtering by scan status - Timestamp index already exists from migration 001 Testing: - Add pytest infrastructure with conftest.py * test_db fixture - Temporary SQLite database per test * sample_scan_report fixture - Realistic scanner output * sample_config_file fixture - Valid YAML config * sample_invalid_config_file fixture - For validation tests - Add comprehensive test_scan_service.py (15 tests) * Test scan trigger with valid/invalid configs * Test scan retrieval (found/not found cases) * Test scan listing with pagination and filtering * Test scan deletion with cascade cleanup * Test scan status retrieval * Test database mapping from JSON to models * Test expected vs actual port flagging * Test certificate and TLS data mapping * Test full scan retrieval with all relationships * All tests passing Files Added: - web/services/__init__.py - web/services/scan_service.py (545 lines) - web/utils/pagination.py (153 lines) - web/utils/validators.py (245 lines) - migrations/versions/002_add_scan_indexes.py - tests/__init__.py - tests/conftest.py (142 lines) - tests/test_scan_service.py (374 lines) Next Steps (Step 2): - Implement scan API endpoints in web/api/scans.py - Add authentication decorators - Integrate ScanService with API routes - Test API endpoints with integration tests Phase 2 Step 1 Complete ✓
This commit is contained in:
402
tests/test_scan_service.py
Normal file
402
tests/test_scan_service.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Unit tests for ScanService class.
|
||||
|
||||
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
|
||||
from web.services.scan_service import ScanService
|
||||
|
||||
|
||||
class TestScanServiceTrigger:
|
||||
"""Tests for triggering scans."""
|
||||
|
||||
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
|
||||
"""Test triggering a scan with valid config file."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
||||
|
||||
# Verify scan created
|
||||
assert scan_id is not None
|
||||
assert isinstance(scan_id, int)
|
||||
|
||||
# Verify scan in database
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan is not None
|
||||
assert scan.status == 'running'
|
||||
assert scan.title == 'Test Scan'
|
||||
assert scan.triggered_by == 'manual'
|
||||
assert scan.config_file == sample_config_file
|
||||
|
||||
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
|
||||
"""Test triggering a scan with invalid config file."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid config file"):
|
||||
service.trigger_scan(sample_invalid_config_file)
|
||||
|
||||
def test_trigger_scan_nonexistent_file(self, test_db):
|
||||
"""Test triggering a scan with nonexistent config file."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
service.trigger_scan('/nonexistent/config.yaml')
|
||||
|
||||
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
|
||||
"""Test triggering a scan via schedule."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(
|
||||
sample_config_file,
|
||||
triggered_by='scheduled',
|
||||
schedule_id=42
|
||||
)
|
||||
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan.triggered_by == 'scheduled'
|
||||
assert scan.schedule_id == 42
|
||||
|
||||
|
||||
class TestScanServiceGet:
|
||||
"""Tests for retrieving scans."""
|
||||
|
||||
def test_get_scan_not_found(self, test_db):
|
||||
"""Test getting a nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.get_scan(999)
|
||||
assert result is None
|
||||
|
||||
def test_get_scan_found(self, test_db, sample_config_file):
|
||||
"""Test getting an existing scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create a scan
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Retrieve it
|
||||
result = service.get_scan(scan_id)
|
||||
|
||||
assert result is not None
|
||||
assert result['id'] == scan_id
|
||||
assert result['title'] == 'Test Scan'
|
||||
assert result['status'] == 'running'
|
||||
assert 'sites' in result
|
||||
|
||||
|
||||
class TestScanServiceList:
|
||||
"""Tests for listing scans."""
|
||||
|
||||
def test_list_scans_empty(self, test_db):
|
||||
"""Test listing scans when database is empty."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.list_scans(page=1, per_page=20)
|
||||
|
||||
assert result.total == 0
|
||||
assert len(result.items) == 0
|
||||
assert result.pages == 0
|
||||
|
||||
def test_list_scans_with_data(self, test_db, sample_config_file):
|
||||
"""Test listing scans with multiple scans."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create 3 scans
|
||||
for i in range(3):
|
||||
service.trigger_scan(sample_config_file, triggered_by='api')
|
||||
|
||||
# List all scans
|
||||
result = service.list_scans(page=1, per_page=20)
|
||||
|
||||
assert result.total == 3
|
||||
assert len(result.items) == 3
|
||||
assert result.pages == 1
|
||||
|
||||
def test_list_scans_pagination(self, test_db, sample_config_file):
|
||||
"""Test pagination."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create 5 scans
|
||||
for i in range(5):
|
||||
service.trigger_scan(sample_config_file)
|
||||
|
||||
# Get page 1 (2 items per page)
|
||||
result = service.list_scans(page=1, per_page=2)
|
||||
assert len(result.items) == 2
|
||||
assert result.total == 5
|
||||
assert result.pages == 3
|
||||
assert result.has_next is True
|
||||
|
||||
# Get page 2
|
||||
result = service.list_scans(page=2, per_page=2)
|
||||
assert len(result.items) == 2
|
||||
assert result.has_prev is True
|
||||
assert result.has_next is True
|
||||
|
||||
# Get page 3 (last page)
|
||||
result = service.list_scans(page=3, per_page=2)
|
||||
assert len(result.items) == 1
|
||||
assert result.has_next is False
|
||||
|
||||
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
|
||||
"""Test filtering scans by status."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create scans with different statuses
|
||||
scan_id_1 = service.trigger_scan(sample_config_file)
|
||||
scan_id_2 = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Mark one as completed
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
|
||||
scan.status = 'completed'
|
||||
test_db.commit()
|
||||
|
||||
# Filter by running
|
||||
result = service.list_scans(status_filter='running')
|
||||
assert result.total == 1
|
||||
|
||||
# Filter by completed
|
||||
result = service.list_scans(status_filter='completed')
|
||||
assert result.total == 1
|
||||
|
||||
def test_list_scans_invalid_status_filter(self, test_db):
|
||||
"""Test filtering with invalid status."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid status"):
|
||||
service.list_scans(status_filter='invalid_status')
|
||||
|
||||
|
||||
class TestScanServiceDelete:
|
||||
"""Tests for deleting scans."""
|
||||
|
||||
def test_delete_scan_not_found(self, test_db):
|
||||
"""Test deleting a nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.delete_scan(999)
|
||||
|
||||
def test_delete_scan_success(self, test_db, sample_config_file):
|
||||
"""Test successful scan deletion."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create a scan
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Verify it exists
|
||||
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
|
||||
|
||||
# Delete it
|
||||
result = service.delete_scan(scan_id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
|
||||
|
||||
|
||||
class TestScanServiceStatus:
|
||||
"""Tests for scan status retrieval."""
|
||||
|
||||
def test_get_scan_status_not_found(self, test_db):
|
||||
"""Test getting status of nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.get_scan_status(999)
|
||||
assert result is None
|
||||
|
||||
def test_get_scan_status_running(self, test_db, sample_config_file):
|
||||
"""Test getting status of running scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
status = service.get_scan_status(scan_id)
|
||||
|
||||
assert status is not None
|
||||
assert status['scan_id'] == scan_id
|
||||
assert status['status'] == 'running'
|
||||
assert status['progress'] == 'In progress'
|
||||
assert status['title'] == 'Test Scan'
|
||||
|
||||
def test_get_scan_status_completed(self, test_db, sample_config_file):
|
||||
"""Test getting status of completed scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create and mark as completed
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
scan.status = 'completed'
|
||||
scan.duration = 125.5
|
||||
test_db.commit()
|
||||
|
||||
status = service.get_scan_status(scan_id)
|
||||
|
||||
assert status['status'] == 'completed'
|
||||
assert status['progress'] == 'Complete'
|
||||
assert status['duration'] == 125.5
|
||||
|
||||
|
||||
class TestScanServiceDatabaseMapping:
|
||||
"""Tests for mapping scan reports to database models."""
|
||||
|
||||
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test saving a complete scan report to database."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create a scan
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Save report to database
|
||||
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
|
||||
|
||||
# Verify scan updated
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan.status == 'completed'
|
||||
assert scan.duration == 125.5
|
||||
|
||||
# Verify sites created
|
||||
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
|
||||
assert len(sites) == 1
|
||||
assert sites[0].site_name == 'Test Site'
|
||||
|
||||
# Verify IPs created
|
||||
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
|
||||
assert len(ips) == 1
|
||||
assert ips[0].ip_address == '192.168.1.10'
|
||||
assert ips[0].ping_expected is True
|
||||
assert ips[0].ping_actual is True
|
||||
|
||||
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
|
||||
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
|
||||
assert len(ports) == 5 # 4 TCP + 1 UDP
|
||||
|
||||
# Verify TCP ports
|
||||
tcp_ports = [p for p in ports if p.protocol == 'tcp']
|
||||
assert len(tcp_ports) == 4
|
||||
tcp_port_numbers = sorted([p.port for p in tcp_ports])
|
||||
assert tcp_port_numbers == [22, 80, 443, 8080]
|
||||
|
||||
# Verify UDP ports
|
||||
udp_ports = [p for p in ports if p.protocol == 'udp']
|
||||
assert len(udp_ports) == 1
|
||||
assert udp_ports[0].port == 53
|
||||
|
||||
# Verify services created
|
||||
services = test_db.query(ScanServiceModel).filter(
|
||||
ScanServiceModel.scan_id == scan_id
|
||||
).all()
|
||||
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
|
||||
|
||||
# Find HTTPS service
|
||||
https_service = next(
|
||||
(s for s in services if s.service_name == 'https'), None
|
||||
)
|
||||
assert https_service is not None
|
||||
assert https_service.product == 'nginx'
|
||||
assert https_service.version == '1.24.0'
|
||||
assert https_service.http_protocol == 'https'
|
||||
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
|
||||
|
||||
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test that expected vs actual ports are correctly flagged."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||
|
||||
# Check TCP ports
|
||||
tcp_ports = test_db.query(ScanPort).filter(
|
||||
ScanPort.scan_id == scan_id,
|
||||
ScanPort.protocol == 'tcp'
|
||||
).all()
|
||||
|
||||
# Ports 22, 80, 443 were expected
|
||||
expected_ports = {22, 80, 443}
|
||||
for port in tcp_ports:
|
||||
if port.port in expected_ports:
|
||||
assert port.expected is True, f"Port {port.port} should be expected"
|
||||
else:
|
||||
# Port 8080 was not expected
|
||||
assert port.expected is False, f"Port {port.port} should not be expected"
|
||||
|
||||
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test that certificate and TLS data are correctly mapped."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||
|
||||
# Find HTTPS service
|
||||
https_service = test_db.query(ScanServiceModel).filter(
|
||||
ScanServiceModel.scan_id == scan_id,
|
||||
ScanServiceModel.service_name == 'https'
|
||||
).first()
|
||||
|
||||
assert https_service is not None
|
||||
|
||||
# Verify certificate created
|
||||
assert len(https_service.certificates) == 1
|
||||
cert = https_service.certificates[0]
|
||||
|
||||
assert cert.subject == 'CN=example.com'
|
||||
assert cert.issuer == "CN=Let's Encrypt Authority"
|
||||
assert cert.days_until_expiry == 365
|
||||
assert cert.is_self_signed is False
|
||||
|
||||
# Verify SANs
|
||||
import json
|
||||
sans = json.loads(cert.sans)
|
||||
assert 'example.com' in sans
|
||||
assert 'www.example.com' in sans
|
||||
|
||||
# Verify TLS versions
|
||||
assert len(cert.tls_versions) == 2
|
||||
|
||||
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
|
||||
assert tls_12 is not None
|
||||
assert tls_12.supported is True
|
||||
|
||||
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
|
||||
assert tls_13 is not None
|
||||
assert tls_13.supported is True
|
||||
|
||||
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test retrieving scan with all nested relationships."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||
|
||||
# Get full scan details
|
||||
result = service.get_scan(scan_id)
|
||||
|
||||
assert result is not None
|
||||
assert len(result['sites']) == 1
|
||||
|
||||
site = result['sites'][0]
|
||||
assert site['name'] == 'Test Site'
|
||||
assert len(site['ips']) == 1
|
||||
|
||||
ip = site['ips'][0]
|
||||
assert ip['address'] == '192.168.1.10'
|
||||
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
|
||||
|
||||
# Find HTTPS port
|
||||
https_port = next(
|
||||
(p for p in ip['ports'] if p['port'] == 443), None
|
||||
)
|
||||
assert https_port is not None
|
||||
assert len(https_port['services']) == 1
|
||||
|
||||
service_data = https_port['services'][0]
|
||||
assert service_data['service_name'] == 'https'
|
||||
assert 'certificates' in service_data
|
||||
assert len(service_data['certificates']) == 1
|
||||
|
||||
cert = service_data['certificates'][0]
|
||||
assert cert['subject'] == 'CN=example.com'
|
||||
assert 'tls_versions' in cert
|
||||
assert len(cert['tls_versions']) == 2
|
||||
Reference in New Issue
Block a user