521 lines
19 KiB
Python
521 lines
19 KiB
Python
"""
|
|
Alert Service Module
|
|
|
|
Handles alert evaluation, rule processing, and notification triggering
|
|
for SneakyScan Phase 5.
|
|
"""
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import List, Dict, Optional, Any
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ..models import (
|
|
Alert, AlertRule, Scan, ScanPort, ScanIP, ScanService as ScanServiceModel,
|
|
ScanCertificate, ScanTLSVersion
|
|
)
|
|
from .scan_service import ScanService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AlertService:
|
|
"""
|
|
Service for evaluating alert rules and generating alerts based on scan results.
|
|
|
|
Supports two main alert types:
|
|
1. Unexpected Port Detection - Alerts when ports marked as unexpected are found open
|
|
2. Drift Detection - Alerts when scan results differ from previous scan
|
|
"""
|
|
|
|
def __init__(self, db_session: Session):
|
|
self.db = db_session
|
|
self.scan_service = ScanService(db_session)
|
|
|
|
def evaluate_alert_rules(self, scan_id: int) -> List[Alert]:
|
|
"""
|
|
Main entry point for alert evaluation after scan completion.
|
|
|
|
Args:
|
|
scan_id: ID of the completed scan to evaluate
|
|
|
|
Returns:
|
|
List of Alert objects that were created
|
|
"""
|
|
logger.info(f"Starting alert evaluation for scan {scan_id}")
|
|
|
|
# Get the scan
|
|
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
|
if not scan:
|
|
logger.error(f"Scan {scan_id} not found")
|
|
return []
|
|
|
|
# Get all enabled alert rules
|
|
rules = self.db.query(AlertRule).filter(AlertRule.enabled == True).all()
|
|
logger.info(f"Found {len(rules)} enabled alert rules to evaluate")
|
|
|
|
alerts_created = []
|
|
|
|
for rule in rules:
|
|
try:
|
|
# Check if rule applies to this scan's config
|
|
if rule.config_id and scan.config_id != rule.config_id:
|
|
logger.debug(f"Skipping rule {rule.id} - config mismatch")
|
|
continue
|
|
|
|
# Evaluate based on rule type
|
|
alert_data = []
|
|
|
|
if rule.rule_type == 'unexpected_port':
|
|
alert_data = self.check_unexpected_ports(scan, rule)
|
|
elif rule.rule_type == 'drift_detection':
|
|
alert_data = self.check_drift_from_previous(scan, rule)
|
|
elif rule.rule_type == 'cert_expiry':
|
|
alert_data = self.check_certificate_expiry(scan, rule)
|
|
elif rule.rule_type == 'weak_tls':
|
|
alert_data = self.check_weak_tls(scan, rule)
|
|
elif rule.rule_type == 'ping_failed':
|
|
alert_data = self.check_ping_failures(scan, rule)
|
|
else:
|
|
logger.warning(f"Unknown rule type: {rule.rule_type}")
|
|
continue
|
|
|
|
# Create alerts for any findings
|
|
for alert_info in alert_data:
|
|
alert = self.create_alert(scan_id, rule, alert_info)
|
|
if alert:
|
|
alerts_created.append(alert)
|
|
|
|
# Trigger notifications if configured
|
|
if rule.email_enabled or rule.webhook_enabled:
|
|
self.trigger_notifications(alert, rule)
|
|
|
|
logger.info(f"Rule {rule.name or rule.id} generated {len(alert_data)} alerts")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating rule {rule.id}: {str(e)}")
|
|
continue
|
|
|
|
logger.info(f"Alert evaluation complete. Created {len(alerts_created)} alerts")
|
|
return alerts_created
|
|
|
|
def check_unexpected_ports(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]:
|
|
"""
|
|
Detect ports that are open but not in the expected_ports list.
|
|
|
|
Args:
|
|
scan: The scan to check
|
|
rule: The alert rule configuration
|
|
|
|
Returns:
|
|
List of alert data dictionaries
|
|
"""
|
|
alerts_to_create = []
|
|
|
|
# Get all ports where expected=False
|
|
unexpected_ports = (
|
|
self.db.query(ScanPort, ScanIP)
|
|
.join(ScanIP, ScanPort.ip_id == ScanIP.id)
|
|
.filter(ScanPort.scan_id == scan.id)
|
|
.filter(ScanPort.expected == False) # Not in config's expected_ports
|
|
.filter(ScanPort.state == 'open')
|
|
.all()
|
|
)
|
|
|
|
# High-risk ports that should trigger critical alerts
|
|
high_risk_ports = {
|
|
22, # SSH
|
|
23, # Telnet
|
|
135, # Windows RPC
|
|
139, # NetBIOS
|
|
445, # SMB
|
|
1433, # SQL Server
|
|
3306, # MySQL
|
|
3389, # RDP
|
|
5432, # PostgreSQL
|
|
5900, # VNC
|
|
6379, # Redis
|
|
9200, # Elasticsearch
|
|
27017, # MongoDB
|
|
}
|
|
|
|
for port, ip in unexpected_ports:
|
|
# Determine severity based on port number
|
|
severity = rule.severity or ('critical' if port.port in high_risk_ports else 'warning')
|
|
|
|
# Get service info if available
|
|
service = (
|
|
self.db.query(ScanServiceModel)
|
|
.filter(ScanServiceModel.port_id == port.id)
|
|
.first()
|
|
)
|
|
|
|
service_info = ""
|
|
if service:
|
|
product = service.product or "Unknown"
|
|
version = service.version or ""
|
|
service_info = f" (Service: {service.service_name}: {product} {version}".strip() + ")"
|
|
|
|
alerts_to_create.append({
|
|
'alert_type': 'unexpected_port',
|
|
'severity': severity,
|
|
'message': f"Unexpected port open on {ip.ip_address}:{port.port}/{port.protocol}{service_info}",
|
|
'ip_address': ip.ip_address,
|
|
'port': port.port
|
|
})
|
|
|
|
return alerts_to_create
|
|
|
|
def check_drift_from_previous(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]:
|
|
"""
|
|
Compare current scan to the last scan with the same config.
|
|
|
|
Args:
|
|
scan: The current scan
|
|
rule: The alert rule configuration
|
|
|
|
Returns:
|
|
List of alert data dictionaries
|
|
"""
|
|
alerts_to_create = []
|
|
|
|
# Find previous scan with same config_id
|
|
previous_scan = (
|
|
self.db.query(Scan)
|
|
.filter(Scan.config_id == scan.config_id)
|
|
.filter(Scan.id < scan.id)
|
|
.filter(Scan.status == 'completed')
|
|
.order_by(Scan.started_at.desc() if Scan.started_at else Scan.timestamp.desc())
|
|
.first()
|
|
)
|
|
|
|
if not previous_scan:
|
|
logger.info(f"No previous scan found for config_id {scan.config_id}")
|
|
return []
|
|
|
|
try:
|
|
# Use existing comparison logic from scan_service
|
|
comparison = self.scan_service.compare_scans(previous_scan.id, scan.id)
|
|
|
|
# Alert on new ports
|
|
for port_data in comparison.get('ports', {}).get('added', []):
|
|
severity = rule.severity or 'warning'
|
|
alerts_to_create.append({
|
|
'alert_type': 'drift_new_port',
|
|
'severity': severity,
|
|
'message': f"New port detected: {port_data['ip']}:{port_data['port']}/{port_data['protocol']}",
|
|
'ip_address': port_data['ip'],
|
|
'port': port_data['port']
|
|
})
|
|
|
|
# Alert on removed ports
|
|
for port_data in comparison.get('ports', {}).get('removed', []):
|
|
severity = rule.severity or 'info'
|
|
alerts_to_create.append({
|
|
'alert_type': 'drift_missing_port',
|
|
'severity': severity,
|
|
'message': f"Port no longer open: {port_data['ip']}:{port_data['port']}/{port_data['protocol']}",
|
|
'ip_address': port_data['ip'],
|
|
'port': port_data['port']
|
|
})
|
|
|
|
# Alert on service changes
|
|
for svc_data in comparison.get('services', {}).get('changed', []):
|
|
old_svc = svc_data.get('old', {})
|
|
new_svc = svc_data.get('new', {})
|
|
|
|
old_desc = f"{old_svc.get('product', 'Unknown')} {old_svc.get('version', '')}".strip()
|
|
new_desc = f"{new_svc.get('product', 'Unknown')} {new_svc.get('version', '')}".strip()
|
|
|
|
severity = rule.severity or 'info'
|
|
alerts_to_create.append({
|
|
'alert_type': 'drift_service_change',
|
|
'severity': severity,
|
|
'message': f"Service changed on {svc_data['ip']}:{svc_data['port']}: {old_desc} → {new_desc}",
|
|
'ip_address': svc_data['ip'],
|
|
'port': svc_data['port']
|
|
})
|
|
|
|
# Alert on certificate changes
|
|
for cert_data in comparison.get('certificates', {}).get('changed', []):
|
|
old_cert = cert_data.get('old', {})
|
|
new_cert = cert_data.get('new', {})
|
|
|
|
severity = rule.severity or 'warning'
|
|
alerts_to_create.append({
|
|
'alert_type': 'drift_cert_change',
|
|
'severity': severity,
|
|
'message': f"Certificate changed on {cert_data['ip']}:{cert_data['port']} - "
|
|
f"Subject: {old_cert.get('subject', 'Unknown')} → {new_cert.get('subject', 'Unknown')}",
|
|
'ip_address': cert_data['ip'],
|
|
'port': cert_data['port']
|
|
})
|
|
|
|
# Check drift score threshold if configured
|
|
if rule.threshold and comparison.get('drift_score', 0) * 100 >= rule.threshold:
|
|
alerts_to_create.append({
|
|
'alert_type': 'drift_threshold_exceeded',
|
|
'severity': rule.severity or 'warning',
|
|
'message': f"Drift score {comparison['drift_score']*100:.1f}% exceeds threshold {rule.threshold}%",
|
|
'ip_address': None,
|
|
'port': None
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error comparing scans: {str(e)}")
|
|
|
|
return alerts_to_create
|
|
|
|
def check_certificate_expiry(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]:
|
|
"""
|
|
Check for certificates expiring within the threshold days.
|
|
|
|
Args:
|
|
scan: The scan to check
|
|
rule: The alert rule configuration
|
|
|
|
Returns:
|
|
List of alert data dictionaries
|
|
"""
|
|
alerts_to_create = []
|
|
threshold_days = rule.threshold or 30 # Default 30 days
|
|
|
|
# Get all certificates from the scan
|
|
certificates = (
|
|
self.db.query(ScanCertificate, ScanPort, ScanIP)
|
|
.join(ScanServiceModel, ScanCertificate.service_id == ScanServiceModel.id)
|
|
.join(ScanPort, ScanServiceModel.port_id == ScanPort.id)
|
|
.join(ScanIP, ScanPort.ip_id == ScanIP.id)
|
|
.filter(ScanPort.scan_id == scan.id)
|
|
.all()
|
|
)
|
|
|
|
for cert, port, ip in certificates:
|
|
if cert.days_until_expiry is not None and cert.days_until_expiry <= threshold_days:
|
|
if cert.days_until_expiry <= 0:
|
|
severity = 'critical'
|
|
message = f"Certificate EXPIRED on {ip.ip_address}:{port.port}"
|
|
elif cert.days_until_expiry <= 7:
|
|
severity = 'critical'
|
|
message = f"Certificate expires in {cert.days_until_expiry} days on {ip.ip_address}:{port.port}"
|
|
elif cert.days_until_expiry <= 14:
|
|
severity = 'warning'
|
|
message = f"Certificate expires in {cert.days_until_expiry} days on {ip.ip_address}:{port.port}"
|
|
else:
|
|
severity = 'info'
|
|
message = f"Certificate expires in {cert.days_until_expiry} days on {ip.ip_address}:{port.port}"
|
|
|
|
alerts_to_create.append({
|
|
'alert_type': 'cert_expiry',
|
|
'severity': severity,
|
|
'message': message,
|
|
'ip_address': ip.ip_address,
|
|
'port': port.port
|
|
})
|
|
|
|
return alerts_to_create
|
|
|
|
def check_weak_tls(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]:
|
|
"""
|
|
Check for weak TLS versions (1.0, 1.1).
|
|
|
|
Args:
|
|
scan: The scan to check
|
|
rule: The alert rule configuration
|
|
|
|
Returns:
|
|
List of alert data dictionaries
|
|
"""
|
|
alerts_to_create = []
|
|
|
|
# Get all TLS version data from the scan
|
|
tls_versions = (
|
|
self.db.query(ScanTLSVersion, ScanPort, ScanIP)
|
|
.join(ScanCertificate, ScanTLSVersion.certificate_id == ScanCertificate.id)
|
|
.join(ScanServiceModel, ScanCertificate.service_id == ScanServiceModel.id)
|
|
.join(ScanPort, ScanServiceModel.port_id == ScanPort.id)
|
|
.join(ScanIP, ScanPort.ip_id == ScanIP.id)
|
|
.filter(ScanPort.scan_id == scan.id)
|
|
.all()
|
|
)
|
|
|
|
# Group TLS versions by port/IP to create one alert per host
|
|
tls_by_host = {}
|
|
for tls, port, ip in tls_versions:
|
|
# Only alert on weak TLS versions that are supported
|
|
if tls.supported and tls.tls_version in ['TLS 1.0', 'TLS 1.1']:
|
|
key = (ip.ip_address, port.port)
|
|
if key not in tls_by_host:
|
|
tls_by_host[key] = {'ip': ip.ip_address, 'port': port.port, 'versions': []}
|
|
tls_by_host[key]['versions'].append(tls.tls_version)
|
|
|
|
# Create alerts for hosts with weak TLS
|
|
for host_key, host_data in tls_by_host.items():
|
|
severity = rule.severity or 'warning'
|
|
alerts_to_create.append({
|
|
'alert_type': 'weak_tls',
|
|
'severity': severity,
|
|
'message': f"Weak TLS versions supported on {host_data['ip']}:{host_data['port']}: {', '.join(host_data['versions'])}",
|
|
'ip_address': host_data['ip'],
|
|
'port': host_data['port']
|
|
})
|
|
|
|
return alerts_to_create
|
|
|
|
def check_ping_failures(self, scan: Scan, rule: AlertRule) -> List[Dict[str, Any]]:
|
|
"""
|
|
Check for hosts that were expected to respond to ping but didn't.
|
|
|
|
Args:
|
|
scan: The scan to check
|
|
rule: The alert rule configuration
|
|
|
|
Returns:
|
|
List of alert data dictionaries
|
|
"""
|
|
alerts_to_create = []
|
|
|
|
# Get all IPs where ping was expected but failed
|
|
failed_pings = (
|
|
self.db.query(ScanIP)
|
|
.filter(ScanIP.scan_id == scan.id)
|
|
.filter(ScanIP.ping_expected == True)
|
|
.filter(ScanIP.ping_actual == False)
|
|
.all()
|
|
)
|
|
|
|
for ip in failed_pings:
|
|
severity = rule.severity or 'warning'
|
|
alerts_to_create.append({
|
|
'alert_type': 'ping_failed',
|
|
'severity': severity,
|
|
'message': f"Host {ip.ip_address} did not respond to ping (expected to be up)",
|
|
'ip_address': ip.ip_address,
|
|
'port': None
|
|
})
|
|
|
|
return alerts_to_create
|
|
|
|
def create_alert(self, scan_id: int, rule: AlertRule, alert_data: Dict[str, Any]) -> Optional[Alert]:
|
|
"""
|
|
Create an alert record in the database.
|
|
|
|
Args:
|
|
scan_id: ID of the scan that triggered the alert
|
|
rule: The alert rule that was triggered
|
|
alert_data: Dictionary with alert details
|
|
|
|
Returns:
|
|
Created Alert object or None if creation failed
|
|
"""
|
|
try:
|
|
alert = Alert(
|
|
scan_id=scan_id,
|
|
rule_id=rule.id,
|
|
alert_type=alert_data['alert_type'],
|
|
severity=alert_data['severity'],
|
|
message=alert_data['message'],
|
|
ip_address=alert_data.get('ip_address'),
|
|
port=alert_data.get('port'),
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
self.db.add(alert)
|
|
self.db.commit()
|
|
|
|
logger.info(f"Created alert: {alert.message}")
|
|
return alert
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create alert: {str(e)}")
|
|
self.db.rollback()
|
|
return None
|
|
|
|
def trigger_notifications(self, alert: Alert, rule: AlertRule):
|
|
"""
|
|
Send notifications for an alert based on rule configuration.
|
|
|
|
Args:
|
|
alert: The alert to send notifications for
|
|
rule: The rule that specifies notification settings
|
|
"""
|
|
# Email notification will be implemented in email_service.py
|
|
if rule.email_enabled:
|
|
logger.info(f"Email notification would be sent for alert {alert.id}")
|
|
# TODO: Call email service
|
|
|
|
# Webhook notification - queue for delivery
|
|
if rule.webhook_enabled:
|
|
try:
|
|
from flask import current_app
|
|
from .webhook_service import WebhookService
|
|
|
|
webhook_service = WebhookService(self.db)
|
|
|
|
# Get matching webhooks for this alert
|
|
matching_webhooks = webhook_service.get_matching_webhooks(alert)
|
|
|
|
if matching_webhooks:
|
|
# Get scheduler from app context
|
|
scheduler = getattr(current_app, 'scheduler', None)
|
|
|
|
# Queue delivery for each matching webhook
|
|
for webhook in matching_webhooks:
|
|
webhook_service.queue_webhook_delivery(
|
|
webhook.id,
|
|
alert.id,
|
|
scheduler_service=scheduler
|
|
)
|
|
logger.info(f"Queued webhook {webhook.id} ({webhook.name}) for alert {alert.id}")
|
|
else:
|
|
logger.debug(f"No matching webhooks found for alert {alert.id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue webhook notifications for alert {alert.id}: {e}", exc_info=True)
|
|
# Don't fail alert creation if webhook queueing fails
|
|
|
|
def acknowledge_alert(self, alert_id: int, acknowledged_by: str = "system") -> bool:
|
|
"""
|
|
Acknowledge an alert.
|
|
|
|
Args:
|
|
alert_id: ID of the alert to acknowledge
|
|
acknowledged_by: Username or system identifier
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
alert = self.db.query(Alert).filter(Alert.id == alert_id).first()
|
|
if not alert:
|
|
logger.error(f"Alert {alert_id} not found")
|
|
return False
|
|
|
|
alert.acknowledged = True
|
|
alert.acknowledged_at = datetime.now(timezone.utc)
|
|
alert.acknowledged_by = acknowledged_by
|
|
|
|
self.db.commit()
|
|
logger.info(f"Alert {alert_id} acknowledged by {acknowledged_by}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to acknowledge alert {alert_id}: {str(e)}")
|
|
self.db.rollback()
|
|
return False
|
|
|
|
def get_alerts_for_scan(self, scan_id: int) -> List[Alert]:
|
|
"""
|
|
Get all alerts for a specific scan.
|
|
|
|
Args:
|
|
scan_id: ID of the scan
|
|
|
|
Returns:
|
|
List of Alert objects
|
|
"""
|
|
return (
|
|
self.db.query(Alert)
|
|
.filter(Alert.scan_id == scan_id)
|
|
.order_by(Alert.severity.desc(), Alert.created_at.desc())
|
|
.all()
|
|
) |