Compare commits
38 Commits
f1982157aa
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7933e39216 | ||
|
|
d094ba7979 | ||
|
|
0849c35bb5 | ||
|
|
28f21a68c7 | ||
|
|
ab872503b4 | ||
|
|
2db4479942 | ||
|
|
3cce50410f | ||
|
|
7e04615843 | ||
|
|
ba95624675 | ||
|
|
3bdce183e0 | ||
|
|
e9849b66f5 | ||
|
|
1e72cd7423 | ||
|
|
6a7b5ef51f | ||
|
|
978c4ebbaa | ||
|
|
1def229e37 | ||
|
|
46344854f0 | ||
|
|
2b660451cd | ||
|
|
1bed55aff4 | ||
|
|
6011ee3012 | ||
|
|
7b9b2395de | ||
|
|
a9be527a1a | ||
|
|
c416ce643f | ||
|
|
c34d807681 | ||
|
|
1bd2d55951 | ||
|
|
0a7d3c41e7 | ||
|
|
26767cdcef | ||
|
|
3297bdea5e | ||
|
|
ef0a271554 | ||
|
|
f425f72e24 | ||
|
|
355c82e39e | ||
|
|
7becfaeba2 | ||
|
|
8695ccd9ab | ||
|
|
4459b3278d | ||
|
|
34867d0911 | ||
|
|
d68ee280b6 | ||
|
|
0661acb961 | ||
| 921b6a81a4 | |||
|
|
2820944ec6 |
@@ -9,3 +9,6 @@ VISUALCROSSING_API_KEY=your_api_key_here
|
|||||||
# Ntfy Access Token
|
# Ntfy Access Token
|
||||||
# Required if your ntfy server requires authentication
|
# Required if your ntfy server requires authentication
|
||||||
NTFY_ACCESS_TOKEN=your_ntfy_token_here
|
NTFY_ACCESS_TOKEN=your_ntfy_token_here
|
||||||
|
|
||||||
|
# Replicate API Token
|
||||||
|
REPLICATE_API_TOKEN=your_replicate_token_here
|
||||||
@@ -2,7 +2,7 @@ name: Weather Alerts
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 * * * *' # Every hour at :00
|
- cron: '0 */4 * * *' # Every 4th hour at :00
|
||||||
workflow_dispatch: {} # Manual trigger
|
workflow_dispatch: {} # Manual trigger
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
@@ -63,6 +63,35 @@ class AlertSettings:
|
|||||||
rules: AlertRules = field(default_factory=AlertRules)
|
rules: AlertRules = field(default_factory=AlertRules)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AISettings:
|
||||||
|
"""AI summarization settings."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
model: str = "meta/meta-llama-3-8b-instruct"
|
||||||
|
api_timeout: int = 60
|
||||||
|
max_tokens: int = 500
|
||||||
|
api_token: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChangeThresholds:
|
||||||
|
"""Thresholds for detecting significant changes between runs."""
|
||||||
|
|
||||||
|
temperature: float = 5.0
|
||||||
|
wind_speed: float = 10.0
|
||||||
|
wind_gust: float = 10.0
|
||||||
|
precipitation_prob: float = 20.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChangeDetectionSettings:
|
||||||
|
"""Change detection settings."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
thresholds: ChangeThresholds = field(default_factory=ChangeThresholds)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AppConfig:
|
class AppConfig:
|
||||||
"""Complete application configuration."""
|
"""Complete application configuration."""
|
||||||
@@ -72,6 +101,8 @@ class AppConfig:
|
|||||||
alerts: AlertSettings = field(default_factory=AlertSettings)
|
alerts: AlertSettings = field(default_factory=AlertSettings)
|
||||||
notifications: NotificationSettings = field(default_factory=NotificationSettings)
|
notifications: NotificationSettings = field(default_factory=NotificationSettings)
|
||||||
state: StateSettings = field(default_factory=StateSettings)
|
state: StateSettings = field(default_factory=StateSettings)
|
||||||
|
ai: AISettings = field(default_factory=AISettings)
|
||||||
|
change_detection: ChangeDetectionSettings = field(default_factory=ChangeDetectionSettings)
|
||||||
|
|
||||||
|
|
||||||
def load_config(
|
def load_config(
|
||||||
@@ -113,6 +144,8 @@ def load_config(
|
|||||||
alerts_data = config_data.get("alerts", {})
|
alerts_data = config_data.get("alerts", {})
|
||||||
notifications_data = config_data.get("notifications", {})
|
notifications_data = config_data.get("notifications", {})
|
||||||
state_data = config_data.get("state", {})
|
state_data = config_data.get("state", {})
|
||||||
|
ai_data = config_data.get("ai", {})
|
||||||
|
change_detection_data = config_data.get("change_detection", {})
|
||||||
|
|
||||||
# Build app settings
|
# Build app settings
|
||||||
app_settings = AppSettings(
|
app_settings = AppSettings(
|
||||||
@@ -150,10 +183,34 @@ def load_config(
|
|||||||
dedup_window_hours=state_data.get("dedup_window_hours", 24),
|
dedup_window_hours=state_data.get("dedup_window_hours", 24),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build AI settings with token from environment
|
||||||
|
ai_settings = AISettings(
|
||||||
|
enabled=ai_data.get("enabled", False),
|
||||||
|
model=ai_data.get("model", "meta/meta-llama-3-8b-instruct"),
|
||||||
|
api_timeout=ai_data.get("api_timeout", 60),
|
||||||
|
max_tokens=ai_data.get("max_tokens", 500),
|
||||||
|
api_token=os.environ.get("REPLICATE_API_TOKEN", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build change detection settings
|
||||||
|
thresholds_data = change_detection_data.get("thresholds", {})
|
||||||
|
change_thresholds = ChangeThresholds(
|
||||||
|
temperature=thresholds_data.get("temperature", 5.0),
|
||||||
|
wind_speed=thresholds_data.get("wind_speed", 10.0),
|
||||||
|
wind_gust=thresholds_data.get("wind_gust", 10.0),
|
||||||
|
precipitation_prob=thresholds_data.get("precipitation_prob", 20.0),
|
||||||
|
)
|
||||||
|
change_detection_settings = ChangeDetectionSettings(
|
||||||
|
enabled=change_detection_data.get("enabled", True),
|
||||||
|
thresholds=change_thresholds,
|
||||||
|
)
|
||||||
|
|
||||||
return AppConfig(
|
return AppConfig(
|
||||||
app=app_settings,
|
app=app_settings,
|
||||||
weather=weather_settings,
|
weather=weather_settings,
|
||||||
alerts=alert_settings,
|
alerts=alert_settings,
|
||||||
notifications=notification_settings,
|
notifications=notification_settings,
|
||||||
state=state_settings,
|
state=state_settings,
|
||||||
|
ai=ai_settings,
|
||||||
|
change_detection=change_detection_settings,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ app:
|
|||||||
|
|
||||||
weather:
|
weather:
|
||||||
location: "viola,tn"
|
location: "viola,tn"
|
||||||
hours_ahead: 24
|
hours_ahead: 4 # Matches 4-hour run frequency
|
||||||
unit_group: "us"
|
unit_group: "us"
|
||||||
|
|
||||||
alerts:
|
alerts:
|
||||||
@@ -34,3 +34,18 @@ notifications:
|
|||||||
state:
|
state:
|
||||||
file_path: "./data/state.json"
|
file_path: "./data/state.json"
|
||||||
dedup_window_hours: 24
|
dedup_window_hours: 24
|
||||||
|
|
||||||
|
ai:
|
||||||
|
enabled: true
|
||||||
|
model: "meta/meta-llama-3-8b-instruct"
|
||||||
|
api_timeout: 60
|
||||||
|
max_tokens: 500
|
||||||
|
|
||||||
|
# These metrics are used for change detection from previous forecast
|
||||||
|
change_detection:
|
||||||
|
enabled: true
|
||||||
|
thresholds:
|
||||||
|
temperature: 5.0 # degrees F
|
||||||
|
wind_speed: 10.0 # mph
|
||||||
|
wind_gust: 10.0 # mph
|
||||||
|
precipitation_prob: 20.0 # percentage points
|
||||||
|
|||||||
191
app/main.py
191
app/main.py
@@ -4,7 +4,12 @@ import sys
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from app.config.loader import AppConfig, load_config
|
from app.config.loader import AppConfig, load_config
|
||||||
|
from app.models.ai_summary import ForecastContext
|
||||||
|
from app.models.alerts import AggregatedAlert
|
||||||
|
from app.models.weather import WeatherForecast
|
||||||
|
from app.services.ai_summary_service import AISummaryService, AISummaryServiceError
|
||||||
from app.services.alert_aggregator import AlertAggregator
|
from app.services.alert_aggregator import AlertAggregator
|
||||||
|
from app.services.change_detector import ChangeDetector, ChangeReport
|
||||||
from app.services.notification_service import NotificationService
|
from app.services.notification_service import NotificationService
|
||||||
from app.services.rule_engine import RuleEngine
|
from app.services.rule_engine import RuleEngine
|
||||||
from app.services.state_manager import StateManager
|
from app.services.state_manager import StateManager
|
||||||
@@ -52,6 +57,26 @@ class WeatherAlertsApp:
|
|||||||
http_client=self.http_client,
|
http_client=self.http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize AI services conditionally
|
||||||
|
self.ai_enabled = config.ai.enabled and bool(config.ai.api_token)
|
||||||
|
self.ai_summary_service: Optional[AISummaryService] = None
|
||||||
|
self.change_detector: Optional[ChangeDetector] = None
|
||||||
|
|
||||||
|
if self.ai_enabled:
|
||||||
|
self.ai_summary_service = AISummaryService(
|
||||||
|
api_token=config.ai.api_token,
|
||||||
|
model=config.ai.model,
|
||||||
|
api_timeout=config.ai.api_timeout,
|
||||||
|
max_tokens=config.ai.max_tokens,
|
||||||
|
http_client=self.http_client,
|
||||||
|
)
|
||||||
|
self.logger.info("ai_summary_enabled", model=config.ai.model)
|
||||||
|
|
||||||
|
if config.change_detection.enabled:
|
||||||
|
self.change_detector = ChangeDetector(
|
||||||
|
thresholds=config.change_detection.thresholds
|
||||||
|
)
|
||||||
|
|
||||||
def run(self) -> int:
|
def run(self) -> int:
|
||||||
"""Execute the main application flow.
|
"""Execute the main application flow.
|
||||||
|
|
||||||
@@ -62,6 +87,7 @@ class WeatherAlertsApp:
|
|||||||
"app_starting",
|
"app_starting",
|
||||||
version=self.config.app.version,
|
version=self.config.app.version,
|
||||||
location=self.config.weather.location,
|
location=self.config.weather.location,
|
||||||
|
ai_enabled=self.ai_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -79,6 +105,9 @@ class WeatherAlertsApp:
|
|||||||
|
|
||||||
if not triggered_alerts:
|
if not triggered_alerts:
|
||||||
self.logger.info("no_alerts_triggered")
|
self.logger.info("no_alerts_triggered")
|
||||||
|
# Clear snapshots when no alerts
|
||||||
|
self.state_manager.save_alert_snapshots([])
|
||||||
|
self.state_manager.save()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
@@ -96,30 +125,55 @@ class WeatherAlertsApp:
|
|||||||
output_count=len(aggregated_alerts),
|
output_count=len(aggregated_alerts),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Filter duplicates
|
# Branch based on AI enabled
|
||||||
|
if self.ai_enabled:
|
||||||
|
return self._run_ai_flow(forecast, aggregated_alerts)
|
||||||
|
else:
|
||||||
|
return self._run_standard_flow(aggregated_alerts)
|
||||||
|
|
||||||
|
except WeatherServiceError as e:
|
||||||
|
self.logger.error("weather_service_error", error=str(e))
|
||||||
|
return 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception("unexpected_error", error=str(e))
|
||||||
|
return 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.http_client.close()
|
||||||
|
|
||||||
|
def _run_standard_flow(self, aggregated_alerts: list[AggregatedAlert]) -> int:
|
||||||
|
"""Run the standard notification flow without AI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aggregated_alerts: List of aggregated alerts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exit code (0 for success, 1 for error).
|
||||||
|
"""
|
||||||
|
# Filter duplicates
|
||||||
self.logger.info("step_filter_duplicates")
|
self.logger.info("step_filter_duplicates")
|
||||||
new_alerts = self.state_manager.filter_duplicates(aggregated_alerts)
|
new_alerts = self.state_manager.filter_duplicates(aggregated_alerts)
|
||||||
|
|
||||||
if not new_alerts:
|
if not new_alerts:
|
||||||
self.logger.info("all_alerts_are_duplicates")
|
self.logger.info("all_alerts_are_duplicates")
|
||||||
|
self._finalize(aggregated_alerts)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Step 4: Send notifications
|
# Send notifications
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"step_send_notifications",
|
"step_send_notifications",
|
||||||
count=len(new_alerts),
|
count=len(new_alerts),
|
||||||
)
|
)
|
||||||
results = self.notification_service.send_batch(new_alerts)
|
results = self.notification_service.send_batch(new_alerts)
|
||||||
|
|
||||||
# Step 5: Record sent alerts
|
# Record sent alerts
|
||||||
self.logger.info("step_record_sent")
|
self.logger.info("step_record_sent")
|
||||||
for result in results:
|
for result in results:
|
||||||
if result.success:
|
if result.success:
|
||||||
self.state_manager.record_sent(result.alert)
|
self.state_manager.record_sent(result.alert)
|
||||||
|
|
||||||
# Step 6: Purge old records and save state
|
self._finalize(aggregated_alerts)
|
||||||
self.state_manager.purge_old_records()
|
|
||||||
self.state_manager.save()
|
|
||||||
|
|
||||||
# Report results
|
# Report results
|
||||||
success_count = sum(1 for r in results if r.success)
|
success_count = sum(1 for r in results if r.success)
|
||||||
@@ -133,16 +187,119 @@ class WeatherAlertsApp:
|
|||||||
|
|
||||||
return 0 if failed_count == 0 else 1
|
return 0 if failed_count == 0 else 1
|
||||||
|
|
||||||
except WeatherServiceError as e:
|
def _run_ai_flow(
|
||||||
self.logger.error("weather_service_error", error=str(e))
|
self,
|
||||||
return 1
|
forecast: WeatherForecast,
|
||||||
|
aggregated_alerts: list[AggregatedAlert],
|
||||||
|
) -> int:
|
||||||
|
"""Run the AI summarization flow.
|
||||||
|
|
||||||
except Exception as e:
|
Args:
|
||||||
self.logger.exception("unexpected_error", error=str(e))
|
forecast: The weather forecast data.
|
||||||
return 1
|
aggregated_alerts: List of aggregated alerts.
|
||||||
|
|
||||||
finally:
|
Returns:
|
||||||
self.http_client.close()
|
Exit code (0 for success, 1 for error).
|
||||||
|
"""
|
||||||
|
self.logger.info("step_ai_summary_flow")
|
||||||
|
|
||||||
|
# Build forecast context
|
||||||
|
forecast_context = self._build_forecast_context(forecast)
|
||||||
|
|
||||||
|
# Detect changes from previous run
|
||||||
|
change_report = ChangeReport()
|
||||||
|
if self.change_detector:
|
||||||
|
previous_snapshots = self.state_manager.get_previous_snapshots()
|
||||||
|
change_report = self.change_detector.detect(
|
||||||
|
aggregated_alerts, previous_snapshots
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to generate AI summary
|
||||||
|
try:
|
||||||
|
assert self.ai_summary_service is not None
|
||||||
|
summary = self.ai_summary_service.summarize(
|
||||||
|
alerts=aggregated_alerts,
|
||||||
|
change_report=change_report,
|
||||||
|
forecast_context=forecast_context,
|
||||||
|
location=self.config.weather.location,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send summary notification
|
||||||
|
self.logger.info("step_send_ai_summary")
|
||||||
|
result = self.notification_service.send_summary(summary)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
self.state_manager.record_ai_summary_sent()
|
||||||
|
self._finalize(aggregated_alerts)
|
||||||
|
self.logger.info(
|
||||||
|
"app_complete_ai",
|
||||||
|
summary_sent=True,
|
||||||
|
alert_count=len(aggregated_alerts),
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
"ai_summary_send_failed",
|
||||||
|
error=result.error,
|
||||||
|
fallback="individual_alerts",
|
||||||
|
)
|
||||||
|
return self._run_standard_flow(aggregated_alerts)
|
||||||
|
|
||||||
|
except AISummaryServiceError as e:
|
||||||
|
self.logger.warning(
|
||||||
|
"ai_summary_generation_failed",
|
||||||
|
error=str(e),
|
||||||
|
fallback="individual_alerts",
|
||||||
|
)
|
||||||
|
return self._run_standard_flow(aggregated_alerts)
|
||||||
|
|
||||||
|
def _build_forecast_context(self, forecast: WeatherForecast) -> ForecastContext:
|
||||||
|
"""Build forecast context from weather forecast data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forecast: The weather forecast data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ForecastContext for AI summary.
|
||||||
|
"""
|
||||||
|
if not forecast.hourly_forecasts:
|
||||||
|
return ForecastContext(
|
||||||
|
start_time="N/A",
|
||||||
|
end_time="N/A",
|
||||||
|
min_temp=0,
|
||||||
|
max_temp=0,
|
||||||
|
max_wind_speed=0,
|
||||||
|
max_wind_gust=0,
|
||||||
|
max_precip_prob=0,
|
||||||
|
conditions=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
temps = [h.temp for h in forecast.hourly_forecasts]
|
||||||
|
wind_speeds = [h.wind_speed for h in forecast.hourly_forecasts]
|
||||||
|
wind_gusts = [h.wind_gust for h in forecast.hourly_forecasts]
|
||||||
|
precip_probs = [h.precip_prob for h in forecast.hourly_forecasts]
|
||||||
|
conditions = [h.conditions for h in forecast.hourly_forecasts]
|
||||||
|
|
||||||
|
return ForecastContext(
|
||||||
|
start_time=forecast.hourly_forecasts[0].hour_key,
|
||||||
|
end_time=forecast.hourly_forecasts[-1].hour_key,
|
||||||
|
min_temp=min(temps),
|
||||||
|
max_temp=max(temps),
|
||||||
|
max_wind_speed=max(wind_speeds),
|
||||||
|
max_wind_gust=max(wind_gusts),
|
||||||
|
max_precip_prob=max(precip_probs),
|
||||||
|
conditions=conditions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _finalize(self, aggregated_alerts: list[AggregatedAlert]) -> None:
|
||||||
|
"""Finalize the run by saving state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aggregated_alerts: Current alerts to save as snapshots.
|
||||||
|
"""
|
||||||
|
self.state_manager.save_alert_snapshots(aggregated_alerts)
|
||||||
|
self.state_manager.purge_old_records()
|
||||||
|
self.state_manager.save()
|
||||||
|
|
||||||
|
|
||||||
def main(config_path: Optional[str] = None) -> int:
|
def main(config_path: Optional[str] = None) -> int:
|
||||||
@@ -176,6 +333,12 @@ def main(config_path: Optional[str] = None) -> int:
|
|||||||
hint="Set NTFY_ACCESS_TOKEN if your server requires auth",
|
hint="Set NTFY_ACCESS_TOKEN if your server requires auth",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.ai.enabled and not config.ai.api_token:
|
||||||
|
logger.warning(
|
||||||
|
"ai_enabled_but_no_token",
|
||||||
|
hint="Set REPLICATE_API_TOKEN to enable AI summaries",
|
||||||
|
)
|
||||||
|
|
||||||
# Run the application
|
# Run the application
|
||||||
app = WeatherAlertsApp(config)
|
app = WeatherAlertsApp(config)
|
||||||
return app.run()
|
return app.run()
|
||||||
|
|||||||
74
app/models/ai_summary.py
Normal file
74
app/models/ai_summary.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Data models for AI summarization feature."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForecastContext:
|
||||||
|
"""Forecast data to include in AI summary prompt."""
|
||||||
|
|
||||||
|
start_time: str
|
||||||
|
end_time: str
|
||||||
|
min_temp: float
|
||||||
|
max_temp: float
|
||||||
|
max_wind_speed: float
|
||||||
|
max_wind_gust: float
|
||||||
|
max_precip_prob: float
|
||||||
|
conditions: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_prompt_text(self) -> str:
|
||||||
|
"""Format forecast context for LLM prompt."""
|
||||||
|
unique_conditions = list(dict.fromkeys(self.conditions))
|
||||||
|
conditions_str = ", ".join(unique_conditions[:5]) if unique_conditions else "N/A"
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Period: {self.start_time} to {self.end_time}\n"
|
||||||
|
f"Temperature: {self.min_temp:.0f}F - {self.max_temp:.0f}F\n"
|
||||||
|
f"Wind: up to {self.max_wind_speed:.0f} mph, gusts to {self.max_wind_gust:.0f} mph\n"
|
||||||
|
f"Precipitation probability: up to {self.max_precip_prob:.0f}%\n"
|
||||||
|
f"Conditions: {conditions_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"start_time": self.start_time,
|
||||||
|
"end_time": self.end_time,
|
||||||
|
"min_temp": self.min_temp,
|
||||||
|
"max_temp": self.max_temp,
|
||||||
|
"max_wind_speed": self.max_wind_speed,
|
||||||
|
"max_wind_gust": self.max_wind_gust,
|
||||||
|
"max_precip_prob": self.max_precip_prob,
|
||||||
|
"conditions": self.conditions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SummaryNotification:
|
||||||
|
"""AI-generated weather alert summary notification."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
message: str
|
||||||
|
location: str
|
||||||
|
alert_count: int
|
||||||
|
has_changes: bool
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tags(self) -> list[str]:
|
||||||
|
"""Get notification tags for AI summary."""
|
||||||
|
return ["robot", "weather", "summary"]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"title": self.title,
|
||||||
|
"message": self.message,
|
||||||
|
"location": self.location,
|
||||||
|
"alert_count": self.alert_count,
|
||||||
|
"has_changes": self.has_changes,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
"""State management models for alert deduplication."""
|
"""State management models for alert deduplication and change detection."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -41,12 +41,59 @@ class SentAlertRecord:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AlertSnapshot:
|
||||||
|
"""Snapshot of an alert for change detection between runs."""
|
||||||
|
|
||||||
|
alert_type: str
|
||||||
|
extreme_value: float
|
||||||
|
threshold: float
|
||||||
|
start_time: str
|
||||||
|
end_time: str
|
||||||
|
hour_count: int
|
||||||
|
captured_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
return {
|
||||||
|
"alert_type": self.alert_type,
|
||||||
|
"extreme_value": self.extreme_value,
|
||||||
|
"threshold": self.threshold,
|
||||||
|
"start_time": self.start_time,
|
||||||
|
"end_time": self.end_time,
|
||||||
|
"hour_count": self.hour_count,
|
||||||
|
"captured_at": self.captured_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "AlertSnapshot":
|
||||||
|
"""Create from dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The serialized snapshot dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An AlertSnapshot instance.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
alert_type=data["alert_type"],
|
||||||
|
extreme_value=data["extreme_value"],
|
||||||
|
threshold=data["threshold"],
|
||||||
|
start_time=data["start_time"],
|
||||||
|
end_time=data["end_time"],
|
||||||
|
hour_count=data["hour_count"],
|
||||||
|
captured_at=datetime.fromisoformat(data["captured_at"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AlertState:
|
class AlertState:
|
||||||
"""State container for tracking sent alerts."""
|
"""State container for tracking sent alerts and change detection."""
|
||||||
|
|
||||||
sent_alerts: dict[str, SentAlertRecord] = field(default_factory=dict)
|
sent_alerts: dict[str, SentAlertRecord] = field(default_factory=dict)
|
||||||
last_updated: datetime = field(default_factory=datetime.now)
|
last_updated: datetime = field(default_factory=datetime.now)
|
||||||
|
previous_alert_snapshots: dict[str, AlertSnapshot] = field(default_factory=dict)
|
||||||
|
last_ai_summary_sent: Optional[datetime] = None
|
||||||
|
|
||||||
def is_duplicate(self, dedup_key: str) -> bool:
|
def is_duplicate(self, dedup_key: str) -> bool:
|
||||||
"""Check if an alert with this dedup key has already been sent.
|
"""Check if an alert with this dedup key has already been sent.
|
||||||
@@ -106,6 +153,15 @@ class AlertState:
|
|||||||
key: record.to_dict() for key, record in self.sent_alerts.items()
|
key: record.to_dict() for key, record in self.sent_alerts.items()
|
||||||
},
|
},
|
||||||
"last_updated": self.last_updated.isoformat(),
|
"last_updated": self.last_updated.isoformat(),
|
||||||
|
"previous_alert_snapshots": {
|
||||||
|
key: snapshot.to_dict()
|
||||||
|
for key, snapshot in self.previous_alert_snapshots.items()
|
||||||
|
},
|
||||||
|
"last_ai_summary_sent": (
|
||||||
|
self.last_ai_summary_sent.isoformat()
|
||||||
|
if self.last_ai_summary_sent
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -130,4 +186,21 @@ class AlertState:
|
|||||||
else datetime.now()
|
else datetime.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(sent_alerts=sent_alerts, last_updated=last_updated)
|
previous_alert_snapshots = {
|
||||||
|
key: AlertSnapshot.from_dict(snapshot_data)
|
||||||
|
for key, snapshot_data in data.get("previous_alert_snapshots", {}).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
last_ai_summary_str = data.get("last_ai_summary_sent")
|
||||||
|
last_ai_summary_sent = (
|
||||||
|
datetime.fromisoformat(last_ai_summary_str)
|
||||||
|
if last_ai_summary_str
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
sent_alerts=sent_alerts,
|
||||||
|
last_updated=last_updated,
|
||||||
|
previous_alert_snapshots=previous_alert_snapshots,
|
||||||
|
last_ai_summary_sent=last_ai_summary_sent,
|
||||||
|
)
|
||||||
|
|||||||
251
app/services/ai_summary_service.py
Normal file
251
app/services/ai_summary_service.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""AI summarization service using Replicate API."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import replicate
|
||||||
|
|
||||||
|
from app.models.ai_summary import ForecastContext, SummaryNotification
|
||||||
|
from app.models.alerts import AggregatedAlert
|
||||||
|
from app.services.change_detector import ChangeReport
|
||||||
|
from app.utils.logging_config import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class AISummaryServiceError(Exception):
|
||||||
|
"""Raised when AI summary service encounters an error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AISummaryService:
|
||||||
|
"""Service for generating AI-powered weather alert summaries using Replicate."""
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
INITIAL_BACKOFF_SECONDS = 1.0
|
||||||
|
BACKOFF_MULTIPLIER = 2.0
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
|
||||||
|
|
||||||
|
## Active Alerts
|
||||||
|
{alerts_section}
|
||||||
|
|
||||||
|
## Changes Since Last Run
|
||||||
|
{changes_section}
|
||||||
|
|
||||||
|
## Forecast for Alert Period
|
||||||
|
{forecast_section}
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
- Write 2-4 sentences max
|
||||||
|
- Prioritize safety-critical info (severe weather, extreme temps)
|
||||||
|
- Mention changes from previous alert
|
||||||
|
- Include specific temperatures, wind speeds, times
|
||||||
|
- No emojis, professional tone
|
||||||
|
|
||||||
|
Summary:"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_token: str,
|
||||||
|
model: str = "meta/meta-llama-3-8b-instruct",
|
||||||
|
api_timeout: int = 60,
|
||||||
|
max_tokens: int = 500,
|
||||||
|
http_client: Optional[object] = None, # Kept for API compatibility
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the AI summary service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_token: Replicate API token.
|
||||||
|
model: Model identifier to use.
|
||||||
|
api_timeout: Maximum time to wait for response in seconds (unused with library).
|
||||||
|
max_tokens: Maximum tokens in generated response.
|
||||||
|
http_client: Unused - kept for API compatibility.
|
||||||
|
"""
|
||||||
|
self.api_token = api_token
|
||||||
|
self.model = model
|
||||||
|
self.api_timeout = api_timeout
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self._retry_count = 0
|
||||||
|
|
||||||
|
# Set API token for replicate library
|
||||||
|
os.environ["REPLICATE_API_TOKEN"] = api_token
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"AISummaryService initialized",
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
self,
|
||||||
|
alerts: list[AggregatedAlert],
|
||||||
|
change_report: ChangeReport,
|
||||||
|
forecast_context: ForecastContext,
|
||||||
|
location: str,
|
||||||
|
) -> SummaryNotification:
|
||||||
|
"""Generate an AI summary of the weather alerts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alerts: List of aggregated alerts to summarize.
|
||||||
|
change_report: Report of changes from previous run.
|
||||||
|
forecast_context: Forecast data for the alert period.
|
||||||
|
location: Location name for the summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SummaryNotification with the AI-generated summary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AISummaryServiceError: If API call fails or times out.
|
||||||
|
"""
|
||||||
|
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"generating_ai_summary",
|
||||||
|
model=self.model,
|
||||||
|
alert_count=len(alerts),
|
||||||
|
has_changes=change_report.has_changes,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_text = self._run_with_retry(prompt)
|
||||||
|
|
||||||
|
return SummaryNotification(
|
||||||
|
title=f"Weather Alert Summary - {location}",
|
||||||
|
message=summary_text,
|
||||||
|
location=location,
|
||||||
|
alert_count=len(alerts),
|
||||||
|
has_changes=change_report.has_changes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_prompt(
|
||||||
|
self,
|
||||||
|
alerts: list[AggregatedAlert],
|
||||||
|
change_report: ChangeReport,
|
||||||
|
forecast_context: ForecastContext,
|
||||||
|
location: str,
|
||||||
|
) -> str:
|
||||||
|
"""Build the prompt for the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alerts: List of aggregated alerts.
|
||||||
|
change_report: Change report from detector.
|
||||||
|
forecast_context: Forecast context data.
|
||||||
|
location: Location name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted prompt string.
|
||||||
|
"""
|
||||||
|
# Format alerts section
|
||||||
|
alert_lines = []
|
||||||
|
for alert in alerts:
|
||||||
|
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||||
|
alert_lines.append(
|
||||||
|
f"- {alert_type}: {alert.extreme_value:.0f} "
|
||||||
|
f"(threshold: {alert.threshold:.0f}) "
|
||||||
|
f"from {alert.start_time} to {alert.end_time}"
|
||||||
|
)
|
||||||
|
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
|
||||||
|
|
||||||
|
# Format changes section
|
||||||
|
changes_section = change_report.to_prompt_text()
|
||||||
|
|
||||||
|
# Format forecast section
|
||||||
|
forecast_section = forecast_context.to_prompt_text()
|
||||||
|
|
||||||
|
return self.PROMPT_TEMPLATE.format(
|
||||||
|
location=location,
|
||||||
|
alerts_section=alerts_section,
|
||||||
|
changes_section=changes_section,
|
||||||
|
forecast_section=forecast_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_with_retry(self, prompt: str) -> str:
|
||||||
|
"""Execute model call with retry logic.
|
||||||
|
|
||||||
|
Implements exponential backoff retry strategy for transient errors:
|
||||||
|
- Rate limit errors (429)
|
||||||
|
- Server errors (5xx)
|
||||||
|
- Timeout errors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text from the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AISummaryServiceError: If all retries fail.
|
||||||
|
"""
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||||
|
|
||||||
|
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
output = replicate.run(
|
||||||
|
self.model,
|
||||||
|
input={
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replicate returns a generator for streaming models
|
||||||
|
# Collect all output parts into a single string
|
||||||
|
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||||
|
response = "".join(str(part) for part in output)
|
||||||
|
else:
|
||||||
|
response = str(output)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"replicate_api_call_succeeded",
|
||||||
|
model=self.model,
|
||||||
|
attempt=attempt,
|
||||||
|
response_length=len(response),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
self._retry_count += 1
|
||||||
|
|
||||||
|
# Determine if error is retryable based on error message
|
||||||
|
error_str = str(e).lower()
|
||||||
|
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||||
|
is_server_error = (
|
||||||
|
"500" in error_str
|
||||||
|
or "502" in error_str
|
||||||
|
or "503" in error_str
|
||||||
|
or "504" in error_str
|
||||||
|
)
|
||||||
|
is_timeout = "timeout" in error_str
|
||||||
|
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||||
|
|
||||||
|
if attempt < self.MAX_RETRIES and is_retryable:
|
||||||
|
self.logger.warning(
|
||||||
|
"replicate_api_call_failed_retrying",
|
||||||
|
model=self.model,
|
||||||
|
attempt=attempt,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
error=str(e),
|
||||||
|
backoff_seconds=backoff,
|
||||||
|
)
|
||||||
|
time.sleep(backoff)
|
||||||
|
backoff *= self.BACKOFF_MULTIPLIER
|
||||||
|
else:
|
||||||
|
# Non-retryable error or max retries reached
|
||||||
|
break
|
||||||
|
|
||||||
|
# All retries exhausted or non-retryable error
|
||||||
|
self.logger.error(
|
||||||
|
"replicate_api_call_failed_after_retries",
|
||||||
|
model=self.model,
|
||||||
|
retry_count=self.MAX_RETRIES,
|
||||||
|
error=str(last_error),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AISummaryServiceError(
|
||||||
|
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
|
||||||
|
)
|
||||||
233
app/services/change_detector.py
Normal file
233
app/services/change_detector.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Change detection service for comparing alerts between runs."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.config.loader import ChangeThresholds
|
||||||
|
from app.models.alerts import AggregatedAlert
|
||||||
|
from app.models.state import AlertSnapshot
|
||||||
|
from app.utils.logging_config import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class ChangeType(Enum):
|
||||||
|
"""Types of changes that can be detected between runs."""
|
||||||
|
|
||||||
|
NEW = "new"
|
||||||
|
REMOVED = "removed"
|
||||||
|
VALUE_CHANGED = "value_changed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AlertChange:
|
||||||
|
"""Represents a detected change in an alert."""
|
||||||
|
|
||||||
|
alert_type: str
|
||||||
|
change_type: ChangeType
|
||||||
|
description: str
|
||||||
|
previous_value: Optional[float] = None
|
||||||
|
current_value: Optional[float] = None
|
||||||
|
value_delta: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChangeReport:
|
||||||
|
"""Report of all changes detected between runs."""
|
||||||
|
|
||||||
|
changes: list[AlertChange] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_changes(self) -> bool:
|
||||||
|
"""Check if any changes were detected."""
|
||||||
|
return len(self.changes) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def new_alerts(self) -> list[AlertChange]:
|
||||||
|
"""Get list of new alert changes."""
|
||||||
|
return [c for c in self.changes if c.change_type == ChangeType.NEW]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def removed_alerts(self) -> list[AlertChange]:
|
||||||
|
"""Get list of removed alert changes."""
|
||||||
|
return [c for c in self.changes if c.change_type == ChangeType.REMOVED]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value_changes(self) -> list[AlertChange]:
|
||||||
|
"""Get list of value change alerts."""
|
||||||
|
return [c for c in self.changes if c.change_type == ChangeType.VALUE_CHANGED]
|
||||||
|
|
||||||
|
def to_prompt_text(self) -> str:
|
||||||
|
"""Format change report for LLM prompt."""
|
||||||
|
if not self.has_changes:
|
||||||
|
return "No significant changes from previous alert."
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for change in self.new_alerts:
|
||||||
|
lines.append(f"NEW: {change.description}")
|
||||||
|
|
||||||
|
for change in self.removed_alerts:
|
||||||
|
lines.append(f"RESOLVED: {change.description}")
|
||||||
|
|
||||||
|
for change in self.value_changes:
|
||||||
|
lines.append(f"CHANGED: {change.description}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class ChangeDetector:
|
||||||
|
"""Detects changes between current alerts and previous snapshots."""
|
||||||
|
|
||||||
|
# Map alert types to their value thresholds
|
||||||
|
ALERT_TYPE_TO_THRESHOLD_KEY = {
|
||||||
|
"temperature_low": "temperature",
|
||||||
|
"temperature_high": "temperature",
|
||||||
|
"wind_speed": "wind_speed",
|
||||||
|
"wind_gust": "wind_gust",
|
||||||
|
"precipitation": "precipitation_prob",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, thresholds: ChangeThresholds) -> None:
|
||||||
|
"""Initialize the change detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thresholds: Thresholds for detecting significant changes.
|
||||||
|
"""
|
||||||
|
self.thresholds = thresholds
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def detect(
|
||||||
|
self,
|
||||||
|
current_alerts: list[AggregatedAlert],
|
||||||
|
previous_snapshots: dict[str, AlertSnapshot],
|
||||||
|
) -> ChangeReport:
|
||||||
|
"""Detect changes between current alerts and previous snapshots.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_alerts: List of current aggregated alerts.
|
||||||
|
previous_snapshots: Dict of previous alert snapshots keyed by alert type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChangeReport containing all detected changes.
|
||||||
|
"""
|
||||||
|
changes: list[AlertChange] = []
|
||||||
|
current_types = {alert.alert_type.value for alert in current_alerts}
|
||||||
|
previous_types = set(previous_snapshots.keys())
|
||||||
|
|
||||||
|
# Detect new alerts
|
||||||
|
for alert in current_alerts:
|
||||||
|
alert_type = alert.alert_type.value
|
||||||
|
if alert_type not in previous_types:
|
||||||
|
changes.append(
|
||||||
|
AlertChange(
|
||||||
|
alert_type=alert_type,
|
||||||
|
change_type=ChangeType.NEW,
|
||||||
|
description=self._format_new_alert_description(alert),
|
||||||
|
current_value=alert.extreme_value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.logger.debug(
|
||||||
|
"change_detected_new",
|
||||||
|
alert_type=alert_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Check for significant value changes
|
||||||
|
prev_snapshot = previous_snapshots[alert_type]
|
||||||
|
value_change = self._detect_value_change(alert, prev_snapshot)
|
||||||
|
if value_change:
|
||||||
|
changes.append(value_change)
|
||||||
|
self.logger.debug(
|
||||||
|
"change_detected_value",
|
||||||
|
alert_type=alert_type,
|
||||||
|
previous=prev_snapshot.extreme_value,
|
||||||
|
current=alert.extreme_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect removed alerts
|
||||||
|
for alert_type in previous_types - current_types:
|
||||||
|
prev_snapshot = previous_snapshots[alert_type]
|
||||||
|
changes.append(
|
||||||
|
AlertChange(
|
||||||
|
alert_type=alert_type,
|
||||||
|
change_type=ChangeType.REMOVED,
|
||||||
|
description=self._format_removed_alert_description(prev_snapshot),
|
||||||
|
previous_value=prev_snapshot.extreme_value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.logger.debug(
|
||||||
|
"change_detected_removed",
|
||||||
|
alert_type=alert_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
report = ChangeReport(changes=changes)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"change_detection_complete",
|
||||||
|
total_changes=len(changes),
|
||||||
|
new_count=len(report.new_alerts),
|
||||||
|
removed_count=len(report.removed_alerts),
|
||||||
|
value_changes_count=len(report.value_changes),
|
||||||
|
)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _detect_value_change(
|
||||||
|
self,
|
||||||
|
current: AggregatedAlert,
|
||||||
|
previous: AlertSnapshot,
|
||||||
|
) -> Optional[AlertChange]:
|
||||||
|
"""Detect if there's a significant value change between current and previous.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current: Current aggregated alert.
|
||||||
|
previous: Previous alert snapshot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AlertChange if significant change detected, None otherwise.
|
||||||
|
"""
|
||||||
|
threshold_key = self.ALERT_TYPE_TO_THRESHOLD_KEY.get(current.alert_type.value)
|
||||||
|
if not threshold_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
threshold = getattr(self.thresholds, threshold_key, None)
|
||||||
|
if threshold is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
delta = abs(current.extreme_value - previous.extreme_value)
|
||||||
|
if delta >= threshold:
|
||||||
|
return AlertChange(
|
||||||
|
alert_type=current.alert_type.value,
|
||||||
|
change_type=ChangeType.VALUE_CHANGED,
|
||||||
|
description=self._format_value_change_description(
|
||||||
|
current, previous, delta
|
||||||
|
),
|
||||||
|
previous_value=previous.extreme_value,
|
||||||
|
current_value=current.extreme_value,
|
||||||
|
value_delta=delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _format_new_alert_description(self, alert: AggregatedAlert) -> str:
|
||||||
|
"""Format description for a new alert."""
|
||||||
|
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||||
|
return f"{alert_type} alert: {alert.extreme_value:.0f} at {alert.extreme_hour}"
|
||||||
|
|
||||||
|
def _format_removed_alert_description(self, snapshot: AlertSnapshot) -> str:
|
||||||
|
"""Format description for a removed alert."""
|
||||||
|
alert_type = snapshot.alert_type.replace("_", " ").title()
|
||||||
|
return f"{alert_type} alert no longer active"
|
||||||
|
|
||||||
|
def _format_value_change_description(
|
||||||
|
self,
|
||||||
|
current: AggregatedAlert,
|
||||||
|
previous: AlertSnapshot,
|
||||||
|
delta: float,
|
||||||
|
) -> str:
|
||||||
|
"""Format description for a value change."""
|
||||||
|
alert_type = current.alert_type.value.replace("_", " ").title()
|
||||||
|
direction = "increased" if current.extreme_value > previous.extreme_value else "decreased"
|
||||||
|
return (
|
||||||
|
f"{alert_type} {direction} from {previous.extreme_value:.0f} "
|
||||||
|
f"to {current.extreme_value:.0f}"
|
||||||
|
)
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from app.models.ai_summary import SummaryNotification
|
||||||
from app.models.alerts import AggregatedAlert, AlertType, TriggeredAlert
|
from app.models.alerts import AggregatedAlert, AlertType, TriggeredAlert
|
||||||
from app.utils.http_client import HttpClient
|
from app.utils.http_client import HttpClient
|
||||||
from app.utils.logging_config import get_logger
|
from app.utils.logging_config import get_logger
|
||||||
@@ -20,6 +21,15 @@ class NotificationResult:
|
|||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SummaryNotificationResult:
|
||||||
|
"""Result of sending a summary notification."""
|
||||||
|
|
||||||
|
summary: SummaryNotification
|
||||||
|
success: bool
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class NotificationServiceError(Exception):
|
class NotificationServiceError(Exception):
|
||||||
"""Raised when notification service encounters an error."""
|
"""Raised when notification service encounters an error."""
|
||||||
|
|
||||||
@@ -175,3 +185,54 @@ class NotificationService:
|
|||||||
tags = list(self.ALERT_TYPE_TAGS.get(alert.alert_type, self.default_tags))
|
tags = list(self.ALERT_TYPE_TAGS.get(alert.alert_type, self.default_tags))
|
||||||
|
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
def send_summary(self, summary: SummaryNotification) -> SummaryNotificationResult:
|
||||||
|
"""Send an AI-generated summary notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary: The summary notification to send.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SummaryNotificationResult indicating success or failure.
|
||||||
|
"""
|
||||||
|
url = f"{self.server_url}/{self.topic}"
|
||||||
|
|
||||||
|
# Build headers
|
||||||
|
headers = {
|
||||||
|
"Title": summary.title,
|
||||||
|
"Priority": self.priority,
|
||||||
|
"Tags": ",".join(summary.tags),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.access_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"sending_summary_notification",
|
||||||
|
location=summary.location,
|
||||||
|
alert_count=summary.alert_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.http_client.post(
|
||||||
|
url,
|
||||||
|
data=summary.message.encode("utf-8"),
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
self.logger.info(
|
||||||
|
"summary_notification_sent",
|
||||||
|
location=summary.location,
|
||||||
|
alert_count=summary.alert_count,
|
||||||
|
has_changes=summary.has_changes,
|
||||||
|
)
|
||||||
|
return SummaryNotificationResult(summary=summary, success=True)
|
||||||
|
else:
|
||||||
|
error_msg = f"HTTP {response.status_code}: {response.text[:100]}"
|
||||||
|
self.logger.error(
|
||||||
|
"summary_notification_failed",
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
|
return SummaryNotificationResult(
|
||||||
|
summary=summary, success=False, error=error_msg
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from app.models.alerts import AggregatedAlert, TriggeredAlert
|
from app.models.alerts import AggregatedAlert, TriggeredAlert
|
||||||
from app.models.state import AlertState
|
from app.models.state import AlertSnapshot, AlertState
|
||||||
from app.utils.logging_config import get_logger
|
from app.utils.logging_config import get_logger
|
||||||
|
|
||||||
# Type alias for alerts that can be deduplicated
|
# Type alias for alerts that can be deduplicated
|
||||||
@@ -183,3 +183,47 @@ class StateManager:
|
|||||||
self.logger.info("old_records_purged", count=purged)
|
self.logger.info("old_records_purged", count=purged)
|
||||||
|
|
||||||
return purged
|
return purged
|
||||||
|
|
||||||
|
def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None:
|
||||||
|
"""Save current alerts as snapshots for change detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alerts: List of aggregated alerts to save as snapshots.
|
||||||
|
"""
|
||||||
|
snapshots: dict[str, AlertSnapshot] = {}
|
||||||
|
|
||||||
|
for alert in alerts:
|
||||||
|
snapshot = AlertSnapshot(
|
||||||
|
alert_type=alert.alert_type.value,
|
||||||
|
extreme_value=alert.extreme_value,
|
||||||
|
threshold=alert.threshold,
|
||||||
|
start_time=alert.start_time,
|
||||||
|
end_time=alert.end_time,
|
||||||
|
hour_count=alert.hour_count,
|
||||||
|
)
|
||||||
|
snapshots[alert.alert_type.value] = snapshot
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
self.state.previous_alert_snapshots = snapshots
|
||||||
|
self.state.last_updated = datetime.now()
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"alert_snapshots_saved",
|
||||||
|
count=len(snapshots),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_previous_snapshots(self) -> dict[str, AlertSnapshot]:
|
||||||
|
"""Get previous alert snapshots for change detection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of alert snapshots keyed by alert type.
|
||||||
|
"""
|
||||||
|
return self.state.previous_alert_snapshots
|
||||||
|
|
||||||
|
def record_ai_summary_sent(self) -> None:
|
||||||
|
"""Record that an AI summary was sent."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
self.state.last_ai_summary_sent = datetime.now()
|
||||||
|
self.logger.debug("ai_summary_sent_recorded")
|
||||||
|
|||||||
@@ -1,23 +1,38 @@
|
|||||||
{
|
{
|
||||||
"sent_alerts": {
|
"sent_alerts": {
|
||||||
"severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.2": {
|
"temperature_low:2026-02-01": {
|
||||||
"dedup_key": "severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.2",
|
"dedup_key": "temperature_low:2026-02-01",
|
||||||
"alert_type": "severe_weather",
|
|
||||||
"sent_at": "2026-01-26T21:15:25.784599",
|
|
||||||
"forecast_hour": "urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.2"
|
|
||||||
},
|
|
||||||
"severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.1": {
|
|
||||||
"dedup_key": "severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.1",
|
|
||||||
"alert_type": "severe_weather",
|
|
||||||
"sent_at": "2026-01-26T21:15:25.784621",
|
|
||||||
"forecast_hour": "urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.1"
|
|
||||||
},
|
|
||||||
"temperature_low:2026-01-26": {
|
|
||||||
"dedup_key": "temperature_low:2026-01-26",
|
|
||||||
"alert_type": "temperature_low",
|
"alert_type": "temperature_low",
|
||||||
"sent_at": "2026-01-26T21:15:25.784625",
|
"sent_at": "2026-02-01T00:01:05.309466",
|
||||||
"forecast_hour": "2026-01-26-22"
|
"forecast_hour": "2026-02-01-01"
|
||||||
|
},
|
||||||
|
"severe_weather:urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1": {
|
||||||
|
"dedup_key": "severe_weather:urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1",
|
||||||
|
"alert_type": "severe_weather",
|
||||||
|
"sent_at": "2026-02-01T08:01:04.797087",
|
||||||
|
"forecast_hour": "urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"last_updated": "2026-01-26T21:15:25.784626"
|
"last_updated": "2026-02-01T16:01:04.531986",
|
||||||
|
"previous_alert_snapshots": {
|
||||||
|
"severe_weather": {
|
||||||
|
"alert_type": "severe_weather",
|
||||||
|
"extreme_value": 1.0,
|
||||||
|
"threshold": 0.0,
|
||||||
|
"start_time": "urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1",
|
||||||
|
"end_time": "urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1",
|
||||||
|
"hour_count": 1,
|
||||||
|
"captured_at": "2026-02-01T16:01:04.531981"
|
||||||
|
},
|
||||||
|
"temperature_low": {
|
||||||
|
"alert_type": "temperature_low",
|
||||||
|
"extreme_value": 25.1,
|
||||||
|
"threshold": 32,
|
||||||
|
"start_time": "2026-02-01-17",
|
||||||
|
"end_time": "2026-02-01-20",
|
||||||
|
"hour_count": 4,
|
||||||
|
"captured_at": "2026-02-01T16:01:04.531984"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"last_ai_summary_sent": null
|
||||||
}
|
}
|
||||||
339
replicate_client.py
Normal file
339
replicate_client.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Replicate API client wrapper for the Sneaky Intel Feed Aggregator.
|
||||||
|
|
||||||
|
Provides low-level API access with retry logic and error handling
|
||||||
|
for all Replicate model interactions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import replicate
|
||||||
|
|
||||||
|
from feed_aggregator.models.config import ReplicateConfig
|
||||||
|
from feed_aggregator.utils.exceptions import APIError
|
||||||
|
from feed_aggregator.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicateClient:
|
||||||
|
"""
|
||||||
|
Low-level wrapper for Replicate API calls.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- API authentication via REPLICATE_API_TOKEN
|
||||||
|
- Retry logic with exponential backoff
|
||||||
|
- Error classification and wrapping
|
||||||
|
- Response normalization
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
MAX_RETRIES: Maximum retry attempts for failed calls.
|
||||||
|
INITIAL_BACKOFF_SECONDS: Initial delay before first retry.
|
||||||
|
BACKOFF_MULTIPLIER: Multiplier for exponential backoff.
|
||||||
|
LLAMA_MODEL: Model identifier for Llama 3 8B Instruct.
|
||||||
|
EMBEDDING_MODEL: Model identifier for text embeddings.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from feed_aggregator.services.replicate_client import ReplicateClient
|
||||||
|
>>> from feed_aggregator.models.config import ReplicateConfig
|
||||||
|
>>>
|
||||||
|
>>> config = ReplicateConfig(api_key="r8_...")
|
||||||
|
>>> client = ReplicateClient(config)
|
||||||
|
>>> response, error = client.run_llama("Summarize: ...")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Retry configuration (matches DatabaseService pattern)
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
INITIAL_BACKOFF_SECONDS = 1.0
|
||||||
|
BACKOFF_MULTIPLIER = 2.0
|
||||||
|
|
||||||
|
# Model identifiers
|
||||||
|
LLAMA_MODEL = "meta/meta-llama-3-8b-instruct"
|
||||||
|
EMBEDDING_MODEL = "beautyyuyanli/multilingual-e5-large:a06276a89f1a902d5fc225a9ca32b6e8e6292b7f3b136518878da97c458e2bad"
|
||||||
|
|
||||||
|
# Default parameters for Llama
|
||||||
|
DEFAULT_MAX_TOKENS = 500
|
||||||
|
DEFAULT_TEMPERATURE = 0.3
|
||||||
|
|
||||||
|
def __init__(self, config: ReplicateConfig):
|
||||||
|
"""
|
||||||
|
Initialize the Replicate client.
|
||||||
|
|
||||||
|
Sets the REPLICATE_API_TOKEN environment variable for the
|
||||||
|
replicate library to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Replicate configuration with API key.
|
||||||
|
"""
|
||||||
|
self._config = config
|
||||||
|
self._retry_count = 0 # Track total retries for metrics
|
||||||
|
|
||||||
|
# Set API token for replicate library
|
||||||
|
os.environ["REPLICATE_API_TOKEN"] = config.api_key
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"ReplicateClient initialized",
|
||||||
|
model=self.LLAMA_MODEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_llama(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
temperature: float = DEFAULT_TEMPERATURE,
|
||||||
|
) -> Tuple[Optional[str], Optional[APIError]]:
|
||||||
|
"""
|
||||||
|
Run Llama 3 8B Instruct with the given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send to the model.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
temperature: Sampling temperature (0-1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_text, error).
|
||||||
|
On success: (str, None)
|
||||||
|
On failure: (None, APIError)
|
||||||
|
"""
|
||||||
|
return self._run_with_retry(
|
||||||
|
model=self.LLAMA_MODEL,
|
||||||
|
input_params={
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
},
|
||||||
|
operation="summarization",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected embedding dimensions for snowflake-arctic-embed-l
|
||||||
|
EMBEDDING_DIMENSIONS = 1024
|
||||||
|
|
||||||
|
def run_embedding(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
) -> Tuple[Optional[List[float]], Optional[APIError]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings using snowflake-arctic-embed-l.
|
||||||
|
|
||||||
|
Uses the same retry logic as run_llama but preserves the
|
||||||
|
list response format instead of converting to string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed (title + summary, typically 200-500 tokens).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (embedding_vector, error).
|
||||||
|
On success: (List[float] with 1024 dimensions, None)
|
||||||
|
On failure: (None, APIError)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> embedding, error = client.run_embedding("Article about security")
|
||||||
|
>>> if embedding:
|
||||||
|
... print(f"Generated {len(embedding)}-dimensional embedding")
|
||||||
|
"""
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||||
|
|
||||||
|
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
# E5 model expects texts as a JSON-formatted string
|
||||||
|
output = replicate.run(
|
||||||
|
self.EMBEDDING_MODEL,
|
||||||
|
input={"texts": json.dumps([text])},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output is a list of embeddings (one per input text)
|
||||||
|
# We only pass one text, so extract the first embedding
|
||||||
|
if isinstance(output, list) and len(output) > 0:
|
||||||
|
embedding = output[0]
|
||||||
|
|
||||||
|
# Validate embedding dimensions
|
||||||
|
if len(embedding) != self.EMBEDDING_DIMENSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected {self.EMBEDDING_DIMENSIONS} dimensions, "
|
||||||
|
f"got {len(embedding)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure all values are floats
|
||||||
|
embedding = [float(v) for v in embedding]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Replicate embedding succeeded",
|
||||||
|
model=self.EMBEDDING_MODEL,
|
||||||
|
operation="embedding",
|
||||||
|
attempt=attempt,
|
||||||
|
dimensions=len(embedding),
|
||||||
|
)
|
||||||
|
|
||||||
|
return embedding, None
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected output format from embedding model: {type(output)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
self._retry_count += 1
|
||||||
|
|
||||||
|
# Determine if error is retryable based on error message
|
||||||
|
error_str = str(e).lower()
|
||||||
|
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||||
|
is_server_error = (
|
||||||
|
"500" in error_str
|
||||||
|
or "502" in error_str
|
||||||
|
or "503" in error_str
|
||||||
|
or "504" in error_str
|
||||||
|
)
|
||||||
|
is_timeout = "timeout" in error_str
|
||||||
|
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||||
|
|
||||||
|
if attempt < self.MAX_RETRIES and is_retryable:
|
||||||
|
logger.warning(
|
||||||
|
"Replicate embedding failed, retrying",
|
||||||
|
model=self.EMBEDDING_MODEL,
|
||||||
|
operation="embedding",
|
||||||
|
attempt=attempt,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
error=str(e),
|
||||||
|
backoff_seconds=backoff,
|
||||||
|
)
|
||||||
|
time.sleep(backoff)
|
||||||
|
backoff *= self.BACKOFF_MULTIPLIER
|
||||||
|
else:
|
||||||
|
# Non-retryable error or max retries reached
|
||||||
|
break
|
||||||
|
|
||||||
|
# All retries exhausted or non-retryable error
|
||||||
|
logger.error(
|
||||||
|
"Replicate embedding failed after retries",
|
||||||
|
model=self.EMBEDDING_MODEL,
|
||||||
|
operation="embedding",
|
||||||
|
retry_count=self.MAX_RETRIES,
|
||||||
|
error=str(last_error),
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, APIError(
|
||||||
|
message=f"Replicate embedding failed: {str(last_error)}",
|
||||||
|
api_name="Replicate",
|
||||||
|
operation="embedding",
|
||||||
|
retry_count=self.MAX_RETRIES,
|
||||||
|
context={"model": self.EMBEDDING_MODEL, "error": str(last_error)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_with_retry(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input_params: Dict[str, Any],
|
||||||
|
operation: str,
|
||||||
|
) -> Tuple[Optional[str], Optional[APIError]]:
|
||||||
|
"""
|
||||||
|
Execute model call with retry logic.
|
||||||
|
|
||||||
|
Implements exponential backoff retry strategy for transient errors:
|
||||||
|
- Rate limit errors (429)
|
||||||
|
- Server errors (5xx)
|
||||||
|
- Timeout errors
|
||||||
|
|
||||||
|
Non-retryable errors (auth failures, etc.) fail immediately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Replicate model identifier.
|
||||||
|
input_params: Model input parameters.
|
||||||
|
operation: Operation name for error reporting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response, error).
|
||||||
|
"""
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||||
|
|
||||||
|
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
output = replicate.run(model, input=input_params)
|
||||||
|
|
||||||
|
# Replicate returns a generator for streaming models
|
||||||
|
# Collect all output parts into a single string
|
||||||
|
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||||
|
response = "".join(str(part) for part in output)
|
||||||
|
else:
|
||||||
|
response = str(output)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Replicate API call succeeded",
|
||||||
|
model=model,
|
||||||
|
operation=operation,
|
||||||
|
attempt=attempt,
|
||||||
|
response_length=len(response),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
self._retry_count += 1
|
||||||
|
|
||||||
|
# Determine if error is retryable based on error message
|
||||||
|
error_str = str(e).lower()
|
||||||
|
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||||
|
is_server_error = (
|
||||||
|
"500" in error_str
|
||||||
|
or "502" in error_str
|
||||||
|
or "503" in error_str
|
||||||
|
or "504" in error_str
|
||||||
|
)
|
||||||
|
is_timeout = "timeout" in error_str
|
||||||
|
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||||
|
|
||||||
|
if attempt < self.MAX_RETRIES and is_retryable:
|
||||||
|
logger.warning(
|
||||||
|
"Replicate API call failed, retrying",
|
||||||
|
model=model,
|
||||||
|
operation=operation,
|
||||||
|
attempt=attempt,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
error=str(e),
|
||||||
|
backoff_seconds=backoff,
|
||||||
|
)
|
||||||
|
time.sleep(backoff)
|
||||||
|
backoff *= self.BACKOFF_MULTIPLIER
|
||||||
|
else:
|
||||||
|
# Non-retryable error or max retries reached
|
||||||
|
break
|
||||||
|
|
||||||
|
# All retries exhausted or non-retryable error
|
||||||
|
logger.error(
|
||||||
|
"Replicate API call failed after retries",
|
||||||
|
model=model,
|
||||||
|
operation=operation,
|
||||||
|
retry_count=self.MAX_RETRIES,
|
||||||
|
error=str(last_error),
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, APIError(
|
||||||
|
message=f"Replicate API call failed: {str(last_error)}",
|
||||||
|
api_name="Replicate",
|
||||||
|
operation=operation,
|
||||||
|
retry_count=self.MAX_RETRIES,
|
||||||
|
context={"model": model, "error": str(last_error)},
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_retries(self) -> int:
|
||||||
|
"""
|
||||||
|
Return total retry count across all calls.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of retry attempts made.
|
||||||
|
"""
|
||||||
|
return self._retry_count
|
||||||
|
|
||||||
|
def reset_retry_count(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset retry counter.
|
||||||
|
|
||||||
|
Useful for metrics collection between batch operations.
|
||||||
|
"""
|
||||||
|
self._retry_count = 0
|
||||||
@@ -2,5 +2,6 @@ requests>=2.28.0,<3.0.0
|
|||||||
structlog>=23.0.0,<25.0.0
|
structlog>=23.0.0,<25.0.0
|
||||||
python-dotenv>=1.0.0,<2.0.0
|
python-dotenv>=1.0.0,<2.0.0
|
||||||
PyYAML>=6.0,<7.0
|
PyYAML>=6.0,<7.0
|
||||||
|
replicate>=0.25.0,<1.0.0
|
||||||
pytest>=7.0.0,<9.0.0
|
pytest>=7.0.0,<9.0.0
|
||||||
responses>=0.23.0,<1.0.0
|
responses>=0.23.0,<1.0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user