diff --git a/app/init_db.py b/app/init_db.py index edad5d4..baade36 100755 --- a/app/init_db.py +++ b/app/init_db.py @@ -54,7 +54,7 @@ def init_default_alert_rules(session): 'webhook_enabled': False, 'severity': 'warning', 'filter_conditions': None, - 'config_file': None + 'config_id': None }, { 'name': 'Drift Detection', @@ -65,7 +65,7 @@ def init_default_alert_rules(session): 'webhook_enabled': False, 'severity': 'info', 'filter_conditions': None, - 'config_file': None + 'config_id': None }, { 'name': 'Certificate Expiry Warning', @@ -76,7 +76,7 @@ def init_default_alert_rules(session): 'webhook_enabled': False, 'severity': 'warning', 'filter_conditions': None, - 'config_file': None + 'config_id': None }, { 'name': 'Weak TLS Detection', @@ -87,7 +87,7 @@ def init_default_alert_rules(session): 'webhook_enabled': False, 'severity': 'warning', 'filter_conditions': None, - 'config_file': None + 'config_id': None }, { 'name': 'Host Down Detection', diff --git a/app/src/report_generator.py b/app/src/report_generator.py index ba86f83..3f18b6c 100755 --- a/app/src/report_generator.py +++ b/app/src/report_generator.py @@ -78,7 +78,7 @@ class HTMLReportGenerator: 'title': self.report_data.get('title', 'SneakyScanner Report'), 'scan_time': self.report_data.get('scan_time'), 'scan_duration': self.report_data.get('scan_duration'), - 'config_file': self.report_data.get('config_file'), + 'config_id': self.report_data.get('config_id'), 'sites': self.report_data.get('sites', []), 'summary_stats': summary_stats, 'drift_alerts': drift_alerts, diff --git a/app/src/scanner.py b/app/src/scanner.py index 7881783..69027f5 100644 --- a/app/src/scanner.py +++ b/app/src/scanner.py @@ -948,7 +948,6 @@ class SneakyScanner: 'title': self.config['title'], 'scan_time': datetime.utcnow().isoformat() + 'Z', 'scan_duration': scan_duration, - 'config_file': str(self.config_path) if self.config_path else None, 'config_id': self.config_id, 'sites': [] } diff --git a/app/templates/report_template.html b/app/templates/report_template.html index 0e668bb..f4328f2 100644 --- a/app/templates/report_template.html +++ b/app/templates/report_template.html @@ -490,8 +490,8 @@
diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 6c30ab7..c182d31 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -13,7 +13,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from web.app import create_app -from web.models import Base, Scan +from web.models import Base, Scan, ScanConfig from web.utils.settings import PasswordManager, SettingsManager @@ -53,7 +53,7 @@ def sample_scan_report(): 'title': 'Test Scan', 'scan_time': '2025-11-14T10:30:00Z', 'scan_duration': 125.5, - 'config_file': '/app/configs/test.yaml', + 'config_id': 1, 'sites': [ { 'name': 'Test Site', @@ -199,6 +199,53 @@ def sample_invalid_config_file(tmp_path): return str(config_file) +@pytest.fixture +def sample_db_config(db): + """ + Create a sample database config for testing. + + Args: + db: Database session fixture + + Returns: + ScanConfig model instance with ID + """ + import json + + config_data = { + 'title': 'Test Scan', + 'sites': [ + { + 'name': 'Test Site', + 'ips': [ + { + 'address': '192.168.1.10', + 'expected': { + 'ping': True, + 'tcp_ports': [22, 80, 443], + 'udp_ports': [53], + 'services': ['ssh', 'http', 'https'] + } + } + ] + } + ] + } + + scan_config = ScanConfig( + title='Test Scan', + config_data=json.dumps(config_data), + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + db.add(scan_config) + db.commit() + db.refresh(scan_config) + + return scan_config + + @pytest.fixture(scope='function') def app(): """ @@ -269,7 +316,7 @@ def sample_scan(db): scan = Scan( timestamp=datetime.utcnow(), status='completed', - config_file='/app/configs/test.yaml', + config_id=1, title='Test Scan', duration=125.5, triggered_by='test', diff --git a/app/tests/test_background_jobs.py b/app/tests/test_background_jobs.py index ddc754f..fb574a5 100644 --- a/app/tests/test_background_jobs.py +++ b/app/tests/test_background_jobs.py @@ -23,12 +23,12 @@ class TestBackgroundJobs: assert app.scheduler.scheduler is not None assert app.scheduler.scheduler.running - def test_queue_scan_job(self, app, db, sample_config_file): + def test_queue_scan_job(self, app, db, sample_db_config): """Test queuing a scan for background execution.""" # Create a scan via service scan_service = ScanService(db) scan_id = scan_service.trigger_scan( - config_file=sample_config_file, + config_id=sample_db_config.id, triggered_by='test', scheduler=app.scheduler ) @@ -43,12 +43,12 @@ class TestBackgroundJobs: assert job is not None assert job.id == f'scan_{scan_id}' - def test_trigger_scan_without_scheduler(self, db, sample_config_file): + def test_trigger_scan_without_scheduler(self, db, sample_db_config): """Test triggering scan without scheduler logs warning.""" # Create scan without scheduler scan_service = ScanService(db) scan_id = scan_service.trigger_scan( - config_file=sample_config_file, + config_id=sample_db_config.id, triggered_by='test', scheduler=None # No scheduler ) @@ -58,13 +58,13 @@ class TestBackgroundJobs: assert scan is not None assert scan.status == 'running' - def test_scheduler_service_queue_scan(self, app, db, sample_config_file): + def test_scheduler_service_queue_scan(self, app, db, sample_db_config): """Test SchedulerService.queue_scan directly.""" # Create scan record first scan = Scan( timestamp=datetime.utcnow(), status='running', - config_file=sample_config_file, + config_id=sample_db_config.id, title='Test Scan', triggered_by='test' ) @@ -72,27 +72,27 @@ class TestBackgroundJobs: db.commit() # Queue the scan - job_id = app.scheduler.queue_scan(scan.id, sample_config_file) + job_id = app.scheduler.queue_scan(scan.id, sample_db_config) # Verify job was queued assert job_id == f'scan_{scan.id}' job = app.scheduler.scheduler.get_job(job_id) assert job is not None - def test_scheduler_list_jobs(self, app, db, sample_config_file): + def test_scheduler_list_jobs(self, app, db, sample_db_config): """Test listing scheduled jobs.""" # Queue a few scans for i in range(3): scan = Scan( timestamp=datetime.utcnow(), status='running', - config_file=sample_config_file, + config_id=sample_db_config.id, title=f'Test Scan {i}', triggered_by='test' ) db.add(scan) db.commit() - app.scheduler.queue_scan(scan.id, sample_config_file) + app.scheduler.queue_scan(scan.id, sample_db_config) # List jobs jobs = app.scheduler.list_jobs() @@ -106,20 +106,20 @@ class TestBackgroundJobs: assert 'name' in job assert 'trigger' in job - def test_scheduler_get_job_status(self, app, db, sample_config_file): + def test_scheduler_get_job_status(self, app, db, sample_db_config): """Test getting status of a specific job.""" # Create and queue a scan scan = Scan( timestamp=datetime.utcnow(), status='running', - config_file=sample_config_file, + config_id=sample_db_config.id, title='Test Scan', triggered_by='test' ) db.add(scan) db.commit() - job_id = app.scheduler.queue_scan(scan.id, sample_config_file) + job_id = app.scheduler.queue_scan(scan.id, sample_db_config) # Get job status status = app.scheduler.get_job_status(job_id) @@ -133,13 +133,13 @@ class TestBackgroundJobs: status = app.scheduler.get_job_status('nonexistent_job_id') assert status is None - def test_scan_timing_fields(self, db, sample_config_file): + def test_scan_timing_fields(self, db, sample_db_config): """Test that scan timing fields are properly set.""" # Create scan with started_at scan = Scan( timestamp=datetime.utcnow(), status='running', - config_file=sample_config_file, + config_id=sample_db_config.id, title='Test Scan', triggered_by='test', started_at=datetime.utcnow() @@ -161,13 +161,13 @@ class TestBackgroundJobs: assert scan.completed_at is not None assert (scan.completed_at - scan.started_at).total_seconds() >= 0 - def test_scan_error_handling(self, db, sample_config_file): + def test_scan_error_handling(self, db, sample_db_config): """Test that error messages are stored correctly.""" # Create failed scan scan = Scan( timestamp=datetime.utcnow(), status='failed', - config_file=sample_config_file, + config_id=sample_db_config.id, title='Failed Scan', triggered_by='test', started_at=datetime.utcnow(), @@ -188,7 +188,7 @@ class TestBackgroundJobs: assert status['error_message'] == 'Test error message' @pytest.mark.skip(reason="Requires actual scanner execution - slow test") - def test_background_scan_execution(self, app, db, sample_config_file): + def test_background_scan_execution(self, app, db, sample_db_config): """ Integration test for actual background scan execution. @@ -200,7 +200,7 @@ class TestBackgroundJobs: # Trigger scan scan_service = ScanService(db) scan_id = scan_service.trigger_scan( - config_file=sample_config_file, + config_id=sample_db_config.id, triggered_by='test', scheduler=app.scheduler ) diff --git a/app/tests/test_scan_api.py b/app/tests/test_scan_api.py index 5cd98eb..648ed29 100644 --- a/app/tests/test_scan_api.py +++ b/app/tests/test_scan_api.py @@ -44,7 +44,7 @@ class TestScanAPIEndpoints: scan = Scan( timestamp=datetime.utcnow(), status='completed', - config_file=f'/app/configs/test{i}.yaml', + config_id=sample_db_config.id, title=f'Test Scan {i}', triggered_by='test' ) @@ -81,7 +81,7 @@ class TestScanAPIEndpoints: scan = Scan( timestamp=datetime.utcnow(), status=status, - config_file='/app/configs/test.yaml', + config_id=1, title=f'{status.capitalize()} Scan', triggered_by='test' ) @@ -123,10 +123,10 @@ class TestScanAPIEndpoints: assert 'error' in data assert data['error'] == 'Not found' - def test_trigger_scan_success(self, client, db, sample_config_file): + def test_trigger_scan_success(self, client, db, sample_db_config): """Test triggering a new scan.""" response = client.post('/api/scans', - json={'config_file': str(sample_config_file)}, + json={'config_file': str(sample_db_config)}, content_type='application/json' ) assert response.status_code == 201 @@ -222,7 +222,7 @@ class TestScanAPIEndpoints: assert 'error' in data assert 'message' in data - def test_scan_workflow_integration(self, client, db, sample_config_file): + def test_scan_workflow_integration(self, client, db, sample_db_config): """ Test complete scan workflow: trigger → status → retrieve → delete. @@ -231,7 +231,7 @@ class TestScanAPIEndpoints: """ # Step 1: Trigger scan response = client.post('/api/scans', - json={'config_file': str(sample_config_file)}, + json={'config_file': str(sample_db_config)}, content_type='application/json' ) assert response.status_code == 201 diff --git a/app/tests/test_scan_comparison.py b/app/tests/test_scan_comparison.py index 30e3f91..136a760 100644 --- a/app/tests/test_scan_comparison.py +++ b/app/tests/test_scan_comparison.py @@ -17,10 +17,10 @@ class TestScanComparison: """Tests for scan comparison methods.""" @pytest.fixture - def scan1_data(self, test_db, sample_config_file): + def scan1_data(self, test_db, sample_db_config): """Create first scan with test data.""" service = ScanService(test_db) - scan_id = service.trigger_scan(sample_config_file, triggered_by='manual') + scan_id = service.trigger_scan(sample_db_config, triggered_by='manual') # Get scan and add some test data scan = test_db.query(Scan).filter(Scan.id == scan_id).first() @@ -77,10 +77,10 @@ class TestScanComparison: return scan_id @pytest.fixture - def scan2_data(self, test_db, sample_config_file): + def scan2_data(self, test_db, sample_db_config): """Create second scan with modified test data.""" service = ScanService(test_db) - scan_id = service.trigger_scan(sample_config_file, triggered_by='manual') + scan_id = service.trigger_scan(sample_db_config, triggered_by='manual') # Get scan and add some test data scan = test_db.query(Scan).filter(Scan.id == scan_id).first() diff --git a/app/tests/test_scan_service.py b/app/tests/test_scan_service.py index 456c5bb..2936bb9 100644 --- a/app/tests/test_scan_service.py +++ b/app/tests/test_scan_service.py @@ -13,49 +13,42 @@ from web.services.scan_service import ScanService class TestScanServiceTrigger: """Tests for triggering scans.""" - def test_trigger_scan_valid_config(self, test_db, sample_config_file): - """Test triggering a scan with valid config file.""" - service = ScanService(test_db) + def test_trigger_scan_valid_config(self, db, sample_db_config): + """Test triggering a scan with valid config.""" + service = ScanService(db) - scan_id = service.trigger_scan(sample_config_file, triggered_by='manual') + scan_id = service.trigger_scan(config_id=sample_db_config.id, triggered_by='manual') # Verify scan created assert scan_id is not None assert isinstance(scan_id, int) # Verify scan in database - scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + scan = db.query(Scan).filter(Scan.id == scan_id).first() assert scan is not None assert scan.status == 'running' assert scan.title == 'Test Scan' assert scan.triggered_by == 'manual' - assert scan.config_file == sample_config_file + assert scan.config_id == sample_db_config.id - def test_trigger_scan_invalid_config(self, test_db, sample_invalid_config_file): - """Test triggering a scan with invalid config file.""" - service = ScanService(test_db) + def test_trigger_scan_invalid_config(self, db): + """Test triggering a scan with invalid config ID.""" + service = ScanService(db) - with pytest.raises(ValueError, match="Invalid config file"): - service.trigger_scan(sample_invalid_config_file) + with pytest.raises(ValueError, match="not found"): + service.trigger_scan(config_id=99999) - def test_trigger_scan_nonexistent_file(self, test_db): - """Test triggering a scan with nonexistent config file.""" - service = ScanService(test_db) - - with pytest.raises(ValueError, match="does not exist"): - service.trigger_scan('/nonexistent/config.yaml') - - def test_trigger_scan_with_schedule(self, test_db, sample_config_file): + def test_trigger_scan_with_schedule(self, db, sample_db_config): """Test triggering a scan via schedule.""" - service = ScanService(test_db) + service = ScanService(db) scan_id = service.trigger_scan( - sample_config_file, + config_id=sample_db_config.id, triggered_by='scheduled', schedule_id=42 ) - scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + scan = db.query(Scan).filter(Scan.id == scan_id).first() assert scan.triggered_by == 'scheduled' assert scan.schedule_id == 42 @@ -63,19 +56,19 @@ class TestScanServiceTrigger: class TestScanServiceGet: """Tests for retrieving scans.""" - def test_get_scan_not_found(self, test_db): + def test_get_scan_not_found(self, db): """Test getting a nonexistent scan.""" - service = ScanService(test_db) + service = ScanService(db) result = service.get_scan(999) assert result is None - def test_get_scan_found(self, test_db, sample_config_file): + def test_get_scan_found(self, db, sample_db_config): """Test getting an existing scan.""" - service = ScanService(test_db) + service = ScanService(db) # Create a scan - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) # Retrieve it result = service.get_scan(scan_id) @@ -90,9 +83,9 @@ class TestScanServiceGet: class TestScanServiceList: """Tests for listing scans.""" - def test_list_scans_empty(self, test_db): + def test_list_scans_empty(self, db): """Test listing scans when database is empty.""" - service = ScanService(test_db) + service = ScanService(db) result = service.list_scans(page=1, per_page=20) @@ -100,13 +93,13 @@ class TestScanServiceList: assert len(result.items) == 0 assert result.pages == 0 - def test_list_scans_with_data(self, test_db, sample_config_file): + def test_list_scans_with_data(self, db, sample_db_config): """Test listing scans with multiple scans.""" - service = ScanService(test_db) + service = ScanService(db) # Create 3 scans for i in range(3): - service.trigger_scan(sample_config_file, triggered_by='api') + service.trigger_scan(config_id=sample_db_config.id, triggered_by='api') # List all scans result = service.list_scans(page=1, per_page=20) @@ -115,13 +108,13 @@ class TestScanServiceList: assert len(result.items) == 3 assert result.pages == 1 - def test_list_scans_pagination(self, test_db, sample_config_file): + def test_list_scans_pagination(self, db, sample_db_config): """Test pagination.""" - service = ScanService(test_db) + service = ScanService(db) # Create 5 scans for i in range(5): - service.trigger_scan(sample_config_file) + service.trigger_scan(config_id=sample_db_config.id) # Get page 1 (2 items per page) result = service.list_scans(page=1, per_page=2) @@ -141,18 +134,18 @@ class TestScanServiceList: assert len(result.items) == 1 assert result.has_next is False - def test_list_scans_filter_by_status(self, test_db, sample_config_file): + def test_list_scans_filter_by_status(self, db, sample_db_config): """Test filtering scans by status.""" - service = ScanService(test_db) + service = ScanService(db) # Create scans with different statuses - scan_id_1 = service.trigger_scan(sample_config_file) - scan_id_2 = service.trigger_scan(sample_config_file) + scan_id_1 = service.trigger_scan(config_id=sample_db_config.id) + scan_id_2 = service.trigger_scan(config_id=sample_db_config.id) # Mark one as completed - scan = test_db.query(Scan).filter(Scan.id == scan_id_1).first() + scan = db.query(Scan).filter(Scan.id == scan_id_1).first() scan.status = 'completed' - test_db.commit() + db.commit() # Filter by running result = service.list_scans(status_filter='running') @@ -162,9 +155,9 @@ class TestScanServiceList: result = service.list_scans(status_filter='completed') assert result.total == 1 - def test_list_scans_invalid_status_filter(self, test_db): + def test_list_scans_invalid_status_filter(self, db): """Test filtering with invalid status.""" - service = ScanService(test_db) + service = ScanService(db) with pytest.raises(ValueError, match="Invalid status"): service.list_scans(status_filter='invalid_status') @@ -173,46 +166,46 @@ class TestScanServiceList: class TestScanServiceDelete: """Tests for deleting scans.""" - def test_delete_scan_not_found(self, test_db): + def test_delete_scan_not_found(self, db): """Test deleting a nonexistent scan.""" - service = ScanService(test_db) + service = ScanService(db) with pytest.raises(ValueError, match="not found"): service.delete_scan(999) - def test_delete_scan_success(self, test_db, sample_config_file): + def test_delete_scan_success(self, db, sample_db_config): """Test successful scan deletion.""" - service = ScanService(test_db) + service = ScanService(db) # Create a scan - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) # Verify it exists - assert test_db.query(Scan).filter(Scan.id == scan_id).first() is not None + assert db.query(Scan).filter(Scan.id == scan_id).first() is not None # Delete it result = service.delete_scan(scan_id) assert result is True # Verify it's gone - assert test_db.query(Scan).filter(Scan.id == scan_id).first() is None + assert db.query(Scan).filter(Scan.id == scan_id).first() is None class TestScanServiceStatus: """Tests for scan status retrieval.""" - def test_get_scan_status_not_found(self, test_db): + def test_get_scan_status_not_found(self, db): """Test getting status of nonexistent scan.""" - service = ScanService(test_db) + service = ScanService(db) result = service.get_scan_status(999) assert result is None - def test_get_scan_status_running(self, test_db, sample_config_file): + def test_get_scan_status_running(self, db, sample_db_config): """Test getting status of running scan.""" - service = ScanService(test_db) + service = ScanService(db) - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) status = service.get_scan_status(scan_id) assert status is not None @@ -221,16 +214,16 @@ class TestScanServiceStatus: assert status['progress'] == 'In progress' assert status['title'] == 'Test Scan' - def test_get_scan_status_completed(self, test_db, sample_config_file): + def test_get_scan_status_completed(self, db, sample_db_config): """Test getting status of completed scan.""" - service = ScanService(test_db) + service = ScanService(db) # Create and mark as completed - scan_id = service.trigger_scan(sample_config_file) - scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + scan_id = service.trigger_scan(config_id=sample_db_config.id) + scan = db.query(Scan).filter(Scan.id == scan_id).first() scan.status = 'completed' scan.duration = 125.5 - test_db.commit() + db.commit() status = service.get_scan_status(scan_id) @@ -242,35 +235,35 @@ class TestScanServiceStatus: class TestScanServiceDatabaseMapping: """Tests for mapping scan reports to database models.""" - def test_save_scan_to_db(self, test_db, sample_config_file, sample_scan_report): + def test_save_scan_to_db(self, db, sample_db_config, sample_scan_report): """Test saving a complete scan report to database.""" - service = ScanService(test_db) + service = ScanService(db) # Create a scan - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) # Save report to database service._save_scan_to_db(sample_scan_report, scan_id, status='completed') # Verify scan updated - scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + scan = db.query(Scan).filter(Scan.id == scan_id).first() assert scan.status == 'completed' assert scan.duration == 125.5 # Verify sites created - sites = test_db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all() + sites = db.query(ScanSite).filter(ScanSite.scan_id == scan_id).all() assert len(sites) == 1 assert sites[0].site_name == 'Test Site' # Verify IPs created - ips = test_db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all() + ips = db.query(ScanIP).filter(ScanIP.scan_id == scan_id).all() assert len(ips) == 1 assert ips[0].ip_address == '192.168.1.10' assert ips[0].ping_expected is True assert ips[0].ping_actual is True # Verify ports created (TCP: 22, 80, 443, 8080 | UDP: 53) - ports = test_db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all() + ports = db.query(ScanPort).filter(ScanPort.scan_id == scan_id).all() assert len(ports) == 5 # 4 TCP + 1 UDP # Verify TCP ports @@ -285,7 +278,7 @@ class TestScanServiceDatabaseMapping: assert udp_ports[0].port == 53 # Verify services created - services = test_db.query(ScanServiceModel).filter( + services = db.query(ScanServiceModel).filter( ScanServiceModel.scan_id == scan_id ).all() assert len(services) == 4 # SSH, HTTP (80), HTTPS, HTTP (8080) @@ -300,15 +293,15 @@ class TestScanServiceDatabaseMapping: assert https_service.http_protocol == 'https' assert https_service.screenshot_path == 'screenshots/192_168_1_10_443.png' - def test_map_port_expected_vs_actual(self, test_db, sample_config_file, sample_scan_report): + def test_map_port_expected_vs_actual(self, db, sample_db_config, sample_scan_report): """Test that expected vs actual ports are correctly flagged.""" - service = ScanService(test_db) + service = ScanService(db) - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) service._save_scan_to_db(sample_scan_report, scan_id) # Check TCP ports - tcp_ports = test_db.query(ScanPort).filter( + tcp_ports = db.query(ScanPort).filter( ScanPort.scan_id == scan_id, ScanPort.protocol == 'tcp' ).all() @@ -322,15 +315,15 @@ class TestScanServiceDatabaseMapping: # Port 8080 was not expected assert port.expected is False, f"Port {port.port} should not be expected" - def test_map_certificate_and_tls(self, test_db, sample_config_file, sample_scan_report): + def test_map_certificate_and_tls(self, db, sample_db_config, sample_scan_report): """Test that certificate and TLS data are correctly mapped.""" - service = ScanService(test_db) + service = ScanService(db) - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) service._save_scan_to_db(sample_scan_report, scan_id) # Find HTTPS service - https_service = test_db.query(ScanServiceModel).filter( + https_service = db.query(ScanServiceModel).filter( ScanServiceModel.scan_id == scan_id, ScanServiceModel.service_name == 'https' ).first() @@ -363,11 +356,11 @@ class TestScanServiceDatabaseMapping: assert tls_13 is not None assert tls_13.supported is True - def test_get_scan_with_full_details(self, test_db, sample_config_file, sample_scan_report): + def test_get_scan_with_full_details(self, db, sample_db_config, sample_scan_report): """Test retrieving scan with all nested relationships.""" - service = ScanService(test_db) + service = ScanService(db) - scan_id = service.trigger_scan(sample_config_file) + scan_id = service.trigger_scan(config_id=sample_db_config.id) service._save_scan_to_db(sample_scan_report, scan_id) # Get full scan details diff --git a/app/tests/test_schedule_api.py b/app/tests/test_schedule_api.py index 7601986..2212f7a 100644 --- a/app/tests/test_schedule_api.py +++ b/app/tests/test_schedule_api.py @@ -13,20 +13,20 @@ from web.models import Schedule, Scan @pytest.fixture -def sample_schedule(db, sample_config_file): +def sample_schedule(db, sample_db_config): """ Create a sample schedule in the database for testing. Args: db: Database session fixture - sample_config_file: Path to test config file + sample_db_config: Path to test config file Returns: Schedule model instance """ schedule = Schedule( name='Daily Test Scan', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True, last_run=None, @@ -68,13 +68,13 @@ class TestScheduleAPIEndpoints: assert data['schedules'][0]['name'] == sample_schedule.name assert data['schedules'][0]['cron_expression'] == sample_schedule.cron_expression - def test_list_schedules_pagination(self, client, db, sample_config_file): + def test_list_schedules_pagination(self, client, db, sample_db_config): """Test schedule list pagination.""" # Create 25 schedules for i in range(25): schedule = Schedule( name=f'Schedule {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True, created_at=datetime.utcnow() @@ -101,13 +101,13 @@ class TestScheduleAPIEndpoints: assert len(data['schedules']) == 10 assert data['page'] == 2 - def test_list_schedules_filter_enabled(self, client, db, sample_config_file): + def test_list_schedules_filter_enabled(self, client, db, sample_db_config): """Test filtering schedules by enabled status.""" # Create enabled and disabled schedules for i in range(3): schedule = Schedule( name=f'Enabled Schedule {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True, created_at=datetime.utcnow() @@ -117,7 +117,7 @@ class TestScheduleAPIEndpoints: for i in range(2): schedule = Schedule( name=f'Disabled Schedule {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 3 * * *', enabled=False, created_at=datetime.utcnow() @@ -165,11 +165,11 @@ class TestScheduleAPIEndpoints: assert 'error' in data assert 'not found' in data['error'].lower() - def test_create_schedule(self, client, db, sample_config_file): + def test_create_schedule(self, client, db, sample_db_config): """Test creating a new schedule.""" schedule_data = { 'name': 'New Test Schedule', - 'config_file': sample_config_file, + 'config_file': sample_db_config, 'cron_expression': '0 3 * * *', 'enabled': True } @@ -211,11 +211,11 @@ class TestScheduleAPIEndpoints: assert 'error' in data assert 'missing' in data['error'].lower() - def test_create_schedule_invalid_cron(self, client, db, sample_config_file): + def test_create_schedule_invalid_cron(self, client, db, sample_db_config): """Test creating schedule with invalid cron expression.""" schedule_data = { 'name': 'Invalid Cron Schedule', - 'config_file': sample_config_file, + 'config_file': sample_db_config, 'cron_expression': 'invalid cron' } @@ -360,13 +360,13 @@ class TestScheduleAPIEndpoints: data = json.loads(response.data) assert 'error' in data - def test_delete_schedule_preserves_scans(self, client, db, sample_schedule, sample_config_file): + def test_delete_schedule_preserves_scans(self, client, db, sample_schedule, sample_db_config): """Test that deleting schedule preserves associated scans.""" # Create a scan associated with the schedule scan = Scan( timestamp=datetime.utcnow(), status='completed', - config_file=sample_config_file, + config_id=sample_db_config.id, title='Test Scan', triggered_by='scheduled', schedule_id=sample_schedule.id @@ -409,14 +409,14 @@ class TestScheduleAPIEndpoints: data = json.loads(response.data) assert 'error' in data - def test_get_schedule_with_history(self, client, db, sample_schedule, sample_config_file): + def test_get_schedule_with_history(self, client, db, sample_schedule, sample_db_config): """Test getting schedule includes execution history.""" # Create some scans for this schedule for i in range(5): scan = Scan( timestamp=datetime.utcnow(), status='completed', - config_file=sample_config_file, + config_id=sample_db_config.id, title=f'Scheduled Scan {i}', triggered_by='scheduled', schedule_id=sample_schedule.id @@ -431,12 +431,12 @@ class TestScheduleAPIEndpoints: assert 'history' in data assert len(data['history']) == 5 - def test_schedule_workflow_integration(self, client, db, sample_config_file): + def test_schedule_workflow_integration(self, client, db, sample_db_config): """Test complete schedule workflow: create → update → trigger → delete.""" # 1. Create schedule schedule_data = { 'name': 'Integration Test Schedule', - 'config_file': sample_config_file, + 'config_file': sample_db_config, 'cron_expression': '0 2 * * *', 'enabled': True } @@ -482,14 +482,14 @@ class TestScheduleAPIEndpoints: scan = db.query(Scan).filter(Scan.id == scan_id).first() assert scan is not None - def test_list_schedules_ordering(self, client, db, sample_config_file): + def test_list_schedules_ordering(self, client, db, sample_db_config): """Test that schedules are ordered by next_run time.""" # Create schedules with different next_run times schedules = [] for i in range(3): schedule = Schedule( name=f'Schedule {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True, next_run=datetime(2025, 11, 15 + i, 2, 0, 0), @@ -501,7 +501,7 @@ class TestScheduleAPIEndpoints: # Create a disabled schedule (next_run is None) disabled_schedule = Schedule( name='Disabled Schedule', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 3 * * *', enabled=False, next_run=None, @@ -523,11 +523,11 @@ class TestScheduleAPIEndpoints: assert returned_schedules[2]['id'] == schedules[2].id assert returned_schedules[3]['id'] == disabled_schedule.id - def test_create_schedule_with_disabled(self, client, db, sample_config_file): + def test_create_schedule_with_disabled(self, client, db, sample_db_config): """Test creating a disabled schedule.""" schedule_data = { 'name': 'Disabled Schedule', - 'config_file': sample_config_file, + 'config_file': sample_db_config, 'cron_expression': '0 2 * * *', 'enabled': False } @@ -587,7 +587,7 @@ class TestScheduleAPIAuthentication: class TestScheduleAPICronValidation: """Test suite for cron expression validation.""" - def test_valid_cron_expressions(self, client, db, sample_config_file): + def test_valid_cron_expressions(self, client, db, sample_db_config): """Test various valid cron expressions.""" valid_expressions = [ '0 2 * * *', # Daily at 2am @@ -600,7 +600,7 @@ class TestScheduleAPICronValidation: for cron_expr in valid_expressions: schedule_data = { 'name': f'Schedule for {cron_expr}', - 'config_file': sample_config_file, + 'config_file': sample_db_config, 'cron_expression': cron_expr } @@ -612,7 +612,7 @@ class TestScheduleAPICronValidation: assert response.status_code == 201, \ f"Valid cron expression '{cron_expr}' should be accepted" - def test_invalid_cron_expressions(self, client, db, sample_config_file): + def test_invalid_cron_expressions(self, client, db, sample_db_config): """Test various invalid cron expressions.""" invalid_expressions = [ 'invalid', @@ -626,7 +626,7 @@ class TestScheduleAPICronValidation: for cron_expr in invalid_expressions: schedule_data = { 'name': f'Schedule for {cron_expr}', - 'config_file': sample_config_file, + 'config_file': sample_db_config, 'cron_expression': cron_expr } diff --git a/app/tests/test_schedule_service.py b/app/tests/test_schedule_service.py index 4e4741d..308e35e 100644 --- a/app/tests/test_schedule_service.py +++ b/app/tests/test_schedule_service.py @@ -15,13 +15,13 @@ from web.services.schedule_service import ScheduleService class TestScheduleServiceCreate: """Tests for creating schedules.""" - def test_create_schedule_valid(self, test_db, sample_config_file): + def test_create_schedule_valid(self, db, sample_db_config): """Test creating a schedule with valid parameters.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Daily Scan', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -31,57 +31,57 @@ class TestScheduleServiceCreate: assert isinstance(schedule_id, int) # Verify schedule in database - schedule = test_db.query(Schedule).filter(Schedule.id == schedule_id).first() + schedule = db.query(Schedule).filter(Schedule.id == schedule_id).first() assert schedule is not None assert schedule.name == 'Daily Scan' - assert schedule.config_file == sample_config_file + assert schedule.config_id == sample_db_config.id assert schedule.cron_expression == '0 2 * * *' assert schedule.enabled is True assert schedule.next_run is not None assert schedule.last_run is None - def test_create_schedule_disabled(self, test_db, sample_config_file): + def test_create_schedule_disabled(self, db, sample_db_config): """Test creating a disabled schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Disabled Scan', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 3 * * *', enabled=False ) - schedule = test_db.query(Schedule).filter(Schedule.id == schedule_id).first() + schedule = db.query(Schedule).filter(Schedule.id == schedule_id).first() assert schedule.enabled is False assert schedule.next_run is None - def test_create_schedule_invalid_cron(self, test_db, sample_config_file): + def test_create_schedule_invalid_cron(self, db, sample_db_config): """Test creating a schedule with invalid cron expression.""" - service = ScheduleService(test_db) + service = ScheduleService(db) with pytest.raises(ValueError, match="Invalid cron expression"): service.create_schedule( name='Invalid Schedule', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='invalid cron', enabled=True ) - def test_create_schedule_nonexistent_config(self, test_db): - """Test creating a schedule with nonexistent config file.""" - service = ScheduleService(test_db) + def test_create_schedule_nonexistent_config(self, db): + """Test creating a schedule with nonexistent config.""" + service = ScheduleService(db) - with pytest.raises(ValueError, match="Config file not found"): + with pytest.raises(ValueError, match="not found"): service.create_schedule( name='Bad Config', - config_file='/nonexistent/config.yaml', + config_id=99999, cron_expression='0 2 * * *', enabled=True ) - def test_create_schedule_various_cron_expressions(self, test_db, sample_config_file): + def test_create_schedule_various_cron_expressions(self, db, sample_db_config): """Test creating schedules with various valid cron expressions.""" - service = ScheduleService(test_db) + service = ScheduleService(db) cron_expressions = [ '0 0 * * *', # Daily at midnight @@ -94,7 +94,7 @@ class TestScheduleServiceCreate: for i, cron in enumerate(cron_expressions): schedule_id = service.create_schedule( name=f'Schedule {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression=cron, enabled=True ) @@ -104,21 +104,21 @@ class TestScheduleServiceCreate: class TestScheduleServiceGet: """Tests for retrieving schedules.""" - def test_get_schedule_not_found(self, test_db): + def test_get_schedule_not_found(self, db): """Test getting a nonexistent schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) with pytest.raises(ValueError, match="Schedule .* not found"): service.get_schedule(999) - def test_get_schedule_found(self, test_db, sample_config_file): + def test_get_schedule_found(self, db, sample_db_config): """Test getting an existing schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Create a schedule schedule_id = service.create_schedule( name='Test Schedule', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -134,14 +134,14 @@ class TestScheduleServiceGet: assert 'history' in result assert isinstance(result['history'], list) - def test_get_schedule_with_history(self, test_db, sample_config_file): + def test_get_schedule_with_history(self, db, sample_db_config): """Test getting schedule includes execution history.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Create schedule schedule_id = service.create_schedule( name='Test Schedule', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -151,13 +151,13 @@ class TestScheduleServiceGet: scan = Scan( timestamp=datetime.utcnow() - timedelta(days=i), status='completed', - config_file=sample_config_file, + config_id=sample_db_config.id, title=f'Scan {i}', triggered_by='scheduled', schedule_id=schedule_id ) - test_db.add(scan) - test_db.commit() + db.add(scan) + db.commit() # Get schedule result = service.get_schedule(schedule_id) @@ -169,9 +169,9 @@ class TestScheduleServiceGet: class TestScheduleServiceList: """Tests for listing schedules.""" - def test_list_schedules_empty(self, test_db): + def test_list_schedules_empty(self, db): """Test listing schedules when database is empty.""" - service = ScheduleService(test_db) + service = ScheduleService(db) result = service.list_schedules(page=1, per_page=20) @@ -180,15 +180,15 @@ class TestScheduleServiceList: assert result['page'] == 1 assert result['per_page'] == 20 - def test_list_schedules_populated(self, test_db, sample_config_file): + def test_list_schedules_populated(self, db, sample_db_config): """Test listing schedules with data.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Create multiple schedules for i in range(5): service.create_schedule( name=f'Schedule {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -199,15 +199,15 @@ class TestScheduleServiceList: assert len(result['schedules']) == 5 assert all('name' in s for s in result['schedules']) - def test_list_schedules_pagination(self, test_db, sample_config_file): + def test_list_schedules_pagination(self, db, sample_db_config): """Test schedule pagination.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Create 25 schedules for i in range(25): service.create_schedule( name=f'Schedule {i:02d}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -226,22 +226,22 @@ class TestScheduleServiceList: result_page3 = service.list_schedules(page=3, per_page=10) assert len(result_page3['schedules']) == 5 - def test_list_schedules_filter_enabled(self, test_db, sample_config_file): + def test_list_schedules_filter_enabled(self, db, sample_db_config): """Test filtering schedules by enabled status.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Create enabled and disabled schedules for i in range(3): service.create_schedule( name=f'Enabled {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) for i in range(2): service.create_schedule( name=f'Disabled {i}', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=False ) @@ -262,13 +262,13 @@ class TestScheduleServiceList: class TestScheduleServiceUpdate: """Tests for updating schedules.""" - def test_update_schedule_name(self, test_db, sample_config_file): + def test_update_schedule_name(self, db, sample_db_config): """Test updating schedule name.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Old Name', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -278,13 +278,13 @@ class TestScheduleServiceUpdate: assert result['name'] == 'New Name' assert result['cron_expression'] == '0 2 * * *' - def test_update_schedule_cron(self, test_db, sample_config_file): + def test_update_schedule_cron(self, db, sample_db_config): """Test updating cron expression recalculates next_run.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -302,13 +302,13 @@ class TestScheduleServiceUpdate: assert result['cron_expression'] == '0 3 * * *' assert result['next_run'] != original_next_run - def test_update_schedule_invalid_cron(self, test_db, sample_config_file): + def test_update_schedule_invalid_cron(self, db, sample_db_config): """Test updating with invalid cron expression fails.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -316,67 +316,67 @@ class TestScheduleServiceUpdate: with pytest.raises(ValueError, match="Invalid cron expression"): service.update_schedule(schedule_id, cron_expression='invalid') - def test_update_schedule_not_found(self, test_db): + def test_update_schedule_not_found(self, db): """Test updating nonexistent schedule fails.""" - service = ScheduleService(test_db) + service = ScheduleService(db) with pytest.raises(ValueError, match="Schedule .* not found"): service.update_schedule(999, name='New Name') - def test_update_schedule_invalid_config_file(self, test_db, sample_config_file): - """Test updating with nonexistent config file fails.""" - service = ScheduleService(test_db) + def test_update_schedule_invalid_config_id(self, db, sample_db_config): + """Test updating with nonexistent config ID fails.""" + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) - with pytest.raises(ValueError, match="Config file not found"): - service.update_schedule(schedule_id, config_file='/nonexistent.yaml') + with pytest.raises(ValueError, match="not found"): + service.update_schedule(schedule_id, config_id=99999) class TestScheduleServiceDelete: """Tests for deleting schedules.""" - def test_delete_schedule(self, test_db, sample_config_file): + def test_delete_schedule(self, db, sample_db_config): """Test deleting a schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='To Delete', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) # Verify exists - assert test_db.query(Schedule).filter(Schedule.id == schedule_id).first() is not None + assert db.query(Schedule).filter(Schedule.id == schedule_id).first() is not None # Delete result = service.delete_schedule(schedule_id) assert result is True # Verify deleted - assert test_db.query(Schedule).filter(Schedule.id == schedule_id).first() is None + assert db.query(Schedule).filter(Schedule.id == schedule_id).first() is None - def test_delete_schedule_not_found(self, test_db): + def test_delete_schedule_not_found(self, db): """Test deleting nonexistent schedule fails.""" - service = ScheduleService(test_db) + service = ScheduleService(db) with pytest.raises(ValueError, match="Schedule .* not found"): service.delete_schedule(999) - def test_delete_schedule_preserves_scans(self, test_db, sample_config_file): + def test_delete_schedule_preserves_scans(self, db, sample_db_config): """Test that deleting schedule preserves associated scans.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Create schedule schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -385,20 +385,20 @@ class TestScheduleServiceDelete: scan = Scan( timestamp=datetime.utcnow(), status='completed', - config_file=sample_config_file, + config_id=sample_db_config.id, title='Test Scan', triggered_by='scheduled', schedule_id=schedule_id ) - test_db.add(scan) - test_db.commit() + db.add(scan) + db.commit() scan_id = scan.id # Delete schedule service.delete_schedule(schedule_id) # Verify scan still exists (schedule_id becomes null) - remaining_scan = test_db.query(Scan).filter(Scan.id == scan_id).first() + remaining_scan = db.query(Scan).filter(Scan.id == scan_id).first() assert remaining_scan is not None assert remaining_scan.schedule_id is None @@ -406,13 +406,13 @@ class TestScheduleServiceDelete: class TestScheduleServiceToggle: """Tests for toggling schedule enabled status.""" - def test_toggle_enabled_to_disabled(self, test_db, sample_config_file): + def test_toggle_enabled_to_disabled(self, db, sample_db_config): """Test disabling an enabled schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -422,13 +422,13 @@ class TestScheduleServiceToggle: assert result['enabled'] is False assert result['next_run'] is None - def test_toggle_disabled_to_enabled(self, test_db, sample_config_file): + def test_toggle_disabled_to_enabled(self, db, sample_db_config): """Test enabling a disabled schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=False ) @@ -442,13 +442,13 @@ class TestScheduleServiceToggle: class TestScheduleServiceRunTimes: """Tests for updating run times.""" - def test_update_run_times(self, test_db, sample_config_file): + def test_update_run_times(self, db, sample_db_config): """Test updating last_run and next_run.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -463,9 +463,9 @@ class TestScheduleServiceRunTimes: assert schedule['last_run'] is not None assert schedule['next_run'] is not None - def test_update_run_times_not_found(self, test_db): + def test_update_run_times_not_found(self, db): """Test updating run times for nonexistent schedule.""" - service = ScheduleService(test_db) + service = ScheduleService(db) with pytest.raises(ValueError, match="Schedule .* not found"): service.update_run_times( @@ -478,9 +478,9 @@ class TestScheduleServiceRunTimes: class TestCronValidation: """Tests for cron expression validation.""" - def test_validate_cron_valid_expressions(self, test_db): + def test_validate_cron_valid_expressions(self, db): """Test validating various valid cron expressions.""" - service = ScheduleService(test_db) + service = ScheduleService(db) valid_expressions = [ '0 0 * * *', # Daily at midnight @@ -496,9 +496,9 @@ class TestCronValidation: assert is_valid is True, f"Expression '{expr}' should be valid" assert error is None - def test_validate_cron_invalid_expressions(self, test_db): + def test_validate_cron_invalid_expressions(self, db): """Test validating invalid cron expressions.""" - service = ScheduleService(test_db) + service = ScheduleService(db) invalid_expressions = [ 'invalid', @@ -518,9 +518,9 @@ class TestCronValidation: class TestNextRunCalculation: """Tests for next run time calculation.""" - def test_calculate_next_run(self, test_db): + def test_calculate_next_run(self, db): """Test calculating next run time.""" - service = ScheduleService(test_db) + service = ScheduleService(db) # Daily at 2 AM next_run = service.calculate_next_run('0 2 * * *') @@ -529,9 +529,9 @@ class TestNextRunCalculation: assert isinstance(next_run, datetime) assert next_run > datetime.utcnow() - def test_calculate_next_run_from_time(self, test_db): + def test_calculate_next_run_from_time(self, db): """Test calculating next run from specific time.""" - service = ScheduleService(test_db) + service = ScheduleService(db) base_time = datetime(2025, 1, 1, 0, 0, 0) next_run = service.calculate_next_run('0 2 * * *', from_time=base_time) @@ -540,9 +540,9 @@ class TestNextRunCalculation: assert next_run.hour == 2 assert next_run.minute == 0 - def test_calculate_next_run_invalid_cron(self, test_db): + def test_calculate_next_run_invalid_cron(self, db): """Test calculating next run with invalid cron raises error.""" - service = ScheduleService(test_db) + service = ScheduleService(db) with pytest.raises(ValueError, match="Invalid cron expression"): service.calculate_next_run('invalid cron') @@ -551,13 +551,13 @@ class TestNextRunCalculation: class TestScheduleHistory: """Tests for schedule execution history.""" - def test_get_schedule_history_empty(self, test_db, sample_config_file): + def test_get_schedule_history_empty(self, db, sample_db_config): """Test getting history for schedule with no executions.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -565,13 +565,13 @@ class TestScheduleHistory: history = service.get_schedule_history(schedule_id) assert len(history) == 0 - def test_get_schedule_history_with_scans(self, test_db, sample_config_file): + def test_get_schedule_history_with_scans(self, db, sample_db_config): """Test getting history with multiple scans.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -581,26 +581,26 @@ class TestScheduleHistory: scan = Scan( timestamp=datetime.utcnow() - timedelta(days=i), status='completed', - config_file=sample_config_file, + config_id=sample_db_config.id, title=f'Scan {i}', triggered_by='scheduled', schedule_id=schedule_id ) - test_db.add(scan) - test_db.commit() + db.add(scan) + db.commit() # Get history (default limit 10) history = service.get_schedule_history(schedule_id, limit=10) assert len(history) == 10 assert history[0]['title'] == 'Scan 0' # Most recent first - def test_get_schedule_history_custom_limit(self, test_db, sample_config_file): + def test_get_schedule_history_custom_limit(self, db, sample_db_config): """Test getting history with custom limit.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -610,13 +610,13 @@ class TestScheduleHistory: scan = Scan( timestamp=datetime.utcnow() - timedelta(days=i), status='completed', - config_file=sample_config_file, + config_id=sample_db_config.id, title=f'Scan {i}', triggered_by='scheduled', schedule_id=schedule_id ) - test_db.add(scan) - test_db.commit() + db.add(scan) + db.commit() # Get only 5 history = service.get_schedule_history(schedule_id, limit=5) @@ -626,13 +626,13 @@ class TestScheduleHistory: class TestScheduleSerialization: """Tests for schedule serialization.""" - def test_schedule_to_dict(self, test_db, sample_config_file): + def test_schedule_to_dict(self, db, sample_db_config): """Test converting schedule to dictionary.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test Schedule', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) @@ -642,7 +642,7 @@ class TestScheduleSerialization: # Verify all required fields assert 'id' in result assert 'name' in result - assert 'config_file' in result + assert 'config_id' in result assert 'cron_expression' in result assert 'enabled' in result assert 'last_run' in result @@ -652,13 +652,13 @@ class TestScheduleSerialization: assert 'updated_at' in result assert 'history' in result - def test_schedule_relative_time_formatting(self, test_db, sample_config_file): + def test_schedule_relative_time_formatting(self, db, sample_db_config): """Test relative time formatting in schedule dict.""" - service = ScheduleService(test_db) + service = ScheduleService(db) schedule_id = service.create_schedule( name='Test', - config_file=sample_config_file, + config_id=sample_db_config.id, cron_expression='0 2 * * *', enabled=True ) diff --git a/app/tests/test_stats_api.py b/app/tests/test_stats_api.py index bd72631..2d94608 100644 --- a/app/tests/test_stats_api.py +++ b/app/tests/test_stats_api.py @@ -20,7 +20,7 @@ class TestStatsAPI: scan_date = today - timedelta(days=i) for j in range(i + 1): # Create 1, 2, 3, 4, 5 scans per day scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=scan_date, status='completed', duration=10.5 @@ -56,7 +56,7 @@ class TestStatsAPI: today = datetime.utcnow() for i in range(10): scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today - timedelta(days=i), status='completed', duration=10.5 @@ -105,7 +105,7 @@ class TestStatsAPI: # Create scan 5 days ago scan1 = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today - timedelta(days=5), status='completed', duration=10.5 @@ -114,7 +114,7 @@ class TestStatsAPI: # Create scan 10 days ago scan2 = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today - timedelta(days=10), status='completed', duration=10.5 @@ -148,7 +148,7 @@ class TestStatsAPI: # 5 completed scans for i in range(5): scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today - timedelta(days=i), status='completed', duration=10.5 @@ -158,7 +158,7 @@ class TestStatsAPI: # 2 failed scans for i in range(2): scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today - timedelta(days=i), status='failed', duration=5.0 @@ -167,7 +167,7 @@ class TestStatsAPI: # 1 running scan scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today, status='running', duration=None @@ -217,7 +217,7 @@ class TestStatsAPI: # Create 3 scans today for i in range(3): scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today, status='completed', duration=10.5 @@ -227,7 +227,7 @@ class TestStatsAPI: # Create 2 scans yesterday for i in range(2): scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=yesterday, status='completed', duration=10.5 @@ -250,7 +250,7 @@ class TestStatsAPI: # Create scans over the last 10 days for i in range(10): scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=today - timedelta(days=i), status='completed', duration=10.5 @@ -275,7 +275,7 @@ class TestStatsAPI: """Test scan trend returns dates in correct format.""" # Create a scan scan = Scan( - config_file='/app/configs/test.yaml', + config_id=1, timestamp=datetime.utcnow(), status='completed', duration=10.5 diff --git a/app/web/api/scans.py b/app/web/api/scans.py index b2095d8..f1d0edf 100644 --- a/app/web/api/scans.py +++ b/app/web/api/scans.py @@ -11,7 +11,6 @@ from sqlalchemy.exc import SQLAlchemyError from web.auth.decorators import api_auth_required from web.services.scan_service import ScanService -from web.utils.validators import validate_config_file from web.utils.pagination import validate_page_params bp = Blueprint('scans', __name__) diff --git a/app/web/api/schedules.py b/app/web/api/schedules.py index a21cdd8..acf92a6 100644 --- a/app/web/api/schedules.py +++ b/app/web/api/schedules.py @@ -88,7 +88,7 @@ def create_schedule(): Request body: name: Schedule name (required) - config_file: Path to YAML config (required) + config_id: Database config ID (required) cron_expression: Cron expression (required, e.g., '0 2 * * *') enabled: Whether schedule is active (optional, default: true) @@ -99,7 +99,7 @@ def create_schedule(): data = request.get_json() or {} # Validate required fields - required = ['name', 'config_file', 'cron_expression'] + required = ['name', 'config_id', 'cron_expression'] missing = [field for field in required if field not in data] if missing: return jsonify({'error': f'Missing required fields: {", ".join(missing)}'}), 400 @@ -108,7 +108,7 @@ def create_schedule(): schedule_service = ScheduleService(current_app.db_session) schedule_id = schedule_service.create_schedule( name=data['name'], - config_file=data['config_file'], + config_id=data['config_id'], cron_expression=data['cron_expression'], enabled=data.get('enabled', True) ) @@ -121,7 +121,7 @@ def create_schedule(): try: current_app.scheduler.add_scheduled_scan( schedule_id=schedule_id, - config_file=schedule['config_file'], + config_id=schedule['config_id'], cron_expression=schedule['cron_expression'] ) logger.info(f"Schedule {schedule_id} added to APScheduler") @@ -154,7 +154,7 @@ def update_schedule(schedule_id): Request body: name: Schedule name (optional) - config_file: Path to YAML config (optional) + config_id: Database config ID (optional) cron_expression: Cron expression (optional) enabled: Whether schedule is active (optional) @@ -181,7 +181,7 @@ def update_schedule(schedule_id): try: # If cron expression or config changed, or enabled status changed cron_changed = 'cron_expression' in data - config_changed = 'config_file' in data + config_changed = 'config_id' in data enabled_changed = 'enabled' in data if enabled_changed: @@ -189,7 +189,7 @@ def update_schedule(schedule_id): # Re-add to scheduler (replaces existing) current_app.scheduler.add_scheduled_scan( schedule_id=schedule_id, - config_file=updated_schedule['config_file'], + config_id=updated_schedule['config_id'], cron_expression=updated_schedule['cron_expression'] ) logger.info(f"Schedule {schedule_id} enabled and added to APScheduler") @@ -201,7 +201,7 @@ def update_schedule(schedule_id): # Reload schedule in APScheduler current_app.scheduler.add_scheduled_scan( schedule_id=schedule_id, - config_file=updated_schedule['config_file'], + config_id=updated_schedule['config_id'], cron_expression=updated_schedule['cron_expression'] ) logger.info(f"Schedule {schedule_id} reloaded in APScheduler") @@ -293,7 +293,7 @@ def trigger_schedule(schedule_id): scheduler = current_app.scheduler if hasattr(current_app, 'scheduler') else None scan_id = scan_service.trigger_scan( - config_file=schedule['config_file'], + config_id=schedule['config_id'], triggered_by='manual', schedule_id=schedule_id, scheduler=scheduler diff --git a/app/web/api/stats.py b/app/web/api/stats.py index b6e1088..eac1627 100644 --- a/app/web/api/stats.py +++ b/app/web/api/stats.py @@ -198,12 +198,12 @@ def scan_history(scan_id): if not reference_scan: return jsonify({'error': 'Scan not found'}), 404 - config_file = reference_scan.config_file + config_id = reference_scan.config_id - # Query historical scans with the same config file + # Query historical scans with the same config_id historical_scans = ( db_session.query(Scan) - .filter(Scan.config_file == config_file) + .filter(Scan.config_id == config_id) .filter(Scan.status == 'completed') .order_by(Scan.timestamp.desc()) .limit(limit) @@ -247,7 +247,7 @@ def scan_history(scan_id): 'scans': scans_data, 'labels': labels, 'port_counts': port_counts, - 'config_file': config_file + 'config_id': config_id }), 200 except SQLAlchemyError as e: diff --git a/app/web/jobs/scan_job.py b/app/web/jobs/scan_job.py index e1575a4..ace3f8c 100644 --- a/app/web/jobs/scan_job.py +++ b/app/web/jobs/scan_job.py @@ -21,7 +21,7 @@ from web.services.alert_service import AlertService logger = logging.getLogger(__name__) -def execute_scan(scan_id: int, config_file: str = None, config_id: int = None, db_url: str = None): +def execute_scan(scan_id: int, config_id: int, db_url: str = None): """ Execute a scan in the background. @@ -31,12 +31,9 @@ def execute_scan(scan_id: int, config_file: str = None, config_id: int = None, d Args: scan_id: ID of the scan record in database - config_file: Path to YAML configuration file (legacy, optional) - config_id: Database config ID (preferred, optional) + config_id: Database config ID db_url: Database connection URL - Note: Provide exactly one of config_file or config_id - Workflow: 1. Create new database session for this thread 2. Update scan status to 'running' @@ -45,8 +42,7 @@ def execute_scan(scan_id: int, config_file: str = None, config_id: int = None, d 5. Save results to database 6. Update status to 'completed' or 'failed' """ - config_desc = f"config_id={config_id}" if config_id else f"config_file={config_file}" - logger.info(f"Starting background scan execution: scan_id={scan_id}, {config_desc}") + logger.info(f"Starting background scan execution: scan_id={scan_id}, config_id={config_id}") # Create new database session for this thread engine = create_engine(db_url, echo=False) @@ -65,21 +61,10 @@ def execute_scan(scan_id: int, config_file: str = None, config_id: int = None, d scan.started_at = datetime.utcnow() session.commit() - logger.info(f"Scan {scan_id}: Initializing scanner with {config_desc}") + logger.info(f"Scan {scan_id}: Initializing scanner with config_id={config_id}") - # Initialize scanner based on config type - if config_id: - # Use database config - scanner = SneakyScanner(config_id=config_id) - else: - # Use YAML config file - # Convert config_file to full path if it's just a filename - if not config_file.startswith('/'): - config_path = f'/app/configs/{config_file}' - else: - config_path = config_file - - scanner = SneakyScanner(config_path=config_path) + # Initialize scanner with database config + scanner = SneakyScanner(config_id=config_id) # Execute scan logger.info(f"Scan {scan_id}: Running scanner...") diff --git a/app/web/models.py b/app/web/models.py index 6bd4d99..caa543a 100644 --- a/app/web/models.py +++ b/app/web/models.py @@ -46,7 +46,6 @@ class Scan(Base): timestamp = Column(DateTime, nullable=False, index=True, comment="Scan start time (UTC)") duration = Column(Float, nullable=True, comment="Total scan duration in seconds") status = Column(String(20), nullable=False, default='running', comment="running, completed, failed") - config_file = Column(Text, nullable=True, comment="Path to YAML config used (deprecated)") config_id = Column(Integer, ForeignKey('scan_configs.id'), nullable=True, index=True, comment="FK to scan_configs table") title = Column(Text, nullable=True, comment="Scan title from config") json_path = Column(Text, nullable=True, comment="Path to JSON report") @@ -403,7 +402,6 @@ class Schedule(Base): id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(255), nullable=False, comment="Schedule name (e.g., 'Daily prod scan')") - config_file = Column(Text, nullable=True, comment="Path to YAML config (deprecated)") config_id = Column(Integer, ForeignKey('scan_configs.id'), nullable=True, index=True, comment="FK to scan_configs table") cron_expression = Column(String(100), nullable=False, comment="Cron-like schedule (e.g., '0 2 * * *')") enabled = Column(Boolean, nullable=False, default=True, comment="Is schedule active?") diff --git a/app/web/routes/main.py b/app/web/routes/main.py index 99277d5..6a55880 100644 --- a/app/web/routes/main.py +++ b/app/web/routes/main.py @@ -101,22 +101,19 @@ def create_schedule(): Create new schedule form page. Returns: - Rendered schedule create template with available config files + Rendered schedule create template with available configs """ - import os + from web.models import ScanConfig - # Get list of available config files - configs_dir = '/app/configs' - config_files = [] + # Get list of available configs from database + configs = [] try: - if os.path.exists(configs_dir): - config_files = [f for f in os.listdir(configs_dir) if f.endswith('.yaml')] - config_files.sort() + configs = current_app.db_session.query(ScanConfig).order_by(ScanConfig.title).all() except Exception as e: - logger.error(f"Error listing config files: {e}") + logger.error(f"Error listing configs: {e}") - return render_template('schedule_create.html', config_files=config_files) + return render_template('schedule_create.html', configs=configs) @bp.route('/schedules/