Complete the foundation for Phase 2 by implementing the service layer, utilities, and comprehensive test suite. This establishes the core business logic for scan management. Service Layer: - Add ScanService class with complete scan lifecycle management * trigger_scan() - Create scan record and prepare for execution * get_scan() - Retrieve scan with all related data (eager loading) * list_scans() - Paginated scan list with status filtering * delete_scan() - Remove scan from DB and delete all files * get_scan_status() - Poll current scan status and progress * _save_scan_to_db() - Persist scan results to database * _map_report_to_models() - Complex JSON-to-DB mapping logic Database Mapping: - Comprehensive mapping from scanner JSON output to normalized schema - Handles nested relationships: sites → IPs → ports → services → certs → TLS - Processes both TCP and UDP ports with expected/actual tracking - Maps service detection results with HTTP/HTTPS information - Stores SSL/TLS certificates with expiration tracking - Records TLS version support and cipher suites - Links screenshots to services Utilities: - Add pagination.py with PaginatedResult class * paginate() function for SQLAlchemy queries * validate_page_params() for input sanitization * Metadata: total, pages, has_prev, has_next, etc. - Add validators.py with comprehensive validation functions * validate_config_file() - YAML structure and required fields * validate_scan_status() - Enum validation (running/completed/failed) * validate_scan_id() - Positive integer validation * validate_port() - Port range validation (1-65535) * validate_ip_address() - Basic IPv4 format validation * sanitize_filename() - Path traversal prevention Database Migration: - Add migration 002 for scan status index - Optimizes queries filtering by scan status - Timestamp index already exists from migration 001 Testing: - Add pytest infrastructure with conftest.py * test_db fixture - Temporary SQLite database per test * sample_scan_report fixture - Realistic scanner output * sample_config_file fixture - Valid YAML config * sample_invalid_config_file fixture - For validation tests - Add comprehensive test_scan_service.py (15 tests) * Test scan trigger with valid/invalid configs * Test scan retrieval (found/not found cases) * Test scan listing with pagination and filtering * Test scan deletion with cascade cleanup * Test scan status retrieval * Test database mapping from JSON to models * Test expected vs actual port flagging * Test certificate and TLS data mapping * Test full scan retrieval with all relationships * All tests passing Files Added: - web/services/__init__.py - web/services/scan_service.py (545 lines) - web/utils/pagination.py (153 lines) - web/utils/validators.py (245 lines) - migrations/versions/002_add_scan_indexes.py - tests/__init__.py - tests/conftest.py (142 lines) - tests/test_scan_service.py (374 lines) Next Steps (Step 2): - Implement scan API endpoints in web/api/scans.py - Add authentication decorators - Integrate ScanService with API routes - Test API endpoints with integration tests Phase 2 Step 1 Complete ✓
285 lines
7.4 KiB
Python
285 lines
7.4 KiB
Python
"""
|
|
Input validation utilities for SneakyScanner web application.
|
|
|
|
Provides validation functions for API inputs, file paths, and data integrity.
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import yaml
|
|
|
|
|
|
def validate_config_file(file_path: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate that a configuration file exists and is valid YAML.
|
|
|
|
Args:
|
|
file_path: Path to configuration file
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
If valid, returns (True, None)
|
|
If invalid, returns (False, error_message)
|
|
|
|
Examples:
|
|
>>> validate_config_file('/app/configs/example.yaml')
|
|
(True, None)
|
|
>>> validate_config_file('/nonexistent.yaml')
|
|
(False, 'File does not exist: /nonexistent.yaml')
|
|
"""
|
|
# Check if path is provided
|
|
if not file_path:
|
|
return False, 'Config file path is required'
|
|
|
|
# Convert to Path object
|
|
path = Path(file_path)
|
|
|
|
# Check if file exists
|
|
if not path.exists():
|
|
return False, f'File does not exist: {file_path}'
|
|
|
|
# Check if it's a file (not directory)
|
|
if not path.is_file():
|
|
return False, f'Path is not a file: {file_path}'
|
|
|
|
# Check file extension
|
|
if path.suffix.lower() not in ['.yaml', '.yml']:
|
|
return False, f'File must be YAML (.yaml or .yml): {file_path}'
|
|
|
|
# Try to parse as YAML
|
|
try:
|
|
with open(path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Check if it's a dictionary (basic structure validation)
|
|
if not isinstance(config, dict):
|
|
return False, 'Config file must contain a YAML dictionary'
|
|
|
|
# Check for required top-level keys
|
|
if 'title' not in config:
|
|
return False, 'Config file missing required "title" field'
|
|
|
|
if 'sites' not in config:
|
|
return False, 'Config file missing required "sites" field'
|
|
|
|
# Validate sites structure
|
|
if not isinstance(config['sites'], list):
|
|
return False, '"sites" must be a list'
|
|
|
|
if len(config['sites']) == 0:
|
|
return False, '"sites" list cannot be empty'
|
|
|
|
except yaml.YAMLError as e:
|
|
return False, f'Invalid YAML syntax: {str(e)}'
|
|
except Exception as e:
|
|
return False, f'Error reading config file: {str(e)}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_scan_status(status: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate scan status value.
|
|
|
|
Args:
|
|
status: Status string to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_scan_status('running')
|
|
(True, None)
|
|
>>> validate_scan_status('invalid')
|
|
(False, 'Invalid status: invalid. Must be one of: running, completed, failed')
|
|
"""
|
|
valid_statuses = ['running', 'completed', 'failed']
|
|
|
|
if status not in valid_statuses:
|
|
return False, f'Invalid status: {status}. Must be one of: {", ".join(valid_statuses)}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_triggered_by(triggered_by: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate triggered_by value.
|
|
|
|
Args:
|
|
triggered_by: Source that triggered the scan
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_triggered_by('manual')
|
|
(True, None)
|
|
>>> validate_triggered_by('api')
|
|
(True, None)
|
|
"""
|
|
valid_sources = ['manual', 'scheduled', 'api']
|
|
|
|
if triggered_by not in valid_sources:
|
|
return False, f'Invalid triggered_by: {triggered_by}. Must be one of: {", ".join(valid_sources)}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_scan_id(scan_id: any) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate scan ID is a positive integer.
|
|
|
|
Args:
|
|
scan_id: Scan ID to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_scan_id(42)
|
|
(True, None)
|
|
>>> validate_scan_id('42')
|
|
(True, None)
|
|
>>> validate_scan_id(-1)
|
|
(False, 'Scan ID must be a positive integer')
|
|
"""
|
|
try:
|
|
scan_id_int = int(scan_id)
|
|
if scan_id_int <= 0:
|
|
return False, 'Scan ID must be a positive integer'
|
|
except (ValueError, TypeError):
|
|
return False, f'Invalid scan ID: {scan_id}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_file_path(file_path: str, must_exist: bool = True) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate a file path.
|
|
|
|
Args:
|
|
file_path: Path to validate
|
|
must_exist: If True, file must exist. If False, only validate format.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_file_path('/app/output/scan.json', must_exist=False)
|
|
(True, None)
|
|
>>> validate_file_path('', must_exist=False)
|
|
(False, 'File path is required')
|
|
"""
|
|
if not file_path:
|
|
return False, 'File path is required'
|
|
|
|
# Check for path traversal attempts
|
|
if '..' in file_path:
|
|
return False, 'Path traversal not allowed'
|
|
|
|
if must_exist:
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
return False, f'File does not exist: {file_path}'
|
|
if not path.is_file():
|
|
return False, f'Path is not a file: {file_path}'
|
|
|
|
return True, None
|
|
|
|
|
|
def sanitize_filename(filename: str) -> str:
|
|
"""
|
|
Sanitize a filename by removing/replacing unsafe characters.
|
|
|
|
Args:
|
|
filename: Original filename
|
|
|
|
Returns:
|
|
Sanitized filename safe for filesystem
|
|
|
|
Examples:
|
|
>>> sanitize_filename('my scan.json')
|
|
'my_scan.json'
|
|
>>> sanitize_filename('../../etc/passwd')
|
|
'etc_passwd'
|
|
"""
|
|
# Remove path components
|
|
filename = os.path.basename(filename)
|
|
|
|
# Replace unsafe characters with underscore
|
|
unsafe_chars = ['/', '\\', '..', ' ', ':', '*', '?', '"', '<', '>', '|']
|
|
for char in unsafe_chars:
|
|
filename = filename.replace(char, '_')
|
|
|
|
# Remove leading/trailing underscores and dots
|
|
filename = filename.strip('_.')
|
|
|
|
# Ensure filename is not empty
|
|
if not filename:
|
|
filename = 'unnamed'
|
|
|
|
return filename
|
|
|
|
|
|
def validate_port(port: any) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate port number.
|
|
|
|
Args:
|
|
port: Port number to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_port(443)
|
|
(True, None)
|
|
>>> validate_port(70000)
|
|
(False, 'Port must be between 1 and 65535')
|
|
"""
|
|
try:
|
|
port_int = int(port)
|
|
if port_int < 1 or port_int > 65535:
|
|
return False, 'Port must be between 1 and 65535'
|
|
except (ValueError, TypeError):
|
|
return False, f'Invalid port: {port}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_ip_address(ip: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate IPv4 address format (basic validation).
|
|
|
|
Args:
|
|
ip: IP address string
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_ip_address('192.168.1.1')
|
|
(True, None)
|
|
>>> validate_ip_address('256.1.1.1')
|
|
(False, 'Invalid IP address format')
|
|
"""
|
|
if not ip:
|
|
return False, 'IP address is required'
|
|
|
|
# Basic IPv4 validation
|
|
parts = ip.split('.')
|
|
if len(parts) != 4:
|
|
return False, 'Invalid IP address format'
|
|
|
|
try:
|
|
for part in parts:
|
|
num = int(part)
|
|
if num < 0 or num > 255:
|
|
return False, 'Invalid IP address format'
|
|
except (ValueError, TypeError):
|
|
return False, 'Invalid IP address format'
|
|
|
|
return True, None
|