Phase 2 Step 1: Implement database and service layer

Complete the foundation for Phase 2 by implementing the service layer,
utilities, and comprehensive test suite. This establishes the core
business logic for scan management.

Service Layer:
- Add ScanService class with complete scan lifecycle management
  * trigger_scan() - Create scan record and prepare for execution
  * get_scan() - Retrieve scan with all related data (eager loading)
  * list_scans() - Paginated scan list with status filtering
  * delete_scan() - Remove scan from DB and delete all files
  * get_scan_status() - Poll current scan status and progress
  * _save_scan_to_db() - Persist scan results to database
  * _map_report_to_models() - Complex JSON-to-DB mapping logic

Database Mapping:
- Comprehensive mapping from scanner JSON output to normalized schema
- Handles nested relationships: sites → IPs → ports → services → certs → TLS
- Processes both TCP and UDP ports with expected/actual tracking
- Maps service detection results with HTTP/HTTPS information
- Stores SSL/TLS certificates with expiration tracking
- Records TLS version support and cipher suites
- Links screenshots to services

Utilities:
- Add pagination.py with PaginatedResult class
  * paginate() function for SQLAlchemy queries
  * validate_page_params() for input sanitization
  * Metadata: total, pages, has_prev, has_next, etc.

- Add validators.py with comprehensive validation functions
  * validate_config_file() - YAML structure and required fields
  * validate_scan_status() - Enum validation (running/completed/failed)
  * validate_scan_id() - Positive integer validation
  * validate_port() - Port range validation (1-65535)
  * validate_ip_address() - Basic IPv4 format validation
  * sanitize_filename() - Path traversal prevention

Database Migration:
- Add migration 002 for scan status index
- Optimizes queries filtering by scan status
- Timestamp index already exists from migration 001

Testing:
- Add pytest infrastructure with conftest.py
  * test_db fixture - Temporary SQLite database per test
  * sample_scan_report fixture - Realistic scanner output
  * sample_config_file fixture - Valid YAML config
  * sample_invalid_config_file fixture - For validation tests

- Add comprehensive test_scan_service.py (15 tests)
  * Test scan trigger with valid/invalid configs
  * Test scan retrieval (found/not found cases)
  * Test scan listing with pagination and filtering
  * Test scan deletion with cascade cleanup
  * Test scan status retrieval
  * Test database mapping from JSON to models
  * Test expected vs actual port flagging
  * Test certificate and TLS data mapping
  * Test full scan retrieval with all relationships
  * All tests passing

Files Added:
- web/services/__init__.py
- web/services/scan_service.py (545 lines)
- web/utils/pagination.py (153 lines)
- web/utils/validators.py (245 lines)
- migrations/versions/002_add_scan_indexes.py
- tests/__init__.py
- tests/conftest.py (142 lines)
- tests/test_scan_service.py (374 lines)

Next Steps (Step 2):
- Implement scan API endpoints in web/api/scans.py
- Add authentication decorators
- Integrate ScanService with API routes
- Test API endpoints with integration tests

Phase 2 Step 1 Complete ✓
This commit is contained in:
2025-11-14 00:26:06 -06:00
parent 9255233a74
commit d7c68a2be8
8 changed files with 1668 additions and 0 deletions

