restructure of dirs, huge docs update

This commit is contained in:
2025-11-17 16:29:14 -06:00
parent 456e052389
commit cd840cb8ca
87 changed files with 2827 additions and 1094 deletions

1
app/tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Test package for SneakyScanner."""

384
app/tests/conftest.py Normal file
View File

@@ -0,0 +1,384 @@
"""
Pytest configuration and fixtures for SneakyScanner tests.
"""
import os
import tempfile
from datetime import datetime
from pathlib import Path
import pytest
import yaml
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from web.app import create_app
from web.models import Base, Scan
from web.utils.settings import PasswordManager, SettingsManager
@pytest.fixture(scope='function')
def test_db():
"""
Create a temporary test database.
Yields a SQLAlchemy session for testing, then cleans up.
"""
# Create temporary database file
db_fd, db_path = tempfile.mkstemp(suffix='.db')
# Create engine and session
engine = create_engine(f'sqlite:///{db_path}', echo=False)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
yield session
# Cleanup
session.close()
os.close(db_fd)
os.unlink(db_path)
@pytest.fixture
def sample_scan_report():
"""
Sample scan report matching the structure from scanner.py.
Returns a dictionary representing a typical scan output.
"""
return {
'title': 'Test Scan',
'scan_time': '2025-11-14T10:30:00Z',
'scan_duration': 125.5,
'config_file': '/app/configs/test.yaml',
'sites': [
{
'name': 'Test Site',
'ips': [
{
'address': '192.168.1.10',
'expected': {
'ping': True,
'tcp_ports': [22, 80, 443],
'udp_ports': [53],
'services': []
},
'actual': {
'ping': True,
'tcp_ports': [22, 80, 443, 8080],
'udp_ports': [53],
'services': [
{
'port': 22,
'service': 'ssh',
'product': 'OpenSSH',
'version': '8.9p1',
'extrainfo': 'Ubuntu',
'ostype': 'Linux'
},
{
'port': 443,
'service': 'https',
'product': 'nginx',
'version': '1.24.0',
'extrainfo': '',
'ostype': '',
'http_info': {
'protocol': 'https',
'screenshot': 'screenshots/192_168_1_10_443.png',
'certificate': {
'subject': 'CN=example.com',
'issuer': 'CN=Let\'s Encrypt Authority',
'serial_number': '123456789',
'not_valid_before': '2025-01-01T00:00:00Z',
'not_valid_after': '2025-12-31T23:59:59Z',
'days_until_expiry': 365,
'sans': ['example.com', 'www.example.com'],
'is_self_signed': False,
'tls_versions': {
'TLS 1.2': {
'supported': True,
'cipher_suites': [
'TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384',
'TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256'
]
},
'TLS 1.3': {
'supported': True,
'cipher_suites': [
'TLS_AES_256_GCM_SHA384',
'TLS_AES_128_GCM_SHA256'
]
}
}
}
}
},
{
'port': 80,
'service': 'http',
'product': 'nginx',
'version': '1.24.0',
'extrainfo': '',
'ostype': '',
'http_info': {
'protocol': 'http',
'screenshot': 'screenshots/192_168_1_10_80.png'
}
},
{
'port': 8080,
'service': 'http',
'product': 'Jetty',
'version': '9.4.48',
'extrainfo': '',
'ostype': ''
}
]
}
}
]
}
]
}
@pytest.fixture
def sample_config_file(tmp_path):
"""
Create a sample YAML config file for testing.
Args:
tmp_path: pytest temporary directory fixture
Returns:
Path to created config file
"""
config_data = {
'title': 'Test Scan',
'sites': [
{
'name': 'Test Site',
'ips': [
{
'address': '192.168.1.10',
'expected': {
'ping': True,
'tcp_ports': [22, 80, 443],
'udp_ports': [53],
'services': ['ssh', 'http', 'https']
}
}
]
}
]
}
config_file = tmp_path / 'test_config.yaml'
with open(config_file, 'w') as f:
yaml.dump(config_data, f)
return str(config_file)
@pytest.fixture
def sample_invalid_config_file(tmp_path):
"""
Create an invalid config file for testing validation.
Returns:
Path to invalid config file
"""
config_file = tmp_path / 'invalid_config.yaml'
with open(config_file, 'w') as f:
f.write("invalid: yaml: content: [missing closing bracket")
return str(config_file)
@pytest.fixture(scope='function')
def app():
"""
Create Flask application for testing.
Returns:
Configured Flask app instance with test database
"""
# Create temporary database
db_fd, db_path = tempfile.mkstemp(suffix='.db')
# Create app with test config
test_config = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}',
'SECRET_KEY': 'test-secret-key'
}
app = create_app(test_config)
yield app
# Cleanup
os.close(db_fd)
os.unlink(db_path)
@pytest.fixture(scope='function')
def client(app):
"""
Create Flask test client.
Args:
app: Flask application fixture
Returns:
Flask test client for making API requests
"""
return app.test_client()
@pytest.fixture(scope='function')
def db(app):
"""
Alias for database session that works with Flask app context.
Args:
app: Flask application fixture
Returns:
SQLAlchemy session
"""
with app.app_context():
yield app.db_session
@pytest.fixture
def sample_scan(db):
"""
Create a sample scan in the database for testing.
Args:
db: Database session fixture
Returns:
Scan model instance
"""
scan = Scan(
timestamp=datetime.utcnow(),
status='completed',
config_file='/app/configs/test.yaml',
title='Test Scan',
duration=125.5,
triggered_by='test',
json_path='/app/output/scan_report_20251114_103000.json',
html_path='/app/output/scan_report_20251114_103000.html',
zip_path='/app/output/scan_report_20251114_103000.zip',
screenshot_dir='/app/output/scan_report_20251114_103000_screenshots'
)
db.add(scan)
db.commit()
db.refresh(scan)
return scan
# Authentication Fixtures
@pytest.fixture
def app_password():
"""
Test password for authentication tests.
Returns:
Test password string
"""
return 'testpassword123'
@pytest.fixture
def db_with_password(db, app_password):
"""
Database session with application password set.
Args:
db: Database session fixture
app_password: Test password fixture
Returns:
Database session with password configured
"""
settings_manager = SettingsManager(db)
PasswordManager.set_app_password(settings_manager, app_password)
return db
@pytest.fixture
def db_no_password(app):
"""
Database session without application password set.
Args:
app: Flask application fixture
Returns:
Database session without password
"""
with app.app_context():
# Clear any password that might be set
settings_manager = SettingsManager(app.db_session)
settings_manager.delete('app_password')
yield app.db_session
@pytest.fixture
def authenticated_client(client, db_with_password, app_password):
"""
Flask test client with authenticated session.
Args:
client: Flask test client fixture
db_with_password: Database with password set
app_password: Test password fixture
Returns:
Test client with active session
"""
# Log in
client.post('/auth/login', data={
'password': app_password
})
return client
@pytest.fixture
def client_no_password(app):
"""
Flask test client with no password set (for setup testing).
Args:
app: Flask application fixture
Returns:
Test client for testing setup flow
"""
# Create temporary database without password
db_fd, db_path = tempfile.mkstemp(suffix='.db')
test_config = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}',
'SECRET_KEY': 'test-secret-key'
}
test_app = create_app(test_config)
test_client = test_app.test_client()
yield test_client
# Cleanup
os.close(db_fd)
os.unlink(db_path)

View File

@@ -0,0 +1,279 @@
"""
Tests for authentication system.
Tests login, logout, session management, and API authentication.
"""
import pytest
from flask import url_for
from web.auth.models import User
from web.utils.settings import PasswordManager, SettingsManager
class TestUserModel:
"""Tests for User model."""
def test_user_get_valid_id(self, db):
"""Test getting user with valid ID."""
user = User.get('1', db)
assert user is not None
assert user.id == '1'
def test_user_get_invalid_id(self, db):
"""Test getting user with invalid ID."""
user = User.get('invalid', db)
assert user is None
def test_user_properties(self):
"""Test user properties."""
user = User('1')
assert user.is_authenticated is True
assert user.is_active is True
assert user.is_anonymous is False
assert user.get_id() == '1'
def test_user_authenticate_success(self, db, app_password):
"""Test successful authentication."""
user = User.authenticate(app_password, db)
assert user is not None
assert user.id == '1'
def test_user_authenticate_failure(self, db):
"""Test failed authentication with wrong password."""
user = User.authenticate('wrongpassword', db)
assert user is None
def test_user_has_password_set(self, db, app_password):
"""Test checking if password is set."""
# Password is set in fixture
assert User.has_password_set(db) is True
def test_user_has_password_not_set(self, db_no_password):
"""Test checking if password is not set."""
assert User.has_password_set(db_no_password) is False
class TestAuthRoutes:
"""Tests for authentication routes."""
def test_login_page_renders(self, client):
"""Test that login page renders correctly."""
response = client.get('/auth/login')
assert response.status_code == 200
# Note: This will fail until templates are created
# assert b'login' in response.data.lower()
def test_login_success(self, client, app_password):
"""Test successful login."""
response = client.post('/auth/login', data={
'password': app_password
}, follow_redirects=False)
# Should redirect to dashboard (or main.dashboard)
assert response.status_code == 302
def test_login_failure(self, client):
"""Test failed login with wrong password."""
response = client.post('/auth/login', data={
'password': 'wrongpassword'
}, follow_redirects=True)
# Should stay on login page
assert response.status_code == 200
def test_login_redirect_when_authenticated(self, authenticated_client):
"""Test that login page redirects when already logged in."""
response = authenticated_client.get('/auth/login', follow_redirects=False)
# Should redirect to dashboard
assert response.status_code == 302
def test_logout(self, authenticated_client):
"""Test logout functionality."""
response = authenticated_client.get('/auth/logout', follow_redirects=False)
# Should redirect to login page
assert response.status_code == 302
assert '/auth/login' in response.location
def test_logout_when_not_authenticated(self, client):
"""Test logout when not authenticated."""
response = client.get('/auth/logout', follow_redirects=False)
# Should redirect to login page anyway
assert response.status_code == 302
def test_setup_page_renders_when_no_password(self, client_no_password):
"""Test that setup page renders when no password is set."""
response = client_no_password.get('/auth/setup')
assert response.status_code == 200
def test_setup_redirects_when_password_set(self, client):
"""Test that setup page redirects when password already set."""
response = client.get('/auth/setup', follow_redirects=False)
assert response.status_code == 302
assert '/auth/login' in response.location
def test_setup_password_success(self, client_no_password):
"""Test setting password via setup page."""
response = client_no_password.post('/auth/setup', data={
'password': 'newpassword123',
'confirm_password': 'newpassword123'
}, follow_redirects=False)
# Should redirect to login
assert response.status_code == 302
assert '/auth/login' in response.location
def test_setup_password_too_short(self, client_no_password):
"""Test that setup rejects password that's too short."""
response = client_no_password.post('/auth/setup', data={
'password': 'short',
'confirm_password': 'short'
}, follow_redirects=True)
# Should stay on setup page
assert response.status_code == 200
def test_setup_passwords_dont_match(self, client_no_password):
"""Test that setup rejects mismatched passwords."""
response = client_no_password.post('/auth/setup', data={
'password': 'password123',
'confirm_password': 'different123'
}, follow_redirects=True)
# Should stay on setup page
assert response.status_code == 200
class TestAPIAuthentication:
"""Tests for API endpoint authentication."""
def test_scans_list_requires_auth(self, client):
"""Test that listing scans requires authentication."""
response = client.get('/api/scans')
assert response.status_code == 401
data = response.get_json()
assert 'error' in data
assert data['error'] == 'Authentication required'
def test_scans_list_with_auth(self, authenticated_client):
"""Test that listing scans works when authenticated."""
response = authenticated_client.get('/api/scans')
# Should succeed (200) even if empty
assert response.status_code == 200
data = response.get_json()
assert 'scans' in data
def test_scan_trigger_requires_auth(self, client):
"""Test that triggering scan requires authentication."""
response = client.post('/api/scans', json={
'config_file': '/app/configs/test.yaml'
})
assert response.status_code == 401
def test_scan_get_requires_auth(self, client):
"""Test that getting scan details requires authentication."""
response = client.get('/api/scans/1')
assert response.status_code == 401
def test_scan_delete_requires_auth(self, client):
"""Test that deleting scan requires authentication."""
response = client.delete('/api/scans/1')
assert response.status_code == 401
def test_scan_status_requires_auth(self, client):
"""Test that getting scan status requires authentication."""
response = client.get('/api/scans/1/status')
assert response.status_code == 401
def test_settings_get_requires_auth(self, client):
"""Test that getting settings requires authentication."""
response = client.get('/api/settings')
assert response.status_code == 401
def test_settings_update_requires_auth(self, client):
"""Test that updating settings requires authentication."""
response = client.put('/api/settings', json={
'settings': {'test_key': 'test_value'}
})
assert response.status_code == 401
def test_settings_get_with_auth(self, authenticated_client):
"""Test that getting settings works when authenticated."""
response = authenticated_client.get('/api/settings')
assert response.status_code == 200
data = response.get_json()
assert 'settings' in data
def test_schedules_list_requires_auth(self, client):
"""Test that listing schedules requires authentication."""
response = client.get('/api/schedules')
assert response.status_code == 401
def test_alerts_list_requires_auth(self, client):
"""Test that listing alerts requires authentication."""
response = client.get('/api/alerts')
assert response.status_code == 401
def test_health_check_no_auth_required(self, client):
"""Test that health check endpoints don't require authentication."""
# Health checks should be accessible without authentication
response = client.get('/api/scans/health')
assert response.status_code == 200
response = client.get('/api/settings/health')
assert response.status_code == 200
response = client.get('/api/schedules/health')
assert response.status_code == 200
response = client.get('/api/alerts/health')
assert response.status_code == 200
class TestSessionManagement:
"""Tests for session management."""
def test_session_persists_across_requests(self, authenticated_client):
"""Test that session persists across multiple requests."""
# First request - should succeed
response1 = authenticated_client.get('/api/scans')
assert response1.status_code == 200
# Second request - should also succeed (session persists)
response2 = authenticated_client.get('/api/settings')
assert response2.status_code == 200
def test_remember_me_cookie(self, client, app_password):
"""Test remember me functionality."""
response = client.post('/auth/login', data={
'password': app_password,
'remember': 'on'
}, follow_redirects=False)
# Should set remember_me cookie
assert response.status_code == 302
# Note: Actual cookie checking would require inspecting response.headers
class TestNextRedirect:
"""Tests for 'next' parameter redirect."""
def test_login_redirects_to_next(self, client, app_password):
"""Test that login redirects to 'next' parameter."""
response = client.post('/auth/login?next=/api/scans', data={
'password': app_password
}, follow_redirects=False)
assert response.status_code == 302
assert '/api/scans' in response.location
def test_login_without_next_redirects_to_dashboard(self, client, app_password):
"""Test that login without 'next' redirects to dashboard."""
response = client.post('/auth/login', data={
'password': app_password
}, follow_redirects=False)
assert response.status_code == 302
# Should redirect to dashboard
assert 'dashboard' in response.location or response.location == '/'

