adding ai summary - default is disabled
This commit is contained in:
251
app/services/ai_summary_service.py
Normal file
251
app/services/ai_summary_service.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""AI summarization service using Replicate API."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import replicate
|
||||
|
||||
from app.models.ai_summary import ForecastContext, SummaryNotification
|
||||
from app.models.alerts import AggregatedAlert
|
||||
from app.services.change_detector import ChangeReport
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
|
||||
class AISummaryServiceError(Exception):
|
||||
"""Raised when AI summary service encounters an error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AISummaryService:
|
||||
"""Service for generating AI-powered weather alert summaries using Replicate."""
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_BACKOFF_SECONDS = 1.0
|
||||
BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
|
||||
|
||||
## Active Alerts
|
||||
{alerts_section}
|
||||
|
||||
## Changes Since Last Run
|
||||
{changes_section}
|
||||
|
||||
## Forecast for Alert Period
|
||||
{forecast_section}
|
||||
|
||||
Instructions:
|
||||
- Write 2-4 sentences max
|
||||
- Prioritize safety-critical info (severe weather, extreme temps)
|
||||
- Mention changes from previous alert
|
||||
- Include specific temperatures, wind speeds, times
|
||||
- No emojis, professional tone
|
||||
|
||||
Summary:"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_token: str,
|
||||
model: str = "meta/meta-llama-3-8b-instruct",
|
||||
api_timeout: int = 60,
|
||||
max_tokens: int = 500,
|
||||
http_client: Optional[object] = None, # Kept for API compatibility
|
||||
) -> None:
|
||||
"""Initialize the AI summary service.
|
||||
|
||||
Args:
|
||||
api_token: Replicate API token.
|
||||
model: Model identifier to use.
|
||||
api_timeout: Maximum time to wait for response in seconds (unused with library).
|
||||
max_tokens: Maximum tokens in generated response.
|
||||
http_client: Unused - kept for API compatibility.
|
||||
"""
|
||||
self.api_token = api_token
|
||||
self.model = model
|
||||
self.api_timeout = api_timeout
|
||||
self.max_tokens = max_tokens
|
||||
self.logger = get_logger(__name__)
|
||||
self._retry_count = 0
|
||||
|
||||
# Set API token for replicate library
|
||||
os.environ["REPLICATE_API_TOKEN"] = api_token
|
||||
|
||||
self.logger.debug(
|
||||
"AISummaryService initialized",
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
alerts: list[AggregatedAlert],
|
||||
change_report: ChangeReport,
|
||||
forecast_context: ForecastContext,
|
||||
location: str,
|
||||
) -> SummaryNotification:
|
||||
"""Generate an AI summary of the weather alerts.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts to summarize.
|
||||
change_report: Report of changes from previous run.
|
||||
forecast_context: Forecast data for the alert period.
|
||||
location: Location name for the summary.
|
||||
|
||||
Returns:
|
||||
SummaryNotification with the AI-generated summary.
|
||||
|
||||
Raises:
|
||||
AISummaryServiceError: If API call fails or times out.
|
||||
"""
|
||||
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
|
||||
|
||||
self.logger.debug(
|
||||
"generating_ai_summary",
|
||||
model=self.model,
|
||||
alert_count=len(alerts),
|
||||
has_changes=change_report.has_changes,
|
||||
)
|
||||
|
||||
summary_text = self._run_with_retry(prompt)
|
||||
|
||||
return SummaryNotification(
|
||||
title=f"Weather Alert Summary - {location}",
|
||||
message=summary_text,
|
||||
location=location,
|
||||
alert_count=len(alerts),
|
||||
has_changes=change_report.has_changes,
|
||||
)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
alerts: list[AggregatedAlert],
|
||||
change_report: ChangeReport,
|
||||
forecast_context: ForecastContext,
|
||||
location: str,
|
||||
) -> str:
|
||||
"""Build the prompt for the LLM.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts.
|
||||
change_report: Change report from detector.
|
||||
forecast_context: Forecast context data.
|
||||
location: Location name.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string.
|
||||
"""
|
||||
# Format alerts section
|
||||
alert_lines = []
|
||||
for alert in alerts:
|
||||
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||
alert_lines.append(
|
||||
f"- {alert_type}: {alert.extreme_value:.0f} "
|
||||
f"(threshold: {alert.threshold:.0f}) "
|
||||
f"from {alert.start_time} to {alert.end_time}"
|
||||
)
|
||||
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
|
||||
|
||||
# Format changes section
|
||||
changes_section = change_report.to_prompt_text()
|
||||
|
||||
# Format forecast section
|
||||
forecast_section = forecast_context.to_prompt_text()
|
||||
|
||||
return self.PROMPT_TEMPLATE.format(
|
||||
location=location,
|
||||
alerts_section=alerts_section,
|
||||
changes_section=changes_section,
|
||||
forecast_section=forecast_section,
|
||||
)
|
||||
|
||||
def _run_with_retry(self, prompt: str) -> str:
|
||||
"""Execute model call with retry logic.
|
||||
|
||||
Implements exponential backoff retry strategy for transient errors:
|
||||
- Rate limit errors (429)
|
||||
- Server errors (5xx)
|
||||
- Timeout errors
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the model.
|
||||
|
||||
Returns:
|
||||
Generated text from the model.
|
||||
|
||||
Raises:
|
||||
AISummaryServiceError: If all retries fail.
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
output = replicate.run(
|
||||
self.model,
|
||||
input={
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
# Replicate returns a generator for streaming models
|
||||
# Collect all output parts into a single string
|
||||
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||
response = "".join(str(part) for part in output)
|
||||
else:
|
||||
response = str(output)
|
||||
|
||||
self.logger.debug(
|
||||
"replicate_api_call_succeeded",
|
||||
model=self.model,
|
||||
attempt=attempt,
|
||||
response_length=len(response),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
self.logger.warning(
|
||||
"replicate_api_call_failed_retrying",
|
||||
model=self.model,
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
self.logger.error(
|
||||
"replicate_api_call_failed_after_retries",
|
||||
model=self.model,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
raise AISummaryServiceError(
|
||||
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
233
app/services/change_detector.py
Normal file
233
app/services/change_detector.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Change detection service for comparing alerts between runs."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.config.loader import ChangeThresholds
|
||||
from app.models.alerts import AggregatedAlert
|
||||
from app.models.state import AlertSnapshot
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
|
||||
class ChangeType(Enum):
|
||||
"""Types of changes that can be detected between runs."""
|
||||
|
||||
NEW = "new"
|
||||
REMOVED = "removed"
|
||||
VALUE_CHANGED = "value_changed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertChange:
|
||||
"""Represents a detected change in an alert."""
|
||||
|
||||
alert_type: str
|
||||
change_type: ChangeType
|
||||
description: str
|
||||
previous_value: Optional[float] = None
|
||||
current_value: Optional[float] = None
|
||||
value_delta: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeReport:
|
||||
"""Report of all changes detected between runs."""
|
||||
|
||||
changes: list[AlertChange] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if any changes were detected."""
|
||||
return len(self.changes) > 0
|
||||
|
||||
@property
|
||||
def new_alerts(self) -> list[AlertChange]:
|
||||
"""Get list of new alert changes."""
|
||||
return [c for c in self.changes if c.change_type == ChangeType.NEW]
|
||||
|
||||
@property
|
||||
def removed_alerts(self) -> list[AlertChange]:
|
||||
"""Get list of removed alert changes."""
|
||||
return [c for c in self.changes if c.change_type == ChangeType.REMOVED]
|
||||
|
||||
@property
|
||||
def value_changes(self) -> list[AlertChange]:
|
||||
"""Get list of value change alerts."""
|
||||
return [c for c in self.changes if c.change_type == ChangeType.VALUE_CHANGED]
|
||||
|
||||
def to_prompt_text(self) -> str:
|
||||
"""Format change report for LLM prompt."""
|
||||
if not self.has_changes:
|
||||
return "No significant changes from previous alert."
|
||||
|
||||
lines = []
|
||||
|
||||
for change in self.new_alerts:
|
||||
lines.append(f"NEW: {change.description}")
|
||||
|
||||
for change in self.removed_alerts:
|
||||
lines.append(f"RESOLVED: {change.description}")
|
||||
|
||||
for change in self.value_changes:
|
||||
lines.append(f"CHANGED: {change.description}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ChangeDetector:
|
||||
"""Detects changes between current alerts and previous snapshots."""
|
||||
|
||||
# Map alert types to their value thresholds
|
||||
ALERT_TYPE_TO_THRESHOLD_KEY = {
|
||||
"temperature_low": "temperature",
|
||||
"temperature_high": "temperature",
|
||||
"wind_speed": "wind_speed",
|
||||
"wind_gust": "wind_gust",
|
||||
"precipitation": "precipitation_prob",
|
||||
}
|
||||
|
||||
def __init__(self, thresholds: ChangeThresholds) -> None:
|
||||
"""Initialize the change detector.
|
||||
|
||||
Args:
|
||||
thresholds: Thresholds for detecting significant changes.
|
||||
"""
|
||||
self.thresholds = thresholds
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
current_alerts: list[AggregatedAlert],
|
||||
previous_snapshots: dict[str, AlertSnapshot],
|
||||
) -> ChangeReport:
|
||||
"""Detect changes between current alerts and previous snapshots.
|
||||
|
||||
Args:
|
||||
current_alerts: List of current aggregated alerts.
|
||||
previous_snapshots: Dict of previous alert snapshots keyed by alert type.
|
||||
|
||||
Returns:
|
||||
ChangeReport containing all detected changes.
|
||||
"""
|
||||
changes: list[AlertChange] = []
|
||||
current_types = {alert.alert_type.value for alert in current_alerts}
|
||||
previous_types = set(previous_snapshots.keys())
|
||||
|
||||
# Detect new alerts
|
||||
for alert in current_alerts:
|
||||
alert_type = alert.alert_type.value
|
||||
if alert_type not in previous_types:
|
||||
changes.append(
|
||||
AlertChange(
|
||||
alert_type=alert_type,
|
||||
change_type=ChangeType.NEW,
|
||||
description=self._format_new_alert_description(alert),
|
||||
current_value=alert.extreme_value,
|
||||
)
|
||||
)
|
||||
self.logger.debug(
|
||||
"change_detected_new",
|
||||
alert_type=alert_type,
|
||||
)
|
||||
else:
|
||||
# Check for significant value changes
|
||||
prev_snapshot = previous_snapshots[alert_type]
|
||||
value_change = self._detect_value_change(alert, prev_snapshot)
|
||||
if value_change:
|
||||
changes.append(value_change)
|
||||
self.logger.debug(
|
||||
"change_detected_value",
|
||||
alert_type=alert_type,
|
||||
previous=prev_snapshot.extreme_value,
|
||||
current=alert.extreme_value,
|
||||
)
|
||||
|
||||
# Detect removed alerts
|
||||
for alert_type in previous_types - current_types:
|
||||
prev_snapshot = previous_snapshots[alert_type]
|
||||
changes.append(
|
||||
AlertChange(
|
||||
alert_type=alert_type,
|
||||
change_type=ChangeType.REMOVED,
|
||||
description=self._format_removed_alert_description(prev_snapshot),
|
||||
previous_value=prev_snapshot.extreme_value,
|
||||
)
|
||||
)
|
||||
self.logger.debug(
|
||||
"change_detected_removed",
|
||||
alert_type=alert_type,
|
||||
)
|
||||
|
||||
report = ChangeReport(changes=changes)
|
||||
|
||||
self.logger.info(
|
||||
"change_detection_complete",
|
||||
total_changes=len(changes),
|
||||
new_count=len(report.new_alerts),
|
||||
removed_count=len(report.removed_alerts),
|
||||
value_changes_count=len(report.value_changes),
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
def _detect_value_change(
|
||||
self,
|
||||
current: AggregatedAlert,
|
||||
previous: AlertSnapshot,
|
||||
) -> Optional[AlertChange]:
|
||||
"""Detect if there's a significant value change between current and previous.
|
||||
|
||||
Args:
|
||||
current: Current aggregated alert.
|
||||
previous: Previous alert snapshot.
|
||||
|
||||
Returns:
|
||||
AlertChange if significant change detected, None otherwise.
|
||||
"""
|
||||
threshold_key = self.ALERT_TYPE_TO_THRESHOLD_KEY.get(current.alert_type.value)
|
||||
if not threshold_key:
|
||||
return None
|
||||
|
||||
threshold = getattr(self.thresholds, threshold_key, None)
|
||||
if threshold is None:
|
||||
return None
|
||||
|
||||
delta = abs(current.extreme_value - previous.extreme_value)
|
||||
if delta >= threshold:
|
||||
return AlertChange(
|
||||
alert_type=current.alert_type.value,
|
||||
change_type=ChangeType.VALUE_CHANGED,
|
||||
description=self._format_value_change_description(
|
||||
current, previous, delta
|
||||
),
|
||||
previous_value=previous.extreme_value,
|
||||
current_value=current.extreme_value,
|
||||
value_delta=delta,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _format_new_alert_description(self, alert: AggregatedAlert) -> str:
|
||||
"""Format description for a new alert."""
|
||||
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||
return f"{alert_type} alert: {alert.extreme_value:.0f} at {alert.extreme_hour}"
|
||||
|
||||
def _format_removed_alert_description(self, snapshot: AlertSnapshot) -> str:
|
||||
"""Format description for a removed alert."""
|
||||
alert_type = snapshot.alert_type.replace("_", " ").title()
|
||||
return f"{alert_type} alert no longer active"
|
||||
|
||||
def _format_value_change_description(
|
||||
self,
|
||||
current: AggregatedAlert,
|
||||
previous: AlertSnapshot,
|
||||
delta: float,
|
||||
) -> str:
|
||||
"""Format description for a value change."""
|
||||
alert_type = current.alert_type.value.replace("_", " ").title()
|
||||
direction = "increased" if current.extreme_value > previous.extreme_value else "decreased"
|
||||
return (
|
||||
f"{alert_type} {direction} from {previous.extreme_value:.0f} "
|
||||
f"to {current.extreme_value:.0f}"
|
||||
)
|
||||
@@ -3,6 +3,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from app.models.ai_summary import SummaryNotification
|
||||
from app.models.alerts import AggregatedAlert, AlertType, TriggeredAlert
|
||||
from app.utils.http_client import HttpClient
|
||||
from app.utils.logging_config import get_logger
|
||||
@@ -20,6 +21,15 @@ class NotificationResult:
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryNotificationResult:
|
||||
"""Result of sending a summary notification."""
|
||||
|
||||
summary: SummaryNotification
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class NotificationServiceError(Exception):
|
||||
"""Raised when notification service encounters an error."""
|
||||
|
||||
@@ -175,3 +185,54 @@ class NotificationService:
|
||||
tags = list(self.ALERT_TYPE_TAGS.get(alert.alert_type, self.default_tags))
|
||||
|
||||
return tags
|
||||
|
||||
def send_summary(self, summary: SummaryNotification) -> SummaryNotificationResult:
|
||||
"""Send an AI-generated summary notification.
|
||||
|
||||
Args:
|
||||
summary: The summary notification to send.
|
||||
|
||||
Returns:
|
||||
SummaryNotificationResult indicating success or failure.
|
||||
"""
|
||||
url = f"{self.server_url}/{self.topic}"
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"Title": summary.title,
|
||||
"Priority": self.priority,
|
||||
"Tags": ",".join(summary.tags),
|
||||
}
|
||||
|
||||
if self.access_token:
|
||||
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
|
||||
self.logger.debug(
|
||||
"sending_summary_notification",
|
||||
location=summary.location,
|
||||
alert_count=summary.alert_count,
|
||||
)
|
||||
|
||||
response = self.http_client.post(
|
||||
url,
|
||||
data=summary.message.encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.success:
|
||||
self.logger.info(
|
||||
"summary_notification_sent",
|
||||
location=summary.location,
|
||||
alert_count=summary.alert_count,
|
||||
has_changes=summary.has_changes,
|
||||
)
|
||||
return SummaryNotificationResult(summary=summary, success=True)
|
||||
else:
|
||||
error_msg = f"HTTP {response.status_code}: {response.text[:100]}"
|
||||
self.logger.error(
|
||||
"summary_notification_failed",
|
||||
error=error_msg,
|
||||
)
|
||||
return SummaryNotificationResult(
|
||||
summary=summary, success=False, error=error_msg
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from app.models.alerts import AggregatedAlert, TriggeredAlert
|
||||
from app.models.state import AlertState
|
||||
from app.models.state import AlertSnapshot, AlertState
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
# Type alias for alerts that can be deduplicated
|
||||
@@ -183,3 +183,47 @@ class StateManager:
|
||||
self.logger.info("old_records_purged", count=purged)
|
||||
|
||||
return purged
|
||||
|
||||
def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None:
|
||||
"""Save current alerts as snapshots for change detection.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts to save as snapshots.
|
||||
"""
|
||||
snapshots: dict[str, AlertSnapshot] = {}
|
||||
|
||||
for alert in alerts:
|
||||
snapshot = AlertSnapshot(
|
||||
alert_type=alert.alert_type.value,
|
||||
extreme_value=alert.extreme_value,
|
||||
threshold=alert.threshold,
|
||||
start_time=alert.start_time,
|
||||
end_time=alert.end_time,
|
||||
hour_count=alert.hour_count,
|
||||
)
|
||||
snapshots[alert.alert_type.value] = snapshot
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
self.state.previous_alert_snapshots = snapshots
|
||||
self.state.last_updated = datetime.now()
|
||||
|
||||
self.logger.debug(
|
||||
"alert_snapshots_saved",
|
||||
count=len(snapshots),
|
||||
)
|
||||
|
||||
def get_previous_snapshots(self) -> dict[str, AlertSnapshot]:
|
||||
"""Get previous alert snapshots for change detection.
|
||||
|
||||
Returns:
|
||||
Dict of alert snapshots keyed by alert type.
|
||||
"""
|
||||
return self.state.previous_alert_snapshots
|
||||
|
||||
def record_ai_summary_sent(self) -> None:
|
||||
"""Record that an AI summary was sent."""
|
||||
from datetime import datetime
|
||||
|
||||
self.state.last_ai_summary_sent = datetime.now()
|
||||
self.logger.debug("ai_summary_sent_recorded")
|
||||
|
||||
Reference in New Issue
Block a user