95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
"""Tests for the SeedService class."""
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import bcrypt
|
|
from sqlmodel import SQLModel, Session, create_engine, select
|
|
|
|
from app.models.user import User
|
|
from app.models.exercise import Exercise
|
|
from app.models.warmup import Warmup
|
|
from app.models.workout_day import WorkoutDay
|
|
from app.models.user_exercise_program import UserExerciseProgram
|
|
from app.services.seed_service import SeedService
|
|
|
|
|
|
class TestSeedService:
|
|
"""Tests for YAML-based database seeding."""
|
|
|
|
def _setup(self):
|
|
"""Create an in-memory DB and seed service."""
|
|
engine = create_engine("sqlite:///:memory:")
|
|
SQLModel.metadata.create_all(engine)
|
|
session = Session(engine)
|
|
# Use the real config files from the project
|
|
config_dir = Path(__file__).resolve().parent.parent / "config"
|
|
service = SeedService(session, config_dir=config_dir)
|
|
return session, service
|
|
|
|
def test_seed_workout_days(self) -> None:
|
|
"""seed_workout_days should create 4 workout day records."""
|
|
session, service = self._setup()
|
|
service.seed_workout_days()
|
|
days = session.exec(select(WorkoutDay)).all()
|
|
assert len(days) == 4
|
|
names = {d.name for d in days}
|
|
assert names == {"Push", "Pull", "Lower", "Full Body"}
|
|
session.close()
|
|
|
|
def test_seed_exercises_from_yaml(self) -> None:
|
|
"""seed_exercises should load all exercises from exercises.yaml."""
|
|
session, service = self._setup()
|
|
service.seed_exercises()
|
|
exercises = session.exec(select(Exercise)).all()
|
|
assert len(exercises) == 20 # 5 Push + 5 Pull + 5 Lower + 5 Full Body
|
|
session.close()
|
|
|
|
def test_seed_warmups_from_yaml(self) -> None:
|
|
"""seed_warmups should load all warmups from exercises.yaml."""
|
|
session, service = self._setup()
|
|
service.seed_warmups()
|
|
warmups = session.exec(select(Warmup)).all()
|
|
assert len(warmups) == 6
|
|
session.close()
|
|
|
|
def test_seed_admin_user(self) -> None:
|
|
"""seed_admin should create admin user with hashed password."""
|
|
session, service = self._setup()
|
|
with patch.dict("os.environ", {
|
|
"ADMIN_USERNAME": "admin",
|
|
"ADMIN_PASSWORD": "testpass",
|
|
}):
|
|
service.seed_admin()
|
|
admin = session.exec(select(User).where(User.is_admin == True)).first() # noqa: E712
|
|
assert admin is not None
|
|
assert admin.username == "admin"
|
|
assert bcrypt.checkpw(b"testpass", admin.password_hash.encode())
|
|
session.close()
|
|
|
|
def test_seed_user_programs(self) -> None:
|
|
"""seed_user_programs should create user profiles and link exercises."""
|
|
session, service = self._setup()
|
|
service.seed_exercises()
|
|
service.seed_user_programs()
|
|
programs = session.exec(select(UserExerciseProgram)).all()
|
|
assert len(programs) > 0
|
|
users = session.exec(select(User).where(User.is_admin == False)).all() # noqa: E712
|
|
assert len(users) == 2 # Phillip and Daughter
|
|
session.close()
|
|
|
|
def test_seed_is_idempotent(self) -> None:
|
|
"""Running seed_all twice should not create duplicate records."""
|
|
session, service = self._setup()
|
|
with patch.dict("os.environ", {
|
|
"ADMIN_USERNAME": "admin",
|
|
"ADMIN_PASSWORD": "testpass",
|
|
}):
|
|
service.seed_all()
|
|
service.seed_all()
|
|
users = session.exec(select(User)).all()
|
|
exercises = session.exec(select(Exercise)).all()
|
|
# Should not be doubled
|
|
assert len(exercises) == 20
|
|
session.close()
|