Phase 2 Step 1: Implement database and service layer

Complete the foundation for Phase 2 by implementing the service layer,
utilities, and comprehensive test suite. This establishes the core
business logic for scan management.

Service Layer:
- Add ScanService class with complete scan lifecycle management
  * trigger_scan() - Create scan record and prepare for execution
  * get_scan() - Retrieve scan with all related data (eager loading)
  * list_scans() - Paginated scan list with status filtering
  * delete_scan() - Remove scan from DB and delete all files
  * get_scan_status() - Poll current scan status and progress
  * _save_scan_to_db() - Persist scan results to database
  * _map_report_to_models() - Complex JSON-to-DB mapping logic

Database Mapping:
- Comprehensive mapping from scanner JSON output to normalized schema
- Handles nested relationships: sites → IPs → ports → services → certs → TLS
- Processes both TCP and UDP ports with expected/actual tracking
- Maps service detection results with HTTP/HTTPS information
- Stores SSL/TLS certificates with expiration tracking
- Records TLS version support and cipher suites
- Links screenshots to services

Utilities:
- Add pagination.py with PaginatedResult class
  * paginate() function for SQLAlchemy queries
  * validate_page_params() for input sanitization
  * Metadata: total, pages, has_prev, has_next, etc.

- Add validators.py with comprehensive validation functions
  * validate_config_file() - YAML structure and required fields
  * validate_scan_status() - Enum validation (running/completed/failed)
  * validate_scan_id() - Positive integer validation
  * validate_port() - Port range validation (1-65535)
  * validate_ip_address() - Basic IPv4 format validation
  * sanitize_filename() - Path traversal prevention

Database Migration:
- Add migration 002 for scan status index
- Optimizes queries filtering by scan status
- Timestamp index already exists from migration 001

Testing:
- Add pytest infrastructure with conftest.py
  * test_db fixture - Temporary SQLite database per test
  * sample_scan_report fixture - Realistic scanner output
  * sample_config_file fixture - Valid YAML config
  * sample_invalid_config_file fixture - For validation tests

- Add comprehensive test_scan_service.py (15 tests)
  * Test scan trigger with valid/invalid configs
  * Test scan retrieval (found/not found cases)
  * Test scan listing with pagination and filtering
  * Test scan deletion with cascade cleanup
  * Test scan status retrieval
  * Test database mapping from JSON to models
  * Test expected vs actual port flagging
  * Test certificate and TLS data mapping
  * Test full scan retrieval with all relationships
  * All tests passing

Files Added:
- web/services/__init__.py
- web/services/scan_service.py (545 lines)
- web/utils/pagination.py (153 lines)
- web/utils/validators.py (245 lines)
- migrations/versions/002_add_scan_indexes.py
- tests/__init__.py
- tests/conftest.py (142 lines)
- tests/test_scan_service.py (374 lines)

Next Steps (Step 2):
- Implement scan API endpoints in web/api/scans.py
- Add authentication decorators
- Integrate ScanService with API routes
- Test API endpoints with integration tests