View File

@@ -0,0 +1,225 @@
"""
Tests for background job execution and scheduler integration.
Tests the APScheduler integration, job queuing, and background scan execution.
"""
import pytest
import time
from datetime import datetime
from web.models import Scan
from web.services.scan_service import ScanService
from web.services.scheduler_service import SchedulerService
class TestBackgroundJobs:
"""Test suite for background job execution."""
def test_scheduler_initialization(self, app):
"""Test that scheduler is initialized with Flask app."""
assert hasattr(app, 'scheduler')
assert app.scheduler is not None
assert app.scheduler.scheduler is not None
assert app.scheduler.scheduler.running
def test_queue_scan_job(self, app, db, sample_config_file):
"""Test queuing a scan for background execution."""
# Create a scan via service
scan_service = ScanService(db)
scan_id = scan_service.trigger_scan(
config_file=sample_config_file,
triggered_by='test',
scheduler=app.scheduler
)
# Verify scan was created
scan = db.query(Scan).filter_by(id=scan_id).first()
assert scan is not None
assert scan.status == 'running'
# Verify job was queued (check scheduler has the job)
job = app.scheduler.scheduler.get_job(f'scan_{scan_id}')
assert job is not None
assert job.id == f'scan_{scan_id}'
def test_trigger_scan_without_scheduler(self, db, sample_config_file):
"""Test triggering scan without scheduler logs warning."""
# Create scan without scheduler
scan_service = ScanService(db)
scan_id = scan_service.trigger_scan(
config_file=sample_config_file,
triggered_by='test',
scheduler=None # No scheduler
)
# Verify scan was created but not queued
scan = db.query(Scan).filter_by(id=scan_id).first()
assert scan is not None
assert scan.status == 'running'
def test_scheduler_service_queue_scan(self, app, db, sample_config_file):
"""Test SchedulerService.queue_scan directly."""
# Create scan record first
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_file=sample_config_file,
title='Test Scan',
triggered_by='test'
)
db.add(scan)
db.commit()
# Queue the scan
job_id = app.scheduler.queue_scan(scan.id, sample_config_file)
# Verify job was queued
assert job_id == f'scan_{scan.id}'
job = app.scheduler.scheduler.get_job(job_id)
assert job is not None
def test_scheduler_list_jobs(self, app, db, sample_config_file):
"""Test listing scheduled jobs."""
# Queue a few scans
for i in range(3):
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_file=sample_config_file,
title=f'Test Scan {i}',
triggered_by='test'
)
db.add(scan)
db.commit()
app.scheduler.queue_scan(scan.id, sample_config_file)
# List jobs
jobs = app.scheduler.list_jobs()
# Should have at least 3 jobs (might have more from other tests)
assert len(jobs) >= 3
# Each job should have required fields
for job in jobs:
assert 'id' in job
assert 'name' in job
assert 'trigger' in job
def test_scheduler_get_job_status(self, app, db, sample_config_file):
"""Test getting status of a specific job."""
# Create and queue a scan
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_file=sample_config_file,
title='Test Scan',
triggered_by='test'
)
db.add(scan)
db.commit()
job_id = app.scheduler.queue_scan(scan.id, sample_config_file)
# Get job status
status = app.scheduler.get_job_status(job_id)
assert status is not None
assert status['id'] == job_id
assert status['name'] == f'Scan {scan.id}'
def test_scheduler_get_nonexistent_job(self, app):
"""Test getting status of non-existent job."""
status = app.scheduler.get_job_status('nonexistent_job_id')
assert status is None
def test_scan_timing_fields(self, db, sample_config_file):
"""Test that scan timing fields are properly set."""
# Create scan with started_at
scan = Scan(
timestamp=datetime.utcnow(),
status='running',
config_file=sample_config_file,
title='Test Scan',
triggered_by='test',
started_at=datetime.utcnow()
)
db.add(scan)
db.commit()
# Verify fields exist
assert scan.started_at is not None
assert scan.completed_at is None
assert scan.error_message is None
# Update to completed
scan.status = 'completed'
scan.completed_at = datetime.utcnow()
db.commit()
# Verify fields updated
assert scan.completed_at is not None
assert (scan.completed_at - scan.started_at).total_seconds() >= 0
def test_scan_error_handling(self, db, sample_config_file):
"""Test that error messages are stored correctly."""
# Create failed scan
scan = Scan(
timestamp=datetime.utcnow(),
status='failed',
config_file=sample_config_file,
title='Failed Scan',
triggered_by='test',
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
error_message='Test error message'
)
db.add(scan)
db.commit()
# Verify error message stored
assert scan.error_message == 'Test error message'
# Verify status query works
scan_service = ScanService(db)
status = scan_service.get_scan_status(scan.id)
assert status['status'] == 'failed'
assert status['error_message'] == 'Test error message'
@pytest.mark.skip(reason="Requires actual scanner execution - slow test")
def test_background_scan_execution(self, app, db, sample_config_file):
"""
Integration test for actual background scan execution.
This test is skipped by default because it actually runs the scanner,
which requires privileged operations and takes time.
To run: pytest -v -k test_background_scan_execution --run-slow
"""
# Trigger scan
scan_service = ScanService(db)
scan_id = scan_service.trigger_scan(
config_file=sample_config_file,
triggered_by='test',
scheduler=app.scheduler
)
# Wait for scan to complete (with timeout)
max_wait = 300 # 5 minutes
start_time = time.time()
while time.time() - start_time < max_wait:
scan = db.query(Scan).filter_by(id=scan_id).first()
if scan.status in ['completed', 'failed']:
break
time.sleep(5)
# Verify scan completed
scan = db.query(Scan).filter_by(id=scan_id).first()
assert scan.status in ['completed', 'failed']
if scan.status == 'completed':
assert scan.duration is not None
assert scan.json_path is not None
else:
assert scan.error_message is not None

View File

