adding ai summary - default is disabled
This commit is contained in:
251
app/services/ai_summary_service.py
Normal file
251
app/services/ai_summary_service.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""AI summarization service using Replicate API."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import replicate
|
||||
|
||||
from app.models.ai_summary import ForecastContext, SummaryNotification
|
||||
from app.models.alerts import AggregatedAlert
|
||||
from app.services.change_detector import ChangeReport
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
|
||||
class AISummaryServiceError(Exception):
|
||||
"""Raised when AI summary service encounters an error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AISummaryService:
|
||||
"""Service for generating AI-powered weather alert summaries using Replicate."""
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_BACKOFF_SECONDS = 1.0
|
||||
BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
|
||||
|
||||
## Active Alerts
|
||||
{alerts_section}
|
||||
|
||||
## Changes Since Last Run
|
||||
{changes_section}
|
||||
|
||||
## Forecast for Alert Period
|
||||
{forecast_section}
|
||||
|
||||
Instructions:
|
||||
- Write 2-4 sentences max
|
||||
- Prioritize safety-critical info (severe weather, extreme temps)
|
||||
- Mention changes from previous alert
|
||||
- Include specific temperatures, wind speeds, times
|
||||
- No emojis, professional tone
|
||||
|
||||
Summary:"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_token: str,
|
||||
model: str = "meta/meta-llama-3-8b-instruct",
|
||||
api_timeout: int = 60,
|
||||
max_tokens: int = 500,
|
||||
http_client: Optional[object] = None, # Kept for API compatibility
|
||||
) -> None:
|
||||
"""Initialize the AI summary service.
|
||||
|
||||
Args:
|
||||
api_token: Replicate API token.
|
||||
model: Model identifier to use.
|
||||
api_timeout: Maximum time to wait for response in seconds (unused with library).
|
||||
max_tokens: Maximum tokens in generated response.
|
||||
http_client: Unused - kept for API compatibility.
|
||||
"""
|
||||
self.api_token = api_token
|
||||
self.model = model
|
||||
self.api_timeout = api_timeout
|
||||
self.max_tokens = max_tokens
|
||||
self.logger = get_logger(__name__)
|
||||
self._retry_count = 0
|
||||
|
||||
# Set API token for replicate library
|
||||
os.environ["REPLICATE_API_TOKEN"] = api_token
|
||||
|
||||
self.logger.debug(
|
||||
"AISummaryService initialized",
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
alerts: list[AggregatedAlert],
|
||||
change_report: ChangeReport,
|
||||
forecast_context: ForecastContext,
|
||||
location: str,
|
||||
) -> SummaryNotification:
|
||||
"""Generate an AI summary of the weather alerts.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts to summarize.
|
||||
change_report: Report of changes from previous run.
|
||||
forecast_context: Forecast data for the alert period.
|
||||
location: Location name for the summary.
|
||||
|
||||
Returns:
|
||||
SummaryNotification with the AI-generated summary.
|
||||
|
||||
Raises:
|
||||
AISummaryServiceError: If API call fails or times out.
|
||||
"""
|
||||
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
|
||||
|
||||
self.logger.debug(
|
||||
"generating_ai_summary",
|
||||
model=self.model,
|
||||
alert_count=len(alerts),
|
||||
has_changes=change_report.has_changes,
|
||||
)
|
||||
|
||||
summary_text = self._run_with_retry(prompt)
|
||||
|
||||
return SummaryNotification(
|
||||
title=f"Weather Alert Summary - {location}",
|
||||
message=summary_text,
|
||||
location=location,
|
||||
alert_count=len(alerts),
|
||||
has_changes=change_report.has_changes,
|
||||
)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
alerts: list[AggregatedAlert],
|
||||
change_report: ChangeReport,
|
||||
forecast_context: ForecastContext,
|
||||
location: str,
|
||||
) -> str:
|
||||
"""Build the prompt for the LLM.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts.
|
||||
change_report: Change report from detector.
|
||||
forecast_context: Forecast context data.
|
||||
location: Location name.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string.
|
||||
"""
|
||||
# Format alerts section
|
||||
alert_lines = []
|
||||
for alert in alerts:
|
||||
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||
alert_lines.append(
|
||||
f"- {alert_type}: {alert.extreme_value:.0f} "
|
||||
f"(threshold: {alert.threshold:.0f}) "
|
||||
f"from {alert.start_time} to {alert.end_time}"
|
||||
)
|
||||
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
|
||||
|
||||
# Format changes section
|
||||
changes_section = change_report.to_prompt_text()
|
||||
|
||||
# Format forecast section
|
||||
forecast_section = forecast_context.to_prompt_text()
|
||||
|
||||
return self.PROMPT_TEMPLATE.format(
|
||||
location=location,
|
||||
alerts_section=alerts_section,
|
||||
changes_section=changes_section,
|
||||
forecast_section=forecast_section,
|
||||
)
|
||||
|
||||
def _run_with_retry(self, prompt: str) -> str:
|
||||
"""Execute model call with retry logic.
|
||||
|
||||
Implements exponential backoff retry strategy for transient errors:
|
||||
- Rate limit errors (429)
|
||||
- Server errors (5xx)
|
||||
- Timeout errors
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the model.
|
||||
|
||||
Returns:
|
||||
Generated text from the model.
|
||||
|
||||
Raises:
|
||||
AISummaryServiceError: If all retries fail.
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
output = replicate.run(
|
||||
self.model,
|
||||
input={
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
# Replicate returns a generator for streaming models
|
||||
# Collect all output parts into a single string
|
||||
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||
response = "".join(str(part) for part in output)
|
||||
else:
|
||||
response = str(output)
|
||||
|
||||
self.logger.debug(
|
||||
"replicate_api_call_succeeded",
|
||||
model=self.model,
|
||||
attempt=attempt,
|
||||
response_length=len(response),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
self.logger.warning(
|
||||
"replicate_api_call_failed_retrying",
|
||||
model=self.model,
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
self.logger.error(
|
||||
"replicate_api_call_failed_after_retries",
|
||||
model=self.model,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
raise AISummaryServiceError(
|
||||
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
Reference in New Issue
Block a user