restructure of dirs, huge docs update
This commit is contained in:
1
app/tests/__init__.py
Normal file
1
app/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test package for SneakyScanner."""
|
||||
384
app/tests/conftest.py
Normal file
384
app/tests/conftest.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Pytest configuration and fixtures for SneakyScanner tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from web.app import create_app
|
||||
from web.models import Base, Scan
|
||||
from web.utils.settings import PasswordManager, SettingsManager
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def test_db():
|
||||
"""
|
||||
Create a temporary test database.
|
||||
|
||||
Yields a SQLAlchemy session for testing, then cleans up.
|
||||
"""
|
||||
# Create temporary database file
|
||||
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
||||
|
||||
# Create engine and session
|
||||
engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
yield session
|
||||
|
||||
# Cleanup
|
||||
session.close()
|
||||
os.close(db_fd)
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_scan_report():
|
||||
"""
|
||||
Sample scan report matching the structure from scanner.py.
|
||||
|
||||
Returns a dictionary representing a typical scan output.
|
||||
"""
|
||||
return {
|
||||
'title': 'Test Scan',
|
||||
'scan_time': '2025-11-14T10:30:00Z',
|
||||
'scan_duration': 125.5,
|
||||
'config_file': '/app/configs/test.yaml',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Test Site',
|
||||
'ips': [
|
||||
{
|
||||
'address': '192.168.1.10',
|
||||
'expected': {
|
||||
'ping': True,
|
||||
'tcp_ports': [22, 80, 443],
|
||||
'udp_ports': [53],
|
||||
'services': []
|
||||
},
|
||||
'actual': {
|
||||
'ping': True,
|
||||
'tcp_ports': [22, 80, 443, 8080],
|
||||
'udp_ports': [53],
|
||||
'services': [
|
||||
{
|
||||
'port': 22,
|
||||
'service': 'ssh',
|
||||
'product': 'OpenSSH',
|
||||
'version': '8.9p1',
|
||||
'extrainfo': 'Ubuntu',
|
||||
'ostype': 'Linux'
|
||||
},
|
||||
{
|
||||
'port': 443,
|
||||
'service': 'https',
|
||||
'product': 'nginx',
|
||||
'version': '1.24.0',
|
||||
'extrainfo': '',
|
||||
'ostype': '',
|
||||
'http_info': {
|
||||
'protocol': 'https',
|
||||
'screenshot': 'screenshots/192_168_1_10_443.png',
|
||||
'certificate': {
|
||||
'subject': 'CN=example.com',
|
||||
'issuer': 'CN=Let\'s Encrypt Authority',
|
||||
'serial_number': '123456789',
|
||||
'not_valid_before': '2025-01-01T00:00:00Z',
|
||||
'not_valid_after': '2025-12-31T23:59:59Z',
|
||||
'days_until_expiry': 365,
|
||||
'sans': ['example.com', 'www.example.com'],
|
||||
'is_self_signed': False,
|
||||
'tls_versions': {
|
||||
'TLS 1.2': {
|
||||
'supported': True,
|
||||
'cipher_suites': [
|
||||
'TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384',
|
||||
'TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256'
|
||||
]
|
||||
},
|
||||
'TLS 1.3': {
|
||||
'supported': True,
|
||||
'cipher_suites': [
|
||||
'TLS_AES_256_GCM_SHA384',
|
||||
'TLS_AES_128_GCM_SHA256'
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'port': 80,
|
||||
'service': 'http',
|
||||
'product': 'nginx',
|
||||
'version': '1.24.0',
|
||||
'extrainfo': '',
|
||||
'ostype': '',
|
||||
'http_info': {
|
||||
'protocol': 'http',
|
||||
'screenshot': 'screenshots/192_168_1_10_80.png'
|
||||
}
|
||||
},
|
||||
{
|
||||
'port': 8080,
|
||||
'service': 'http',
|
||||
'product': 'Jetty',
|
||||
'version': '9.4.48',
|
||||
'extrainfo': '',
|
||||
'ostype': ''
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config_file(tmp_path):
|
||||
"""
|
||||
Create a sample YAML config file for testing.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest temporary directory fixture
|
||||
|
||||
Returns:
|
||||
Path to created config file
|
||||
"""
|
||||
config_data = {
|
||||
'title': 'Test Scan',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Test Site',
|
||||
'ips': [
|
||||
{
|
||||
'address': '192.168.1.10',
|
||||
'expected': {
|
||||
'ping': True,
|
||||
'tcp_ports': [22, 80, 443],
|
||||
'udp_ports': [53],
|
||||
'services': ['ssh', 'http', 'https']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config_file = tmp_path / 'test_config.yaml'
|
||||
with open(config_file, 'w') as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
return str(config_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_invalid_config_file(tmp_path):
|
||||
"""
|
||||
Create an invalid config file for testing validation.
|
||||
|
||||
Returns:
|
||||
Path to invalid config file
|
||||
"""
|
||||
config_file = tmp_path / 'invalid_config.yaml'
|
||||
with open(config_file, 'w') as f:
|
||||
f.write("invalid: yaml: content: [missing closing bracket")
|
||||
|
||||
return str(config_file)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def app():
|
||||
"""
|
||||
Create Flask application for testing.
|
||||
|
||||
Returns:
|
||||
Configured Flask app instance with test database
|
||||
"""
|
||||
# Create temporary database
|
||||
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
||||
|
||||
# Create app with test config
|
||||
test_config = {
|
||||
'TESTING': True,
|
||||
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}',
|
||||
'SECRET_KEY': 'test-secret-key'
|
||||
}
|
||||
|
||||
app = create_app(test_config)
|
||||
|
||||
yield app
|
||||
|
||||
# Cleanup
|
||||
os.close(db_fd)
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def client(app):
|
||||
"""
|
||||
Create Flask test client.
|
||||
|
||||
Args:
|
||||
app: Flask application fixture
|
||||
|
||||
Returns:
|
||||
Flask test client for making API requests
|
||||
"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def db(app):
|
||||
"""
|
||||
Alias for database session that works with Flask app context.
|
||||
|
||||
Args:
|
||||
app: Flask application fixture
|
||||
|
||||
Returns:
|
||||
SQLAlchemy session
|
||||
"""
|
||||
with app.app_context():
|
||||
yield app.db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_scan(db):
|
||||
"""
|
||||
Create a sample scan in the database for testing.
|
||||
|
||||
Args:
|
||||
db: Database session fixture
|
||||
|
||||
Returns:
|
||||
Scan model instance
|
||||
"""
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='completed',
|
||||
config_file='/app/configs/test.yaml',
|
||||
title='Test Scan',
|
||||
duration=125.5,
|
||||
triggered_by='test',
|
||||
json_path='/app/output/scan_report_20251114_103000.json',
|
||||
html_path='/app/output/scan_report_20251114_103000.html',
|
||||
zip_path='/app/output/scan_report_20251114_103000.zip',
|
||||
screenshot_dir='/app/output/scan_report_20251114_103000_screenshots'
|
||||
)
|
||||
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
db.refresh(scan)
|
||||
|
||||
return scan
|
||||
|
||||
|
||||
# Authentication Fixtures
|
||||
|
||||
@pytest.fixture
|
||||
def app_password():
|
||||
"""
|
||||
Test password for authentication tests.
|
||||
|
||||
Returns:
|
||||
Test password string
|
||||
"""
|
||||
return 'testpassword123'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_with_password(db, app_password):
|
||||
"""
|
||||
Database session with application password set.
|
||||
|
||||
Args:
|
||||
db: Database session fixture
|
||||
app_password: Test password fixture
|
||||
|
||||
Returns:
|
||||
Database session with password configured
|
||||
"""
|
||||
settings_manager = SettingsManager(db)
|
||||
PasswordManager.set_app_password(settings_manager, app_password)
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_no_password(app):
|
||||
"""
|
||||
Database session without application password set.
|
||||
|
||||
Args:
|
||||
app: Flask application fixture
|
||||
|
||||
Returns:
|
||||
Database session without password
|
||||
"""
|
||||
with app.app_context():
|
||||
# Clear any password that might be set
|
||||
settings_manager = SettingsManager(app.db_session)
|
||||
settings_manager.delete('app_password')
|
||||
yield app.db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(client, db_with_password, app_password):
|
||||
"""
|
||||
Flask test client with authenticated session.
|
||||
|
||||
Args:
|
||||
client: Flask test client fixture
|
||||
db_with_password: Database with password set
|
||||
app_password: Test password fixture
|
||||
|
||||
Returns:
|
||||
Test client with active session
|
||||
"""
|
||||
# Log in
|
||||
client.post('/auth/login', data={
|
||||
'password': app_password
|
||||
})
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_no_password(app):
|
||||
"""
|
||||
Flask test client with no password set (for setup testing).
|
||||
|
||||
Args:
|
||||
app: Flask application fixture
|
||||
|
||||
Returns:
|
||||
Test client for testing setup flow
|
||||
"""
|
||||
# Create temporary database without password
|
||||
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
||||
|
||||
test_config = {
|
||||
'TESTING': True,
|
||||
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}',
|
||||
'SECRET_KEY': 'test-secret-key'
|
||||
}
|
||||
|
||||
test_app = create_app(test_config)
|
||||
test_client = test_app.test_client()
|
||||
|
||||
yield test_client
|
||||
|
||||
# Cleanup
|
||||
os.close(db_fd)
|
||||
os.unlink(db_path)
|
||||
279
app/tests/test_authentication.py
Normal file
279
app/tests/test_authentication.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Tests for authentication system.
|
||||
|
||||
Tests login, logout, session management, and API authentication.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from flask import url_for
|
||||
from web.auth.models import User
|
||||
from web.utils.settings import PasswordManager, SettingsManager
|
||||
|
||||
|
||||
class TestUserModel:
|
||||
"""Tests for User model."""
|
||||
|
||||
def test_user_get_valid_id(self, db):
|
||||
"""Test getting user with valid ID."""
|
||||
user = User.get('1', db)
|
||||
assert user is not None
|
||||
assert user.id == '1'
|
||||
|
||||
def test_user_get_invalid_id(self, db):
|
||||
"""Test getting user with invalid ID."""
|
||||
user = User.get('invalid', db)
|
||||
assert user is None
|
||||
|
||||
def test_user_properties(self):
|
||||
"""Test user properties."""
|
||||
user = User('1')
|
||||
assert user.is_authenticated is True
|
||||
assert user.is_active is True
|
||||
assert user.is_anonymous is False
|
||||
assert user.get_id() == '1'
|
||||
|
||||
def test_user_authenticate_success(self, db, app_password):
|
||||
"""Test successful authentication."""
|
||||
user = User.authenticate(app_password, db)
|
||||
assert user is not None
|
||||
assert user.id == '1'
|
||||
|
||||
def test_user_authenticate_failure(self, db):
|
||||
"""Test failed authentication with wrong password."""
|
||||
user = User.authenticate('wrongpassword', db)
|
||||
assert user is None
|
||||
|
||||
def test_user_has_password_set(self, db, app_password):
|
||||
"""Test checking if password is set."""
|
||||
# Password is set in fixture
|
||||
assert User.has_password_set(db) is True
|
||||
|
||||
def test_user_has_password_not_set(self, db_no_password):
|
||||
"""Test checking if password is not set."""
|
||||
assert User.has_password_set(db_no_password) is False
|
||||
|
||||
|
||||
class TestAuthRoutes:
|
||||
"""Tests for authentication routes."""
|
||||
|
||||
def test_login_page_renders(self, client):
|
||||
"""Test that login page renders correctly."""
|
||||
response = client.get('/auth/login')
|
||||
assert response.status_code == 200
|
||||
# Note: This will fail until templates are created
|
||||
# assert b'login' in response.data.lower()
|
||||
|
||||
def test_login_success(self, client, app_password):
|
||||
"""Test successful login."""
|
||||
response = client.post('/auth/login', data={
|
||||
'password': app_password
|
||||
}, follow_redirects=False)
|
||||
|
||||
# Should redirect to dashboard (or main.dashboard)
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_login_failure(self, client):
|
||||
"""Test failed login with wrong password."""
|
||||
response = client.post('/auth/login', data={
|
||||
'password': 'wrongpassword'
|
||||
}, follow_redirects=True)
|
||||
|
||||
# Should stay on login page
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_login_redirect_when_authenticated(self, authenticated_client):
|
||||
"""Test that login page redirects when already logged in."""
|
||||
response = authenticated_client.get('/auth/login', follow_redirects=False)
|
||||
# Should redirect to dashboard
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_logout(self, authenticated_client):
|
||||
"""Test logout functionality."""
|
||||
response = authenticated_client.get('/auth/logout', follow_redirects=False)
|
||||
|
||||
# Should redirect to login page
|
||||
assert response.status_code == 302
|
||||
assert '/auth/login' in response.location
|
||||
|
||||
def test_logout_when_not_authenticated(self, client):
|
||||
"""Test logout when not authenticated."""
|
||||
response = client.get('/auth/logout', follow_redirects=False)
|
||||
|
||||
# Should redirect to login page anyway
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_setup_page_renders_when_no_password(self, client_no_password):
|
||||
"""Test that setup page renders when no password is set."""
|
||||
response = client_no_password.get('/auth/setup')
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_setup_redirects_when_password_set(self, client):
|
||||
"""Test that setup page redirects when password already set."""
|
||||
response = client.get('/auth/setup', follow_redirects=False)
|
||||
assert response.status_code == 302
|
||||
assert '/auth/login' in response.location
|
||||
|
||||
def test_setup_password_success(self, client_no_password):
|
||||
"""Test setting password via setup page."""
|
||||
response = client_no_password.post('/auth/setup', data={
|
||||
'password': 'newpassword123',
|
||||
'confirm_password': 'newpassword123'
|
||||
}, follow_redirects=False)
|
||||
|
||||
# Should redirect to login
|
||||
assert response.status_code == 302
|
||||
assert '/auth/login' in response.location
|
||||
|
||||
def test_setup_password_too_short(self, client_no_password):
|
||||
"""Test that setup rejects password that's too short."""
|
||||
response = client_no_password.post('/auth/setup', data={
|
||||
'password': 'short',
|
||||
'confirm_password': 'short'
|
||||
}, follow_redirects=True)
|
||||
|
||||
# Should stay on setup page
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_setup_passwords_dont_match(self, client_no_password):
|
||||
"""Test that setup rejects mismatched passwords."""
|
||||
response = client_no_password.post('/auth/setup', data={
|
||||
'password': 'password123',
|
||||
'confirm_password': 'different123'
|
||||
}, follow_redirects=True)
|
||||
|
||||
# Should stay on setup page
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAPIAuthentication:
|
||||
"""Tests for API endpoint authentication."""
|
||||
|
||||
def test_scans_list_requires_auth(self, client):
|
||||
"""Test that listing scans requires authentication."""
|
||||
response = client.get('/api/scans')
|
||||
assert response.status_code == 401
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == 'Authentication required'
|
||||
|
||||
def test_scans_list_with_auth(self, authenticated_client):
|
||||
"""Test that listing scans works when authenticated."""
|
||||
response = authenticated_client.get('/api/scans')
|
||||
# Should succeed (200) even if empty
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'scans' in data
|
||||
|
||||
def test_scan_trigger_requires_auth(self, client):
|
||||
"""Test that triggering scan requires authentication."""
|
||||
response = client.post('/api/scans', json={
|
||||
'config_file': '/app/configs/test.yaml'
|
||||
})
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_scan_get_requires_auth(self, client):
|
||||
"""Test that getting scan details requires authentication."""
|
||||
response = client.get('/api/scans/1')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_scan_delete_requires_auth(self, client):
|
||||
"""Test that deleting scan requires authentication."""
|
||||
response = client.delete('/api/scans/1')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_scan_status_requires_auth(self, client):
|
||||
"""Test that getting scan status requires authentication."""
|
||||
response = client.get('/api/scans/1/status')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_settings_get_requires_auth(self, client):
|
||||
"""Test that getting settings requires authentication."""
|
||||
response = client.get('/api/settings')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_settings_update_requires_auth(self, client):
|
||||
"""Test that updating settings requires authentication."""
|
||||
response = client.put('/api/settings', json={
|
||||
'settings': {'test_key': 'test_value'}
|
||||
})
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_settings_get_with_auth(self, authenticated_client):
|
||||
"""Test that getting settings works when authenticated."""
|
||||
response = authenticated_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'settings' in data
|
||||
|
||||
def test_schedules_list_requires_auth(self, client):
|
||||
"""Test that listing schedules requires authentication."""
|
||||
response = client.get('/api/schedules')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_alerts_list_requires_auth(self, client):
|
||||
"""Test that listing alerts requires authentication."""
|
||||
response = client.get('/api/alerts')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_health_check_no_auth_required(self, client):
|
||||
"""Test that health check endpoints don't require authentication."""
|
||||
# Health checks should be accessible without authentication
|
||||
response = client.get('/api/scans/health')
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get('/api/settings/health')
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get('/api/schedules/health')
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get('/api/alerts/health')
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestSessionManagement:
|
||||
"""Tests for session management."""
|
||||
|
||||
def test_session_persists_across_requests(self, authenticated_client):
|
||||
"""Test that session persists across multiple requests."""
|
||||
# First request - should succeed
|
||||
response1 = authenticated_client.get('/api/scans')
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Second request - should also succeed (session persists)
|
||||
response2 = authenticated_client.get('/api/settings')
|
||||
assert response2.status_code == 200
|
||||
|
||||
def test_remember_me_cookie(self, client, app_password):
|
||||
"""Test remember me functionality."""
|
||||
response = client.post('/auth/login', data={
|
||||
'password': app_password,
|
||||
'remember': 'on'
|
||||
}, follow_redirects=False)
|
||||
|
||||
# Should set remember_me cookie
|
||||
assert response.status_code == 302
|
||||
# Note: Actual cookie checking would require inspecting response.headers
|
||||
|
||||
|
||||
class TestNextRedirect:
|
||||
"""Tests for 'next' parameter redirect."""
|
||||
|
||||
def test_login_redirects_to_next(self, client, app_password):
|
||||
"""Test that login redirects to 'next' parameter."""
|
||||
response = client.post('/auth/login?next=/api/scans', data={
|
||||
'password': app_password
|
||||
}, follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert '/api/scans' in response.location
|
||||
|
||||
def test_login_without_next_redirects_to_dashboard(self, client, app_password):
|
||||
"""Test that login without 'next' redirects to dashboard."""
|
||||
response = client.post('/auth/login', data={
|
||||
'password': app_password
|
||||
}, follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
# Should redirect to dashboard
|
||||
assert 'dashboard' in response.location or response.location == '/'
|
||||
225
app/tests/test_background_jobs.py
Normal file
225
app/tests/test_background_jobs.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Tests for background job execution and scheduler integration.
|
||||
|
||||
Tests the APScheduler integration, job queuing, and background scan execution.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from web.models import Scan
|
||||
from web.services.scan_service import ScanService
|
||||
from web.services.scheduler_service import SchedulerService
|
||||
|
||||
|
||||
class TestBackgroundJobs:
|
||||
"""Test suite for background job execution."""
|
||||
|
||||
def test_scheduler_initialization(self, app):
|
||||
"""Test that scheduler is initialized with Flask app."""
|
||||
assert hasattr(app, 'scheduler')
|
||||
assert app.scheduler is not None
|
||||
assert app.scheduler.scheduler is not None
|
||||
assert app.scheduler.scheduler.running
|
||||
|
||||
def test_queue_scan_job(self, app, db, sample_config_file):
|
||||
"""Test queuing a scan for background execution."""
|
||||
# Create a scan via service
|
||||
scan_service = ScanService(db)
|
||||
scan_id = scan_service.trigger_scan(
|
||||
config_file=sample_config_file,
|
||||
triggered_by='test',
|
||||
scheduler=app.scheduler
|
||||
)
|
||||
|
||||
# Verify scan was created
|
||||
scan = db.query(Scan).filter_by(id=scan_id).first()
|
||||
assert scan is not None
|
||||
assert scan.status == 'running'
|
||||
|
||||
# Verify job was queued (check scheduler has the job)
|
||||
job = app.scheduler.scheduler.get_job(f'scan_{scan_id}')
|
||||
assert job is not None
|
||||
assert job.id == f'scan_{scan_id}'
|
||||
|
||||
def test_trigger_scan_without_scheduler(self, db, sample_config_file):
|
||||
"""Test triggering scan without scheduler logs warning."""
|
||||
# Create scan without scheduler
|
||||
scan_service = ScanService(db)
|
||||
scan_id = scan_service.trigger_scan(
|
||||
config_file=sample_config_file,
|
||||
triggered_by='test',
|
||||
scheduler=None # No scheduler
|
||||
)
|
||||
|
||||
# Verify scan was created but not queued
|
||||
scan = db.query(Scan).filter_by(id=scan_id).first()
|
||||
assert scan is not None
|
||||
assert scan.status == 'running'
|
||||
|
||||
def test_scheduler_service_queue_scan(self, app, db, sample_config_file):
|
||||
"""Test SchedulerService.queue_scan directly."""
|
||||
# Create scan record first
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='running',
|
||||
config_file=sample_config_file,
|
||||
title='Test Scan',
|
||||
triggered_by='test'
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
# Queue the scan
|
||||
job_id = app.scheduler.queue_scan(scan.id, sample_config_file)
|
||||
|
||||
# Verify job was queued
|
||||
assert job_id == f'scan_{scan.id}'
|
||||
job = app.scheduler.scheduler.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
def test_scheduler_list_jobs(self, app, db, sample_config_file):
|
||||
"""Test listing scheduled jobs."""
|
||||
# Queue a few scans
|
||||
for i in range(3):
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='running',
|
||||
config_file=sample_config_file,
|
||||
title=f'Test Scan {i}',
|
||||
triggered_by='test'
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
app.scheduler.queue_scan(scan.id, sample_config_file)
|
||||
|
||||
# List jobs
|
||||
jobs = app.scheduler.list_jobs()
|
||||
|
||||
# Should have at least 3 jobs (might have more from other tests)
|
||||
assert len(jobs) >= 3
|
||||
|
||||
# Each job should have required fields
|
||||
for job in jobs:
|
||||
assert 'id' in job
|
||||
assert 'name' in job
|
||||
assert 'trigger' in job
|
||||
|
||||
def test_scheduler_get_job_status(self, app, db, sample_config_file):
|
||||
"""Test getting status of a specific job."""
|
||||
# Create and queue a scan
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='running',
|
||||
config_file=sample_config_file,
|
||||
title='Test Scan',
|
||||
triggered_by='test'
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
job_id = app.scheduler.queue_scan(scan.id, sample_config_file)
|
||||
|
||||
# Get job status
|
||||
status = app.scheduler.get_job_status(job_id)
|
||||
|
||||
assert status is not None
|
||||
assert status['id'] == job_id
|
||||
assert status['name'] == f'Scan {scan.id}'
|
||||
|
||||
def test_scheduler_get_nonexistent_job(self, app):
|
||||
"""Test getting status of non-existent job."""
|
||||
status = app.scheduler.get_job_status('nonexistent_job_id')
|
||||
assert status is None
|
||||
|
||||
def test_scan_timing_fields(self, db, sample_config_file):
|
||||
"""Test that scan timing fields are properly set."""
|
||||
# Create scan with started_at
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='running',
|
||||
config_file=sample_config_file,
|
||||
title='Test Scan',
|
||||
triggered_by='test',
|
||||
started_at=datetime.utcnow()
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
# Verify fields exist
|
||||
assert scan.started_at is not None
|
||||
assert scan.completed_at is None
|
||||
assert scan.error_message is None
|
||||
|
||||
# Update to completed
|
||||
scan.status = 'completed'
|
||||
scan.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
# Verify fields updated
|
||||
assert scan.completed_at is not None
|
||||
assert (scan.completed_at - scan.started_at).total_seconds() >= 0
|
||||
|
||||
def test_scan_error_handling(self, db, sample_config_file):
|
||||
"""Test that error messages are stored correctly."""
|
||||
# Create failed scan
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='failed',
|
||||
config_file=sample_config_file,
|
||||
title='Failed Scan',
|
||||
triggered_by='test',
|
||||
started_at=datetime.utcnow(),
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message='Test error message'
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
# Verify error message stored
|
||||
assert scan.error_message == 'Test error message'
|
||||
|
||||
# Verify status query works
|
||||
scan_service = ScanService(db)
|
||||
status = scan_service.get_scan_status(scan.id)
|
||||
|
||||
assert status['status'] == 'failed'
|
||||
assert status['error_message'] == 'Test error message'
|
||||
|
||||
@pytest.mark.skip(reason="Requires actual scanner execution - slow test")
|
||||
def test_background_scan_execution(self, app, db, sample_config_file):
|
||||
"""
|
||||
Integration test for actual background scan execution.
|
||||
|
||||
This test is skipped by default because it actually runs the scanner,
|
||||
which requires privileged operations and takes time.
|
||||
|
||||
To run: pytest -v -k test_background_scan_execution --run-slow
|
||||
"""
|
||||
# Trigger scan
|
||||
scan_service = ScanService(db)
|
||||
scan_id = scan_service.trigger_scan(
|
||||
config_file=sample_config_file,
|
||||
triggered_by='test',
|
||||
scheduler=app.scheduler
|
||||
)
|
||||
|
||||
# Wait for scan to complete (with timeout)
|
||||
max_wait = 300 # 5 minutes
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait:
|
||||
scan = db.query(Scan).filter_by(id=scan_id).first()
|
||||
if scan.status in ['completed', 'failed']:
|
||||
break
|
||||
time.sleep(5)
|
||||
|
||||
# Verify scan completed
|
||||
scan = db.query(Scan).filter_by(id=scan_id).first()
|
||||
assert scan.status in ['completed', 'failed']
|
||||
|
||||
if scan.status == 'completed':
|
||||
assert scan.duration is not None
|
||||
assert scan.json_path is not None
|
||||
else:
|
||||
assert scan.error_message is not None
|
||||
483
app/tests/test_config_api.py
Normal file
483
app/tests/test_config_api.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Integration tests for Config API endpoints.
|
||||
|
||||
Tests all config API endpoints including CSV/YAML upload, listing, downloading,
|
||||
and deletion with schedule protection.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
from web.app import create_app
|
||||
from web.models import Base
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test application"""
|
||||
# Create temporary database
|
||||
test_db = tempfile.mktemp(suffix='.db')
|
||||
|
||||
# Create temporary configs directory
|
||||
temp_configs_dir = tempfile.mkdtemp()
|
||||
|
||||
app = create_app({
|
||||
'TESTING': True,
|
||||
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{test_db}',
|
||||
'SECRET_KEY': 'test-secret-key',
|
||||
'WTF_CSRF_ENABLED': False,
|
||||
})
|
||||
|
||||
# Override configs directory in ConfigService
|
||||
os.environ['CONFIGS_DIR'] = temp_configs_dir
|
||||
|
||||
# Create tables
|
||||
with app.app_context():
|
||||
Base.metadata.create_all(bind=app.db_session.get_bind())
|
||||
|
||||
yield app
|
||||
|
||||
# Cleanup
|
||||
os.unlink(test_db)
|
||||
shutil.rmtree(temp_configs_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers(client):
|
||||
"""Get authentication headers"""
|
||||
# First register and login a user
|
||||
from web.auth.models import User
|
||||
|
||||
with client.application.app_context():
|
||||
# Create test user
|
||||
user = User(username='testuser')
|
||||
user.set_password('testpass')
|
||||
client.application.db_session.add(user)
|
||||
client.application.db_session.commit()
|
||||
|
||||
# Login
|
||||
response = client.post('/auth/login', data={
|
||||
'username': 'testuser',
|
||||
'password': 'testpass'
|
||||
}, follow_redirects=True)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Return empty headers (session-based auth)
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csv():
|
||||
"""Sample CSV content"""
|
||||
return """scan_title,site_name,ip_address,ping_expected,tcp_ports,udp_ports,services
|
||||
Test Scan,Web Servers,10.10.20.4,true,"22,80,443",53,"ssh,http,https"
|
||||
Test Scan,Web Servers,10.10.20.5,true,22,,"ssh"
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml():
|
||||
"""Sample YAML content"""
|
||||
return """title: Test Scan
|
||||
sites:
|
||||
- name: Web Servers
|
||||
ips:
|
||||
- address: 10.10.20.4
|
||||
expected:
|
||||
ping: true
|
||||
tcp_ports: [22, 80, 443]
|
||||
udp_ports: [53]
|
||||
services: [ssh, http, https]
|
||||
"""
|
||||
|
||||
|
||||
class TestListConfigs:
|
||||
"""Tests for GET /api/configs"""
|
||||
|
||||
def test_list_configs_empty(self, client, auth_headers):
|
||||
"""Test listing configs when none exist"""
|
||||
response = client.get('/api/configs', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert 'configs' in data
|
||||
assert data['configs'] == []
|
||||
|
||||
def test_list_configs_with_files(self, client, auth_headers, app, sample_yaml):
|
||||
"""Test listing configs with existing files"""
|
||||
# Create a config file
|
||||
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml)
|
||||
|
||||
response = client.get('/api/configs', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data['configs']) == 1
|
||||
assert data['configs'][0]['filename'] == 'test-scan.yaml'
|
||||
assert data['configs'][0]['title'] == 'Test Scan'
|
||||
assert 'created_at' in data['configs'][0]
|
||||
assert 'size_bytes' in data['configs'][0]
|
||||
assert 'used_by_schedules' in data['configs'][0]
|
||||
|
||||
def test_list_configs_requires_auth(self, client):
|
||||
"""Test that listing configs requires authentication"""
|
||||
response = client.get('/api/configs')
|
||||
assert response.status_code in [401, 302] # Unauthorized or redirect
|
||||
|
||||
|
||||
class TestGetConfig:
|
||||
"""Tests for GET /api/configs/<filename>"""
|
||||
|
||||
def test_get_config_valid(self, client, auth_headers, app, sample_yaml):
|
||||
"""Test getting a valid config file"""
|
||||
# Create a config file
|
||||
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml)
|
||||
|
||||
response = client.get('/api/configs/test-scan.yaml', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert data['filename'] == 'test-scan.yaml'
|
||||
assert 'content' in data
|
||||
assert 'parsed' in data
|
||||
assert data['parsed']['title'] == 'Test Scan'
|
||||
|
||||
def test_get_config_not_found(self, client, auth_headers):
|
||||
"""Test getting non-existent config"""
|
||||
response = client.get('/api/configs/nonexistent.yaml', headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_get_config_requires_auth(self, client):
|
||||
"""Test that getting config requires authentication"""
|
||||
response = client.get('/api/configs/test.yaml')
|
||||
assert response.status_code in [401, 302]
|
||||
|
||||
|
||||
class TestUploadCSV:
|
||||
"""Tests for POST /api/configs/upload-csv"""
|
||||
|
||||
def test_upload_csv_valid(self, client, auth_headers, sample_csv):
|
||||
"""Test uploading valid CSV"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(sample_csv.encode('utf-8')), 'test.csv')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-csv', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 200
|
||||
|
||||
result = response.get_json()
|
||||
assert result['success'] is True
|
||||
assert 'filename' in result
|
||||
assert result['filename'].endswith('.yaml')
|
||||
assert 'preview' in result
|
||||
|
||||
def test_upload_csv_no_file(self, client, auth_headers):
|
||||
"""Test uploading without file"""
|
||||
response = client.post('/api/configs/upload-csv', data={},
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_upload_csv_invalid_format(self, client, auth_headers):
|
||||
"""Test uploading invalid CSV"""
|
||||
from io import BytesIO
|
||||
|
||||
invalid_csv = "not,a,valid,csv\nmissing,columns"
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(invalid_csv.encode('utf-8')), 'test.csv')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-csv', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
result = response.get_json()
|
||||
assert 'error' in result
|
||||
|
||||
def test_upload_csv_wrong_extension(self, client, auth_headers):
|
||||
"""Test uploading file with wrong extension"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(b'test'), 'test.txt')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-csv', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_upload_csv_duplicate_filename(self, client, auth_headers, sample_csv):
|
||||
"""Test uploading CSV that generates duplicate filename"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(sample_csv.encode('utf-8')), 'test.csv')
|
||||
}
|
||||
|
||||
# Upload first time
|
||||
response1 = client.post('/api/configs/upload-csv', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Upload second time (should fail)
|
||||
response2 = client.post('/api/configs/upload-csv', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response2.status_code == 400
|
||||
|
||||
def test_upload_csv_requires_auth(self, client, sample_csv):
|
||||
"""Test that uploading CSV requires authentication"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(sample_csv.encode('utf-8')), 'test.csv')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-csv', data=data,
|
||||
content_type='multipart/form-data')
|
||||
assert response.status_code in [401, 302]
|
||||
|
||||
|
||||
class TestUploadYAML:
|
||||
"""Tests for POST /api/configs/upload-yaml"""
|
||||
|
||||
def test_upload_yaml_valid(self, client, auth_headers, sample_yaml):
|
||||
"""Test uploading valid YAML"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(sample_yaml.encode('utf-8')), 'test.yaml')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-yaml', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 200
|
||||
|
||||
result = response.get_json()
|
||||
assert result['success'] is True
|
||||
assert 'filename' in result
|
||||
|
||||
def test_upload_yaml_no_file(self, client, auth_headers):
|
||||
"""Test uploading without file"""
|
||||
response = client.post('/api/configs/upload-yaml', data={},
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_upload_yaml_invalid_syntax(self, client, auth_headers):
|
||||
"""Test uploading YAML with invalid syntax"""
|
||||
from io import BytesIO
|
||||
|
||||
invalid_yaml = "invalid: yaml: syntax: ["
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(invalid_yaml.encode('utf-8')), 'test.yaml')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-yaml', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_upload_yaml_missing_required_fields(self, client, auth_headers):
|
||||
"""Test uploading YAML missing required fields"""
|
||||
from io import BytesIO
|
||||
|
||||
invalid_yaml = """sites:
|
||||
- name: Test
|
||||
ips:
|
||||
- address: 10.0.0.1
|
||||
"""
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(invalid_yaml.encode('utf-8')), 'test.yaml')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-yaml', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_upload_yaml_wrong_extension(self, client, auth_headers):
|
||||
"""Test uploading file with wrong extension"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(b'test'), 'test.txt')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-yaml', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_upload_yaml_requires_auth(self, client, sample_yaml):
|
||||
"""Test that uploading YAML requires authentication"""
|
||||
from io import BytesIO
|
||||
|
||||
data = {
|
||||
'file': (BytesIO(sample_yaml.encode('utf-8')), 'test.yaml')
|
||||
}
|
||||
|
||||
response = client.post('/api/configs/upload-yaml', data=data,
|
||||
content_type='multipart/form-data')
|
||||
assert response.status_code in [401, 302]
|
||||
|
||||
|
||||
class TestDownloadTemplate:
|
||||
"""Tests for GET /api/configs/template"""
|
||||
|
||||
def test_download_template(self, client, auth_headers):
|
||||
"""Test downloading CSV template"""
|
||||
response = client.get('/api/configs/template', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.content_type == 'text/csv; charset=utf-8'
|
||||
assert b'scan_title,site_name,ip_address' in response.data
|
||||
|
||||
def test_download_template_requires_auth(self, client):
|
||||
"""Test that downloading template requires authentication"""
|
||||
response = client.get('/api/configs/template')
|
||||
assert response.status_code in [401, 302]
|
||||
|
||||
|
||||
class TestDownloadConfig:
|
||||
"""Tests for GET /api/configs/<filename>/download"""
|
||||
|
||||
def test_download_config_valid(self, client, auth_headers, app, sample_yaml):
|
||||
"""Test downloading existing config"""
|
||||
# Create a config file
|
||||
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml)
|
||||
|
||||
response = client.get('/api/configs/test-scan.yaml/download', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.content_type == 'application/x-yaml; charset=utf-8'
|
||||
assert b'title: Test Scan' in response.data
|
||||
|
||||
def test_download_config_not_found(self, client, auth_headers):
|
||||
"""Test downloading non-existent config"""
|
||||
response = client.get('/api/configs/nonexistent.yaml/download', headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_download_config_requires_auth(self, client):
|
||||
"""Test that downloading config requires authentication"""
|
||||
response = client.get('/api/configs/test.yaml/download')
|
||||
assert response.status_code in [401, 302]
|
||||
|
||||
|
||||
class TestDeleteConfig:
|
||||
"""Tests for DELETE /api/configs/<filename>"""
|
||||
|
||||
def test_delete_config_valid(self, client, auth_headers, app, sample_yaml):
|
||||
"""Test deleting a config file"""
|
||||
# Create a config file
|
||||
temp_configs_dir = os.environ.get('CONFIGS_DIR', '/app/configs')
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml)
|
||||
|
||||
response = client.delete('/api/configs/test-scan.yaml', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert data['success'] is True
|
||||
|
||||
# Verify file is deleted
|
||||
assert not os.path.exists(config_path)
|
||||
|
||||
def test_delete_config_not_found(self, client, auth_headers):
|
||||
"""Test deleting non-existent config"""
|
||||
response = client.delete('/api/configs/nonexistent.yaml', headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_config_requires_auth(self, client):
|
||||
"""Test that deleting config requires authentication"""
|
||||
response = client.delete('/api/configs/test.yaml')
|
||||
assert response.status_code in [401, 302]
|
||||
|
||||
|
||||
class TestEndToEndWorkflow:
|
||||
"""End-to-end workflow tests"""
|
||||
|
||||
def test_complete_csv_workflow(self, client, auth_headers, sample_csv):
|
||||
"""Test complete CSV upload workflow"""
|
||||
from io import BytesIO
|
||||
|
||||
# 1. Download template
|
||||
response = client.get('/api/configs/template', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 2. Upload CSV
|
||||
data = {
|
||||
'file': (BytesIO(sample_csv.encode('utf-8')), 'workflow-test.csv')
|
||||
}
|
||||
response = client.post('/api/configs/upload-csv', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 200
|
||||
result = response.get_json()
|
||||
filename = result['filename']
|
||||
|
||||
# 3. List configs (should include new one)
|
||||
response = client.get('/api/configs', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
configs = response.get_json()['configs']
|
||||
assert any(c['filename'] == filename for c in configs)
|
||||
|
||||
# 4. Get config details
|
||||
response = client.get(f'/api/configs/{filename}', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 5. Download config
|
||||
response = client.get(f'/api/configs/{filename}/download', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 6. Delete config
|
||||
response = client.delete(f'/api/configs/{filename}', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 7. Verify deletion
|
||||
response = client.get(f'/api/configs/{filename}', headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_yaml_upload_workflow(self, client, auth_headers, sample_yaml):
|
||||
"""Test YAML upload workflow"""
|
||||
from io import BytesIO
|
||||
|
||||
# Upload YAML
|
||||
data = {
|
||||
'file': (BytesIO(sample_yaml.encode('utf-8')), 'yaml-workflow.yaml')
|
||||
}
|
||||
response = client.post('/api/configs/upload-yaml', data=data,
|
||||
headers=auth_headers, content_type='multipart/form-data')
|
||||
assert response.status_code == 200
|
||||
|
||||
filename = response.get_json()['filename']
|
||||
|
||||
# Verify it exists
|
||||
response = client.get(f'/api/configs/{filename}', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Clean up
|
||||
client.delete(f'/api/configs/{filename}', headers=auth_headers)
|
||||
545
app/tests/test_config_service.py
Normal file
545
app/tests/test_config_service.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""
|
||||
Unit tests for Config Service
|
||||
|
||||
Tests the ConfigService class which manages scan configuration files.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import yaml
|
||||
import tempfile
|
||||
import shutil
|
||||
from web.services.config_service import ConfigService
|
||||
|
||||
|
||||
class TestConfigService:
|
||||
"""Test suite for ConfigService"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_configs_dir(self):
|
||||
"""Create a temporary directory for config files"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, temp_configs_dir):
|
||||
"""Create a ConfigService instance with temp directory"""
|
||||
return ConfigService(configs_dir=temp_configs_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml_config(self):
|
||||
"""Sample YAML config content"""
|
||||
return """title: Test Scan
|
||||
sites:
|
||||
- name: Web Servers
|
||||
ips:
|
||||
- address: 10.10.20.4
|
||||
expected:
|
||||
ping: true
|
||||
tcp_ports: [22, 80, 443]
|
||||
udp_ports: [53]
|
||||
services: [ssh, http, https]
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csv_content(self):
|
||||
"""Sample CSV content"""
|
||||
return """scan_title,site_name,ip_address,ping_expected,tcp_ports,udp_ports,services
|
||||
Test Scan,Web Servers,10.10.20.4,true,"22,80,443",53,"ssh,http,https"
|
||||
Test Scan,Web Servers,10.10.20.5,true,22,,"ssh"
|
||||
"""
|
||||
|
||||
def test_list_configs_empty_directory(self, service):
|
||||
"""Test listing configs when directory is empty"""
|
||||
configs = service.list_configs()
|
||||
assert configs == []
|
||||
|
||||
def test_list_configs_with_files(self, service, temp_configs_dir, sample_yaml_config):
|
||||
"""Test listing configs with existing files"""
|
||||
# Create a config file
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
configs = service.list_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]['filename'] == 'test-scan.yaml'
|
||||
assert configs[0]['title'] == 'Test Scan'
|
||||
assert 'created_at' in configs[0]
|
||||
assert 'size_bytes' in configs[0]
|
||||
assert 'used_by_schedules' in configs[0]
|
||||
|
||||
def test_list_configs_ignores_non_yaml_files(self, service, temp_configs_dir):
|
||||
"""Test that non-YAML files are ignored"""
|
||||
# Create non-YAML files
|
||||
with open(os.path.join(temp_configs_dir, 'test.txt'), 'w') as f:
|
||||
f.write('not a yaml file')
|
||||
with open(os.path.join(temp_configs_dir, 'readme.md'), 'w') as f:
|
||||
f.write('# README')
|
||||
|
||||
configs = service.list_configs()
|
||||
assert len(configs) == 0
|
||||
|
||||
def test_get_config_valid(self, service, temp_configs_dir, sample_yaml_config):
|
||||
"""Test getting a valid config file"""
|
||||
# Create a config file
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
result = service.get_config('test-scan.yaml')
|
||||
assert result['filename'] == 'test-scan.yaml'
|
||||
assert 'content' in result
|
||||
assert 'parsed' in result
|
||||
assert result['parsed']['title'] == 'Test Scan'
|
||||
assert len(result['parsed']['sites']) == 1
|
||||
|
||||
def test_get_config_not_found(self, service):
|
||||
"""Test getting a non-existent config"""
|
||||
with pytest.raises(FileNotFoundError, match="not found"):
|
||||
service.get_config('nonexistent.yaml')
|
||||
|
||||
def test_get_config_invalid_yaml(self, service, temp_configs_dir):
|
||||
"""Test getting a config with invalid YAML syntax"""
|
||||
# Create invalid YAML file
|
||||
config_path = os.path.join(temp_configs_dir, 'invalid.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write("invalid: yaml: syntax: [")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid YAML syntax"):
|
||||
service.get_config('invalid.yaml')
|
||||
|
||||
def test_create_from_yaml_valid(self, service, sample_yaml_config):
|
||||
"""Test creating config from valid YAML"""
|
||||
filename = service.create_from_yaml('test-scan.yaml', sample_yaml_config)
|
||||
assert filename == 'test-scan.yaml'
|
||||
assert service.config_exists('test-scan.yaml')
|
||||
|
||||
# Verify content
|
||||
result = service.get_config('test-scan.yaml')
|
||||
assert result['parsed']['title'] == 'Test Scan'
|
||||
|
||||
def test_create_from_yaml_adds_extension(self, service, sample_yaml_config):
|
||||
"""Test that .yaml extension is added if missing"""
|
||||
filename = service.create_from_yaml('test-scan', sample_yaml_config)
|
||||
assert filename == 'test-scan.yaml'
|
||||
assert service.config_exists('test-scan.yaml')
|
||||
|
||||
def test_create_from_yaml_sanitizes_filename(self, service, sample_yaml_config):
|
||||
"""Test that filename is sanitized"""
|
||||
filename = service.create_from_yaml('../../../etc/test.yaml', sample_yaml_config)
|
||||
# secure_filename should remove path traversal
|
||||
assert '..' not in filename
|
||||
assert '/' not in filename
|
||||
|
||||
def test_create_from_yaml_duplicate_filename(self, service, temp_configs_dir, sample_yaml_config):
|
||||
"""Test creating config with duplicate filename"""
|
||||
# Create first config
|
||||
service.create_from_yaml('test-scan.yaml', sample_yaml_config)
|
||||
|
||||
# Try to create duplicate
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.create_from_yaml('test-scan.yaml', sample_yaml_config)
|
||||
|
||||
def test_create_from_yaml_invalid_syntax(self, service):
|
||||
"""Test creating config with invalid YAML syntax"""
|
||||
invalid_yaml = "invalid: yaml: syntax: ["
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid YAML syntax"):
|
||||
service.create_from_yaml('test.yaml', invalid_yaml)
|
||||
|
||||
def test_create_from_yaml_invalid_structure(self, service):
|
||||
"""Test creating config with invalid structure (missing title)"""
|
||||
invalid_config = """sites:
|
||||
- name: Test
|
||||
ips:
|
||||
- address: 10.0.0.1
|
||||
expected:
|
||||
ping: true
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required field: 'title'"):
|
||||
service.create_from_yaml('test.yaml', invalid_config)
|
||||
|
||||
def test_create_from_csv_valid(self, service, sample_csv_content):
|
||||
"""Test creating config from valid CSV"""
|
||||
filename, yaml_content = service.create_from_csv(sample_csv_content)
|
||||
assert filename == 'test-scan.yaml'
|
||||
assert service.config_exists(filename)
|
||||
|
||||
# Verify YAML was created correctly
|
||||
result = service.get_config(filename)
|
||||
assert result['parsed']['title'] == 'Test Scan'
|
||||
assert len(result['parsed']['sites']) == 1
|
||||
assert len(result['parsed']['sites'][0]['ips']) == 2
|
||||
|
||||
def test_create_from_csv_with_suggested_filename(self, service, sample_csv_content):
|
||||
"""Test creating config with suggested filename"""
|
||||
filename, yaml_content = service.create_from_csv(sample_csv_content, 'custom-name.yaml')
|
||||
assert filename == 'custom-name.yaml'
|
||||
assert service.config_exists(filename)
|
||||
|
||||
def test_create_from_csv_invalid(self, service):
|
||||
"""Test creating config from invalid CSV"""
|
||||
invalid_csv = """scan_title,site_name,ip_address
|
||||
Missing,Columns,Here
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match="CSV parsing failed"):
|
||||
service.create_from_csv(invalid_csv)
|
||||
|
||||
def test_create_from_csv_duplicate_filename(self, service, sample_csv_content):
|
||||
"""Test creating CSV config with duplicate filename"""
|
||||
# Create first config
|
||||
service.create_from_csv(sample_csv_content)
|
||||
|
||||
# Try to create duplicate (same title generates same filename)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.create_from_csv(sample_csv_content)
|
||||
|
||||
def test_delete_config_valid(self, service, temp_configs_dir, sample_yaml_config):
|
||||
"""Test deleting a config file"""
|
||||
# Create a config file
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
assert service.config_exists('test-scan.yaml')
|
||||
service.delete_config('test-scan.yaml')
|
||||
assert not service.config_exists('test-scan.yaml')
|
||||
|
||||
def test_delete_config_not_found(self, service):
|
||||
"""Test deleting non-existent config"""
|
||||
with pytest.raises(FileNotFoundError, match="not found"):
|
||||
service.delete_config('nonexistent.yaml')
|
||||
|
||||
def test_delete_config_used_by_schedule(self, service, temp_configs_dir, sample_yaml_config, monkeypatch):
|
||||
"""Test deleting config that is used by schedules - should cascade delete schedules"""
|
||||
# Create a config file
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
# Mock schedule service interactions
|
||||
deleted_schedule_ids = []
|
||||
|
||||
class MockScheduleService:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
def list_schedules(self, page=1, per_page=10000):
|
||||
return {
|
||||
'schedules': [
|
||||
{
|
||||
'id': 1,
|
||||
'name': 'Daily Scan',
|
||||
'config_file': 'test-scan.yaml',
|
||||
'enabled': True
|
||||
},
|
||||
{
|
||||
'id': 2,
|
||||
'name': 'Weekly Audit',
|
||||
'config_file': 'test-scan.yaml',
|
||||
'enabled': False # Disabled schedule should also be deleted
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def delete_schedule(self, schedule_id):
|
||||
deleted_schedule_ids.append(schedule_id)
|
||||
return True
|
||||
|
||||
# Mock the ScheduleService import
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
mock_module = MagicMock()
|
||||
mock_module.ScheduleService = MockScheduleService
|
||||
monkeypatch.setitem(sys.modules, 'web.services.schedule_service', mock_module)
|
||||
|
||||
# Mock current_app
|
||||
mock_app = MagicMock()
|
||||
mock_app.db_session = MagicMock()
|
||||
|
||||
import flask
|
||||
monkeypatch.setattr(flask, 'current_app', mock_app)
|
||||
|
||||
# Delete the config - should cascade delete associated schedules
|
||||
service.delete_config('test-scan.yaml')
|
||||
|
||||
# Config should be deleted
|
||||
assert not service.config_exists('test-scan.yaml')
|
||||
|
||||
# Both schedules (enabled and disabled) should be deleted
|
||||
assert deleted_schedule_ids == [1, 2]
|
||||
|
||||
def test_validate_config_content_valid(self, service):
|
||||
"""Test validating valid config content"""
|
||||
valid_config = {
|
||||
'title': 'Test Scan',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Web Servers',
|
||||
'ips': [
|
||||
{
|
||||
'address': '10.10.20.4',
|
||||
'expected': {
|
||||
'ping': True,
|
||||
'tcp_ports': [22, 80, 443],
|
||||
'udp_ports': [53]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(valid_config)
|
||||
assert is_valid is True
|
||||
assert error == ""
|
||||
|
||||
def test_validate_config_content_not_dict(self, service):
|
||||
"""Test validating non-dict content"""
|
||||
is_valid, error = service.validate_config_content(['not', 'a', 'dict'])
|
||||
assert is_valid is False
|
||||
assert 'must be a dictionary' in error
|
||||
|
||||
def test_validate_config_content_missing_title(self, service):
|
||||
"""Test validating config without title"""
|
||||
config = {
|
||||
'sites': []
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "Missing required field: 'title'" in error
|
||||
|
||||
def test_validate_config_content_missing_sites(self, service):
|
||||
"""Test validating config without sites"""
|
||||
config = {
|
||||
'title': 'Test'
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "Missing required field: 'sites'" in error
|
||||
|
||||
def test_validate_config_content_empty_title(self, service):
|
||||
"""Test validating config with empty title"""
|
||||
config = {
|
||||
'title': '',
|
||||
'sites': []
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "non-empty string" in error
|
||||
|
||||
def test_validate_config_content_sites_not_list(self, service):
|
||||
"""Test validating config with sites as non-list"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': 'not a list'
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "must be a list" in error
|
||||
|
||||
def test_validate_config_content_no_sites(self, service):
|
||||
"""Test validating config with empty sites list"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': []
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "at least one site" in error
|
||||
|
||||
def test_validate_config_content_site_missing_name(self, service):
|
||||
"""Test validating site without name"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': [
|
||||
{
|
||||
'ips': []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "missing required field: 'name'" in error
|
||||
|
||||
def test_validate_config_content_site_missing_ips(self, service):
|
||||
"""Test validating site without ips"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Test Site'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "missing required field: 'ips'" in error
|
||||
|
||||
def test_validate_config_content_site_no_ips(self, service):
|
||||
"""Test validating site with empty ips list"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Test Site',
|
||||
'ips': []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "at least one IP" in error
|
||||
|
||||
def test_validate_config_content_ip_missing_address(self, service):
|
||||
"""Test validating IP without address"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Test Site',
|
||||
'ips': [
|
||||
{
|
||||
'expected': {}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "missing required field: 'address'" in error
|
||||
|
||||
def test_validate_config_content_ip_missing_expected(self, service):
|
||||
"""Test validating IP without expected"""
|
||||
config = {
|
||||
'title': 'Test',
|
||||
'sites': [
|
||||
{
|
||||
'name': 'Test Site',
|
||||
'ips': [
|
||||
{
|
||||
'address': '10.0.0.1'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
is_valid, error = service.validate_config_content(config)
|
||||
assert is_valid is False
|
||||
assert "missing required field: 'expected'" in error
|
||||
|
||||
def test_generate_filename_from_title_simple(self, service):
|
||||
"""Test generating filename from simple title"""
|
||||
filename = service.generate_filename_from_title('Production Scan')
|
||||
assert filename == 'production-scan.yaml'
|
||||
|
||||
def test_generate_filename_from_title_special_chars(self, service):
|
||||
"""Test generating filename with special characters"""
|
||||
filename = service.generate_filename_from_title('Prod Scan (2025)!')
|
||||
assert filename == 'prod-scan-2025.yaml'
|
||||
assert '(' not in filename
|
||||
assert ')' not in filename
|
||||
assert '!' not in filename
|
||||
|
||||
def test_generate_filename_from_title_multiple_spaces(self, service):
|
||||
"""Test generating filename with multiple spaces"""
|
||||
filename = service.generate_filename_from_title('Test Multiple Spaces')
|
||||
assert filename == 'test-multiple-spaces.yaml'
|
||||
# Should not have consecutive hyphens
|
||||
assert '--' not in filename
|
||||
|
||||
def test_generate_filename_from_title_leading_trailing_spaces(self, service):
|
||||
"""Test generating filename with leading/trailing spaces"""
|
||||
filename = service.generate_filename_from_title(' Test Scan ')
|
||||
assert filename == 'test-scan.yaml'
|
||||
assert not filename.startswith('-')
|
||||
assert not filename.endswith('-.yaml')
|
||||
|
||||
def test_generate_filename_from_title_long(self, service):
|
||||
"""Test generating filename from long title"""
|
||||
long_title = 'A' * 300
|
||||
filename = service.generate_filename_from_title(long_title)
|
||||
# Should be limited to 200 chars (195 + .yaml)
|
||||
assert len(filename) <= 200
|
||||
|
||||
def test_generate_filename_from_title_empty(self, service):
|
||||
"""Test generating filename from empty title"""
|
||||
filename = service.generate_filename_from_title('')
|
||||
assert filename == 'config.yaml'
|
||||
|
||||
def test_generate_filename_from_title_only_special_chars(self, service):
|
||||
"""Test generating filename from title with only special characters"""
|
||||
filename = service.generate_filename_from_title('!@#$%^&*()')
|
||||
assert filename == 'config.yaml'
|
||||
|
||||
def test_get_config_path(self, service, temp_configs_dir):
|
||||
"""Test getting config path"""
|
||||
path = service.get_config_path('test.yaml')
|
||||
assert path == os.path.join(temp_configs_dir, 'test.yaml')
|
||||
|
||||
def test_config_exists_true(self, service, temp_configs_dir, sample_yaml_config):
|
||||
"""Test config_exists returns True for existing file"""
|
||||
config_path = os.path.join(temp_configs_dir, 'test-scan.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
assert service.config_exists('test-scan.yaml') is True
|
||||
|
||||
def test_config_exists_false(self, service):
|
||||
"""Test config_exists returns False for non-existent file"""
|
||||
assert service.config_exists('nonexistent.yaml') is False
|
||||
|
||||
def test_get_schedules_using_config_none(self, service):
|
||||
"""Test getting schedules when none use the config"""
|
||||
schedules = service.get_schedules_using_config('test.yaml')
|
||||
# Should return empty list (ScheduleService might not exist in test env)
|
||||
assert isinstance(schedules, list)
|
||||
|
||||
def test_list_configs_sorted_by_date(self, service, temp_configs_dir, sample_yaml_config):
|
||||
"""Test that configs are sorted by creation date (most recent first)"""
|
||||
import time
|
||||
|
||||
# Create first config
|
||||
config1_path = os.path.join(temp_configs_dir, 'config1.yaml')
|
||||
with open(config1_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
time.sleep(0.1) # Ensure different timestamps
|
||||
|
||||
# Create second config
|
||||
config2_path = os.path.join(temp_configs_dir, 'config2.yaml')
|
||||
with open(config2_path, 'w') as f:
|
||||
f.write(sample_yaml_config)
|
||||
|
||||
configs = service.list_configs()
|
||||
assert len(configs) == 2
|
||||
# Most recent should be first
|
||||
assert configs[0]['filename'] == 'config2.yaml'
|
||||
assert configs[1]['filename'] == 'config1.yaml'
|
||||
|
||||
def test_list_configs_handles_parse_errors(self, service, temp_configs_dir):
|
||||
"""Test that list_configs handles files that can't be parsed"""
|
||||
# Create invalid YAML file
|
||||
config_path = os.path.join(temp_configs_dir, 'invalid.yaml')
|
||||
with open(config_path, 'w') as f:
|
||||
f.write("invalid: yaml: [")
|
||||
|
||||
# Should not raise error, just use filename as title
|
||||
configs = service.list_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]['filename'] == 'invalid.yaml'
|
||||
267
app/tests/test_error_handling.py
Normal file
267
app/tests/test_error_handling.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Tests for error handling and logging functionality.
|
||||
|
||||
Tests error handlers, request/response logging, database rollback on errors,
|
||||
and proper error responses (JSON vs HTML).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from web.app import create_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test Flask app."""
|
||||
test_config = {
|
||||
'TESTING': True,
|
||||
'SQLALCHEMY_DATABASE_URI': 'sqlite:///:memory:',
|
||||
'SECRET_KEY': 'test-secret-key',
|
||||
'WTF_CSRF_ENABLED': False
|
||||
}
|
||||
app = create_app(test_config)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
class TestErrorHandlers:
|
||||
"""Test error handler functionality."""
|
||||
|
||||
def test_404_json_response(self, client):
|
||||
"""Test 404 error returns JSON for API requests."""
|
||||
response = client.get('/api/nonexistent')
|
||||
assert response.status_code == 404
|
||||
assert response.content_type == 'application/json'
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error'] == 'Not Found'
|
||||
assert 'message' in data
|
||||
|
||||
def test_404_html_response(self, client):
|
||||
"""Test 404 error returns HTML for web requests."""
|
||||
response = client.get('/nonexistent')
|
||||
assert response.status_code == 404
|
||||
assert 'text/html' in response.content_type
|
||||
assert b'404' in response.data
|
||||
|
||||
def test_400_json_response(self, client):
|
||||
"""Test 400 error returns JSON for API requests."""
|
||||
# Trigger 400 by sending invalid JSON
|
||||
response = client.post(
|
||||
'/api/scans',
|
||||
data='invalid json',
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code in [400, 401] # 401 if auth required
|
||||
|
||||
def test_405_method_not_allowed(self, client):
|
||||
"""Test 405 error for method not allowed."""
|
||||
# Try POST to health check (only GET allowed)
|
||||
response = client.post('/api/scans/health')
|
||||
assert response.status_code == 405
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error'] == 'Method Not Allowed'
|
||||
|
||||
def test_json_accept_header(self, client):
|
||||
"""Test JSON response when Accept header specifies JSON."""
|
||||
response = client.get(
|
||||
'/nonexistent',
|
||||
headers={'Accept': 'application/json'}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.content_type == 'application/json'
|
||||
|
||||
|
||||
class TestLogging:
|
||||
"""Test logging functionality."""
|
||||
|
||||
def test_request_logging(self, client, caplog):
|
||||
"""Test that requests are logged."""
|
||||
with caplog.at_level(logging.INFO):
|
||||
response = client.get('/api/scans/health')
|
||||
|
||||
# Check log messages
|
||||
log_messages = [record.message for record in caplog.records]
|
||||
# Should log incoming request and response
|
||||
assert any('GET /api/scans/health' in msg for msg in log_messages)
|
||||
|
||||
def test_error_logging(self, client, caplog):
|
||||
"""Test that errors are logged with full context."""
|
||||
with caplog.at_level(logging.INFO):
|
||||
client.get('/api/nonexistent')
|
||||
|
||||
# Check that 404 was logged
|
||||
log_messages = [record.message for record in caplog.records]
|
||||
assert any('not found' in msg.lower() or '404' in msg for msg in log_messages)
|
||||
|
||||
def test_request_id_in_logs(self, client, caplog):
|
||||
"""Test that request ID is included in log records."""
|
||||
with caplog.at_level(logging.INFO):
|
||||
client.get('/api/scans/health')
|
||||
|
||||
# Check that log records have request_id attribute
|
||||
for record in caplog.records:
|
||||
assert hasattr(record, 'request_id')
|
||||
assert record.request_id # Should not be empty
|
||||
|
||||
|
||||
class TestRequestResponseHandlers:
|
||||
"""Test request and response handler middleware."""
|
||||
|
||||
def test_request_id_header(self, client):
|
||||
"""Test that response includes X-Request-ID header for API requests."""
|
||||
response = client.get('/api/scans/health')
|
||||
assert 'X-Request-ID' in response.headers
|
||||
|
||||
def test_request_duration_header(self, client):
|
||||
"""Test that response includes X-Request-Duration-Ms header."""
|
||||
response = client.get('/api/scans/health')
|
||||
assert 'X-Request-Duration-Ms' in response.headers
|
||||
|
||||
duration = float(response.headers['X-Request-Duration-Ms'])
|
||||
assert duration >= 0 # Should be non-negative
|
||||
|
||||
def test_security_headers(self, client):
|
||||
"""Test that security headers are added to API responses."""
|
||||
response = client.get('/api/scans/health')
|
||||
|
||||
# Check security headers
|
||||
assert response.headers.get('X-Content-Type-Options') == 'nosniff'
|
||||
assert response.headers.get('X-Frame-Options') == 'DENY'
|
||||
assert response.headers.get('X-XSS-Protection') == '1; mode=block'
|
||||
|
||||
def test_request_timing(self, client):
|
||||
"""Test that request timing is calculated correctly."""
|
||||
response = client.get('/api/scans/health')
|
||||
|
||||
duration_header = response.headers.get('X-Request-Duration-Ms')
|
||||
assert duration_header is not None
|
||||
|
||||
duration = float(duration_header)
|
||||
# Should complete in reasonable time (less than 5 seconds)
|
||||
assert duration < 5000
|
||||
|
||||
|
||||
class TestDatabaseErrorHandling:
|
||||
"""Test database error handling and rollback."""
|
||||
|
||||
def test_database_rollback_on_error(self, app):
|
||||
"""Test that database session is rolled back on error."""
|
||||
# This test would require triggering a database error
|
||||
# For now, just verify the error handler is registered
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# Check that SQLAlchemyError handler is registered
|
||||
assert SQLAlchemyError in app.error_handler_spec[None]
|
||||
|
||||
|
||||
class TestLogRotation:
|
||||
"""Test log rotation configuration."""
|
||||
|
||||
def test_log_files_created(self, app, tmp_path):
|
||||
"""Test that log files are created in logs directory."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Check that logs directory exists
|
||||
log_dir = Path('logs')
|
||||
# Note: In test environment, logs may not be created immediately
|
||||
# Just verify the configuration is set up
|
||||
|
||||
# Verify app logger has handlers
|
||||
assert len(app.logger.handlers) > 0
|
||||
|
||||
# Verify at least one handler is a RotatingFileHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
has_rotating_handler = any(
|
||||
isinstance(h, RotatingFileHandler)
|
||||
for h in app.logger.handlers
|
||||
)
|
||||
assert has_rotating_handler, "Should have RotatingFileHandler configured"
|
||||
|
||||
def test_log_handler_configuration(self, app):
|
||||
"""Test that log handlers are configured correctly."""
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
# Find RotatingFileHandler
|
||||
rotating_handlers = [
|
||||
h for h in app.logger.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
]
|
||||
|
||||
assert len(rotating_handlers) > 0, "Should have rotating file handlers"
|
||||
|
||||
# Check handler configuration
|
||||
for handler in rotating_handlers:
|
||||
# Should have max size configured
|
||||
assert handler.maxBytes > 0
|
||||
# Should have backup count configured
|
||||
assert handler.backupCount > 0
|
||||
|
||||
|
||||
class TestStructuredLogging:
|
||||
"""Test structured logging features."""
|
||||
|
||||
def test_log_format_includes_request_id(self, client, caplog):
|
||||
"""Test that log format includes request ID."""
|
||||
with caplog.at_level(logging.INFO):
|
||||
client.get('/api/scans/health')
|
||||
|
||||
# Verify log records have request_id
|
||||
for record in caplog.records:
|
||||
assert hasattr(record, 'request_id')
|
||||
|
||||
def test_error_log_includes_traceback(self, app, caplog):
|
||||
"""Test that errors are logged with traceback."""
|
||||
with app.test_request_context('/api/test'):
|
||||
with caplog.at_level(logging.ERROR):
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError as e:
|
||||
app.logger.error("Test error occurred", exc_info=True)
|
||||
|
||||
# Check that traceback is in logs
|
||||
log_output = caplog.text
|
||||
assert 'Test error' in log_output
|
||||
assert 'Traceback' in log_output or 'ValueError' in log_output
|
||||
|
||||
|
||||
class TestErrorTemplates:
|
||||
"""Test error template rendering."""
|
||||
|
||||
def test_404_template_exists(self, client):
|
||||
"""Test that 404 error template is rendered."""
|
||||
response = client.get('/nonexistent')
|
||||
assert response.status_code == 404
|
||||
assert b'404' in response.data
|
||||
assert b'Page Not Found' in response.data or b'Not Found' in response.data
|
||||
|
||||
def test_500_template_exists(self, app):
|
||||
"""Test that 500 error template can be rendered."""
|
||||
# We can't easily trigger a 500 without breaking the app
|
||||
# Just verify the template file exists
|
||||
from pathlib import Path
|
||||
template_path = Path('web/templates/errors/500.html')
|
||||
assert template_path.exists(), "500 error template should exist"
|
||||
|
||||
def test_error_template_styling(self, client):
|
||||
"""Test that error templates include styling."""
|
||||
response = client.get('/nonexistent')
|
||||
# Should include CSS styling
|
||||
assert b'style' in response.data or b'css' in response.data.lower()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
267
app/tests/test_scan_api.py
Normal file
267
app/tests/test_scan_api.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Integration tests for Scan API endpoints.
|
||||
|
||||
Tests all scan management endpoints including triggering scans,
|
||||
listing, retrieving details, deleting, and status polling.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from web.models import Scan
|
||||
|
||||
|
||||
class TestScanAPIEndpoints:
|
||||
"""Test suite for scan API endpoints."""
|
||||
|
||||
def test_list_scans_empty(self, client, db):
|
||||
"""Test listing scans when database is empty."""
|
||||
response = client.get('/api/scans')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['scans'] == []
|
||||
assert data['total'] == 0
|
||||
assert data['page'] == 1
|
||||
assert data['per_page'] == 20
|
||||
|
||||
def test_list_scans_with_data(self, client, db, sample_scan):
|
||||
"""Test listing scans with existing data."""
|
||||
response = client.get('/api/scans')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 1
|
||||
assert len(data['scans']) == 1
|
||||
assert data['scans'][0]['id'] == sample_scan.id
|
||||
|
||||
def test_list_scans_pagination(self, client, db):
|
||||
"""Test scan list pagination."""
|
||||
# Create 25 scans
|
||||
for i in range(25):
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='completed',
|
||||
config_file=f'/app/configs/test{i}.yaml',
|
||||
title=f'Test Scan {i}',
|
||||
triggered_by='test'
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
# Test page 1
|
||||
response = client.get('/api/scans?page=1&per_page=10')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 25
|
||||
assert len(data['scans']) == 10
|
||||
assert data['page'] == 1
|
||||
assert data['per_page'] == 10
|
||||
assert data['total_pages'] == 3
|
||||
assert data['has_next'] is True
|
||||
assert data['has_prev'] is False
|
||||
|
||||
# Test page 2
|
||||
response = client.get('/api/scans?page=2&per_page=10')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert len(data['scans']) == 10
|
||||
assert data['page'] == 2
|
||||
assert data['has_next'] is True
|
||||
assert data['has_prev'] is True
|
||||
|
||||
def test_list_scans_status_filter(self, client, db):
|
||||
"""Test filtering scans by status."""
|
||||
# Create scans with different statuses
|
||||
for status in ['running', 'completed', 'failed']:
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status=status,
|
||||
config_file='/app/configs/test.yaml',
|
||||
title=f'{status.capitalize()} Scan',
|
||||
triggered_by='test'
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
# Filter by completed
|
||||
response = client.get('/api/scans?status=completed')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 1
|
||||
assert data['scans'][0]['status'] == 'completed'
|
||||
|
||||
def test_list_scans_invalid_page(self, client, db):
|
||||
"""Test listing scans with invalid page parameter."""
|
||||
response = client.get('/api/scans?page=0')
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_get_scan_success(self, client, db, sample_scan):
|
||||
"""Test retrieving a specific scan."""
|
||||
response = client.get(f'/api/scans/{sample_scan.id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['id'] == sample_scan.id
|
||||
assert data['title'] == sample_scan.title
|
||||
assert data['status'] == sample_scan.status
|
||||
|
||||
def test_get_scan_not_found(self, client, db):
|
||||
"""Test retrieving a non-existent scan."""
|
||||
response = client.get('/api/scans/99999')
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error'] == 'Not found'
|
||||
|
||||
def test_trigger_scan_success(self, client, db, sample_config_file):
|
||||
"""Test triggering a new scan."""
|
||||
response = client.post('/api/scans',
|
||||
json={'config_file': str(sample_config_file)},
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'scan_id' in data
|
||||
assert data['status'] == 'running'
|
||||
assert data['message'] == 'Scan queued successfully'
|
||||
|
||||
# Verify scan was created in database
|
||||
scan = db.query(Scan).filter_by(id=data['scan_id']).first()
|
||||
assert scan is not None
|
||||
assert scan.status == 'running'
|
||||
assert scan.triggered_by == 'api'
|
||||
|
||||
def test_trigger_scan_missing_config_file(self, client, db):
|
||||
"""Test triggering scan without config_file."""
|
||||
response = client.post('/api/scans',
|
||||
json={},
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'config_file is required' in data['message']
|
||||
|
||||
def test_trigger_scan_invalid_config_file(self, client, db):
|
||||
"""Test triggering scan with non-existent config file."""
|
||||
response = client.post('/api/scans',
|
||||
json={'config_file': '/nonexistent/config.yaml'},
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_delete_scan_success(self, client, db, sample_scan):
|
||||
"""Test deleting a scan."""
|
||||
scan_id = sample_scan.id
|
||||
|
||||
response = client.delete(f'/api/scans/{scan_id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['scan_id'] == scan_id
|
||||
assert 'deleted successfully' in data['message']
|
||||
|
||||
# Verify scan was deleted from database
|
||||
scan = db.query(Scan).filter_by(id=scan_id).first()
|
||||
assert scan is None
|
||||
|
||||
def test_delete_scan_not_found(self, client, db):
|
||||
"""Test deleting a non-existent scan."""
|
||||
response = client.delete('/api/scans/99999')
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_get_scan_status_success(self, client, db, sample_scan):
|
||||
"""Test getting scan status."""
|
||||
response = client.get(f'/api/scans/{sample_scan.id}/status')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['scan_id'] == sample_scan.id
|
||||
assert data['status'] == sample_scan.status
|
||||
assert 'timestamp' in data
|
||||
|
||||
def test_get_scan_status_not_found(self, client, db):
|
||||
"""Test getting status for non-existent scan."""
|
||||
response = client.get('/api/scans/99999/status')
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_api_error_handling(self, client, db):
|
||||
"""Test API error responses are properly formatted."""
|
||||
# Test 404
|
||||
response = client.get('/api/scans/99999')
|
||||
assert response.status_code == 404
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'message' in data
|
||||
|
||||
# Test 400
|
||||
response = client.post('/api/scans', json={})
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'message' in data
|
||||
|
||||
def test_scan_workflow_integration(self, client, db, sample_config_file):
|
||||
"""
|
||||
Test complete scan workflow: trigger → status → retrieve → delete.
|
||||
|
||||
This integration test verifies the entire scan lifecycle through
|
||||
the API endpoints.
|
||||
"""
|
||||
# Step 1: Trigger scan
|
||||
response = client.post('/api/scans',
|
||||
json={'config_file': str(sample_config_file)},
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = json.loads(response.data)
|
||||
scan_id = data['scan_id']
|
||||
|
||||
# Step 2: Check status
|
||||
response = client.get(f'/api/scans/{scan_id}/status')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['scan_id'] == scan_id
|
||||
assert data['status'] == 'running'
|
||||
|
||||
# Step 3: List scans (verify it appears)
|
||||
response = client.get('/api/scans')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 1
|
||||
assert data['scans'][0]['id'] == scan_id
|
||||
|
||||
# Step 4: Get scan details
|
||||
response = client.get(f'/api/scans/{scan_id}')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['id'] == scan_id
|
||||
|
||||
# Step 5: Delete scan
|
||||
response = client.delete(f'/api/scans/{scan_id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Step 6: Verify deletion
|
||||
response = client.get(f'/api/scans/{scan_id}')
|
||||
assert response.status_code == 404
|
||||
319
app/tests/test_scan_comparison.py
Normal file
319
app/tests/test_scan_comparison.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Unit tests for scan comparison functionality.
|
||||
|
||||
Tests scan comparison logic including port, service, and certificate comparisons,
|
||||
as well as drift score calculation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from web.models import Scan, ScanSite, ScanIP, ScanPort
|
||||
from web.models import ScanService as ScanServiceModel, ScanCertificate
|
||||
from web.services.scan_service import ScanService
|
||||
|
||||
|
||||
class TestScanComparison:
|
||||
"""Tests for scan comparison methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def scan1_data(self, test_db, sample_config_file):
|
||||
"""Create first scan with test data."""
|
||||
service = ScanService(test_db)
|
||||
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
||||
|
||||
# Get scan and add some test data
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
scan.status = 'completed'
|
||||
|
||||
# Create site
|
||||
site = ScanSite(scan_id=scan.id, site_name='Test Site')
|
||||
test_db.add(site)
|
||||
test_db.flush()
|
||||
|
||||
# Create IP
|
||||
ip = ScanIP(
|
||||
scan_id=scan.id,
|
||||
site_id=site.id,
|
||||
ip_address='192.168.1.100',
|
||||
ping_expected=True,
|
||||
ping_actual=True
|
||||
)
|
||||
test_db.add(ip)
|
||||
test_db.flush()
|
||||
|
||||
# Create ports
|
||||
port1 = ScanPort(
|
||||
scan_id=scan.id,
|
||||
ip_id=ip.id,
|
||||
port=80,
|
||||
protocol='tcp',
|
||||
state='open',
|
||||
expected=True
|
||||
)
|
||||
port2 = ScanPort(
|
||||
scan_id=scan.id,
|
||||
ip_id=ip.id,
|
||||
port=443,
|
||||
protocol='tcp',
|
||||
state='open',
|
||||
expected=True
|
||||
)
|
||||
test_db.add(port1)
|
||||
test_db.add(port2)
|
||||
test_db.flush()
|
||||
|
||||
# Create service
|
||||
svc1 = ScanServiceModel(
|
||||
scan_id=scan.id,
|
||||
port_id=port1.id,
|
||||
service_name='http',
|
||||
product='nginx',
|
||||
version='1.18.0'
|
||||
)
|
||||
test_db.add(svc1)
|
||||
|
||||
test_db.commit()
|
||||
return scan_id
|
||||
|
||||
@pytest.fixture
|
||||
def scan2_data(self, test_db, sample_config_file):
|
||||
"""Create second scan with modified test data."""
|
||||
service = ScanService(test_db)
|
||||
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
||||
|
||||
# Get scan and add some test data
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
scan.status = 'completed'
|
||||
|
||||
# Create site
|
||||
site = ScanSite(scan_id=scan.id, site_name='Test Site')
|
||||
test_db.add(site)
|
||||
test_db.flush()
|
||||
|
||||
# Create IP
|
||||
ip = ScanIP(
|
||||
scan_id=scan.id,
|
||||
site_id=site.id,
|
||||
ip_address='192.168.1.100',
|
||||
ping_expected=True,
|
||||
ping_actual=True
|
||||
)
|
||||
test_db.add(ip)
|
||||
test_db.flush()
|
||||
|
||||
# Create ports (port 80 removed, 443 kept, 8080 added)
|
||||
port2 = ScanPort(
|
||||
scan_id=scan.id,
|
||||
ip_id=ip.id,
|
||||
port=443,
|
||||
protocol='tcp',
|
||||
state='open',
|
||||
expected=True
|
||||
)
|
||||
port3 = ScanPort(
|
||||
scan_id=scan.id,
|
||||
ip_id=ip.id,
|
||||
port=8080,
|
||||
protocol='tcp',
|
||||
state='open',
|
||||
expected=False
|
||||
)
|
||||
test_db.add(port2)
|
||||
test_db.add(port3)
|
||||
test_db.flush()
|
||||
|
||||
# Create service with updated version
|
||||
svc2 = ScanServiceModel(
|
||||
scan_id=scan.id,
|
||||
port_id=port3.id,
|
||||
service_name='http',
|
||||
product='nginx',
|
||||
version='1.20.0' # Version changed
|
||||
)
|
||||
test_db.add(svc2)
|
||||
|
||||
test_db.commit()
|
||||
return scan_id
|
||||
|
||||
def test_compare_scans_basic(self, test_db, scan1_data, scan2_data):
|
||||
"""Test basic scan comparison."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.compare_scans(scan1_data, scan2_data)
|
||||
|
||||
assert result is not None
|
||||
assert 'scan1' in result
|
||||
assert 'scan2' in result
|
||||
assert 'ports' in result
|
||||
assert 'services' in result
|
||||
assert 'certificates' in result
|
||||
assert 'drift_score' in result
|
||||
|
||||
# Verify scan metadata
|
||||
assert result['scan1']['id'] == scan1_data
|
||||
assert result['scan2']['id'] == scan2_data
|
||||
|
||||
def test_compare_scans_not_found(self, test_db):
|
||||
"""Test comparison with nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.compare_scans(999, 998)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_compare_ports(self, test_db, scan1_data, scan2_data):
|
||||
"""Test port comparison logic."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.compare_scans(scan1_data, scan2_data)
|
||||
|
||||
# Scan1 has ports 80, 443
|
||||
# Scan2 has ports 443, 8080
|
||||
# Expected: added=[8080], removed=[80], unchanged=[443]
|
||||
|
||||
ports = result['ports']
|
||||
assert len(ports['added']) == 1
|
||||
assert len(ports['removed']) == 1
|
||||
assert len(ports['unchanged']) == 1
|
||||
|
||||
# Check added port
|
||||
added_port = ports['added'][0]
|
||||
assert added_port['port'] == 8080
|
||||
|
||||
# Check removed port
|
||||
removed_port = ports['removed'][0]
|
||||
assert removed_port['port'] == 80
|
||||
|
||||
# Check unchanged port
|
||||
unchanged_port = ports['unchanged'][0]
|
||||
assert unchanged_port['port'] == 443
|
||||
|
||||
def test_compare_services(self, test_db, scan1_data, scan2_data):
|
||||
"""Test service comparison logic."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.compare_scans(scan1_data, scan2_data)
|
||||
|
||||
services = result['services']
|
||||
|
||||
# Scan1 has nginx 1.18.0 on port 80
|
||||
# Scan2 has nginx 1.20.0 on port 8080
|
||||
# These are on different ports, so they should be added/removed, not changed
|
||||
|
||||
assert len(services['added']) >= 0
|
||||
assert len(services['removed']) >= 0
|
||||
|
||||
def test_drift_score_calculation(self, test_db, scan1_data, scan2_data):
|
||||
"""Test drift score calculation."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.compare_scans(scan1_data, scan2_data)
|
||||
|
||||
drift_score = result['drift_score']
|
||||
|
||||
# Drift score should be between 0.0 and 1.0
|
||||
assert 0.0 <= drift_score <= 1.0
|
||||
|
||||
# Since we have changes (1 port added, 1 removed), drift should be > 0
|
||||
assert drift_score > 0.0
|
||||
|
||||
def test_compare_identical_scans(self, test_db, scan1_data):
|
||||
"""Test comparing a scan with itself (should have zero drift)."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.compare_scans(scan1_data, scan1_data)
|
||||
|
||||
# Comparing scan with itself should have zero drift
|
||||
assert result['drift_score'] == 0.0
|
||||
assert len(result['ports']['added']) == 0
|
||||
assert len(result['ports']['removed']) == 0
|
||||
|
||||
|
||||
class TestScanComparisonAPI:
|
||||
"""Tests for scan comparison API endpoint."""
|
||||
|
||||
def test_compare_scans_api(self, client, auth_headers, scan1_data, scan2_data):
|
||||
"""Test scan comparison API endpoint."""
|
||||
response = client.get(
|
||||
f'/api/scans/{scan1_data}/compare/{scan2_data}',
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
|
||||
assert 'scan1' in data
|
||||
assert 'scan2' in data
|
||||
assert 'ports' in data
|
||||
assert 'services' in data
|
||||
assert 'drift_score' in data
|
||||
|
||||
def test_compare_scans_api_not_found(self, client, auth_headers):
|
||||
"""Test comparison API with nonexistent scans."""
|
||||
response = client.get(
|
||||
'/api/scans/999/compare/998',
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_compare_scans_api_requires_auth(self, client, scan1_data, scan2_data):
|
||||
"""Test that comparison API requires authentication."""
|
||||
response = client.get(f'/api/scans/{scan1_data}/compare/{scan2_data}')
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestHistoricalChartAPI:
|
||||
"""Tests for historical scan chart API endpoint."""
|
||||
|
||||
def test_scan_history_api(self, client, auth_headers, scan1_data):
|
||||
"""Test scan history API endpoint."""
|
||||
response = client.get(
|
||||
f'/api/stats/scan-history/{scan1_data}',
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
|
||||
assert 'scans' in data
|
||||
assert 'labels' in data
|
||||
assert 'port_counts' in data
|
||||
assert 'config_file' in data
|
||||
|
||||
# Should include at least the scan we created
|
||||
assert len(data['scans']) >= 1
|
||||
|
||||
def test_scan_history_api_not_found(self, client, auth_headers):
|
||||
"""Test history API with nonexistent scan."""
|
||||
response = client.get(
|
||||
'/api/stats/scan-history/999',
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_scan_history_api_limit(self, client, auth_headers, scan1_data):
|
||||
"""Test scan history API with limit parameter."""
|
||||
response = client.get(
|
||||
f'/api/stats/scan-history/{scan1_data}?limit=5',
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
|
||||
# Should respect limit
|
||||
assert len(data['scans']) <= 5
|
||||
|
||||
def test_scan_history_api_requires_auth(self, client, scan1_data):
|
||||
"""Test that history API requires authentication."""
|
||||
response = client.get(f'/api/stats/scan-history/{scan1_data}')
|
||||
|
||||
assert response.status_code == 401
|
||||
402
app/tests/test_scan_service.py
Normal file
402
app/tests/test_scan_service.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Unit tests for ScanService class.
|
||||
|
||||
Tests scan lifecycle operations: trigger, get, list, delete, and database mapping.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from web.models import Scan, ScanSite, ScanIP, ScanPort, ScanService as ScanServiceModel
|
||||
from web.services.scan_service import ScanService
|
||||
|
||||
|
||||
class TestScanServiceTrigger:
|
||||
"""Tests for triggering scans."""
|
||||
|
||||
def test_trigger_scan_valid_config(self, test_db, sample_config_file):
|
||||
"""Test triggering a scan with valid config file."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file, triggered_by='manual')
|
||||
|
||||
# Verify scan created
|
||||
assert scan_id is not None
|
||||
assert isinstance(scan_id, int)
|
||||
|
||||
# Verify scan in database
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan is not None
|
||||
assert scan.status == 'running'
|
||||
assert scan.title == 'Test Scan'
|
||||
assert scan.triggered_by == 'manual'
|
||||
assert scan.config_file == sample_config_file
|
||||
|
||||
def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file):
|
||||
"""Test triggering a scan with invalid config file."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid config file"):
|
||||
service.trigger_scan(sample_invalid_config_file)
|
||||
|
||||
def test_trigger_scan_nonexistent_file(self, test_db):
|
||||
"""Test triggering a scan with nonexistent config file."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
service.trigger_scan('/nonexistent/config.yaml')
|
||||
|
||||
def test_trigger_scan_with_schedule(self, test_db, sample_config_file):
|
||||
"""Test triggering a scan via schedule."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(
|
||||
sample_config_file,
|
||||
triggered_by='scheduled',
|
||||
schedule_id=42
|
||||
)
|
||||
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan.triggered_by == 'scheduled'
|
||||
assert scan.schedule_id == 42
|
||||
|
||||
|
||||
class TestScanServiceGet:
|
||||
"""Tests for retrieving scans."""
|
||||
|
||||
def test_get_scan_not_found(self, test_db):
|
||||
"""Test getting a nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.get_scan(999)
|
||||
assert result is None
|
||||
|
||||
def test_get_scan_found(self, test_db, sample_config_file):
|
||||
"""Test getting an existing scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create a scan
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Retrieve it
|
||||
result = service.get_scan(scan_id)
|
||||
|
||||
assert result is not None
|
||||
assert result['id'] == scan_id
|
||||
assert result['title'] == 'Test Scan'
|
||||
assert result['status'] == 'running'
|
||||
assert 'sites' in result
|
||||
|
||||
|
||||
class TestScanServiceList:
|
||||
"""Tests for listing scans."""
|
||||
|
||||
def test_list_scans_empty(self, test_db):
|
||||
"""Test listing scans when database is empty."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.list_scans(page=1, per_page=20)
|
||||
|
||||
assert result.total == 0
|
||||
assert len(result.items) == 0
|
||||
assert result.pages == 0
|
||||
|
||||
def test_list_scans_with_data(self, test_db, sample_config_file):
|
||||
"""Test listing scans with multiple scans."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create 3 scans
|
||||
for i in range(3):
|
||||
service.trigger_scan(sample_config_file, triggered_by='api')
|
||||
|
||||
# List all scans
|
||||
result = service.list_scans(page=1, per_page=20)
|
||||
|
||||
assert result.total == 3
|
||||
assert len(result.items) == 3
|
||||
assert result.pages == 1
|
||||
|
||||
def test_list_scans_pagination(self, test_db, sample_config_file):
|
||||
"""Test pagination."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create 5 scans
|
||||
for i in range(5):
|
||||
service.trigger_scan(sample_config_file)
|
||||
|
||||
# Get page 1 (2 items per page)
|
||||
result = service.list_scans(page=1, per_page=2)
|
||||
assert len(result.items) == 2
|
||||
assert result.total == 5
|
||||
assert result.pages == 3
|
||||
assert result.has_next is True
|
||||
|
||||
# Get page 2
|
||||
result = service.list_scans(page=2, per_page=2)
|
||||
assert len(result.items) == 2
|
||||
assert result.has_prev is True
|
||||
assert result.has_next is True
|
||||
|
||||
# Get page 3 (last page)
|
||||
result = service.list_scans(page=3, per_page=2)
|
||||
assert len(result.items) == 1
|
||||
assert result.has_next is False
|
||||
|
||||
def test_list_scans_filter_by_status(self, test_db, sample_config_file):
|
||||
"""Test filtering scans by status."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create scans with different statuses
|
||||
scan_id_1 = service.trigger_scan(sample_config_file)
|
||||
scan_id_2 = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Mark one as completed
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first()
|
||||
scan.status = 'completed'
|
||||
test_db.commit()
|
||||
|
||||
# Filter by running
|
||||
result = service.list_scans(status_filter='running')
|
||||
assert result.total == 1
|
||||
|
||||
# Filter by completed
|
||||
result = service.list_scans(status_filter='completed')
|
||||
assert result.total == 1
|
||||
|
||||
def test_list_scans_invalid_status_filter(self, test_db):
|
||||
"""Test filtering with invalid status."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid status"):
|
||||
service.list_scans(status_filter='invalid_status')
|
||||
|
||||
|
||||
class TestScanServiceDelete:
|
||||
"""Tests for deleting scans."""
|
||||
|
||||
def test_delete_scan_not_found(self, test_db):
|
||||
"""Test deleting a nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.delete_scan(999)
|
||||
|
||||
def test_delete_scan_success(self, test_db, sample_config_file):
|
||||
"""Test successful scan deletion."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create a scan
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Verify it exists
|
||||
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None
|
||||
|
||||
# Delete it
|
||||
result = service.delete_scan(scan_id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None
|
||||
|
||||
|
||||
class TestScanServiceStatus:
|
||||
"""Tests for scan status retrieval."""
|
||||
|
||||
def test_get_scan_status_not_found(self, test_db):
|
||||
"""Test getting status of nonexistent scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
result = service.get_scan_status(999)
|
||||
assert result is None
|
||||
|
||||
def test_get_scan_status_running(self, test_db, sample_config_file):
|
||||
"""Test getting status of running scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
status = service.get_scan_status(scan_id)
|
||||
|
||||
assert status is not None
|
||||
assert status['scan_id'] == scan_id
|
||||
assert status['status'] == 'running'
|
||||
assert status['progress'] == 'In progress'
|
||||
assert status['title'] == 'Test Scan'
|
||||
|
||||
def test_get_scan_status_completed(self, test_db, sample_config_file):
|
||||
"""Test getting status of completed scan."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create and mark as completed
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
scan.status = 'completed'
|
||||
scan.duration = 125.5
|
||||
test_db.commit()
|
||||
|
||||
status = service.get_scan_status(scan_id)
|
||||
|
||||
assert status['status'] == 'completed'
|
||||
assert status['progress'] == 'Complete'
|
||||
assert status['duration'] == 125.5
|
||||
|
||||
|
||||
class TestScanServiceDatabaseMapping:
|
||||
"""Tests for mapping scan reports to database models."""
|
||||
|
||||
def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test saving a complete scan report to database."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
# Create a scan
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
|
||||
# Save report to database
|
||||
service._save_scan_to_db(sample_scan_report, scan_id, status='completed')
|
||||
|
||||
# Verify scan updated
|
||||
scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan.status == 'completed'
|
||||
assert scan.duration == 125.5
|
||||
|
||||
# Verify sites created
|
||||
sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all()
|
||||
assert len(sites) == 1
|
||||
assert sites[0].site_name == 'Test Site'
|
||||
|
||||
# Verify IPs created
|
||||
ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all()
|
||||
assert len(ips) == 1
|
||||
assert ips[0].ip_address == '192.168.1.10'
|
||||
assert ips[0].ping_expected is True
|
||||
assert ips[0].ping_actual is True
|
||||
|
||||
# Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53)
|
||||
ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all()
|
||||
assert len(ports) == 5 # 4 TCP + 1 UDP
|
||||
|
||||
# Verify TCP ports
|
||||
tcp_ports = [p for p in ports if p.protocol == 'tcp']
|
||||
assert len(tcp_ports) == 4
|
||||
tcp_port_numbers = sorted([p.port for p in tcp_ports])
|
||||
assert tcp_port_numbers == [22, 80, 443, 8080]
|
||||
|
||||
# Verify UDP ports
|
||||
udp_ports = [p for p in ports if p.protocol == 'udp']
|
||||
assert len(udp_ports) == 1
|
||||
assert udp_ports[0].port == 53
|
||||
|
||||
# Verify services created
|
||||
services = test_db.query(ScanServiceModel).filter(
|
||||
ScanServiceModel.scan_id == scan_id
|
||||
).all()
|
||||
assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080)
|
||||
|
||||
# Find HTTPS service
|
||||
https_service = next(
|
||||
(s for s in services if s.service_name == 'https'), None
|
||||
)
|
||||
assert https_service is not None
|
||||
assert https_service.product == 'nginx'
|
||||
assert https_service.version == '1.24.0'
|
||||
assert https_service.http_protocol == 'https'
|
||||
assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png'
|
||||
|
||||
def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test that expected vs actual ports are correctly flagged."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||
|
||||
# Check TCP ports
|
||||
tcp_ports = test_db.query(ScanPort).filter(
|
||||
ScanPort.scan_id == scan_id,
|
||||
ScanPort.protocol == 'tcp'
|
||||
).all()
|
||||
|
||||
# Ports 22, 80, 443 were expected
|
||||
expected_ports = {22, 80, 443}
|
||||
for port in tcp_ports:
|
||||
if port.port in expected_ports:
|
||||
assert port.expected is True, f"Port {port.port} should be expected"
|
||||
else:
|
||||
# Port 8080 was not expected
|
||||
assert port.expected is False, f"Port {port.port} should not be expected"
|
||||
|
||||
def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test that certificate and TLS data are correctly mapped."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||
|
||||
# Find HTTPS service
|
||||
https_service = test_db.query(ScanServiceModel).filter(
|
||||
ScanServiceModel.scan_id == scan_id,
|
||||
ScanServiceModel.service_name == 'https'
|
||||
).first()
|
||||
|
||||
assert https_service is not None
|
||||
|
||||
# Verify certificate created
|
||||
assert len(https_service.certificates) == 1
|
||||
cert = https_service.certificates[0]
|
||||
|
||||
assert cert.subject == 'CN=example.com'
|
||||
assert cert.issuer == "CN=Let's Encrypt Authority"
|
||||
assert cert.days_until_expiry == 365
|
||||
assert cert.is_self_signed is False
|
||||
|
||||
# Verify SANs
|
||||
import json
|
||||
sans = json.loads(cert.sans)
|
||||
assert 'example.com' in sans
|
||||
assert 'www.example.com' in sans
|
||||
|
||||
# Verify TLS versions
|
||||
assert len(cert.tls_versions) == 2
|
||||
|
||||
tls_12 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.2'), None)
|
||||
assert tls_12 is not None
|
||||
assert tls_12.supported is True
|
||||
|
||||
tls_13 = next((t for t in cert.tls_versions if t.tls_version == 'TLS 1.3'), None)
|
||||
assert tls_13 is not None
|
||||
assert tls_13.supported is True
|
||||
|
||||
def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report):
|
||||
"""Test retrieving scan with all nested relationships."""
|
||||
service = ScanService(test_db)
|
||||
|
||||
scan_id = service.trigger_scan(sample_config_file)
|
||||
service._save_scan_to_db(sample_scan_report, scan_id)
|
||||
|
||||
# Get full scan details
|
||||
result = service.get_scan(scan_id)
|
||||
|
||||
assert result is not None
|
||||
assert len(result['sites']) == 1
|
||||
|
||||
site = result['sites'][0]
|
||||
assert site['name'] == 'Test Site'
|
||||
assert len(site['ips']) == 1
|
||||
|
||||
ip = site['ips'][0]
|
||||
assert ip['address'] == '192.168.1.10'
|
||||
assert len(ip['ports']) == 5 # 4 TCP + 1 UDP
|
||||
|
||||
# Find HTTPS port
|
||||
https_port = next(
|
||||
(p for p in ip['ports'] if p['port'] == 443), None
|
||||
)
|
||||
assert https_port is not None
|
||||
assert len(https_port['services']) == 1
|
||||
|
||||
service_data = https_port['services'][0]
|
||||
assert service_data['service_name'] == 'https'
|
||||
assert 'certificates' in service_data
|
||||
assert len(service_data['certificates']) == 1
|
||||
|
||||
cert = service_data['certificates'][0]
|
||||
assert cert['subject'] == 'CN=example.com'
|
||||
assert 'tls_versions' in cert
|
||||
assert len(cert['tls_versions']) == 2
|
||||
639
app/tests/test_schedule_api.py
Normal file
639
app/tests/test_schedule_api.py
Normal file
@@ -0,0 +1,639 @@
|
||||
"""
|
||||
Integration tests for Schedule API endpoints.
|
||||
|
||||
Tests all schedule management endpoints including creating, listing,
|
||||
updating, deleting schedules, and manually triggering scheduled scans.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from web.models import Schedule, Scan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_schedule(db, sample_config_file):
|
||||
"""
|
||||
Create a sample schedule in the database for testing.
|
||||
|
||||
Args:
|
||||
db: Database session fixture
|
||||
sample_config_file: Path to test config file
|
||||
|
||||
Returns:
|
||||
Schedule model instance
|
||||
"""
|
||||
schedule = Schedule(
|
||||
name='Daily Test Scan',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True,
|
||||
last_run=None,
|
||||
next_run=datetime(2025, 11, 15, 2, 0, 0),
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
db.add(schedule)
|
||||
db.commit()
|
||||
db.refresh(schedule)
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
class TestScheduleAPIEndpoints:
|
||||
"""Test suite for schedule API endpoints."""
|
||||
|
||||
def test_list_schedules_empty(self, client, db):
|
||||
"""Test listing schedules when database is empty."""
|
||||
response = client.get('/api/schedules')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['schedules'] == []
|
||||
assert data['total'] == 0
|
||||
assert data['page'] == 1
|
||||
assert data['per_page'] == 20
|
||||
|
||||
def test_list_schedules_populated(self, client, db, sample_schedule):
|
||||
"""Test listing schedules with existing data."""
|
||||
response = client.get('/api/schedules')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 1
|
||||
assert len(data['schedules']) == 1
|
||||
assert data['schedules'][0]['id'] == sample_schedule.id
|
||||
assert data['schedules'][0]['name'] == sample_schedule.name
|
||||
assert data['schedules'][0]['cron_expression'] == sample_schedule.cron_expression
|
||||
|
||||
def test_list_schedules_pagination(self, client, db, sample_config_file):
|
||||
"""Test schedule list pagination."""
|
||||
# Create 25 schedules
|
||||
for i in range(25):
|
||||
schedule = Schedule(
|
||||
name=f'Schedule {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(schedule)
|
||||
db.commit()
|
||||
|
||||
# Test page 1
|
||||
response = client.get('/api/schedules?page=1&per_page=10')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 25
|
||||
assert len(data['schedules']) == 10
|
||||
assert data['page'] == 1
|
||||
assert data['per_page'] == 10
|
||||
assert data['pages'] == 3
|
||||
|
||||
# Test page 2
|
||||
response = client.get('/api/schedules?page=2&per_page=10')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert len(data['schedules']) == 10
|
||||
assert data['page'] == 2
|
||||
|
||||
def test_list_schedules_filter_enabled(self, client, db, sample_config_file):
|
||||
"""Test filtering schedules by enabled status."""
|
||||
# Create enabled and disabled schedules
|
||||
for i in range(3):
|
||||
schedule = Schedule(
|
||||
name=f'Enabled Schedule {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(schedule)
|
||||
|
||||
for i in range(2):
|
||||
schedule = Schedule(
|
||||
name=f'Disabled Schedule {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 3 * * *',
|
||||
enabled=False,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(schedule)
|
||||
db.commit()
|
||||
|
||||
# Filter by enabled=true
|
||||
response = client.get('/api/schedules?enabled=true')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 3
|
||||
for schedule in data['schedules']:
|
||||
assert schedule['enabled'] is True
|
||||
|
||||
# Filter by enabled=false
|
||||
response = client.get('/api/schedules?enabled=false')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['total'] == 2
|
||||
for schedule in data['schedules']:
|
||||
assert schedule['enabled'] is False
|
||||
|
||||
def test_get_schedule(self, client, db, sample_schedule):
|
||||
"""Test getting schedule details."""
|
||||
response = client.get(f'/api/schedules/{sample_schedule.id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['id'] == sample_schedule.id
|
||||
assert data['name'] == sample_schedule.name
|
||||
assert data['config_file'] == sample_schedule.config_file
|
||||
assert data['cron_expression'] == sample_schedule.cron_expression
|
||||
assert data['enabled'] == sample_schedule.enabled
|
||||
assert 'history' in data
|
||||
|
||||
def test_get_schedule_not_found(self, client, db):
|
||||
"""Test getting non-existent schedule."""
|
||||
response = client.get('/api/schedules/99999')
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'not found' in data['error'].lower()
|
||||
|
||||
def test_create_schedule(self, client, db, sample_config_file):
|
||||
"""Test creating a new schedule."""
|
||||
schedule_data = {
|
||||
'name': 'New Test Schedule',
|
||||
'config_file': sample_config_file,
|
||||
'cron_expression': '0 3 * * *',
|
||||
'enabled': True
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'schedule_id' in data
|
||||
assert data['message'] == 'Schedule created successfully'
|
||||
assert 'schedule' in data
|
||||
|
||||
# Verify schedule in database
|
||||
schedule = db.query(Schedule).filter(Schedule.id == data['schedule_id']).first()
|
||||
assert schedule is not None
|
||||
assert schedule.name == schedule_data['name']
|
||||
assert schedule.cron_expression == schedule_data['cron_expression']
|
||||
|
||||
def test_create_schedule_missing_fields(self, client, db):
|
||||
"""Test creating schedule with missing required fields."""
|
||||
# Missing cron_expression
|
||||
schedule_data = {
|
||||
'name': 'Incomplete Schedule',
|
||||
'config_file': '/app/configs/test.yaml'
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'missing' in data['error'].lower()
|
||||
|
||||
def test_create_schedule_invalid_cron(self, client, db, sample_config_file):
|
||||
"""Test creating schedule with invalid cron expression."""
|
||||
schedule_data = {
|
||||
'name': 'Invalid Cron Schedule',
|
||||
'config_file': sample_config_file,
|
||||
'cron_expression': 'invalid cron'
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'invalid' in data['error'].lower() or 'cron' in data['error'].lower()
|
||||
|
||||
def test_create_schedule_invalid_config(self, client, db):
|
||||
"""Test creating schedule with non-existent config file."""
|
||||
schedule_data = {
|
||||
'name': 'Invalid Config Schedule',
|
||||
'config_file': '/nonexistent/config.yaml',
|
||||
'cron_expression': '0 2 * * *'
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert 'not found' in data['error'].lower()
|
||||
|
||||
def test_update_schedule(self, client, db, sample_schedule):
|
||||
"""Test updating schedule fields."""
|
||||
update_data = {
|
||||
'name': 'Updated Schedule Name',
|
||||
'cron_expression': '0 4 * * *'
|
||||
}
|
||||
|
||||
response = client.put(
|
||||
f'/api/schedules/{sample_schedule.id}',
|
||||
data=json.dumps(update_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['message'] == 'Schedule updated successfully'
|
||||
assert data['schedule']['name'] == update_data['name']
|
||||
assert data['schedule']['cron_expression'] == update_data['cron_expression']
|
||||
|
||||
# Verify in database
|
||||
db.refresh(sample_schedule)
|
||||
assert sample_schedule.name == update_data['name']
|
||||
assert sample_schedule.cron_expression == update_data['cron_expression']
|
||||
|
||||
def test_update_schedule_not_found(self, client, db):
|
||||
"""Test updating non-existent schedule."""
|
||||
update_data = {'name': 'New Name'}
|
||||
|
||||
response = client.put(
|
||||
'/api/schedules/99999',
|
||||
data=json.dumps(update_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_update_schedule_invalid_cron(self, client, db, sample_schedule):
|
||||
"""Test updating schedule with invalid cron expression."""
|
||||
update_data = {'cron_expression': 'invalid'}
|
||||
|
||||
response = client.put(
|
||||
f'/api/schedules/{sample_schedule.id}',
|
||||
data=json.dumps(update_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_update_schedule_toggle_enabled(self, client, db, sample_schedule):
|
||||
"""Test enabling/disabling schedule."""
|
||||
# Disable schedule
|
||||
response = client.put(
|
||||
f'/api/schedules/{sample_schedule.id}',
|
||||
data=json.dumps({'enabled': False}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['schedule']['enabled'] is False
|
||||
|
||||
# Enable schedule
|
||||
response = client.put(
|
||||
f'/api/schedules/{sample_schedule.id}',
|
||||
data=json.dumps({'enabled': True}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['schedule']['enabled'] is True
|
||||
|
||||
def test_update_schedule_no_data(self, client, db, sample_schedule):
|
||||
"""Test updating schedule with no data."""
|
||||
response = client.put(
|
||||
f'/api/schedules/{sample_schedule.id}',
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_delete_schedule(self, client, db, sample_schedule):
|
||||
"""Test deleting a schedule."""
|
||||
schedule_id = sample_schedule.id
|
||||
|
||||
response = client.delete(f'/api/schedules/{schedule_id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['message'] == 'Schedule deleted successfully'
|
||||
assert data['schedule_id'] == schedule_id
|
||||
|
||||
# Verify deletion in database
|
||||
schedule = db.query(Schedule).filter(Schedule.id == schedule_id).first()
|
||||
assert schedule is None
|
||||
|
||||
def test_delete_schedule_not_found(self, client, db):
|
||||
"""Test deleting non-existent schedule."""
|
||||
response = client.delete('/api/schedules/99999')
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_delete_schedule_preserves_scans(self, client, db, sample_schedule, sample_config_file):
|
||||
"""Test that deleting schedule preserves associated scans."""
|
||||
# Create a scan associated with the schedule
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='completed',
|
||||
config_file=sample_config_file,
|
||||
title='Test Scan',
|
||||
triggered_by='scheduled',
|
||||
schedule_id=sample_schedule.id
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
scan_id = scan.id
|
||||
|
||||
# Delete schedule
|
||||
response = client.delete(f'/api/schedules/{sample_schedule.id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify scan still exists
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan is not None
|
||||
assert scan.schedule_id is None # Schedule ID becomes null
|
||||
|
||||
def test_trigger_schedule(self, client, db, sample_schedule):
|
||||
"""Test manually triggering a scheduled scan."""
|
||||
response = client.post(f'/api/schedules/{sample_schedule.id}/trigger')
|
||||
assert response.status_code == 201
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['message'] == 'Scan triggered successfully'
|
||||
assert 'scan_id' in data
|
||||
assert data['schedule_id'] == sample_schedule.id
|
||||
|
||||
# Verify scan was created
|
||||
scan = db.query(Scan).filter(Scan.id == data['scan_id']).first()
|
||||
assert scan is not None
|
||||
assert scan.triggered_by == 'manual'
|
||||
assert scan.schedule_id == sample_schedule.id
|
||||
assert scan.config_file == sample_schedule.config_file
|
||||
|
||||
def test_trigger_schedule_not_found(self, client, db):
|
||||
"""Test triggering non-existent schedule."""
|
||||
response = client.post('/api/schedules/99999/trigger')
|
||||
assert response.status_code == 404
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_get_schedule_with_history(self, client, db, sample_schedule, sample_config_file):
|
||||
"""Test getting schedule includes execution history."""
|
||||
# Create some scans for this schedule
|
||||
for i in range(5):
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='completed',
|
||||
config_file=sample_config_file,
|
||||
title=f'Scheduled Scan {i}',
|
||||
triggered_by='scheduled',
|
||||
schedule_id=sample_schedule.id
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
|
||||
response = client.get(f'/api/schedules/{sample_schedule.id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'history' in data
|
||||
assert len(data['history']) == 5
|
||||
|
||||
def test_schedule_workflow_integration(self, client, db, sample_config_file):
|
||||
"""Test complete schedule workflow: create → update → trigger → delete."""
|
||||
# 1. Create schedule
|
||||
schedule_data = {
|
||||
'name': 'Integration Test Schedule',
|
||||
'config_file': sample_config_file,
|
||||
'cron_expression': '0 2 * * *',
|
||||
'enabled': True
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 201
|
||||
schedule_id = json.loads(response.data)['schedule_id']
|
||||
|
||||
# 2. Get schedule
|
||||
response = client.get(f'/api/schedules/{schedule_id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
# 3. Update schedule
|
||||
response = client.put(
|
||||
f'/api/schedules/{schedule_id}',
|
||||
data=json.dumps({'name': 'Updated Integration Test'}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 4. Trigger schedule
|
||||
response = client.post(f'/api/schedules/{schedule_id}/trigger')
|
||||
assert response.status_code == 201
|
||||
scan_id = json.loads(response.data)['scan_id']
|
||||
|
||||
# 5. Verify scan was created
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan is not None
|
||||
|
||||
# 6. Delete schedule
|
||||
response = client.delete(f'/api/schedules/{schedule_id}')
|
||||
assert response.status_code == 200
|
||||
|
||||
# 7. Verify schedule deleted
|
||||
response = client.get(f'/api/schedules/{schedule_id}')
|
||||
assert response.status_code == 404
|
||||
|
||||
# 8. Verify scan still exists
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert scan is not None
|
||||
|
||||
def test_list_schedules_ordering(self, client, db, sample_config_file):
|
||||
"""Test that schedules are ordered by next_run time."""
|
||||
# Create schedules with different next_run times
|
||||
schedules = []
|
||||
for i in range(3):
|
||||
schedule = Schedule(
|
||||
name=f'Schedule {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True,
|
||||
next_run=datetime(2025, 11, 15 + i, 2, 0, 0),
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(schedule)
|
||||
schedules.append(schedule)
|
||||
|
||||
# Create a disabled schedule (next_run is None)
|
||||
disabled_schedule = Schedule(
|
||||
name='Disabled Schedule',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 3 * * *',
|
||||
enabled=False,
|
||||
next_run=None,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
db.add(disabled_schedule)
|
||||
db.commit()
|
||||
|
||||
response = client.get('/api/schedules')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
returned_schedules = data['schedules']
|
||||
|
||||
# Schedules with next_run should come before those without
|
||||
# Within those with next_run, they should be ordered by time
|
||||
assert returned_schedules[0]['id'] == schedules[0].id
|
||||
assert returned_schedules[1]['id'] == schedules[1].id
|
||||
assert returned_schedules[2]['id'] == schedules[2].id
|
||||
assert returned_schedules[3]['id'] == disabled_schedule.id
|
||||
|
||||
def test_create_schedule_with_disabled(self, client, db, sample_config_file):
|
||||
"""Test creating a disabled schedule."""
|
||||
schedule_data = {
|
||||
'name': 'Disabled Schedule',
|
||||
'config_file': sample_config_file,
|
||||
'cron_expression': '0 2 * * *',
|
||||
'enabled': False
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['schedule']['enabled'] is False
|
||||
assert data['schedule']['next_run'] is None # Disabled schedules have no next_run
|
||||
|
||||
|
||||
class TestScheduleAPIAuthentication:
|
||||
"""Test suite for schedule API authentication."""
|
||||
|
||||
def test_schedules_require_authentication(self, app):
|
||||
"""Test that all schedule endpoints require authentication."""
|
||||
# Create unauthenticated client
|
||||
client = app.test_client()
|
||||
|
||||
endpoints = [
|
||||
('GET', '/api/schedules'),
|
||||
('GET', '/api/schedules/1'),
|
||||
('POST', '/api/schedules'),
|
||||
('PUT', '/api/schedules/1'),
|
||||
('DELETE', '/api/schedules/1'),
|
||||
('POST', '/api/schedules/1/trigger')
|
||||
]
|
||||
|
||||
for method, endpoint in endpoints:
|
||||
if method == 'GET':
|
||||
response = client.get(endpoint)
|
||||
elif method == 'POST':
|
||||
response = client.post(
|
||||
endpoint,
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
elif method == 'PUT':
|
||||
response = client.put(
|
||||
endpoint,
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
elif method == 'DELETE':
|
||||
response = client.delete(endpoint)
|
||||
|
||||
# Should redirect to login or return 401
|
||||
assert response.status_code in [302, 401], \
|
||||
f"{method} {endpoint} should require authentication"
|
||||
|
||||
|
||||
class TestScheduleAPICronValidation:
|
||||
"""Test suite for cron expression validation."""
|
||||
|
||||
def test_valid_cron_expressions(self, client, db, sample_config_file):
|
||||
"""Test various valid cron expressions."""
|
||||
valid_expressions = [
|
||||
'0 2 * * *', # Daily at 2am
|
||||
'*/15 * * * *', # Every 15 minutes
|
||||
'0 0 * * 0', # Weekly on Sunday
|
||||
'0 0 1 * *', # Monthly on 1st
|
||||
'0 */4 * * *', # Every 4 hours
|
||||
]
|
||||
|
||||
for cron_expr in valid_expressions:
|
||||
schedule_data = {
|
||||
'name': f'Schedule for {cron_expr}',
|
||||
'config_file': sample_config_file,
|
||||
'cron_expression': cron_expr
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 201, \
|
||||
f"Valid cron expression '{cron_expr}' should be accepted"
|
||||
|
||||
def test_invalid_cron_expressions(self, client, db, sample_config_file):
|
||||
"""Test various invalid cron expressions."""
|
||||
invalid_expressions = [
|
||||
'invalid',
|
||||
'60 2 * * *', # Invalid minute
|
||||
'0 25 * * *', # Invalid hour
|
||||
'0 0 32 * *', # Invalid day
|
||||
'0 0 * 13 *', # Invalid month
|
||||
'0 0 * * 8', # Invalid day of week
|
||||
]
|
||||
|
||||
for cron_expr in invalid_expressions:
|
||||
schedule_data = {
|
||||
'name': f'Schedule for {cron_expr}',
|
||||
'config_file': sample_config_file,
|
||||
'cron_expression': cron_expr
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
'/api/schedules',
|
||||
data=json.dumps(schedule_data),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400, \
|
||||
f"Invalid cron expression '{cron_expr}' should be rejected"
|
||||
671
app/tests/test_schedule_service.py
Normal file
671
app/tests/test_schedule_service.py
Normal file
@@ -0,0 +1,671 @@
|
||||
"""
|
||||
Unit tests for ScheduleService class.
|
||||
|
||||
Tests schedule lifecycle operations: create, get, list, update, delete, and
|
||||
cron expression validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from web.models import Schedule, Scan
|
||||
from web.services.schedule_service import ScheduleService
|
||||
|
||||
|
||||
class TestScheduleServiceCreate:
|
||||
"""Tests for creating schedules."""
|
||||
|
||||
def test_create_schedule_valid(self, test_db, sample_config_file):
|
||||
"""Test creating a schedule with valid parameters."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Daily Scan',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Verify schedule created
|
||||
assert schedule_id is not None
|
||||
assert isinstance(schedule_id, int)
|
||||
|
||||
# Verify schedule in database
|
||||
schedule = test_db.query(Schedule).filter(Schedule.id == schedule_id).first()
|
||||
assert schedule is not None
|
||||
assert schedule.name == 'Daily Scan'
|
||||
assert schedule.config_file == sample_config_file
|
||||
assert schedule.cron_expression == '0 2 * * *'
|
||||
assert schedule.enabled is True
|
||||
assert schedule.next_run is not None
|
||||
assert schedule.last_run is None
|
||||
|
||||
def test_create_schedule_disabled(self, test_db, sample_config_file):
|
||||
"""Test creating a disabled schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Disabled Scan',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 3 * * *',
|
||||
enabled=False
|
||||
)
|
||||
|
||||
schedule = test_db.query(Schedule).filter(Schedule.id == schedule_id).first()
|
||||
assert schedule.enabled is False
|
||||
assert schedule.next_run is None
|
||||
|
||||
def test_create_schedule_invalid_cron(self, test_db, sample_config_file):
|
||||
"""Test creating a schedule with invalid cron expression."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid cron expression"):
|
||||
service.create_schedule(
|
||||
name='Invalid Schedule',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='invalid cron',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
def test_create_schedule_nonexistent_config(self, test_db):
|
||||
"""Test creating a schedule with nonexistent config file."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Config file not found"):
|
||||
service.create_schedule(
|
||||
name='Bad Config',
|
||||
config_file='/nonexistent/config.yaml',
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
def test_create_schedule_various_cron_expressions(self, test_db, sample_config_file):
|
||||
"""Test creating schedules with various valid cron expressions."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
cron_expressions = [
|
||||
'0 0 * * *', # Daily at midnight
|
||||
'*/15 * * * *', # Every 15 minutes
|
||||
'0 2 * * 0', # Weekly on Sunday at 2 AM
|
||||
'0 0 1 * *', # Monthly on the 1st at midnight
|
||||
'30 14 * * 1-5', # Weekdays at 2:30 PM
|
||||
]
|
||||
|
||||
for i, cron in enumerate(cron_expressions):
|
||||
schedule_id = service.create_schedule(
|
||||
name=f'Schedule {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression=cron,
|
||||
enabled=True
|
||||
)
|
||||
assert schedule_id is not None
|
||||
|
||||
|
||||
class TestScheduleServiceGet:
|
||||
"""Tests for retrieving schedules."""
|
||||
|
||||
def test_get_schedule_not_found(self, test_db):
|
||||
"""Test getting a nonexistent schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Schedule .* not found"):
|
||||
service.get_schedule(999)
|
||||
|
||||
def test_get_schedule_found(self, test_db, sample_config_file):
|
||||
"""Test getting an existing schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Create a schedule
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test Schedule',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Retrieve it
|
||||
result = service.get_schedule(schedule_id)
|
||||
|
||||
assert result is not None
|
||||
assert result['id'] == schedule_id
|
||||
assert result['name'] == 'Test Schedule'
|
||||
assert result['cron_expression'] == '0 2 * * *'
|
||||
assert result['enabled'] is True
|
||||
assert 'history' in result
|
||||
assert isinstance(result['history'], list)
|
||||
|
||||
def test_get_schedule_with_history(self, test_db, sample_config_file):
|
||||
"""Test getting schedule includes execution history."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Create schedule
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test Schedule',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Create associated scans
|
||||
for i in range(3):
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow() - timedelta(days=i),
|
||||
status='completed',
|
||||
config_file=sample_config_file,
|
||||
title=f'Scan {i}',
|
||||
triggered_by='scheduled',
|
||||
schedule_id=schedule_id
|
||||
)
|
||||
test_db.add(scan)
|
||||
test_db.commit()
|
||||
|
||||
# Get schedule
|
||||
result = service.get_schedule(schedule_id)
|
||||
|
||||
assert len(result['history']) == 3
|
||||
assert result['history'][0]['title'] == 'Scan 0' # Most recent first
|
||||
|
||||
|
||||
class TestScheduleServiceList:
|
||||
"""Tests for listing schedules."""
|
||||
|
||||
def test_list_schedules_empty(self, test_db):
|
||||
"""Test listing schedules when database is empty."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
result = service.list_schedules(page=1, per_page=20)
|
||||
|
||||
assert result['total'] == 0
|
||||
assert len(result['schedules']) == 0
|
||||
assert result['page'] == 1
|
||||
assert result['per_page'] == 20
|
||||
|
||||
def test_list_schedules_populated(self, test_db, sample_config_file):
|
||||
"""Test listing schedules with data."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Create multiple schedules
|
||||
for i in range(5):
|
||||
service.create_schedule(
|
||||
name=f'Schedule {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
result = service.list_schedules(page=1, per_page=20)
|
||||
|
||||
assert result['total'] == 5
|
||||
assert len(result['schedules']) == 5
|
||||
assert all('name' in s for s in result['schedules'])
|
||||
|
||||
def test_list_schedules_pagination(self, test_db, sample_config_file):
|
||||
"""Test schedule pagination."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Create 25 schedules
|
||||
for i in range(25):
|
||||
service.create_schedule(
|
||||
name=f'Schedule {i:02d}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Get first page
|
||||
result_page1 = service.list_schedules(page=1, per_page=10)
|
||||
assert len(result_page1['schedules']) == 10
|
||||
assert result_page1['total'] == 25
|
||||
assert result_page1['pages'] == 3
|
||||
|
||||
# Get second page
|
||||
result_page2 = service.list_schedules(page=2, per_page=10)
|
||||
assert len(result_page2['schedules']) == 10
|
||||
|
||||
# Get third page
|
||||
result_page3 = service.list_schedules(page=3, per_page=10)
|
||||
assert len(result_page3['schedules']) == 5
|
||||
|
||||
def test_list_schedules_filter_enabled(self, test_db, sample_config_file):
|
||||
"""Test filtering schedules by enabled status."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Create enabled and disabled schedules
|
||||
for i in range(3):
|
||||
service.create_schedule(
|
||||
name=f'Enabled {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
for i in range(2):
|
||||
service.create_schedule(
|
||||
name=f'Disabled {i}',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=False
|
||||
)
|
||||
|
||||
# Filter enabled only
|
||||
result_enabled = service.list_schedules(enabled_filter=True)
|
||||
assert result_enabled['total'] == 3
|
||||
|
||||
# Filter disabled only
|
||||
result_disabled = service.list_schedules(enabled_filter=False)
|
||||
assert result_disabled['total'] == 2
|
||||
|
||||
# No filter
|
||||
result_all = service.list_schedules(enabled_filter=None)
|
||||
assert result_all['total'] == 5
|
||||
|
||||
|
||||
class TestScheduleServiceUpdate:
|
||||
"""Tests for updating schedules."""
|
||||
|
||||
def test_update_schedule_name(self, test_db, sample_config_file):
|
||||
"""Test updating schedule name."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Old Name',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
result = service.update_schedule(schedule_id, name='New Name')
|
||||
|
||||
assert result['name'] == 'New Name'
|
||||
assert result['cron_expression'] == '0 2 * * *'
|
||||
|
||||
def test_update_schedule_cron(self, test_db, sample_config_file):
|
||||
"""Test updating cron expression recalculates next_run."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
original = service.get_schedule(schedule_id)
|
||||
original_next_run = original['next_run']
|
||||
|
||||
# Update cron expression
|
||||
result = service.update_schedule(
|
||||
schedule_id,
|
||||
cron_expression='0 3 * * *'
|
||||
)
|
||||
|
||||
# Next run should be recalculated
|
||||
assert result['cron_expression'] == '0 3 * * *'
|
||||
assert result['next_run'] != original_next_run
|
||||
|
||||
def test_update_schedule_invalid_cron(self, test_db, sample_config_file):
|
||||
"""Test updating with invalid cron expression fails."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid cron expression"):
|
||||
service.update_schedule(schedule_id, cron_expression='invalid')
|
||||
|
||||
def test_update_schedule_not_found(self, test_db):
|
||||
"""Test updating nonexistent schedule fails."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Schedule .* not found"):
|
||||
service.update_schedule(999, name='New Name')
|
||||
|
||||
def test_update_schedule_invalid_config_file(self, test_db, sample_config_file):
|
||||
"""Test updating with nonexistent config file fails."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Config file not found"):
|
||||
service.update_schedule(schedule_id, config_file='/nonexistent.yaml')
|
||||
|
||||
|
||||
class TestScheduleServiceDelete:
|
||||
"""Tests for deleting schedules."""
|
||||
|
||||
def test_delete_schedule(self, test_db, sample_config_file):
|
||||
"""Test deleting a schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='To Delete',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Verify exists
|
||||
assert test_db.query(Schedule).filter(Schedule.id == schedule_id).first() is not None
|
||||
|
||||
# Delete
|
||||
result = service.delete_schedule(schedule_id)
|
||||
assert result is True
|
||||
|
||||
# Verify deleted
|
||||
assert test_db.query(Schedule).filter(Schedule.id == schedule_id).first() is None
|
||||
|
||||
def test_delete_schedule_not_found(self, test_db):
|
||||
"""Test deleting nonexistent schedule fails."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Schedule .* not found"):
|
||||
service.delete_schedule(999)
|
||||
|
||||
def test_delete_schedule_preserves_scans(self, test_db, sample_config_file):
|
||||
"""Test that deleting schedule preserves associated scans."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Create schedule
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Create associated scan
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow(),
|
||||
status='completed',
|
||||
config_file=sample_config_file,
|
||||
title='Test Scan',
|
||||
triggered_by='scheduled',
|
||||
schedule_id=schedule_id
|
||||
)
|
||||
test_db.add(scan)
|
||||
test_db.commit()
|
||||
scan_id = scan.id
|
||||
|
||||
# Delete schedule
|
||||
service.delete_schedule(schedule_id)
|
||||
|
||||
# Verify scan still exists (schedule_id becomes null)
|
||||
remaining_scan = test_db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
assert remaining_scan is not None
|
||||
assert remaining_scan.schedule_id is None
|
||||
|
||||
|
||||
class TestScheduleServiceToggle:
|
||||
"""Tests for toggling schedule enabled status."""
|
||||
|
||||
def test_toggle_enabled_to_disabled(self, test_db, sample_config_file):
|
||||
"""Test disabling an enabled schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
result = service.toggle_enabled(schedule_id, enabled=False)
|
||||
|
||||
assert result['enabled'] is False
|
||||
assert result['next_run'] is None
|
||||
|
||||
def test_toggle_disabled_to_enabled(self, test_db, sample_config_file):
|
||||
"""Test enabling a disabled schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=False
|
||||
)
|
||||
|
||||
result = service.toggle_enabled(schedule_id, enabled=True)
|
||||
|
||||
assert result['enabled'] is True
|
||||
assert result['next_run'] is not None
|
||||
|
||||
|
||||
class TestScheduleServiceRunTimes:
|
||||
"""Tests for updating run times."""
|
||||
|
||||
def test_update_run_times(self, test_db, sample_config_file):
|
||||
"""Test updating last_run and next_run."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
last_run = datetime.utcnow()
|
||||
next_run = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
result = service.update_run_times(schedule_id, last_run, next_run)
|
||||
assert result is True
|
||||
|
||||
schedule = service.get_schedule(schedule_id)
|
||||
assert schedule['last_run'] is not None
|
||||
assert schedule['next_run'] is not None
|
||||
|
||||
def test_update_run_times_not_found(self, test_db):
|
||||
"""Test updating run times for nonexistent schedule."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Schedule .* not found"):
|
||||
service.update_run_times(
|
||||
999,
|
||||
datetime.utcnow(),
|
||||
datetime.utcnow() + timedelta(days=1)
|
||||
)
|
||||
|
||||
|
||||
class TestCronValidation:
|
||||
"""Tests for cron expression validation."""
|
||||
|
||||
def test_validate_cron_valid_expressions(self, test_db):
|
||||
"""Test validating various valid cron expressions."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
valid_expressions = [
|
||||
'0 0 * * *', # Daily at midnight
|
||||
'*/15 * * * *', # Every 15 minutes
|
||||
'0 2 * * 0', # Weekly on Sunday
|
||||
'0 0 1 * *', # Monthly
|
||||
'30 14 * * 1-5', # Weekdays
|
||||
'0 */4 * * *', # Every 4 hours
|
||||
]
|
||||
|
||||
for expr in valid_expressions:
|
||||
is_valid, error = service.validate_cron_expression(expr)
|
||||
assert is_valid is True, f"Expression '{expr}' should be valid"
|
||||
assert error is None
|
||||
|
||||
def test_validate_cron_invalid_expressions(self, test_db):
|
||||
"""Test validating invalid cron expressions."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
invalid_expressions = [
|
||||
'invalid',
|
||||
'60 0 * * *', # Invalid minute (0-59)
|
||||
'0 24 * * *', # Invalid hour (0-23)
|
||||
'0 0 32 * *', # Invalid day (1-31)
|
||||
'0 0 * 13 *', # Invalid month (1-12)
|
||||
'0 0 * * 7', # Invalid weekday (0-6)
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
is_valid, error = service.validate_cron_expression(expr)
|
||||
assert is_valid is False, f"Expression '{expr}' should be invalid"
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestNextRunCalculation:
|
||||
"""Tests for next run time calculation."""
|
||||
|
||||
def test_calculate_next_run(self, test_db):
|
||||
"""Test calculating next run time."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
# Daily at 2 AM
|
||||
next_run = service.calculate_next_run('0 2 * * *')
|
||||
|
||||
assert next_run is not None
|
||||
assert isinstance(next_run, datetime)
|
||||
assert next_run > datetime.utcnow()
|
||||
|
||||
def test_calculate_next_run_from_time(self, test_db):
|
||||
"""Test calculating next run from specific time."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
base_time = datetime(2025, 1, 1, 0, 0, 0)
|
||||
next_run = service.calculate_next_run('0 2 * * *', from_time=base_time)
|
||||
|
||||
# Should be 2 AM on same day
|
||||
assert next_run.hour == 2
|
||||
assert next_run.minute == 0
|
||||
|
||||
def test_calculate_next_run_invalid_cron(self, test_db):
|
||||
"""Test calculating next run with invalid cron raises error."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid cron expression"):
|
||||
service.calculate_next_run('invalid cron')
|
||||
|
||||
|
||||
class TestScheduleHistory:
|
||||
"""Tests for schedule execution history."""
|
||||
|
||||
def test_get_schedule_history_empty(self, test_db, sample_config_file):
|
||||
"""Test getting history for schedule with no executions."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
history = service.get_schedule_history(schedule_id)
|
||||
assert len(history) == 0
|
||||
|
||||
def test_get_schedule_history_with_scans(self, test_db, sample_config_file):
|
||||
"""Test getting history with multiple scans."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Create 15 scans
|
||||
for i in range(15):
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow() - timedelta(days=i),
|
||||
status='completed',
|
||||
config_file=sample_config_file,
|
||||
title=f'Scan {i}',
|
||||
triggered_by='scheduled',
|
||||
schedule_id=schedule_id
|
||||
)
|
||||
test_db.add(scan)
|
||||
test_db.commit()
|
||||
|
||||
# Get history (default limit 10)
|
||||
history = service.get_schedule_history(schedule_id, limit=10)
|
||||
assert len(history) == 10
|
||||
assert history[0]['title'] == 'Scan 0' # Most recent first
|
||||
|
||||
def test_get_schedule_history_custom_limit(self, test_db, sample_config_file):
|
||||
"""Test getting history with custom limit."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
# Create 10 scans
|
||||
for i in range(10):
|
||||
scan = Scan(
|
||||
timestamp=datetime.utcnow() - timedelta(days=i),
|
||||
status='completed',
|
||||
config_file=sample_config_file,
|
||||
title=f'Scan {i}',
|
||||
triggered_by='scheduled',
|
||||
schedule_id=schedule_id
|
||||
)
|
||||
test_db.add(scan)
|
||||
test_db.commit()
|
||||
|
||||
# Get only 5
|
||||
history = service.get_schedule_history(schedule_id, limit=5)
|
||||
assert len(history) == 5
|
||||
|
||||
|
||||
class TestScheduleSerialization:
|
||||
"""Tests for schedule serialization."""
|
||||
|
||||
def test_schedule_to_dict(self, test_db, sample_config_file):
|
||||
"""Test converting schedule to dictionary."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test Schedule',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
result = service.get_schedule(schedule_id)
|
||||
|
||||
# Verify all required fields
|
||||
assert 'id' in result
|
||||
assert 'name' in result
|
||||
assert 'config_file' in result
|
||||
assert 'cron_expression' in result
|
||||
assert 'enabled' in result
|
||||
assert 'last_run' in result
|
||||
assert 'next_run' in result
|
||||
assert 'next_run_relative' in result
|
||||
assert 'created_at' in result
|
||||
assert 'updated_at' in result
|
||||
assert 'history' in result
|
||||
|
||||
def test_schedule_relative_time_formatting(self, test_db, sample_config_file):
|
||||
"""Test relative time formatting in schedule dict."""
|
||||
service = ScheduleService(test_db)
|
||||
|
||||
schedule_id = service.create_schedule(
|
||||
name='Test',
|
||||
config_file=sample_config_file,
|
||||
cron_expression='0 2 * * *',
|
||||
enabled=True
|
||||
)
|
||||
|
||||
result = service.get_schedule(schedule_id)
|
||||
|
||||
# Should have relative time for next_run
|
||||
assert result['next_run_relative'] is not None
|
||||
assert isinstance(result['next_run_relative'], str)
|
||||
assert 'in' in result['next_run_relative'].lower()
|
||||
325
app/tests/test_stats_api.py
Normal file
325
app/tests/test_stats_api.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Tests for stats API endpoints.
|
||||
|
||||
Tests dashboard statistics and trending data endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from web.models import Scan
|
||||
|
||||
|
||||
class TestStatsAPI:
|
||||
"""Test suite for stats API endpoints."""
|
||||
|
||||
def test_scan_trend_default_30_days(self, client, auth_headers, db_session):
|
||||
"""Test scan trend endpoint with default 30 days."""
|
||||
# Create test scans over multiple days
|
||||
today = datetime.utcnow()
|
||||
for i in range(5):
|
||||
scan_date = today - timedelta(days=i)
|
||||
for j in range(i + 1): # Create 1, 2, 3, 4, 5 scans per day
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=scan_date,
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
db_session.commit()
|
||||
|
||||
# Request trend data
|
||||
response = client.get('/api/stats/scan-trend', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert 'labels' in data
|
||||
assert 'values' in data
|
||||
assert 'start_date' in data
|
||||
assert 'end_date' in data
|
||||
assert 'total_scans' in data
|
||||
|
||||
# Should have 30 days of data
|
||||
assert len(data['labels']) == 30
|
||||
assert len(data['values']) == 30
|
||||
|
||||
# Total scans should match (1+2+3+4+5 = 15)
|
||||
assert data['total_scans'] == 15
|
||||
|
||||
# Values should be non-negative integers
|
||||
assert all(isinstance(v, int) for v in data['values'])
|
||||
assert all(v >= 0 for v in data['values'])
|
||||
|
||||
def test_scan_trend_custom_days(self, client, auth_headers, db_session):
|
||||
"""Test scan trend endpoint with custom number of days."""
|
||||
# Create test scans
|
||||
today = datetime.utcnow()
|
||||
for i in range(10):
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today - timedelta(days=i),
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
db_session.commit()
|
||||
|
||||
# Request 7 days of data
|
||||
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data['labels']) == 7
|
||||
assert len(data['values']) == 7
|
||||
assert data['total_scans'] == 7
|
||||
|
||||
def test_scan_trend_max_days_365(self, client, auth_headers):
|
||||
"""Test scan trend endpoint accepts maximum 365 days."""
|
||||
response = client.get('/api/stats/scan-trend?days=365', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data['labels']) == 365
|
||||
|
||||
def test_scan_trend_rejects_days_over_365(self, client, auth_headers):
|
||||
"""Test scan trend endpoint rejects more than 365 days."""
|
||||
response = client.get('/api/stats/scan-trend?days=366', headers=auth_headers)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
assert '365' in data['error']
|
||||
|
||||
def test_scan_trend_rejects_days_less_than_1(self, client, auth_headers):
|
||||
"""Test scan trend endpoint rejects days less than 1."""
|
||||
response = client.get('/api/stats/scan-trend?days=0', headers=auth_headers)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_scan_trend_fills_missing_days_with_zero(self, client, auth_headers, db_session):
|
||||
"""Test scan trend fills days with no scans as zero."""
|
||||
# Create scans only on specific days
|
||||
today = datetime.utcnow()
|
||||
|
||||
# Create scan 5 days ago
|
||||
scan1 = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today - timedelta(days=5),
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan1)
|
||||
|
||||
# Create scan 10 days ago
|
||||
scan2 = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today - timedelta(days=10),
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan2)
|
||||
db_session.commit()
|
||||
|
||||
# Request 15 days
|
||||
response = client.get('/api/stats/scan-trend?days=15', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
|
||||
# Should have 15 days of data
|
||||
assert len(data['values']) == 15
|
||||
|
||||
# Most days should be zero
|
||||
zero_days = sum(1 for v in data['values'] if v == 0)
|
||||
assert zero_days >= 13 # At least 13 days with no scans
|
||||
|
||||
def test_scan_trend_requires_authentication(self, client):
|
||||
"""Test scan trend endpoint requires authentication."""
|
||||
response = client.get('/api/stats/scan-trend')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_summary_endpoint(self, client, auth_headers, db_session):
|
||||
"""Test summary statistics endpoint."""
|
||||
# Create test scans with different statuses
|
||||
today = datetime.utcnow()
|
||||
|
||||
# 5 completed scans
|
||||
for i in range(5):
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today - timedelta(days=i),
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
|
||||
# 2 failed scans
|
||||
for i in range(2):
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today - timedelta(days=i),
|
||||
status='failed',
|
||||
duration=5.0
|
||||
)
|
||||
db_session.add(scan)
|
||||
|
||||
# 1 running scan
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today,
|
||||
status='running',
|
||||
duration=None
|
||||
)
|
||||
db_session.add(scan)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Request summary
|
||||
response = client.get('/api/stats/summary', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert 'total_scans' in data
|
||||
assert 'completed_scans' in data
|
||||
assert 'failed_scans' in data
|
||||
assert 'running_scans' in data
|
||||
assert 'scans_today' in data
|
||||
assert 'scans_this_week' in data
|
||||
|
||||
# Verify counts
|
||||
assert data['total_scans'] == 8
|
||||
assert data['completed_scans'] == 5
|
||||
assert data['failed_scans'] == 2
|
||||
assert data['running_scans'] == 1
|
||||
assert data['scans_today'] >= 1
|
||||
assert data['scans_this_week'] >= 1
|
||||
|
||||
def test_summary_with_no_scans(self, client, auth_headers):
|
||||
"""Test summary endpoint with no scans."""
|
||||
response = client.get('/api/stats/summary', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert data['total_scans'] == 0
|
||||
assert data['completed_scans'] == 0
|
||||
assert data['failed_scans'] == 0
|
||||
assert data['running_scans'] == 0
|
||||
assert data['scans_today'] == 0
|
||||
assert data['scans_this_week'] == 0
|
||||
|
||||
def test_summary_scans_today(self, client, auth_headers, db_session):
|
||||
"""Test summary counts scans today correctly."""
|
||||
today = datetime.utcnow()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
# Create 3 scans today
|
||||
for i in range(3):
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today,
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
|
||||
# Create 2 scans yesterday
|
||||
for i in range(2):
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=yesterday,
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
response = client.get('/api/stats/summary', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
assert data['scans_today'] == 3
|
||||
assert data['scans_this_week'] >= 3
|
||||
|
||||
def test_summary_scans_this_week(self, client, auth_headers, db_session):
|
||||
"""Test summary counts scans this week correctly."""
|
||||
today = datetime.utcnow()
|
||||
|
||||
# Create scans over the last 10 days
|
||||
for i in range(10):
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=today - timedelta(days=i),
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
response = client.get('/api/stats/summary', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
# Last 7 days (0-6) = 7 scans
|
||||
assert data['scans_this_week'] == 7
|
||||
|
||||
def test_summary_requires_authentication(self, client):
|
||||
"""Test summary endpoint requires authentication."""
|
||||
response = client.get('/api/stats/summary')
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_scan_trend_date_format(self, client, auth_headers, db_session):
|
||||
"""Test scan trend returns dates in correct format."""
|
||||
# Create a scan
|
||||
scan = Scan(
|
||||
config_file='/app/configs/test.yaml',
|
||||
timestamp=datetime.utcnow(),
|
||||
status='completed',
|
||||
duration=10.5
|
||||
)
|
||||
db_session.add(scan)
|
||||
db_session.commit()
|
||||
|
||||
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
|
||||
# Check date format (YYYY-MM-DD)
|
||||
for label in data['labels']:
|
||||
assert len(label) == 10
|
||||
assert label[4] == '-'
|
||||
assert label[7] == '-'
|
||||
# Try parsing to ensure valid date
|
||||
datetime.strptime(label, '%Y-%m-%d')
|
||||
|
||||
def test_scan_trend_consecutive_dates(self, client, auth_headers):
|
||||
"""Test scan trend returns consecutive dates."""
|
||||
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
labels = data['labels']
|
||||
|
||||
# Convert to datetime objects
|
||||
dates = [datetime.strptime(label, '%Y-%m-%d') for label in labels]
|
||||
|
||||
# Check dates are consecutive
|
||||
for i in range(len(dates) - 1):
|
||||
diff = dates[i + 1] - dates[i]
|
||||
assert diff.days == 1, f"Dates not consecutive: {dates[i]} to {dates[i+1]}"
|
||||
|
||||
def test_scan_trend_ends_with_today(self, client, auth_headers):
|
||||
"""Test scan trend ends with today's date."""
|
||||
response = client.get('/api/stats/scan-trend?days=7', headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.get_json()
|
||||
|
||||
# Last date should be today
|
||||
today = datetime.utcnow().date()
|
||||
last_date = datetime.strptime(data['labels'][-1], '%Y-%m-%d').date()
|
||||
assert last_date == today
|
||||
Reference in New Issue
Block a user