- Add SessionCacheService with 5-minute TTL Redis cache - Cache validated sessions to avoid redundant Appwrite calls - Add /api/v1/auth/me endpoint for retrieving current user - Invalidate cache on logout and password reset - Add session_cache config to auth section (Redis db 2) - Fix Docker Redis hostname (localhost -> redis) - Handle timezone-aware datetime comparisons Security: tokens hashed before use as cache keys, explicit invalidation on logout/password change, graceful degradation when Redis unavailable. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
347 lines
12 KiB
Python
347 lines
12 KiB
Python
"""
|
|
Session Cache Service
|
|
|
|
This service provides Redis-based caching for authenticated sessions to reduce
|
|
Appwrite API calls. Instead of validating every request with Appwrite, we cache
|
|
the session data and validate periodically (default: every 5 minutes).
|
|
|
|
Security Features:
|
|
- Session tokens are hashed (SHA-256) before use as cache keys
|
|
- Session expiry is checked on every cache hit
|
|
- Explicit invalidation on logout and password change
|
|
- Graceful degradation on Redis failure (falls back to Appwrite)
|
|
|
|
Usage:
|
|
from app.services.session_cache_service import SessionCacheService
|
|
|
|
cache = SessionCacheService()
|
|
|
|
# Get cached user (returns None on miss)
|
|
user = cache.get(token)
|
|
|
|
# Cache a validated session
|
|
cache.set(token, user_data, session_expire)
|
|
|
|
# Invalidate on logout
|
|
cache.invalidate_token(token)
|
|
|
|
# Invalidate all user sessions on password change
|
|
cache.invalidate_user(user_id)
|
|
"""
|
|
|
|
import hashlib
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Optional, Any, Dict
|
|
|
|
from app.services.redis_service import RedisService, RedisServiceError
|
|
from app.services.appwrite_service import UserData
|
|
from app.config import get_config
|
|
from app.utils.logging import get_logger
|
|
|
|
|
|
# Initialize logger
|
|
logger = get_logger(__file__)
|
|
|
|
# Cache key prefixes
|
|
SESSION_CACHE_PREFIX = "session_cache:"
|
|
USER_INVALIDATION_PREFIX = "user_invalidated:"
|
|
|
|
|
|
class SessionCacheService:
|
|
"""
|
|
Redis-based session cache service.
|
|
|
|
This service caches validated session data to reduce the number of
|
|
Appwrite API calls per request. Sessions are cached with a configurable
|
|
TTL (default: 5 minutes) and are explicitly invalidated on logout or
|
|
password change.
|
|
|
|
Attributes:
|
|
enabled: Whether caching is enabled
|
|
ttl_seconds: Cache TTL in seconds
|
|
redis: RedisService instance
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
Initialize the session cache service.
|
|
|
|
Reads configuration from the auth.session_cache config section.
|
|
If caching is disabled or Redis connection fails, operates in
|
|
pass-through mode (always returns None).
|
|
"""
|
|
self.config = get_config()
|
|
self.enabled = self.config.auth.session_cache.enabled
|
|
self.ttl_seconds = self.config.auth.session_cache.ttl_seconds
|
|
self.redis_db = self.config.auth.session_cache.redis_db
|
|
self._redis: Optional[RedisService] = None
|
|
|
|
if self.enabled:
|
|
try:
|
|
# Build Redis URL with the session cache database
|
|
redis_url = f"redis://{self.config.redis.host}:{self.config.redis.port}/{self.redis_db}"
|
|
self._redis = RedisService(redis_url=redis_url)
|
|
logger.info(
|
|
"Session cache service initialized",
|
|
ttl_seconds=self.ttl_seconds,
|
|
redis_db=self.redis_db
|
|
)
|
|
except RedisServiceError as e:
|
|
logger.warning(
|
|
"Failed to initialize session cache, operating in pass-through mode",
|
|
error=str(e)
|
|
)
|
|
self.enabled = False
|
|
|
|
def _hash_token(self, token: str) -> str:
|
|
"""
|
|
Hash a session token for use as a cache key.
|
|
|
|
Tokens are hashed to prevent enumeration attacks if Redis is
|
|
compromised. We use the first 32 characters of the SHA-256 hash
|
|
as a balance between collision resistance and key length.
|
|
|
|
Args:
|
|
token: The raw session token
|
|
|
|
Returns:
|
|
First 32 characters of the SHA-256 hash
|
|
"""
|
|
return hashlib.sha256(token.encode()).hexdigest()[:32]
|
|
|
|
def _get_cache_key(self, token: str) -> str:
|
|
"""
|
|
Generate a cache key for a session token.
|
|
|
|
Args:
|
|
token: The raw session token
|
|
|
|
Returns:
|
|
Cache key in format "session_cache:{hashed_token}"
|
|
"""
|
|
return f"{SESSION_CACHE_PREFIX}{self._hash_token(token)}"
|
|
|
|
def _get_invalidation_key(self, user_id: str) -> str:
|
|
"""
|
|
Generate an invalidation marker key for a user.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
|
|
Returns:
|
|
Invalidation key in format "user_invalidated:{user_id}"
|
|
"""
|
|
return f"{USER_INVALIDATION_PREFIX}{user_id}"
|
|
|
|
def get(self, token: str) -> Optional[UserData]:
|
|
"""
|
|
Retrieve cached user data for a session token.
|
|
|
|
This method:
|
|
1. Checks if caching is enabled
|
|
2. Retrieves cached data from Redis
|
|
3. Validates session hasn't expired
|
|
4. Checks user hasn't been invalidated (password change)
|
|
5. Returns UserData if valid, None otherwise
|
|
|
|
Args:
|
|
token: The session token to look up
|
|
|
|
Returns:
|
|
UserData if cache hit and valid, None if miss or invalid
|
|
"""
|
|
if not self.enabled or not self._redis:
|
|
return None
|
|
|
|
try:
|
|
cache_key = self._get_cache_key(token)
|
|
cached_data = self._redis.get_json(cache_key)
|
|
|
|
if cached_data is None:
|
|
logger.debug("Session cache miss", cache_key=cache_key[:16])
|
|
return None
|
|
|
|
# Check if session has expired
|
|
session_expire = cached_data.get("session_expire")
|
|
if session_expire:
|
|
expire_time = datetime.fromisoformat(session_expire)
|
|
# Ensure both datetimes are timezone-aware for comparison
|
|
now = datetime.now(timezone.utc)
|
|
if expire_time.tzinfo is None:
|
|
expire_time = expire_time.replace(tzinfo=timezone.utc)
|
|
if expire_time < now:
|
|
logger.debug("Cached session expired", cache_key=cache_key[:16])
|
|
self._redis.delete(cache_key)
|
|
return None
|
|
|
|
# Check if user has been invalidated (password change)
|
|
user_id = cached_data.get("user_id")
|
|
if user_id:
|
|
invalidation_key = self._get_invalidation_key(user_id)
|
|
invalidated_at = self._redis.get(invalidation_key)
|
|
if invalidated_at:
|
|
cached_at = cached_data.get("cached_at", 0)
|
|
if float(invalidated_at) > cached_at:
|
|
logger.debug(
|
|
"User invalidated after cache, rejecting",
|
|
user_id=user_id
|
|
)
|
|
self._redis.delete(cache_key)
|
|
return None
|
|
|
|
# Reconstruct UserData
|
|
user_data = UserData(
|
|
id=cached_data["user_id"],
|
|
email=cached_data["email"],
|
|
name=cached_data["name"],
|
|
email_verified=cached_data["email_verified"],
|
|
tier=cached_data["tier"],
|
|
created_at=datetime.fromisoformat(cached_data["created_at"]),
|
|
updated_at=datetime.fromisoformat(cached_data["updated_at"])
|
|
)
|
|
|
|
logger.debug("Session cache hit", user_id=user_data.id)
|
|
return user_data
|
|
|
|
except RedisServiceError as e:
|
|
logger.warning("Session cache read failed, falling back to Appwrite", error=str(e))
|
|
return None
|
|
except (KeyError, ValueError, TypeError) as e:
|
|
logger.warning("Invalid cached session data", error=str(e))
|
|
return None
|
|
|
|
def set(
|
|
self,
|
|
token: str,
|
|
user_data: UserData,
|
|
session_expire: datetime
|
|
) -> bool:
|
|
"""
|
|
Cache a validated session.
|
|
|
|
Args:
|
|
token: The session token
|
|
user_data: The validated user data to cache
|
|
session_expire: When the session expires (from Appwrite)
|
|
|
|
Returns:
|
|
True if cached successfully, False otherwise
|
|
"""
|
|
if not self.enabled or not self._redis:
|
|
return False
|
|
|
|
try:
|
|
cache_key = self._get_cache_key(token)
|
|
|
|
# Calculate effective TTL (min of config TTL and session remaining time)
|
|
# Ensure timezone-aware comparison
|
|
now = datetime.now(timezone.utc)
|
|
expire_aware = session_expire if session_expire.tzinfo else session_expire.replace(tzinfo=timezone.utc)
|
|
session_remaining = (expire_aware - now).total_seconds()
|
|
effective_ttl = min(self.ttl_seconds, max(1, int(session_remaining)))
|
|
|
|
cache_data: Dict[str, Any] = {
|
|
"user_id": user_data.id,
|
|
"email": user_data.email,
|
|
"name": user_data.name,
|
|
"email_verified": user_data.email_verified,
|
|
"tier": user_data.tier,
|
|
"created_at": user_data.created_at.isoformat() if isinstance(user_data.created_at, datetime) else user_data.created_at,
|
|
"updated_at": user_data.updated_at.isoformat() if isinstance(user_data.updated_at, datetime) else user_data.updated_at,
|
|
"session_expire": session_expire.isoformat(),
|
|
"cached_at": time.time()
|
|
}
|
|
|
|
success = self._redis.set_json(cache_key, cache_data, ttl=effective_ttl)
|
|
|
|
if success:
|
|
logger.debug(
|
|
"Session cached",
|
|
user_id=user_data.id,
|
|
ttl=effective_ttl
|
|
)
|
|
|
|
return success
|
|
|
|
except RedisServiceError as e:
|
|
logger.warning("Session cache write failed", error=str(e))
|
|
return False
|
|
|
|
def invalidate_token(self, token: str) -> bool:
|
|
"""
|
|
Invalidate a specific session token (used on logout).
|
|
|
|
Args:
|
|
token: The session token to invalidate
|
|
|
|
Returns:
|
|
True if invalidated successfully, False otherwise
|
|
"""
|
|
if not self.enabled or not self._redis:
|
|
return False
|
|
|
|
try:
|
|
cache_key = self._get_cache_key(token)
|
|
deleted = self._redis.delete(cache_key)
|
|
|
|
logger.debug("Session cache invalidated", deleted_count=deleted)
|
|
return deleted > 0
|
|
|
|
except RedisServiceError as e:
|
|
logger.warning("Session cache invalidation failed", error=str(e))
|
|
return False
|
|
|
|
def invalidate_user(self, user_id: str) -> bool:
|
|
"""
|
|
Invalidate all sessions for a user (used on password change).
|
|
|
|
This sets an invalidation marker with the current timestamp.
|
|
Any cached sessions created before this timestamp will be rejected.
|
|
|
|
Args:
|
|
user_id: The user ID to invalidate
|
|
|
|
Returns:
|
|
True if invalidation marker set successfully, False otherwise
|
|
"""
|
|
if not self.enabled or not self._redis:
|
|
return False
|
|
|
|
try:
|
|
invalidation_key = self._get_invalidation_key(user_id)
|
|
|
|
# Set invalidation marker with TTL matching session duration
|
|
# Use the longer duration (remember_me) to ensure coverage
|
|
marker_ttl = self.config.auth.duration_remember_me
|
|
success = self._redis.set(
|
|
invalidation_key,
|
|
str(time.time()),
|
|
ttl=marker_ttl
|
|
)
|
|
|
|
if success:
|
|
logger.info(
|
|
"User sessions invalidated",
|
|
user_id=user_id,
|
|
marker_ttl=marker_ttl
|
|
)
|
|
|
|
return success
|
|
|
|
except RedisServiceError as e:
|
|
logger.warning("User invalidation failed", error=str(e), user_id=user_id)
|
|
return False
|
|
|
|
def health_check(self) -> bool:
|
|
"""
|
|
Check if the session cache is healthy.
|
|
|
|
Returns:
|
|
True if Redis is healthy and caching is enabled, False otherwise
|
|
"""
|
|
if not self.enabled or not self._redis:
|
|
return False
|
|
|
|
return self._redis.health_check()
|