158
web/utils/pagination.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Pagination utilities for SneakyScanner web application.
Provides helper functions for paginating SQLAlchemy queries.
"""
from typing import Any, Dict, List
from sqlalchemy.orm import Query
class PaginatedResult:
"""Container for paginated query results."""
def __init__(self, items: List[Any], total: int, page: int, per_page: int):
"""
Initialize paginated result.
Args:
items: List of items for current page
total: Total number of items across all pages
page: Current page number (1-indexed)
per_page: Number of items per page
"""
self.items = items
self.total = total
self.page = page
self.per_page = per_page
@property
def pages(self) -> int:
"""Calculate total number of pages."""
if self.per_page == 0:
return 0
return (self.total + self.per_page - 1) // self.per_page
@property
def has_prev(self) -> bool:
"""Check if there is a previous page."""
return self.page > 1
@property
def has_next(self) -> bool:
"""Check if there is a next page."""
return self.page < self.pages
@property
def prev_page(self) -> int:
"""Get previous page number."""
return self.page - 1 if self.has_prev else None
@property
def next_page(self) -> int:
"""Get next page number."""
return self.page + 1 if self.has_next else None
def to_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary for API responses.
Returns:
Dictionary with pagination metadata and items
"""
return {
'items': self.items,
'total': self.total,
'page': self.page,
'per_page': self.per_page,
'pages': self.pages,
'has_prev': self.has_prev,
'has_next': self.has_next,
'prev_page': self.prev_page,
'next_page': self.next_page,
}
def paginate(query: Query, page: int = 1, per_page: int = 20,
max_per_page: int = 100) -> PaginatedResult:
"""
Paginate a SQLAlchemy query.
Args:
query: SQLAlchemy query to paginate
page: Page number (1-indexed, default: 1)
per_page: Items per page (default: 20)
max_per_page: Maximum items per page (default: 100)
Returns:
PaginatedResult with items and pagination metadata
Examples:
>>> from web.models import Scan
>>> query = db.query(Scan).order_by(Scan.timestamp.desc())
>>> result = paginate(query, page=1, per_page=20)
>>> scans = result.items
>>> total_pages = result.pages
"""
# Validate and sanitize parameters
page = max(1, page) # Page must be at least 1
per_page = max(1, min(per_page, max_per_page)) # Clamp per_page
# Get total count
total = query.count()
# Calculate offset
offset = (page - 1) * per_page
# Execute query with limit and offset
items = query.limit(per_page).offset(offset).all()
return PaginatedResult(
items=items,
total=total,
page=page,
per_page=per_page
)
def validate_page_params(page: Any, per_page: Any,
max_per_page: int = 100) -> tuple[int, int]:
"""
Validate and sanitize pagination parameters.
Args:
page: Page number (any type, will be converted to int)
per_page: Items per page (any type, will be converted to int)
max_per_page: Maximum items per page (default: 100)
Returns:
Tuple of (validated_page, validated_per_page)
Examples:
>>> validate_page_params('2', '50')
(2, 50)
>>> validate_page_params(-1, 200)
(1, 100)
>>> validate_page_params(None, None)
(1, 20)
"""
# Default values
default_page = 1
default_per_page = 20
# Convert to int, use default if invalid
try:
page = int(page) if page is not None else default_page
except (ValueError, TypeError):
page = default_page
try:
per_page = int(per_page) if per_page is not None else default_per_page
except (ValueError, TypeError):
per_page = default_per_page
# Validate ranges
page = max(1, page)
per_page = max(1, min(per_page, max_per_page))
return page, per_page

284
web/utils/validators.py Normal file
View File

