diff --git a/.env.example b/.env.example index 400fe67..0a4a0b7 100644 --- a/.env.example +++ b/.env.example @@ -9,3 +9,6 @@ VISUALCROSSING_API_KEY=your_api_key_here # Ntfy Access Token # Required if your ntfy server requires authentication NTFY_ACCESS_TOKEN=your_ntfy_token_here + +# Replicate API Token +REPLICATE_API_TOKEN=your_replicate_token_here \ No newline at end of file diff --git a/.gitea/workflows/weather-alerts.yml b/.gitea/workflows/weather-alerts.yml index 72e1909..da13260 100644 --- a/.gitea/workflows/weather-alerts.yml +++ b/.gitea/workflows/weather-alerts.yml @@ -2,7 +2,7 @@ name: Weather Alerts on: schedule: - - cron: '0 * * * *' # Every hour at :00 + - cron: '0 */4 * * *' # Every 4th hour at :00 workflow_dispatch: {} # Manual trigger jobs: diff --git a/app/config/loader.py b/app/config/loader.py index a39a631..d8c7763 100644 --- a/app/config/loader.py +++ b/app/config/loader.py @@ -63,6 +63,35 @@ class AlertSettings: rules: AlertRules = field(default_factory=AlertRules) +@dataclass +class AISettings: + """AI summarization settings.""" + + enabled: bool = False + model: str = "meta/meta-llama-3-8b-instruct" + api_timeout: int = 60 + max_tokens: int = 500 + api_token: str = "" + + +@dataclass +class ChangeThresholds: + """Thresholds for detecting significant changes between runs.""" + + temperature: float = 5.0 + wind_speed: float = 10.0 + wind_gust: float = 10.0 + precipitation_prob: float = 20.0 + + +@dataclass +class ChangeDetectionSettings: + """Change detection settings.""" + + enabled: bool = True + thresholds: ChangeThresholds = field(default_factory=ChangeThresholds) + + @dataclass class AppConfig: """Complete application configuration.""" @@ -72,6 +101,8 @@ class AppConfig: alerts: AlertSettings = field(default_factory=AlertSettings) notifications: NotificationSettings = field(default_factory=NotificationSettings) state: StateSettings = field(default_factory=StateSettings) + ai: AISettings = field(default_factory=AISettings) + change_detection: ChangeDetectionSettings = field(default_factory=ChangeDetectionSettings) def load_config( @@ -113,6 +144,8 @@ def load_config( alerts_data = config_data.get("alerts", {}) notifications_data = config_data.get("notifications", {}) state_data = config_data.get("state", {}) + ai_data = config_data.get("ai", {}) + change_detection_data = config_data.get("change_detection", {}) # Build app settings app_settings = AppSettings( @@ -150,10 +183,34 @@ def load_config( dedup_window_hours=state_data.get("dedup_window_hours", 24), ) + # Build AI settings with token from environment + ai_settings = AISettings( + enabled=ai_data.get("enabled", False), + model=ai_data.get("model", "meta/meta-llama-3-8b-instruct"), + api_timeout=ai_data.get("api_timeout", 60), + max_tokens=ai_data.get("max_tokens", 500), + api_token=os.environ.get("REPLICATE_API_TOKEN", ""), + ) + + # Build change detection settings + thresholds_data = change_detection_data.get("thresholds", {}) + change_thresholds = ChangeThresholds( + temperature=thresholds_data.get("temperature", 5.0), + wind_speed=thresholds_data.get("wind_speed", 10.0), + wind_gust=thresholds_data.get("wind_gust", 10.0), + precipitation_prob=thresholds_data.get("precipitation_prob", 20.0), + ) + change_detection_settings = ChangeDetectionSettings( + enabled=change_detection_data.get("enabled", True), + thresholds=change_thresholds, + ) + return AppConfig( app=app_settings, weather=weather_settings, alerts=alert_settings, notifications=notification_settings, state=state_settings, + ai=ai_settings, + change_detection=change_detection_settings, ) diff --git a/app/config/settings.yaml b/app/config/settings.yaml index 656225a..bb3a225 100644 --- a/app/config/settings.yaml +++ b/app/config/settings.yaml @@ -5,7 +5,7 @@ app: weather: location: "viola,tn" - hours_ahead: 24 + hours_ahead: 4 # Matches 4-hour run frequency unit_group: "us" alerts: @@ -34,3 +34,18 @@ notifications: state: file_path: "./data/state.json" dedup_window_hours: 24 + +ai: + enabled: true + model: "meta/meta-llama-3-8b-instruct" + api_timeout: 60 + max_tokens: 500 + +# These metrics are used for change detection from previous forecast +change_detection: + enabled: true + thresholds: + temperature: 5.0 # degrees F + wind_speed: 10.0 # mph + wind_gust: 10.0 # mph + precipitation_prob: 20.0 # percentage points diff --git a/app/main.py b/app/main.py index cb5e93c..18949a7 100644 --- a/app/main.py +++ b/app/main.py @@ -4,7 +4,12 @@ import sys from typing import Optional from app.config.loader import AppConfig, load_config +from app.models.ai_summary import ForecastContext +from app.models.alerts import AggregatedAlert +from app.models.weather import WeatherForecast +from app.services.ai_summary_service import AISummaryService, AISummaryServiceError from app.services.alert_aggregator import AlertAggregator +from app.services.change_detector import ChangeDetector, ChangeReport from app.services.notification_service import NotificationService from app.services.rule_engine import RuleEngine from app.services.state_manager import StateManager @@ -52,6 +57,26 @@ class WeatherAlertsApp: http_client=self.http_client, ) + # Initialize AI services conditionally + self.ai_enabled = config.ai.enabled and bool(config.ai.api_token) + self.ai_summary_service: Optional[AISummaryService] = None + self.change_detector: Optional[ChangeDetector] = None + + if self.ai_enabled: + self.ai_summary_service = AISummaryService( + api_token=config.ai.api_token, + model=config.ai.model, + api_timeout=config.ai.api_timeout, + max_tokens=config.ai.max_tokens, + http_client=self.http_client, + ) + self.logger.info("ai_summary_enabled", model=config.ai.model) + + if config.change_detection.enabled: + self.change_detector = ChangeDetector( + thresholds=config.change_detection.thresholds + ) + def run(self) -> int: """Execute the main application flow. @@ -62,6 +87,7 @@ class WeatherAlertsApp: "app_starting", version=self.config.app.version, location=self.config.weather.location, + ai_enabled=self.ai_enabled, ) try: @@ -79,6 +105,9 @@ class WeatherAlertsApp: if not triggered_alerts: self.logger.info("no_alerts_triggered") + # Clear snapshots when no alerts + self.state_manager.save_alert_snapshots([]) + self.state_manager.save() return 0 self.logger.info( @@ -96,42 +125,11 @@ class WeatherAlertsApp: output_count=len(aggregated_alerts), ) - # Step 3: Filter duplicates - self.logger.info("step_filter_duplicates") - new_alerts = self.state_manager.filter_duplicates(aggregated_alerts) - - if not new_alerts: - self.logger.info("all_alerts_are_duplicates") - return 0 - - # Step 4: Send notifications - self.logger.info( - "step_send_notifications", - count=len(new_alerts), - ) - results = self.notification_service.send_batch(new_alerts) - - # Step 5: Record sent alerts - self.logger.info("step_record_sent") - for result in results: - if result.success: - self.state_manager.record_sent(result.alert) - - # Step 6: Purge old records and save state - self.state_manager.purge_old_records() - self.state_manager.save() - - # Report results - success_count = sum(1 for r in results if r.success) - failed_count = len(results) - success_count - - self.logger.info( - "app_complete", - alerts_sent=success_count, - alerts_failed=failed_count, - ) - - return 0 if failed_count == 0 else 1 + # Branch based on AI enabled + if self.ai_enabled: + return self._run_ai_flow(forecast, aggregated_alerts) + else: + return self._run_standard_flow(aggregated_alerts) except WeatherServiceError as e: self.logger.error("weather_service_error", error=str(e)) @@ -144,6 +142,165 @@ class WeatherAlertsApp: finally: self.http_client.close() + def _run_standard_flow(self, aggregated_alerts: list[AggregatedAlert]) -> int: + """Run the standard notification flow without AI. + + Args: + aggregated_alerts: List of aggregated alerts. + + Returns: + Exit code (0 for success, 1 for error). + """ + # Filter duplicates + self.logger.info("step_filter_duplicates") + new_alerts = self.state_manager.filter_duplicates(aggregated_alerts) + + if not new_alerts: + self.logger.info("all_alerts_are_duplicates") + self._finalize(aggregated_alerts) + return 0 + + # Send notifications + self.logger.info( + "step_send_notifications", + count=len(new_alerts), + ) + results = self.notification_service.send_batch(new_alerts) + + # Record sent alerts + self.logger.info("step_record_sent") + for result in results: + if result.success: + self.state_manager.record_sent(result.alert) + + self._finalize(aggregated_alerts) + + # Report results + success_count = sum(1 for r in results if r.success) + failed_count = len(results) - success_count + + self.logger.info( + "app_complete", + alerts_sent=success_count, + alerts_failed=failed_count, + ) + + return 0 if failed_count == 0 else 1 + + def _run_ai_flow( + self, + forecast: WeatherForecast, + aggregated_alerts: list[AggregatedAlert], + ) -> int: + """Run the AI summarization flow. + + Args: + forecast: The weather forecast data. + aggregated_alerts: List of aggregated alerts. + + Returns: + Exit code (0 for success, 1 for error). + """ + self.logger.info("step_ai_summary_flow") + + # Build forecast context + forecast_context = self._build_forecast_context(forecast) + + # Detect changes from previous run + change_report = ChangeReport() + if self.change_detector: + previous_snapshots = self.state_manager.get_previous_snapshots() + change_report = self.change_detector.detect( + aggregated_alerts, previous_snapshots + ) + + # Try to generate AI summary + try: + assert self.ai_summary_service is not None + summary = self.ai_summary_service.summarize( + alerts=aggregated_alerts, + change_report=change_report, + forecast_context=forecast_context, + location=self.config.weather.location, + ) + + # Send summary notification + self.logger.info("step_send_ai_summary") + result = self.notification_service.send_summary(summary) + + if result.success: + self.state_manager.record_ai_summary_sent() + self._finalize(aggregated_alerts) + self.logger.info( + "app_complete_ai", + summary_sent=True, + alert_count=len(aggregated_alerts), + ) + return 0 + else: + self.logger.warning( + "ai_summary_send_failed", + error=result.error, + fallback="individual_alerts", + ) + return self._run_standard_flow(aggregated_alerts) + + except AISummaryServiceError as e: + self.logger.warning( + "ai_summary_generation_failed", + error=str(e), + fallback="individual_alerts", + ) + return self._run_standard_flow(aggregated_alerts) + + def _build_forecast_context(self, forecast: WeatherForecast) -> ForecastContext: + """Build forecast context from weather forecast data. + + Args: + forecast: The weather forecast data. + + Returns: + ForecastContext for AI summary. + """ + if not forecast.hourly_forecasts: + return ForecastContext( + start_time="N/A", + end_time="N/A", + min_temp=0, + max_temp=0, + max_wind_speed=0, + max_wind_gust=0, + max_precip_prob=0, + conditions=[], + ) + + temps = [h.temp for h in forecast.hourly_forecasts] + wind_speeds = [h.wind_speed for h in forecast.hourly_forecasts] + wind_gusts = [h.wind_gust for h in forecast.hourly_forecasts] + precip_probs = [h.precip_prob for h in forecast.hourly_forecasts] + conditions = [h.conditions for h in forecast.hourly_forecasts] + + return ForecastContext( + start_time=forecast.hourly_forecasts[0].hour_key, + end_time=forecast.hourly_forecasts[-1].hour_key, + min_temp=min(temps), + max_temp=max(temps), + max_wind_speed=max(wind_speeds), + max_wind_gust=max(wind_gusts), + max_precip_prob=max(precip_probs), + conditions=conditions, + ) + + def _finalize(self, aggregated_alerts: list[AggregatedAlert]) -> None: + """Finalize the run by saving state. + + Args: + aggregated_alerts: Current alerts to save as snapshots. + """ + self.state_manager.save_alert_snapshots(aggregated_alerts) + self.state_manager.purge_old_records() + self.state_manager.save() + def main(config_path: Optional[str] = None) -> int: """Main entry point for the application. @@ -176,6 +333,12 @@ def main(config_path: Optional[str] = None) -> int: hint="Set NTFY_ACCESS_TOKEN if your server requires auth", ) + if config.ai.enabled and not config.ai.api_token: + logger.warning( + "ai_enabled_but_no_token", + hint="Set REPLICATE_API_TOKEN to enable AI summaries", + ) + # Run the application app = WeatherAlertsApp(config) return app.run() diff --git a/app/models/ai_summary.py b/app/models/ai_summary.py new file mode 100644 index 0000000..cd64cc7 --- /dev/null +++ b/app/models/ai_summary.py @@ -0,0 +1,74 @@ +"""Data models for AI summarization feature.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + + +@dataclass +class ForecastContext: + """Forecast data to include in AI summary prompt.""" + + start_time: str + end_time: str + min_temp: float + max_temp: float + max_wind_speed: float + max_wind_gust: float + max_precip_prob: float + conditions: list[str] = field(default_factory=list) + + def to_prompt_text(self) -> str: + """Format forecast context for LLM prompt.""" + unique_conditions = list(dict.fromkeys(self.conditions)) + conditions_str = ", ".join(unique_conditions[:5]) if unique_conditions else "N/A" + + return ( + f"Period: {self.start_time} to {self.end_time}\n" + f"Temperature: {self.min_temp:.0f}F - {self.max_temp:.0f}F\n" + f"Wind: up to {self.max_wind_speed:.0f} mph, gusts to {self.max_wind_gust:.0f} mph\n" + f"Precipitation probability: up to {self.max_precip_prob:.0f}%\n" + f"Conditions: {conditions_str}" + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "start_time": self.start_time, + "end_time": self.end_time, + "min_temp": self.min_temp, + "max_temp": self.max_temp, + "max_wind_speed": self.max_wind_speed, + "max_wind_gust": self.max_wind_gust, + "max_precip_prob": self.max_precip_prob, + "conditions": self.conditions, + } + + +@dataclass +class SummaryNotification: + """AI-generated weather alert summary notification.""" + + title: str + message: str + location: str + alert_count: int + has_changes: bool + created_at: datetime = field(default_factory=datetime.now) + + @property + def tags(self) -> list[str]: + """Get notification tags for AI summary.""" + return ["robot", "weather", "summary"] + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "title": self.title, + "message": self.message, + "location": self.location, + "alert_count": self.alert_count, + "has_changes": self.has_changes, + "created_at": self.created_at.isoformat(), + "tags": self.tags, + } diff --git a/app/models/state.py b/app/models/state.py index 833e5f1..de2a9cb 100644 --- a/app/models/state.py +++ b/app/models/state.py @@ -1,8 +1,8 @@ -"""State management models for alert deduplication.""" +"""State management models for alert deduplication and change detection.""" from dataclasses import dataclass, field from datetime import datetime -from typing import Any +from typing import Any, Optional @dataclass @@ -41,12 +41,59 @@ class SentAlertRecord: ) +@dataclass +class AlertSnapshot: + """Snapshot of an alert for change detection between runs.""" + + alert_type: str + extreme_value: float + threshold: float + start_time: str + end_time: str + hour_count: int + captured_at: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "alert_type": self.alert_type, + "extreme_value": self.extreme_value, + "threshold": self.threshold, + "start_time": self.start_time, + "end_time": self.end_time, + "hour_count": self.hour_count, + "captured_at": self.captured_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AlertSnapshot": + """Create from dictionary. + + Args: + data: The serialized snapshot dict. + + Returns: + An AlertSnapshot instance. + """ + return cls( + alert_type=data["alert_type"], + extreme_value=data["extreme_value"], + threshold=data["threshold"], + start_time=data["start_time"], + end_time=data["end_time"], + hour_count=data["hour_count"], + captured_at=datetime.fromisoformat(data["captured_at"]), + ) + + @dataclass class AlertState: - """State container for tracking sent alerts.""" + """State container for tracking sent alerts and change detection.""" sent_alerts: dict[str, SentAlertRecord] = field(default_factory=dict) last_updated: datetime = field(default_factory=datetime.now) + previous_alert_snapshots: dict[str, AlertSnapshot] = field(default_factory=dict) + last_ai_summary_sent: Optional[datetime] = None def is_duplicate(self, dedup_key: str) -> bool: """Check if an alert with this dedup key has already been sent. @@ -106,6 +153,15 @@ class AlertState: key: record.to_dict() for key, record in self.sent_alerts.items() }, "last_updated": self.last_updated.isoformat(), + "previous_alert_snapshots": { + key: snapshot.to_dict() + for key, snapshot in self.previous_alert_snapshots.items() + }, + "last_ai_summary_sent": ( + self.last_ai_summary_sent.isoformat() + if self.last_ai_summary_sent + else None + ), } @classmethod @@ -130,4 +186,21 @@ class AlertState: else datetime.now() ) - return cls(sent_alerts=sent_alerts, last_updated=last_updated) + previous_alert_snapshots = { + key: AlertSnapshot.from_dict(snapshot_data) + for key, snapshot_data in data.get("previous_alert_snapshots", {}).items() + } + + last_ai_summary_str = data.get("last_ai_summary_sent") + last_ai_summary_sent = ( + datetime.fromisoformat(last_ai_summary_str) + if last_ai_summary_str + else None + ) + + return cls( + sent_alerts=sent_alerts, + last_updated=last_updated, + previous_alert_snapshots=previous_alert_snapshots, + last_ai_summary_sent=last_ai_summary_sent, + ) diff --git a/app/services/ai_summary_service.py b/app/services/ai_summary_service.py new file mode 100644 index 0000000..3620645 --- /dev/null +++ b/app/services/ai_summary_service.py @@ -0,0 +1,251 @@ +"""AI summarization service using Replicate API.""" + +import os +import time +from typing import Optional + +import replicate + +from app.models.ai_summary import ForecastContext, SummaryNotification +from app.models.alerts import AggregatedAlert +from app.services.change_detector import ChangeReport +from app.utils.logging_config import get_logger + + +class AISummaryServiceError(Exception): + """Raised when AI summary service encounters an error.""" + + pass + + +class AISummaryService: + """Service for generating AI-powered weather alert summaries using Replicate.""" + + # Retry configuration + MAX_RETRIES = 3 + INITIAL_BACKOFF_SECONDS = 1.0 + BACKOFF_MULTIPLIER = 2.0 + + PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}. + +## Active Alerts +{alerts_section} + +## Changes Since Last Run +{changes_section} + +## Forecast for Alert Period +{forecast_section} + +Instructions: +- Write 2-4 sentences max +- Prioritize safety-critical info (severe weather, extreme temps) +- Mention changes from previous alert +- Include specific temperatures, wind speeds, times +- No emojis, professional tone + +Summary:""" + + def __init__( + self, + api_token: str, + model: str = "meta/meta-llama-3-8b-instruct", + api_timeout: int = 60, + max_tokens: int = 500, + http_client: Optional[object] = None, # Kept for API compatibility + ) -> None: + """Initialize the AI summary service. + + Args: + api_token: Replicate API token. + model: Model identifier to use. + api_timeout: Maximum time to wait for response in seconds (unused with library). + max_tokens: Maximum tokens in generated response. + http_client: Unused - kept for API compatibility. + """ + self.api_token = api_token + self.model = model + self.api_timeout = api_timeout + self.max_tokens = max_tokens + self.logger = get_logger(__name__) + self._retry_count = 0 + + # Set API token for replicate library + os.environ["REPLICATE_API_TOKEN"] = api_token + + self.logger.debug( + "AISummaryService initialized", + model=self.model, + ) + + def summarize( + self, + alerts: list[AggregatedAlert], + change_report: ChangeReport, + forecast_context: ForecastContext, + location: str, + ) -> SummaryNotification: + """Generate an AI summary of the weather alerts. + + Args: + alerts: List of aggregated alerts to summarize. + change_report: Report of changes from previous run. + forecast_context: Forecast data for the alert period. + location: Location name for the summary. + + Returns: + SummaryNotification with the AI-generated summary. + + Raises: + AISummaryServiceError: If API call fails or times out. + """ + prompt = self._build_prompt(alerts, change_report, forecast_context, location) + + self.logger.debug( + "generating_ai_summary", + model=self.model, + alert_count=len(alerts), + has_changes=change_report.has_changes, + ) + + summary_text = self._run_with_retry(prompt) + + return SummaryNotification( + title=f"Weather Alert Summary - {location}", + message=summary_text, + location=location, + alert_count=len(alerts), + has_changes=change_report.has_changes, + ) + + def _build_prompt( + self, + alerts: list[AggregatedAlert], + change_report: ChangeReport, + forecast_context: ForecastContext, + location: str, + ) -> str: + """Build the prompt for the LLM. + + Args: + alerts: List of aggregated alerts. + change_report: Change report from detector. + forecast_context: Forecast context data. + location: Location name. + + Returns: + Formatted prompt string. + """ + # Format alerts section + alert_lines = [] + for alert in alerts: + alert_type = alert.alert_type.value.replace("_", " ").title() + alert_lines.append( + f"- {alert_type}: {alert.extreme_value:.0f} " + f"(threshold: {alert.threshold:.0f}) " + f"from {alert.start_time} to {alert.end_time}" + ) + alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts." + + # Format changes section + changes_section = change_report.to_prompt_text() + + # Format forecast section + forecast_section = forecast_context.to_prompt_text() + + return self.PROMPT_TEMPLATE.format( + location=location, + alerts_section=alerts_section, + changes_section=changes_section, + forecast_section=forecast_section, + ) + + def _run_with_retry(self, prompt: str) -> str: + """Execute model call with retry logic. + + Implements exponential backoff retry strategy for transient errors: + - Rate limit errors (429) + - Server errors (5xx) + - Timeout errors + + Args: + prompt: The prompt to send to the model. + + Returns: + Generated text from the model. + + Raises: + AISummaryServiceError: If all retries fail. + """ + last_error: Optional[Exception] = None + backoff = self.INITIAL_BACKOFF_SECONDS + + for attempt in range(1, self.MAX_RETRIES + 1): + try: + output = replicate.run( + self.model, + input={ + "prompt": prompt, + "max_tokens": self.max_tokens, + "temperature": 0.7, + }, + ) + + # Replicate returns a generator for streaming models + # Collect all output parts into a single string + if hasattr(output, "__iter__") and not isinstance(output, str): + response = "".join(str(part) for part in output) + else: + response = str(output) + + self.logger.debug( + "replicate_api_call_succeeded", + model=self.model, + attempt=attempt, + response_length=len(response), + ) + + return response + + except Exception as e: + last_error = e + self._retry_count += 1 + + # Determine if error is retryable based on error message + error_str = str(e).lower() + is_rate_limit = "rate" in error_str or "429" in error_str + is_server_error = ( + "500" in error_str + or "502" in error_str + or "503" in error_str + or "504" in error_str + ) + is_timeout = "timeout" in error_str + is_retryable = is_rate_limit or is_server_error or is_timeout + + if attempt < self.MAX_RETRIES and is_retryable: + self.logger.warning( + "replicate_api_call_failed_retrying", + model=self.model, + attempt=attempt, + max_retries=self.MAX_RETRIES, + error=str(e), + backoff_seconds=backoff, + ) + time.sleep(backoff) + backoff *= self.BACKOFF_MULTIPLIER + else: + # Non-retryable error or max retries reached + break + + # All retries exhausted or non-retryable error + self.logger.error( + "replicate_api_call_failed_after_retries", + model=self.model, + retry_count=self.MAX_RETRIES, + error=str(last_error), + ) + + raise AISummaryServiceError( + f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}" + ) diff --git a/app/services/change_detector.py b/app/services/change_detector.py new file mode 100644 index 0000000..25e344a --- /dev/null +++ b/app/services/change_detector.py @@ -0,0 +1,233 @@ +"""Change detection service for comparing alerts between runs.""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from app.config.loader import ChangeThresholds +from app.models.alerts import AggregatedAlert +from app.models.state import AlertSnapshot +from app.utils.logging_config import get_logger + + +class ChangeType(Enum): + """Types of changes that can be detected between runs.""" + + NEW = "new" + REMOVED = "removed" + VALUE_CHANGED = "value_changed" + + +@dataclass +class AlertChange: + """Represents a detected change in an alert.""" + + alert_type: str + change_type: ChangeType + description: str + previous_value: Optional[float] = None + current_value: Optional[float] = None + value_delta: Optional[float] = None + + +@dataclass +class ChangeReport: + """Report of all changes detected between runs.""" + + changes: list[AlertChange] = field(default_factory=list) + + @property + def has_changes(self) -> bool: + """Check if any changes were detected.""" + return len(self.changes) > 0 + + @property + def new_alerts(self) -> list[AlertChange]: + """Get list of new alert changes.""" + return [c for c in self.changes if c.change_type == ChangeType.NEW] + + @property + def removed_alerts(self) -> list[AlertChange]: + """Get list of removed alert changes.""" + return [c for c in self.changes if c.change_type == ChangeType.REMOVED] + + @property + def value_changes(self) -> list[AlertChange]: + """Get list of value change alerts.""" + return [c for c in self.changes if c.change_type == ChangeType.VALUE_CHANGED] + + def to_prompt_text(self) -> str: + """Format change report for LLM prompt.""" + if not self.has_changes: + return "No significant changes from previous alert." + + lines = [] + + for change in self.new_alerts: + lines.append(f"NEW: {change.description}") + + for change in self.removed_alerts: + lines.append(f"RESOLVED: {change.description}") + + for change in self.value_changes: + lines.append(f"CHANGED: {change.description}") + + return "\n".join(lines) + + +class ChangeDetector: + """Detects changes between current alerts and previous snapshots.""" + + # Map alert types to their value thresholds + ALERT_TYPE_TO_THRESHOLD_KEY = { + "temperature_low": "temperature", + "temperature_high": "temperature", + "wind_speed": "wind_speed", + "wind_gust": "wind_gust", + "precipitation": "precipitation_prob", + } + + def __init__(self, thresholds: ChangeThresholds) -> None: + """Initialize the change detector. + + Args: + thresholds: Thresholds for detecting significant changes. + """ + self.thresholds = thresholds + self.logger = get_logger(__name__) + + def detect( + self, + current_alerts: list[AggregatedAlert], + previous_snapshots: dict[str, AlertSnapshot], + ) -> ChangeReport: + """Detect changes between current alerts and previous snapshots. + + Args: + current_alerts: List of current aggregated alerts. + previous_snapshots: Dict of previous alert snapshots keyed by alert type. + + Returns: + ChangeReport containing all detected changes. + """ + changes: list[AlertChange] = [] + current_types = {alert.alert_type.value for alert in current_alerts} + previous_types = set(previous_snapshots.keys()) + + # Detect new alerts + for alert in current_alerts: + alert_type = alert.alert_type.value + if alert_type not in previous_types: + changes.append( + AlertChange( + alert_type=alert_type, + change_type=ChangeType.NEW, + description=self._format_new_alert_description(alert), + current_value=alert.extreme_value, + ) + ) + self.logger.debug( + "change_detected_new", + alert_type=alert_type, + ) + else: + # Check for significant value changes + prev_snapshot = previous_snapshots[alert_type] + value_change = self._detect_value_change(alert, prev_snapshot) + if value_change: + changes.append(value_change) + self.logger.debug( + "change_detected_value", + alert_type=alert_type, + previous=prev_snapshot.extreme_value, + current=alert.extreme_value, + ) + + # Detect removed alerts + for alert_type in previous_types - current_types: + prev_snapshot = previous_snapshots[alert_type] + changes.append( + AlertChange( + alert_type=alert_type, + change_type=ChangeType.REMOVED, + description=self._format_removed_alert_description(prev_snapshot), + previous_value=prev_snapshot.extreme_value, + ) + ) + self.logger.debug( + "change_detected_removed", + alert_type=alert_type, + ) + + report = ChangeReport(changes=changes) + + self.logger.info( + "change_detection_complete", + total_changes=len(changes), + new_count=len(report.new_alerts), + removed_count=len(report.removed_alerts), + value_changes_count=len(report.value_changes), + ) + + return report + + def _detect_value_change( + self, + current: AggregatedAlert, + previous: AlertSnapshot, + ) -> Optional[AlertChange]: + """Detect if there's a significant value change between current and previous. + + Args: + current: Current aggregated alert. + previous: Previous alert snapshot. + + Returns: + AlertChange if significant change detected, None otherwise. + """ + threshold_key = self.ALERT_TYPE_TO_THRESHOLD_KEY.get(current.alert_type.value) + if not threshold_key: + return None + + threshold = getattr(self.thresholds, threshold_key, None) + if threshold is None: + return None + + delta = abs(current.extreme_value - previous.extreme_value) + if delta >= threshold: + return AlertChange( + alert_type=current.alert_type.value, + change_type=ChangeType.VALUE_CHANGED, + description=self._format_value_change_description( + current, previous, delta + ), + previous_value=previous.extreme_value, + current_value=current.extreme_value, + value_delta=delta, + ) + + return None + + def _format_new_alert_description(self, alert: AggregatedAlert) -> str: + """Format description for a new alert.""" + alert_type = alert.alert_type.value.replace("_", " ").title() + return f"{alert_type} alert: {alert.extreme_value:.0f} at {alert.extreme_hour}" + + def _format_removed_alert_description(self, snapshot: AlertSnapshot) -> str: + """Format description for a removed alert.""" + alert_type = snapshot.alert_type.replace("_", " ").title() + return f"{alert_type} alert no longer active" + + def _format_value_change_description( + self, + current: AggregatedAlert, + previous: AlertSnapshot, + delta: float, + ) -> str: + """Format description for a value change.""" + alert_type = current.alert_type.value.replace("_", " ").title() + direction = "increased" if current.extreme_value > previous.extreme_value else "decreased" + return ( + f"{alert_type} {direction} from {previous.extreme_value:.0f} " + f"to {current.extreme_value:.0f}" + ) diff --git a/app/services/notification_service.py b/app/services/notification_service.py index 67fbf83..ef458c4 100644 --- a/app/services/notification_service.py +++ b/app/services/notification_service.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Optional, Union +from app.models.ai_summary import SummaryNotification from app.models.alerts import AggregatedAlert, AlertType, TriggeredAlert from app.utils.http_client import HttpClient from app.utils.logging_config import get_logger @@ -20,6 +21,15 @@ class NotificationResult: error: Optional[str] = None +@dataclass +class SummaryNotificationResult: + """Result of sending a summary notification.""" + + summary: SummaryNotification + success: bool + error: Optional[str] = None + + class NotificationServiceError(Exception): """Raised when notification service encounters an error.""" @@ -175,3 +185,54 @@ class NotificationService: tags = list(self.ALERT_TYPE_TAGS.get(alert.alert_type, self.default_tags)) return tags + + def send_summary(self, summary: SummaryNotification) -> SummaryNotificationResult: + """Send an AI-generated summary notification. + + Args: + summary: The summary notification to send. + + Returns: + SummaryNotificationResult indicating success or failure. + """ + url = f"{self.server_url}/{self.topic}" + + # Build headers + headers = { + "Title": summary.title, + "Priority": self.priority, + "Tags": ",".join(summary.tags), + } + + if self.access_token: + headers["Authorization"] = f"Bearer {self.access_token}" + + self.logger.debug( + "sending_summary_notification", + location=summary.location, + alert_count=summary.alert_count, + ) + + response = self.http_client.post( + url, + data=summary.message.encode("utf-8"), + headers=headers, + ) + + if response.success: + self.logger.info( + "summary_notification_sent", + location=summary.location, + alert_count=summary.alert_count, + has_changes=summary.has_changes, + ) + return SummaryNotificationResult(summary=summary, success=True) + else: + error_msg = f"HTTP {response.status_code}: {response.text[:100]}" + self.logger.error( + "summary_notification_failed", + error=error_msg, + ) + return SummaryNotificationResult( + summary=summary, success=False, error=error_msg + ) diff --git a/app/services/state_manager.py b/app/services/state_manager.py index ad99311..cfbb0d6 100644 --- a/app/services/state_manager.py +++ b/app/services/state_manager.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, Union from app.models.alerts import AggregatedAlert, TriggeredAlert -from app.models.state import AlertState +from app.models.state import AlertSnapshot, AlertState from app.utils.logging_config import get_logger # Type alias for alerts that can be deduplicated @@ -183,3 +183,47 @@ class StateManager: self.logger.info("old_records_purged", count=purged) return purged + + def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None: + """Save current alerts as snapshots for change detection. + + Args: + alerts: List of aggregated alerts to save as snapshots. + """ + snapshots: dict[str, AlertSnapshot] = {} + + for alert in alerts: + snapshot = AlertSnapshot( + alert_type=alert.alert_type.value, + extreme_value=alert.extreme_value, + threshold=alert.threshold, + start_time=alert.start_time, + end_time=alert.end_time, + hour_count=alert.hour_count, + ) + snapshots[alert.alert_type.value] = snapshot + + from datetime import datetime + + self.state.previous_alert_snapshots = snapshots + self.state.last_updated = datetime.now() + + self.logger.debug( + "alert_snapshots_saved", + count=len(snapshots), + ) + + def get_previous_snapshots(self) -> dict[str, AlertSnapshot]: + """Get previous alert snapshots for change detection. + + Returns: + Dict of alert snapshots keyed by alert type. + """ + return self.state.previous_alert_snapshots + + def record_ai_summary_sent(self) -> None: + """Record that an AI summary was sent.""" + from datetime import datetime + + self.state.last_ai_summary_sent = datetime.now() + self.logger.debug("ai_summary_sent_recorded") diff --git a/replicate_client.py b/replicate_client.py new file mode 100644 index 0000000..fc8e831 --- /dev/null +++ b/replicate_client.py @@ -0,0 +1,339 @@ +""" +Replicate API client wrapper for the Sneaky Intel Feed Aggregator. + +Provides low-level API access with retry logic and error handling +for all Replicate model interactions. +""" + +import json +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import replicate + +from feed_aggregator.models.config import ReplicateConfig +from feed_aggregator.utils.exceptions import APIError +from feed_aggregator.utils.logging import get_logger + +logger = get_logger(__name__) + + +class ReplicateClient: + """ + Low-level wrapper for Replicate API calls. + + Handles: + - API authentication via REPLICATE_API_TOKEN + - Retry logic with exponential backoff + - Error classification and wrapping + - Response normalization + + Attributes: + MAX_RETRIES: Maximum retry attempts for failed calls. + INITIAL_BACKOFF_SECONDS: Initial delay before first retry. + BACKOFF_MULTIPLIER: Multiplier for exponential backoff. + LLAMA_MODEL: Model identifier for Llama 3 8B Instruct. + EMBEDDING_MODEL: Model identifier for text embeddings. + + Example: + >>> from feed_aggregator.services.replicate_client import ReplicateClient + >>> from feed_aggregator.models.config import ReplicateConfig + >>> + >>> config = ReplicateConfig(api_key="r8_...") + >>> client = ReplicateClient(config) + >>> response, error = client.run_llama("Summarize: ...") + """ + + # Retry configuration (matches DatabaseService pattern) + MAX_RETRIES = 3 + INITIAL_BACKOFF_SECONDS = 1.0 + BACKOFF_MULTIPLIER = 2.0 + + # Model identifiers + LLAMA_MODEL = "meta/meta-llama-3-8b-instruct" + EMBEDDING_MODEL = "beautyyuyanli/multilingual-e5-large:a06276a89f1a902d5fc225a9ca32b6e8e6292b7f3b136518878da97c458e2bad" + + # Default parameters for Llama + DEFAULT_MAX_TOKENS = 500 + DEFAULT_TEMPERATURE = 0.3 + + def __init__(self, config: ReplicateConfig): + """ + Initialize the Replicate client. + + Sets the REPLICATE_API_TOKEN environment variable for the + replicate library to use. + + Args: + config: Replicate configuration with API key. + """ + self._config = config + self._retry_count = 0 # Track total retries for metrics + + # Set API token for replicate library + os.environ["REPLICATE_API_TOKEN"] = config.api_key + + logger.debug( + "ReplicateClient initialized", + model=self.LLAMA_MODEL, + ) + + def run_llama( + self, + prompt: str, + max_tokens: int = DEFAULT_MAX_TOKENS, + temperature: float = DEFAULT_TEMPERATURE, + ) -> Tuple[Optional[str], Optional[APIError]]: + """ + Run Llama 3 8B Instruct with the given prompt. + + Args: + prompt: The prompt to send to the model. + max_tokens: Maximum tokens in response. + temperature: Sampling temperature (0-1). + + Returns: + Tuple of (response_text, error). + On success: (str, None) + On failure: (None, APIError) + """ + return self._run_with_retry( + model=self.LLAMA_MODEL, + input_params={ + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + }, + operation="summarization", + ) + + # Expected embedding dimensions for snowflake-arctic-embed-l + EMBEDDING_DIMENSIONS = 1024 + + def run_embedding( + self, + text: str, + ) -> Tuple[Optional[List[float]], Optional[APIError]]: + """ + Generate embeddings using snowflake-arctic-embed-l. + + Uses the same retry logic as run_llama but preserves the + list response format instead of converting to string. + + Args: + text: Text to embed (title + summary, typically 200-500 tokens). + + Returns: + Tuple of (embedding_vector, error). + On success: (List[float] with 1024 dimensions, None) + On failure: (None, APIError) + + Example: + >>> embedding, error = client.run_embedding("Article about security") + >>> if embedding: + ... print(f"Generated {len(embedding)}-dimensional embedding") + """ + last_error: Optional[Exception] = None + backoff = self.INITIAL_BACKOFF_SECONDS + + for attempt in range(1, self.MAX_RETRIES + 1): + try: + # E5 model expects texts as a JSON-formatted string + output = replicate.run( + self.EMBEDDING_MODEL, + input={"texts": json.dumps([text])}, + ) + + # Output is a list of embeddings (one per input text) + # We only pass one text, so extract the first embedding + if isinstance(output, list) and len(output) > 0: + embedding = output[0] + + # Validate embedding dimensions + if len(embedding) != self.EMBEDDING_DIMENSIONS: + raise ValueError( + f"Expected {self.EMBEDDING_DIMENSIONS} dimensions, " + f"got {len(embedding)}" + ) + + # Ensure all values are floats + embedding = [float(v) for v in embedding] + + logger.debug( + "Replicate embedding succeeded", + model=self.EMBEDDING_MODEL, + operation="embedding", + attempt=attempt, + dimensions=len(embedding), + ) + + return embedding, None + else: + raise ValueError( + f"Unexpected output format from embedding model: {type(output)}" + ) + + except Exception as e: + last_error = e + self._retry_count += 1 + + # Determine if error is retryable based on error message + error_str = str(e).lower() + is_rate_limit = "rate" in error_str or "429" in error_str + is_server_error = ( + "500" in error_str + or "502" in error_str + or "503" in error_str + or "504" in error_str + ) + is_timeout = "timeout" in error_str + is_retryable = is_rate_limit or is_server_error or is_timeout + + if attempt < self.MAX_RETRIES and is_retryable: + logger.warning( + "Replicate embedding failed, retrying", + model=self.EMBEDDING_MODEL, + operation="embedding", + attempt=attempt, + max_retries=self.MAX_RETRIES, + error=str(e), + backoff_seconds=backoff, + ) + time.sleep(backoff) + backoff *= self.BACKOFF_MULTIPLIER + else: + # Non-retryable error or max retries reached + break + + # All retries exhausted or non-retryable error + logger.error( + "Replicate embedding failed after retries", + model=self.EMBEDDING_MODEL, + operation="embedding", + retry_count=self.MAX_RETRIES, + error=str(last_error), + ) + + return None, APIError( + message=f"Replicate embedding failed: {str(last_error)}", + api_name="Replicate", + operation="embedding", + retry_count=self.MAX_RETRIES, + context={"model": self.EMBEDDING_MODEL, "error": str(last_error)}, + ) + + def _run_with_retry( + self, + model: str, + input_params: Dict[str, Any], + operation: str, + ) -> Tuple[Optional[str], Optional[APIError]]: + """ + Execute model call with retry logic. + + Implements exponential backoff retry strategy for transient errors: + - Rate limit errors (429) + - Server errors (5xx) + - Timeout errors + + Non-retryable errors (auth failures, etc.) fail immediately. + + Args: + model: Replicate model identifier. + input_params: Model input parameters. + operation: Operation name for error reporting. + + Returns: + Tuple of (response, error). + """ + last_error: Optional[Exception] = None + backoff = self.INITIAL_BACKOFF_SECONDS + + for attempt in range(1, self.MAX_RETRIES + 1): + try: + output = replicate.run(model, input=input_params) + + # Replicate returns a generator for streaming models + # Collect all output parts into a single string + if hasattr(output, "__iter__") and not isinstance(output, str): + response = "".join(str(part) for part in output) + else: + response = str(output) + + logger.debug( + "Replicate API call succeeded", + model=model, + operation=operation, + attempt=attempt, + response_length=len(response), + ) + + return response, None + + except Exception as e: + last_error = e + self._retry_count += 1 + + # Determine if error is retryable based on error message + error_str = str(e).lower() + is_rate_limit = "rate" in error_str or "429" in error_str + is_server_error = ( + "500" in error_str + or "502" in error_str + or "503" in error_str + or "504" in error_str + ) + is_timeout = "timeout" in error_str + is_retryable = is_rate_limit or is_server_error or is_timeout + + if attempt < self.MAX_RETRIES and is_retryable: + logger.warning( + "Replicate API call failed, retrying", + model=model, + operation=operation, + attempt=attempt, + max_retries=self.MAX_RETRIES, + error=str(e), + backoff_seconds=backoff, + ) + time.sleep(backoff) + backoff *= self.BACKOFF_MULTIPLIER + else: + # Non-retryable error or max retries reached + break + + # All retries exhausted or non-retryable error + logger.error( + "Replicate API call failed after retries", + model=model, + operation=operation, + retry_count=self.MAX_RETRIES, + error=str(last_error), + ) + + return None, APIError( + message=f"Replicate API call failed: {str(last_error)}", + api_name="Replicate", + operation=operation, + retry_count=self.MAX_RETRIES, + context={"model": model, "error": str(last_error)}, + ) + + @property + def total_retries(self) -> int: + """ + Return total retry count across all calls. + + Returns: + Total number of retry attempts made. + """ + return self._retry_count + + def reset_retry_count(self) -> None: + """ + Reset retry counter. + + Useful for metrics collection between batch operations. + """ + self._retry_count = 0 diff --git a/requirements.txt b/requirements.txt index 12f3989..f981a9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ requests>=2.28.0,<3.0.0 structlog>=23.0.0,<25.0.0 python-dotenv>=1.0.0,<2.0.0 PyYAML>=6.0,<7.0 +replicate>=0.25.0,<1.0.0 pytest>=7.0.0,<9.0.0 responses>=0.23.0,<1.0.0