""" Session Cache Service This service provides Redis-based caching for authenticated sessions to reduce Appwrite API calls. Instead of validating every request with Appwrite, we cache the session data and validate periodically (default: every 5 minutes). Security Features: - Session tokens are hashed (SHA-256) before use as cache keys - Session expiry is checked on every cache hit - Explicit invalidation on logout and password change - Graceful degradation on Redis failure (falls back to Appwrite) Usage: from app.services.session_cache_service import SessionCacheService cache = SessionCacheService() # Get cached user (returns None on miss) user = cache.get(token) # Cache a validated session cache.set(token, user_data, session_expire) # Invalidate on logout cache.invalidate_token(token) # Invalidate all user sessions on password change cache.invalidate_user(user_id) """ import hashlib import time from datetime import datetime, timezone from typing import Optional, Any, Dict from app.services.redis_service import RedisService, RedisServiceError from app.services.appwrite_service import UserData from app.config import get_config from app.utils.logging import get_logger # Initialize logger logger = get_logger(__file__) # Cache key prefixes SESSION_CACHE_PREFIX = "session_cache:" USER_INVALIDATION_PREFIX = "user_invalidated:" class SessionCacheService: """ Redis-based session cache service. This service caches validated session data to reduce the number of Appwrite API calls per request. Sessions are cached with a configurable TTL (default: 5 minutes) and are explicitly invalidated on logout or password change. Attributes: enabled: Whether caching is enabled ttl_seconds: Cache TTL in seconds redis: RedisService instance """ def __init__(self): """ Initialize the session cache service. Reads configuration from the auth.session_cache config section. If caching is disabled or Redis connection fails, operates in pass-through mode (always returns None). """ self.config = get_config() self.enabled = self.config.auth.session_cache.enabled self.ttl_seconds = self.config.auth.session_cache.ttl_seconds self.redis_db = self.config.auth.session_cache.redis_db self._redis: Optional[RedisService] = None if self.enabled: try: # Build Redis URL with the session cache database redis_url = f"redis://{self.config.redis.host}:{self.config.redis.port}/{self.redis_db}" self._redis = RedisService(redis_url=redis_url) logger.info( "Session cache service initialized", ttl_seconds=self.ttl_seconds, redis_db=self.redis_db ) except RedisServiceError as e: logger.warning( "Failed to initialize session cache, operating in pass-through mode", error=str(e) ) self.enabled = False def _hash_token(self, token: str) -> str: """ Hash a session token for use as a cache key. Tokens are hashed to prevent enumeration attacks if Redis is compromised. We use the first 32 characters of the SHA-256 hash as a balance between collision resistance and key length. Args: token: The raw session token Returns: First 32 characters of the SHA-256 hash """ return hashlib.sha256(token.encode()).hexdigest()[:32] def _get_cache_key(self, token: str) -> str: """ Generate a cache key for a session token. Args: token: The raw session token Returns: Cache key in format "session_cache:{hashed_token}" """ return f"{SESSION_CACHE_PREFIX}{self._hash_token(token)}" def _get_invalidation_key(self, user_id: str) -> str: """ Generate an invalidation marker key for a user. Args: user_id: The user ID Returns: Invalidation key in format "user_invalidated:{user_id}" """ return f"{USER_INVALIDATION_PREFIX}{user_id}" def get(self, token: str) -> Optional[UserData]: """ Retrieve cached user data for a session token. This method: 1. Checks if caching is enabled 2. Retrieves cached data from Redis 3. Validates session hasn't expired 4. Checks user hasn't been invalidated (password change) 5. Returns UserData if valid, None otherwise Args: token: The session token to look up Returns: UserData if cache hit and valid, None if miss or invalid """ if not self.enabled or not self._redis: return None try: cache_key = self._get_cache_key(token) cached_data = self._redis.get_json(cache_key) if cached_data is None: logger.debug("Session cache miss", cache_key=cache_key[:16]) return None # Check if session has expired session_expire = cached_data.get("session_expire") if session_expire: expire_time = datetime.fromisoformat(session_expire) # Ensure both datetimes are timezone-aware for comparison now = datetime.now(timezone.utc) if expire_time.tzinfo is None: expire_time = expire_time.replace(tzinfo=timezone.utc) if expire_time < now: logger.debug("Cached session expired", cache_key=cache_key[:16]) self._redis.delete(cache_key) return None # Check if user has been invalidated (password change) user_id = cached_data.get("user_id") if user_id: invalidation_key = self._get_invalidation_key(user_id) invalidated_at = self._redis.get(invalidation_key) if invalidated_at: cached_at = cached_data.get("cached_at", 0) if float(invalidated_at) > cached_at: logger.debug( "User invalidated after cache, rejecting", user_id=user_id ) self._redis.delete(cache_key) return None # Reconstruct UserData user_data = UserData( id=cached_data["user_id"], email=cached_data["email"], name=cached_data["name"], email_verified=cached_data["email_verified"], tier=cached_data["tier"], created_at=datetime.fromisoformat(cached_data["created_at"]), updated_at=datetime.fromisoformat(cached_data["updated_at"]) ) logger.debug("Session cache hit", user_id=user_data.id) return user_data except RedisServiceError as e: logger.warning("Session cache read failed, falling back to Appwrite", error=str(e)) return None except (KeyError, ValueError, TypeError) as e: logger.warning("Invalid cached session data", error=str(e)) return None def set( self, token: str, user_data: UserData, session_expire: datetime ) -> bool: """ Cache a validated session. Args: token: The session token user_data: The validated user data to cache session_expire: When the session expires (from Appwrite) Returns: True if cached successfully, False otherwise """ if not self.enabled or not self._redis: return False try: cache_key = self._get_cache_key(token) # Calculate effective TTL (min of config TTL and session remaining time) # Ensure timezone-aware comparison now = datetime.now(timezone.utc) expire_aware = session_expire if session_expire.tzinfo else session_expire.replace(tzinfo=timezone.utc) session_remaining = (expire_aware - now).total_seconds() effective_ttl = min(self.ttl_seconds, max(1, int(session_remaining))) cache_data: Dict[str, Any] = { "user_id": user_data.id, "email": user_data.email, "name": user_data.name, "email_verified": user_data.email_verified, "tier": user_data.tier, "created_at": user_data.created_at.isoformat() if isinstance(user_data.created_at, datetime) else user_data.created_at, "updated_at": user_data.updated_at.isoformat() if isinstance(user_data.updated_at, datetime) else user_data.updated_at, "session_expire": session_expire.isoformat(), "cached_at": time.time() } success = self._redis.set_json(cache_key, cache_data, ttl=effective_ttl) if success: logger.debug( "Session cached", user_id=user_data.id, ttl=effective_ttl ) return success except RedisServiceError as e: logger.warning("Session cache write failed", error=str(e)) return False def invalidate_token(self, token: str) -> bool: """ Invalidate a specific session token (used on logout). Args: token: The session token to invalidate Returns: True if invalidated successfully, False otherwise """ if not self.enabled or not self._redis: return False try: cache_key = self._get_cache_key(token) deleted = self._redis.delete(cache_key) logger.debug("Session cache invalidated", deleted_count=deleted) return deleted > 0 except RedisServiceError as e: logger.warning("Session cache invalidation failed", error=str(e)) return False def invalidate_user(self, user_id: str) -> bool: """ Invalidate all sessions for a user (used on password change). This sets an invalidation marker with the current timestamp. Any cached sessions created before this timestamp will be rejected. Args: user_id: The user ID to invalidate Returns: True if invalidation marker set successfully, False otherwise """ if not self.enabled or not self._redis: return False try: invalidation_key = self._get_invalidation_key(user_id) # Set invalidation marker with TTL matching session duration # Use the longer duration (remember_me) to ensure coverage marker_ttl = self.config.auth.duration_remember_me success = self._redis.set( invalidation_key, str(time.time()), ttl=marker_ttl ) if success: logger.info( "User sessions invalidated", user_id=user_id, marker_ttl=marker_ttl ) return success except RedisServiceError as e: logger.warning("User invalidation failed", error=str(e), user_id=user_id) return False def health_check(self) -> bool: """ Check if the session cache is healthy. Returns: True if Redis is healthy and caching is enabled, False otherwise """ if not self.enabled or not self._redis: return False return self._redis.health_check()