@@ -0,0 +1,284 @@
"""
Input validation utilities for SneakyScanner web application.
Provides validation functions for API inputs, file paths, and data integrity.
"""
import os
from pathlib import Path
from typing import Optional
import yaml
def validate_config_file(file_path: str) -> tuple[bool, Optional[str]]:
"""
Validate that a configuration file exists and is valid YAML.
Args:
file_path: Path to configuration file
Returns:
Tuple of (is_valid, error_message)
If valid, returns (True, None)
If invalid, returns (False, error_message)
Examples:
>>> validate_config_file('/app/configs/example.yaml')
(True, None)
>>> validate_config_file('/nonexistent.yaml')
(False, 'File does not exist: /nonexistent.yaml')
"""
# Check if path is provided
if not file_path:
return False, 'Config file path is required'
# Convert to Path object
path = Path(file_path)
# Check if file exists
if not path.exists():
return False, f'File does not exist: {file_path}'
# Check if it's a file (not directory)
if not path.is_file():
return False, f'Path is not a file: {file_path}'
# Check file extension
if path.suffix.lower() not in ['.yaml', '.yml']:
return False, f'File must be YAML (.yaml or .yml): {file_path}'
# Try to parse as YAML
try:
with open(path, 'r') as f:
config = yaml.safe_load(f)
# Check if it's a dictionary (basic structure validation)
if not isinstance(config, dict):
return False, 'Config file must contain a YAML dictionary'
# Check for required top-level keys
if 'title' not in config:
return False, 'Config file missing required "title" field'
if 'sites' not in config:
return False, 'Config file missing required "sites" field'
# Validate sites structure
if not isinstance(config['sites'], list):
return False, '"sites" must be a list'
if len(config['sites']) == 0:
return False, '"sites" list cannot be empty'
except yaml.YAMLError as e:
return False, f'Invalid YAML syntax: {str(e)}'
except Exception as e:
return False, f'Error reading config file: {str(e)}'
return True, None
def validate_scan_status(status: str) -> tuple[bool, Optional[str]]:
"""
Validate scan status value.
Args:
status: Status string to validate
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_scan_status('running')
(True, None)
>>> validate_scan_status('invalid')
(False, 'Invalid status: invalid. Must be one of: running, completed, failed')
"""
valid_statuses = ['running', 'completed', 'failed']
if status not in valid_statuses:
return False, f'Invalid status: {status}. Must be one of: {", ".join(valid_statuses)}'
return True, None
def validate_triggered_by(triggered_by: str) -> tuple[bool, Optional[str]]:
"""
Validate triggered_by value.
Args:
triggered_by: Source that triggered the scan
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_triggered_by('manual')
(True, None)
>>> validate_triggered_by('api')
(True, None)
"""
valid_sources = ['manual', 'scheduled', 'api']
if triggered_by not in valid_sources:
return False, f'Invalid triggered_by: {triggered_by}. Must be one of: {", ".join(valid_sources)}'
return True, None
def validate_scan_id(scan_id: any) -> tuple[bool, Optional[str]]:
"""
Validate scan ID is a positive integer.
Args:
scan_id: Scan ID to validate
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_scan_id(42)
(True, None)
>>> validate_scan_id('42')
(True, None)
>>> validate_scan_id(-1)
(False, 'Scan ID must be a positive integer')
"""
try:
scan_id_int = int(scan_id)
if scan_id_int <= 0:
return False, 'Scan ID must be a positive integer'
except (ValueError, TypeError):
return False, f'Invalid scan ID: {scan_id}'
return True, None
def validate_file_path(file_path: str, must_exist: bool = True) -> tuple[bool, Optional[str]]:
"""
Validate a file path.
Args:
file_path: Path to validate
must_exist: If True, file must exist. If False, only validate format.
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_file_path('/app/output/scan.json', must_exist=False)
(True, None)
>>> validate_file_path('', must_exist=False)
(False, 'File path is required')
"""
if not file_path:
return False, 'File path is required'
# Check for path traversal attempts
if '..' in file_path:
return False, 'Path traversal not allowed'
if must_exist:
path = Path(file_path)
if not path.exists():
return False, f'File does not exist: {file_path}'
if not path.is_file():
return False, f'Path is not a file: {file_path}'
return True, None
def sanitize_filename(filename: str) -> str:
"""
Sanitize a filename by removing/replacing unsafe characters.
Args:
filename: Original filename
Returns:
Sanitized filename safe for filesystem
Examples:
>>> sanitize_filename('my scan.json')
'my_scan.json'
>>> sanitize_filename('../../etc/passwd')
'etc_passwd'
"""
# Remove path components
filename = os.path.basename(filename)
# Replace unsafe characters with underscore
unsafe_chars = ['/', '\\', '..', ' ', ':', '*', '?', '"', '<', '>', '|']
for char in unsafe_chars:
filename = filename.replace(char, '_')
# Remove leading/trailing underscores and dots
filename = filename.strip('_.')
# Ensure filename is not empty
if not filename:
filename = 'unnamed'
return filename
def validate_port(port: any) -> tuple[bool, Optional[str]]:
"""
Validate port number.
Args:
port: Port number to validate
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_port(443)
(True, None)
>>> validate_port(70000)
(False, 'Port must be between 1 and 65535')
"""
try:
port_int = int(port)
if port_int < 1 or port_int > 65535:
return False, 'Port must be between 1 and 65535'
except (ValueError, TypeError):
return False, f'Invalid port: {port}'
return True, None
def validate_ip_address(ip: str) -> tuple[bool, Optional[str]]:
"""
Validate IPv4 address format (basic validation).
Args:
ip: IP address string
Returns:
Tuple of (is_valid, error_message)
Examples:
>>> validate_ip_address('192.168.1.1')
(True, None)
>>> validate_ip_address('256.1.1.1')
(False, 'Invalid IP address format')
"""
if not ip:
return False, 'IP address is required'
# Basic IPv4 validation
parts = ip.split('.')
if len(parts) != 4:
return False, 'Invalid IP address format'
try:
for part in parts:
num = int(part)
if num < 0 or num > 255:
return False, 'Invalid IP address format'
except (ValueError, TypeError):
return False, 'Invalid IP address format'
return True, None