Files
weather-alerts/app/services/ai_summary_service.py

252 lines
7.9 KiB
Python

"""AI summarization service using Replicate API."""
import os
import time
from typing import Optional
import replicate
from app.models.ai_summary import ForecastContext, SummaryNotification
from app.models.alerts import AggregatedAlert
from app.services.change_detector import ChangeReport
from app.utils.logging_config import get_logger
class AISummaryServiceError(Exception):
"""Raised when AI summary service encounters an error."""
pass
class AISummaryService:
"""Service for generating AI-powered weather alert summaries using Replicate."""
# Retry configuration
MAX_RETRIES = 3
INITIAL_BACKOFF_SECONDS = 1.0
BACKOFF_MULTIPLIER = 2.0
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
## Active Alerts
{alerts_section}
## Changes Since Last Run
{changes_section}
## Forecast for Alert Period
{forecast_section}
Instructions:
- Write 2-4 sentences max
- Prioritize safety-critical info (severe weather, extreme temps)
- Mention changes from previous alert
- Include specific temperatures, wind speeds, times
- No emojis, professional tone
Summary:"""
def __init__(
self,
api_token: str,
model: str = "meta/meta-llama-3-8b-instruct",
api_timeout: int = 60,
max_tokens: int = 500,
http_client: Optional[object] = None, # Kept for API compatibility
) -> None:
"""Initialize the AI summary service.
Args:
api_token: Replicate API token.
model: Model identifier to use.
api_timeout: Maximum time to wait for response in seconds (unused with library).
max_tokens: Maximum tokens in generated response.
http_client: Unused - kept for API compatibility.
"""
self.api_token = api_token
self.model = model
self.api_timeout = api_timeout
self.max_tokens = max_tokens
self.logger = get_logger(__name__)
self._retry_count = 0
# Set API token for replicate library
os.environ["REPLICATE_API_TOKEN"] = api_token
self.logger.debug(
"AISummaryService initialized",
model=self.model,
)
def summarize(
self,
alerts: list[AggregatedAlert],
change_report: ChangeReport,
forecast_context: ForecastContext,
location: str,
) -> SummaryNotification:
"""Generate an AI summary of the weather alerts.
Args:
alerts: List of aggregated alerts to summarize.
change_report: Report of changes from previous run.
forecast_context: Forecast data for the alert period.
location: Location name for the summary.
Returns:
SummaryNotification with the AI-generated summary.
Raises:
AISummaryServiceError: If API call fails or times out.
"""
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
self.logger.debug(
"generating_ai_summary",
model=self.model,
alert_count=len(alerts),
has_changes=change_report.has_changes,
)
summary_text = self._run_with_retry(prompt)
return SummaryNotification(
title=f"Weather Alert Summary - {location}",
message=summary_text,
location=location,
alert_count=len(alerts),
has_changes=change_report.has_changes,
)
def _build_prompt(
self,
alerts: list[AggregatedAlert],
change_report: ChangeReport,
forecast_context: ForecastContext,
location: str,
) -> str:
"""Build the prompt for the LLM.
Args:
alerts: List of aggregated alerts.
change_report: Change report from detector.
forecast_context: Forecast context data.
location: Location name.
Returns:
Formatted prompt string.
"""
# Format alerts section
alert_lines = []
for alert in alerts:
alert_type = alert.alert_type.value.replace("_", " ").title()
alert_lines.append(
f"- {alert_type}: {alert.extreme_value:.0f} "
f"(threshold: {alert.threshold:.0f}) "
f"from {alert.start_time} to {alert.end_time}"
)
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
# Format changes section
changes_section = change_report.to_prompt_text()
# Format forecast section
forecast_section = forecast_context.to_prompt_text()
return self.PROMPT_TEMPLATE.format(
location=location,
alerts_section=alerts_section,
changes_section=changes_section,
forecast_section=forecast_section,
)
def _run_with_retry(self, prompt: str) -> str:
"""Execute model call with retry logic.
Implements exponential backoff retry strategy for transient errors:
- Rate limit errors (429)
- Server errors (5xx)
- Timeout errors
Args:
prompt: The prompt to send to the model.
Returns:
Generated text from the model.
Raises:
AISummaryServiceError: If all retries fail.
"""
last_error: Optional[Exception] = None
backoff = self.INITIAL_BACKOFF_SECONDS
for attempt in range(1, self.MAX_RETRIES + 1):
try:
output = replicate.run(
self.model,
input={
"prompt": prompt,
"max_tokens": self.max_tokens,
"temperature": 0.7,
},
)
# Replicate returns a generator for streaming models
# Collect all output parts into a single string
if hasattr(output, "__iter__") and not isinstance(output, str):
response = "".join(str(part) for part in output)
else:
response = str(output)
self.logger.debug(
"replicate_api_call_succeeded",
model=self.model,
attempt=attempt,
response_length=len(response),
)
return response
except Exception as e:
last_error = e
self._retry_count += 1
# Determine if error is retryable based on error message
error_str = str(e).lower()
is_rate_limit = "rate" in error_str or "429" in error_str
is_server_error = (
"500" in error_str
or "502" in error_str
or "503" in error_str
or "504" in error_str
)
is_timeout = "timeout" in error_str
is_retryable = is_rate_limit or is_server_error or is_timeout
if attempt < self.MAX_RETRIES and is_retryable:
self.logger.warning(
"replicate_api_call_failed_retrying",
model=self.model,
attempt=attempt,
max_retries=self.MAX_RETRIES,
error=str(e),
backoff_seconds=backoff,
)
time.sleep(backoff)
backoff *= self.BACKOFF_MULTIPLIER
else:
# Non-retryable error or max retries reached
break
# All retries exhausted or non-retryable error
self.logger.error(
"replicate_api_call_failed_after_retries",
model=self.model,
retry_count=self.MAX_RETRIES,
error=str(last_error),
)
raise AISummaryServiceError(
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
)