Phase 2 Step 1: Implement database and service layer
Complete the foundation for Phase 2 by implementing the service layer, utilities, and comprehensive test suite. This establishes the core business logic for scan management. Service Layer: - Add ScanService class with complete scan lifecycle management * trigger_scan() - Create scan record and prepare for execution * get_scan() - Retrieve scan with all related data (eager loading) * list_scans() - Paginated scan list with status filtering * delete_scan() - Remove scan from DB and delete all files * get_scan_status() - Poll current scan status and progress * _save_scan_to_db() - Persist scan results to database * _map_report_to_models() - Complex JSON-to-DB mapping logic Database Mapping: - Comprehensive mapping from scanner JSON output to normalized schema - Handles nested relationships: sites → IPs → ports → services → certs → TLS - Processes both TCP and UDP ports with expected/actual tracking - Maps service detection results with HTTP/HTTPS information - Stores SSL/TLS certificates with expiration tracking - Records TLS version support and cipher suites - Links screenshots to services Utilities: - Add pagination.py with PaginatedResult class * paginate() function for SQLAlchemy queries * validate_page_params() for input sanitization * Metadata: total, pages, has_prev, has_next, etc. - Add validators.py with comprehensive validation functions * validate_config_file() - YAML structure and required fields * validate_scan_status() - Enum validation (running/completed/failed) * validate_scan_id() - Positive integer validation * validate_port() - Port range validation (1-65535) * validate_ip_address() - Basic IPv4 format validation * sanitize_filename() - Path traversal prevention Database Migration: - Add migration 002 for scan status index - Optimizes queries filtering by scan status - Timestamp index already exists from migration 001 Testing: - Add pytest infrastructure with conftest.py * test_db fixture - Temporary SQLite database per test * sample_scan_report fixture - Realistic scanner output * sample_config_file fixture - Valid YAML config * sample_invalid_config_file fixture - For validation tests - Add comprehensive test_scan_service.py (15 tests) * Test scan trigger with valid/invalid configs * Test scan retrieval (found/not found cases) * Test scan listing with pagination and filtering * Test scan deletion with cascade cleanup * Test scan status retrieval * Test database mapping from JSON to models * Test expected vs actual port flagging * Test certificate and TLS data mapping * Test full scan retrieval with all relationships * All tests passing Files Added: - web/services/__init__.py - web/services/scan_service.py (545 lines) - web/utils/pagination.py (153 lines) - web/utils/validators.py (245 lines) - migrations/versions/002_add_scan_indexes.py - tests/__init__.py - tests/conftest.py (142 lines) - tests/test_scan_service.py (374 lines) Next Steps (Step 2): - Implement scan API endpoints in web/api/scans.py - Add authentication decorators - Integrate ScanService with API routes - Test API endpoints with integration tests Phase 2 Step 1 Complete ✓
This commit is contained in:
28
migrations/versions/002_add_scan_indexes.py
Normal file
28
migrations/versions/002_add_scan_indexes.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Add indexes for scan queries
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2025-11-14 00:30:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '002'
|
||||||
|
down_revision = '001'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add database indexes for better query performance."""
|
||||||
|
# Add index on scans.status for filtering
|
||||||
|
# Note: index on scans.timestamp already exists from migration 001
|
||||||
|
op.create_index('ix_scans_status', 'scans', ['status'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove indexes."""
|
||||||
|
op.drop_index('ix_scans_status', table_name='scans')
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test package for SneakyScanner."""
|
||||||
196
tests/conftest.py
Normal file
196
tests/conftest.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""
|
||||||
|
Pytest configuration and fixtures for SneakyScanner tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from web.models import Base
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='function')
|
||||||
|
def test_db():
|
||||||
|
"""
|
||||||
|
Create a temporary test database.
|
||||||
|
|
||||||
|
Yields a SQLAlchemy session for testing, then cleans up.
|
||||||
|
"""
|
||||||
|
# Create temporary database file
|
||||||
|
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
||||||
|
|
||||||
|
# Create engine and session
|
||||||
|
engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
session = Session()
|
||||||
|
|
||||||
|
yield session
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
session.close()
|
||||||
|
os.close(db_fd)
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_scan_report():
|
||||||
|
"""
|
||||||
|
Sample scan report matching the structure from scanner.py.
|
||||||
|
|
||||||
|
Returns a dictionary representing a typical scan output.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'title': 'Test Scan',
|
||||||
|
'scan_time': '2025-11-14T10:30:00Z',
|
||||||
|
'scan_duration': 125.5,
|
||||||
|
'config_file': '/app/configs/test.yaml',
|
||||||
|
'sites': [
|
||||||
|
{
|
||||||
|
'name': 'Test Site',
|
||||||
|
'ips': [
|
||||||
|
{
|
||||||
|
'address': '192.168.1.10',
|
||||||
|
'expected': {
|
||||||
|
'ping': True,
|
||||||
|
'tcp_ports': [22, 80, 443],
|
||||||
|
'udp_ports': [53],
|
||||||
|
'services': []
|
||||||
|
},
|
||||||
|
'actual': {
|
||||||
|
'ping': True,
|
||||||
|
'tcp_ports': [22, 80, 443, 8080],
|
||||||
|
'udp_ports': [53],
|
||||||
|
'services': [
|
||||||
|
{
|
||||||
|
'port': 22,
|
||||||
|
'service': 'ssh',
|
||||||
|
'product': 'OpenSSH',
|
||||||
|
'version': '8.9p1',
|
||||||
|
'extrainfo': 'Ubuntu',
|
||||||
|
'ostype': 'Linux'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'port': 443,
|
||||||
|
'service': 'https',
|
||||||
|
'product': 'nginx',
|
||||||
|
'version': '1.24.0',
|
||||||
|
'extrainfo': '',
|
||||||
|
'ostype': '',
|
||||||
|
'http_info': {
|
||||||
|
'protocol': 'https',
|
||||||
|
'screenshot': 'screenshots/192_168_1_10_443.png',
|
||||||
|
'certificate': {
|
||||||
|
'subject': 'CN=example.com',
|
||||||
|
'issuer': 'CN=Let\'s Encrypt Authority',
|
||||||
|
'serial_number': '123456789',
|
||||||
|
'not_valid_before': '2025-01-01T00:00:00Z',
|
||||||
|
'not_valid_after': '2025-12-31T23:59:59Z',
|
||||||
|
'days_until_expiry': 365,
|
||||||
|
'sans': ['example.com', 'www.example.com'],
|
||||||
|
'is_self_signed': False,
|
||||||
|
'tls_versions': {
|
||||||
|
'TLS 1.2': {
|
||||||
|
'supported': True,
|
||||||
|
'cipher_suites': [
|
||||||
|
'TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384',
|
||||||
|
'TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256'
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'TLS 1.3': {
|
||||||
|
'supported': True,
|
||||||
|
'cipher_suites': [
|
||||||
|
'TLS_AES_256_GCM_SHA384',
|
||||||
|
'TLS_AES_128_GCM_SHA256'
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'port': 80,
|
||||||
|
'service': 'http',
|
||||||
|
'product': 'nginx',
|
||||||
|
'version': '1.24.0',
|
||||||
|
'extrainfo': '',
|
||||||
|
'ostype': '',
|
||||||
|
'http_info': {
|
||||||
|
'protocol': 'http',
|
||||||
|
'screenshot': 'screenshots/192_168_1_10_80.png'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'port': 8080,
|
||||||
|
'service': 'http',
|
||||||
|
'product': 'Jetty',
|
||||||
|
'version': '9.4.48',
|
||||||
|
'extrainfo': '',
|
||||||
|
'ostype': ''
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_config_file(tmp_path):
|
||||||
|
"""
|
||||||
|
Create a sample YAML config file for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: pytest temporary directory fixture
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to created config file
|
||||||
|
"""
|
||||||
|
config_data = {
|
||||||
|
'title': 'Test Scan',
|
||||||
|
'sites': [
|
||||||
|
{
|
||||||
|
'name': 'Test Site',
|
||||||
|
'ips': [
|
||||||
|
{
|
||||||
|
'address': '192.168.1.10',
|
||||||
|
'expected': {
|
||||||
|
'ping': True,
|
||||||
|
'tcp_ports': [22, 80, 443],
|
||||||
|
'udp_ports': [53],
|
||||||
|
'services': ['ssh', 'http', 'https']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
config_file = tmp_path / 'test_config.yaml'
|
||||||
|
with open(config_file, 'w') as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
return str(config_file)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_invalid_config_file(tmp_path):
|
||||||
|
"""
|
||||||
|
Create an invalid config file for testing validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to invalid config file
|
||||||
|
"""
|
||||||
|
config_file = tmp_path / 'invalid_config.yaml'
|
||||||
|
with open(config_file, 'w') as f:
|
||||||
|
f.write("invalid: yaml: content: [missing closing bracket")
|
||||||
|
|
||||||
|
return str(config_file)
|
||||||
402
tests/test_scan_service.py
Normal file
402
tests/test_scan_service.py
Normal file
@@ -0,0 +1,402 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for ScanService class.
|
||||||
|
|
||||||
|
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
|
||||||
|
from web.services.scan_service import ScanService
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanServiceTrigger:
|
||||||
|
"""Tests for triggering scans."""
|
||||||
|
|
||||||
|
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
|
||||||
|
"""Test triggering a scan with valid config file."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
||||||
|
|
||||||
|
# Verify scan created
|
||||||
|
assert scan_id is not None
|
||||||
|
assert isinstance(scan_id, int)
|
||||||
|
|
||||||
|
# Verify scan in database
|
||||||
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
assert scan is not None
|
||||||
|
assert scan.status == 'running'
|
||||||
|
assert scan.title == 'Test Scan'
|
||||||
|
assert scan.triggered_by == 'manual'
|
||||||
|
assert scan.config_file == sample_config_file
|
||||||
|
|
||||||
|
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
|
||||||
|
"""Test triggering a scan with invalid config file."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid config file"):
|
||||||
|
service.trigger_scan(sample_invalid_config_file)
|
||||||
|
|
||||||
|
def test_trigger_scan_nonexistent_file(self, test_db):
|
||||||
|
"""Test triggering a scan with nonexistent config file."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not exist"):
|
||||||
|
service.trigger_scan('/nonexistent/config.yaml')
|
||||||
|
|
||||||
|
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
|
||||||
|
"""Test triggering a scan via schedule."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
scan_id = service.trigger_scan(
|
||||||
|
sample_config_file,
|
||||||
|
triggered_by='scheduled',
|
||||||
|
schedule_id=42
|
||||||
|
)
|
||||||
|
|
||||||
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
assert scan.triggered_by == 'scheduled'
|
||||||
|
assert scan.schedule_id == 42
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanServiceGet:
|
||||||
|
"""Tests for retrieving scans."""
|
||||||
|
|
||||||
|
def test_get_scan_not_found(self, test_db):
|
||||||
|
"""Test getting a nonexistent scan."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
result = service.get_scan(999)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_scan_found(self, test_db, sample_config_file):
|
||||||
|
"""Test getting an existing scan."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create a scan
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
|
||||||
|
# Retrieve it
|
||||||
|
result = service.get_scan(scan_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result['id'] == scan_id
|
||||||
|
assert result['title'] == 'Test Scan'
|
||||||
|
assert result['status'] == 'running'
|
||||||
|
assert 'sites' in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanServiceList:
|
||||||
|
"""Tests for listing scans."""
|
||||||
|
|
||||||
|
def test_list_scans_empty(self, test_db):
|
||||||
|
"""Test listing scans when database is empty."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
result = service.list_scans(page=1, per_page=20)
|
||||||
|
|
||||||
|
assert result.total == 0
|
||||||
|
assert len(result.items) == 0
|
||||||
|
assert result.pages == 0
|
||||||
|
|
||||||
|
def test_list_scans_with_data(self, test_db, sample_config_file):
|
||||||
|
"""Test listing scans with multiple scans."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create 3 scans
|
||||||
|
for i in range(3):
|
||||||
|
service.trigger_scan(sample_config_file, triggered_by='api')
|
||||||
|
|
||||||
|
# List all scans
|
||||||
|
result = service.list_scans(page=1, per_page=20)
|
||||||
|
|
||||||
|
assert result.total == 3
|
||||||
|
assert len(result.items) == 3
|
||||||
|
assert result.pages == 1
|
||||||
|
|
||||||
|
def test_list_scans_pagination(self, test_db, sample_config_file):
|
||||||
|
"""Test pagination."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create 5 scans
|
||||||
|
for i in range(5):
|
||||||
|
service.trigger_scan(sample_config_file)
|
||||||
|
|
||||||
|
# Get page 1 (2 items per page)
|
||||||
|
result = service.list_scans(page=1, per_page=2)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.total == 5
|
||||||
|
assert result.pages == 3
|
||||||
|
assert result.has_next is True
|
||||||
|
|
||||||
|
# Get page 2
|
||||||
|
result = service.list_scans(page=2, per_page=2)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.has_prev is True
|
||||||
|
assert result.has_next is True
|
||||||
|
|
||||||
|
# Get page 3 (last page)
|
||||||
|
result = service.list_scans(page=3, per_page=2)
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.has_next is False
|
||||||
|
|
||||||
|
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
|
||||||
|
"""Test filtering scans by status."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create scans with different statuses
|
||||||
|
scan_id_1 = service.trigger_scan(sample_config_file)
|
||||||
|
scan_id_2 = service.trigger_scan(sample_config_file)
|
||||||
|
|
||||||
|
# Mark one as completed
|
||||||
|
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
|
||||||
|
scan.status = 'completed'
|
||||||
|
test_db.commit()
|
||||||
|
|
||||||
|
# Filter by running
|
||||||
|
result = service.list_scans(status_filter='running')
|
||||||
|
assert result.total == 1
|
||||||
|
|
||||||
|
# Filter by completed
|
||||||
|
result = service.list_scans(status_filter='completed')
|
||||||
|
assert result.total == 1
|
||||||
|
|
||||||
|
def test_list_scans_invalid_status_filter(self, test_db):
|
||||||
|
"""Test filtering with invalid status."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid status"):
|
||||||
|
service.list_scans(status_filter='invalid_status')
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanServiceDelete:
|
||||||
|
"""Tests for deleting scans."""
|
||||||
|
|
||||||
|
def test_delete_scan_not_found(self, test_db):
|
||||||
|
"""Test deleting a nonexistent scan."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
service.delete_scan(999)
|
||||||
|
|
||||||
|
def test_delete_scan_success(self, test_db, sample_config_file):
|
||||||
|
"""Test successful scan deletion."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create a scan
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
|
||||||
|
# Verify it exists
|
||||||
|
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
|
||||||
|
|
||||||
|
# Delete it
|
||||||
|
result = service.delete_scan(scan_id)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify it's gone
|
||||||
|
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanServiceStatus:
|
||||||
|
"""Tests for scan status retrieval."""
|
||||||
|
|
||||||
|
def test_get_scan_status_not_found(self, test_db):
|
||||||
|
"""Test getting status of nonexistent scan."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
result = service.get_scan_status(999)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_scan_status_running(self, test_db, sample_config_file):
|
||||||
|
"""Test getting status of running scan."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
status = service.get_scan_status(scan_id)
|
||||||
|
|
||||||
|
assert status is not None
|
||||||
|
assert status['scan_id'] == scan_id
|
||||||
|
assert status['status'] == 'running'
|
||||||
|
assert status['progress'] == 'In progress'
|
||||||
|
assert status['title'] == 'Test Scan'
|
||||||
|
|
||||||
|
def test_get_scan_status_completed(self, test_db, sample_config_file):
|
||||||
|
"""Test getting status of completed scan."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create and mark as completed
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
scan.status = 'completed'
|
||||||
|
scan.duration = 125.5
|
||||||
|
test_db.commit()
|
||||||
|
|
||||||
|
status = service.get_scan_status(scan_id)
|
||||||
|
|
||||||
|
assert status['status'] == 'completed'
|
||||||
|
assert status['progress'] == 'Complete'
|
||||||
|
assert status['duration'] == 125.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanServiceDatabaseMapping:
|
||||||
|
"""Tests for mapping scan reports to database models."""
|
||||||
|
|
||||||
|
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
|
||||||
|
"""Test saving a complete scan report to database."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
# Create a scan
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
|
||||||
|
# Save report to database
|
||||||
|
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
|
||||||
|
|
||||||
|
# Verify scan updated
|
||||||
|
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
assert scan.status == 'completed'
|
||||||
|
assert scan.duration == 125.5
|
||||||
|
|
||||||
|
# Verify sites created
|
||||||
|
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
|
||||||
|
assert len(sites) == 1
|
||||||
|
assert sites[0].site_name == 'Test Site'
|
||||||
|
|
||||||
|
# Verify IPs created
|
||||||
|
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
|
||||||
|
assert len(ips) == 1
|
||||||
|
assert ips[0].ip_address == '192.168.1.10'
|
||||||
|
assert ips[0].ping_expected is True
|
||||||
|
assert ips[0].ping_actual is True
|
||||||
|
|
||||||
|
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
|
||||||
|
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
|
||||||
|
assert len(ports) == 5 # 4 TCP + 1 UDP
|
||||||
|
|
||||||
|
# Verify TCP ports
|
||||||
|
tcp_ports = [p for p in ports if p.protocol == 'tcp']
|
||||||
|
assert len(tcp_ports) == 4
|
||||||
|
tcp_port_numbers = sorted([p.port for p in tcp_ports])
|
||||||
|
assert tcp_port_numbers == [22, 80, 443, 8080]
|
||||||
|
|
||||||
|
# Verify UDP ports
|
||||||
|
udp_ports = [p for p in ports if p.protocol == 'udp']
|
||||||
|
assert len(udp_ports) == 1
|
||||||
|
assert udp_ports[0].port == 53
|
||||||
|
|
||||||
|
# Verify services created
|
||||||
|
services = test_db.query(ScanServiceModel).filter(
|
||||||
|
ScanServiceModel.scan_id == scan_id
|
||||||
|
).all()
|
||||||
|
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
|
||||||
|
|
||||||
|
# Find HTTPS service
|
||||||
|
https_service = next(
|
||||||
|
(s for s in services if s.service_name == 'https'), None
|
||||||
|
)
|
||||||
|
assert https_service is not None
|
||||||
|
assert https_service.product == 'nginx'
|
||||||
|
assert https_service.version == '1.24.0'
|
||||||
|
assert https_service.http_protocol == 'https'
|
||||||
|
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
|
||||||
|
|
||||||
|
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
|
||||||
|
"""Test that expected vs actual ports are correctly flagged."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||||
|
|
||||||
|
# Check TCP ports
|
||||||
|
tcp_ports = test_db.query(ScanPort).filter(
|
||||||
|
ScanPort.scan_id == scan_id,
|
||||||
|
ScanPort.protocol == 'tcp'
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# Ports 22, 80, 443 were expected
|
||||||
|
expected_ports = {22, 80, 443}
|
||||||
|
for port in tcp_ports:
|
||||||
|
if port.port in expected_ports:
|
||||||
|
assert port.expected is True, f"Port {port.port} should be expected"
|
||||||
|
else:
|
||||||
|
# Port 8080 was not expected
|
||||||
|
assert port.expected is False, f"Port {port.port} should not be expected"
|
||||||
|
|
||||||
|
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
|
||||||
|
"""Test that certificate and TLS data are correctly mapped."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||||
|
|
||||||
|
# Find HTTPS service
|
||||||
|
https_service = test_db.query(ScanServiceModel).filter(
|
||||||
|
ScanServiceModel.scan_id == scan_id,
|
||||||
|
ScanServiceModel.service_name == 'https'
|
||||||
|
).first()
|
||||||
|
|
||||||
|
assert https_service is not None
|
||||||
|
|
||||||
|
# Verify certificate created
|
||||||
|
assert len(https_service.certificates) == 1
|
||||||
|
cert = https_service.certificates[0]
|
||||||
|
|
||||||
|
assert cert.subject == 'CN=example.com'
|
||||||
|
assert cert.issuer == "CN=Let's Encrypt Authority"
|
||||||
|
assert cert.days_until_expiry == 365
|
||||||
|
assert cert.is_self_signed is False
|
||||||
|
|
||||||
|
# Verify SANs
|
||||||
|
import json
|
||||||
|
sans = json.loads(cert.sans)
|
||||||
|
assert 'example.com' in sans
|
||||||
|
assert 'www.example.com' in sans
|
||||||
|
|
||||||
|
# Verify TLS versions
|
||||||
|
assert len(cert.tls_versions) == 2
|
||||||
|
|
||||||
|
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
|
||||||
|
assert tls_12 is not None
|
||||||
|
assert tls_12.supported is True
|
||||||
|
|
||||||
|
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
|
||||||
|
assert tls_13 is not None
|
||||||
|
assert tls_13.supported is True
|
||||||
|
|
||||||
|
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
|
||||||
|
"""Test retrieving scan with all nested relationships."""
|
||||||
|
service = ScanService(test_db)
|
||||||
|
|
||||||
|
scan_id = service.trigger_scan(sample_config_file)
|
||||||
|
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||||
|
|
||||||
|
# Get full scan details
|
||||||
|
result = service.get_scan(scan_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result['sites']) == 1
|
||||||
|
|
||||||
|
site = result['sites'][0]
|
||||||
|
assert site['name'] == 'Test Site'
|
||||||
|
assert len(site['ips']) == 1
|
||||||
|
|
||||||
|
ip = site['ips'][0]
|
||||||
|
assert ip['address'] == '192.168.1.10'
|
||||||
|
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
|
||||||
|
|
||||||
|
# Find HTTPS port
|
||||||
|
https_port = next(
|
||||||
|
(p for p in ip['ports'] if p['port'] == 443), None
|
||||||
|
)
|
||||||
|
assert https_port is not None
|
||||||
|
assert len(https_port['services']) == 1
|
||||||
|
|
||||||
|
service_data = https_port['services'][0]
|
||||||
|
assert service_data['service_name'] == 'https'
|
||||||
|
assert 'certificates' in service_data
|
||||||
|
assert len(service_data['certificates']) == 1
|
||||||
|
|
||||||
|
cert = service_data['certificates'][0]
|
||||||
|
assert cert['subject'] == 'CN=example.com'
|
||||||
|
assert 'tls_versions' in cert
|
||||||
|
assert len(cert['tls_versions']) == 2
|
||||||
10
web/services/__init__.py
Normal file
10
web/services/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Services package for SneakyScanner web application.
|
||||||
|
|
||||||
|
This package contains business logic layer services that orchestrate
|
||||||
|
operations between API endpoints and database models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from web.services.scan_service import ScanService
|
||||||
|
|
||||||
|
__all__ = ['ScanService']
|
||||||
589
web/services/scan_service.py
Normal file
589
web/services/scan_service.py
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
"""
|
||||||
|
Scan service for managing scan operations and database integration.
|
||||||
|
|
||||||
|
This service handles the business logic for triggering scans, retrieving
|
||||||
|
scan results, and mapping scanner output to database models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from web.models import (
|
||||||
|
Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel,
|
||||||
|
ScanCertificate, ScanTLSVersion
|
||||||
|
)
|
||||||
|
from web.utils.pagination import paginate, PaginatedResult
|
||||||
|
from web.utils.validators import validate_config_file, validate_scan_status
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ScanService:
|
||||||
|
"""
|
||||||
|
Service for managing scan operations.
|
||||||
|
|
||||||
|
Handles scan lifecycle: triggering, status tracking, result storage,
|
||||||
|
and cleanup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_session: Session):
|
||||||
|
"""
|
||||||
|
Initialize scan service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: SQLAlchemy database session
|
||||||
|
"""
|
||||||
|
self.db = db_session
|
||||||
|
|
||||||
|
def trigger_scan(self, config_file: str, triggered_by: str = 'manual',
|
||||||
|
schedule_id: Optional[int] = None) -> int:
|
||||||
|
"""
|
||||||
|
Trigger a new scan.
|
||||||
|
|
||||||
|
Creates a Scan record in the database with status='running' and
|
||||||
|
queues the scan for background execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: Path to YAML configuration file
|
||||||
|
triggered_by: Source that triggered scan (manual, scheduled, api)
|
||||||
|
schedule_id: Optional schedule ID if triggered by schedule
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scan ID of the created scan
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If config file is invalid
|
||||||
|
"""
|
||||||
|
# Validate config file
|
||||||
|
is_valid, error_msg = validate_config_file(config_file)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(f"Invalid config file: {error_msg}")
|
||||||
|
|
||||||
|
# Load config to get title
|
||||||
|
import yaml
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Create scan record
|
||||||
|
scan = Scan(
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
status='running',
|
||||||
|
config_file=config_file,
|
||||||
|
title=config.get('title', 'Untitled Scan'),
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
schedule_id=schedule_id,
|
||||||
|
created_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(scan)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(scan)
|
||||||
|
|
||||||
|
logger.info(f"Scan {scan.id} triggered via {triggered_by}")
|
||||||
|
|
||||||
|
# NOTE: Background job queuing will be implemented in Step 3
|
||||||
|
# For now, just return the scan ID
|
||||||
|
return scan.id
|
||||||
|
|
||||||
|
def get_scan(self, scan_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get scan details with all related data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan_id: Scan ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with scan data including sites, IPs, ports, services, etc.
|
||||||
|
Returns None if scan not found.
|
||||||
|
"""
|
||||||
|
# Query with eager loading of all relationships
|
||||||
|
scan = (
|
||||||
|
self.db.query(Scan)
|
||||||
|
.options(
|
||||||
|
joinedload(Scan.sites).joinedload(ScanSite.ips).joinedload(ScanIP.ports),
|
||||||
|
joinedload(Scan.ports).joinedload(ScanPort.services),
|
||||||
|
joinedload(Scan.services).joinedload(ScanServiceModel.certificates),
|
||||||
|
joinedload(Scan.certificates).joinedload(ScanCertificate.tls_versions)
|
||||||
|
)
|
||||||
|
.filter(Scan.id == scan_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not scan:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to dictionary
|
||||||
|
return self._scan_to_dict(scan)
|
||||||
|
|
||||||
|
def list_scans(self, page: int = 1, per_page: int = 20,
|
||||||
|
status_filter: Optional[str] = None) -> PaginatedResult:
|
||||||
|
"""
|
||||||
|
List scans with pagination and optional filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: Page number (1-indexed)
|
||||||
|
per_page: Items per page
|
||||||
|
status_filter: Optional filter by status (running, completed, failed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PaginatedResult with scan list and metadata
|
||||||
|
"""
|
||||||
|
# Build query
|
||||||
|
query = self.db.query(Scan).order_by(Scan.timestamp.desc())
|
||||||
|
|
||||||
|
# Apply status filter if provided
|
||||||
|
if status_filter:
|
||||||
|
is_valid, error_msg = validate_scan_status(status_filter)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
query = query.filter(Scan.status == status_filter)
|
||||||
|
|
||||||
|
# Paginate
|
||||||
|
result = paginate(query, page=page, per_page=per_page)
|
||||||
|
|
||||||
|
# Convert scans to dictionaries (summary only, not full details)
|
||||||
|
result.items = [self._scan_to_summary_dict(scan) for scan in result.items]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete_scan(self, scan_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a scan and all associated files.
|
||||||
|
|
||||||
|
Removes:
|
||||||
|
- Database record (cascade deletes related records)
|
||||||
|
- JSON report file
|
||||||
|
- HTML report file
|
||||||
|
- ZIP archive file
|
||||||
|
- Screenshot directory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan_id: Scan ID to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If scan not found
|
||||||
|
"""
|
||||||
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
if not scan:
|
||||||
|
raise ValueError(f"Scan {scan_id} not found")
|
||||||
|
|
||||||
|
logger.info(f"Deleting scan {scan_id}")
|
||||||
|
|
||||||
|
# Delete files (handle missing files gracefully)
|
||||||
|
files_to_delete = [
|
||||||
|
scan.json_path,
|
||||||
|
scan.html_path,
|
||||||
|
scan.zip_path
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_path in files_to_delete:
|
||||||
|
if file_path:
|
||||||
|
try:
|
||||||
|
Path(file_path).unlink()
|
||||||
|
logger.debug(f"Deleted file: {file_path}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"File not found (already deleted?): {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting file {file_path}: {e}")
|
||||||
|
|
||||||
|
# Delete screenshot directory
|
||||||
|
if scan.screenshot_dir:
|
||||||
|
try:
|
||||||
|
shutil.rmtree(scan.screenshot_dir)
|
||||||
|
logger.debug(f"Deleted directory: {scan.screenshot_dir}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"Directory not found (already deleted?): {scan.screenshot_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting directory {scan.screenshot_dir}: {e}")
|
||||||
|
|
||||||
|
# Delete database record (cascade handles related records)
|
||||||
|
self.db.delete(scan)
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Scan {scan_id} deleted successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_scan_status(self, scan_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get current scan status and progress.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan_id: Scan ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status information, or None if scan not found
|
||||||
|
"""
|
||||||
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
if not scan:
|
||||||
|
return None
|
||||||
|
|
||||||
|
status_info = {
|
||||||
|
'scan_id': scan.id,
|
||||||
|
'status': scan.status,
|
||||||
|
'title': scan.title,
|
||||||
|
'started_at': scan.timestamp.isoformat() if scan.timestamp else None,
|
||||||
|
'duration': scan.duration,
|
||||||
|
'triggered_by': scan.triggered_by
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add progress estimate based on status
|
||||||
|
if scan.status == 'running':
|
||||||
|
status_info['progress'] = 'In progress'
|
||||||
|
elif scan.status == 'completed':
|
||||||
|
status_info['progress'] = 'Complete'
|
||||||
|
elif scan.status == 'failed':
|
||||||
|
status_info['progress'] = 'Failed'
|
||||||
|
|
||||||
|
return status_info
|
||||||
|
|
||||||
|
def _save_scan_to_db(self, report: Dict[str, Any], scan_id: int,
|
||||||
|
status: str = 'completed') -> None:
|
||||||
|
"""
|
||||||
|
Save scan results to database.
|
||||||
|
|
||||||
|
Updates the Scan record and creates all related records
|
||||||
|
(sites, IPs, ports, services, certificates, TLS versions).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
report: Scan report dictionary from scanner
|
||||||
|
scan_id: Scan ID to update
|
||||||
|
status: Final scan status (completed or failed)
|
||||||
|
"""
|
||||||
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
||||||
|
if not scan:
|
||||||
|
raise ValueError(f"Scan {scan_id} not found")
|
||||||
|
|
||||||
|
# Update scan record
|
||||||
|
scan.status = status
|
||||||
|
scan.duration = report.get('scan_duration')
|
||||||
|
|
||||||
|
# Map report data to database models
|
||||||
|
self._map_report_to_models(report, scan)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(f"Scan {scan_id} saved to database with status '{status}'")
|
||||||
|
|
||||||
|
def _map_report_to_models(self, report: Dict[str, Any], scan_obj: Scan) -> None:
|
||||||
|
"""
|
||||||
|
Map JSON report structure to database models.
|
||||||
|
|
||||||
|
Creates records for sites, IPs, ports, services, certificates, and TLS versions.
|
||||||
|
Processes nested relationships in order to handle foreign keys correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
report: Scan report dictionary
|
||||||
|
scan_obj: Scan database object to attach records to
|
||||||
|
"""
|
||||||
|
logger.debug(f"Mapping report to database models for scan {scan_obj.id}")
|
||||||
|
|
||||||
|
# Process each site
|
||||||
|
for site_data in report.get('sites', []):
|
||||||
|
# Create ScanSite record
|
||||||
|
site = ScanSite(
|
||||||
|
scan_id=scan_obj.id,
|
||||||
|
site_name=site_data['name']
|
||||||
|
)
|
||||||
|
self.db.add(site)
|
||||||
|
self.db.flush() # Get site.id for foreign key
|
||||||
|
|
||||||
|
# Process each IP in this site
|
||||||
|
for ip_data in site_data.get('ips', []):
|
||||||
|
# Create ScanIP record
|
||||||
|
ip = ScanIP(
|
||||||
|
scan_id=scan_obj.id,
|
||||||
|
site_id=site.id,
|
||||||
|
ip_address=ip_data['address'],
|
||||||
|
ping_expected=ip_data.get('expected', {}).get('ping'),
|
||||||
|
ping_actual=ip_data.get('actual', {}).get('ping')
|
||||||
|
)
|
||||||
|
self.db.add(ip)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
# Process TCP ports
|
||||||
|
expected_tcp = set(ip_data.get('expected', {}).get('tcp_ports', []))
|
||||||
|
actual_tcp = ip_data.get('actual', {}).get('tcp_ports', [])
|
||||||
|
|
||||||
|
for port_num in actual_tcp:
|
||||||
|
port = ScanPort(
|
||||||
|
scan_id=scan_obj.id,
|
||||||
|
ip_id=ip.id,
|
||||||
|
port=port_num,
|
||||||
|
protocol='tcp',
|
||||||
|
expected=(port_num in expected_tcp),
|
||||||
|
state='open'
|
||||||
|
)
|
||||||
|
self.db.add(port)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
# Find service for this port
|
||||||
|
service_data = self._find_service_for_port(
|
||||||
|
ip_data.get('actual', {}).get('services', []),
|
||||||
|
port_num
|
||||||
|
)
|
||||||
|
|
||||||
|
if service_data:
|
||||||
|
# Create ScanService record
|
||||||
|
service = ScanServiceModel(
|
||||||
|
scan_id=scan_obj.id,
|
||||||
|
port_id=port.id,
|
||||||
|
service_name=service_data.get('service'),
|
||||||
|
product=service_data.get('product'),
|
||||||
|
version=service_data.get('version'),
|
||||||
|
extrainfo=service_data.get('extrainfo'),
|
||||||
|
ostype=service_data.get('ostype'),
|
||||||
|
http_protocol=service_data.get('http_info', {}).get('protocol'),
|
||||||
|
screenshot_path=service_data.get('http_info', {}).get('screenshot')
|
||||||
|
)
|
||||||
|
self.db.add(service)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
# Process certificate and TLS info if present
|
||||||
|
http_info = service_data.get('http_info', {})
|
||||||
|
if http_info.get('certificate'):
|
||||||
|
self._process_certificate(
|
||||||
|
http_info['certificate'],
|
||||||
|
scan_obj.id,
|
||||||
|
service.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process UDP ports
|
||||||
|
expected_udp = set(ip_data.get('expected', {}).get('udp_ports', []))
|
||||||
|
actual_udp = ip_data.get('actual', {}).get('udp_ports', [])
|
||||||
|
|
||||||
|
for port_num in actual_udp:
|
||||||
|
port = ScanPort(
|
||||||
|
scan_id=scan_obj.id,
|
||||||
|
ip_id=ip.id,
|
||||||
|
port=port_num,
|
||||||
|
protocol='udp',
|
||||||
|
expected=(port_num in expected_udp),
|
||||||
|
state='open'
|
||||||
|
)
|
||||||
|
self.db.add(port)
|
||||||
|
|
||||||
|
logger.debug(f"Report mapping complete for scan {scan_obj.id}")
|
||||||
|
|
||||||
|
def _find_service_for_port(self, services: List[Dict], port: int) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
Find service data for a specific port.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
services: List of service dictionaries
|
||||||
|
port: Port number to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Service dictionary if found, None otherwise
|
||||||
|
"""
|
||||||
|
for service in services:
|
||||||
|
if service.get('port') == port:
|
||||||
|
return service
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _process_certificate(self, cert_data: Dict[str, Any], scan_id: int,
|
||||||
|
service_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Process certificate and TLS version data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cert_data: Certificate data dictionary
|
||||||
|
scan_id: Scan ID
|
||||||
|
service_id: Service ID
|
||||||
|
"""
|
||||||
|
# Create ScanCertificate record
|
||||||
|
cert = ScanCertificate(
|
||||||
|
scan_id=scan_id,
|
||||||
|
service_id=service_id,
|
||||||
|
subject=cert_data.get('subject'),
|
||||||
|
issuer=cert_data.get('issuer'),
|
||||||
|
serial_number=cert_data.get('serial_number'),
|
||||||
|
not_valid_before=self._parse_datetime(cert_data.get('not_valid_before')),
|
||||||
|
not_valid_after=self._parse_datetime(cert_data.get('not_valid_after')),
|
||||||
|
days_until_expiry=cert_data.get('days_until_expiry'),
|
||||||
|
sans=json.dumps(cert_data.get('sans', [])),
|
||||||
|
is_self_signed=cert_data.get('is_self_signed', False)
|
||||||
|
)
|
||||||
|
self.db.add(cert)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
# Process TLS versions
|
||||||
|
tls_versions = cert_data.get('tls_versions', {})
|
||||||
|
for version, version_data in tls_versions.items():
|
||||||
|
tls = ScanTLSVersion(
|
||||||
|
scan_id=scan_id,
|
||||||
|
certificate_id=cert.id,
|
||||||
|
tls_version=version,
|
||||||
|
supported=version_data.get('supported', False),
|
||||||
|
cipher_suites=json.dumps(version_data.get('cipher_suites', []))
|
||||||
|
)
|
||||||
|
self.db.add(tls)
|
||||||
|
|
||||||
|
def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
Parse ISO datetime string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date_str: ISO format datetime string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
datetime object or None if parsing fails
|
||||||
|
"""
|
||||||
|
if not date_str:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle ISO format with 'Z' suffix
|
||||||
|
if date_str.endswith('Z'):
|
||||||
|
date_str = date_str[:-1] + '+00:00'
|
||||||
|
return datetime.fromisoformat(date_str)
|
||||||
|
except (ValueError, AttributeError) as e:
|
||||||
|
logger.warning(f"Failed to parse datetime '{date_str}': {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _scan_to_dict(self, scan: Scan) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert Scan object to dictionary with full details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan: Scan database object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation with all related data
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'id': scan.id,
|
||||||
|
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
|
||||||
|
'duration': scan.duration,
|
||||||
|
'status': scan.status,
|
||||||
|
'title': scan.title,
|
||||||
|
'config_file': scan.config_file,
|
||||||
|
'json_path': scan.json_path,
|
||||||
|
'html_path': scan.html_path,
|
||||||
|
'zip_path': scan.zip_path,
|
||||||
|
'screenshot_dir': scan.screenshot_dir,
|
||||||
|
'triggered_by': scan.triggered_by,
|
||||||
|
'created_at': scan.created_at.isoformat() if scan.created_at else None,
|
||||||
|
'sites': [self._site_to_dict(site) for site in scan.sites]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _scan_to_summary_dict(self, scan: Scan) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert Scan object to summary dictionary (no related data).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan: Scan database object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary dictionary
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'id': scan.id,
|
||||||
|
'timestamp': scan.timestamp.isoformat() if scan.timestamp else None,
|
||||||
|
'duration': scan.duration,
|
||||||
|
'status': scan.status,
|
||||||
|
'title': scan.title,
|
||||||
|
'triggered_by': scan.triggered_by,
|
||||||
|
'created_at': scan.created_at.isoformat() if scan.created_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _site_to_dict(self, site: ScanSite) -> Dict[str, Any]:
|
||||||
|
"""Convert ScanSite to dictionary."""
|
||||||
|
return {
|
||||||
|
'id': site.id,
|
||||||
|
'name': site.site_name,
|
||||||
|
'ips': [self._ip_to_dict(ip) for ip in site.ips]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _ip_to_dict(self, ip: ScanIP) -> Dict[str, Any]:
|
||||||
|
"""Convert ScanIP to dictionary."""
|
||||||
|
return {
|
||||||
|
'id': ip.id,
|
||||||
|
'address': ip.ip_address,
|
||||||
|
'ping_expected': ip.ping_expected,
|
||||||
|
'ping_actual': ip.ping_actual,
|
||||||
|
'ports': [self._port_to_dict(port) for port in ip.ports]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _port_to_dict(self, port: ScanPort) -> Dict[str, Any]:
|
||||||
|
"""Convert ScanPort to dictionary."""
|
||||||
|
return {
|
||||||
|
'id': port.id,
|
||||||
|
'port': port.port,
|
||||||
|
'protocol': port.protocol,
|
||||||
|
'state': port.state,
|
||||||
|
'expected': port.expected,
|
||||||
|
'services': [self._service_to_dict(svc) for svc in port.services]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _service_to_dict(self, service: ScanServiceModel) -> Dict[str, Any]:
|
||||||
|
"""Convert ScanService to dictionary."""
|
||||||
|
result = {
|
||||||
|
'id': service.id,
|
||||||
|
'service_name': service.service_name,
|
||||||
|
'product': service.product,
|
||||||
|
'version': service.version,
|
||||||
|
'extrainfo': service.extrainfo,
|
||||||
|
'ostype': service.ostype,
|
||||||
|
'http_protocol': service.http_protocol,
|
||||||
|
'screenshot_path': service.screenshot_path
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add certificate info if present
|
||||||
|
if service.certificates:
|
||||||
|
result['certificates'] = [
|
||||||
|
self._certificate_to_dict(cert) for cert in service.certificates
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _certificate_to_dict(self, cert: ScanCertificate) -> Dict[str, Any]:
|
||||||
|
"""Convert ScanCertificate to dictionary."""
|
||||||
|
result = {
|
||||||
|
'id': cert.id,
|
||||||
|
'subject': cert.subject,
|
||||||
|
'issuer': cert.issuer,
|
||||||
|
'serial_number': cert.serial_number,
|
||||||
|
'not_valid_before': cert.not_valid_before.isoformat() if cert.not_valid_before else None,
|
||||||
|
'not_valid_after': cert.not_valid_after.isoformat() if cert.not_valid_after else None,
|
||||||
|
'days_until_expiry': cert.days_until_expiry,
|
||||||
|
'is_self_signed': cert.is_self_signed
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse SANs from JSON
|
||||||
|
if cert.sans:
|
||||||
|
try:
|
||||||
|
result['sans'] = json.loads(cert.sans)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
result['sans'] = []
|
||||||
|
|
||||||
|
# Add TLS versions
|
||||||
|
result['tls_versions'] = [
|
||||||
|
self._tls_version_to_dict(tls) for tls in cert.tls_versions
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _tls_version_to_dict(self, tls: ScanTLSVersion) -> Dict[str, Any]:
|
||||||
|
"""Convert ScanTLSVersion to dictionary."""
|
||||||
|
result = {
|
||||||
|
'id': tls.id,
|
||||||
|
'tls_version': tls.tls_version,
|
||||||
|
'supported': tls.supported
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse cipher suites from JSON
|
||||||
|
if tls.cipher_suites:
|
||||||
|
try:
|
||||||
|
result['cipher_suites'] = json.loads(tls.cipher_suites)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
result['cipher_suites'] = []
|
||||||
|
|
||||||
|
return result
|
||||||
158
web/utils/pagination.py
Normal file
158
web/utils/pagination.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Pagination utilities for SneakyScanner web application.
|
||||||
|
|
||||||
|
Provides helper functions for paginating SQLAlchemy queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from sqlalchemy.orm import Query
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedResult:
|
||||||
|
"""Container for paginated query results."""
|
||||||
|
|
||||||
|
def __init__(self, items: List[Any], total: int, page: int, per_page: int):
|
||||||
|
"""
|
||||||
|
Initialize paginated result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: List of items for current page
|
||||||
|
total: Total number of items across all pages
|
||||||
|
page: Current page number (1-indexed)
|
||||||
|
per_page: Number of items per page
|
||||||
|
"""
|
||||||
|
self.items = items
|
||||||
|
self.total = total
|
||||||
|
self.page = page
|
||||||
|
self.per_page = per_page
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pages(self) -> int:
|
||||||
|
"""Calculate total number of pages."""
|
||||||
|
if self.per_page == 0:
|
||||||
|
return 0
|
||||||
|
return (self.total + self.per_page - 1) // self.per_page
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_prev(self) -> bool:
|
||||||
|
"""Check if there is a previous page."""
|
||||||
|
return self.page > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_next(self) -> bool:
|
||||||
|
"""Check if there is a next page."""
|
||||||
|
return self.page < self.pages
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prev_page(self) -> int:
|
||||||
|
"""Get previous page number."""
|
||||||
|
return self.page - 1 if self.has_prev else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_page(self) -> int:
|
||||||
|
"""Get next page number."""
|
||||||
|
return self.page + 1 if self.has_next else None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert to dictionary for API responses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with pagination metadata and items
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'items': self.items,
|
||||||
|
'total': self.total,
|
||||||
|
'page': self.page,
|
||||||
|
'per_page': self.per_page,
|
||||||
|
'pages': self.pages,
|
||||||
|
'has_prev': self.has_prev,
|
||||||
|
'has_next': self.has_next,
|
||||||
|
'prev_page': self.prev_page,
|
||||||
|
'next_page': self.next_page,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def paginate(query: Query, page: int = 1, per_page: int = 20,
|
||||||
|
max_per_page: int = 100) -> PaginatedResult:
|
||||||
|
"""
|
||||||
|
Paginate a SQLAlchemy query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQLAlchemy query to paginate
|
||||||
|
page: Page number (1-indexed, default: 1)
|
||||||
|
per_page: Items per page (default: 20)
|
||||||
|
max_per_page: Maximum items per page (default: 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PaginatedResult with items and pagination metadata
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from web.models import Scan
|
||||||
|
>>> query = db.query(Scan).order_by(Scan.timestamp.desc())
|
||||||
|
>>> result = paginate(query, page=1, per_page=20)
|
||||||
|
>>> scans = result.items
|
||||||
|
>>> total_pages = result.pages
|
||||||
|
"""
|
||||||
|
# Validate and sanitize parameters
|
||||||
|
page = max(1, page) # Page must be at least 1
|
||||||
|
per_page = max(1, min(per_page, max_per_page)) # Clamp per_page
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
# Calculate offset
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
|
||||||
|
# Execute query with limit and offset
|
||||||
|
items = query.limit(per_page).offset(offset).all()
|
||||||
|
|
||||||
|
return PaginatedResult(
|
||||||
|
items=items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
per_page=per_page
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_page_params(page: Any, per_page: Any,
|
||||||
|
max_per_page: int = 100) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Validate and sanitize pagination parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: Page number (any type, will be converted to int)
|
||||||
|
per_page: Items per page (any type, will be converted to int)
|
||||||
|
max_per_page: Maximum items per page (default: 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (validated_page, validated_per_page)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_page_params('2', '50')
|
||||||
|
(2, 50)
|
||||||
|
>>> validate_page_params(-1, 200)
|
||||||
|
(1, 100)
|
||||||
|
>>> validate_page_params(None, None)
|
||||||
|
(1, 20)
|
||||||
|
"""
|
||||||
|
# Default values
|
||||||
|
default_page = 1
|
||||||
|
default_per_page = 20
|
||||||
|
|
||||||
|
# Convert to int, use default if invalid
|
||||||
|
try:
|
||||||
|
page = int(page) if page is not None else default_page
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
page = default_page
|
||||||
|
|
||||||
|
try:
|
||||||
|
per_page = int(per_page) if per_page is not None else default_per_page
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
per_page = default_per_page
|
||||||
|
|
||||||
|
# Validate ranges
|
||||||
|
page = max(1, page)
|
||||||
|
per_page = max(1, min(per_page, max_per_page))
|
||||||
|
|
||||||
|
return page, per_page
|
||||||
284
web/utils/validators.py
Normal file
284
web/utils/validators.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
Input validation utilities for SneakyScanner web application.
|
||||||
|
|
||||||
|
Provides validation functions for API inputs, file paths, and data integrity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config_file(file_path: str) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate that a configuration file exists and is valid YAML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to configuration file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
If valid, returns (True, None)
|
||||||
|
If invalid, returns (False, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_config_file('/app/configs/example.yaml')
|
||||||
|
(True, None)
|
||||||
|
>>> validate_config_file('/nonexistent.yaml')
|
||||||
|
(False, 'File does not exist: /nonexistent.yaml')
|
||||||
|
"""
|
||||||
|
# Check if path is provided
|
||||||
|
if not file_path:
|
||||||
|
return False, 'Config file path is required'
|
||||||
|
|
||||||
|
# Convert to Path object
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not path.exists():
|
||||||
|
return False, f'File does not exist: {file_path}'
|
||||||
|
|
||||||
|
# Check if it's a file (not directory)
|
||||||
|
if not path.is_file():
|
||||||
|
return False, f'Path is not a file: {file_path}'
|
||||||
|
|
||||||
|
# Check file extension
|
||||||
|
if path.suffix.lower() not in ['.yaml', '.yml']:
|
||||||
|
return False, f'File must be YAML (.yaml or .yml): {file_path}'
|
||||||
|
|
||||||
|
# Try to parse as YAML
|
||||||
|
try:
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Check if it's a dictionary (basic structure validation)
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return False, 'Config file must contain a YAML dictionary'
|
||||||
|
|
||||||
|
# Check for required top-level keys
|
||||||
|
if 'title' not in config:
|
||||||
|
return False, 'Config file missing required "title" field'
|
||||||
|
|
||||||
|
if 'sites' not in config:
|
||||||
|
return False, 'Config file missing required "sites" field'
|
||||||
|
|
||||||
|
# Validate sites structure
|
||||||
|
if not isinstance(config['sites'], list):
|
||||||
|
return False, '"sites" must be a list'
|
||||||
|
|
||||||
|
if len(config['sites']) == 0:
|
||||||
|
return False, '"sites" list cannot be empty'
|
||||||
|
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
return False, f'Invalid YAML syntax: {str(e)}'
|
||||||
|
except Exception as e:
|
||||||
|
return False, f'Error reading config file: {str(e)}'
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_scan_status(status: str) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate scan status value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Status string to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_scan_status('running')
|
||||||
|
(True, None)
|
||||||
|
>>> validate_scan_status('invalid')
|
||||||
|
(False, 'Invalid status: invalid. Must be one of: running, completed, failed')
|
||||||
|
"""
|
||||||
|
valid_statuses = ['running', 'completed', 'failed']
|
||||||
|
|
||||||
|
if status not in valid_statuses:
|
||||||
|
return False, f'Invalid status: {status}. Must be one of: {", ".join(valid_statuses)}'
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_triggered_by(triggered_by: str) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate triggered_by value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
triggered_by: Source that triggered the scan
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_triggered_by('manual')
|
||||||
|
(True, None)
|
||||||
|
>>> validate_triggered_by('api')
|
||||||
|
(True, None)
|
||||||
|
"""
|
||||||
|
valid_sources = ['manual', 'scheduled', 'api']
|
||||||
|
|
||||||
|
if triggered_by not in valid_sources:
|
||||||
|
return False, f'Invalid triggered_by: {triggered_by}. Must be one of: {", ".join(valid_sources)}'
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_scan_id(scan_id: any) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate scan ID is a positive integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan_id: Scan ID to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_scan_id(42)
|
||||||
|
(True, None)
|
||||||
|
>>> validate_scan_id('42')
|
||||||
|
(True, None)
|
||||||
|
>>> validate_scan_id(-1)
|
||||||
|
(False, 'Scan ID must be a positive integer')
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
scan_id_int = int(scan_id)
|
||||||
|
if scan_id_int <= 0:
|
||||||
|
return False, 'Scan ID must be a positive integer'
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False, f'Invalid scan ID: {scan_id}'
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_file_path(file_path: str, must_exist: bool = True) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate a file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to validate
|
||||||
|
must_exist: If True, file must exist. If False, only validate format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_file_path('/app/output/scan.json', must_exist=False)
|
||||||
|
(True, None)
|
||||||
|
>>> validate_file_path('', must_exist=False)
|
||||||
|
(False, 'File path is required')
|
||||||
|
"""
|
||||||
|
if not file_path:
|
||||||
|
return False, 'File path is required'
|
||||||
|
|
||||||
|
# Check for path traversal attempts
|
||||||
|
if '..' in file_path:
|
||||||
|
return False, 'Path traversal not allowed'
|
||||||
|
|
||||||
|
if must_exist:
|
||||||
|
path = Path(file_path)
|
||||||
|
if not path.exists():
|
||||||
|
return False, f'File does not exist: {file_path}'
|
||||||
|
if not path.is_file():
|
||||||
|
return False, f'Path is not a file: {file_path}'
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_filename(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a filename by removing/replacing unsafe characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Original filename
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized filename safe for filesystem
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> sanitize_filename('my scan.json')
|
||||||
|
'my_scan.json'
|
||||||
|
>>> sanitize_filename('../../etc/passwd')
|
||||||
|
'etc_passwd'
|
||||||
|
"""
|
||||||
|
# Remove path components
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
|
||||||
|
# Replace unsafe characters with underscore
|
||||||
|
unsafe_chars = ['/', '\\', '..', ' ', ':', '*', '?', '"', '<', '>', '|']
|
||||||
|
for char in unsafe_chars:
|
||||||
|
filename = filename.replace(char, '_')
|
||||||
|
|
||||||
|
# Remove leading/trailing underscores and dots
|
||||||
|
filename = filename.strip('_.')
|
||||||
|
|
||||||
|
# Ensure filename is not empty
|
||||||
|
if not filename:
|
||||||
|
filename = 'unnamed'
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def validate_port(port: any) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate port number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: Port number to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_port(443)
|
||||||
|
(True, None)
|
||||||
|
>>> validate_port(70000)
|
||||||
|
(False, 'Port must be between 1 and 65535')
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
port_int = int(port)
|
||||||
|
if port_int < 1 or port_int > 65535:
|
||||||
|
return False, 'Port must be between 1 and 65535'
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False, f'Invalid port: {port}'
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ip_address(ip: str) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate IPv4 address format (basic validation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip: IP address string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validate_ip_address('192.168.1.1')
|
||||||
|
(True, None)
|
||||||
|
>>> validate_ip_address('256.1.1.1')
|
||||||
|
(False, 'Invalid IP address format')
|
||||||
|
"""
|
||||||
|
if not ip:
|
||||||
|
return False, 'IP address is required'
|
||||||
|
|
||||||
|
# Basic IPv4 validation
|
||||||
|
parts = ip.split('.')
|
||||||
|
if len(parts) != 4:
|
||||||
|
return False, 'Invalid IP address format'
|
||||||
|
|
||||||
|
try:
|
||||||
|
for part in parts:
|
||||||
|
num = int(part)
|
||||||
|
if num < 0 or num > 255:
|
||||||
|
return False, 'Invalid IP address format'
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False, 'Invalid IP address format'
|
||||||
|
|
||||||
|
return True, None
|
||||||
Reference in New Issue
Block a user