@@ -0,0 +1,483 @@
"""
Integration tests for Config API endpoints.
Tests all config API endpoints including CSV/YAML upload, listing, downloading,
and deletion with schedule protection.
"""
import pytest
import os
import tempfile
import shutil
from web.app import create_app
from web.models import Base
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
@pytest.fixture
def app():
"""Create test application"""
# Create temporary database
test_db = tempfile.mktemp(suffix='.db')
# Create temporary configs directory
temp_configs_dir = tempfile.mkdtemp()
app = create_app({
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{test_db}',
'SECRET_KEY': 'test-secret-key',
'WTF_CSRF_ENABLED': False,
})
# Override configs directory in ConfigService
os.environ['CONFIGS_DIR'] = temp_configs_dir
# Create tables
with app.app_context():
Base.metadata.create_all(bind=app.db_session.get_bind())
yield app
# Cleanup
os.unlink(test_db)
shutil.rmtree(temp_configs_dir)
@pytest.fixture
def client(app):
"""Create test client"""
return app.test_client()
@pytest.fixture
def auth_headers(client):
"""Get authentication headers"""
# First register and login a user
from web.auth.models import User
with client.application.app_context():
# Create test user
user = User(username='testuser')
user.set_password('testpass')
client.application.db_session.add(user)
client.application.db_session.commit()
# Login
response = client.post('/auth/login', data={
'username': 'testuser',
'password': 'testpass'
}, follow_redirects=True)
assert response.status_code == 200
# Return empty headers (session-based auth)
return {}
@pytest.fixture
def sample_csv():
"""Sample CSV content"""
return """scan_title,site_name,ip_address,ping_expected,tcp_ports,udp_ports,services
Test Scan,Web Servers,10.10.20.4,true,"22,80,443",53,"ssh,http,https"
Test Scan,Web Servers,10.10.20.5,true,22,,"ssh"
"""
@pytest.fixture
def sample_yaml():
"""Sample YAML content"""
return """title: Test Scan
sites:
- name: Web Servers
ips:
- address: 10.10.20.4
expected:
ping: true
tcp_ports: [22, 80, 443]
udp_ports: [53]
services: [ssh, http, https]
"""
class TestListConfigs:
"""Tests for GET /api/configs"""
def test_list_configs_empty(self, client, auth_headers):
"""Test listing configs when none exist"""
response = client.get('/api/configs', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert 'configs' in data
assert data['configs'] == []
def test_list_configs_with_files(self, client, auth_headers, app, sample_yaml):
"""Test listing configs with existing files"""
# Create a config file
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml)
response = client.get('/api/configs', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert len(data['configs']) == 1
assert data['configs'][0]['filename'] == 'test-scan.yaml'
assert data['configs'][0]['title'] == 'Test Scan'
assert 'created_at' in data['configs'][0]
assert 'size_bytes' in data['configs'][0]
assert 'used_by_schedules' in data['configs'][0]
def test_list_configs_requires_auth(self, client):
"""Test that listing configs requires authentication"""
response = client.get('/api/configs')
assert response.status_code in [401, 302] # Unauthorized or redirect
class TestGetConfig:
"""Tests for GET /api/configs/<filename>"""
def test_get_config_valid(self, client, auth_headers, app, sample_yaml):
"""Test getting a valid config file"""
# Create a config file
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml)
response = client.get('/api/configs/test-scan.yaml', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert data['filename'] == 'test-scan.yaml'
assert 'content' in data
assert 'parsed' in data
assert data['parsed']['title'] == 'Test Scan'
def test_get_config_not_found(self, client, auth_headers):
"""Test getting non-existent config"""
response = client.get('/api/configs/nonexistent.yaml', headers=auth_headers)
assert response.status_code == 404
data = response.get_json()
assert 'error' in data
def test_get_config_requires_auth(self, client):
"""Test that getting config requires authentication"""
response = client.get('/api/configs/test.yaml')
assert response.status_code in [401, 302]
class TestUploadCSV:
"""Tests for POST /api/configs/upload-csv"""
def test_upload_csv_valid(self, client, auth_headers, sample_csv):
"""Test uploading valid CSV"""
from io import BytesIO
data = {
'file': (BytesIO(sample_csv.encode('utf-8')), 'test.csv')
}
response = client.post('/api/configs/upload-csv', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 200
result = response.get_json()
assert result['success'] is True
assert 'filename' in result
assert result['filename'].endswith('.yaml')
assert 'preview' in result
def test_upload_csv_no_file(self, client, auth_headers):
"""Test uploading without file"""
response = client.post('/api/configs/upload-csv', data={},
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
data = response.get_json()
assert 'error' in data
def test_upload_csv_invalid_format(self, client, auth_headers):
"""Test uploading invalid CSV"""
from io import BytesIO
invalid_csv = "not,a,valid,csv\nmissing,columns"
data = {
'file': (BytesIO(invalid_csv.encode('utf-8')), 'test.csv')
}
response = client.post('/api/configs/upload-csv', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
result = response.get_json()
assert 'error' in result
def test_upload_csv_wrong_extension(self, client, auth_headers):
"""Test uploading file with wrong extension"""
from io import BytesIO
data = {
'file': (BytesIO(b'test'), 'test.txt')
}
response = client.post('/api/configs/upload-csv', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
def test_upload_csv_duplicate_filename(self, client, auth_headers, sample_csv):
"""Test uploading CSV that generates duplicate filename"""
from io import BytesIO
data = {
'file': (BytesIO(sample_csv.encode('utf-8')), 'test.csv')
}
# Upload first time
response1 = client.post('/api/configs/upload-csv', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response1.status_code == 200
# Upload second time (should fail)
response2 = client.post('/api/configs/upload-csv', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response2.status_code == 400
def test_upload_csv_requires_auth(self, client, sample_csv):
"""Test that uploading CSV requires authentication"""
from io import BytesIO
data = {
'file': (BytesIO(sample_csv.encode('utf-8')), 'test.csv')
}
response = client.post('/api/configs/upload-csv', data=data,
content_type='multipart/form-data')
assert response.status_code in [401, 302]
class TestUploadYAML:
"""Tests for POST /api/configs/upload-yaml"""
def test_upload_yaml_valid(self, client, auth_headers, sample_yaml):
"""Test uploading valid YAML"""
from io import BytesIO
data = {
'file': (BytesIO(sample_yaml.encode('utf-8')), 'test.yaml')
}
response = client.post('/api/configs/upload-yaml', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 200
result = response.get_json()
assert result['success'] is True
assert 'filename' in result
def test_upload_yaml_no_file(self, client, auth_headers):
"""Test uploading without file"""
response = client.post('/api/configs/upload-yaml', data={},
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
def test_upload_yaml_invalid_syntax(self, client, auth_headers):
"""Test uploading YAML with invalid syntax"""
from io import BytesIO
invalid_yaml = "invalid: yaml: syntax: ["
data = {
'file': (BytesIO(invalid_yaml.encode('utf-8')), 'test.yaml')
}
response = client.post('/api/configs/upload-yaml', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
def test_upload_yaml_missing_required_fields(self, client, auth_headers):
"""Test uploading YAML missing required fields"""
from io import BytesIO
invalid_yaml = """sites:
- name: Test
ips:
- address: 10.0.0.1
"""
data = {
'file': (BytesIO(invalid_yaml.encode('utf-8')), 'test.yaml')
}
response = client.post('/api/configs/upload-yaml', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
def test_upload_yaml_wrong_extension(self, client, auth_headers):
"""Test uploading file with wrong extension"""
from io import BytesIO
data = {
'file': (BytesIO(b'test'), 'test.txt')
}
response = client.post('/api/configs/upload-yaml', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 400
def test_upload_yaml_requires_auth(self, client, sample_yaml):
"""Test that uploading YAML requires authentication"""
from io import BytesIO
data = {
'file': (BytesIO(sample_yaml.encode('utf-8')), 'test.yaml')
}
response = client.post('/api/configs/upload-yaml', data=data,
content_type='multipart/form-data')
assert response.status_code in [401, 302]
class TestDownloadTemplate:
"""Tests for GET /api/configs/template"""
def test_download_template(self, client, auth_headers):
"""Test downloading CSV template"""
response = client.get('/api/configs/template', headers=auth_headers)
assert response.status_code == 200
assert response.content_type == 'text/csv; charset=utf-8'
assert b'scan_title,site_name,ip_address' in response.data
def test_download_template_requires_auth(self, client):
"""Test that downloading template requires authentication"""
response = client.get('/api/configs/template')
assert response.status_code in [401, 302]
class TestDownloadConfig:
"""Tests for GET /api/configs/<filename>/download"""
def test_download_config_valid(self, client, auth_headers, app, sample_yaml):
"""Test downloading existing config"""
# Create a config file
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml)
response = client.get('/api/configs/test-scan.yaml/download', headers=auth_headers)
assert response.status_code == 200
assert response.content_type == 'application/x-yaml; charset=utf-8'
assert b'title: Test Scan' in response.data
def test_download_config_not_found(self, client, auth_headers):
"""Test downloading non-existent config"""
response = client.get('/api/configs/nonexistent.yaml/download', headers=auth_headers)
assert response.status_code == 404
def test_download_config_requires_auth(self, client):
"""Test that downloading config requires authentication"""
response = client.get('/api/configs/test.yaml/download')
assert response.status_code in [401, 302]
class TestDeleteConfig:
"""Tests for DELETE /api/configs/<filename>"""
def test_delete_config_valid(self, client, auth_headers, app, sample_yaml):
"""Test deleting a config file"""
# Create a config file
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml)
response = client.delete('/api/configs/test-scan.yaml', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert data['success'] is True
# Verify file is deleted
assert not os.path.exists(config_path)
def test_delete_config_not_found(self, client, auth_headers):
"""Test deleting non-existent config"""
response = client.delete('/api/configs/nonexistent.yaml', headers=auth_headers)
assert response.status_code == 404
def test_delete_config_requires_auth(self, client):
"""Test that deleting config requires authentication"""
response = client.delete('/api/configs/test.yaml')
assert response.status_code in [401, 302]
class TestEndToEndWorkflow:
"""End-to-end workflow tests"""
def test_complete_csv_workflow(self, client, auth_headers, sample_csv):
"""Test complete CSV upload workflow"""
from io import BytesIO
# 1. Download template
response = client.get('/api/configs/template', headers=auth_headers)
assert response.status_code == 200
# 2. Upload CSV
data = {
'file': (BytesIO(sample_csv.encode('utf-8')), 'workflow-test.csv')
}
response = client.post('/api/configs/upload-csv', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 200
result = response.get_json()
filename = result['filename']
# 3. List configs (should include new one)
response = client.get('/api/configs', headers=auth_headers)
assert response.status_code == 200
configs = response.get_json()['configs']
assert any(c['filename'] == filename for c in configs)
# 4. Get config details
response = client.get(f'/api/configs/{filename}', headers=auth_headers)
assert response.status_code == 200
# 5. Download config
response = client.get(f'/api/configs/{filename}/download', headers=auth_headers)
assert response.status_code == 200
# 6. Delete config
response = client.delete(f'/api/configs/{filename}', headers=auth_headers)
assert response.status_code == 200
# 7. Verify deletion
response = client.get(f'/api/configs/{filename}', headers=auth_headers)
assert response.status_code == 404
def test_yaml_upload_workflow(self, client, auth_headers, sample_yaml):
"""Test YAML upload workflow"""
from io import BytesIO
# Upload YAML
data = {
'file': (BytesIO(sample_yaml.encode('utf-8')), 'yaml-workflow.yaml')
}
response = client.post('/api/configs/upload-yaml', data=data,
headers=auth_headers, content_type='multipart/form-data')
assert response.status_code == 200
filename = response.get_json()['filename']
# Verify it exists
response = client.get(f'/api/configs/{filename}', headers=auth_headers)
assert response.status_code == 200
# Clean up
client.delete(f'/api/configs/{filename}', headers=auth_headers)

View File

@@ -0,0 +1,545 @@
"""
Unit tests for Config Service
Tests the ConfigService class which manages scan configuration files.
"""
import pytest
import os
import yaml
import tempfile
import shutil
from web.services.config_service import ConfigService
class TestConfigService:
"""Test suite for ConfigService"""
@pytest.fixture
def temp_configs_dir(self):
"""Create a temporary directory for config files"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
@pytest.fixture
def service(self, temp_configs_dir):
"""Create a ConfigService instance with temp directory"""
return ConfigService(configs_dir=temp_configs_dir)
@pytest.fixture
def sample_yaml_config(self):
"""Sample YAML config content"""
return """title: Test Scan
sites:
- name: Web Servers
ips:
- address: 10.10.20.4
expected:
ping: true
tcp_ports: [22, 80, 443]
udp_ports: [53]
services: [ssh, http, https]
"""
@pytest.fixture
def sample_csv_content(self):
"""Sample CSV content"""
return """scan_title,site_name,ip_address,ping_expected,tcp_ports,udp_ports,services
Test Scan,Web Servers,10.10.20.4,true,"22,80,443",53,"ssh,http,https"
Test Scan,Web Servers,10.10.20.5,true,22,,"ssh"
"""
def test_list_configs_empty_directory(self, service):
"""Test listing configs when directory is empty"""
configs = service.list_configs()
assert configs == []
def test_list_configs_with_files(self, service, temp_configs_dir, sample_yaml_config):
"""Test listing configs with existing files"""
# Create a config file
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml_config)
configs = service.list_configs()
assert len(configs) == 1
assert configs[0]['filename'] == 'test-scan.yaml'
assert configs[0]['title'] == 'Test Scan'
assert 'created_at' in configs[0]
assert 'size_bytes' in configs[0]
assert 'used_by_schedules' in configs[0]
def test_list_configs_ignores_non_yaml_files(self, service, temp_configs_dir):
"""Test that non-YAML files are ignored"""
# Create non-YAML files
with open(os.path.join(temp_configs_dir, 'test.txt'), 'w') as f:
f.write('not a yaml file')
with open(os.path.join(temp_configs_dir, 'readme.md'), 'w') as f:
f.write('# README')
configs = service.list_configs()
assert len(configs) == 0
def test_get_config_valid(self, service, temp_configs_dir, sample_yaml_config):
"""Test getting a valid config file"""
# Create a config file
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml_config)
result = service.get_config('test-scan.yaml')
assert result['filename'] == 'test-scan.yaml'
assert 'content' in result
assert 'parsed' in result
assert result['parsed']['title'] == 'Test Scan'
assert len(result['parsed']['sites']) == 1
def test_get_config_not_found(self, service):
"""Test getting a non-existent config"""
with pytest.raises(FileNotFoundError, match="not found"):
service.get_config('nonexistent.yaml')
def test_get_config_invalid_yaml(self, service, temp_configs_dir):
"""Test getting a config with invalid YAML syntax"""
# Create invalid YAML file
config_path = os.path.join(temp_configs_dir, 'invalid.yaml')
with open(config_path, 'w') as f:
f.write("invalid: yaml: syntax: [")
with pytest.raises(ValueError, match="Invalid YAML syntax"):
service.get_config('invalid.yaml')
def test_create_from_yaml_valid(self, service, sample_yaml_config):
"""Test creating config from valid YAML"""
filename = service.create_from_yaml('test-scan.yaml', sample_yaml_config)
assert filename == 'test-scan.yaml'
assert service.config_exists('test-scan.yaml')
# Verify content
result = service.get_config('test-scan.yaml')
assert result['parsed']['title'] == 'Test Scan'
def test_create_from_yaml_adds_extension(self, service, sample_yaml_config):
"""Test that .yaml extension is added if missing"""
filename = service.create_from_yaml('test-scan', sample_yaml_config)
assert filename == 'test-scan.yaml'
assert service.config_exists('test-scan.yaml')
def test_create_from_yaml_sanitizes_filename(self, service, sample_yaml_config):
"""Test that filename is sanitized"""
filename = service.create_from_yaml('../../../etc/test.yaml', sample_yaml_config)
# secure_filename should remove path traversal
assert '..' not in filename
assert '/' not in filename
def test_create_from_yaml_duplicate_filename(self, service, temp_configs_dir, sample_yaml_config):
"""Test creating config with duplicate filename"""
# Create first config
service.create_from_yaml('test-scan.yaml', sample_yaml_config)
# Try to create duplicate
with pytest.raises(ValueError, match="already exists"):
service.create_from_yaml('test-scan.yaml', sample_yaml_config)
def test_create_from_yaml_invalid_syntax(self, service):
"""Test creating config with invalid YAML syntax"""
invalid_yaml = "invalid: yaml: syntax: ["
with pytest.raises(ValueError, match="Invalid YAML syntax"):
service.create_from_yaml('test.yaml', invalid_yaml)
def test_create_from_yaml_invalid_structure(self, service):
"""Test creating config with invalid structure (missing title)"""
invalid_config = """sites:
- name: Test
ips:
- address: 10.0.0.1
expected:
ping: true
"""
with pytest.raises(ValueError, match="Missing required field: 'title'"):
service.create_from_yaml('test.yaml', invalid_config)
def test_create_from_csv_valid(self, service, sample_csv_content):
"""Test creating config from valid CSV"""
filename, yaml_content = service.create_from_csv(sample_csv_content)
assert filename == 'test-scan.yaml'
assert service.config_exists(filename)
# Verify YAML was created correctly
result = service.get_config(filename)
assert result['parsed']['title'] == 'Test Scan'
assert len(result['parsed']['sites']) == 1
assert len(result['parsed']['sites'][0]['ips']) == 2
def test_create_from_csv_with_suggested_filename(self, service, sample_csv_content):
"""Test creating config with suggested filename"""
filename, yaml_content = service.create_from_csv(sample_csv_content, 'custom-name.yaml')
assert filename == 'custom-name.yaml'
assert service.config_exists(filename)
def test_create_from_csv_invalid(self, service):
"""Test creating config from invalid CSV"""
invalid_csv = """scan_title,site_name,ip_address
Missing,Columns,Here
"""
with pytest.raises(ValueError, match="CSV parsing failed"):
service.create_from_csv(invalid_csv)
def test_create_from_csv_duplicate_filename(self, service, sample_csv_content):
"""Test creating CSV config with duplicate filename"""
# Create first config
service.create_from_csv(sample_csv_content)
# Try to create duplicate (same title generates same filename)
with pytest.raises(ValueError, match="already exists"):
service.create_from_csv(sample_csv_content)
def test_delete_config_valid(self, service, temp_configs_dir, sample_yaml_config):
"""Test deleting a config file"""
# Create a config file
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml_config)
assert service.config_exists('test-scan.yaml')
service.delete_config('test-scan.yaml')
assert not service.config_exists('test-scan.yaml')
def test_delete_config_not_found(self, service):
"""Test deleting non-existent config"""
with pytest.raises(FileNotFoundError, match="not found"):
service.delete_config('nonexistent.yaml')
def test_delete_config_used_by_schedule(self, service, temp_configs_dir, sample_yaml_config, monkeypatch):
"""Test deleting config that is used by schedules - should cascade delete schedules"""
# Create a config file
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml_config)
# Mock schedule service interactions
deleted_schedule_ids = []
class MockScheduleService:
def __init__(self, db):
self.db = db
def list_schedules(self, page=1, per_page=10000):
return {
'schedules': [
{
'id': 1,
'name': 'Daily Scan',
'config_file': 'test-scan.yaml',
'enabled': True
},
{
'id': 2,
'name': 'Weekly Audit',
'config_file': 'test-scan.yaml',
'enabled': False # Disabled schedule should also be deleted
}
]
}
def delete_schedule(self, schedule_id):
deleted_schedule_ids.append(schedule_id)
return True
# Mock the ScheduleService import
import sys
from unittest.mock import MagicMock
mock_module = MagicMock()
mock_module.ScheduleService = MockScheduleService
monkeypatch.setitem(sys.modules, 'web.services.schedule_service', mock_module)
# Mock current_app
mock_app = MagicMock()
mock_app.db_session = MagicMock()
import flask
monkeypatch.setattr(flask, 'current_app', mock_app)
# Delete the config - should cascade delete associated schedules
service.delete_config('test-scan.yaml')
# Config should be deleted
assert not service.config_exists('test-scan.yaml')
# Both schedules (enabled and disabled) should be deleted
assert deleted_schedule_ids == [1, 2]
def test_validate_config_content_valid(self, service):
"""Test validating valid config content"""
valid_config = {
'title': 'Test Scan',
'sites': [
{
'name': 'Web Servers',
'ips': [
{
'address': '10.10.20.4',
'expected': {
'ping': True,
'tcp_ports': [22, 80, 443],
'udp_ports': [53]
}
}
]
}
]
}
is_valid, error = service.validate_config_content(valid_config)
assert is_valid is True
assert error == ""
def test_validate_config_content_not_dict(self, service):
"""Test validating non-dict content"""
is_valid, error = service.validate_config_content(['not', 'a', 'dict'])
assert is_valid is False
assert 'must be a dictionary' in error
def test_validate_config_content_missing_title(self, service):
"""Test validating config without title"""
config = {
'sites': []
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "Missing required field: 'title'" in error
def test_validate_config_content_missing_sites(self, service):
"""Test validating config without sites"""
config = {
'title': 'Test'
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "Missing required field: 'sites'" in error
def test_validate_config_content_empty_title(self, service):
"""Test validating config with empty title"""
config = {
'title': '',
'sites': []
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "non-empty string" in error
def test_validate_config_content_sites_not_list(self, service):
"""Test validating config with sites as non-list"""
config = {
'title': 'Test',
'sites': 'not a list'
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "must be a list" in error
def test_validate_config_content_no_sites(self, service):
"""Test validating config with empty sites list"""
config = {
'title': 'Test',
'sites': []
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "at least one site" in error
def test_validate_config_content_site_missing_name(self, service):
"""Test validating site without name"""
config = {
'title': 'Test',
'sites': [
{
'ips': []
}
]
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "missing required field: 'name'" in error
def test_validate_config_content_site_missing_ips(self, service):
"""Test validating site without ips"""
config = {
'title': 'Test',
'sites': [
{
'name': 'Test Site'
}
]
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "missing required field: 'ips'" in error
def test_validate_config_content_site_no_ips(self, service):
"""Test validating site with empty ips list"""
config = {
'title': 'Test',
'sites': [
{
'name': 'Test Site',
'ips': []
}
]
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "at least one IP" in error
def test_validate_config_content_ip_missing_address(self, service):
"""Test validating IP without address"""
config = {
'title': 'Test',
'sites': [
{
'name': 'Test Site',
'ips': [
{
'expected': {}
}
]
}
]
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "missing required field: 'address'" in error
def test_validate_config_content_ip_missing_expected(self, service):
"""Test validating IP without expected"""
config = {
'title': 'Test',
'sites': [
{
'name': 'Test Site',
'ips': [
{
'address': '10.0.0.1'
}
]
}
]
}
is_valid, error = service.validate_config_content(config)
assert is_valid is False
assert "missing required field: 'expected'" in error
def test_generate_filename_from_title_simple(self, service):
"""Test generating filename from simple title"""
filename = service.generate_filename_from_title('Production Scan')
assert filename == 'production-scan.yaml'
def test_generate_filename_from_title_special_chars(self, service):
"""Test generating filename with special characters"""
filename = service.generate_filename_from_title('Prod Scan (2025)!')
assert filename == 'prod-scan-2025.yaml'
assert '(' not in filename
assert ')' not in filename
assert '!' not in filename
def test_generate_filename_from_title_multiple_spaces(self, service):
"""Test generating filename with multiple spaces"""
filename = service.generate_filename_from_title('Test Multiple Spaces')
assert filename == 'test-multiple-spaces.yaml'
# Should not have consecutive hyphens
assert '--' not in filename
def test_generate_filename_from_title_leading_trailing_spaces(self, service):
"""Test generating filename with leading/trailing spaces"""
filename = service.generate_filename_from_title(' Test Scan ')
assert filename == 'test-scan.yaml'
assert not filename.startswith('-')
assert not filename.endswith('-.yaml')
def test_generate_filename_from_title_long(self, service):
"""Test generating filename from long title"""
long_title = 'A' * 300
filename = service.generate_filename_from_title(long_title)
# Should be limited to 200 chars (195 + .yaml)
assert len(filename) <= 200
def test_generate_filename_from_title_empty(self, service):
"""Test generating filename from empty title"""
filename = service.generate_filename_from_title('')
assert filename == 'config.yaml'
def test_generate_filename_from_title_only_special_chars(self, service):
"""Test generating filename from title with only special characters"""
filename = service.generate_filename_from_title('!@#$%^&*()')
assert filename == 'config.yaml'
def test_get_config_path(self, service, temp_configs_dir):
"""Test getting config path"""
path = service.get_config_path('test.yaml')
assert path == os.path.join(temp_configs_dir, 'test.yaml')
def test_config_exists_true(self, service, temp_configs_dir, sample_yaml_config):
"""Test config_exists returns True for existing file"""
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
with open(config_path, 'w') as f:
f.write(sample_yaml_config)
assert service.config_exists('test-scan.yaml') is True
def test_config_exists_false(self, service):
"""Test config_exists returns False for non-existent file"""
assert service.config_exists('nonexistent.yaml') is False
def test_get_schedules_using_config_none(self, service):
"""Test getting schedules when none use the config"""
schedules = service.get_schedules_using_config('test.yaml')
# Should return empty list (ScheduleService might not exist in test env)
assert isinstance(schedules, list)
def test_list_configs_sorted_by_date(self, service, temp_configs_dir, sample_yaml_config):
"""Test that configs are sorted by creation date (most recent first)"""
import time
# Create first config
config1_path = os.path.join(temp_configs_dir, 'config1.yaml')
with open(config1_path, 'w') as f:
f.write(sample_yaml_config)
time.sleep(0.1) # Ensure different timestamps
# Create second config
config2_path = os.path.join(temp_configs_dir, 'config2.yaml')
with open(config2_path, 'w') as f:
f.write(sample_yaml_config)
configs = service.list_configs()
assert len(configs) == 2
# Most recent should be first
assert configs[0]['filename'] == 'config2.yaml'
assert configs[1]['filename'] == 'config1.yaml'
def test_list_configs_handles_parse_errors(self, service, temp_configs_dir):
"""Test that list_configs handles files that can't be parsed"""
# Create invalid YAML file
config_path = os.path.join(temp_configs_dir, 'invalid.yaml')
with open(config_path, 'w') as f:
f.write("invalid: yaml: [")
# Should not raise error, just use filename as title
configs = service.list_configs()
assert len(configs) == 1
assert configs[0]['filename'] == 'invalid.yaml'

View File

@@ -0,0 +1,267 @@
"""
Tests for error handling and logging functionality.
Tests error handlers, request/response logging, database rollback on errors,
and proper error responses (JSON vs HTML).
"""
import json
import logging
import pytest
from flask import Flask
from sqlalchemy.exc import SQLAlchemyError
from web.app import create_app
@pytest.fixture
def app():
"""Create test Flask app."""
test_config = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': 'sqlite:///:memory:',
'SECRET_KEY': 'test-secret-key',
'WTF_CSRF_ENABLED': False
}
app = create_app(test_config)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return app.test_client()
class TestErrorHandlers:
"""Test error handler functionality."""
def test_404_json_response(self, client):
"""Test 404 error returns JSON for API requests."""
response = client.get('/api/nonexistent')
assert response.status_code == 404
assert response.content_type == 'application/json'
data = json.loads(response.data)
assert 'error' in data
assert data['error'] == 'Not Found'
assert 'message' in data
def test_404_html_response(self, client):
"""Test 404 error returns HTML for web requests."""
response = client.get('/nonexistent')
assert response.status_code == 404
assert 'text/html' in response.content_type
assert b'404' in response.data
def test_400_json_response(self, client):
"""Test 400 error returns JSON for API requests."""
# Trigger 400 by sending invalid JSON
response = client.post(
'/api/scans',
data='invalid json',
content_type='application/json'
)
assert response.status_code in [400, 401] # 401 if auth required
def test_405_method_not_allowed(self, client):
"""Test 405 error for method not allowed."""
# Try POST to health check (only GET allowed)
response = client.post('/api/scans/health')
assert response.status_code == 405
data = json.loads(response.data)
assert 'error' in data
assert data['error'] == 'Method Not Allowed'
def test_json_accept_header(self, client):
"""Test JSON response when Accept header specifies JSON."""
response = client.get(
'/nonexistent',
headers={'Accept': 'application/json'}
)
assert response.status_code == 404
assert response.content_type == 'application/json'
class TestLogging:
"""Test logging functionality."""
def test_request_logging(self, client, caplog):
"""Test that requests are logged."""
with caplog.at_level(logging.INFO):
response = client.get('/api/scans/health')
# Check log messages
log_messages = [record.message for record in caplog.records]
# Should log incoming request and response
assert any('GET /api/scans/health' in msg for msg in log_messages)
def test_error_logging(self, client, caplog):
"""Test that errors are logged with full context."""
with caplog.at_level(logging.INFO):
client.get('/api/nonexistent')
# Check that 404 was logged
log_messages = [record.message for record in caplog.records]
assert any('not found' in msg.lower() or '404' in msg for msg in log_messages)
def test_request_id_in_logs(self, client, caplog):
"""Test that request ID is included in log records."""
with caplog.at_level(logging.INFO):
client.get('/api/scans/health')
# Check that log records have request_id attribute
for record in caplog.records:
assert hasattr(record, 'request_id')
assert record.request_id # Should not be empty
class TestRequestResponseHandlers:
"""Test request and response handler middleware."""
def test_request_id_header(self, client):
"""Test that response includes X-Request-ID header for API requests."""
response = client.get('/api/scans/health')
assert 'X-Request-ID' in response.headers
def test_request_duration_header(self, client):
"""Test that response includes X-Request-Duration-Ms header."""
response = client.get('/api/scans/health')
assert 'X-Request-Duration-Ms' in response.headers
duration = float(response.headers['X-Request-Duration-Ms'])
assert duration >= 0 # Should be non-negative
def test_security_headers(self, client):
"""Test that security headers are added to API responses."""
response = client.get('/api/scans/health')
# Check security headers
assert response.headers.get('X-Content-Type-Options') == 'nosniff'
assert response.headers.get('X-Frame-Options') == 'DENY'
assert response.headers.get('X-XSS-Protection') == '1; mode=block'
def test_request_timing(self, client):
"""Test that request timing is calculated correctly."""
response = client.get('/api/scans/health')
duration_header = response.headers.get('X-Request-Duration-Ms')
assert duration_header is not None
duration = float(duration_header)
# Should complete in reasonable time (less than 5 seconds)
assert duration < 5000
class TestDatabaseErrorHandling:
"""Test database error handling and rollback."""
def test_database_rollback_on_error(self, app):
"""Test that database session is rolled back on error."""
# This test would require triggering a database error
# For now, just verify the error handler is registered
from sqlalchemy.exc import SQLAlchemyError
# Check that SQLAlchemyError handler is registered
assert SQLAlchemyError in app.error_handler_spec[None]
class TestLogRotation:
"""Test log rotation configuration."""
def test_log_files_created(self, app, tmp_path):
"""Test that log files are created in logs directory."""
import os
from pathlib import Path
# Check that logs directory exists
log_dir = Path('logs')
# Note: In test environment, logs may not be created immediately
# Just verify the configuration is set up
# Verify app logger has handlers
assert len(app.logger.handlers) > 0
# Verify at least one handler is a RotatingFileHandler
from logging.handlers import RotatingFileHandler
has_rotating_handler = any(
isinstance(h, RotatingFileHandler)
for h in app.logger.handlers
)
assert has_rotating_handler, "Should have RotatingFileHandler configured"
def test_log_handler_configuration(self, app):
"""Test that log handlers are configured correctly."""
from logging.handlers import RotatingFileHandler
# Find RotatingFileHandler
rotating_handlers = [
h for h in app.logger.handlers
if isinstance(h, RotatingFileHandler)
]
assert len(rotating_handlers) > 0, "Should have rotating file handlers"
# Check handler configuration
for handler in rotating_handlers:
# Should have max size configured
assert handler.maxBytes > 0
# Should have backup count configured
assert handler.backupCount > 0
class TestStructuredLogging:
"""Test structured logging features."""
def test_log_format_includes_request_id(self, client, caplog):
"""Test that log format includes request ID."""
with caplog.at_level(logging.INFO):
client.get('/api/scans/health')
# Verify log records have request_id
for record in caplog.records:
assert hasattr(record, 'request_id')
def test_error_log_includes_traceback(self, app, caplog):
"""Test that errors are logged with traceback."""
with app.test_request_context('/api/test'):
with caplog.at_level(logging.ERROR):
try:
raise ValueError("Test error")
except ValueError as e:
app.logger.error("Test error occurred", exc_info=True)
# Check that traceback is in logs
log_output = caplog.text
assert 'Test error' in log_output
assert 'Traceback' in log_output or 'ValueError' in log_output
class TestErrorTemplates:
"""Test error template rendering."""
def test_404_template_exists(self, client):
"""Test that 404 error template is rendered."""
response = client.get('/nonexistent')
assert response.status_code == 404
assert b'404' in response.data
assert b'Page Not Found' in response.data or b'Not Found' in response.data
def test_500_template_exists(self, app):
"""Test that 500 error template can be rendered."""
# We can't easily trigger a 500 without breaking the app
# Just verify the template file exists
from pathlib import Path
template_path = Path('web/templates/errors/500.html')
assert template_path.exists(), "500 error template should exist"
def test_error_template_styling(self, client):
"""Test that error templates include styling."""
response = client.get('/nonexistent')
# Should include CSS styling
assert b'style' in response.data or b'css' in response.data.lower()
if __name__ == '__main__':
pytest.main([__file__, '-v'])

267
app/tests/test_scan_api.py Normal file
View File

@@ -0,0 +1,267 @@
"""
Integration tests for Scan API endpoints.
Tests all scan management endpoints including triggering scans,
listing, retrieving details, deleting, and status polling.
"""
import json
import pytest
from pathlib import Path
from datetime import datetime
from web.models import Scan
class TestScanAPIEndpoints:
"""Test suite for scan API endpoints."""
def test_list_scans_empty(self, client, db):
"""Test listing scans when database is empty."""
response = client.get('/api/scans')
assert response.status_code == 200
data = json.loads(response.data)
assert data['scans'] == []
assert data['total'] == 0
assert data['page'] == 1
assert data['per_page'] == 20
def test_list_scans_with_data(self, client, db, sample_scan):
"""Test listing scans with existing data."""
response = client.get('/api/scans')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 1
assert len(data['scans']) == 1
assert data['scans'][0]['id'] == sample_scan.id
def test_list_scans_pagination(self, client, db):
"""Test scan list pagination."""
# Create 25 scans
for i in range(25):
scan = Scan(
timestamp=datetime.utcnow(),
status='completed',
config_file=f'/app/configs/test{i}.yaml',
title=f'Test Scan {i}',
triggered_by='test'
)
db.add(scan)
db.commit()
# Test page 1
response = client.get('/api/scans?page=1&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 25
assert len(data['scans']) == 10
assert data['page'] == 1
assert data['per_page'] == 10
assert data['total_pages'] == 3
assert data['has_next'] is True
assert data['has_prev'] is False
# Test page 2
response = client.get('/api/scans?page=2&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert len(data['scans']) == 10
assert data['page'] == 2
assert data['has_next'] is True
assert data['has_prev'] is True
def test_list_scans_status_filter(self, client, db):
"""Test filtering scans by status."""
# Create scans with different statuses
for status in ['running', 'completed', 'failed']:
scan = Scan(
timestamp=datetime.utcnow(),
status=status,
config_file='/app/configs/test.yaml',
title=f'{status.capitalize()} Scan',
triggered_by='test'
)
db.add(scan)
db.commit()
# Filter by completed
response = client.get('/api/scans?status=completed')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 1
assert data['scans'][0]['status'] == 'completed'
def test_list_scans_invalid_page(self, client, db):
"""Test listing scans with invalid page parameter."""
response = client.get('/api/scans?page=0')
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
def test_get_scan_success(self, client, db, sample_scan):
"""Test retrieving a specific scan."""
response = client.get(f'/api/scans/{sample_scan.id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['id'] == sample_scan.id
assert data['title'] == sample_scan.title
assert data['status'] == sample_scan.status
def test_get_scan_not_found(self, client, db):
"""Test retrieving a non-existent scan."""
response = client.get('/api/scans/99999')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
assert data['error'] == 'Not found'
def test_trigger_scan_success(self, client, db, sample_config_file):
"""Test triggering a new scan."""
response = client.post('/api/scans',
json={'config_file': str(sample_config_file)},
content_type='application/json'
)
assert response.status_code == 201
data = json.loads(response.data)
assert 'scan_id' in data
assert data['status'] == 'running'
assert data['message'] == 'Scan queued successfully'
# Verify scan was created in database
scan = db.query(Scan).filter_by(id=data['scan_id']).first()
assert scan is not None
assert scan.status == 'running'
assert scan.triggered_by == 'api'
def test_trigger_scan_missing_config_file(self, client, db):
"""Test triggering scan without config_file."""
response = client.post('/api/scans',
json={},
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'config_file is required' in data['message']
def test_trigger_scan_invalid_config_file(self, client, db):
"""Test triggering scan with non-existent config file."""
response = client.post('/api/scans',
json={'config_file': '/nonexistent/config.yaml'},
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
def test_delete_scan_success(self, client, db, sample_scan):
"""Test deleting a scan."""
scan_id = sample_scan.id
response = client.delete(f'/api/scans/{scan_id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['scan_id'] == scan_id
assert 'deleted successfully' in data['message']
# Verify scan was deleted from database
scan = db.query(Scan).filter_by(id=scan_id).first()
assert scan is None
def test_delete_scan_not_found(self, client, db):
"""Test deleting a non-existent scan."""
response = client.delete('/api/scans/99999')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
def test_get_scan_status_success(self, client, db, sample_scan):
"""Test getting scan status."""
response = client.get(f'/api/scans/{sample_scan.id}/status')
assert response.status_code == 200
data = json.loads(response.data)
assert data['scan_id'] == sample_scan.id
assert data['status'] == sample_scan.status
assert 'timestamp' in data
def test_get_scan_status_not_found(self, client, db):
"""Test getting status for non-existent scan."""
response = client.get('/api/scans/99999/status')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
def test_api_error_handling(self, client, db):
"""Test API error responses are properly formatted."""
# Test 404
response = client.get('/api/scans/99999')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
assert 'message' in data
# Test 400
response = client.post('/api/scans', json={})
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'message' in data
def test_scan_workflow_integration(self, client, db, sample_config_file):
"""
Test complete scan workflow: trigger → status → retrieve → delete.
This integration test verifies the entire scan lifecycle through
the API endpoints.
"""
# Step 1: Trigger scan
response = client.post('/api/scans',
json={'config_file': str(sample_config_file)},
content_type='application/json'
)
assert response.status_code == 201
data = json.loads(response.data)
scan_id = data['scan_id']
# Step 2: Check status
response = client.get(f'/api/scans/{scan_id}/status')
assert response.status_code == 200
data = json.loads(response.data)
assert data['scan_id'] == scan_id
assert data['status'] == 'running'
# Step 3: List scans (verify it appears)
response = client.get('/api/scans')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 1
assert data['scans'][0]['id'] == scan_id
# Step 4: Get scan details
response = client.get(f'/api/scans/{scan_id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['id'] == scan_id
# Step 5: Delete scan
response = client.delete(f'/api/scans/{scan_id}')
assert response.status_code == 200
# Step 6: Verify deletion
response = client.get(f'/api/scans/{scan_id}')
assert response.status_code == 404

View File

@@ -0,0 +1,319 @@
"""
Unit tests for scan comparison functionality.
Tests scan comparison logic including port, service, and certificate comparisons,
as well as drift score calculation.
"""
import pytest
from datetime import datetime
from web.models import Scan, ScanSite, ScanIP, ScanPort
from web.models import ScanService as ScanServiceModel, ScanCertificate
from web.services.scan_service import ScanService
class TestScanComparison:
"""Tests for scan comparison methods."""
@pytest.fixture
def scan1_data(self, test_db, sample_config_file):
"""Create first scan with test data."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
# Get scan and add some test data
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
scan.status = 'completed'
# Create site
site = ScanSite(scan_id=scan.id, site_name='Test Site')
test_db.add(site)
test_db.flush()
# Create IP
ip = ScanIP(
scan_id=scan.id,
site_id=site.id,
ip_address='192.168.1.100',
ping_expected=True,
ping_actual=True
)
test_db.add(ip)
test_db.flush()
# Create ports
port1 = ScanPort(
scan_id=scan.id,
ip_id=ip.id,
port=80,
protocol='tcp',
state='open',
expected=True
)
port2 = ScanPort(
scan_id=scan.id,
ip_id=ip.id,
port=443,
protocol='tcp',
state='open',
expected=True
)
test_db.add(port1)
test_db.add(port2)
test_db.flush()
# Create service
svc1 = ScanServiceModel(
scan_id=scan.id,
port_id=port1.id,
service_name='http',
product='nginx',
version='1.18.0'
)
test_db.add(svc1)
test_db.commit()
return scan_id
@pytest.fixture
def scan2_data(self, test_db, sample_config_file):
"""Create second scan with modified test data."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
# Get scan and add some test data
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
scan.status = 'completed'
# Create site
site = ScanSite(scan_id=scan.id, site_name='Test Site')
test_db.add(site)
test_db.flush()
# Create IP
ip = ScanIP(
scan_id=scan.id,
site_id=site.id,
ip_address='192.168.1.100',
ping_expected=True,
ping_actual=True
)
test_db.add(ip)
test_db.flush()
# Create ports (port 80 removed, 443 kept, 8080 added)
port2 = ScanPort(
scan_id=scan.id,
ip_id=ip.id,
port=443,
protocol='tcp',
state='open',
expected=True
)
port3 = ScanPort(
scan_id=scan.id,
ip_id=ip.id,
port=8080,
protocol='tcp',
state='open',
expected=False
)
test_db.add(port2)
test_db.add(port3)
test_db.flush()
# Create service with updated version
svc2 = ScanServiceModel(
scan_id=scan.id,
port_id=port3.id,
service_name='http',
product='nginx',
version='1.20.0' # Version changed
)
test_db.add(svc2)
test_db.commit()
return scan_id
def test_compare_scans_basic(self, test_db, scan1_data, scan2_data):
"""Test basic scan comparison."""
service = ScanService(test_db)
result = service.compare_scans(scan1_data, scan2_data)
assert result is not None
assert 'scan1' in result
assert 'scan2' in result
assert 'ports' in result
assert 'services' in result
assert 'certificates' in result
assert 'drift_score' in result
# Verify scan metadata
assert result['scan1']['id'] == scan1_data
assert result['scan2']['id'] == scan2_data
def test_compare_scans_not_found(self, test_db):
"""Test comparison with nonexistent scan."""
service = ScanService(test_db)
result = service.compare_scans(999, 998)
assert result is None
def test_compare_ports(self, test_db, scan1_data, scan2_data):
"""Test port comparison logic."""
service = ScanService(test_db)
result = service.compare_scans(scan1_data, scan2_data)
# Scan1 has ports 80, 443
# Scan2 has ports 443, 8080
# Expected: added=[8080], removed=[80], unchanged=[443]
ports = result['ports']
assert len(ports['added']) == 1
assert len(ports['removed']) == 1
assert len(ports['unchanged']) == 1
# Check added port
added_port = ports['added'][0]
assert added_port['port'] == 8080
# Check removed port
removed_port = ports['removed'][0]
assert removed_port['port'] == 80
# Check unchanged port
unchanged_port = ports['unchanged'][0]
assert unchanged_port['port'] == 443
def test_compare_services(self, test_db, scan1_data, scan2_data):
"""Test service comparison logic."""
service = ScanService(test_db)
result = service.compare_scans(scan1_data, scan2_data)
services = result['services']
# Scan1 has nginx 1.18.0 on port 80
# Scan2 has nginx 1.20.0 on port 8080
# These are on different ports, so they should be added/removed, not changed
assert len(services['added']) >= 0
assert len(services['removed']) >= 0
def test_drift_score_calculation(self, test_db, scan1_data, scan2_data):
"""Test drift score calculation."""
service = ScanService(test_db)
result = service.compare_scans(scan1_data, scan2_data)
drift_score = result['drift_score']
# Drift score should be between 0.0 and 1.0
assert 0.0 <= drift_score <= 1.0
# Since we have changes (1 port added, 1 removed), drift should be > 0
assert drift_score > 0.0
def test_compare_identical_scans(self, test_db, scan1_data):
"""Test comparing a scan with itself (should have zero drift)."""
service = ScanService(test_db)
result = service.compare_scans(scan1_data, scan1_data)
# Comparing scan with itself should have zero drift
assert result['drift_score'] == 0.0
assert len(result['ports']['added']) == 0
assert len(result['ports']['removed']) == 0
class TestScanComparisonAPI:
"""Tests for scan comparison API endpoint."""
def test_compare_scans_api(self, client, auth_headers, scan1_data, scan2_data):
"""Test scan comparison API endpoint."""
response = client.get(
f'/api/scans/{scan1_data}/compare/{scan2_data}',
headers=auth_headers
)
assert response.status_code == 200
data = response.get_json()
assert 'scan1' in data
assert 'scan2' in data
assert 'ports' in data
assert 'services' in data
assert 'drift_score' in data
def test_compare_scans_api_not_found(self, client, auth_headers):
"""Test comparison API with nonexistent scans."""
response = client.get(
'/api/scans/999/compare/998',
headers=auth_headers
)
assert response.status_code == 404
data = response.get_json()
assert 'error' in data
def test_compare_scans_api_requires_auth(self, client, scan1_data, scan2_data):
"""Test that comparison API requires authentication."""
response = client.get(f'/api/scans/{scan1_data}/compare/{scan2_data}')
assert response.status_code == 401
class TestHistoricalChartAPI:
"""Tests for historical scan chart API endpoint."""
def test_scan_history_api(self, client, auth_headers, scan1_data):
"""Test scan history API endpoint."""
response = client.get(
f'/api/stats/scan-history/{scan1_data}',
headers=auth_headers
)
assert response.status_code == 200
data = response.get_json()
assert 'scans' in data
assert 'labels' in data
assert 'port_counts' in data
assert 'config_file' in data
# Should include at least the scan we created
assert len(data['scans']) >= 1
def test_scan_history_api_not_found(self, client, auth_headers):
"""Test history API with nonexistent scan."""
response = client.get(
'/api/stats/scan-history/999',
headers=auth_headers
)
assert response.status_code == 404
data = response.get_json()
assert 'error' in data
def test_scan_history_api_limit(self, client, auth_headers, scan1_data):
"""Test scan history API with limit parameter."""
response = client.get(
f'/api/stats/scan-history/{scan1_data}?limit=5',
headers=auth_headers
)
assert response.status_code == 200
data = response.get_json()
# Should respect limit
assert len(data['scans']) <= 5
def test_scan_history_api_requires_auth(self, client, scan1_data):
"""Test that history API requires authentication."""
response = client.get(f'/api/stats/scan-history/{scan1_data}')
assert response.status_code == 401

View File

@@ -0,0 +1,402 @@
"""
Unit tests for ScanService class.
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
"""
import pytest
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
from web.services.scan_service import ScanService
class TestScanServiceTrigger:
"""Tests for triggering scans."""
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
"""Test triggering a scan with valid config file."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
# Verify scan created
assert scan_id is not None
assert isinstance(scan_id, int)
# Verify scan in database
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan is not None
assert scan.status == 'running'
assert scan.title == 'Test Scan'
assert scan.triggered_by == 'manual'
assert scan.config_file == sample_config_file
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
"""Test triggering a scan with invalid config file."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="Invalid config file"):
service.trigger_scan(sample_invalid_config_file)
def test_trigger_scan_nonexistent_file(self, test_db):
"""Test triggering a scan with nonexistent config file."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="does not exist"):
service.trigger_scan('/nonexistent/config.yaml')
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
"""Test triggering a scan via schedule."""
service = ScanService(test_db)
scan_id = service.trigger_scan(
sample_config_file,
triggered_by='scheduled',
schedule_id=42
)
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan.triggered_by == 'scheduled'
assert scan.schedule_id == 42
class TestScanServiceGet:
"""Tests for retrieving scans."""
def test_get_scan_not_found(self, test_db):
"""Test getting a nonexistent scan."""
service = ScanService(test_db)
result = service.get_scan(999)
assert result is None
def test_get_scan_found(self, test_db, sample_config_file):
"""Test getting an existing scan."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Retrieve it
result = service.get_scan(scan_id)
assert result is not None
assert result['id'] == scan_id
assert result['title'] == 'Test Scan'
assert result['status'] == 'running'
assert 'sites' in result
class TestScanServiceList:
"""Tests for listing scans."""
def test_list_scans_empty(self, test_db):
"""Test listing scans when database is empty."""
service = ScanService(test_db)
result = service.list_scans(page=1, per_page=20)
assert result.total == 0
assert len(result.items) == 0
assert result.pages == 0
def test_list_scans_with_data(self, test_db, sample_config_file):
"""Test listing scans with multiple scans."""
service = ScanService(test_db)
# Create 3 scans
for i in range(3):
service.trigger_scan(sample_config_file, triggered_by='api')
# List all scans
result = service.list_scans(page=1, per_page=20)
assert result.total == 3
assert len(result.items) == 3
assert result.pages == 1
def test_list_scans_pagination(self, test_db, sample_config_file):
"""Test pagination."""
service = ScanService(test_db)
# Create 5 scans
for i in range(5):
service.trigger_scan(sample_config_file)
# Get page 1 (2 items per page)
result = service.list_scans(page=1, per_page=2)
assert len(result.items) == 2
assert result.total == 5
assert result.pages == 3
assert result.has_next is True
# Get page 2
result = service.list_scans(page=2, per_page=2)
assert len(result.items) == 2
assert result.has_prev is True
assert result.has_next is True
# Get page 3 (last page)
result = service.list_scans(page=3, per_page=2)
assert len(result.items) == 1
assert result.has_next is False
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
"""Test filtering scans by status."""
service = ScanService(test_db)
# Create scans with different statuses
scan_id_1 = service.trigger_scan(sample_config_file)
scan_id_2 = service.trigger_scan(sample_config_file)
# Mark one as completed
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
scan.status = 'completed'
test_db.commit()
# Filter by running
result = service.list_scans(status_filter='running')
assert result.total == 1
# Filter by completed
result = service.list_scans(status_filter='completed')
assert result.total == 1
def test_list_scans_invalid_status_filter(self, test_db):
"""Test filtering with invalid status."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="Invalid status"):
service.list_scans(status_filter='invalid_status')
class TestScanServiceDelete:
"""Tests for deleting scans."""
def test_delete_scan_not_found(self, test_db):
"""Test deleting a nonexistent scan."""
service = ScanService(test_db)
with pytest.raises(ValueError, match="not found"):
service.delete_scan(999)
def test_delete_scan_success(self, test_db, sample_config_file):
"""Test successful scan deletion."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Verify it exists
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
# Delete it
result = service.delete_scan(scan_id)
assert result is True
# Verify it's gone
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
class TestScanServiceStatus:
"""Tests for scan status retrieval."""
def test_get_scan_status_not_found(self, test_db):
"""Test getting status of nonexistent scan."""
service = ScanService(test_db)
result = service.get_scan_status(999)
assert result is None
def test_get_scan_status_running(self, test_db, sample_config_file):
"""Test getting status of running scan."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
status = service.get_scan_status(scan_id)
assert status is not None
assert status['scan_id'] == scan_id
assert status['status'] == 'running'
assert status['progress'] == 'In progress'
assert status['title'] == 'Test Scan'
def test_get_scan_status_completed(self, test_db, sample_config_file):
"""Test getting status of completed scan."""
service = ScanService(test_db)
# Create and mark as completed
scan_id = service.trigger_scan(sample_config_file)
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
scan.status = 'completed'
scan.duration = 125.5
test_db.commit()
status = service.get_scan_status(scan_id)
assert status['status'] == 'completed'
assert status['progress'] == 'Complete'
assert status['duration'] == 125.5
class TestScanServiceDatabaseMapping:
"""Tests for mapping scan reports to database models."""
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
"""Test saving a complete scan report to database."""
service = ScanService(test_db)
# Create a scan
scan_id = service.trigger_scan(sample_config_file)
# Save report to database
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
# Verify scan updated
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert scan.status == 'completed'
assert scan.duration == 125.5
# Verify sites created
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
assert len(sites) == 1
assert sites[0].site_name == 'Test Site'
# Verify IPs created
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
assert len(ips) == 1
assert ips[0].ip_address == '192.168.1.10'
assert ips[0].ping_expected is True
assert ips[0].ping_actual is True
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
assert len(ports) == 5 # 4 TCP + 1 UDP
# Verify TCP ports
tcp_ports = [p for p in ports if p.protocol == 'tcp']
assert len(tcp_ports) == 4
tcp_port_numbers = sorted([p.port for p in tcp_ports])
assert tcp_port_numbers == [22, 80, 443, 8080]
# Verify UDP ports
udp_ports = [p for p in ports if p.protocol == 'udp']
assert len(udp_ports) == 1
assert udp_ports[0].port == 53
# Verify services created
services = test_db.query(ScanServiceModel).filter(
ScanServiceModel.scan_id == scan_id
).all()
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
# Find HTTPS service
https_service = next(
(s for s in services if s.service_name == 'https'), None
)
assert https_service is not None
assert https_service.product == 'nginx'
assert https_service.version == '1.24.0'
assert https_service.http_protocol == 'https'
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
"""Test that expected vs actual ports are correctly flagged."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Check TCP ports
tcp_ports = test_db.query(ScanPort).filter(
ScanPort.scan_id == scan_id,
ScanPort.protocol == 'tcp'
).all()
# Ports 22, 80, 443 were expected
expected_ports = {22, 80, 443}
for port in tcp_ports:
if port.port in expected_ports:
assert port.expected is True, f"Port {port.port} should be expected"
else:
# Port 8080 was not expected
assert port.expected is False, f"Port {port.port} should not be expected"
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
"""Test that certificate and TLS data are correctly mapped."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Find HTTPS service
https_service = test_db.query(ScanServiceModel).filter(
ScanServiceModel.scan_id == scan_id,
ScanServiceModel.service_name == 'https'
).first()
assert https_service is not None
# Verify certificate created
assert len(https_service.certificates) == 1
cert = https_service.certificates[0]
assert cert.subject == 'CN=example.com'
assert cert.issuer == "CN=Let's Encrypt Authority"
assert cert.days_until_expiry == 365
assert cert.is_self_signed is False
# Verify SANs
import json
sans = json.loads(cert.sans)
assert 'example.com' in sans
assert 'www.example.com' in sans
# Verify TLS versions
assert len(cert.tls_versions) == 2
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
assert tls_12 is not None
assert tls_12.supported is True
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
assert tls_13 is not None
assert tls_13.supported is True
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
"""Test retrieving scan with all nested relationships."""
service = ScanService(test_db)
scan_id = service.trigger_scan(sample_config_file)
service._save_scan_to_db(sample_scan_report, scan_id)
# Get full scan details
result = service.get_scan(scan_id)
assert result is not None
assert len(result['sites']) == 1
site = result['sites'][0]
assert site['name'] == 'Test Site'
assert len(site['ips']) == 1
ip = site['ips'][0]
assert ip['address'] == '192.168.1.10'
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
# Find HTTPS port
https_port = next(
(p for p in ip['ports'] if p['port'] == 443), None
)
assert https_port is not None
assert len(https_port['services']) == 1
service_data = https_port['services'][0]
assert service_data['service_name'] == 'https'
assert 'certificates' in service_data
assert len(service_data['certificates']) == 1
cert = service_data['certificates'][0]
assert cert['subject'] == 'CN=example.com'
assert 'tls_versions' in cert
assert len(cert['tls_versions']) == 2

View File

@@ -0,0 +1,639 @@
"""
Integration tests for Schedule API endpoints.
Tests all schedule management endpoints including creating, listing,
updating, deleting schedules, and manually triggering scheduled scans.
"""
import json
import pytest
from datetime import datetime
from web.models import Schedule, Scan
@pytest.fixture
def sample_schedule(db, sample_config_file):
"""
Create a sample schedule in the database for testing.
Args:
db: Database session fixture
sample_config_file: Path to test config file
Returns:
Schedule model instance
"""
schedule = Schedule(
name='Daily Test Scan',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True,
last_run=None,
next_run=datetime(2025, 11, 15, 2, 0, 0),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
db.add(schedule)
db.commit()
db.refresh(schedule)
return schedule
class TestScheduleAPIEndpoints:
"""Test suite for schedule API endpoints."""
def test_list_schedules_empty(self, client, db):
"""Test listing schedules when database is empty."""
response = client.get('/api/schedules')
assert response.status_code == 200
data = json.loads(response.data)
assert data['schedules'] == []
assert data['total'] == 0
assert data['page'] == 1
assert data['per_page'] == 20
def test_list_schedules_populated(self, client, db, sample_schedule):
"""Test listing schedules with existing data."""
response = client.get('/api/schedules')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 1
assert len(data['schedules']) == 1
assert data['schedules'][0]['id'] == sample_schedule.id
assert data['schedules'][0]['name'] == sample_schedule.name
assert data['schedules'][0]['cron_expression'] == sample_schedule.cron_expression
def test_list_schedules_pagination(self, client, db, sample_config_file):
"""Test schedule list pagination."""
# Create 25 schedules
for i in range(25):
schedule = Schedule(
name=f'Schedule {i}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True,
created_at=datetime.utcnow()
)
db.add(schedule)
db.commit()
# Test page 1
response = client.get('/api/schedules?page=1&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 25
assert len(data['schedules']) == 10
assert data['page'] == 1
assert data['per_page'] == 10
assert data['pages'] == 3
# Test page 2
response = client.get('/api/schedules?page=2&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert len(data['schedules']) == 10
assert data['page'] == 2
def test_list_schedules_filter_enabled(self, client, db, sample_config_file):
"""Test filtering schedules by enabled status."""
# Create enabled and disabled schedules
for i in range(3):
schedule = Schedule(
name=f'Enabled Schedule {i}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True,
created_at=datetime.utcnow()
)
db.add(schedule)
for i in range(2):
schedule = Schedule(
name=f'Disabled Schedule {i}',
config_file=sample_config_file,
cron_expression='0 3 * * *',
enabled=False,
created_at=datetime.utcnow()
)
db.add(schedule)
db.commit()
# Filter by enabled=true
response = client.get('/api/schedules?enabled=true')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 3
for schedule in data['schedules']:
assert schedule['enabled'] is True
# Filter by enabled=false
response = client.get('/api/schedules?enabled=false')
assert response.status_code == 200
data = json.loads(response.data)
assert data['total'] == 2
for schedule in data['schedules']:
assert schedule['enabled'] is False
def test_get_schedule(self, client, db, sample_schedule):
"""Test getting schedule details."""
response = client.get(f'/api/schedules/{sample_schedule.id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['id'] == sample_schedule.id
assert data['name'] == sample_schedule.name
assert data['config_file'] == sample_schedule.config_file
assert data['cron_expression'] == sample_schedule.cron_expression
assert data['enabled'] == sample_schedule.enabled
assert 'history' in data
def test_get_schedule_not_found(self, client, db):
"""Test getting non-existent schedule."""
response = client.get('/api/schedules/99999')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
assert 'not found' in data['error'].lower()
def test_create_schedule(self, client, db, sample_config_file):
"""Test creating a new schedule."""
schedule_data = {
'name': 'New Test Schedule',
'config_file': sample_config_file,
'cron_expression': '0 3 * * *',
'enabled': True
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 201
data = json.loads(response.data)
assert 'schedule_id' in data
assert data['message'] == 'Schedule created successfully'
assert 'schedule' in data
# Verify schedule in database
schedule = db.query(Schedule).filter(Schedule.id == data['schedule_id']).first()
assert schedule is not None
assert schedule.name == schedule_data['name']
assert schedule.cron_expression == schedule_data['cron_expression']
def test_create_schedule_missing_fields(self, client, db):
"""Test creating schedule with missing required fields."""
# Missing cron_expression
schedule_data = {
'name': 'Incomplete Schedule',
'config_file': '/app/configs/test.yaml'
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'missing' in data['error'].lower()
def test_create_schedule_invalid_cron(self, client, db, sample_config_file):
"""Test creating schedule with invalid cron expression."""
schedule_data = {
'name': 'Invalid Cron Schedule',
'config_file': sample_config_file,
'cron_expression': 'invalid cron'
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'invalid' in data['error'].lower() or 'cron' in data['error'].lower()
def test_create_schedule_invalid_config(self, client, db):
"""Test creating schedule with non-existent config file."""
schedule_data = {
'name': 'Invalid Config Schedule',
'config_file': '/nonexistent/config.yaml',
'cron_expression': '0 2 * * *'
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'not found' in data['error'].lower()
def test_update_schedule(self, client, db, sample_schedule):
"""Test updating schedule fields."""
update_data = {
'name': 'Updated Schedule Name',
'cron_expression': '0 4 * * *'
}
response = client.put(
f'/api/schedules/{sample_schedule.id}',
data=json.dumps(update_data),
content_type='application/json'
)
assert response.status_code == 200
data = json.loads(response.data)
assert data['message'] == 'Schedule updated successfully'
assert data['schedule']['name'] == update_data['name']
assert data['schedule']['cron_expression'] == update_data['cron_expression']
# Verify in database
db.refresh(sample_schedule)
assert sample_schedule.name == update_data['name']
assert sample_schedule.cron_expression == update_data['cron_expression']
def test_update_schedule_not_found(self, client, db):
"""Test updating non-existent schedule."""
update_data = {'name': 'New Name'}
response = client.put(
'/api/schedules/99999',
data=json.dumps(update_data),
content_type='application/json'
)
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
def test_update_schedule_invalid_cron(self, client, db, sample_schedule):
"""Test updating schedule with invalid cron expression."""
update_data = {'cron_expression': 'invalid'}
response = client.put(
f'/api/schedules/{sample_schedule.id}',
data=json.dumps(update_data),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
def test_update_schedule_toggle_enabled(self, client, db, sample_schedule):
"""Test enabling/disabling schedule."""
# Disable schedule
response = client.put(
f'/api/schedules/{sample_schedule.id}',
data=json.dumps({'enabled': False}),
content_type='application/json'
)
assert response.status_code == 200
data = json.loads(response.data)
assert data['schedule']['enabled'] is False
# Enable schedule
response = client.put(
f'/api/schedules/{sample_schedule.id}',
data=json.dumps({'enabled': True}),
content_type='application/json'
)
assert response.status_code == 200
data = json.loads(response.data)
assert data['schedule']['enabled'] is True
def test_update_schedule_no_data(self, client, db, sample_schedule):
"""Test updating schedule with no data."""
response = client.put(
f'/api/schedules/{sample_schedule.id}',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
def test_delete_schedule(self, client, db, sample_schedule):
"""Test deleting a schedule."""
schedule_id = sample_schedule.id
response = client.delete(f'/api/schedules/{schedule_id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['message'] == 'Schedule deleted successfully'
assert data['schedule_id'] == schedule_id
# Verify deletion in database
schedule = db.query(Schedule).filter(Schedule.id == schedule_id).first()
assert schedule is None
def test_delete_schedule_not_found(self, client, db):
"""Test deleting non-existent schedule."""
response = client.delete('/api/schedules/99999')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
def test_delete_schedule_preserves_scans(self, client, db, sample_schedule, sample_config_file):
"""Test that deleting schedule preserves associated scans."""
# Create a scan associated with the schedule
scan = Scan(
timestamp=datetime.utcnow(),
status='completed',
config_file=sample_config_file,
title='Test Scan',
triggered_by='scheduled',
schedule_id=sample_schedule.id
)
db.add(scan)
db.commit()
scan_id = scan.id
# Delete schedule
response = client.delete(f'/api/schedules/{sample_schedule.id}')
assert response.status_code == 200
# Verify scan still exists
scan = db.query(Scan).filter(Scan.id == scan_id).first()
assert scan is not None
assert scan.schedule_id is None # Schedule ID becomes null
def test_trigger_schedule(self, client, db, sample_schedule):
"""Test manually triggering a scheduled scan."""
response = client.post(f'/api/schedules/{sample_schedule.id}/trigger')
assert response.status_code == 201
data = json.loads(response.data)
assert data['message'] == 'Scan triggered successfully'
assert 'scan_id' in data
assert data['schedule_id'] == sample_schedule.id
# Verify scan was created
scan = db.query(Scan).filter(Scan.id == data['scan_id']).first()
assert scan is not None
assert scan.triggered_by == 'manual'
assert scan.schedule_id == sample_schedule.id
assert scan.config_file == sample_schedule.config_file
def test_trigger_schedule_not_found(self, client, db):
"""Test triggering non-existent schedule."""
response = client.post('/api/schedules/99999/trigger')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
def test_get_schedule_with_history(self, client, db, sample_schedule, sample_config_file):
"""Test getting schedule includes execution history."""
# Create some scans for this schedule
for i in range(5):
scan = Scan(
timestamp=datetime.utcnow(),
status='completed',
config_file=sample_config_file,
title=f'Scheduled Scan {i}',
triggered_by='scheduled',
schedule_id=sample_schedule.id
)
db.add(scan)
db.commit()
response = client.get(f'/api/schedules/{sample_schedule.id}')
assert response.status_code == 200
data = json.loads(response.data)
assert 'history' in data
assert len(data['history']) == 5
def test_schedule_workflow_integration(self, client, db, sample_config_file):
"""Test complete schedule workflow: create → update → trigger → delete."""
# 1. Create schedule
schedule_data = {
'name': 'Integration Test Schedule',
'config_file': sample_config_file,
'cron_expression': '0 2 * * *',
'enabled': True
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 201
schedule_id = json.loads(response.data)['schedule_id']
# 2. Get schedule
response = client.get(f'/api/schedules/{schedule_id}')
assert response.status_code == 200
# 3. Update schedule
response = client.put(
f'/api/schedules/{schedule_id}',
data=json.dumps({'name': 'Updated Integration Test'}),
content_type='application/json'
)
assert response.status_code == 200
# 4. Trigger schedule
response = client.post(f'/api/schedules/{schedule_id}/trigger')
assert response.status_code == 201
scan_id = json.loads(response.data)['scan_id']
# 5. Verify scan was created
scan = db.query(Scan).filter(Scan.id == scan_id).first()
assert scan is not None
# 6. Delete schedule
response = client.delete(f'/api/schedules/{schedule_id}')
assert response.status_code == 200
# 7. Verify schedule deleted
response = client.get(f'/api/schedules/{schedule_id}')
assert response.status_code == 404
# 8. Verify scan still exists
scan = db.query(Scan).filter(Scan.id == scan_id).first()
assert scan is not None
def test_list_schedules_ordering(self, client, db, sample_config_file):
"""Test that schedules are ordered by next_run time."""
# Create schedules with different next_run times
schedules = []
for i in range(3):
schedule = Schedule(
name=f'Schedule {i}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True,
next_run=datetime(2025, 11, 15 + i, 2, 0, 0),
created_at=datetime.utcnow()
)
db.add(schedule)
schedules.append(schedule)
# Create a disabled schedule (next_run is None)
disabled_schedule = Schedule(
name='Disabled Schedule',
config_file=sample_config_file,
cron_expression='0 3 * * *',
enabled=False,
next_run=None,
created_at=datetime.utcnow()
)
db.add(disabled_schedule)
db.commit()
response = client.get('/api/schedules')
assert response.status_code == 200
data = json.loads(response.data)
returned_schedules = data['schedules']
# Schedules with next_run should come before those without
# Within those with next_run, they should be ordered by time
assert returned_schedules[0]['id'] == schedules[0].id
assert returned_schedules[1]['id'] == schedules[1].id
assert returned_schedules[2]['id'] == schedules[2].id
assert returned_schedules[3]['id'] == disabled_schedule.id
def test_create_schedule_with_disabled(self, client, db, sample_config_file):
"""Test creating a disabled schedule."""
schedule_data = {
'name': 'Disabled Schedule',
'config_file': sample_config_file,
'cron_expression': '0 2 * * *',
'enabled': False
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 201
data = json.loads(response.data)
assert data['schedule']['enabled'] is False
assert data['schedule']['next_run'] is None # Disabled schedules have no next_run
class TestScheduleAPIAuthentication:
"""Test suite for schedule API authentication."""
def test_schedules_require_authentication(self, app):
"""Test that all schedule endpoints require authentication."""
# Create unauthenticated client
client = app.test_client()
endpoints = [
('GET', '/api/schedules'),
('GET', '/api/schedules/1'),
('POST', '/api/schedules'),
('PUT', '/api/schedules/1'),
('DELETE', '/api/schedules/1'),
('POST', '/api/schedules/1/trigger')
]
for method, endpoint in endpoints:
if method == 'GET':
response = client.get(endpoint)
elif method == 'POST':
response = client.post(
endpoint,
data=json.dumps({}),
content_type='application/json'
)
elif method == 'PUT':
response = client.put(
endpoint,
data=json.dumps({}),
content_type='application/json'
)
elif method == 'DELETE':
response = client.delete(endpoint)
# Should redirect to login or return 401
assert response.status_code in [302, 401], \
f"{method} {endpoint} should require authentication"
class TestScheduleAPICronValidation:
"""Test suite for cron expression validation."""
def test_valid_cron_expressions(self, client, db, sample_config_file):
"""Test various valid cron expressions."""
valid_expressions = [
'0 2 * * *', # Daily at 2am
'*/15 * * * *', # Every 15 minutes
'0 0 * * 0', # Weekly on Sunday
'0 0 1 * *', # Monthly on 1st
'0 */4 * * *', # Every 4 hours
]
for cron_expr in valid_expressions:
schedule_data = {
'name': f'Schedule for {cron_expr}',
'config_file': sample_config_file,
'cron_expression': cron_expr
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 201, \
f"Valid cron expression '{cron_expr}' should be accepted"
def test_invalid_cron_expressions(self, client, db, sample_config_file):
"""Test various invalid cron expressions."""
invalid_expressions = [
'invalid',
'60 2 * * *', # Invalid minute
'0 25 * * *', # Invalid hour
'0 0 32 * *', # Invalid day
'0 0 * 13 *', # Invalid month
'0 0 * * 8', # Invalid day of week
]
for cron_expr in invalid_expressions:
schedule_data = {
'name': f'Schedule for {cron_expr}',
'config_file': sample_config_file,
'cron_expression': cron_expr
}
response = client.post(
'/api/schedules',
data=json.dumps(schedule_data),
content_type='application/json'
)
assert response.status_code == 400, \
f"Invalid cron expression '{cron_expr}' should be rejected"

View File

@@ -0,0 +1,671 @@
"""
Unit tests for ScheduleService class.
Tests schedule lifecycle operations: create, get, list, update, delete, and
cron expression validation.
"""
import pytest
from datetime import datetime, timedelta
from web.models import Schedule, Scan
from web.services.schedule_service import ScheduleService
class TestScheduleServiceCreate:
"""Tests for creating schedules."""
def test_create_schedule_valid(self, test_db, sample_config_file):
"""Test creating a schedule with valid parameters."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Daily Scan',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Verify schedule created
assert schedule_id is not None
assert isinstance(schedule_id, int)
# Verify schedule in database
schedule = test_db.query(Schedule).filter(Schedule.id == schedule_id).first()
assert schedule is not None
assert schedule.name == 'Daily Scan'
assert schedule.config_file == sample_config_file
assert schedule.cron_expression == '0 2 * * *'
assert schedule.enabled is True
assert schedule.next_run is not None
assert schedule.last_run is None
def test_create_schedule_disabled(self, test_db, sample_config_file):
"""Test creating a disabled schedule."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Disabled Scan',
config_file=sample_config_file,
cron_expression='0 3 * * *',
enabled=False
)
schedule = test_db.query(Schedule).filter(Schedule.id == schedule_id).first()
assert schedule.enabled is False
assert schedule.next_run is None
def test_create_schedule_invalid_cron(self, test_db, sample_config_file):
"""Test creating a schedule with invalid cron expression."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Invalid cron expression"):
service.create_schedule(
name='Invalid Schedule',
config_file=sample_config_file,
cron_expression='invalid cron',
enabled=True
)
def test_create_schedule_nonexistent_config(self, test_db):
"""Test creating a schedule with nonexistent config file."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Config file not found"):
service.create_schedule(
name='Bad Config',
config_file='/nonexistent/config.yaml',
cron_expression='0 2 * * *',
enabled=True
)
def test_create_schedule_various_cron_expressions(self, test_db, sample_config_file):
"""Test creating schedules with various valid cron expressions."""
service = ScheduleService(test_db)
cron_expressions = [
'0 0 * * *', # Daily at midnight
'*/15 * * * *', # Every 15 minutes
'0 2 * * 0', # Weekly on Sunday at 2 AM
'0 0 1 * *', # Monthly on the 1st at midnight
'30 14 * * 1-5', # Weekdays at 2:30 PM
]
for i, cron in enumerate(cron_expressions):
schedule_id = service.create_schedule(
name=f'Schedule {i}',
config_file=sample_config_file,
cron_expression=cron,
enabled=True
)
assert schedule_id is not None
class TestScheduleServiceGet:
"""Tests for retrieving schedules."""
def test_get_schedule_not_found(self, test_db):
"""Test getting a nonexistent schedule."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Schedule .* not found"):
service.get_schedule(999)
def test_get_schedule_found(self, test_db, sample_config_file):
"""Test getting an existing schedule."""
service = ScheduleService(test_db)
# Create a schedule
schedule_id = service.create_schedule(
name='Test Schedule',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Retrieve it
result = service.get_schedule(schedule_id)
assert result is not None
assert result['id'] == schedule_id
assert result['name'] == 'Test Schedule'
assert result['cron_expression'] == '0 2 * * *'
assert result['enabled'] is True
assert 'history' in result
assert isinstance(result['history'], list)
def test_get_schedule_with_history(self, test_db, sample_config_file):
"""Test getting schedule includes execution history."""
service = ScheduleService(test_db)
# Create schedule
schedule_id = service.create_schedule(
name='Test Schedule',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Create associated scans
for i in range(3):
scan = Scan(
timestamp=datetime.utcnow() - timedelta(days=i),
status='completed',
config_file=sample_config_file,
title=f'Scan {i}',
triggered_by='scheduled',
schedule_id=schedule_id
)
test_db.add(scan)
test_db.commit()
# Get schedule
result = service.get_schedule(schedule_id)
assert len(result['history']) == 3
assert result['history'][0]['title'] == 'Scan 0' # Most recent first
class TestScheduleServiceList:
"""Tests for listing schedules."""
def test_list_schedules_empty(self, test_db):
"""Test listing schedules when database is empty."""
service = ScheduleService(test_db)
result = service.list_schedules(page=1, per_page=20)
assert result['total'] == 0
assert len(result['schedules']) == 0
assert result['page'] == 1
assert result['per_page'] == 20
def test_list_schedules_populated(self, test_db, sample_config_file):
"""Test listing schedules with data."""
service = ScheduleService(test_db)
# Create multiple schedules
for i in range(5):
service.create_schedule(
name=f'Schedule {i}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
result = service.list_schedules(page=1, per_page=20)
assert result['total'] == 5
assert len(result['schedules']) == 5
assert all('name' in s for s in result['schedules'])
def test_list_schedules_pagination(self, test_db, sample_config_file):
"""Test schedule pagination."""
service = ScheduleService(test_db)
# Create 25 schedules
for i in range(25):
service.create_schedule(
name=f'Schedule {i:02d}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Get first page
result_page1 = service.list_schedules(page=1, per_page=10)
assert len(result_page1['schedules']) == 10
assert result_page1['total'] == 25
assert result_page1['pages'] == 3
# Get second page
result_page2 = service.list_schedules(page=2, per_page=10)
assert len(result_page2['schedules']) == 10
# Get third page
result_page3 = service.list_schedules(page=3, per_page=10)
assert len(result_page3['schedules']) == 5
def test_list_schedules_filter_enabled(self, test_db, sample_config_file):
"""Test filtering schedules by enabled status."""
service = ScheduleService(test_db)
# Create enabled and disabled schedules
for i in range(3):
service.create_schedule(
name=f'Enabled {i}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
for i in range(2):
service.create_schedule(
name=f'Disabled {i}',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=False
)
# Filter enabled only
result_enabled = service.list_schedules(enabled_filter=True)
assert result_enabled['total'] == 3
# Filter disabled only
result_disabled = service.list_schedules(enabled_filter=False)
assert result_disabled['total'] == 2
# No filter
result_all = service.list_schedules(enabled_filter=None)
assert result_all['total'] == 5
class TestScheduleServiceUpdate:
"""Tests for updating schedules."""
def test_update_schedule_name(self, test_db, sample_config_file):
"""Test updating schedule name."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Old Name',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
result = service.update_schedule(schedule_id, name='New Name')
assert result['name'] == 'New Name'
assert result['cron_expression'] == '0 2 * * *'
def test_update_schedule_cron(self, test_db, sample_config_file):
"""Test updating cron expression recalculates next_run."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
original = service.get_schedule(schedule_id)
original_next_run = original['next_run']
# Update cron expression
result = service.update_schedule(
schedule_id,
cron_expression='0 3 * * *'
)
# Next run should be recalculated
assert result['cron_expression'] == '0 3 * * *'
assert result['next_run'] != original_next_run
def test_update_schedule_invalid_cron(self, test_db, sample_config_file):
"""Test updating with invalid cron expression fails."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
with pytest.raises(ValueError, match="Invalid cron expression"):
service.update_schedule(schedule_id, cron_expression='invalid')
def test_update_schedule_not_found(self, test_db):
"""Test updating nonexistent schedule fails."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Schedule .* not found"):
service.update_schedule(999, name='New Name')
def test_update_schedule_invalid_config_file(self, test_db, sample_config_file):
"""Test updating with nonexistent config file fails."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
with pytest.raises(ValueError, match="Config file not found"):
service.update_schedule(schedule_id, config_file='/nonexistent.yaml')
class TestScheduleServiceDelete:
"""Tests for deleting schedules."""
def test_delete_schedule(self, test_db, sample_config_file):
"""Test deleting a schedule."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='To Delete',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Verify exists
assert test_db.query(Schedule).filter(Schedule.id == schedule_id).first() is not None
# Delete
result = service.delete_schedule(schedule_id)
assert result is True
# Verify deleted
assert test_db.query(Schedule).filter(Schedule.id == schedule_id).first() is None
def test_delete_schedule_not_found(self, test_db):
"""Test deleting nonexistent schedule fails."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Schedule .* not found"):
service.delete_schedule(999)
def test_delete_schedule_preserves_scans(self, test_db, sample_config_file):
"""Test that deleting schedule preserves associated scans."""
service = ScheduleService(test_db)
# Create schedule
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Create associated scan
scan = Scan(
timestamp=datetime.utcnow(),
status='completed',
config_file=sample_config_file,
title='Test Scan',
triggered_by='scheduled',
schedule_id=schedule_id
)
test_db.add(scan)
test_db.commit()
scan_id = scan.id
# Delete schedule
service.delete_schedule(schedule_id)
# Verify scan still exists (schedule_id becomes null)
remaining_scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
assert remaining_scan is not None
assert remaining_scan.schedule_id is None
class TestScheduleServiceToggle:
"""Tests for toggling schedule enabled status."""
def test_toggle_enabled_to_disabled(self, test_db, sample_config_file):
"""Test disabling an enabled schedule."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
result = service.toggle_enabled(schedule_id, enabled=False)
assert result['enabled'] is False
assert result['next_run'] is None
def test_toggle_disabled_to_enabled(self, test_db, sample_config_file):
"""Test enabling a disabled schedule."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=False
)
result = service.toggle_enabled(schedule_id, enabled=True)
assert result['enabled'] is True
assert result['next_run'] is not None
class TestScheduleServiceRunTimes:
"""Tests for updating run times."""
def test_update_run_times(self, test_db, sample_config_file):
"""Test updating last_run and next_run."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
last_run = datetime.utcnow()
next_run = datetime.utcnow() + timedelta(days=1)
result = service.update_run_times(schedule_id, last_run, next_run)
assert result is True
schedule = service.get_schedule(schedule_id)
assert schedule['last_run'] is not None
assert schedule['next_run'] is not None
def test_update_run_times_not_found(self, test_db):
"""Test updating run times for nonexistent schedule."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Schedule .* not found"):
service.update_run_times(
999,
datetime.utcnow(),
datetime.utcnow() + timedelta(days=1)
)
class TestCronValidation:
"""Tests for cron expression validation."""
def test_validate_cron_valid_expressions(self, test_db):
"""Test validating various valid cron expressions."""
service = ScheduleService(test_db)
valid_expressions = [
'0 0 * * *', # Daily at midnight
'*/15 * * * *', # Every 15 minutes
'0 2 * * 0', # Weekly on Sunday
'0 0 1 * *', # Monthly
'30 14 * * 1-5', # Weekdays
'0 */4 * * *', # Every 4 hours
]
for expr in valid_expressions:
is_valid, error = service.validate_cron_expression(expr)
assert is_valid is True, f"Expression '{expr}' should be valid"
assert error is None
def test_validate_cron_invalid_expressions(self, test_db):
"""Test validating invalid cron expressions."""
service = ScheduleService(test_db)
invalid_expressions = [
'invalid',
'60 0 * * *', # Invalid minute (0-59)
'0 24 * * *', # Invalid hour (0-23)
'0 0 32 * *', # Invalid day (1-31)
'0 0 * 13 *', # Invalid month (1-12)
'0 0 * * 7', # Invalid weekday (0-6)
]
for expr in invalid_expressions:
is_valid, error = service.validate_cron_expression(expr)
assert is_valid is False, f"Expression '{expr}' should be invalid"
assert error is not None
class TestNextRunCalculation:
"""Tests for next run time calculation."""
def test_calculate_next_run(self, test_db):
"""Test calculating next run time."""
service = ScheduleService(test_db)
# Daily at 2 AM
next_run = service.calculate_next_run('0 2 * * *')
assert next_run is not None
assert isinstance(next_run, datetime)
assert next_run > datetime.utcnow()
def test_calculate_next_run_from_time(self, test_db):
"""Test calculating next run from specific time."""
service = ScheduleService(test_db)
base_time = datetime(2025, 1, 1, 0, 0, 0)
next_run = service.calculate_next_run('0 2 * * *', from_time=base_time)
# Should be 2 AM on same day
assert next_run.hour == 2
assert next_run.minute == 0
def test_calculate_next_run_invalid_cron(self, test_db):
"""Test calculating next run with invalid cron raises error."""
service = ScheduleService(test_db)
with pytest.raises(ValueError, match="Invalid cron expression"):
service.calculate_next_run('invalid cron')
class TestScheduleHistory:
"""Tests for schedule execution history."""
def test_get_schedule_history_empty(self, test_db, sample_config_file):
"""Test getting history for schedule with no executions."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
history = service.get_schedule_history(schedule_id)
assert len(history) == 0
def test_get_schedule_history_with_scans(self, test_db, sample_config_file):
"""Test getting history with multiple scans."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Create 15 scans
for i in range(15):
scan = Scan(
timestamp=datetime.utcnow() - timedelta(days=i),
status='completed',
config_file=sample_config_file,
title=f'Scan {i}',
triggered_by='scheduled',
schedule_id=schedule_id
)
test_db.add(scan)
test_db.commit()
# Get history (default limit 10)
history = service.get_schedule_history(schedule_id, limit=10)
assert len(history) == 10
assert history[0]['title'] == 'Scan 0' # Most recent first
def test_get_schedule_history_custom_limit(self, test_db, sample_config_file):
"""Test getting history with custom limit."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
# Create 10 scans
for i in range(10):
scan = Scan(
timestamp=datetime.utcnow() - timedelta(days=i),
status='completed',
config_file=sample_config_file,
title=f'Scan {i}',
triggered_by='scheduled',
schedule_id=schedule_id
)
test_db.add(scan)
test_db.commit()
# Get only 5
history = service.get_schedule_history(schedule_id, limit=5)
assert len(history) == 5
class TestScheduleSerialization:
"""Tests for schedule serialization."""
def test_schedule_to_dict(self, test_db, sample_config_file):
"""Test converting schedule to dictionary."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test Schedule',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
result = service.get_schedule(schedule_id)
# Verify all required fields
assert 'id' in result
assert 'name' in result
assert 'config_file' in result
assert 'cron_expression' in result
assert 'enabled' in result
assert 'last_run' in result
assert 'next_run' in result
assert 'next_run_relative' in result
assert 'created_at' in result
assert 'updated_at' in result
assert 'history' in result
def test_schedule_relative_time_formatting(self, test_db, sample_config_file):
"""Test relative time formatting in schedule dict."""
service = ScheduleService(test_db)
schedule_id = service.create_schedule(
name='Test',
config_file=sample_config_file,
cron_expression='0 2 * * *',
enabled=True
)
result = service.get_schedule(schedule_id)
# Should have relative time for next_run
assert result['next_run_relative'] is not None
assert isinstance(result['next_run_relative'], str)
assert 'in' in result['next_run_relative'].lower()

325
app/tests/test_stats_api.py Normal file
View File

@@ -0,0 +1,325 @@
"""
Tests for stats API endpoints.
Tests dashboard statistics and trending data endpoints.
"""
import pytest
from datetime import datetime, timedelta
from web.models import Scan
class TestStatsAPI:
"""Test suite for stats API endpoints."""
def test_scan_trend_default_30_days(self, client, auth_headers, db_session):
"""Test scan trend endpoint with default 30 days."""
# Create test scans over multiple days
today = datetime.utcnow()
for i in range(5):
scan_date = today - timedelta(days=i)
for j in range(i + 1): # Create 1, 2, 3, 4, 5 scans per day
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=scan_date,
status='completed',
duration=10.5
)
db_session.add(scan)
db_session.commit()
# Request trend data
response = client.get('/api/stats/scan-trend', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert 'labels' in data
assert 'values' in data
assert 'start_date' in data
assert 'end_date' in data
assert 'total_scans' in data
# Should have 30 days of data
assert len(data['labels']) == 30
assert len(data['values']) == 30
# Total scans should match (1+2+3+4+5 = 15)
assert data['total_scans'] == 15
# Values should be non-negative integers
assert all(isinstance(v, int) for v in data['values'])
assert all(v >= 0 for v in data['values'])
def test_scan_trend_custom_days(self, client, auth_headers, db_session):
"""Test scan trend endpoint with custom number of days."""
# Create test scans
today = datetime.utcnow()
for i in range(10):
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=today - timedelta(days=i),
status='completed',
duration=10.5
)
db_session.add(scan)
db_session.commit()
# Request 7 days of data
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert len(data['labels']) == 7
assert len(data['values']) == 7
assert data['total_scans'] == 7
def test_scan_trend_max_days_365(self, client, auth_headers):
"""Test scan trend endpoint accepts maximum 365 days."""
response = client.get('/api/stats/scan-trend?days=365', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert len(data['labels']) == 365
def test_scan_trend_rejects_days_over_365(self, client, auth_headers):
"""Test scan trend endpoint rejects more than 365 days."""
response = client.get('/api/stats/scan-trend?days=366', headers=auth_headers)
assert response.status_code == 400
data = response.get_json()
assert 'error' in data
assert '365' in data['error']
def test_scan_trend_rejects_days_less_than_1(self, client, auth_headers):
"""Test scan trend endpoint rejects days less than 1."""
response = client.get('/api/stats/scan-trend?days=0', headers=auth_headers)
assert response.status_code == 400
data = response.get_json()
assert 'error' in data
def test_scan_trend_fills_missing_days_with_zero(self, client, auth_headers, db_session):
"""Test scan trend fills days with no scans as zero."""
# Create scans only on specific days
today = datetime.utcnow()
# Create scan 5 days ago
scan1 = Scan(
config_file='/app/configs/test.yaml',
timestamp=today - timedelta(days=5),
status='completed',
duration=10.5
)
db_session.add(scan1)
# Create scan 10 days ago
scan2 = Scan(
config_file='/app/configs/test.yaml',
timestamp=today - timedelta(days=10),
status='completed',
duration=10.5
)
db_session.add(scan2)
db_session.commit()
# Request 15 days
response = client.get('/api/stats/scan-trend?days=15', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
# Should have 15 days of data
assert len(data['values']) == 15
# Most days should be zero
zero_days = sum(1 for v in data['values'] if v == 0)
assert zero_days >= 13 # At least 13 days with no scans
def test_scan_trend_requires_authentication(self, client):
"""Test scan trend endpoint requires authentication."""
response = client.get('/api/stats/scan-trend')
assert response.status_code == 401
def test_summary_endpoint(self, client, auth_headers, db_session):
"""Test summary statistics endpoint."""
# Create test scans with different statuses
today = datetime.utcnow()
# 5 completed scans
for i in range(5):
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=today - timedelta(days=i),
status='completed',
duration=10.5
)
db_session.add(scan)
# 2 failed scans
for i in range(2):
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=today - timedelta(days=i),
status='failed',
duration=5.0
)
db_session.add(scan)
# 1 running scan
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=today,
status='running',
duration=None
)
db_session.add(scan)
db_session.commit()
# Request summary
response = client.get('/api/stats/summary', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert 'total_scans' in data
assert 'completed_scans' in data
assert 'failed_scans' in data
assert 'running_scans' in data
assert 'scans_today' in data
assert 'scans_this_week' in data
# Verify counts
assert data['total_scans'] == 8
assert data['completed_scans'] == 5
assert data['failed_scans'] == 2
assert data['running_scans'] == 1
assert data['scans_today'] >= 1
assert data['scans_this_week'] >= 1
def test_summary_with_no_scans(self, client, auth_headers):
"""Test summary endpoint with no scans."""
response = client.get('/api/stats/summary', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert data['total_scans'] == 0
assert data['completed_scans'] == 0
assert data['failed_scans'] == 0
assert data['running_scans'] == 0
assert data['scans_today'] == 0
assert data['scans_this_week'] == 0
def test_summary_scans_today(self, client, auth_headers, db_session):
"""Test summary counts scans today correctly."""
today = datetime.utcnow()
yesterday = today - timedelta(days=1)
# Create 3 scans today
for i in range(3):
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=today,
status='completed',
duration=10.5
)
db_session.add(scan)
# Create 2 scans yesterday
for i in range(2):
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=yesterday,
status='completed',
duration=10.5
)
db_session.add(scan)
db_session.commit()
response = client.get('/api/stats/summary', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
assert data['scans_today'] == 3
assert data['scans_this_week'] >= 3
def test_summary_scans_this_week(self, client, auth_headers, db_session):
"""Test summary counts scans this week correctly."""
today = datetime.utcnow()
# Create scans over the last 10 days
for i in range(10):
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=today - timedelta(days=i),
status='completed',
duration=10.5
)
db_session.add(scan)
db_session.commit()
response = client.get('/api/stats/summary', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
# Last 7 days (0-6) = 7 scans
assert data['scans_this_week'] == 7
def test_summary_requires_authentication(self, client):
"""Test summary endpoint requires authentication."""
response = client.get('/api/stats/summary')
assert response.status_code == 401
def test_scan_trend_date_format(self, client, auth_headers, db_session):
"""Test scan trend returns dates in correct format."""
# Create a scan
scan = Scan(
config_file='/app/configs/test.yaml',
timestamp=datetime.utcnow(),
status='completed',
duration=10.5
)
db_session.add(scan)
db_session.commit()
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
# Check date format (YYYY-MM-DD)
for label in data['labels']:
assert len(label) == 10
assert label[4] == '-'
assert label[7] == '-'
# Try parsing to ensure valid date
datetime.strptime(label, '%Y-%m-%d')
def test_scan_trend_consecutive_dates(self, client, auth_headers):
"""Test scan trend returns consecutive dates."""
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
labels = data['labels']
# Convert to datetime objects
dates = [datetime.strptime(label, '%Y-%m-%d') for label in labels]
# Check dates are consecutive
for i in range(len(dates) - 1):
diff = dates[i + 1] - dates[i]
assert diff.days == 1, f"Dates not consecutive: {dates[i]} to {dates[i+1]}"
def test_scan_trend_ends_with_today(self, client, auth_headers):
"""Test scan trend ends with today's date."""
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
assert response.status_code == 200
data = response.get_json()
# Last date should be today
today = datetime.utcnow().date()
last_date = datetime.strptime(data['labels'][-1], '%Y-%m-%d').date()
assert last_date == today