Phase 2 Step 1 Complete ✓
This commit is contained in:
2025-11-14 00:26:06 -06:00
parent 9255233a74
commit d7c68a2be8
8 changed files with 1668 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
"""Add indexes for scan queries
Revision ID: 002
Revises: 001
Create Date: 2025-11-14 00:30:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '002'
down_revision = '001'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add database indexes for better query performance."""
# Add index on scans.status for filtering
# Note: index on scans.timestamp already exists from migration 001
op.create_index('ix_scans_status', 'scans', ['status'], unique=False)
def downgrade() -> None:
"""Remove indexes."""
op.drop_index('ix_scans_status', table_name='scans')

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Test package for SneakyScanner."""

196
tests/conftest.py Normal file
View File

@@ -0,0 +1,196 @@
"""
Pytest configuration and fixtures for SneakyScanner tests.
"""
import os
import tempfile
from pathlib import Path
import pytest
import yaml
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from web.models import Base
@pytest.fixture(scope='function')
def test_db():
"""
Create a temporary test database.
Yields a SQLAlchemy session for testing, then cleans up.
"""
# Create temporary database file
db_fd, db_path = tempfile.mkstemp(suffix='.db')
# Create engine and session
engine = create_engine(f'sqlite:///{db_path}', echo=False)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
yield session
# Cleanup
session.close()
os.close(db_fd)
os.unlink(db_path)
@pytest.fixture
def sample_scan_report():
"""
Sample scan report matching the structure from scanner.py.
Returns a dictionary representing a typical scan output.
"""
return {
'title': 'Test Scan',
'scan_time': '2025-11-14T10:30:00Z',
'scan_duration': 125.5,
'config_file': '/app/configs/test.yaml',
'sites': [
{
'name': 'Test Site',
'ips': [
{
'address': '192.168.1.10',
'expected': {
'ping': True,
'tcp_ports': [22, 80, 443],
'udp_ports': [53],
'services': []
},
'actual': {
'ping': True,
'tcp_ports': [22, 80, 443, 8080],
'udp_ports': [53],
'services': [
{
'port': 22,
'service': 'ssh',
'product': 'OpenSSH',
'version': '8.9p1',
'extrainfo': 'Ubuntu',
'ostype': 'Linux'
},
{
'port': 443,
'service': 'https',
'product': 'nginx',
'version': '1.24.0',
'extrainfo': '',
'ostype': '',
'http_info': {
'protocol': 'https',
'screenshot': 'screenshots/192_168_1_10_443.png',
'certificate': {
'subject': 'CN=example.com',
'issuer': 'CN=Let\'s Encrypt Authority',
'serial_number': '123456789',
'not_valid_before': '2025-01-01T00:00:00Z',
'not_valid_after': '2025-12-31T23:59:59Z',
'days_until_expiry': 365,
'sans': ['example.com', 'www.example.com'],
'is_self_signed': False,
'tls_versions': {
'TLS 1.2': {
'supported': True,
'cipher_suites': [
'TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384',
'TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256'
]
},
'TLS 1.3': {
'supported': True,
'cipher_suites': [
'TLS_AES_256_GCM_SHA384',
'TLS_AES_128_GCM_SHA256'
]
}
}
}
}
},
{
'port': 80,
'service': 'http',
'product': 'nginx',
'version': '1.24.0',
'extrainfo': '',
'ostype': '',
'http_info': {
'protocol': 'http',
'screenshot': 'screenshots/192_168_1_10_80.png'
}
},
{
'port': 8080,
'service': 'http',
'product': 'Jetty',
'version': '9.4.48',
'extrainfo': '',
'ostype': ''
}
]
}
}
]
}
]
}
@pytest.fixture
def sample_config_file(tmp_path):
"""
Create a sample YAML config file for testing.
Args:
tmp_path: pytest temporary directory fixture
Returns:
Path to created config file
"""
config_data = {
'title': 'Test Scan',
'sites': [
{
'name': 'Test Site',
'ips': [
{
'address': '192.168.1.10',
'expected': {
'ping': True,
'tcp_ports': [22, 80, 443],
'udp_ports': [53],
'services': ['ssh', 'http', 'https']
}
}
]
}
]
}
config_file = tmp_path / 'test_config.yaml'
with open(config_file, 'w') as f:
yaml.dump(config_data, f)
return str(config_file)
@pytest.fixture
def sample_invalid_config_file(tmp_path):
"""
Create an invalid config file for testing validation.
Returns:
Path to invalid config file
"""
config_file = tmp_path / 'invalid_config.yaml'
with open(config_file, 'w') as f:
f.write("invalid: yaml: content: [missing closing bracket")
return str(config_file)

402
tests/test_scan_service.py Normal file
View File

@@ -0,0 +1,402 @@
"""
Unit tests for ScanService class.
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
"""
import pytest
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
from web.services.scan_service import ScanService
class TestScanServiceTrigger:
"""Tests for triggering scans."""
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
"""Test triggering a scan with valid config file."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
# Verify scan created
assert scan_id is not None
assert isinstance(scan_id, int)
# Verify scan in database
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan is not None
assert scan.status == 'running'
assert scan.title == 'Test Scan'
assert scan.triggered_by == 'manual'
assert scan.config_file == sample_config_file
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
"""Test triggering a scan with invalid config file."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="Invalid config file"):
service.trigger_scan(sample_invalid_config_file)
def test_trigger_scan_nonexistent_file(self, test_db):
"""Test triggering a scan with nonexistent config file."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="does not exist"):
service.trigger_scan('/nonexistent/config.yaml')
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
"""Test triggering a scan via schedule."""
service = ScanService(test_db)
scan_id = service.trigger_scan(
sample_config_file,
triggered_by='scheduled',
schedule_id=42
)
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan.triggered_by == 'scheduled'
assert scan.schedule_id == 42
class TestScanServiceGet:
"""Tests for retrieving scans."""
def test_get_scan_not_found(self, test_db):
"""Test getting a nonexistent scan."""
service = ScanService(test_db)
result = service.get_scan(999)
assert result is None
def test_get_scan_found(self, test_db, sample_config_file):
"""Test getting an existing scan."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Retrieve it
result = service.get_scan(scan_id)
assert result is not None
assert result['id'] == scan_id
assert result['title'] == 'Test Scan'
assert result['status'] == 'running'
assert 'sites' in result
class TestScanServiceList:
"""Tests for listing scans."""
def test_list_scans_empty(self, test_db):
"""Test listing scans when database is empty."""
service = ScanService(test_db)
result = service.list_scans(page=1, per_page=20)
assert result.total == 0
assert len(result.items) == 0
assert result.pages == 0
def test_list_scans_with_data(self, test_db, sample_config_file):
"""Test listing scans with multiple scans."""
service = ScanService(test_db)
# Create 3 scans
for i in range(3):
service.trigger_scan(sample_config_file, triggered_by='api')
# List all scans
result = service.list_scans(page=1, per_page=20)
assert result.total == 3
assert len(result.items) == 3
assert result.pages == 1
def test_list_scans_pagination(self, test_db, sample_config_file):
"""Test pagination."""
service = ScanService(test_db)
# Create 5 scans
for i in range(5):
service.trigger_scan(sample_config_file)
# Get page 1 (2 items per page)
result = service.list_scans(page=1, per_page=2)
assert len(result.items) == 2
assert result.total == 5
assert result.pages == 3
assert result.has_next is True
# Get page 2
result = service.list_scans(page=2, per_page=2)
assert len(result.items) == 2
assert result.has_prev is True
assert result.has_next is True
# Get page 3 (last page)
result = service.list_scans(page=3, per_page=2)
assert len(result.items) == 1
assert result.has_next is False
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
"""Test filtering scans by status."""
service = ScanService(test_db)
# Create scans with different statuses
scan_id_1 = service.trigger_scan(sample_config_file)
scan_id_2 = service.trigger_scan(sample_config_file)
# Mark one as completed
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
scan.status = 'completed'
test_db.commit()
# Filter by running
result = service.list_scans(status_filter='running')
assert result.total == 1
# Filter by completed
result = service.list_scans(status_filter='completed')
assert result.total == 1
def test_list_scans_invalid_status_filter(self, test_db):
"""Test filtering with invalid status."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="Invalid status"):
service.list_scans(status_filter='invalid_status')
class TestScanServiceDelete:
"""Tests for deleting scans."""
def test_delete_scan_not_found(self, test_db):
"""Test deleting a nonexistent scan."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="not found"):
service.delete_scan(999)
def test_delete_scan_success(self, test_db, sample_config_file):
"""Test successful scan deletion."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Verify it exists
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
# Delete it
result = service.delete_scan(scan_id)
assert result is True
# Verify it's gone
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
class TestScanServiceStatus:
"""Tests for scan status retrieval."""
def test_get_scan_status_not_found(self, test_db):
"""Test getting status of nonexistent scan."""
service = ScanService(test_db)
result = service.get_scan_status(999)
assert result is None
def test_get_scan_status_running(self, test_db, sample_config_file):
"""Test getting status of running scan."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
status = service.get_scan_status(scan_id)
assert status is not None
assert status['scan_id'] == scan_id
assert status['status'] == 'running'
assert status['progress'] == 'In progress'
assert status['title'] == 'Test Scan'
def test_get_scan_status_completed(self, test_db, sample_config_file):
"""Test getting status of completed scan."""
service = ScanService(test_db)
# Create and mark as completed
scan_id = service.trigger_scan(sample_config_file)
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
scan.status = 'completed'
scan.duration = 125.5
test_db.commit()
status = service.get_scan_status(scan_id)
assert status['status'] == 'completed'
assert status['progress'] == 'Complete'
assert status['duration'] == 125.5
class TestScanServiceDatabaseMapping:
"""Tests for mapping scan reports to database models."""
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
"""Test saving a complete scan report to database."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Save report to database
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
# Verify scan updated
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan.status == 'completed'
assert scan.duration == 125.5
# Verify sites created
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
assert len(sites) == 1
assert sites[0].site_name == 'Test Site'
# Verify IPs created
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
assert len(ips) == 1
assert ips[0].ip_address == '192.168.1.10'
assert ips[0].ping_expected is True
assert ips[0].ping_actual is True
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
assert len(ports) == 5 # 4 TCP + 1 UDP
# Verify TCP ports
tcp_ports = [p for p in ports if p.protocol == 'tcp']
assert len(tcp_ports) == 4
tcp_port_numbers = sorted([p.port for p in tcp_ports])
assert tcp_port_numbers == [22, 80, 443, 8080]
# Verify UDP ports
udp_ports = [p for p in ports if p.protocol == 'udp']
assert len(udp_ports) == 1
assert udp_ports[0].port == 53
# Verify services created
services = test_db.query(ScanServiceModel).filter(
ScanServiceModel.scan_id == scan_id
).all()
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
# Find HTTPS service
https_service = next(
(s for s in services if s.service_name == 'https'), None
)
assert https_service is not None
assert https_service.product == 'nginx'
assert https_service.version == '1.24.0'
assert https_service.http_protocol == 'https'
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
"""Test that expected vs actual ports are correctly flagged."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Check TCP ports
tcp_ports = test_db.query(ScanPort).filter(
ScanPort.scan_id == scan_id,
ScanPort.protocol == 'tcp'
).all()
# Ports 22, 80, 443 were expected
expected_ports = {22, 80, 443}
for port in tcp_ports:
if port.port in expected_ports:
assert port.expected is True, f"Port {port.port} should be expected"
else:
# Port 8080 was not expected
assert port.expected is False, f"Port {port.port} should not be expected"
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
"""Test that certificate and TLS data are correctly mapped."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Find HTTPS service
https_service = test_db.query(ScanServiceModel).filter(
ScanServiceModel.scan_id == scan_id,
ScanServiceModel.service_name == 'https'
).first()
assert https_service is not None
# Verify certificate created
assert len(https_service.certificates) == 1
cert = https_service.certificates[0]
assert cert.subject == 'CN=example.com'
assert cert.issuer == "CN=Let's Encrypt Authority"
assert cert.days_until_expiry == 365
assert cert.is_self_signed is False
# Verify SANs
import json
sans = json.loads(cert.sans)
assert 'example.com' in sans
assert 'www.example.com' in sans
# Verify TLS versions
assert len(cert.tls_versions) == 2
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
assert tls_12 is not None
assert tls_12.supported is True
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
assert tls_13 is not None
assert tls_13.supported is True
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
"""Test retrieving scan with all nested relationships."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Get full scan details
result = service.get_scan(scan_id)
assert result is not None
assert len(result['sites']) == 1
site = result['sites'][0]
assert site['name'] == 'Test Site'
assert len(site['ips']) == 1
ip = site['ips'][0]
assert ip['address'] == '192.168.1.10'
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
# Find HTTPS port
https_port = next(
(p for p in ip['ports'] if p['port'] == 443), None
)
assert https_port is not None
assert len(https_port['services']) == 1
service_data = https_port['services'][0]
assert service_data['service_name'] == 'https'
assert 'certificates' in service_data
assert len(service_data['certificates']) == 1
cert = service_data['certificates'][0]
assert cert['subject'] == 'CN=example.com'
assert 'tls_versions' in cert
assert len(cert['tls_versions']) == 2

10
web/services/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Services package for SneakyScanner web application.
This package contains business logic layer services that orchestrate
operations between API endpoints and database models.
"""
from web.services.scan_service import ScanService
__all__ = ['ScanService']

