adding ai summary - default is disabled

This commit is contained in:
2026-01-26 17:13:11 -06:00
parent 2820944ec6
commit 921b6a81a4
13 changed files with 1357 additions and 43 deletions

View File

@@ -0,0 +1,251 @@
"""AI summarization service using Replicate API."""
import os
import time
from typing import Optional
import replicate
from app.models.ai_summary import ForecastContext, SummaryNotification
from app.models.alerts import AggregatedAlert
from app.services.change_detector import ChangeReport
from app.utils.logging_config import get_logger
class AISummaryServiceError(Exception):
"""Raised when AI summary service encounters an error."""
pass
class AISummaryService:
"""Service for generating AI-powered weather alert summaries using Replicate."""
# Retry configuration
MAX_RETRIES = 3
INITIAL_BACKOFF_SECONDS = 1.0
BACKOFF_MULTIPLIER = 2.0
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
## Active Alerts
{alerts_section}
## Changes Since Last Run
{changes_section}
## Forecast for Alert Period
{forecast_section}
Instructions:
- Write 2-4 sentences max
- Prioritize safety-critical info (severe weather, extreme temps)
- Mention changes from previous alert
- Include specific temperatures, wind speeds, times
- No emojis, professional tone
Summary:"""
def __init__(
self,
api_token: str,
model: str = "meta/meta-llama-3-8b-instruct",
api_timeout: int = 60,
max_tokens: int = 500,
http_client: Optional[object] = None, # Kept for API compatibility
) -> None:
"""Initialize the AI summary service.
Args:
api_token: Replicate API token.
model: Model identifier to use.
api_timeout: Maximum time to wait for response in seconds (unused with library).
max_tokens: Maximum tokens in generated response.
http_client: Unused - kept for API compatibility.
"""
self.api_token = api_token
self.model = model
self.api_timeout = api_timeout
self.max_tokens = max_tokens
self.logger = get_logger(__name__)
self._retry_count = 0
# Set API token for replicate library
os.environ["REPLICATE_API_TOKEN"] = api_token
self.logger.debug(
"AISummaryService initialized",
model=self.model,
)
def summarize(
self,
alerts: list[AggregatedAlert],
change_report: ChangeReport,
forecast_context: ForecastContext,
location: str,
) -> SummaryNotification:
"""Generate an AI summary of the weather alerts.
Args:
alerts: List of aggregated alerts to summarize.
change_report: Report of changes from previous run.
forecast_context: Forecast data for the alert period.
location: Location name for the summary.
Returns:
SummaryNotification with the AI-generated summary.
Raises:
AISummaryServiceError: If API call fails or times out.
"""
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
self.logger.debug(
"generating_ai_summary",
model=self.model,
alert_count=len(alerts),
has_changes=change_report.has_changes,
)
summary_text = self._run_with_retry(prompt)
return SummaryNotification(
title=f"Weather Alert Summary - {location}",
message=summary_text,
location=location,
alert_count=len(alerts),
has_changes=change_report.has_changes,
)
def _build_prompt(
self,
alerts: list[AggregatedAlert],
change_report: ChangeReport,
forecast_context: ForecastContext,
location: str,
) -> str:
"""Build the prompt for the LLM.
Args:
alerts: List of aggregated alerts.
change_report: Change report from detector.
forecast_context: Forecast context data.
location: Location name.
Returns:
Formatted prompt string.
"""
# Format alerts section
alert_lines = []
for alert in alerts:
alert_type = alert.alert_type.value.replace("_", " ").title()
alert_lines.append(
f"- {alert_type}: {alert.extreme_value:.0f} "
f"(threshold: {alert.threshold:.0f}) "
f"from {alert.start_time} to {alert.end_time}"
)
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
# Format changes section
changes_section = change_report.to_prompt_text()
# Format forecast section
forecast_section = forecast_context.to_prompt_text()
return self.PROMPT_TEMPLATE.format(
location=location,
alerts_section=alerts_section,
changes_section=changes_section,
forecast_section=forecast_section,
)
def _run_with_retry(self, prompt: str) -> str:
"""Execute model call with retry logic.
Implements exponential backoff retry strategy for transient errors:
- Rate limit errors (429)
- Server errors (5xx)
- Timeout errors
Args:
prompt: The prompt to send to the model.
Returns:
Generated text from the model.
Raises:
AISummaryServiceError: If all retries fail.
"""
last_error: Optional[Exception] = None
backoff = self.INITIAL_BACKOFF_SECONDS
for attempt in range(1, self.MAX_RETRIES + 1):
try:
output = replicate.run(
self.model,
input={
"prompt": prompt,
"max_tokens": self.max_tokens,
"temperature": 0.7,
},
)
# Replicate returns a generator for streaming models
# Collect all output parts into a single string
if hasattr(output, "__iter__") and not isinstance(output, str):
response = "".join(str(part) for part in output)
else:
response = str(output)
self.logger.debug(
"replicate_api_call_succeeded",
model=self.model,
attempt=attempt,
response_length=len(response),
)
return response
except Exception as e:
last_error = e
self._retry_count += 1
# Determine if error is retryable based on error message
error_str = str(e).lower()
is_rate_limit = "rate" in error_str or "429" in error_str
is_server_error = (
"500" in error_str
or "502" in error_str
or "503" in error_str
or "504" in error_str
)
is_timeout = "timeout" in error_str
is_retryable = is_rate_limit or is_server_error or is_timeout
if attempt < self.MAX_RETRIES and is_retryable:
self.logger.warning(
"replicate_api_call_failed_retrying",
model=self.model,
attempt=attempt,
max_retries=self.MAX_RETRIES,
error=str(e),
backoff_seconds=backoff,
)
time.sleep(backoff)
backoff *= self.BACKOFF_MULTIPLIER
else:
# Non-retryable error or max retries reached
break
# All retries exhausted or non-retryable error
self.logger.error(
"replicate_api_call_failed_after_retries",
model=self.model,
retry_count=self.MAX_RETRIES,
error=str(last_error),
)
raise AISummaryServiceError(
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
)

View File

@@ -0,0 +1,233 @@
"""Change detection service for comparing alerts between runs."""
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from app.config.loader import ChangeThresholds
from app.models.alerts import AggregatedAlert
from app.models.state import AlertSnapshot
from app.utils.logging_config import get_logger
class ChangeType(Enum):
"""Types of changes that can be detected between runs."""
NEW = "new"
REMOVED = "removed"
VALUE_CHANGED = "value_changed"
@dataclass
class AlertChange:
"""Represents a detected change in an alert."""
alert_type: str
change_type: ChangeType
description: str
previous_value: Optional[float] = None
current_value: Optional[float] = None
value_delta: Optional[float] = None
@dataclass
class ChangeReport:
"""Report of all changes detected between runs."""
changes: list[AlertChange] = field(default_factory=list)
@property
def has_changes(self) -> bool:
"""Check if any changes were detected."""
return len(self.changes) > 0
@property
def new_alerts(self) -> list[AlertChange]:
"""Get list of new alert changes."""
return [c for c in self.changes if c.change_type == ChangeType.NEW]
@property
def removed_alerts(self) -> list[AlertChange]:
"""Get list of removed alert changes."""
return [c for c in self.changes if c.change_type == ChangeType.REMOVED]
@property
def value_changes(self) -> list[AlertChange]:
"""Get list of value change alerts."""
return [c for c in self.changes if c.change_type == ChangeType.VALUE_CHANGED]
def to_prompt_text(self) -> str:
"""Format change report for LLM prompt."""
if not self.has_changes:
return "No significant changes from previous alert."
lines = []
for change in self.new_alerts:
lines.append(f"NEW: {change.description}")
for change in self.removed_alerts:
lines.append(f"RESOLVED: {change.description}")
for change in self.value_changes:
lines.append(f"CHANGED: {change.description}")
return "\n".join(lines)
class ChangeDetector:
"""Detects changes between current alerts and previous snapshots."""
# Map alert types to their value thresholds
ALERT_TYPE_TO_THRESHOLD_KEY = {
"temperature_low": "temperature",
"temperature_high": "temperature",
"wind_speed": "wind_speed",
"wind_gust": "wind_gust",
"precipitation": "precipitation_prob",
}
def __init__(self, thresholds: ChangeThresholds) -> None:
"""Initialize the change detector.
Args:
thresholds: Thresholds for detecting significant changes.
"""
self.thresholds = thresholds
self.logger = get_logger(__name__)
def detect(
self,
current_alerts: list[AggregatedAlert],
previous_snapshots: dict[str, AlertSnapshot],
) -> ChangeReport:
"""Detect changes between current alerts and previous snapshots.
Args:
current_alerts: List of current aggregated alerts.
previous_snapshots: Dict of previous alert snapshots keyed by alert type.
Returns:
ChangeReport containing all detected changes.
"""
changes: list[AlertChange] = []
current_types = {alert.alert_type.value for alert in current_alerts}
previous_types = set(previous_snapshots.keys())
# Detect new alerts
for alert in current_alerts:
alert_type = alert.alert_type.value
if alert_type not in previous_types:
changes.append(
AlertChange(
alert_type=alert_type,
change_type=ChangeType.NEW,
description=self._format_new_alert_description(alert),
current_value=alert.extreme_value,
)
)
self.logger.debug(
"change_detected_new",
alert_type=alert_type,
)
else:
# Check for significant value changes
prev_snapshot = previous_snapshots[alert_type]
value_change = self._detect_value_change(alert, prev_snapshot)
if value_change:
changes.append(value_change)
self.logger.debug(
"change_detected_value",
alert_type=alert_type,
previous=prev_snapshot.extreme_value,
current=alert.extreme_value,
)
# Detect removed alerts
for alert_type in previous_types - current_types:
prev_snapshot = previous_snapshots[alert_type]
changes.append(
AlertChange(
alert_type=alert_type,
change_type=ChangeType.REMOVED,
description=self._format_removed_alert_description(prev_snapshot),
previous_value=prev_snapshot.extreme_value,
)
)
self.logger.debug(
"change_detected_removed",
alert_type=alert_type,
)
report = ChangeReport(changes=changes)
self.logger.info(
"change_detection_complete",
total_changes=len(changes),
new_count=len(report.new_alerts),
removed_count=len(report.removed_alerts),
value_changes_count=len(report.value_changes),
)
return report
def _detect_value_change(
self,
current: AggregatedAlert,
previous: AlertSnapshot,
) -> Optional[AlertChange]:
"""Detect if there's a significant value change between current and previous.
Args:
current: Current aggregated alert.
previous: Previous alert snapshot.
Returns:
AlertChange if significant change detected, None otherwise.
"""
threshold_key = self.ALERT_TYPE_TO_THRESHOLD_KEY.get(current.alert_type.value)
if not threshold_key:
return None
threshold = getattr(self.thresholds, threshold_key, None)
if threshold is None:
return None
delta = abs(current.extreme_value - previous.extreme_value)
if delta >= threshold:
return AlertChange(
alert_type=current.alert_type.value,
change_type=ChangeType.VALUE_CHANGED,
description=self._format_value_change_description(
current, previous, delta
),
previous_value=previous.extreme_value,
current_value=current.extreme_value,
value_delta=delta,
)
return None
def _format_new_alert_description(self, alert: AggregatedAlert) -> str:
"""Format description for a new alert."""
alert_type = alert.alert_type.value.replace("_", " ").title()
return f"{alert_type} alert: {alert.extreme_value:.0f} at {alert.extreme_hour}"
def _format_removed_alert_description(self, snapshot: AlertSnapshot) -> str:
"""Format description for a removed alert."""
alert_type = snapshot.alert_type.replace("_", " ").title()
return f"{alert_type} alert no longer active"
def _format_value_change_description(
self,
current: AggregatedAlert,
previous: AlertSnapshot,
delta: float,
) -> str:
"""Format description for a value change."""
alert_type = current.alert_type.value.replace("_", " ").title()
direction = "increased" if current.extreme_value > previous.extreme_value else "decreased"
return (
f"{alert_type} {direction} from {previous.extreme_value:.0f} "
f"to {current.extreme_value:.0f}"
)

View File

@@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Optional, Union
from app.models.ai_summary import SummaryNotification
from app.models.alerts import AggregatedAlert, AlertType, TriggeredAlert
from app.utils.http_client import HttpClient
from app.utils.logging_config import get_logger
@@ -20,6 +21,15 @@ class NotificationResult:
error: Optional[str] = None
@dataclass
class SummaryNotificationResult:
"""Result of sending a summary notification."""
summary: SummaryNotification
success: bool
error: Optional[str] = None
class NotificationServiceError(Exception):
"""Raised when notification service encounters an error."""
@@ -175,3 +185,54 @@ class NotificationService:
tags = list(self.ALERT_TYPE_TAGS.get(alert.alert_type, self.default_tags))
return tags
def send_summary(self, summary: SummaryNotification) -> SummaryNotificationResult:
"""Send an AI-generated summary notification.
Args:
summary: The summary notification to send.
Returns:
SummaryNotificationResult indicating success or failure.
"""
url = f"{self.server_url}/{self.topic}"
# Build headers
headers = {
"Title": summary.title,
"Priority": self.priority,
"Tags": ",".join(summary.tags),
}
if self.access_token:
headers["Authorization"] = f"Bearer {self.access_token}"
self.logger.debug(
"sending_summary_notification",
location=summary.location,
alert_count=summary.alert_count,
)
response = self.http_client.post(
url,
data=summary.message.encode("utf-8"),
headers=headers,
)
if response.success:
self.logger.info(
"summary_notification_sent",
location=summary.location,
alert_count=summary.alert_count,
has_changes=summary.has_changes,
)
return SummaryNotificationResult(summary=summary, success=True)
else:
error_msg = f"HTTP {response.status_code}: {response.text[:100]}"
self.logger.error(
"summary_notification_failed",
error=error_msg,
)
return SummaryNotificationResult(
summary=summary, success=False, error=error_msg
)

View File

@@ -7,7 +7,7 @@ from pathlib import Path
from typing import Optional, Union
from app.models.alerts import AggregatedAlert, TriggeredAlert
from app.models.state import AlertState
from app.models.state import AlertSnapshot, AlertState
from app.utils.logging_config import get_logger
# Type alias for alerts that can be deduplicated
@@ -183,3 +183,47 @@ class StateManager:
self.logger.info("old_records_purged", count=purged)
return purged
def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None:
"""Save current alerts as snapshots for change detection.
Args:
alerts: List of aggregated alerts to save as snapshots.
"""
snapshots: dict[str, AlertSnapshot] = {}
for alert in alerts:
snapshot = AlertSnapshot(
alert_type=alert.alert_type.value,
extreme_value=alert.extreme_value,
threshold=alert.threshold,
start_time=alert.start_time,
end_time=alert.end_time,
hour_count=alert.hour_count,
)
snapshots[alert.alert_type.value] = snapshot
from datetime import datetime
self.state.previous_alert_snapshots = snapshots
self.state.last_updated = datetime.now()
self.logger.debug(
"alert_snapshots_saved",
count=len(snapshots),
)
def get_previous_snapshots(self) -> dict[str, AlertSnapshot]:
"""Get previous alert snapshots for change detection.
Returns:
Dict of alert snapshots keyed by alert type.
"""
return self.state.previous_alert_snapshots
def record_ai_summary_sent(self) -> None:
"""Record that an AI summary was sent."""
from datetime import datetime
self.state.last_ai_summary_sent = datetime.now()
self.logger.debug("ai_summary_sent_recorded")