Files
api_test/app/utils/token_store.py
2025-10-15 13:58:10 -05:00

219 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# token_store.py
from __future__ import annotations
import os
import io
import uuid
import time
import json
import tempfile
import threading
import secrets
from dataclasses import dataclass, asdict
from typing import List, Optional, Dict
import bcrypt
import yaml
@dataclass
class ApiKey:
id: str
label: str
hash: str # bcrypt hash (utf-8 str)
created_at: str # ISO8601-ish seconds precision
active: bool = True
@staticmethod
def now_iso() -> str:
# keep it simple; no tz handling needed for a local file
return time.strftime("%Y-%m-%dT%H:%M:%S")
class ApiKeyStore:
"""
Minimal API key manager:
- YAML file on disk
- stores bcrypt hashes only
- mint/list/get/deactivate/delete/rotate/verify
File format:
---
version: 1
keys:
- id: <uuid4>
label: "build bot"
hash: "$2b$12$..."
created_at: "2025-10-15T10:22:00"
active: true
"""
def __init__(self, path: str, bcrypt_rounds: int = 12):
self.path = path
self.bcrypt_rounds = bcrypt_rounds
self._lock = threading.Lock()
self._data: Dict[str, object] = {"version": 1, "keys": []}
self._load_if_exists()
# ---------- public API ----------
def mint(self, label: str = "") -> Dict[str, str]:
"""
Create a new API token.
Returns a dict with plaintext `token` (show it once!) and its `id`.
"""
token = secrets.token_urlsafe(32)
hashed = bcrypt.hashpw(token.encode("utf-8"),
bcrypt.gensalt(rounds=self.bcrypt_rounds)).decode("utf-8")
key = ApiKey(
id=str(uuid.uuid4()),
label=label,
hash=hashed,
created_at=ApiKey.now_iso(),
active=True,
)
with self._lock:
keys = self._keys()
keys.append(key)
self._commit()
return {"id": key.id, "token": token}
def list(self, include_inactive: bool = True) -> List[Dict[str, object]]:
keys = self._keys()
out = []
for k in keys:
if include_inactive or k.active:
d = asdict(k)
d.pop("hash") # dont leak hashes in listing
out.append(d)
return out
def get(self, key_id: str) -> Optional[Dict[str, object]]:
k = self._find(key_id)
if not k:
return None
d = asdict(k)
d.pop("hash")
return d
def deactivate(self, key_id: str) -> bool:
with self._lock:
k = self._find(key_id)
if not k:
return False
if not k.active:
return True
k.active = False
self._commit()
return True
def activate(self, key_id: str) -> bool:
with self._lock:
k = self._find(key_id)
if not k:
return False
if k.active:
return True
k.active = True
self._commit()
return True
def delete(self, key_id: str) -> bool:
with self._lock:
keys = self._keys()
before = len(keys)
keys[:] = [k for k in keys if k.id != key_id]
changed = len(keys) != before
if changed:
self._commit()
return changed
def rotate(self, key_id: str) -> Optional[Dict[str, str]]:
"""
Mint a brand-new token for an existing key ID (keeps label).
The old hash is replaced. Returns {"id", "token"} or None if not found.
"""
with self._lock:
k = self._find(key_id)
if not k:
return None
new_token = secrets.token_urlsafe(32)
new_hash = bcrypt.hashpw(new_token.encode("utf-8"),
bcrypt.gensalt(rounds=self.bcrypt_rounds)).decode("utf-8")
k.hash = new_hash
k.created_at = ApiKey.now_iso()
k.active = True
self._commit()
return {"id": k.id, "token": new_token}
def verify(self, token: str) -> bool:
"""
Check whether a plaintext token matches any ACTIVE key.
"""
if not token:
return False
token_b = token.encode("utf-8")
# no lock needed for read-only access; file commits are atomic
for k in self._keys():
if not k.active:
continue
try:
if bcrypt.checkpw(token_b, k.hash.encode("utf-8")):
return True
except ValueError:
# malformed hash in file—treat as non-match
continue
return False
# ---------- internals ----------
def _load_if_exists(self) -> None:
if not os.path.exists(self.path):
return
with open(self.path, "r", encoding="utf-8") as f:
raw = yaml.safe_load(f) or {}
version = raw.get("version", 1)
if version != 1:
raise ValueError(f"Unsupported token store version: {version}")
keys_raw = raw.get("keys", [])
keys: List[ApiKey] = []
for item in keys_raw:
keys.append(ApiKey(
id=str(item["id"]),
label=str(item.get("label", "")),
hash=str(item["hash"]),
created_at=str(item.get("created_at", ApiKey.now_iso())),
active=bool(item.get("active", True)),
))
self._data = {"version": 1, "keys": keys}
def _keys(self) -> List[ApiKey]:
return self._data["keys"] # type: ignore[return-value]
def _commit(self) -> None:
"""
Write YAML atomically to avoid partial writes.
"""
directory = os.path.dirname(os.path.abspath(self.path)) or "."
os.makedirs(directory, exist_ok=True)
# Convert dataclasses to serializable dicts
payload = {
"version": 1,
"keys": [asdict(k) for k in self._keys()],
# optional: a tiny meta line to help with diffs
"_meta": {"updated": ApiKey.now_iso(), "count": len(self._keys())},
}
tmp_fd, tmp_path = tempfile.mkstemp(prefix=".apikeystore.", dir=directory)
try:
with io.open(tmp_fd, "w", encoding="utf-8") as tmp:
yaml.safe_dump(payload, tmp, sort_keys=False)
os.replace(tmp_path, self.path) # atomic on POSIX & Windows
finally:
# If os.replace throws, ensure temp file is gone
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except OSError:
pass