"""MinimumWorkload service — CRUD, workload computation and validation. BE-CAL-004: user-level workload config read/write. BE-CAL-007: workload warning rules — compute actual scheduled minutes across daily/weekly/monthly/yearly periods and compare against thresholds. """ from __future__ import annotations import copy from datetime import date, timedelta from typing import Optional from sqlalchemy import func as sa_func from sqlalchemy.orm import Session from app.models.calendar import SlotStatus, SlotType, TimeSlot from app.models.minimum_workload import ( DEFAULT_WORKLOAD_CONFIG, CATEGORIES, PERIODS, MinimumWorkload, ) from app.schemas.calendar import ( MinimumWorkloadConfig, MinimumWorkloadUpdate, WorkloadWarningItem, ) from app.services.plan_slot import get_virtual_slots_for_date # Slot types that map to workload categories. "system" is excluded. _SLOT_TYPE_TO_CATEGORY = { SlotType.WORK: "work", SlotType.ON_CALL: "on_call", SlotType.ENTERTAINMENT: "entertainment", } # Statuses that should NOT count towards workload (cancelled / failed slots). _EXCLUDED_STATUSES = {SlotStatus.SKIPPED, SlotStatus.ABORTED} # --------------------------------------------------------------------------- # Read # --------------------------------------------------------------------------- def get_workload_config(db: Session, user_id: int) -> dict: """Return the raw config dict for *user_id*, falling back to defaults.""" row = db.query(MinimumWorkload).filter(MinimumWorkload.user_id == user_id).first() if row is None: return copy.deepcopy(DEFAULT_WORKLOAD_CONFIG) return row.config def get_workload_row(db: Session, user_id: int) -> Optional[MinimumWorkload]: """Return the ORM row or None.""" return db.query(MinimumWorkload).filter(MinimumWorkload.user_id == user_id).first() # --------------------------------------------------------------------------- # Write (upsert) # --------------------------------------------------------------------------- def upsert_workload_config( db: Session, user_id: int, update: MinimumWorkloadUpdate, ) -> MinimumWorkload: """Create or update the workload config for *user_id*. Only the periods present in *update* are overwritten; the rest keep their current (or default) values. """ row = db.query(MinimumWorkload).filter(MinimumWorkload.user_id == user_id).first() if row is None: row = MinimumWorkload( user_id=user_id, config=copy.deepcopy(DEFAULT_WORKLOAD_CONFIG), ) db.add(row) # Merge provided periods into existing config current = copy.deepcopy(row.config) if row.config else copy.deepcopy(DEFAULT_WORKLOAD_CONFIG) for period in PERIODS: period_data = getattr(update, period, None) if period_data is not None: current[period] = period_data.model_dump() # Ensure JSON column is flagged as dirty for SQLAlchemy row.config = current db.flush() return row def replace_workload_config( db: Session, user_id: int, config: MinimumWorkloadConfig, ) -> MinimumWorkload: """Full replace of the workload config for *user_id*.""" row = db.query(MinimumWorkload).filter(MinimumWorkload.user_id == user_id).first() if row is None: row = MinimumWorkload(user_id=user_id, config=config.model_dump()) db.add(row) else: row.config = config.model_dump() db.flush() return row # --------------------------------------------------------------------------- # Workload computation (BE-CAL-007) # --------------------------------------------------------------------------- def _date_range_for_period( period: str, reference_date: date, ) -> tuple[date, date]: """Return inclusive ``(start, end)`` date bounds for *period* containing *reference_date*. - daily → just the reference date itself - weekly → ISO week (Mon–Sun) containing the reference date - monthly → calendar month containing the reference date - yearly → calendar year containing the reference date """ if period == "daily": return reference_date, reference_date if period == "weekly": # ISO weekday: Monday=1 … Sunday=7 start = reference_date - timedelta(days=reference_date.weekday()) # Monday end = start + timedelta(days=6) # Sunday return start, end if period == "monthly": start = reference_date.replace(day=1) # Last day of month if reference_date.month == 12: end = reference_date.replace(month=12, day=31) else: end = reference_date.replace(month=reference_date.month + 1, day=1) - timedelta(days=1) return start, end if period == "yearly": start = reference_date.replace(month=1, day=1) end = reference_date.replace(month=12, day=31) return start, end raise ValueError(f"Unknown period: {period}") def _sum_real_slots( db: Session, user_id: int, start_date: date, end_date: date, ) -> dict[str, int]: """Sum ``estimated_duration`` of real (materialized) slots by category. Returns ``{"work": N, "on_call": N, "entertainment": N}`` with minutes. Slots with status in ``_EXCLUDED_STATUSES`` or ``slot_type=system`` are skipped. """ excluded = [s.value for s in _EXCLUDED_STATUSES] rows = ( db.query( TimeSlot.slot_type, sa_func.coalesce(sa_func.sum(TimeSlot.estimated_duration), 0), ) .filter( TimeSlot.user_id == user_id, TimeSlot.date >= start_date, TimeSlot.date <= end_date, TimeSlot.status.notin_(excluded), TimeSlot.slot_type != SlotType.SYSTEM.value, ) .group_by(TimeSlot.slot_type) .all() ) totals: dict[str, int] = {"work": 0, "on_call": 0, "entertainment": 0} for slot_type_val, total in rows: # slot_type_val may be an enum or a raw string if hasattr(slot_type_val, "value"): slot_type_val = slot_type_val.value cat = _SLOT_TYPE_TO_CATEGORY.get(SlotType(slot_type_val)) if cat: totals[cat] += int(total) return totals def _sum_virtual_slots( db: Session, user_id: int, start_date: date, end_date: date, ) -> dict[str, int]: """Sum ``estimated_duration`` of virtual (plan-generated, not-yet-materialized) slots by category across a date range. Iterates day by day — acceptable because periods are at most a year and the function only queries plans once per day. """ totals: dict[str, int] = {"work": 0, "on_call": 0, "entertainment": 0} current = start_date while current <= end_date: for vs in get_virtual_slots_for_date(db, user_id, current): slot_type = vs["slot_type"] if hasattr(slot_type, "value"): slot_type = slot_type.value cat = _SLOT_TYPE_TO_CATEGORY.get(SlotType(slot_type)) if cat: totals[cat] += vs["estimated_duration"] current += timedelta(days=1) return totals def compute_scheduled_minutes( db: Session, user_id: int, reference_date: date, ) -> dict[str, dict[str, int]]: """Compute total scheduled minutes for each period containing *reference_date*. Returns the canonical shape consumed by :func:`check_workload_warnings`:: { "daily": {"work": N, "on_call": N, "entertainment": N}, "weekly": { ... }, "monthly": { ... }, "yearly": { ... }, } Includes both real (materialized) and virtual (plan-generated) slots. """ result: dict[str, dict[str, int]] = {} for period in PERIODS: start, end = _date_range_for_period(period, reference_date) real = _sum_real_slots(db, user_id, start, end) virtual = _sum_virtual_slots(db, user_id, start, end) result[period] = { cat: real.get(cat, 0) + virtual.get(cat, 0) for cat in CATEGORIES } return result # --------------------------------------------------------------------------- # Warning comparison # --------------------------------------------------------------------------- def check_workload_warnings( db: Session, user_id: int, scheduled_minutes: dict[str, dict[str, int]], ) -> list[WorkloadWarningItem]: """Compare *scheduled_minutes* against the user's configured thresholds. ``scheduled_minutes`` has the same shape as the config:: {"daily": {"work": N, ...}, "weekly": {...}, ...} Returns a list of warnings for every (period, category) where the scheduled total is below the minimum. An empty list means no warnings. """ config = get_workload_config(db, user_id) warnings: list[WorkloadWarningItem] = [] for period in PERIODS: cfg_period = config.get(period, {}) sch_period = scheduled_minutes.get(period, {}) for cat in CATEGORIES: minimum = cfg_period.get(cat, 0) if minimum <= 0: continue current = sch_period.get(cat, 0) if current < minimum: shortfall = minimum - current warnings.append(WorkloadWarningItem( period=period, category=cat, current_minutes=current, minimum_minutes=minimum, shortfall_minutes=shortfall, message=( f"{period.capitalize()} {cat.replace('_', '-')} workload " f"is {current} min, below minimum of {minimum} min " f"(shortfall: {shortfall} min)" ), )) return warnings # --------------------------------------------------------------------------- # High-level convenience: compute + check in one call (BE-CAL-007) # --------------------------------------------------------------------------- def get_workload_warnings_for_date( db: Session, user_id: int, reference_date: date, ) -> list[WorkloadWarningItem]: """One-shot helper: compute scheduled minutes for *reference_date* and return any workload warnings. Calendar API endpoints should call this after a create/edit mutation to include warnings in the response. Warnings are advisory — they do NOT prevent the operation. """ scheduled = compute_scheduled_minutes(db, user_id, reference_date) return check_workload_warnings(db, user_id, scheduled)