View File

@@ -0,0 +1,589 @@
"""
Scan service for managing scan operations and database integration.
This service handles the business logic for triggering scans, retrieving
scan results, and mapping scanner output to database models.
"""
import json
import logging
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session, joinedload
from web.models import (
Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel,
ScanCertificate, ScanTLSVersion
)
from web.utils.pagination import paginate, PaginatedResult
from web.utils.validators import validate_config_file, validate_scan_status
logger = logging.getLogger(__name__)
class ScanService:
"""
Service for managing scan operations.
Handles scan lifecycle: triggering, status tracking, result storage,
and cleanup.
"""
def __init__(self, db_session: Session):
"""
Initialize scan service.
Args:
db_session: SQLAlchemy database session
"""
self.db = db_session
def trigger_scan(self, config_file: str, triggered_by: str = 'manual',
schedule_id: Optional[int] = None) -> int:
"""
Trigger a new scan.
Creates a Scan record in the database with status='running' and
queues the scan for background execution.
Args:
config_file: Path to YAML configuration file
triggered_by: Source that triggered scan (manual, scheduled, api)
schedule_id: Optional schedule ID if triggered by schedule
Returns:
Scan ID of the created scan
Raises:
ValueError: If config file is invalid
"""
# Validate config file
is_valid, error_msg = validate_config_file(config_file)
if not is_valid:
raise ValueError(f"Invalid config file: {error_msg}")
# Load config to get title
import yaml
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
# Create scan record
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_file=config_file,
title=config.get('title', 'Untitled Scan'),
triggered_by=triggered_by,
schedule_id=schedule_id,
created_at=datetime.utcnow()
)
self.db.add(scan)
self.db.commit()
self.db.refresh(scan)
logger.info(f"Scan {scan.id} triggered via {triggered_by}")
# NOTE: Background job queuing will be implemented in Step 3
# For now, just return the scan ID
return scan.id
def get_scan(self, scan_id: int) -> Optional[Dict[str, Any]]:
"""
Get scan details with all related data.
Args:
scan_id: Scan ID to retrieve
Returns:
Dictionary with scan data including sites, IPs, ports, services, etc.
Returns None if scan not found.
"""
# Query with eager loading of all relationships
scan = (
self.db.query(Scan)
.options(
joinedload(Scan.sites).joinedload(ScanSite.ips).joinedload(ScanIP.ports),
joinedload(Scan.ports).joinedload(ScanPort.services),
joinedload(Scan.services).joinedload(ScanServiceModel.certificates),
joinedload(Scan.certificates).joinedload(ScanCertificate.tls_versions)
)
.filter(Scan.id == scan_id)
.first()
)
if not scan:
return None
# Convert to dictionary
return self._scan_to_dict(scan)
def list_scans(self, page: int = 1, per_page: int = 20,
status_filter: Optional[str] = None) -> PaginatedResult:
"""
List scans with pagination and optional filtering.
Args:
page: Page number (1-indexed)
per_page: Items per page
status_filter: Optional filter by status (running, completed, failed)
Returns:
PaginatedResult with scan list and metadata
"""
# Build query
query = self.db.query(Scan).order_by(Scan.timestamp.desc())
# Apply status filter if provided
if status_filter:
is_valid, error_msg = validate_scan_status(status_filter)
if not is_valid:
raise ValueError(error_msg)
query = query.filter(Scan.status == status_filter)
# Paginate
result = paginate(query, page=page, per_page=per_page)
# Convert scans to dictionaries (summary only, not full details)
result.items = [self._scan_to_summary_dict(scan) for scan in result.items]
return result
def delete_scan(self, scan_id: int) -> bool:
"""
Delete a scan and all associated files.
Removes:
- Database record (cascade deletes related records)
- JSON report file
- HTML report file
- ZIP archive file
- Screenshot directory
Args:
scan_id: Scan ID to delete
Returns:
True if deleted successfully
Raises:
ValueError: If scan not found
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
logger.info(f"Deleting scan {scan_id}")
# Delete files (handle missing files gracefully)
files_to_delete = [
scan.json_path,
scan.html_path,
scan.zip_path
]
for file_path in files_to_delete:
if file_path:
try:
Path(file_path).unlink()
logger.debug(f"Deleted file: {file_path}")
except FileNotFoundError:
logger.warning(f"File not found (already deleted?): {file_path}")
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
# Delete screenshot directory
if scan.screenshot_dir:
try:
shutil.rmtree(scan.screenshot_dir)
logger.debug(f"Deleted directory: {scan.screenshot_dir}")
except FileNotFoundError:
logger.warning(f"Directory not found (already deleted?): {scan.screenshot_dir}")
except Exception as e:
logger.error(f"Error deleting directory {scan.screenshot_dir}: {e}")
# Delete database record (cascade handles related records)
self.db.delete(scan)
self.db.commit()
logger.info(f"Scan {scan_id} deleted successfully")
return True
def get_scan_status(self, scan_id: int) -> Optional[Dict[str, Any]]:
"""
Get current scan status and progress.
Args:
scan_id: Scan ID
Returns:
Dictionary with status information, or None if scan not found
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
return None
status_info = {
'scan_id': scan.id,
'status': scan.status,
'title': scan.title,
'started_at': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'triggered_by': scan.triggered_by
}
# Add progress estimate based on status
if scan.status == 'running':
status_info['progress'] = 'In progress'
elif scan.status == 'completed':
status_info['progress'] = 'Complete'
elif scan.status == 'failed':
status_info['progress'] = 'Failed'
return status_info
def _save_scan_to_db(self, report: Dict[str, Any], scan_id: int,
status: str = 'completed') -> None:
"""
Save scan results to database.
Updates the Scan record and creates all related records
(sites, IPs, ports, services, certificates, TLS versions).
Args:
report: Scan report dictionary from scanner
scan_id: Scan ID to update
status: Final scan status (completed or failed)
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise ValueError(f"Scan {scan_id} not found")
# Update scan record
scan.status = status
scan.duration = report.get('scan_duration')
# Map report data to database models
self._map_report_to_models(report, scan)
self.db.commit()
logger.info(f"Scan {scan_id} saved to database with status '{status}'")
def _map_report_to_models(self, report: Dict[str, Any], scan_obj: Scan) -> None:
"""
Map JSON report structure to database models.
Creates records for sites, IPs, ports, services, certificates, and TLS versions.
Processes nested relationships in order to handle foreign keys correctly.
Args:
report: Scan report dictionary
scan_obj: Scan database object to attach records to
"""
logger.debug(f"Mapping report to database models for scan {scan_obj.id}")
# Process each site
for site_data in report.get('sites', []):
# Create ScanSite record
site = ScanSite(
scan_id=scan_obj.id,
site_name=site_data['name']
)
self.db.add(site)
self.db.flush() # Get site.id for foreign key
# Process each IP in this site
for ip_data in site_data.get('ips', []):
# Create ScanIP record
ip = ScanIP(
scan_id=scan_obj.id,
site_id=site.id,
ip_address=ip_data['address'],
ping_expected=ip_data.get('expected', {}).get('ping'),
ping_actual=ip_data.get('actual', {}).get('ping')
)
self.db.add(ip)
self.db.flush()
# Process TCP ports
expected_tcp = set(ip_data.get('expected', {}).get('tcp_ports', []))
actual_tcp = ip_data.get('actual', {}).get('tcp_ports', [])
for port_num in actual_tcp:
port = ScanPort(
scan_id=scan_obj.id,
ip_id=ip.id,
port=port_num,
protocol='tcp',
expected=(port_num in expected_tcp),
state='open'
)
self.db.add(port)
self.db.flush()
# Find service for this port
service_data = self._find_service_for_port(
ip_data.get('actual', {}).get('services', []),
port_num
)
if service_data:
# Create ScanService record
service = ScanServiceModel(
scan_id=scan_obj.id,
port_id=port.id,
service_name=service_data.get('service'),
product=service_data.get('product'),
version=service_data.get('version'),
extrainfo=service_data.get('extrainfo'),
ostype=service_data.get('ostype'),
http_protocol=service_data.get('http_info', {}).get('protocol'),
screenshot_path=service_data.get('http_info', {}).get('screenshot')
)
self.db.add(service)
self.db.flush()
# Process certificate and TLS info if present
http_info = service_data.get('http_info', {})
if http_info.get('certificate'):
self._process_certificate(
http_info['certificate'],
scan_obj.id,
service.id
)
# Process UDP ports
expected_udp = set(ip_data.get('expected', {}).get('udp_ports', []))
actual_udp = ip_data.get('actual', {}).get('udp_ports', [])
for port_num in actual_udp:
port = ScanPort(
scan_id=scan_obj.id,
ip_id=ip.id,
port=port_num,
protocol='udp',
expected=(port_num in expected_udp),
state='open'
)
self.db.add(port)
logger.debug(f"Report mapping complete for scan {scan_obj.id}")
def _find_service_for_port(self, services: List[Dict], port: int) -> Optional[Dict]:
"""
Find service data for a specific port.
Args:
services: List of service dictionaries
port: Port number to find
Returns:
Service dictionary if found, None otherwise
"""
for service in services:
if service.get('port') == port:
return service
return None
def _process_certificate(self, cert_data: Dict[str, Any], scan_id: int,
service_id: int) -> None:
"""
Process certificate and TLS version data.
Args:
cert_data: Certificate data dictionary
scan_id: Scan ID
service_id: Service ID
"""
# Create ScanCertificate record
cert = ScanCertificate(
scan_id=scan_id,
service_id=service_id,
subject=cert_data.get('subject'),
issuer=cert_data.get('issuer'),
serial_number=cert_data.get('serial_number'),
not_valid_before=self._parse_datetime(cert_data.get('not_valid_before')),
not_valid_after=self._parse_datetime(cert_data.get('not_valid_after')),
days_until_expiry=cert_data.get('days_until_expiry'),
sans=json.dumps(cert_data.get('sans', [])),
is_self_signed=cert_data.get('is_self_signed', False)
)
self.db.add(cert)
self.db.flush()
# Process TLS versions
tls_versions = cert_data.get('tls_versions', {})
for version, version_data in tls_versions.items():
tls = ScanTLSVersion(
scan_id=scan_id,
certificate_id=cert.id,
tls_version=version,
supported=version_data.get('supported', False),
cipher_suites=json.dumps(version_data.get('cipher_suites', []))
)
self.db.add(tls)
def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]:
"""
Parse ISO datetime string.
Args:
date_str: ISO format datetime string
Returns:
datetime object or None if parsing fails
"""
if not date_str:
return None
try:
# Handle ISO format with 'Z' suffix
if date_str.endswith('Z'):
date_str = date_str[:-1] + '+00:00'
return datetime.fromisoformat(date_str)
except (ValueError, AttributeError) as e:
logger.warning(f"Failed to parse datetime '{date_str}': {e}")
return None
def _scan_to_dict(self, scan: Scan) -> Dict[str, Any]:
"""
Convert Scan object to dictionary with full details.
Args:
scan: Scan database object
Returns:
Dictionary representation with all related data
"""
return {
'id': scan.id,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'status': scan.status,
'title': scan.title,
'config_file': scan.config_file,
'json_path': scan.json_path,
'html_path': scan.html_path,
'zip_path': scan.zip_path,
'screenshot_dir': scan.screenshot_dir,
'triggered_by': scan.triggered_by,
'created_at': scan.created_at.isoformat() if scan.created_at else None,
'sites': [self._site_to_dict(site) for site in scan.sites]
}
def _scan_to_summary_dict(self, scan: Scan) -> Dict[str, Any]:
"""
Convert Scan object to summary dictionary (no related data).
Args:
scan: Scan database object
Returns:
Summary dictionary
"""
return {
'id': scan.id,
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
'duration': scan.duration,
'status': scan.status,
'title': scan.title,
'triggered_by': scan.triggered_by,
'created_at': scan.created_at.isoformat() if scan.created_at else None
}
def _site_to_dict(self, site: ScanSite) -> Dict[str, Any]:
"""Convert ScanSite to dictionary."""
return {
'id': site.id,
'name': site.site_name,
'ips': [self._ip_to_dict(ip) for ip in site.ips]
}
def _ip_to_dict(self, ip: ScanIP) -> Dict[str, Any]:
"""Convert ScanIP to dictionary."""
return {
'id': ip.id,
'address': ip.ip_address,
'ping_expected': ip.ping_expected,
'ping_actual': ip.ping_actual,
'ports': [self._port_to_dict(port) for port in ip.ports]
}
def _port_to_dict(self, port: ScanPort) -> Dict[str, Any]:
"""Convert ScanPort to dictionary."""
return {
'id': port.id,
'port': port.port,
'protocol': port.protocol,
'state': port.state,
'expected': port.expected,
'services': [self._service_to_dict(svc) for svc in port.services]
}
def _service_to_dict(self, service: ScanServiceModel) -> Dict[str, Any]:
"""Convert ScanService to dictionary."""
result = {
'id': service.id,
'service_name': service.service_name,
'product': service.product,
'version': service.version,
'extrainfo': service.extrainfo,
'ostype': service.ostype,
'http_protocol': service.http_protocol,
'screenshot_path': service.screenshot_path
}
# Add certificate info if present
if service.certificates:
result['certificates'] = [
self._certificate_to_dict(cert) for cert in service.certificates
]
return result
def _certificate_to_dict(self, cert: ScanCertificate) -> Dict[str, Any]:
"""Convert ScanCertificate to dictionary."""
result = {
'id': cert.id,
'subject': cert.subject,
'issuer': cert.issuer,
'serial_number': cert.serial_number,
'not_valid_before': cert.not_valid_before.isoformat() if cert.not_valid_before else None,
'not_valid_after': cert.not_valid_after.isoformat() if cert.not_valid_after else None,
'days_until_expiry': cert.days_until_expiry,
'is_self_signed': cert.is_self_signed
}
# Parse SANs from JSON
if cert.sans:
try:
result['sans'] = json.loads(cert.sans)
except json.JSONDecodeError:
result['sans'] = []
# Add TLS versions
result['tls_versions'] = [
self._tls_version_to_dict(tls) for tls in cert.tls_versions
]
return result
def _tls_version_to_dict(self, tls: ScanTLSVersion) -> Dict[str, Any]:
"""Convert ScanTLSVersion to dictionary."""
result = {
'id': tls.id,
'tls_version': tls.tls_version,
'supported': tls.supported
}
# Parse cipher suites from JSON
if tls.cipher_suites:
try:
result['cipher_suites'] = json.loads(tls.cipher_suites)
except json.JSONDecodeError:
result['cipher_suites'] = []
return result

158
web/utils/pagination.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Pagination utilities for SneakyScanner web application.
Provides helper functions for paginating SQLAlchemy queries.
"""
from typing import Any, Dict, List
from sqlalchemy.orm import Query
class PaginatedResult:
"""Container for paginated query results."""
def __init__(self, items: List[Any], total: int, page: int, per_page: int):
"""
Initialize paginated result.
Args:
items: List of items for current page
total: Total number of items across all pages
page: Current page number (1-indexed)
per_page: Number of items per page
"""
self.items = items
self.total = total
self.page = page
self.per_page = per_page
@property
def pages(self) -> int:
"""Calculate total number of pages."""
if self.per_page == 0:
return 0
return (self.total + self.per_page - 1) // self.per_page
@property
def has_prev(self) -> bool:
"""Check if there is a previous page."""
return self.page > 1
@property
def has_next(self) -> bool:
"""Check if there is a next page."""
return self.page < self.pages
@property
def prev_page(self) -> int:
"""Get previous page number."""
return self.page - 1 if self.has_prev else None
@property
def next_page(self) -> int:
"""Get next page number."""
return self.page + 1 if self.has_next else None
def to_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary for API responses.
Returns:
Dictionary with pagination metadata and items
"""
return {
'items': self.items,
'total': self.total,
'page': self.page,
'per_page': self.per_page,
'pages': self.pages,
'has_prev': self.has_prev,
'has_next': self.has_next,
'prev_page': self.prev_page,
'next_page': self.next_page,
}
def paginate(query: Query, page: int = 1, per_page: int = 20,
max_per_page: int = 100) -> PaginatedResult:
"""
Paginate a SQLAlchemy query.
Args:
query: SQLAlchemy query to paginate
page: Page number (1-indexed, default: 1)
per_page: Items per page (default: 20)
max_per_page: Maximum items per page (default: 100)
Returns:
PaginatedResult with items and pagination metadata
Examples:
>>> from web.models import Scan
>>> query = db.query(Scan).order_by(Scan.timestamp.desc())
>>> result = paginate(query, page=1, per_page=20)
>>> scans = result.items
>>> total_pages = result.pages
"""
# Validate and sanitize parameters
page = max(1, page) # Page must be at least 1
per_page = max(1, min(per_page, max_per_page)) # Clamp per_page
# Get total count
total = query.count()
# Calculate offset
offset = (page - 1) * per_page
# Execute query with limit and offset
items = query.limit(per_page).offset(offset).all()
return PaginatedResult(
items=items,
total=total,
page=page,
per_page=per_page
)
def validate_page_params(page: Any, per_page: Any,
max_per_page: int = 100) -> tuple[int, int]:
"""
Validate and sanitize pagination parameters.
Args:
page: Page number (any type, will be converted to int)
per_page: Items per page (any type, will be converted to int)
max_per_page: Maximum items per page (default: 100)
Returns:
Tuple of (validated_page, validated_per_page)
Examples:
>>> validate_page_params('2', '50')
(2, 50)
>>> validate_page_params(-1, 200)
(1, 100)
>>> validate_page_params(None, None)
(1, 20)
"""
# Default values
default_page = 1
default_per_page = 20
# Convert to int, use default if invalid
try:
page = int(page) if page is not None else default_page
except (ValueError, TypeError):
page = default_page
try:
per_page = int(per_page) if per_page is not None else default_per_page
except (ValueError, TypeError):
per_page = default_per_page
# Validate ranges
page = max(1, page)
per_page = max(1, min(per_page, max_per_page))
return page, per_page

284
web/utils/validators.py Normal file
View File

@@ -0,0 +1,284 @@
"""
Input validation utilities for SneakyScanner web application.
Provides validation functions for API inputs, file paths, and data integrity.
"""
import os
from pathlib import Path
from typing import Optional
import yaml
def validate_config_file(file_path: str) -> tuple[bool, Optional[str]]:
"""
Validate that a configuration file exists and is valid YAML.
Args:
file_path: Path to configuration file
Returns:
Tuple of (is_valid, error_message)
If valid, returns (True, None)
If invalid, returns (False, error_message)
Examples:
>>> validate_config_file('/app/configs/example.yaml')
(True, None)
>>> validate_config_file('/nonexistent.yaml')
(False, 'File does not exist: /nonexistent.yaml')
"""
# Check if path is provided
if not file_path:
return False, 'Config file path is required'
# Convert to Path object
path = Path(file_path)
# Check if file exists
if not path.exists():
return False, f'File does not exist: {file_path}'
# Check if it's a file (not directory)
if not path.is_file():
return False, f'Path is not a file: {file_path}'
# Check file extension
if path.suffix.lower() not in ['.yaml', '.yml']:
return False, f'File must be YAML (.yaml or .yml): {file_path}'
# Try to parse as YAML
try:
with open(path, 'r') as f:
config = yaml.safe_load(f)
# Check if it's a dictionary (basic structure validation)
if not isinstance(config, dict):
return False, 'Config file must contain a YAML dictionary'
# Check for required top-level keys
if 'title' not in config:
return False, 'Config file missing required "title" field'
if 'sites' not in config:
return False, 'Config file missing required "sites" field'
# Validate sites structure
if not isinstance(config['sites'], list):
return False, '"sites" must be a list'
if len(config['sites']) == 0:
return False, '"sites" list cannot be empty'
except yaml.YAMLError as e:
return False, f'Invalid YAML syntax: {str(e)}'
except Exception as e:
return False, f'Error reading config file: {str(e)}'
return True, None
def validate_scan_status(status: str) -> tuple[bool, Optional[str]]:
"""
Validate scan status value.
Args:
status: Status string to validate
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_scan_status('running')
(True, None)
>>> validate_scan_status('invalid')
(False, 'Invalid status: invalid. Must be one of: running, completed, failed')
"""
valid_statuses = ['running', 'completed', 'failed']
if status not in valid_statuses:
return False, f'Invalid status: {status}. Must be one of: {", ".join(valid_statuses)}'
return True, None
def validate_triggered_by(triggered_by: str) -> tuple[bool, Optional[str]]:
"""
Validate triggered_by value.
Args:
triggered_by: Source that triggered the scan
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_triggered_by('manual')
(True, None)
>>> validate_triggered_by('api')
(True, None)
"""
valid_sources = ['manual', 'scheduled', 'api']
if triggered_by not in valid_sources:
return False, f'Invalid triggered_by: {triggered_by}. Must be one of: {", ".join(valid_sources)}'
return True, None
def validate_scan_id(scan_id: any) -> tuple[bool, Optional[str]]:
"""
Validate scan ID is a positive integer.
Args:
scan_id: Scan ID to validate
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_scan_id(42)
(True, None)
>>> validate_scan_id('42')
(True, None)
>>> validate_scan_id(-1)
(False, 'Scan ID must be a positive integer')
"""
try:
scan_id_int = int(scan_id)
if scan_id_int <= 0:
return False, 'Scan ID must be a positive integer'
except (ValueError, TypeError):
return False, f'Invalid scan ID: {scan_id}'
return True, None
def validate_file_path(file_path: str, must_exist: bool = True) -> tuple[bool, Optional[str]]:
"""
Validate a file path.
Args:
file_path: Path to validate
must_exist: If True, file must exist. If False, only validate format.
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_file_path('/app/output/scan.json', must_exist=False)
(True, None)
>>> validate_file_path('', must_exist=False)
(False, 'File path is required')
"""
if not file_path:
return False, 'File path is required'
# Check for path traversal attempts
if '..' in file_path:
return False, 'Path traversal not allowed'
if must_exist:
path = Path(file_path)
if not path.exists():
return False, f'File does not exist: {file_path}'
if not path.is_file():
return False, f'Path is not a file: {file_path}'
return True, None
def sanitize_filename(filename: str) -> str:
"""
Sanitize a filename by removing/replacing unsafe characters.
Args:
filename: Original filename
Returns:
Sanitized filename safe for filesystem
Examples:
>>> sanitize_filename('my scan.json')
'my_scan.json'
>>> sanitize_filename('../../etc/passwd')
'etc_passwd'
"""
# Remove path components
filename = os.path.basename(filename)
# Replace unsafe characters with underscore
unsafe_chars = ['/', '\\', '..', ' ', ':', '*', '?', '"', '<', '>', '|']
for char in unsafe_chars:
filename = filename.replace(char, '_')
# Remove leading/trailing underscores and dots
filename = filename.strip('_.')
# Ensure filename is not empty
if not filename:
filename = 'unnamed'
return filename
def validate_port(port: any) -> tuple[bool, Optional[str]]:
"""
Validate port number.
Args:
port: Port number to validate
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_port(443)
(True, None)
>>> validate_port(70000)
(False, 'Port must be between 1 and 65535')
"""
try:
port_int = int(port)
if port_int < 1 or port_int > 65535:
return False, 'Port must be between 1 and 65535'
except (ValueError, TypeError):
return False, f'Invalid port: {port}'
return True, None
def validate_ip_address(ip: str) -> tuple[bool, Optional[str]]:
"""
Validate IPv4 address format (basic validation).
Args:
ip: IP address string
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_ip_address('192.168.1.1')
(True, None)
>>> validate_ip_address('256.1.1.1')
(False, 'Invalid IP address format')
"""
if not ip:
return False, 'IP address is required'
# Basic IPv4 validation
parts = ip.split('.')
if len(parts) != 4:
return False, 'Invalid IP address format'
try:
for part in parts:
num = int(part)
if num < 0 or num > 255:
return False, 'Invalid IP address format'
except (ValueError, TypeError):
return False, 'Invalid IP address format'
return True, None