"""AI summarization service using Replicate API.""" import os import time from typing import Optional import replicate from app.models.ai_summary import ForecastContext, SummaryNotification from app.models.alerts import AggregatedAlert from app.services.change_detector import ChangeReport from app.utils.logging_config import get_logger class AISummaryServiceError(Exception): """Raised when AI summary service encounters an error.""" pass class AISummaryService: """Service for generating AI-powered weather alert summaries using Replicate.""" # Retry configuration MAX_RETRIES = 3 INITIAL_BACKOFF_SECONDS = 1.0 BACKOFF_MULTIPLIER = 2.0 PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}. ## Active Alerts {alerts_section} ## Changes Since Last Run {changes_section} ## Forecast for Alert Period {forecast_section} Instructions: - Write 2-4 sentences max - Prioritize safety-critical info (severe weather, extreme temps) - Mention changes from previous alert - Include specific temperatures, wind speeds, times - No emojis, professional tone Summary:""" def __init__( self, api_token: str, model: str = "meta/meta-llama-3-8b-instruct", api_timeout: int = 60, max_tokens: int = 500, http_client: Optional[object] = None, # Kept for API compatibility ) -> None: """Initialize the AI summary service. Args: api_token: Replicate API token. model: Model identifier to use. api_timeout: Maximum time to wait for response in seconds (unused with library). max_tokens: Maximum tokens in generated response. http_client: Unused - kept for API compatibility. """ self.api_token = api_token self.model = model self.api_timeout = api_timeout self.max_tokens = max_tokens self.logger = get_logger(__name__) self._retry_count = 0 # Set API token for replicate library os.environ["REPLICATE_API_TOKEN"] = api_token self.logger.debug( "AISummaryService initialized", model=self.model, ) def summarize( self, alerts: list[AggregatedAlert], change_report: ChangeReport, forecast_context: ForecastContext, location: str, ) -> SummaryNotification: """Generate an AI summary of the weather alerts. Args: alerts: List of aggregated alerts to summarize. change_report: Report of changes from previous run. forecast_context: Forecast data for the alert period. location: Location name for the summary. Returns: SummaryNotification with the AI-generated summary. Raises: AISummaryServiceError: If API call fails or times out. """ prompt = self._build_prompt(alerts, change_report, forecast_context, location) self.logger.debug( "generating_ai_summary", model=self.model, alert_count=len(alerts), has_changes=change_report.has_changes, ) summary_text = self._run_with_retry(prompt) return SummaryNotification( title=f"Weather Alert Summary - {location}", message=summary_text, location=location, alert_count=len(alerts), has_changes=change_report.has_changes, ) def _build_prompt( self, alerts: list[AggregatedAlert], change_report: ChangeReport, forecast_context: ForecastContext, location: str, ) -> str: """Build the prompt for the LLM. Args: alerts: List of aggregated alerts. change_report: Change report from detector. forecast_context: Forecast context data. location: Location name. Returns: Formatted prompt string. """ # Format alerts section alert_lines = [] for alert in alerts: alert_type = alert.alert_type.value.replace("_", " ").title() alert_lines.append( f"- {alert_type}: {alert.extreme_value:.0f} " f"(threshold: {alert.threshold:.0f}) " f"from {alert.start_time} to {alert.end_time}" ) alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts." # Format changes section changes_section = change_report.to_prompt_text() # Format forecast section forecast_section = forecast_context.to_prompt_text() return self.PROMPT_TEMPLATE.format( location=location, alerts_section=alerts_section, changes_section=changes_section, forecast_section=forecast_section, ) def _run_with_retry(self, prompt: str) -> str: """Execute model call with retry logic. Implements exponential backoff retry strategy for transient errors: - Rate limit errors (429) - Server errors (5xx) - Timeout errors Args: prompt: The prompt to send to the model. Returns: Generated text from the model. Raises: AISummaryServiceError: If all retries fail. """ last_error: Optional[Exception] = None backoff = self.INITIAL_BACKOFF_SECONDS for attempt in range(1, self.MAX_RETRIES + 1): try: output = replicate.run( self.model, input={ "prompt": prompt, "max_tokens": self.max_tokens, "temperature": 0.7, }, ) # Replicate returns a generator for streaming models # Collect all output parts into a single string if hasattr(output, "__iter__") and not isinstance(output, str): response = "".join(str(part) for part in output) else: response = str(output) self.logger.debug( "replicate_api_call_succeeded", model=self.model, attempt=attempt, response_length=len(response), ) return response except Exception as e: last_error = e self._retry_count += 1 # Determine if error is retryable based on error message error_str = str(e).lower() is_rate_limit = "rate" in error_str or "429" in error_str is_server_error = ( "500" in error_str or "502" in error_str or "503" in error_str or "504" in error_str ) is_timeout = "timeout" in error_str is_retryable = is_rate_limit or is_server_error or is_timeout if attempt < self.MAX_RETRIES and is_retryable: self.logger.warning( "replicate_api_call_failed_retrying", model=self.model, attempt=attempt, max_retries=self.MAX_RETRIES, error=str(e), backoff_seconds=backoff, ) time.sleep(backoff) backoff *= self.BACKOFF_MULTIPLIER else: # Non-retryable error or max retries reached break # All retries exhausted or non-retryable error self.logger.error( "replicate_api_call_failed_after_retries", model=self.model, retry_count=self.MAX_RETRIES, error=str(last_error), ) raise AISummaryServiceError( f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}" )