Files
weather-alerts/app/services/state_manager.py

230 lines
6.9 KiB
Python

"""State manager for alert deduplication with atomic file persistence."""
import json
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
from app.models.alerts import AggregatedAlert, TriggeredAlert
from app.models.state import AlertSnapshot, AlertState
from app.utils.logging_config import get_logger
# Type alias for alerts that can be deduplicated
DeduplicableAlert = Union[TriggeredAlert, AggregatedAlert]
class StateManagerError(Exception):
"""Raised when state management encounters an error."""
pass
class StateManager:
"""Manages alert state persistence for deduplication."""
def __init__(
self,
file_path: str,
dedup_window_hours: int = 24,
) -> None:
"""Initialize the state manager.
Args:
file_path: Path to the state JSON file.
dedup_window_hours: Hours to retain sent alert records.
"""
self.file_path = Path(file_path)
self.dedup_window_hours = dedup_window_hours
self.logger = get_logger(__name__)
self._state: Optional[AlertState] = None
@property
def state(self) -> AlertState:
"""Get the current state, loading from file if necessary."""
if self._state is None:
self._state = self.load()
return self._state
def load(self) -> AlertState:
"""Load state from file.
Returns:
AlertState instance, empty if file doesn't exist.
"""
if not self.file_path.exists():
self.logger.info("state_file_not_found", path=str(self.file_path))
return AlertState()
try:
with open(self.file_path) as f:
data = json.load(f)
state = AlertState.from_dict(data)
self.logger.info(
"state_loaded",
path=str(self.file_path),
record_count=len(state.sent_alerts),
)
return state
except json.JSONDecodeError as e:
self.logger.warning(
"state_file_corrupt",
path=str(self.file_path),
error=str(e),
)
return AlertState()
def save(self) -> None:
"""Save state to file with atomic write.
Uses write-to-temp-then-rename for crash safety.
"""
if self._state is None:
return
# Ensure directory exists
self.file_path.parent.mkdir(parents=True, exist_ok=True)
# Write to temp file first
dir_path = self.file_path.parent
try:
fd, temp_path = tempfile.mkstemp(
suffix=".tmp",
prefix="state_",
dir=dir_path,
)
try:
with os.fdopen(fd, "w") as f:
json.dump(self._state.to_dict(), f, indent=2)
# Atomic rename
os.replace(temp_path, self.file_path)
self.logger.debug(
"state_saved",
path=str(self.file_path),
record_count=len(self._state.sent_alerts),
)
except Exception:
# Clean up temp file on error
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
except OSError as e:
self.logger.error("state_save_failed", error=str(e))
raise StateManagerError(f"Failed to save state: {e}")
def filter_duplicates(
self,
alerts: list[DeduplicableAlert],
) -> list[DeduplicableAlert]:
"""Filter out alerts that have already been sent.
Args:
alerts: List of triggered or aggregated alerts.
Returns:
List of alerts that haven't been sent within the dedup window.
"""
new_alerts: list[DeduplicableAlert] = []
for alert in alerts:
if not self.state.is_duplicate(alert.dedup_key):
new_alerts.append(alert)
else:
self.logger.debug(
"alert_filtered_duplicate",
dedup_key=alert.dedup_key,
)
filtered_count = len(alerts) - len(new_alerts)
if filtered_count > 0:
self.logger.info(
"duplicates_filtered",
total=len(alerts),
new=len(new_alerts),
duplicates=filtered_count,
)
return new_alerts
def record_sent(self, alert: DeduplicableAlert) -> None:
"""Record that an alert was sent.
Args:
alert: The triggered or aggregated alert that was sent.
"""
# Get the forecast hour - AggregatedAlert uses start_time, TriggeredAlert uses forecast_hour
if isinstance(alert, AggregatedAlert):
forecast_hour = alert.start_time
else:
forecast_hour = alert.forecast_hour
self.state.record_sent(
dedup_key=alert.dedup_key,
alert_type=alert.alert_type.value,
forecast_hour=forecast_hour,
)
self.logger.debug("alert_recorded", dedup_key=alert.dedup_key)
def purge_old_records(self) -> int:
"""Remove records older than the deduplication window.
Returns:
Number of records purged.
"""
purged = self.state.purge_old_records(self.dedup_window_hours)
if purged > 0:
self.logger.info("old_records_purged", count=purged)
return purged
def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None:
"""Save current alerts as snapshots for change detection.
Args:
alerts: List of aggregated alerts to save as snapshots.
"""
snapshots: dict[str, AlertSnapshot] = {}
for alert in alerts:
snapshot = AlertSnapshot(
alert_type=alert.alert_type.value,
extreme_value=alert.extreme_value,
threshold=alert.threshold,
start_time=alert.start_time,
end_time=alert.end_time,
hour_count=alert.hour_count,
)
snapshots[alert.alert_type.value] = snapshot
from datetime import datetime
self.state.previous_alert_snapshots = snapshots
self.state.last_updated = datetime.now()
self.logger.debug(
"alert_snapshots_saved",
count=len(snapshots),
)
def get_previous_snapshots(self) -> dict[str, AlertSnapshot]:
"""Get previous alert snapshots for change detection.
Returns:
Dict of alert snapshots keyed by alert type.
"""
return self.state.previous_alert_snapshots
def record_ai_summary_sent(self) -> None:
"""Record that an AI summary was sent."""
from datetime import datetime
self.state.last_ai_summary_sent = datetime.now()
self.logger.debug("ai_summary_sent_recorded")