Compare commits
39 Commits
abe116f518
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d094ba7979 | ||
|
|
0849c35bb5 | ||
|
|
28f21a68c7 | ||
|
|
ab872503b4 | ||
|
|
2db4479942 | ||
|
|
3cce50410f | ||
|
|
7e04615843 | ||
|
|
ba95624675 | ||
|
|
3bdce183e0 | ||
|
|
e9849b66f5 | ||
|
|
1e72cd7423 | ||
|
|
6a7b5ef51f | ||
|
|
978c4ebbaa | ||
|
|
1def229e37 | ||
|
|
46344854f0 | ||
|
|
2b660451cd | ||
|
|
1bed55aff4 | ||
|
|
6011ee3012 | ||
|
|
7b9b2395de | ||
|
|
a9be527a1a | ||
|
|
c416ce643f | ||
|
|
c34d807681 | ||
|
|
1bd2d55951 | ||
|
|
0a7d3c41e7 | ||
|
|
26767cdcef | ||
|
|
3297bdea5e | ||
|
|
ef0a271554 | ||
|
|
f425f72e24 | ||
|
|
355c82e39e | ||
|
|
7becfaeba2 | ||
|
|
8695ccd9ab | ||
|
|
4459b3278d | ||
|
|
34867d0911 | ||
|
|
d68ee280b6 | ||
|
|
0661acb961 | ||
| 921b6a81a4 | |||
|
|
2820944ec6 | ||
| f1982157aa | |||
| bd9c0c40a6 |
@@ -9,3 +9,6 @@ VISUALCROSSING_API_KEY=your_api_key_here
|
||||
# Ntfy Access Token
|
||||
# Required if your ntfy server requires authentication
|
||||
NTFY_ACCESS_TOKEN=your_ntfy_token_here
|
||||
|
||||
# Replicate API Token
|
||||
REPLICATE_API_TOKEN=your_replicate_token_here
|
||||
@@ -2,7 +2,7 @@ name: Weather Alerts
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 * * * *' # Every hour at :00
|
||||
- cron: '0 */4 * * *' # Every 4th hour at :00
|
||||
workflow_dispatch: {} # Manual trigger
|
||||
|
||||
jobs:
|
||||
|
||||
173
README.md
Normal file
173
README.md
Normal file
@@ -0,0 +1,173 @@
|
||||
# Weather Alerts
|
||||
|
||||
A Python application that monitors weather forecasts and sends push notifications when conditions match your alert rules.
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Queries the Visual Crossing Weather API for upcoming forecast data
|
||||
2. Evaluates forecast conditions against configurable alert rules
|
||||
3. Sends notifications to an ntfy server when alert conditions are triggered
|
||||
4. Tracks sent alerts to prevent duplicate notifications
|
||||
|
||||
## Configuration
|
||||
|
||||
All configuration is managed in `app/config/settings.yaml`. Copy from `settings.example.yaml` to get started:
|
||||
|
||||
```bash
|
||||
cp app/config/settings.example.yaml app/config/settings.yaml
|
||||
```
|
||||
|
||||
## Location Configuration
|
||||
|
||||
Set your location in `settings.yaml` under the `weather.location` field.
|
||||
|
||||
### Supported Location Formats
|
||||
|
||||
Visual Crossing supports several location formats:
|
||||
|
||||
| Format | Example | Notes |
|
||||
|--------|---------|-------|
|
||||
| City, State (US) | `"Chicago,IL"` | State abbreviation or full name |
|
||||
| City, Country | `"Paris,France"`, `"Tokyo,Japan"` | For international locations |
|
||||
| City, State, Country | `"Chicago,IL,USA"` | Most explicit format |
|
||||
| Full Address | `"620 Herndon Parkway, Herndon, Virginia, 20170, USA"` | Street-level precision |
|
||||
| ZIP/Postal Code | `"20170"` | US ZIP codes work directly |
|
||||
| Coordinates | `"40.7128,-74.0060"` | Latitude,longitude in decimal degrees |
|
||||
|
||||
### Location Best Practices
|
||||
|
||||
- **Include state/country** to avoid ambiguity (many cities share names like "Springfield")
|
||||
- **Coordinates are fastest** and most precise - no geocoding lookup required
|
||||
- **URL encoding is handled automatically** by the application
|
||||
|
||||
### Coordinate Format
|
||||
|
||||
For latitude/longitude coordinates:
|
||||
- Latitude: -90 to 90 (negative values = southern hemisphere)
|
||||
- Longitude: -180 to 180 (negative values = western hemisphere)
|
||||
- Format: `"latitude,longitude"` (e.g., `"35.8456,-86.4244"` for Viola, TN)
|
||||
|
||||
## Alert Rules Configuration
|
||||
|
||||
Configure alert rules under `alerts.rules` in settings.yaml. Each rule can be independently enabled or disabled.
|
||||
|
||||
### Temperature Alerts
|
||||
|
||||
```yaml
|
||||
temperature:
|
||||
enabled: true
|
||||
below: 32 # Alert when temperature falls below (freezing warning)
|
||||
above: 95 # Alert when temperature exceeds (heat warning)
|
||||
```
|
||||
|
||||
Both thresholds are optional - set only the ones you need.
|
||||
|
||||
### Precipitation Alerts
|
||||
|
||||
```yaml
|
||||
precipitation:
|
||||
enabled: true
|
||||
probability_above: 60 # Alert when rain/snow chance exceeds this percentage
|
||||
```
|
||||
|
||||
### Wind Alerts
|
||||
|
||||
```yaml
|
||||
wind:
|
||||
enabled: true
|
||||
speed_above: 25 # Sustained wind threshold (mph for "us" units)
|
||||
gust_above: 30 # Wind gust threshold (mph for "us" units)
|
||||
```
|
||||
|
||||
Either threshold can trigger an alert independently.
|
||||
|
||||
### Severe Weather Alerts
|
||||
|
||||
```yaml
|
||||
severe_weather:
|
||||
enabled: true # Forward NWS severe weather alerts from the API
|
||||
```
|
||||
|
||||
When enabled, forwards official severe weather alerts (tornado warnings, flash flood warnings, etc.) from the National Weather Service via the Visual Crossing API.
|
||||
|
||||
## Other Configuration Options
|
||||
|
||||
### Weather Settings
|
||||
|
||||
```yaml
|
||||
weather:
|
||||
location: "viola,tn" # Your location (see formats above)
|
||||
hours_ahead: 24 # How many hours of forecast to check
|
||||
unit_group: "us" # "us" for Fahrenheit/mph, "metric" for Celsius/kph
|
||||
```
|
||||
|
||||
### Notification Settings
|
||||
|
||||
```yaml
|
||||
notifications:
|
||||
ntfy:
|
||||
server_url: "https://ntfy.sneakygeek.net"
|
||||
topic: "weather-alerts"
|
||||
priority: "high" # min, low, default, high, urgent
|
||||
tags: ["cloud", "warning"] # Emoji tags for notifications
|
||||
```
|
||||
|
||||
### State Management
|
||||
|
||||
```yaml
|
||||
state:
|
||||
file_path: "./data/state.json"
|
||||
dedup_window_hours: 24 # Prevent duplicate alerts within this time window
|
||||
```
|
||||
|
||||
The deduplication window prevents the same alert from being sent repeatedly. A "High Wind Warning" at 3 PM won't be sent again until 24 hours later (with default settings).
|
||||
|
||||
### Application Settings
|
||||
|
||||
```yaml
|
||||
app:
|
||||
name: "weather-alerts"
|
||||
version: "1.0.0"
|
||||
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Create a `.env` file in the project root with your secrets:
|
||||
|
||||
```bash
|
||||
VISUALCROSSING_API_KEY=your_api_key_here
|
||||
NTFY_ACCESS_TOKEN=your_ntfy_token_here
|
||||
```
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `VISUALCROSSING_API_KEY` | Yes | API key from [Visual Crossing](https://www.visualcrossing.com/) |
|
||||
| `NTFY_ACCESS_TOKEN` | Yes* | Authentication token for ntfy server (*if server requires auth) |
|
||||
|
||||
## Example Configuration
|
||||
|
||||
Here's a complete example for monitoring weather in New York City:
|
||||
|
||||
```yaml
|
||||
weather:
|
||||
location: "New York,NY,USA"
|
||||
hours_ahead: 12
|
||||
unit_group: "us"
|
||||
|
||||
alerts:
|
||||
rules:
|
||||
temperature:
|
||||
enabled: true
|
||||
below: 20 # Alert for very cold temps
|
||||
above: 95 # Alert for heat waves
|
||||
precipitation:
|
||||
enabled: true
|
||||
probability_above: 70
|
||||
wind:
|
||||
enabled: true
|
||||
speed_above: 30
|
||||
gust_above: 45
|
||||
severe_weather:
|
||||
enabled: true
|
||||
```
|
||||
@@ -63,6 +63,35 @@ class AlertSettings:
|
||||
rules: AlertRules = field(default_factory=AlertRules)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AISettings:
|
||||
"""AI summarization settings."""
|
||||
|
||||
enabled: bool = False
|
||||
model: str = "meta/meta-llama-3-8b-instruct"
|
||||
api_timeout: int = 60
|
||||
max_tokens: int = 500
|
||||
api_token: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeThresholds:
|
||||
"""Thresholds for detecting significant changes between runs."""
|
||||
|
||||
temperature: float = 5.0
|
||||
wind_speed: float = 10.0
|
||||
wind_gust: float = 10.0
|
||||
precipitation_prob: float = 20.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeDetectionSettings:
|
||||
"""Change detection settings."""
|
||||
|
||||
enabled: bool = True
|
||||
thresholds: ChangeThresholds = field(default_factory=ChangeThresholds)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Complete application configuration."""
|
||||
@@ -72,6 +101,8 @@ class AppConfig:
|
||||
alerts: AlertSettings = field(default_factory=AlertSettings)
|
||||
notifications: NotificationSettings = field(default_factory=NotificationSettings)
|
||||
state: StateSettings = field(default_factory=StateSettings)
|
||||
ai: AISettings = field(default_factory=AISettings)
|
||||
change_detection: ChangeDetectionSettings = field(default_factory=ChangeDetectionSettings)
|
||||
|
||||
|
||||
def load_config(
|
||||
@@ -113,6 +144,8 @@ def load_config(
|
||||
alerts_data = config_data.get("alerts", {})
|
||||
notifications_data = config_data.get("notifications", {})
|
||||
state_data = config_data.get("state", {})
|
||||
ai_data = config_data.get("ai", {})
|
||||
change_detection_data = config_data.get("change_detection", {})
|
||||
|
||||
# Build app settings
|
||||
app_settings = AppSettings(
|
||||
@@ -150,10 +183,34 @@ def load_config(
|
||||
dedup_window_hours=state_data.get("dedup_window_hours", 24),
|
||||
)
|
||||
|
||||
# Build AI settings with token from environment
|
||||
ai_settings = AISettings(
|
||||
enabled=ai_data.get("enabled", False),
|
||||
model=ai_data.get("model", "meta/meta-llama-3-8b-instruct"),
|
||||
api_timeout=ai_data.get("api_timeout", 60),
|
||||
max_tokens=ai_data.get("max_tokens", 500),
|
||||
api_token=os.environ.get("REPLICATE_API_TOKEN", ""),
|
||||
)
|
||||
|
||||
# Build change detection settings
|
||||
thresholds_data = change_detection_data.get("thresholds", {})
|
||||
change_thresholds = ChangeThresholds(
|
||||
temperature=thresholds_data.get("temperature", 5.0),
|
||||
wind_speed=thresholds_data.get("wind_speed", 10.0),
|
||||
wind_gust=thresholds_data.get("wind_gust", 10.0),
|
||||
precipitation_prob=thresholds_data.get("precipitation_prob", 20.0),
|
||||
)
|
||||
change_detection_settings = ChangeDetectionSettings(
|
||||
enabled=change_detection_data.get("enabled", True),
|
||||
thresholds=change_thresholds,
|
||||
)
|
||||
|
||||
return AppConfig(
|
||||
app=app_settings,
|
||||
weather=weather_settings,
|
||||
alerts=alert_settings,
|
||||
notifications=notification_settings,
|
||||
state=state_settings,
|
||||
ai=ai_settings,
|
||||
change_detection=change_detection_settings,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ app:
|
||||
|
||||
weather:
|
||||
location: "viola,tn"
|
||||
hours_ahead: 24
|
||||
hours_ahead: 4 # Matches 4-hour run frequency
|
||||
unit_group: "us"
|
||||
|
||||
alerts:
|
||||
@@ -34,3 +34,18 @@ notifications:
|
||||
state:
|
||||
file_path: "./data/state.json"
|
||||
dedup_window_hours: 24
|
||||
|
||||
ai:
|
||||
enabled: true
|
||||
model: "meta/meta-llama-3-8b-instruct"
|
||||
api_timeout: 60
|
||||
max_tokens: 500
|
||||
|
||||
# These metrics are used for change detection from previous forecast
|
||||
change_detection:
|
||||
enabled: true
|
||||
thresholds:
|
||||
temperature: 5.0 # degrees F
|
||||
wind_speed: 10.0 # mph
|
||||
wind_gust: 10.0 # mph
|
||||
precipitation_prob: 20.0 # percentage points
|
||||
|
||||
235
app/main.py
235
app/main.py
@@ -4,7 +4,12 @@ import sys
|
||||
from typing import Optional
|
||||
|
||||
from app.config.loader import AppConfig, load_config
|
||||
from app.models.ai_summary import ForecastContext
|
||||
from app.models.alerts import AggregatedAlert
|
||||
from app.models.weather import WeatherForecast
|
||||
from app.services.ai_summary_service import AISummaryService, AISummaryServiceError
|
||||
from app.services.alert_aggregator import AlertAggregator
|
||||
from app.services.change_detector import ChangeDetector, ChangeReport
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.rule_engine import RuleEngine
|
||||
from app.services.state_manager import StateManager
|
||||
@@ -52,6 +57,26 @@ class WeatherAlertsApp:
|
||||
http_client=self.http_client,
|
||||
)
|
||||
|
||||
# Initialize AI services conditionally
|
||||
self.ai_enabled = config.ai.enabled and bool(config.ai.api_token)
|
||||
self.ai_summary_service: Optional[AISummaryService] = None
|
||||
self.change_detector: Optional[ChangeDetector] = None
|
||||
|
||||
if self.ai_enabled:
|
||||
self.ai_summary_service = AISummaryService(
|
||||
api_token=config.ai.api_token,
|
||||
model=config.ai.model,
|
||||
api_timeout=config.ai.api_timeout,
|
||||
max_tokens=config.ai.max_tokens,
|
||||
http_client=self.http_client,
|
||||
)
|
||||
self.logger.info("ai_summary_enabled", model=config.ai.model)
|
||||
|
||||
if config.change_detection.enabled:
|
||||
self.change_detector = ChangeDetector(
|
||||
thresholds=config.change_detection.thresholds
|
||||
)
|
||||
|
||||
def run(self) -> int:
|
||||
"""Execute the main application flow.
|
||||
|
||||
@@ -62,6 +87,7 @@ class WeatherAlertsApp:
|
||||
"app_starting",
|
||||
version=self.config.app.version,
|
||||
location=self.config.weather.location,
|
||||
ai_enabled=self.ai_enabled,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -79,6 +105,9 @@ class WeatherAlertsApp:
|
||||
|
||||
if not triggered_alerts:
|
||||
self.logger.info("no_alerts_triggered")
|
||||
# Clear snapshots when no alerts
|
||||
self.state_manager.save_alert_snapshots([])
|
||||
self.state_manager.save()
|
||||
return 0
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,42 +125,11 @@ class WeatherAlertsApp:
|
||||
output_count=len(aggregated_alerts),
|
||||
)
|
||||
|
||||
# Step 3: Filter duplicates
|
||||
self.logger.info("step_filter_duplicates")
|
||||
new_alerts = self.state_manager.filter_duplicates(aggregated_alerts)
|
||||
|
||||
if not new_alerts:
|
||||
self.logger.info("all_alerts_are_duplicates")
|
||||
return 0
|
||||
|
||||
# Step 4: Send notifications
|
||||
self.logger.info(
|
||||
"step_send_notifications",
|
||||
count=len(new_alerts),
|
||||
)
|
||||
results = self.notification_service.send_batch(new_alerts)
|
||||
|
||||
# Step 5: Record sent alerts
|
||||
self.logger.info("step_record_sent")
|
||||
for result in results:
|
||||
if result.success:
|
||||
self.state_manager.record_sent(result.alert)
|
||||
|
||||
# Step 6: Purge old records and save state
|
||||
self.state_manager.purge_old_records()
|
||||
self.state_manager.save()
|
||||
|
||||
# Report results
|
||||
success_count = sum(1 for r in results if r.success)
|
||||
failed_count = len(results) - success_count
|
||||
|
||||
self.logger.info(
|
||||
"app_complete",
|
||||
alerts_sent=success_count,
|
||||
alerts_failed=failed_count,
|
||||
)
|
||||
|
||||
return 0 if failed_count == 0 else 1
|
||||
# Branch based on AI enabled
|
||||
if self.ai_enabled:
|
||||
return self._run_ai_flow(forecast, aggregated_alerts)
|
||||
else:
|
||||
return self._run_standard_flow(aggregated_alerts)
|
||||
|
||||
except WeatherServiceError as e:
|
||||
self.logger.error("weather_service_error", error=str(e))
|
||||
@@ -144,6 +142,165 @@ class WeatherAlertsApp:
|
||||
finally:
|
||||
self.http_client.close()
|
||||
|
||||
def _run_standard_flow(self, aggregated_alerts: list[AggregatedAlert]) -> int:
|
||||
"""Run the standard notification flow without AI.
|
||||
|
||||
Args:
|
||||
aggregated_alerts: List of aggregated alerts.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error).
|
||||
"""
|
||||
# Filter duplicates
|
||||
self.logger.info("step_filter_duplicates")
|
||||
new_alerts = self.state_manager.filter_duplicates(aggregated_alerts)
|
||||
|
||||
if not new_alerts:
|
||||
self.logger.info("all_alerts_are_duplicates")
|
||||
self._finalize(aggregated_alerts)
|
||||
return 0
|
||||
|
||||
# Send notifications
|
||||
self.logger.info(
|
||||
"step_send_notifications",
|
||||
count=len(new_alerts),
|
||||
)
|
||||
results = self.notification_service.send_batch(new_alerts)
|
||||
|
||||
# Record sent alerts
|
||||
self.logger.info("step_record_sent")
|
||||
for result in results:
|
||||
if result.success:
|
||||
self.state_manager.record_sent(result.alert)
|
||||
|
||||
self._finalize(aggregated_alerts)
|
||||
|
||||
# Report results
|
||||
success_count = sum(1 for r in results if r.success)
|
||||
failed_count = len(results) - success_count
|
||||
|
||||
self.logger.info(
|
||||
"app_complete",
|
||||
alerts_sent=success_count,
|
||||
alerts_failed=failed_count,
|
||||
)
|
||||
|
||||
return 0 if failed_count == 0 else 1
|
||||
|
||||
def _run_ai_flow(
|
||||
self,
|
||||
forecast: WeatherForecast,
|
||||
aggregated_alerts: list[AggregatedAlert],
|
||||
) -> int:
|
||||
"""Run the AI summarization flow.
|
||||
|
||||
Args:
|
||||
forecast: The weather forecast data.
|
||||
aggregated_alerts: List of aggregated alerts.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error).
|
||||
"""
|
||||
self.logger.info("step_ai_summary_flow")
|
||||
|
||||
# Build forecast context
|
||||
forecast_context = self._build_forecast_context(forecast)
|
||||
|
||||
# Detect changes from previous run
|
||||
change_report = ChangeReport()
|
||||
if self.change_detector:
|
||||
previous_snapshots = self.state_manager.get_previous_snapshots()
|
||||
change_report = self.change_detector.detect(
|
||||
aggregated_alerts, previous_snapshots
|
||||
)
|
||||
|
||||
# Try to generate AI summary
|
||||
try:
|
||||
assert self.ai_summary_service is not None
|
||||
summary = self.ai_summary_service.summarize(
|
||||
alerts=aggregated_alerts,
|
||||
change_report=change_report,
|
||||
forecast_context=forecast_context,
|
||||
location=self.config.weather.location,
|
||||
)
|
||||
|
||||
# Send summary notification
|
||||
self.logger.info("step_send_ai_summary")
|
||||
result = self.notification_service.send_summary(summary)
|
||||
|
||||
if result.success:
|
||||
self.state_manager.record_ai_summary_sent()
|
||||
self._finalize(aggregated_alerts)
|
||||
self.logger.info(
|
||||
"app_complete_ai",
|
||||
summary_sent=True,
|
||||
alert_count=len(aggregated_alerts),
|
||||
)
|
||||
return 0
|
||||
else:
|
||||
self.logger.warning(
|
||||
"ai_summary_send_failed",
|
||||
error=result.error,
|
||||
fallback="individual_alerts",
|
||||
)
|
||||
return self._run_standard_flow(aggregated_alerts)
|
||||
|
||||
except AISummaryServiceError as e:
|
||||
self.logger.warning(
|
||||
"ai_summary_generation_failed",
|
||||
error=str(e),
|
||||
fallback="individual_alerts",
|
||||
)
|
||||
return self._run_standard_flow(aggregated_alerts)
|
||||
|
||||
def _build_forecast_context(self, forecast: WeatherForecast) -> ForecastContext:
|
||||
"""Build forecast context from weather forecast data.
|
||||
|
||||
Args:
|
||||
forecast: The weather forecast data.
|
||||
|
||||
Returns:
|
||||
ForecastContext for AI summary.
|
||||
"""
|
||||
if not forecast.hourly_forecasts:
|
||||
return ForecastContext(
|
||||
start_time="N/A",
|
||||
end_time="N/A",
|
||||
min_temp=0,
|
||||
max_temp=0,
|
||||
max_wind_speed=0,
|
||||
max_wind_gust=0,
|
||||
max_precip_prob=0,
|
||||
conditions=[],
|
||||
)
|
||||
|
||||
temps = [h.temp for h in forecast.hourly_forecasts]
|
||||
wind_speeds = [h.wind_speed for h in forecast.hourly_forecasts]
|
||||
wind_gusts = [h.wind_gust for h in forecast.hourly_forecasts]
|
||||
precip_probs = [h.precip_prob for h in forecast.hourly_forecasts]
|
||||
conditions = [h.conditions for h in forecast.hourly_forecasts]
|
||||
|
||||
return ForecastContext(
|
||||
start_time=forecast.hourly_forecasts[0].hour_key,
|
||||
end_time=forecast.hourly_forecasts[-1].hour_key,
|
||||
min_temp=min(temps),
|
||||
max_temp=max(temps),
|
||||
max_wind_speed=max(wind_speeds),
|
||||
max_wind_gust=max(wind_gusts),
|
||||
max_precip_prob=max(precip_probs),
|
||||
conditions=conditions,
|
||||
)
|
||||
|
||||
def _finalize(self, aggregated_alerts: list[AggregatedAlert]) -> None:
|
||||
"""Finalize the run by saving state.
|
||||
|
||||
Args:
|
||||
aggregated_alerts: Current alerts to save as snapshots.
|
||||
"""
|
||||
self.state_manager.save_alert_snapshots(aggregated_alerts)
|
||||
self.state_manager.purge_old_records()
|
||||
self.state_manager.save()
|
||||
|
||||
|
||||
def main(config_path: Optional[str] = None) -> int:
|
||||
"""Main entry point for the application.
|
||||
@@ -176,6 +333,12 @@ def main(config_path: Optional[str] = None) -> int:
|
||||
hint="Set NTFY_ACCESS_TOKEN if your server requires auth",
|
||||
)
|
||||
|
||||
if config.ai.enabled and not config.ai.api_token:
|
||||
logger.warning(
|
||||
"ai_enabled_but_no_token",
|
||||
hint="Set REPLICATE_API_TOKEN to enable AI summaries",
|
||||
)
|
||||
|
||||
# Run the application
|
||||
app = WeatherAlertsApp(config)
|
||||
return app.run()
|
||||
|
||||
74
app/models/ai_summary.py
Normal file
74
app/models/ai_summary.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Data models for AI summarization feature."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForecastContext:
|
||||
"""Forecast data to include in AI summary prompt."""
|
||||
|
||||
start_time: str
|
||||
end_time: str
|
||||
min_temp: float
|
||||
max_temp: float
|
||||
max_wind_speed: float
|
||||
max_wind_gust: float
|
||||
max_precip_prob: float
|
||||
conditions: list[str] = field(default_factory=list)
|
||||
|
||||
def to_prompt_text(self) -> str:
|
||||
"""Format forecast context for LLM prompt."""
|
||||
unique_conditions = list(dict.fromkeys(self.conditions))
|
||||
conditions_str = ", ".join(unique_conditions[:5]) if unique_conditions else "N/A"
|
||||
|
||||
return (
|
||||
f"Period: {self.start_time} to {self.end_time}\n"
|
||||
f"Temperature: {self.min_temp:.0f}F - {self.max_temp:.0f}F\n"
|
||||
f"Wind: up to {self.max_wind_speed:.0f} mph, gusts to {self.max_wind_gust:.0f} mph\n"
|
||||
f"Precipitation probability: up to {self.max_precip_prob:.0f}%\n"
|
||||
f"Conditions: {conditions_str}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"min_temp": self.min_temp,
|
||||
"max_temp": self.max_temp,
|
||||
"max_wind_speed": self.max_wind_speed,
|
||||
"max_wind_gust": self.max_wind_gust,
|
||||
"max_precip_prob": self.max_precip_prob,
|
||||
"conditions": self.conditions,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryNotification:
|
||||
"""AI-generated weather alert summary notification."""
|
||||
|
||||
title: str
|
||||
message: str
|
||||
location: str
|
||||
alert_count: int
|
||||
has_changes: bool
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
@property
|
||||
def tags(self) -> list[str]:
|
||||
"""Get notification tags for AI summary."""
|
||||
return ["robot", "weather", "summary"]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"title": self.title,
|
||||
"message": self.message,
|
||||
"location": self.location,
|
||||
"alert_count": self.alert_count,
|
||||
"has_changes": self.has_changes,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"tags": self.tags,
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
"""State management models for alert deduplication."""
|
||||
"""State management models for alert deduplication and change detection."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,12 +41,59 @@ class SentAlertRecord:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertSnapshot:
|
||||
"""Snapshot of an alert for change detection between runs."""
|
||||
|
||||
alert_type: str
|
||||
extreme_value: float
|
||||
threshold: float
|
||||
start_time: str
|
||||
end_time: str
|
||||
hour_count: int
|
||||
captured_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"alert_type": self.alert_type,
|
||||
"extreme_value": self.extreme_value,
|
||||
"threshold": self.threshold,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"hour_count": self.hour_count,
|
||||
"captured_at": self.captured_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AlertSnapshot":
|
||||
"""Create from dictionary.
|
||||
|
||||
Args:
|
||||
data: The serialized snapshot dict.
|
||||
|
||||
Returns:
|
||||
An AlertSnapshot instance.
|
||||
"""
|
||||
return cls(
|
||||
alert_type=data["alert_type"],
|
||||
extreme_value=data["extreme_value"],
|
||||
threshold=data["threshold"],
|
||||
start_time=data["start_time"],
|
||||
end_time=data["end_time"],
|
||||
hour_count=data["hour_count"],
|
||||
captured_at=datetime.fromisoformat(data["captured_at"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertState:
|
||||
"""State container for tracking sent alerts."""
|
||||
"""State container for tracking sent alerts and change detection."""
|
||||
|
||||
sent_alerts: dict[str, SentAlertRecord] = field(default_factory=dict)
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
previous_alert_snapshots: dict[str, AlertSnapshot] = field(default_factory=dict)
|
||||
last_ai_summary_sent: Optional[datetime] = None
|
||||
|
||||
def is_duplicate(self, dedup_key: str) -> bool:
|
||||
"""Check if an alert with this dedup key has already been sent.
|
||||
@@ -106,6 +153,15 @@ class AlertState:
|
||||
key: record.to_dict() for key, record in self.sent_alerts.items()
|
||||
},
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
"previous_alert_snapshots": {
|
||||
key: snapshot.to_dict()
|
||||
for key, snapshot in self.previous_alert_snapshots.items()
|
||||
},
|
||||
"last_ai_summary_sent": (
|
||||
self.last_ai_summary_sent.isoformat()
|
||||
if self.last_ai_summary_sent
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -130,4 +186,21 @@ class AlertState:
|
||||
else datetime.now()
|
||||
)
|
||||
|
||||
return cls(sent_alerts=sent_alerts, last_updated=last_updated)
|
||||
previous_alert_snapshots = {
|
||||
key: AlertSnapshot.from_dict(snapshot_data)
|
||||
for key, snapshot_data in data.get("previous_alert_snapshots", {}).items()
|
||||
}
|
||||
|
||||
last_ai_summary_str = data.get("last_ai_summary_sent")
|
||||
last_ai_summary_sent = (
|
||||
datetime.fromisoformat(last_ai_summary_str)
|
||||
if last_ai_summary_str
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
sent_alerts=sent_alerts,
|
||||
last_updated=last_updated,
|
||||
previous_alert_snapshots=previous_alert_snapshots,
|
||||
last_ai_summary_sent=last_ai_summary_sent,
|
||||
)
|
||||
|
||||
251
app/services/ai_summary_service.py
Normal file
251
app/services/ai_summary_service.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""AI summarization service using Replicate API."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import replicate
|
||||
|
||||
from app.models.ai_summary import ForecastContext, SummaryNotification
|
||||
from app.models.alerts import AggregatedAlert
|
||||
from app.services.change_detector import ChangeReport
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
|
||||
class AISummaryServiceError(Exception):
|
||||
"""Raised when AI summary service encounters an error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AISummaryService:
|
||||
"""Service for generating AI-powered weather alert summaries using Replicate."""
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_BACKOFF_SECONDS = 1.0
|
||||
BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
PROMPT_TEMPLATE = """You are a weather assistant. Generate a concise alert summary for {location}.
|
||||
|
||||
## Active Alerts
|
||||
{alerts_section}
|
||||
|
||||
## Changes Since Last Run
|
||||
{changes_section}
|
||||
|
||||
## Forecast for Alert Period
|
||||
{forecast_section}
|
||||
|
||||
Instructions:
|
||||
- Write 2-4 sentences max
|
||||
- Prioritize safety-critical info (severe weather, extreme temps)
|
||||
- Mention changes from previous alert
|
||||
- Include specific temperatures, wind speeds, times
|
||||
- No emojis, professional tone
|
||||
|
||||
Summary:"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_token: str,
|
||||
model: str = "meta/meta-llama-3-8b-instruct",
|
||||
api_timeout: int = 60,
|
||||
max_tokens: int = 500,
|
||||
http_client: Optional[object] = None, # Kept for API compatibility
|
||||
) -> None:
|
||||
"""Initialize the AI summary service.
|
||||
|
||||
Args:
|
||||
api_token: Replicate API token.
|
||||
model: Model identifier to use.
|
||||
api_timeout: Maximum time to wait for response in seconds (unused with library).
|
||||
max_tokens: Maximum tokens in generated response.
|
||||
http_client: Unused - kept for API compatibility.
|
||||
"""
|
||||
self.api_token = api_token
|
||||
self.model = model
|
||||
self.api_timeout = api_timeout
|
||||
self.max_tokens = max_tokens
|
||||
self.logger = get_logger(__name__)
|
||||
self._retry_count = 0
|
||||
|
||||
# Set API token for replicate library
|
||||
os.environ["REPLICATE_API_TOKEN"] = api_token
|
||||
|
||||
self.logger.debug(
|
||||
"AISummaryService initialized",
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
alerts: list[AggregatedAlert],
|
||||
change_report: ChangeReport,
|
||||
forecast_context: ForecastContext,
|
||||
location: str,
|
||||
) -> SummaryNotification:
|
||||
"""Generate an AI summary of the weather alerts.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts to summarize.
|
||||
change_report: Report of changes from previous run.
|
||||
forecast_context: Forecast data for the alert period.
|
||||
location: Location name for the summary.
|
||||
|
||||
Returns:
|
||||
SummaryNotification with the AI-generated summary.
|
||||
|
||||
Raises:
|
||||
AISummaryServiceError: If API call fails or times out.
|
||||
"""
|
||||
prompt = self._build_prompt(alerts, change_report, forecast_context, location)
|
||||
|
||||
self.logger.debug(
|
||||
"generating_ai_summary",
|
||||
model=self.model,
|
||||
alert_count=len(alerts),
|
||||
has_changes=change_report.has_changes,
|
||||
)
|
||||
|
||||
summary_text = self._run_with_retry(prompt)
|
||||
|
||||
return SummaryNotification(
|
||||
title=f"Weather Alert Summary - {location}",
|
||||
message=summary_text,
|
||||
location=location,
|
||||
alert_count=len(alerts),
|
||||
has_changes=change_report.has_changes,
|
||||
)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
alerts: list[AggregatedAlert],
|
||||
change_report: ChangeReport,
|
||||
forecast_context: ForecastContext,
|
||||
location: str,
|
||||
) -> str:
|
||||
"""Build the prompt for the LLM.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts.
|
||||
change_report: Change report from detector.
|
||||
forecast_context: Forecast context data.
|
||||
location: Location name.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string.
|
||||
"""
|
||||
# Format alerts section
|
||||
alert_lines = []
|
||||
for alert in alerts:
|
||||
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||
alert_lines.append(
|
||||
f"- {alert_type}: {alert.extreme_value:.0f} "
|
||||
f"(threshold: {alert.threshold:.0f}) "
|
||||
f"from {alert.start_time} to {alert.end_time}"
|
||||
)
|
||||
alerts_section = "\n".join(alert_lines) if alert_lines else "No active alerts."
|
||||
|
||||
# Format changes section
|
||||
changes_section = change_report.to_prompt_text()
|
||||
|
||||
# Format forecast section
|
||||
forecast_section = forecast_context.to_prompt_text()
|
||||
|
||||
return self.PROMPT_TEMPLATE.format(
|
||||
location=location,
|
||||
alerts_section=alerts_section,
|
||||
changes_section=changes_section,
|
||||
forecast_section=forecast_section,
|
||||
)
|
||||
|
||||
def _run_with_retry(self, prompt: str) -> str:
|
||||
"""Execute model call with retry logic.
|
||||
|
||||
Implements exponential backoff retry strategy for transient errors:
|
||||
- Rate limit errors (429)
|
||||
- Server errors (5xx)
|
||||
- Timeout errors
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the model.
|
||||
|
||||
Returns:
|
||||
Generated text from the model.
|
||||
|
||||
Raises:
|
||||
AISummaryServiceError: If all retries fail.
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
output = replicate.run(
|
||||
self.model,
|
||||
input={
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
# Replicate returns a generator for streaming models
|
||||
# Collect all output parts into a single string
|
||||
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||
response = "".join(str(part) for part in output)
|
||||
else:
|
||||
response = str(output)
|
||||
|
||||
self.logger.debug(
|
||||
"replicate_api_call_succeeded",
|
||||
model=self.model,
|
||||
attempt=attempt,
|
||||
response_length=len(response),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
self.logger.warning(
|
||||
"replicate_api_call_failed_retrying",
|
||||
model=self.model,
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
self.logger.error(
|
||||
"replicate_api_call_failed_after_retries",
|
||||
model=self.model,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
raise AISummaryServiceError(
|
||||
f"Replicate API call failed after {self.MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
233
app/services/change_detector.py
Normal file
233
app/services/change_detector.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Change detection service for comparing alerts between runs."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.config.loader import ChangeThresholds
|
||||
from app.models.alerts import AggregatedAlert
|
||||
from app.models.state import AlertSnapshot
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
|
||||
class ChangeType(Enum):
|
||||
"""Types of changes that can be detected between runs."""
|
||||
|
||||
NEW = "new"
|
||||
REMOVED = "removed"
|
||||
VALUE_CHANGED = "value_changed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertChange:
|
||||
"""Represents a detected change in an alert."""
|
||||
|
||||
alert_type: str
|
||||
change_type: ChangeType
|
||||
description: str
|
||||
previous_value: Optional[float] = None
|
||||
current_value: Optional[float] = None
|
||||
value_delta: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeReport:
|
||||
"""Report of all changes detected between runs."""
|
||||
|
||||
changes: list[AlertChange] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if any changes were detected."""
|
||||
return len(self.changes) > 0
|
||||
|
||||
@property
|
||||
def new_alerts(self) -> list[AlertChange]:
|
||||
"""Get list of new alert changes."""
|
||||
return [c for c in self.changes if c.change_type == ChangeType.NEW]
|
||||
|
||||
@property
|
||||
def removed_alerts(self) -> list[AlertChange]:
|
||||
"""Get list of removed alert changes."""
|
||||
return [c for c in self.changes if c.change_type == ChangeType.REMOVED]
|
||||
|
||||
@property
|
||||
def value_changes(self) -> list[AlertChange]:
|
||||
"""Get list of value change alerts."""
|
||||
return [c for c in self.changes if c.change_type == ChangeType.VALUE_CHANGED]
|
||||
|
||||
def to_prompt_text(self) -> str:
|
||||
"""Format change report for LLM prompt."""
|
||||
if not self.has_changes:
|
||||
return "No significant changes from previous alert."
|
||||
|
||||
lines = []
|
||||
|
||||
for change in self.new_alerts:
|
||||
lines.append(f"NEW: {change.description}")
|
||||
|
||||
for change in self.removed_alerts:
|
||||
lines.append(f"RESOLVED: {change.description}")
|
||||
|
||||
for change in self.value_changes:
|
||||
lines.append(f"CHANGED: {change.description}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ChangeDetector:
|
||||
"""Detects changes between current alerts and previous snapshots."""
|
||||
|
||||
# Map alert types to their value thresholds
|
||||
ALERT_TYPE_TO_THRESHOLD_KEY = {
|
||||
"temperature_low": "temperature",
|
||||
"temperature_high": "temperature",
|
||||
"wind_speed": "wind_speed",
|
||||
"wind_gust": "wind_gust",
|
||||
"precipitation": "precipitation_prob",
|
||||
}
|
||||
|
||||
def __init__(self, thresholds: ChangeThresholds) -> None:
|
||||
"""Initialize the change detector.
|
||||
|
||||
Args:
|
||||
thresholds: Thresholds for detecting significant changes.
|
||||
"""
|
||||
self.thresholds = thresholds
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
current_alerts: list[AggregatedAlert],
|
||||
previous_snapshots: dict[str, AlertSnapshot],
|
||||
) -> ChangeReport:
|
||||
"""Detect changes between current alerts and previous snapshots.
|
||||
|
||||
Args:
|
||||
current_alerts: List of current aggregated alerts.
|
||||
previous_snapshots: Dict of previous alert snapshots keyed by alert type.
|
||||
|
||||
Returns:
|
||||
ChangeReport containing all detected changes.
|
||||
"""
|
||||
changes: list[AlertChange] = []
|
||||
current_types = {alert.alert_type.value for alert in current_alerts}
|
||||
previous_types = set(previous_snapshots.keys())
|
||||
|
||||
# Detect new alerts
|
||||
for alert in current_alerts:
|
||||
alert_type = alert.alert_type.value
|
||||
if alert_type not in previous_types:
|
||||
changes.append(
|
||||
AlertChange(
|
||||
alert_type=alert_type,
|
||||
change_type=ChangeType.NEW,
|
||||
description=self._format_new_alert_description(alert),
|
||||
current_value=alert.extreme_value,
|
||||
)
|
||||
)
|
||||
self.logger.debug(
|
||||
"change_detected_new",
|
||||
alert_type=alert_type,
|
||||
)
|
||||
else:
|
||||
# Check for significant value changes
|
||||
prev_snapshot = previous_snapshots[alert_type]
|
||||
value_change = self._detect_value_change(alert, prev_snapshot)
|
||||
if value_change:
|
||||
changes.append(value_change)
|
||||
self.logger.debug(
|
||||
"change_detected_value",
|
||||
alert_type=alert_type,
|
||||
previous=prev_snapshot.extreme_value,
|
||||
current=alert.extreme_value,
|
||||
)
|
||||
|
||||
# Detect removed alerts
|
||||
for alert_type in previous_types - current_types:
|
||||
prev_snapshot = previous_snapshots[alert_type]
|
||||
changes.append(
|
||||
AlertChange(
|
||||
alert_type=alert_type,
|
||||
change_type=ChangeType.REMOVED,
|
||||
description=self._format_removed_alert_description(prev_snapshot),
|
||||
previous_value=prev_snapshot.extreme_value,
|
||||
)
|
||||
)
|
||||
self.logger.debug(
|
||||
"change_detected_removed",
|
||||
alert_type=alert_type,
|
||||
)
|
||||
|
||||
report = ChangeReport(changes=changes)
|
||||
|
||||
self.logger.info(
|
||||
"change_detection_complete",
|
||||
total_changes=len(changes),
|
||||
new_count=len(report.new_alerts),
|
||||
removed_count=len(report.removed_alerts),
|
||||
value_changes_count=len(report.value_changes),
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
def _detect_value_change(
|
||||
self,
|
||||
current: AggregatedAlert,
|
||||
previous: AlertSnapshot,
|
||||
) -> Optional[AlertChange]:
|
||||
"""Detect if there's a significant value change between current and previous.
|
||||
|
||||
Args:
|
||||
current: Current aggregated alert.
|
||||
previous: Previous alert snapshot.
|
||||
|
||||
Returns:
|
||||
AlertChange if significant change detected, None otherwise.
|
||||
"""
|
||||
threshold_key = self.ALERT_TYPE_TO_THRESHOLD_KEY.get(current.alert_type.value)
|
||||
if not threshold_key:
|
||||
return None
|
||||
|
||||
threshold = getattr(self.thresholds, threshold_key, None)
|
||||
if threshold is None:
|
||||
return None
|
||||
|
||||
delta = abs(current.extreme_value - previous.extreme_value)
|
||||
if delta >= threshold:
|
||||
return AlertChange(
|
||||
alert_type=current.alert_type.value,
|
||||
change_type=ChangeType.VALUE_CHANGED,
|
||||
description=self._format_value_change_description(
|
||||
current, previous, delta
|
||||
),
|
||||
previous_value=previous.extreme_value,
|
||||
current_value=current.extreme_value,
|
||||
value_delta=delta,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _format_new_alert_description(self, alert: AggregatedAlert) -> str:
|
||||
"""Format description for a new alert."""
|
||||
alert_type = alert.alert_type.value.replace("_", " ").title()
|
||||
return f"{alert_type} alert: {alert.extreme_value:.0f} at {alert.extreme_hour}"
|
||||
|
||||
def _format_removed_alert_description(self, snapshot: AlertSnapshot) -> str:
|
||||
"""Format description for a removed alert."""
|
||||
alert_type = snapshot.alert_type.replace("_", " ").title()
|
||||
return f"{alert_type} alert no longer active"
|
||||
|
||||
def _format_value_change_description(
|
||||
self,
|
||||
current: AggregatedAlert,
|
||||
previous: AlertSnapshot,
|
||||
delta: float,
|
||||
) -> str:
|
||||
"""Format description for a value change."""
|
||||
alert_type = current.alert_type.value.replace("_", " ").title()
|
||||
direction = "increased" if current.extreme_value > previous.extreme_value else "decreased"
|
||||
return (
|
||||
f"{alert_type} {direction} from {previous.extreme_value:.0f} "
|
||||
f"to {current.extreme_value:.0f}"
|
||||
)
|
||||
@@ -3,6 +3,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from app.models.ai_summary import SummaryNotification
|
||||
from app.models.alerts import AggregatedAlert, AlertType, TriggeredAlert
|
||||
from app.utils.http_client import HttpClient
|
||||
from app.utils.logging_config import get_logger
|
||||
@@ -20,6 +21,15 @@ class NotificationResult:
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryNotificationResult:
|
||||
"""Result of sending a summary notification."""
|
||||
|
||||
summary: SummaryNotification
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class NotificationServiceError(Exception):
|
||||
"""Raised when notification service encounters an error."""
|
||||
|
||||
@@ -175,3 +185,54 @@ class NotificationService:
|
||||
tags = list(self.ALERT_TYPE_TAGS.get(alert.alert_type, self.default_tags))
|
||||
|
||||
return tags
|
||||
|
||||
def send_summary(self, summary: SummaryNotification) -> SummaryNotificationResult:
|
||||
"""Send an AI-generated summary notification.
|
||||
|
||||
Args:
|
||||
summary: The summary notification to send.
|
||||
|
||||
Returns:
|
||||
SummaryNotificationResult indicating success or failure.
|
||||
"""
|
||||
url = f"{self.server_url}/{self.topic}"
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"Title": summary.title,
|
||||
"Priority": self.priority,
|
||||
"Tags": ",".join(summary.tags),
|
||||
}
|
||||
|
||||
if self.access_token:
|
||||
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
|
||||
self.logger.debug(
|
||||
"sending_summary_notification",
|
||||
location=summary.location,
|
||||
alert_count=summary.alert_count,
|
||||
)
|
||||
|
||||
response = self.http_client.post(
|
||||
url,
|
||||
data=summary.message.encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.success:
|
||||
self.logger.info(
|
||||
"summary_notification_sent",
|
||||
location=summary.location,
|
||||
alert_count=summary.alert_count,
|
||||
has_changes=summary.has_changes,
|
||||
)
|
||||
return SummaryNotificationResult(summary=summary, success=True)
|
||||
else:
|
||||
error_msg = f"HTTP {response.status_code}: {response.text[:100]}"
|
||||
self.logger.error(
|
||||
"summary_notification_failed",
|
||||
error=error_msg,
|
||||
)
|
||||
return SummaryNotificationResult(
|
||||
summary=summary, success=False, error=error_msg
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from app.models.alerts import AggregatedAlert, TriggeredAlert
|
||||
from app.models.state import AlertState
|
||||
from app.models.state import AlertSnapshot, AlertState
|
||||
from app.utils.logging_config import get_logger
|
||||
|
||||
# Type alias for alerts that can be deduplicated
|
||||
@@ -183,3 +183,47 @@ class StateManager:
|
||||
self.logger.info("old_records_purged", count=purged)
|
||||
|
||||
return purged
|
||||
|
||||
def save_alert_snapshots(self, alerts: list[AggregatedAlert]) -> None:
|
||||
"""Save current alerts as snapshots for change detection.
|
||||
|
||||
Args:
|
||||
alerts: List of aggregated alerts to save as snapshots.
|
||||
"""
|
||||
snapshots: dict[str, AlertSnapshot] = {}
|
||||
|
||||
for alert in alerts:
|
||||
snapshot = AlertSnapshot(
|
||||
alert_type=alert.alert_type.value,
|
||||
extreme_value=alert.extreme_value,
|
||||
threshold=alert.threshold,
|
||||
start_time=alert.start_time,
|
||||
end_time=alert.end_time,
|
||||
hour_count=alert.hour_count,
|
||||
)
|
||||
snapshots[alert.alert_type.value] = snapshot
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
self.state.previous_alert_snapshots = snapshots
|
||||
self.state.last_updated = datetime.now()
|
||||
|
||||
self.logger.debug(
|
||||
"alert_snapshots_saved",
|
||||
count=len(snapshots),
|
||||
)
|
||||
|
||||
def get_previous_snapshots(self) -> dict[str, AlertSnapshot]:
|
||||
"""Get previous alert snapshots for change detection.
|
||||
|
||||
Returns:
|
||||
Dict of alert snapshots keyed by alert type.
|
||||
"""
|
||||
return self.state.previous_alert_snapshots
|
||||
|
||||
def record_ai_summary_sent(self) -> None:
|
||||
"""Record that an AI summary was sent."""
|
||||
from datetime import datetime
|
||||
|
||||
self.state.last_ai_summary_sent = datetime.now()
|
||||
self.logger.debug("ai_summary_sent_recorded")
|
||||
|
||||
@@ -1,23 +1,38 @@
|
||||
{
|
||||
"sent_alerts": {
|
||||
"severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.2": {
|
||||
"dedup_key": "severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.2",
|
||||
"alert_type": "severe_weather",
|
||||
"sent_at": "2026-01-26T21:15:25.784599",
|
||||
"forecast_hour": "urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.2"
|
||||
},
|
||||
"severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.1": {
|
||||
"dedup_key": "severe_weather:urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.1",
|
||||
"alert_type": "severe_weather",
|
||||
"sent_at": "2026-01-26T21:15:25.784621",
|
||||
"forecast_hour": "urn:oid:2.49.0.1.840.0.60e26c16bab4970c623ad5eaf13913a10a0c4662.001.1"
|
||||
},
|
||||
"temperature_low:2026-01-26": {
|
||||
"dedup_key": "temperature_low:2026-01-26",
|
||||
"temperature_low:2026-02-01": {
|
||||
"dedup_key": "temperature_low:2026-02-01",
|
||||
"alert_type": "temperature_low",
|
||||
"sent_at": "2026-01-26T21:15:25.784625",
|
||||
"forecast_hour": "2026-01-26-22"
|
||||
"sent_at": "2026-02-01T00:01:05.309466",
|
||||
"forecast_hour": "2026-02-01-01"
|
||||
},
|
||||
"severe_weather:urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1": {
|
||||
"dedup_key": "severe_weather:urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1",
|
||||
"alert_type": "severe_weather",
|
||||
"sent_at": "2026-02-01T08:01:04.797087",
|
||||
"forecast_hour": "urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1"
|
||||
}
|
||||
},
|
||||
"last_updated": "2026-01-26T21:15:25.784626"
|
||||
"last_updated": "2026-02-01T12:01:03.966306",
|
||||
"previous_alert_snapshots": {
|
||||
"severe_weather": {
|
||||
"alert_type": "severe_weather",
|
||||
"extreme_value": 1.0,
|
||||
"threshold": 0.0,
|
||||
"start_time": "urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1",
|
||||
"end_time": "urn:oid:2.49.0.1.840.0.2e5d7c4b4f9e162788039415050daaff74f78960.001.1",
|
||||
"hour_count": 1,
|
||||
"captured_at": "2026-02-01T12:01:03.966302"
|
||||
},
|
||||
"temperature_low": {
|
||||
"alert_type": "temperature_low",
|
||||
"extreme_value": 15.0,
|
||||
"threshold": 32,
|
||||
"start_time": "2026-02-01-13",
|
||||
"end_time": "2026-02-01-16",
|
||||
"hour_count": 4,
|
||||
"captured_at": "2026-02-01T12:01:03.966304"
|
||||
}
|
||||
},
|
||||
"last_ai_summary_sent": null
|
||||
}
|
||||
339
replicate_client.py
Normal file
339
replicate_client.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Replicate API client wrapper for the Sneaky Intel Feed Aggregator.
|
||||
|
||||
Provides low-level API access with retry logic and error handling
|
||||
for all Replicate model interactions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import replicate
|
||||
|
||||
from feed_aggregator.models.config import ReplicateConfig
|
||||
from feed_aggregator.utils.exceptions import APIError
|
||||
from feed_aggregator.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ReplicateClient:
|
||||
"""
|
||||
Low-level wrapper for Replicate API calls.
|
||||
|
||||
Handles:
|
||||
- API authentication via REPLICATE_API_TOKEN
|
||||
- Retry logic with exponential backoff
|
||||
- Error classification and wrapping
|
||||
- Response normalization
|
||||
|
||||
Attributes:
|
||||
MAX_RETRIES: Maximum retry attempts for failed calls.
|
||||
INITIAL_BACKOFF_SECONDS: Initial delay before first retry.
|
||||
BACKOFF_MULTIPLIER: Multiplier for exponential backoff.
|
||||
LLAMA_MODEL: Model identifier for Llama 3 8B Instruct.
|
||||
EMBEDDING_MODEL: Model identifier for text embeddings.
|
||||
|
||||
Example:
|
||||
>>> from feed_aggregator.services.replicate_client import ReplicateClient
|
||||
>>> from feed_aggregator.models.config import ReplicateConfig
|
||||
>>>
|
||||
>>> config = ReplicateConfig(api_key="r8_...")
|
||||
>>> client = ReplicateClient(config)
|
||||
>>> response, error = client.run_llama("Summarize: ...")
|
||||
"""
|
||||
|
||||
# Retry configuration (matches DatabaseService pattern)
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_BACKOFF_SECONDS = 1.0
|
||||
BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Model identifiers
|
||||
LLAMA_MODEL = "meta/meta-llama-3-8b-instruct"
|
||||
EMBEDDING_MODEL = "beautyyuyanli/multilingual-e5-large:a06276a89f1a902d5fc225a9ca32b6e8e6292b7f3b136518878da97c458e2bad"
|
||||
|
||||
# Default parameters for Llama
|
||||
DEFAULT_MAX_TOKENS = 500
|
||||
DEFAULT_TEMPERATURE = 0.3
|
||||
|
||||
def __init__(self, config: ReplicateConfig):
|
||||
"""
|
||||
Initialize the Replicate client.
|
||||
|
||||
Sets the REPLICATE_API_TOKEN environment variable for the
|
||||
replicate library to use.
|
||||
|
||||
Args:
|
||||
config: Replicate configuration with API key.
|
||||
"""
|
||||
self._config = config
|
||||
self._retry_count = 0 # Track total retries for metrics
|
||||
|
||||
# Set API token for replicate library
|
||||
os.environ["REPLICATE_API_TOKEN"] = config.api_key
|
||||
|
||||
logger.debug(
|
||||
"ReplicateClient initialized",
|
||||
model=self.LLAMA_MODEL,
|
||||
)
|
||||
|
||||
def run_llama(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
) -> Tuple[Optional[str], Optional[APIError]]:
|
||||
"""
|
||||
Run Llama 3 8B Instruct with the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the model.
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature (0-1).
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, error).
|
||||
On success: (str, None)
|
||||
On failure: (None, APIError)
|
||||
"""
|
||||
return self._run_with_retry(
|
||||
model=self.LLAMA_MODEL,
|
||||
input_params={
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
operation="summarization",
|
||||
)
|
||||
|
||||
# Expected embedding dimensions for snowflake-arctic-embed-l
|
||||
EMBEDDING_DIMENSIONS = 1024
|
||||
|
||||
def run_embedding(
|
||||
self,
|
||||
text: str,
|
||||
) -> Tuple[Optional[List[float]], Optional[APIError]]:
|
||||
"""
|
||||
Generate embeddings using snowflake-arctic-embed-l.
|
||||
|
||||
Uses the same retry logic as run_llama but preserves the
|
||||
list response format instead of converting to string.
|
||||
|
||||
Args:
|
||||
text: Text to embed (title + summary, typically 200-500 tokens).
|
||||
|
||||
Returns:
|
||||
Tuple of (embedding_vector, error).
|
||||
On success: (List[float] with 1024 dimensions, None)
|
||||
On failure: (None, APIError)
|
||||
|
||||
Example:
|
||||
>>> embedding, error = client.run_embedding("Article about security")
|
||||
>>> if embedding:
|
||||
... print(f"Generated {len(embedding)}-dimensional embedding")
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
# E5 model expects texts as a JSON-formatted string
|
||||
output = replicate.run(
|
||||
self.EMBEDDING_MODEL,
|
||||
input={"texts": json.dumps([text])},
|
||||
)
|
||||
|
||||
# Output is a list of embeddings (one per input text)
|
||||
# We only pass one text, so extract the first embedding
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
embedding = output[0]
|
||||
|
||||
# Validate embedding dimensions
|
||||
if len(embedding) != self.EMBEDDING_DIMENSIONS:
|
||||
raise ValueError(
|
||||
f"Expected {self.EMBEDDING_DIMENSIONS} dimensions, "
|
||||
f"got {len(embedding)}"
|
||||
)
|
||||
|
||||
# Ensure all values are floats
|
||||
embedding = [float(v) for v in embedding]
|
||||
|
||||
logger.debug(
|
||||
"Replicate embedding succeeded",
|
||||
model=self.EMBEDDING_MODEL,
|
||||
operation="embedding",
|
||||
attempt=attempt,
|
||||
dimensions=len(embedding),
|
||||
)
|
||||
|
||||
return embedding, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected output format from embedding model: {type(output)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
logger.warning(
|
||||
"Replicate embedding failed, retrying",
|
||||
model=self.EMBEDDING_MODEL,
|
||||
operation="embedding",
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
logger.error(
|
||||
"Replicate embedding failed after retries",
|
||||
model=self.EMBEDDING_MODEL,
|
||||
operation="embedding",
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
return None, APIError(
|
||||
message=f"Replicate embedding failed: {str(last_error)}",
|
||||
api_name="Replicate",
|
||||
operation="embedding",
|
||||
retry_count=self.MAX_RETRIES,
|
||||
context={"model": self.EMBEDDING_MODEL, "error": str(last_error)},
|
||||
)
|
||||
|
||||
def _run_with_retry(
|
||||
self,
|
||||
model: str,
|
||||
input_params: Dict[str, Any],
|
||||
operation: str,
|
||||
) -> Tuple[Optional[str], Optional[APIError]]:
|
||||
"""
|
||||
Execute model call with retry logic.
|
||||
|
||||
Implements exponential backoff retry strategy for transient errors:
|
||||
- Rate limit errors (429)
|
||||
- Server errors (5xx)
|
||||
- Timeout errors
|
||||
|
||||
Non-retryable errors (auth failures, etc.) fail immediately.
|
||||
|
||||
Args:
|
||||
model: Replicate model identifier.
|
||||
input_params: Model input parameters.
|
||||
operation: Operation name for error reporting.
|
||||
|
||||
Returns:
|
||||
Tuple of (response, error).
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
output = replicate.run(model, input=input_params)
|
||||
|
||||
# Replicate returns a generator for streaming models
|
||||
# Collect all output parts into a single string
|
||||
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||
response = "".join(str(part) for part in output)
|
||||
else:
|
||||
response = str(output)
|
||||
|
||||
logger.debug(
|
||||
"Replicate API call succeeded",
|
||||
model=model,
|
||||
operation=operation,
|
||||
attempt=attempt,
|
||||
response_length=len(response),
|
||||
)
|
||||
|
||||
return response, None
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
logger.warning(
|
||||
"Replicate API call failed, retrying",
|
||||
model=model,
|
||||
operation=operation,
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
logger.error(
|
||||
"Replicate API call failed after retries",
|
||||
model=model,
|
||||
operation=operation,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
return None, APIError(
|
||||
message=f"Replicate API call failed: {str(last_error)}",
|
||||
api_name="Replicate",
|
||||
operation=operation,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
context={"model": model, "error": str(last_error)},
|
||||
)
|
||||
|
||||
@property
|
||||
def total_retries(self) -> int:
|
||||
"""
|
||||
Return total retry count across all calls.
|
||||
|
||||
Returns:
|
||||
Total number of retry attempts made.
|
||||
"""
|
||||
return self._retry_count
|
||||
|
||||
def reset_retry_count(self) -> None:
|
||||
"""
|
||||
Reset retry counter.
|
||||
|
||||
Useful for metrics collection between batch operations.
|
||||
"""
|
||||
self._retry_count = 0
|
||||
@@ -2,5 +2,6 @@ requests>=2.28.0,<3.0.0
|
||||
structlog>=23.0.0,<25.0.0
|
||||
python-dotenv>=1.0.0,<2.0.0
|
||||
PyYAML>=6.0,<7.0
|
||||
replicate>=0.25.0,<1.0.0
|
||||
pytest>=7.0.0,<9.0.0
|
||||
responses>=0.23.0,<1.0.0
|
||||
|
||||
Reference in New Issue
Block a user