291 lines
7.7 KiB
Python
291 lines
7.7 KiB
Python
"""
|
|
Input validation utilities for SneakyScanner web application.
|
|
|
|
Provides validation functions for API inputs, file paths, and data integrity.
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import yaml
|
|
|
|
|
|
def validate_config_file(file_path: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate that a configuration file exists and is valid YAML.
|
|
|
|
Args:
|
|
file_path: Path to configuration file (absolute or relative filename)
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
If valid, returns (True, None)
|
|
If invalid, returns (False, error_message)
|
|
|
|
Examples:
|
|
>>> validate_config_file('/app/configs/example.yaml')
|
|
(True, None)
|
|
>>> validate_config_file('example.yaml')
|
|
(True, None)
|
|
>>> validate_config_file('/nonexistent.yaml')
|
|
(False, 'File does not exist: /nonexistent.yaml')
|
|
"""
|
|
# Check if path is provided
|
|
if not file_path:
|
|
return False, 'Config file path is required'
|
|
|
|
# If file_path is just a filename (not absolute), prepend configs directory
|
|
if not file_path.startswith('/'):
|
|
file_path = f'/app/configs/{file_path}'
|
|
|
|
# Convert to Path object
|
|
path = Path(file_path)
|
|
|
|
# Check if file exists
|
|
if not path.exists():
|
|
return False, f'File does not exist: {file_path}'
|
|
|
|
# Check if it's a file (not directory)
|
|
if not path.is_file():
|
|
return False, f'Path is not a file: {file_path}'
|
|
|
|
# Check file extension
|
|
if path.suffix.lower() not in ['.yaml', '.yml']:
|
|
return False, f'File must be YAML (.yaml or .yml): {file_path}'
|
|
|
|
# Try to parse as YAML
|
|
try:
|
|
with open(path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Check if it's a dictionary (basic structure validation)
|
|
if not isinstance(config, dict):
|
|
return False, 'Config file must contain a YAML dictionary'
|
|
|
|
# Check for required top-level keys
|
|
if 'title' not in config:
|
|
return False, 'Config file missing required "title" field'
|
|
|
|
if 'sites' not in config:
|
|
return False, 'Config file missing required "sites" field'
|
|
|
|
# Validate sites structure
|
|
if not isinstance(config['sites'], list):
|
|
return False, '"sites" must be a list'
|
|
|
|
if len(config['sites']) == 0:
|
|
return False, '"sites" list cannot be empty'
|
|
|
|
except yaml.YAMLError as e:
|
|
return False, f'Invalid YAML syntax: {str(e)}'
|
|
except Exception as e:
|
|
return False, f'Error reading config file: {str(e)}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_scan_status(status: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate scan status value.
|
|
|
|
Args:
|
|
status: Status string to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_scan_status('running')
|
|
(True, None)
|
|
>>> validate_scan_status('invalid')
|
|
(False, 'Invalid status: invalid. Must be one of: running, completed, failed')
|
|
"""
|
|
valid_statuses = ['running', 'completed', 'failed']
|
|
|
|
if status not in valid_statuses:
|
|
return False, f'Invalid status: {status}. Must be one of: {", ".join(valid_statuses)}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_triggered_by(triggered_by: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate triggered_by value.
|
|
|
|
Args:
|
|
triggered_by: Source that triggered the scan
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_triggered_by('manual')
|
|
(True, None)
|
|
>>> validate_triggered_by('api')
|
|
(True, None)
|
|
"""
|
|
valid_sources = ['manual', 'scheduled', 'api']
|
|
|
|
if triggered_by not in valid_sources:
|
|
return False, f'Invalid triggered_by: {triggered_by}. Must be one of: {", ".join(valid_sources)}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_scan_id(scan_id: any) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate scan ID is a positive integer.
|
|
|
|
Args:
|
|
scan_id: Scan ID to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_scan_id(42)
|
|
(True, None)
|
|
>>> validate_scan_id('42')
|
|
(True, None)
|
|
>>> validate_scan_id(-1)
|
|
(False, 'Scan ID must be a positive integer')
|
|
"""
|
|
try:
|
|
scan_id_int = int(scan_id)
|
|
if scan_id_int <= 0:
|
|
return False, 'Scan ID must be a positive integer'
|
|
except (ValueError, TypeError):
|
|
return False, f'Invalid scan ID: {scan_id}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_file_path(file_path: str, must_exist: bool = True) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate a file path.
|
|
|
|
Args:
|
|
file_path: Path to validate
|
|
must_exist: If True, file must exist. If False, only validate format.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_file_path('/app/output/scan.json', must_exist=False)
|
|
(True, None)
|
|
>>> validate_file_path('', must_exist=False)
|
|
(False, 'File path is required')
|
|
"""
|
|
if not file_path:
|
|
return False, 'File path is required'
|
|
|
|
# Check for path traversal attempts
|
|
if '..' in file_path:
|
|
return False, 'Path traversal not allowed'
|
|
|
|
if must_exist:
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
return False, f'File does not exist: {file_path}'
|
|
if not path.is_file():
|
|
return False, f'Path is not a file: {file_path}'
|
|
|
|
return True, None
|
|
|
|
|
|
def sanitize_filename(filename: str) -> str:
|
|
"""
|
|
Sanitize a filename by removing/replacing unsafe characters.
|
|
|
|
Args:
|
|
filename: Original filename
|
|
|
|
Returns:
|
|
Sanitized filename safe for filesystem
|
|
|
|
Examples:
|
|
>>> sanitize_filename('my scan.json')
|
|
'my_scan.json'
|
|
>>> sanitize_filename('../../etc/passwd')
|
|
'etc_passwd'
|
|
"""
|
|
# Remove path components
|
|
filename = os.path.basename(filename)
|
|
|
|
# Replace unsafe characters with underscore
|
|
unsafe_chars = ['/', '\\', '..', ' ', ':', '*', '?', '"', '<', '>', '|']
|
|
for char in unsafe_chars:
|
|
filename = filename.replace(char, '_')
|
|
|
|
# Remove leading/trailing underscores and dots
|
|
filename = filename.strip('_.')
|
|
|
|
# Ensure filename is not empty
|
|
if not filename:
|
|
filename = 'unnamed'
|
|
|
|
return filename
|
|
|
|
|
|
def validate_port(port: any) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate port number.
|
|
|
|
Args:
|
|
port: Port number to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_port(443)
|
|
(True, None)
|
|
>>> validate_port(70000)
|
|
(False, 'Port must be between 1 and 65535')
|
|
"""
|
|
try:
|
|
port_int = int(port)
|
|
if port_int < 1 or port_int > 65535:
|
|
return False, 'Port must be between 1 and 65535'
|
|
except (ValueError, TypeError):
|
|
return False, f'Invalid port: {port}'
|
|
|
|
return True, None
|
|
|
|
|
|
def validate_ip_address(ip: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate IPv4 address format (basic validation).
|
|
|
|
Args:
|
|
ip: IP address string
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
|
|
Examples:
|
|
>>> validate_ip_address('192.168.1.1')
|
|
(True, None)
|
|
>>> validate_ip_address('256.1.1.1')
|
|
(False, 'Invalid IP address format')
|
|
"""
|
|
if not ip:
|
|
return False, 'IP address is required'
|
|
|
|
# Basic IPv4 validation
|
|
parts = ip.split('.')
|
|
if len(parts) != 4:
|
|
return False, 'Invalid IP address format'
|
|
|
|
try:
|
|
for part in parts:
|
|
num = int(part)
|
|
if num < 0 or num > 255:
|
|
return False, 'Invalid IP address format'
|
|
except (ValueError, TypeError):
|
|
return False, 'Invalid IP address format'
|
|
|
|
return True, None
|