252 lines
7.9 KiB
Python
252 lines
7.9 KiB
Python
"""AI summarization service using Replicate API."""
|
|
|
|
import os
|
|
import time
|
|
from typing import Optional
|
|
|
|
import replicate
|
|
|
|
from app.models.ai_summary import ForecastContext, SummaryNotification
|
|
from app.models.alerts import AggregatedAlert
|
|
from app.services.change_detector import ChangeReport
|
|
from app.utils.logging_config import get_logger
|
|
|
|
|
|
class AISummaryServiceError(Exception):
|
|
"""Raised when AI summary service encounters an error."""
|
|
|
|
pass
|
|
|
|
|
|
class AISummaryService:
|
|
"""Service for generating AI-powered weather alert summaries using Replicate."""
|
|
|
|
# Retry configuration
|
|
MAX_RETRIES = 3
|
|
INITIAL_BACKOFF_SECONDS = 1.0
|
|
BACKOFF_MULTIPLIER = 2.0
|
|
|
|
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
|
|
|
|
## Active Alerts
|
|
{alerts_section}
|
|
|
|
## Changes Since Last Run
|
|
{changes_section}
|
|
|
|
## Forecast for Alert Period
|
|
{forecast_section}
|
|
|
|
Instructions:
|
|
- Write 2-4 sentences max
|
|
- Prioritize safety-critical info (severe weather, extreme temps)
|
|
- Mention changes from previous alert
|
|
- Include specific temperatures, wind speeds, times
|
|
- No emojis, professional tone
|
|
|
|
Summary:"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_token: str,
|
|
model: str = "meta/meta-llama-3-8b-instruct",
|
|
api_timeout: int = 60,
|
|
max_tokens: int = 500,
|
|
http_client: Optional[object] = None, # Kept for API compatibility
|
|
) -> None:
|
|
"""Initialize the AI summary service.
|
|
|
|
Args:
|
|
api_token: Replicate API token.
|
|
model: Model identifier to use.
|
|
api_timeout: Maximum time to wait for response in seconds (unused with library).
|
|
max_tokens: Maximum tokens in generated response.
|
|
http_client: Unused - kept for API compatibility.
|
|
"""
|
|
self.api_token = api_token
|
|
self.model = model
|
|
self.api_timeout = api_timeout
|
|
self.max_tokens = max_tokens
|
|
self.logger = get_logger(__name__)
|
|
self._retry_count = 0
|
|
|
|
# Set API token for replicate library
|
|
os.environ["REPLICATE_API_TOKEN"] = api_token
|
|
|
|
self.logger.debug(
|
|
"AISummaryService initialized",
|
|
model=self.model,
|
|
)
|
|
|
|
def summarize(
|
|
self,
|
|
alerts: list[AggregatedAlert],
|
|
change_report: ChangeReport,
|
|
forecast_context: ForecastContext,
|
|
location: str,
|
|
) -> SummaryNotification:
|
|
"""Generate an AI summary of the weather alerts.
|
|
|
|
Args:
|
|
alerts: List of aggregated alerts to summarize.
|
|
change_report: Report of changes from previous run.
|
|
forecast_context: Forecast data for the alert period.
|
|
location: Location name for the summary.
|
|
|
|
Returns:
|
|
SummaryNotification with the AI-generated summary.
|
|
|
|
Raises:
|
|
AISummaryServiceError: If API call fails or times out.
|
|
"""
|
|
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
|
|
|
|
self.logger.debug(
|
|
"generating_ai_summary",
|
|
model=self.model,
|
|
alert_count=len(alerts),
|
|
has_changes=change_report.has_changes,
|
|
)
|
|
|
|
summary_text = self._run_with_retry(prompt)
|
|
|
|
return SummaryNotification(
|
|
title=f"Weather Alert Summary - {location}",
|
|
message=summary_text,
|
|
location=location,
|
|
alert_count=len(alerts),
|
|
has_changes=change_report.has_changes,
|
|
)
|
|
|
|
def _build_prompt(
|
|
self,
|
|
alerts: list[AggregatedAlert],
|
|
change_report: ChangeReport,
|
|
forecast_context: ForecastContext,
|
|
location: str,
|
|
) -> str:
|
|
"""Build the prompt for the LLM.
|
|
|
|
Args:
|
|
alerts: List of aggregated alerts.
|
|
change_report: Change report from detector.
|
|
forecast_context: Forecast context data.
|
|
location: Location name.
|
|
|
|
Returns:
|
|
Formatted prompt string.
|
|
"""
|
|
# Format alerts section
|
|
alert_lines = []
|
|
for alert in alerts:
|
|
alert_type = alert.alert_type.value.replace("_", " ").title()
|
|
alert_lines.append(
|
|
f"- {alert_type}: {alert.extreme_value:.0f} "
|
|
f"(threshold: {alert.threshold:.0f}) "
|
|
f"from {alert.start_time} to {alert.end_time}"
|
|
)
|
|
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
|
|
|
|
# Format changes section
|
|
changes_section = change_report.to_prompt_text()
|
|
|
|
# Format forecast section
|
|
forecast_section = forecast_context.to_prompt_text()
|
|
|
|
return self.PROMPT_TEMPLATE.format(
|
|
location=location,
|
|
alerts_section=alerts_section,
|
|
changes_section=changes_section,
|
|
forecast_section=forecast_section,
|
|
)
|
|
|
|
def _run_with_retry(self, prompt: str) -> str:
|
|
"""Execute model call with retry logic.
|
|
|
|
Implements exponential backoff retry strategy for transient errors:
|
|
- Rate limit errors (429)
|
|
- Server errors (5xx)
|
|
- Timeout errors
|
|
|
|
Args:
|
|
prompt: The prompt to send to the model.
|
|
|
|
Returns:
|
|
Generated text from the model.
|
|
|
|
Raises:
|
|
AISummaryServiceError: If all retries fail.
|
|
"""
|
|
last_error: Optional[Exception] = None
|
|
backoff = self.INITIAL_BACKOFF_SECONDS
|
|
|
|
for attempt in range(1, self.MAX_RETRIES + 1):
|
|
try:
|
|
output = replicate.run(
|
|
self.model,
|
|
input={
|
|
"prompt": prompt,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": 0.7,
|
|
},
|
|
)
|
|
|
|
# Replicate returns a generator for streaming models
|
|
# Collect all output parts into a single string
|
|
if hasattr(output, "__iter__") and not isinstance(output, str):
|
|
response = "".join(str(part) for part in output)
|
|
else:
|
|
response = str(output)
|
|
|
|
self.logger.debug(
|
|
"replicate_api_call_succeeded",
|
|
model=self.model,
|
|
attempt=attempt,
|
|
response_length=len(response),
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
self._retry_count += 1
|
|
|
|
# Determine if error is retryable based on error message
|
|
error_str = str(e).lower()
|
|
is_rate_limit = "rate" in error_str or "429" in error_str
|
|
is_server_error = (
|
|
"500" in error_str
|
|
or "502" in error_str
|
|
or "503" in error_str
|
|
or "504" in error_str
|
|
)
|
|
is_timeout = "timeout" in error_str
|
|
is_retryable = is_rate_limit or is_server_error or is_timeout
|
|
|
|
if attempt < self.MAX_RETRIES and is_retryable:
|
|
self.logger.warning(
|
|
"replicate_api_call_failed_retrying",
|
|
model=self.model,
|
|
attempt=attempt,
|
|
max_retries=self.MAX_RETRIES,
|
|
error=str(e),
|
|
backoff_seconds=backoff,
|
|
)
|
|
time.sleep(backoff)
|
|
backoff *= self.BACKOFF_MULTIPLIER
|
|
else:
|
|
# Non-retryable error or max retries reached
|
|
break
|
|
|
|
# All retries exhausted or non-retryable error
|
|
self.logger.error(
|
|
"replicate_api_call_failed_after_retries",
|
|
model=self.model,
|
|
retry_count=self.MAX_RETRIES,
|
|
error=str(last_error),
|
|
)
|
|
|
|
raise AISummaryServiceError(
|
|
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
|
|
)
|