"""State manager for alert deduplication with atomic file persistence.""" import json import os import tempfile from pathlib import Path from typing import Optional, Union from app.models.alerts import AggregatedAlert, TriggeredAlert from app.models.state import AlertSnapshot, AlertState from app.utils.logging_config import get_logger # Type alias for alerts that can be deduplicated DeduplicableAlert = Union[TriggeredAlert, AggregatedAlert] class StateManagerError(Exception): """Raised when state management encounters an error.""" pass class StateManager: """Manages alert state persistence for deduplication.""" def __init__( self, file_path: str, dedup_window_hours: int = 24, ) -> None: """Initialize the state manager. Args: file_path: Path to the state JSON file. dedup_window_hours: Hours to retain sent alert records. """ self.file_path = Path(file_path) self.dedup_window_hours = dedup_window_hours self.logger = get_logger(__name__) self._state: Optional[AlertState] = None @property def state(self) -> AlertState: """Get the current state, loading from file if necessary.""" if self._state is None: self._state = self.load() return self._state def load(self) -> AlertState: """Load state from file. Returns: AlertState instance, empty if file doesn't exist. """ if not self.file_path.exists(): self.logger.info("state_file_not_found", path=str(self.file_path)) return AlertState() try: with open(self.file_path) as f: data = json.load(f) state = AlertState.from_dict(data) self.logger.info( "state_loaded", path=str(self.file_path), record_count=len(state.sent_alerts), ) return state except json.JSONDecodeError as e: self.logger.warning( "state_file_corrupt", path=str(self.file_path), error=str(e), ) return AlertState() def save(self) -> None: """Save state to file with atomic write. Uses write-to-temp-then-rename for crash safety. """ if self._state is None: return # Ensure directory exists self.file_path.parent.mkdir(parents=True, exist_ok=True) # Write to temp file first dir_path = self.file_path.parent try: fd, temp_path = tempfile.mkstemp( suffix=".tmp", prefix="state_", dir=dir_path, ) try: with os.fdopen(fd, "w") as f: json.dump(self._state.to_dict(), f, indent=2) # Atomic rename os.replace(temp_path, self.file_path) self.logger.debug( "state_saved", path=str(self.file_path), record_count=len(self._state.sent_alerts), ) except Exception: # Clean up temp file on error if os.path.exists(temp_path): os.unlink(temp_path) raise except OSError as e: self.logger.error("state_save_failed", error=str(e)) raise StateManagerError(f"Failed to save state: {e}") def filter_duplicates( self, alerts: list[DeduplicableAlert], ) -> list[DeduplicableAlert]: """Filter out alerts that have already been sent. Args: alerts: List of triggered or aggregated alerts. Returns: List of alerts that haven't been sent within the dedup window. """ new_alerts: list[DeduplicableAlert] = [] for alert in alerts: if not self.state.is_duplicate(alert.dedup_key): new_alerts.append(alert) else: self.logger.debug( "alert_filtered_duplicate", dedup_key=alert.dedup_key, ) filtered_count = len(alerts) - len(new_alerts) if filtered_count > 0: self.logger.info( "duplicates_filtered", total=len(alerts), new=len(new_alerts), duplicates=filtered_count, ) return new_alerts def record_sent(self, alert: DeduplicableAlert) -> None: """Record that an alert was sent. Args: alert: The triggered or aggregated alert that was sent. """ # Get the forecast hour - AggregatedAlert uses start_time, TriggeredAlert uses forecast_hour if isinstance(alert, AggregatedAlert): forecast_hour = alert.start_time else: forecast_hour = alert.forecast_hour self.state.record_sent( dedup_key=alert.dedup_key, alert_type=alert.alert_type.value, forecast_hour=forecast_hour, ) self.logger.debug("alert_recorded", dedup_key=alert.dedup_key) def purge_old_records(self) -> int: """Remove records older than the deduplication window. Returns: Number of records purged. """ purged = self.state.purge_old_records(self.dedup_window_hours) if purged > 0: self.logger.info("old_records_purged", count=purged) return purged def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None: """Save current alerts as snapshots for change detection. Args: alerts: List of aggregated alerts to save as snapshots. """ snapshots: dict[str, AlertSnapshot] = {} for alert in alerts: snapshot = AlertSnapshot( alert_type=alert.alert_type.value, extreme_value=alert.extreme_value, threshold=alert.threshold, start_time=alert.start_time, end_time=alert.end_time, hour_count=alert.hour_count, ) snapshots[alert.alert_type.value] = snapshot from datetime import datetime self.state.previous_alert_snapshots = snapshots self.state.last_updated = datetime.now() self.logger.debug( "alert_snapshots_saved", count=len(snapshots), ) def get_previous_snapshots(self) -> dict[str, AlertSnapshot]: """Get previous alert snapshots for change detection. Returns: Dict of alert snapshots keyed by alert type. """ return self.state.previous_alert_snapshots def record_ai_summary_sent(self) -> None: """Record that an AI summary was sent.""" from datetime import datetime self.state.last_ai_summary_sent = datetime.now() self.logger.debug("ai_summary_sent_recorded")