""" Alert Service Module Handles alert evaluation, rule processing, and notification triggering for SneakyScan Phase 5. """ import logging from datetime import datetime, timezone from typing import List, Dict, Optional, Any from sqlalchemy.orm import Session from ..models import ( Alert, AlertRule, Scan, ScanPort, ScanIP, ScanService as ScanServiceModel, ScanCertificate, ScanTLSVersion ) from .scan_service import ScanService logger = logging.getLogger(__name__) class AlertService: """ Service for evaluating alert rules and generating alerts based on scan results. Supports two main alert types: 1. Unexpected Port Detection - Alerts when ports marked as unexpected are found open 2. Drift Detection - Alerts when scan results differ from previous scan """ def __init__(self, db_session: Session): self.db = db_session self.scan_service = ScanService(db_session) def evaluate_alert_rules(self, scan_id: int) -> List[Alert]: """ Main entry point for alert evaluation after scan completion. Args: scan_id: ID of the completed scan to evaluate Returns: List of Alert objects that were created """ logger.info(f"Starting alert evaluation for scan {scan_id}") # Get the scan scan = self.db.query(Scan).filter(Scan.id == scan_id).first() if not scan: logger.error(f"Scan {scan_id} not found") return [] # Get all enabled alert rules rules = self.db.query(AlertRule).filter(AlertRule.enabled == True).all() logger.info(f"Found {len(rules)} enabled alert rules to evaluate") alerts_created = [] for rule in rules: try: # Check if rule applies to this scan's config if rule.config_file and scan.config_file != rule.config_file: logger.debug(f"Skipping rule {rule.id} - config mismatch") continue # Evaluate based on rule type alert_data = [] if rule.rule_type == 'unexpected_port': alert_data = self.check_unexpected_ports(scan, rule) elif rule.rule_type == 'drift_detection': alert_data = self.check_drift_from_previous(scan, rule) elif rule.rule_type == 'cert_expiry': alert_data = self.check_certificate_expiry(scan, rule) elif rule.rule_type == 'weak_tls': alert_data = self.check_weak_tls(scan, rule) elif rule.rule_type == 'ping_failed': alert_data = self.check_ping_failures(scan, rule) else: logger.warning(f"Unknown rule type: {rule.rule_type}") continue # Create alerts for any findings for alert_info in alert_data: alert = self.create_alert(scan_id, rule, alert_info) if alert: alerts_created.append(alert) # Trigger notifications if configured if rule.email_enabled or rule.webhook_enabled: self.trigger_notifications(alert, rule) logger.info(f"Rule {rule.name or rule.id} generated {len(alert_data)} alerts") except Exception as e: logger.error(f"Error evaluating rule {rule.id}: {str(e)}") continue logger.info(f"Alert evaluation complete. Created {len(alerts_created)} alerts") return alerts_created def check_unexpected_ports(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]: """ Detect ports that are open but not in the expected_ports list. Args: scan: The scan to check rule: The alert rule configuration Returns: List of alert data dictionaries """ alerts_to_create = [] # Get all ports where expected=False unexpected_ports = ( self.db.query(ScanPort, ScanIP) .join(ScanIP, ScanPort.ip_id == ScanIP.id) .filter(ScanPort.scan_id == scan.id) .filter(ScanPort.expected == False) # Not in config's expected_ports .filter(ScanPort.state == 'open') .all() ) # High-risk ports that should trigger critical alerts high_risk_ports = { 22, # SSH 23, # Telnet 135, # Windows RPC 139, # NetBIOS 445, # SMB 1433, # SQL Server 3306, # MySQL 3389, # RDP 5432, # PostgreSQL 5900, # VNC 6379, # Redis 9200, # Elasticsearch 27017, # MongoDB } for port, ip in unexpected_ports: # Determine severity based on port number severity = rule.severity or ('critical' if port.port in high_risk_ports else 'warning') # Get service info if available service = ( self.db.query(ScanServiceModel) .filter(ScanServiceModel.port_id == port.id) .first() ) service_info = "" if service: product = service.product or "Unknown" version = service.version or "" service_info = f" (Service: {service.service_name}: {product} {version}".strip() + ")" alerts_to_create.append({ 'alert_type': 'unexpected_port', 'severity': severity, 'message': f"Unexpected port open on {ip.ip_address}:{port.port}/{port.protocol}{service_info}", 'ip_address': ip.ip_address, 'port': port.port }) return alerts_to_create def check_drift_from_previous(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]: """ Compare current scan to the last scan with the same config. Args: scan: The current scan rule: The alert rule configuration Returns: List of alert data dictionaries """ alerts_to_create = [] # Find previous scan with same config_file previous_scan = ( self.db.query(Scan) .filter(Scan.config_file == scan.config_file) .filter(Scan.id < scan.id) .filter(Scan.status == 'completed') .order_by(Scan.started_at.desc() if Scan.started_at else Scan.timestamp.desc()) .first() ) if not previous_scan: logger.info(f"No previous scan found for config {scan.config_file}") return [] try: # Use existing comparison logic from scan_service comparison = self.scan_service.compare_scans(previous_scan.id, scan.id) # Alert on new ports for port_data in comparison.get('ports', {}).get('added', []): severity = rule.severity or 'warning' alerts_to_create.append({ 'alert_type': 'drift_new_port', 'severity': severity, 'message': f"New port detected: {port_data['ip']}:{port_data['port']}/{port_data['protocol']}", 'ip_address': port_data['ip'], 'port': port_data['port'] }) # Alert on removed ports for port_data in comparison.get('ports', {}).get('removed', []): severity = rule.severity or 'info' alerts_to_create.append({ 'alert_type': 'drift_missing_port', 'severity': severity, 'message': f"Port no longer open: {port_data['ip']}:{port_data['port']}/{port_data['protocol']}", 'ip_address': port_data['ip'], 'port': port_data['port'] }) # Alert on service changes for svc_data in comparison.get('services', {}).get('changed', []): old_svc = svc_data.get('old', {}) new_svc = svc_data.get('new', {}) old_desc = f"{old_svc.get('product', 'Unknown')} {old_svc.get('version', '')}".strip() new_desc = f"{new_svc.get('product', 'Unknown')} {new_svc.get('version', '')}".strip() severity = rule.severity or 'info' alerts_to_create.append({ 'alert_type': 'drift_service_change', 'severity': severity, 'message': f"Service changed on {svc_data['ip']}:{svc_data['port']}: {old_desc} → {new_desc}", 'ip_address': svc_data['ip'], 'port': svc_data['port'] }) # Alert on certificate changes for cert_data in comparison.get('certificates', {}).get('changed', []): old_cert = cert_data.get('old', {}) new_cert = cert_data.get('new', {}) severity = rule.severity or 'warning' alerts_to_create.append({ 'alert_type': 'drift_cert_change', 'severity': severity, 'message': f"Certificate changed on {cert_data['ip']}:{cert_data['port']} - " f"Subject: {old_cert.get('subject', 'Unknown')} → {new_cert.get('subject', 'Unknown')}", 'ip_address': cert_data['ip'], 'port': cert_data['port'] }) # Check drift score threshold if configured if rule.threshold and comparison.get('drift_score', 0) * 100 >= rule.threshold: alerts_to_create.append({ 'alert_type': 'drift_threshold_exceeded', 'severity': rule.severity or 'warning', 'message': f"Drift score {comparison['drift_score']*100:.1f}% exceeds threshold {rule.threshold}%", 'ip_address': None, 'port': None }) except Exception as e: logger.error(f"Error comparing scans: {str(e)}") return alerts_to_create def check_certificate_expiry(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]: """ Check for certificates expiring within the threshold days. Args: scan: The scan to check rule: The alert rule configuration Returns: List of alert data dictionaries """ alerts_to_create = [] threshold_days = rule.threshold or 30 # Default 30 days # Get all certificates from the scan certificates = ( self.db.query(ScanCertificate, ScanPort, ScanIP) .join(ScanServiceModel, ScanCertificate.service_id == ScanServiceModel.id) .join(ScanPort, ScanServiceModel.port_id == ScanPort.id) .join(ScanIP, ScanPort.ip_id == ScanIP.id) .filter(ScanPort.scan_id == scan.id) .all() ) for cert, port, ip in certificates: if cert.days_until_expiry is not None and cert.days_until_expiry <= threshold_days: if cert.days_until_expiry <= 0: severity = 'critical' message = f"Certificate EXPIRED on {ip.ip_address}:{port.port}" elif cert.days_until_expiry <= 7: severity = 'critical' message = f"Certificate expires in {cert.days_until_expiry} days on {ip.ip_address}:{port.port}" elif cert.days_until_expiry <= 14: severity = 'warning' message = f"Certificate expires in {cert.days_until_expiry} days on {ip.ip_address}:{port.port}" else: severity = 'info' message = f"Certificate expires in {cert.days_until_expiry} days on {ip.ip_address}:{port.port}" alerts_to_create.append({ 'alert_type': 'cert_expiry', 'severity': severity, 'message': message, 'ip_address': ip.ip_address, 'port': port.port }) return alerts_to_create def check_weak_tls(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]: """ Check for weak TLS versions (1.0, 1.1). Args: scan: The scan to check rule: The alert rule configuration Returns: List of alert data dictionaries """ alerts_to_create = [] # Get all TLS version data from the scan tls_versions = ( self.db.query(ScanTLSVersion, ScanPort, ScanIP) .join(ScanCertificate, ScanTLSVersion.certificate_id == ScanCertificate.id) .join(ScanServiceModel, ScanCertificate.service_id == ScanServiceModel.id) .join(ScanPort, ScanServiceModel.port_id == ScanPort.id) .join(ScanIP, ScanPort.ip_id == ScanIP.id) .filter(ScanPort.scan_id == scan.id) .all() ) # Group TLS versions by port/IP to create one alert per host tls_by_host = {} for tls, port, ip in tls_versions: # Only alert on weak TLS versions that are supported if tls.supported and tls.tls_version in ['TLS 1.0', 'TLS 1.1']: key = (ip.ip_address, port.port) if key not in tls_by_host: tls_by_host[key] = {'ip': ip.ip_address, 'port': port.port, 'versions': []} tls_by_host[key]['versions'].append(tls.tls_version) # Create alerts for hosts with weak TLS for host_key, host_data in tls_by_host.items(): severity = rule.severity or 'warning' alerts_to_create.append({ 'alert_type': 'weak_tls', 'severity': severity, 'message': f"Weak TLS versions supported on {host_data['ip']}:{host_data['port']}: {', '.join(host_data['versions'])}", 'ip_address': host_data['ip'], 'port': host_data['port'] }) return alerts_to_create def check_ping_failures(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]: """ Check for hosts that were expected to respond to ping but didn't. Args: scan: The scan to check rule: The alert rule configuration Returns: List of alert data dictionaries """ alerts_to_create = [] # Get all IPs where ping was expected but failed failed_pings = ( self.db.query(ScanIP) .filter(ScanIP.scan_id == scan.id) .filter(ScanIP.ping_expected == True) .filter(ScanIP.ping_actual == False) .all() ) for ip in failed_pings: severity = rule.severity or 'warning' alerts_to_create.append({ 'alert_type': 'ping_failed', 'severity': severity, 'message': f"Host {ip.ip_address} did not respond to ping (expected to be up)", 'ip_address': ip.ip_address, 'port': None }) return alerts_to_create def create_alert(self, scan_id: int, rule: AlertRule, alert_data: Dict[str, Any]) -> Optional[Alert]: """ Create an alert record in the database. Args: scan_id: ID of the scan that triggered the alert rule: The alert rule that was triggered alert_data: Dictionary with alert details Returns: Created Alert object or None if creation failed """ try: alert = Alert( scan_id=scan_id, rule_id=rule.id, alert_type=alert_data['alert_type'], severity=alert_data['severity'], message=alert_data['message'], ip_address=alert_data.get('ip_address'), port=alert_data.get('port'), created_at=datetime.now(timezone.utc) ) self.db.add(alert) self.db.commit() logger.info(f"Created alert: {alert.message}") return alert except Exception as e: logger.error(f"Failed to create alert: {str(e)}") self.db.rollback() return None def trigger_notifications(self, alert: Alert, rule: AlertRule): """ Send notifications for an alert based on rule configuration. Args: alert: The alert to send notifications for rule: The rule that specifies notification settings """ # Email notification will be implemented in email_service.py if rule.email_enabled: logger.info(f"Email notification would be sent for alert {alert.id}") # TODO: Call email service # Webhook notification - queue for delivery if rule.webhook_enabled: try: from flask import current_app from .webhook_service import WebhookService webhook_service = WebhookService(self.db) # Get matching webhooks for this alert matching_webhooks = webhook_service.get_matching_webhooks(alert) if matching_webhooks: # Get scheduler from app context scheduler = getattr(current_app, 'scheduler', None) # Queue delivery for each matching webhook for webhook in matching_webhooks: webhook_service.queue_webhook_delivery( webhook.id, alert.id, scheduler_service=scheduler ) logger.info(f"Queued webhook {webhook.id} ({webhook.name}) for alert {alert.id}") else: logger.debug(f"No matching webhooks found for alert {alert.id}") except Exception as e: logger.error(f"Failed to queue webhook notifications for alert {alert.id}: {e}", exc_info=True) # Don't fail alert creation if webhook queueing fails def acknowledge_alert(self, alert_id: int, acknowledged_by: str = "system") -> bool: """ Acknowledge an alert. Args: alert_id: ID of the alert to acknowledge acknowledged_by: Username or system identifier Returns: True if successful, False otherwise """ try: alert = self.db.query(Alert).filter(Alert.id == alert_id).first() if not alert: logger.error(f"Alert {alert_id} not found") return False alert.acknowledged = True alert.acknowledged_at = datetime.now(timezone.utc) alert.acknowledged_by = acknowledged_by self.db.commit() logger.info(f"Alert {alert_id} acknowledged by {acknowledged_by}") return True except Exception as e: logger.error(f"Failed to acknowledge alert {alert_id}: {str(e)}") self.db.rollback() return False def get_alerts_for_scan(self, scan_id: int) -> List[Alert]: """ Get all alerts for a specific scan. Args: scan_id: ID of the scan Returns: List of Alert objects """ return ( self.db.query(Alert) .filter(Alert.scan_id == scan_id) .order_by(Alert.severity.desc(), Alert.created_at.desc()) .all() )