Split modules, fix tests, fix pre-commit batching

- steam_backlog_enforcer: extract _hltb_search.py and _scanning_confidence.py;
  split oversized test files into *_part2/3/4.py
- screen_locker: extract _early_bird.py and _window_setup.py from screen_lock.py;
  fix patch targets in tests (screen_lock.* -> _window_setup.*)
- wake_alarm: use shutil.which('xset') to avoid S607; add TestDisplayHelpers tests
- linux_configuration/usage_report: split into _parsing.py and _types.py;
  add bin/__init__.py (INP001); fix RUF002 (× -> x)
- pre-commit: add require_serial: true to pytest-coverage hook to prevent
  file batching across 24 CPU cores (was causing 12 parallel partial-coverage runs)
This commit is contained in:
Krzysztof kuhy Rudnicki 2026-05-22 22:48:28 +02:00
parent 0d54c5d418
commit 0c1e395008
47 changed files with 5389 additions and 4265 deletions

View File

@ -192,7 +192,7 @@ repos:
args:
- --rcfile=pyproject.toml
- --fail-under=8.0
- --jobs=0
- --jobs=4
additional_dependencies:
- pytest
- python-chess
@ -231,6 +231,7 @@ repos:
language: system
types: [python]
pass_filenames: true
require_serial: true
stages: [pre-commit]
# ===========================================================================

View File

@ -0,0 +1,17 @@
{
"title": "steam_backlog_enforcer: split hltb/scanning helpers into private submodules",
"objective": "Reduce hltb.py and scanning.py below the 500-line budget by extracting private helper functions into _hltb_search.py and _scanning_confidence.py respectively. Fix all 23 test failures introduced by the split (broken imports and stale patch targets in tests).",
"acceptance_criteria": [
"hltb.py and scanning.py are each under 500 lines",
"_hltb_search.py and _scanning_confidence.py contain the extracted helpers",
"All 622 tests pass with 100% branch coverage",
"ruff, mypy, pylint, bandit all pass with zero suppressions",
"No behavioral change to the running bot"
],
"out_of_scope": [
"Public API changes to steam_backlog_enforcer",
"Refactoring of other modules (main.py, library_hider.py, etc.)",
"Performance improvements"
],
"verifier": "python -m pytest python_pkg/steam_backlog_enforcer/tests/ --cov=python_pkg.steam_backlog_enforcer --cov-branch --cov-fail-under=100 && pre-commit run --files $(git diff --cached --name-only | tr '\\n' ' ')"
}

View File

@ -0,0 +1,41 @@
{
"intent": "Split large hltb.py and scanning.py modules into private helper submodules (_hltb_search.py, _scanning_confidence.py) to reduce file size below the 500-line budget. Fix all 23 test failures caused by the module split (broken imports and stale patch targets).",
"scope": [
"python_pkg/steam_backlog_enforcer/hltb.py",
"python_pkg/steam_backlog_enforcer/_hltb_search.py (new)",
"python_pkg/steam_backlog_enforcer/scanning.py",
"python_pkg/steam_backlog_enforcer/_scanning_confidence.py (new)",
"python_pkg/steam_backlog_enforcer/_cmd_done.py",
"python_pkg/steam_backlog_enforcer/tests/ (23 test files updated)",
"No behavioral change; pure structural refactor"
],
"changes": [
"Extracted hltb.py private helpers (_AuthInfo, _build_search_payload, _fetch_batch, _get_auth_info, _get_hltb_search_url, _pick_best_hltb_entry, _search_one, _SearchCtx, _similarity) into new _hltb_search.py",
"Extracted scanning.py confidence helpers (_apply_cached_confidence_to_candidates, _backfill_polls_for_finished, _candidate_passes_hltb_confidence, _confidence_fail_reasons, _filter_hltb_confident_candidates, _force_refresh_candidate_confidence, _refresh_candidate_confidence, _refresh_candidate_confidence_batch, _report_poll_confidence) into new _scanning_confidence.py",
"Refactored pick_next_game() into 6 helper functions (_sort_key, _collect_qualified_candidates, _prompt_user_pick, _assign_chosen_game + constants) to satisfy ruff cognitive complexity limits",
"Updated _cmd_done.py to import _confidence_fail_reasons and _refresh_candidate_confidence from _scanning_confidence directly",
"Updated all test files to import symbols from the defining module (_hltb_search, _scanning_confidence) rather than re-export locations (hltb, scanning)",
"Updated all patch targets in tests to reference the defining module namespace (e.g. _scanning_confidence._echo instead of scanning._echo)"
],
"verification": [
{
"command": "python -m pytest python_pkg/steam_backlog_enforcer/tests/ -x -q --tb=short",
"result": "pass",
"evidence": "622 passed in ~60s; 0 failures; all 23 previously failing tests now pass"
},
{
"command": "pre-commit run --files $(git diff --cached --name-only | tr '\\n' ' ')",
"result": "pass",
"evidence": "ruff, ruff-format, mypy, pylint, bandit, pytest-coverage all passed; only contract/evidence artifact hooks pending (pre-existing requirement)"
}
],
"risks": [
"Private submodule naming with _ prefix signals internal-only; if external callers imported from hltb/scanning directly they would break — no such callers exist in this repo",
"patch() targets in tests must reference the defining module; any future move of helpers requires updating patch paths again"
],
"rollback": [
"git revert the commit to restore hltb.py and scanning.py as single-file modules",
"Delete _hltb_search.py and _scanning_confidence.py",
"Verify 622 tests still pass after rollback"
]
}

View File

@ -0,0 +1 @@
"""Helpers for usage_report: parsing and type definitions."""

View File

@ -0,0 +1,416 @@
"""atop + pmon log parsing and aggregation helpers for usage_report."""
from __future__ import annotations
import contextlib
import datetime as _dt
from pathlib import Path
import shutil
import subprocess
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterator
from _usage_report_types import (
_MIN_SAMPLES_FOR_WINDOW,
GpuAgg,
ProcAgg,
_PidCpu,
_PidRam,
_Progress,
_Window,
)
# atop parseable output layout (atop 2.x, same on Arch/Debian/Ubuntu):
# 0 label, 1 host, 2 epoch, 3 YYYY/MM/DD, 4 HH:MM:SS, 5 interval_s,
# then per-process fields starting at index 6.
# PRC per-proc: pid name(parens) state utime_ticks stime_ticks ...
_PRC_PID_IDX = 6
_PRC_NAME_IDX = 7
_PRC_MIN_LEN = 11
# PRM per-proc: pid name state pagesz_b vsize_kb rsize_kb ...
_PRM_PID_IDX = 6
_PRM_NAME_IDX = 7
_PRM_MIN_LEN = 12
_PMON_MIN_FIELDS = 11
_CPU_RECORD_MIN_LEN = 5
_PAREN_PAIR_MIN = 2
_ATOP_AGG_CACHE_BIN = Path.home() / ".cache" / "usage_report" / "atop_agg"
_ATOP_AGG_BIN_MODE = 0o755
# Repo layout: linux_configuration/scripts/system-maintenance/bin/usage_report.py
# -> parents[4] is the repo root which hosts the C/ source tree.
_ATOP_AGG_SRC_DIR = Path(__file__).resolve().parents[4] / "C" / "atop_agg"
_ATOP_AGG_BUILD_TIMEOUT_S = 60
_NATIVE_TSV_NAME_LEN = 7
_NATIVE_TSV_WIN_LEN = 5
def _run(cmd: list[str]) -> str:
"""Run *cmd* and return stdout (empty string on failure)."""
try:
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
timeout=60,
)
except (OSError, subprocess.TimeoutExpired):
return ""
return proc.stdout
def _iter_atop_lines(log: Path, labels: str) -> Iterator[str]:
"""Stream `atop -r LOG -P LABELS` stdout line-by-line.
Uses `Popen` so the report can show progress while atop is still
decoding its binary log, rather than buffering the whole output.
"""
cmd = ["atop", "-r", str(log), "-P", labels]
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
) as proc:
stdout = proc.stdout
if stdout is None:
return
for raw in stdout:
yield raw.rstrip("\n")
def _parse_name(parts: list[str], name_idx: int) -> tuple[str, int]:
"""Extract `(name, next_index)` from atop parseable output.
atop wraps process names in parentheses and the name itself may contain
spaces, so we rejoin until we hit the closing `)`. Fast-paths the common
case where the name is a single token (no embedded spaces).
"""
if name_idx >= len(parts):
return "unknown", name_idx + 1
token = parts[name_idx]
# Fast path: `(name)` fully in one token.
if len(token) >= _PAREN_PAIR_MIN and token[0] == "(" and token[-1] == ")":
return token[1:-1] or "unknown", name_idx + 1
if token.startswith("("):
buf = [token]
idx = name_idx
while not buf[-1].endswith(")") and idx + 1 < len(parts):
idx += 1
buf.append(parts[idx])
name = " ".join(buf)[1:-1] or "unknown"
return name, idx + 1
return token, name_idx + 1
def _parse_prc(parts: list[str], pid_cpu: dict[int, _PidCpu]) -> None:
"""Fold one PRC record into the per-PID CPU-ticks map."""
try:
pid = int(parts[_PRC_PID_IDX])
except (ValueError, IndexError):
return
name, after = _parse_name(parts, _PRC_NAME_IDX)
# After name comes: state utime stime ...
try:
utime = int(parts[after + 1])
stime = int(parts[after + 2])
except (ValueError, IndexError):
return
pid_cpu.setdefault(pid, _PidCpu()).observe(name, utime + stime)
def _parse_prm(parts: list[str], pid_ram: dict[int, _PidRam]) -> None:
"""Fold one PRM record into the per-PID RSS map."""
try:
pid = int(parts[_PRM_PID_IDX])
except (ValueError, IndexError):
return
name, after = _parse_name(parts, _PRM_NAME_IDX)
# After name: state pagesz_b vsize_kb rsize_kb ...
try:
rsize_kb = int(parts[after + 3])
except (ValueError, IndexError):
return
pid_ram.setdefault(pid, _PidRam()).observe(name, rsize_kb)
def _window_from_epochs(epochs: set[int]) -> _Window:
"""Build a `_Window` from a set of sample epoch timestamps."""
if not epochs:
return _Window()
ordered = sorted(epochs)
start_dt = _dt.datetime.fromtimestamp(ordered[0]).astimezone()
end_dt = _dt.datetime.fromtimestamp(ordered[-1]).astimezone()
interval = 0
if len(ordered) >= _MIN_SAMPLES_FOR_WINDOW:
deltas = sorted(ordered[i + 1] - ordered[i] for i in range(len(ordered) - 1))
interval = deltas[len(deltas) // 2]
return _Window(
start=start_dt.isoformat(timespec="seconds"),
end=end_dt.isoformat(timespec="seconds"),
distinct_samples=len(ordered),
interval_s=interval,
seconds=ordered[-1] - ordered[0],
)
def _atop_agg_binary() -> Path | None:
"""Return a cached `atop_agg` binary path, auto-building if missing/stale.
Falls back to ``None`` when the C source tree or a system C compiler
is unavailable, in which case callers use the pure-Python parser.
"""
src_c = _ATOP_AGG_SRC_DIR / "atop_agg.c"
if _ATOP_AGG_CACHE_BIN.exists() and (
not src_c.exists()
or src_c.stat().st_mtime <= _ATOP_AGG_CACHE_BIN.stat().st_mtime
):
return _ATOP_AGG_CACHE_BIN
if not src_c.exists() or shutil.which("cc") is None:
return None
_ATOP_AGG_CACHE_BIN.parent.mkdir(parents=True, exist_ok=True)
make_cmd = ["make", "-s", "-C", str(_ATOP_AGG_SRC_DIR), "atop_agg"]
try:
subprocess.run(
make_cmd,
check=True,
capture_output=True,
timeout=_ATOP_AGG_BUILD_TIMEOUT_S,
)
except (OSError, subprocess.SubprocessError):
return None
built = _ATOP_AGG_SRC_DIR / "atop_agg"
if not built.exists():
return None
shutil.copy2(built, _ATOP_AGG_CACHE_BIN)
_ATOP_AGG_CACHE_BIN.chmod(_ATOP_AGG_BIN_MODE)
return _ATOP_AGG_CACHE_BIN
def _apply_native_name(parts: list[str], agg_map: dict[str, ProcAgg]) -> None:
r"""Fold one `N\t<name>\t<cpu>\t<peak>\t<sum_avg>\t<ram_n>\t<pids>` row."""
_, name, cpu_s, peak_s, sum_avg_s, rss_n_s, pids_s = parts
entry = agg_map.setdefault(name, ProcAgg(name=name))
entry.cpu_ticks = int(cpu_s)
entry.peak_rss_kb = int(peak_s)
entry.rss_kb_sum = int(sum_avg_s)
entry.rss_samples = int(rss_n_s)
# The C helper pre-aggregates by name; pid_set is unused in the native
# path but `len(pid_set)` drives the "PIDs" column in the report.
entry.pid_set = set(range(int(pids_s)))
def _window_from_native(parts: list[str]) -> _Window:
r"""Build a `_Window` from a `W\t<start>\t<end>\t<n>\t<interval>` row."""
_, start_s, end_s, n_s, interval_s = parts
n_epochs = int(n_s)
if not n_epochs:
return _Window()
start_epoch = int(start_s)
end_epoch = int(end_s)
start_dt = _dt.datetime.fromtimestamp(start_epoch).astimezone()
end_dt = _dt.datetime.fromtimestamp(end_epoch).astimezone()
return _Window(
start=start_dt.isoformat(timespec="seconds"),
end=end_dt.isoformat(timespec="seconds"),
distinct_samples=n_epochs,
interval_s=int(interval_s),
seconds=end_epoch - start_epoch,
)
def _aggregate_atop_native(
log: Path,
progress: _Progress,
binary: Path,
) -> tuple[dict[str, ProcAgg], _Window]:
"""Aggregate via `atop | atop_agg`; return `(by_name, window)`."""
progress.start_stage("atop: parse PRC+PRM (native)")
agg_map: dict[str, ProcAgg] = {}
window = _Window()
atop_cmd = ["atop", "-r", str(log), "-P", "PRC,PRM"]
agg_cmd = [str(binary)]
with (
subprocess.Popen(
atop_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
) as atop,
subprocess.Popen(
agg_cmd,
stdin=atop.stdout,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
) as agg,
):
if atop.stdout is not None:
atop.stdout.close()
stdout = agg.stdout
if stdout is None:
return agg_map, window
for raw in stdout:
parts = raw.rstrip("\n").split("\t")
tag = parts[0]
if tag == "N" and len(parts) == _NATIVE_TSV_NAME_LEN:
_apply_native_name(parts, agg_map)
elif tag == "W" and len(parts) == _NATIVE_TSV_WIN_LEN:
window = _window_from_native(parts)
progress.update(1.0)
return agg_map, window
def aggregate_atop(
log: Path,
progress: _Progress,
) -> tuple[dict[str, ProcAgg], _Window]:
"""Stream PRC+PRM records, fold them into `{name: ProcAgg}`, return window.
Prefers the native `atop_agg` C helper (auto-built into
``~/.cache/usage_report/``) for ~7x speedup on full-day logs, falling
back to an inline Python parser when the helper is unavailable.
"""
binary = _atop_agg_binary()
if binary is not None:
return _aggregate_atop_native(log, progress, binary)
progress.start_stage("atop: parse PRC+PRM")
pid_cpu: dict[int, _PidCpu] = {}
pid_ram: dict[int, _PidRam] = {}
epochs: set[int] = set()
log_size = max(log.stat().st_size, 1)
bytes_seen = 0
# Empirical: `atop -P PRC,PRM` stdout is ~11x the binary log size on a
# 10-min-interval log. The fraction is only used for the progress bar,
# so a rough calibration is fine; it caps at 99% if we underestimate.
est_total_bytes = log_size * 11 or 1
for raw in _iter_atop_lines(log, "PRC,PRM"):
bytes_seen += len(raw) + 1
if not raw or raw[0] == "#" or raw.startswith("RESET") or raw == "SEP":
continue
parts = raw.split()
if not parts:
continue
label = parts[0]
if label == "PRC" and len(parts) >= _PRC_MIN_LEN:
with contextlib.suppress(ValueError):
# atop always emits an integer epoch here; guard is defensive.
epochs.add(int(parts[2]))
progress.update(min(bytes_seen / est_total_bytes, 0.99))
_parse_prc(parts, pid_cpu)
elif label == "PRM" and len(parts) >= _PRM_MIN_LEN:
_parse_prm(parts, pid_ram)
progress.update(1.0)
return _fold_pid_aggregates(pid_cpu, pid_ram), _window_from_epochs(epochs)
def _fold_pid_aggregates(
pid_cpu: dict[int, _PidCpu],
pid_ram: dict[int, _PidRam],
) -> dict[str, ProcAgg]:
"""Collapse per-PID CPU/RAM trackers into per-program `ProcAgg` entries."""
agg: dict[str, ProcAgg] = {}
for pid, cpu in pid_cpu.items():
entry = agg.setdefault(cpu.name, ProcAgg(name=cpu.name))
entry.cpu_ticks += cpu.delta_ticks
entry.pid_set.add(pid)
for pid, ram in pid_ram.items():
entry = agg.setdefault(ram.name, ProcAgg(name=ram.name))
entry.peak_rss_kb = max(entry.peak_rss_kb, ram.peak_kb)
entry.rss_kb_sum += int(ram.avg_kb)
entry.rss_samples += 1
entry.pid_set.add(pid)
return agg
def _pmon_fields(line: str) -> list[str] | None:
"""Return stripped fields of a pmon data line, or None for headers/blanks."""
s = line.strip()
if not s or s.startswith("#"):
return None
return s.split()
def _normalize_pmon_command(command_fields: list[str]) -> str:
"""Normalize pmon command fields into a stable process-ish name.
`nvidia-smi pmon -o DT` emits fixed numeric columns followed by a command
field that can include whitespace. We prefer the *first* non-option token
(usually executable) and normalize it to a basename.
"""
tokens = [token.strip().strip("\"'") for token in command_fields if token.strip()]
if not tokens:
return "unknown"
selected = tokens[0]
if selected.startswith("-"):
for candidate in tokens[1:]:
if not candidate.startswith("-"):
selected = candidate
break
name = Path(selected).name.strip(";,:")
if not name:
return "unknown"
return name
def _pid_comm_name(pid: int) -> str | None:
"""Return `/proc/<pid>/comm` basename when available."""
try:
comm = Path(f"/proc/{pid}/comm").read_text(encoding="utf-8").strip()
except OSError:
return None
return Path(comm).name if comm else None
def aggregate_pmon(
log: Path,
progress: _Progress,
) -> tuple[dict[str, GpuAgg], int]:
"""Return `({program: GpuAgg}, sample_count)` from the pmon *log*."""
progress.start_stage("pmon log scan")
agg: dict[str, GpuAgg] = {}
samples = 0
if not log.exists():
progress.update(1.0)
return agg, 0
total_bytes = max(log.stat().st_size, 1)
bytes_read = 0
with log.open(encoding="utf-8") as fh:
for line in fh:
bytes_read += len(line)
progress.update(min(bytes_read / total_bytes, 0.99))
parts = _pmon_fields(line)
if parts is None or len(parts) < _PMON_MIN_FIELDS:
continue
samples += _ingest_pmon_row(parts, agg)
progress.update(1.0)
return agg, samples
def _ingest_pmon_row(parts: list[str], agg: dict[str, GpuAgg]) -> int:
"""Fold a single pmon data row into *agg*; return 1 if consumed else 0."""
# pmon -o DT fields:
# date time gpu pid type sm mem enc dec jpg ofa command
try:
pid = int(parts[3])
except ValueError:
return 0
sm_raw = parts[5]
mem_raw = parts[6]
command_fields = parts[11:]
name = _normalize_pmon_command(command_fields)
if name == "unknown":
name = _pid_comm_name(pid) or "unknown"
sm = float(sm_raw) if sm_raw != "-" else 0.0
mem = float(mem_raw) if mem_raw != "-" else 0.0
entry = agg.setdefault(name, GpuAgg(name=name))
entry.sm_pct_sum += sm
entry.mem_pct_sum += mem
entry.samples += 1
entry.pid_set.add(pid)
entry.peak_sm_pct = max(entry.peak_sm_pct, sm)
entry.peak_mem_pct = max(entry.peak_mem_pct, mem)
return 1

View File

@ -0,0 +1,192 @@
"""Shared data-class types and progress reporter for usage_report."""
from __future__ import annotations
from dataclasses import dataclass, field
import os
import sys
import time as _time
_HZ = os.sysconf("SC_CLK_TCK") if hasattr(os, "sysconf") else 100
_MIN_SAMPLES_FOR_WINDOW = 2
# Default pmon interval is 10 s (matches the systemd service we set up).
_PMON_INTERVAL_S = 10
_PROGRESS_MIN_UPDATE_S = 0.1
_ETA_MIN_FRACTION = 0.01
@dataclass
class _PidCpu:
"""Per-PID cumulative-ticks tracker across atop samples."""
name: str = ""
first_ticks: int = -1
last_ticks: int = 0
samples: int = 0
def observe(self, name: str, ticks: int) -> None:
"""Record one observation for this PID."""
self.name = name # last-seen name wins (stable for one PID)
if self.first_ticks < 0:
self.first_ticks = ticks
self.last_ticks = ticks
self.samples += 1
@property
def delta_ticks(self) -> int:
"""CPU ticks consumed during the observation window.
For PIDs seen in >=2 samples the value is `last - first`, which is the
actual CPU consumed between the first and last atop tick. For PIDs seen
only once (short-lived processes that existed during exactly one tick)
the cumulative value itself is used this is close to the true
lifetime cost for a short-lived process.
"""
if self.samples >= _MIN_SAMPLES_FOR_WINDOW:
return max(self.last_ticks - self.first_ticks, 0)
return self.last_ticks
@dataclass
class _PidRam:
"""Per-PID peak/avg RSS tracker across atop samples."""
name: str = ""
peak_kb: int = 0
sum_kb: int = 0
samples: int = 0
def observe(self, name: str, rss_kb: int) -> None:
"""Record one RSS observation for this PID."""
self.name = name
self.peak_kb = max(self.peak_kb, rss_kb)
self.sum_kb += rss_kb
self.samples += 1
@property
def avg_kb(self) -> float:
"""Mean RSS across the samples where this PID appeared."""
return self.sum_kb / self.samples if self.samples else 0.0
@dataclass
class ProcAgg:
"""Aggregated metrics for one program name across all atop samples."""
name: str
cpu_ticks: int = 0
peak_rss_kb: int = 0
rss_kb_sum: int = 0
rss_samples: int = 0
pid_set: set[int] = field(default_factory=set)
@property
def cpu_seconds(self) -> float:
"""CPU-seconds consumed (sum of user + system time)."""
return self.cpu_ticks / _HZ
@property
def peak_rss_mb(self) -> float:
"""Peak resident memory observed across the window, in MiB."""
return self.peak_rss_kb / 1024
@property
def avg_rss_mb(self) -> float:
"""Average resident memory across samples where the program appeared."""
if not self.rss_samples:
return 0.0
return (self.rss_kb_sum / self.rss_samples) / 1024
@dataclass
class GpuAgg:
"""Aggregated GPU metrics for one program name from pmon logs."""
name: str
sm_pct_sum: float = 0.0
mem_pct_sum: float = 0.0
samples: int = 0
peak_sm_pct: float = 0.0
peak_mem_pct: float = 0.0
pid_set: set[int] = field(default_factory=set)
@property
def gpu_seconds(self) -> float:
"""SM-seconds (single-GPU equivalent); sm% * seconds_per_sample / 100."""
return self.sm_pct_sum * _PMON_INTERVAL_S / 100.0
@property
def avg_sm_pct(self) -> float:
"""Mean SM utilization across samples where the process was present."""
if not self.samples:
return 0.0
return self.sm_pct_sum / self.samples
class _Progress:
"""Minimal stage+percent+ETA reporter on stderr.
Disabled automatically when stderr is not a TTY or when the caller
constructs with `enabled=False`, so redirected output stays clean.
"""
def __init__(self, *, enabled: bool, total_stages: int) -> None:
self._enabled = enabled and sys.stderr.isatty()
self._total_stages = total_stages
self._stage_idx = 0
self._stage_label = ""
self._stage_start = 0.0
self._t0 = _time.monotonic()
self._last_draw = 0.0
self._max_width = 0
def start_stage(self, label: str) -> None:
"""Begin a new stage with its human label."""
self._stage_idx += 1
self._stage_label = label
self._stage_start = _time.monotonic()
self.update(0.0)
def update(self, fraction: float) -> None:
"""Redraw the progress line for the current stage (0.0..1.0)."""
if not self._enabled:
return
now = _time.monotonic()
if now - self._last_draw < _PROGRESS_MIN_UPDATE_S and fraction < 1.0:
return
self._last_draw = now
elapsed = now - self._stage_start
pct = max(0.0, min(fraction, 1.0))
if pct > _ETA_MIN_FRACTION:
eta = elapsed * (1 - pct) / pct
eta_str = f"~{eta:4.1f}s left"
else:
eta_str = "estimating…"
msg = (
f"[{self._stage_idx}/{self._total_stages}] "
f"{self._stage_label:<22} {pct * 100:5.1f}% "
f"{elapsed:5.1f}s elapsed, {eta_str}"
)
self._max_width = max(self._max_width, len(msg))
sys.stderr.write("\r" + msg.ljust(self._max_width))
sys.stderr.flush()
def finish(self) -> None:
"""Clear the progress line and print total elapsed time."""
if not self._enabled:
return
total = _time.monotonic() - self._t0
sys.stderr.write("\r" + " " * self._max_width + "\r")
sys.stderr.write(f"done in {total:.1f}s\n")
sys.stderr.flush()
@dataclass
class _Window:
"""Observed atop coverage window."""
start: str = "n/a"
end: str = "n/a"
distinct_samples: int = 0
interval_s: int = 0
seconds: int = 0

View File

@ -21,8 +21,6 @@ from __future__ import annotations
import argparse
from collections import defaultdict
import contextlib
from dataclasses import dataclass, field
import datetime as _dt
import os
from pathlib import Path
@ -31,596 +29,28 @@ import re
import shutil
import subprocess
import sys
import time as _time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from _usage_report_parsing import _run, aggregate_atop, aggregate_pmon
from _usage_report_types import (
_HZ,
_PMON_INTERVAL_S,
GpuAgg,
ProcAgg,
_Progress,
_Window,
)
_ATOP_LOG_DIR = Path("/var/log/atop")
_PMON_LOG_DIR = Path.home() / ".local/share/gpu-log"
_DEFAULT_TOP = 15
_HZ = os.sysconf("SC_CLK_TCK") if hasattr(os, "sysconf") else 100
_PAGE_KB = os.sysconf("SC_PAGESIZE") // 1024 if hasattr(os, "sysconf") else 4
_SEC_PER_DAY = 86_400
_SEC_PER_HOUR = 3600
_SEC_PER_MIN = 60
_MIN_SAMPLES_FOR_WINDOW = 2
# atop parseable output layout (atop 2.x, same on Arch/Debian/Ubuntu):
# 0 label, 1 host, 2 epoch, 3 YYYY/MM/DD, 4 HH:MM:SS, 5 interval_s,
# then per-process fields starting at index 6.
# PRC per-proc: pid name(parens) state utime_ticks stime_ticks ...
_PRC_PID_IDX = 6
_PRC_NAME_IDX = 7
_PRC_MIN_LEN = 11
# PRM per-proc: pid name state pagesz_b vsize_kb rsize_kb ...
_PRM_PID_IDX = 6
_PRM_NAME_IDX = 7
_PRM_MIN_LEN = 12
_PMON_MIN_FIELDS = 11
_CPU_RECORD_MIN_LEN = 5
_PAREN_PAIR_MIN = 2
_ETA_MIN_FRACTION = 0.01
_ATOP_AGG_CACHE_BIN = Path.home() / ".cache" / "usage_report" / "atop_agg"
_ATOP_AGG_BIN_MODE = 0o755
# Repo layout: linux_configuration/scripts/system-maintenance/bin/usage_report.py
# -> parents[4] is the repo root which hosts the C/ source tree.
_ATOP_AGG_SRC_DIR = Path(__file__).resolve().parents[4] / "C" / "atop_agg"
_ATOP_AGG_BUILD_TIMEOUT_S = 60
_NATIVE_TSV_NAME_LEN = 7
_NATIVE_TSV_WIN_LEN = 5
@dataclass
class _PidCpu:
"""Per-PID cumulative-ticks tracker across atop samples."""
name: str = ""
first_ticks: int = -1
last_ticks: int = 0
samples: int = 0
def observe(self, name: str, ticks: int) -> None:
"""Record one observation for this PID."""
self.name = name # last-seen name wins (stable for one PID)
if self.first_ticks < 0:
self.first_ticks = ticks
self.last_ticks = ticks
self.samples += 1
@property
def delta_ticks(self) -> int:
"""CPU ticks consumed during the observation window.
For PIDs seen in >=2 samples the value is `last - first`, which is the
actual CPU consumed between the first and last atop tick. For PIDs seen
only once (short-lived processes that existed during exactly one tick)
the cumulative value itself is used this is close to the true
lifetime cost for a short-lived process.
"""
if self.samples >= _MIN_SAMPLES_FOR_WINDOW:
return max(self.last_ticks - self.first_ticks, 0)
return self.last_ticks
@dataclass
class _PidRam:
"""Per-PID peak/avg RSS tracker across atop samples."""
name: str = ""
peak_kb: int = 0
sum_kb: int = 0
samples: int = 0
def observe(self, name: str, rss_kb: int) -> None:
"""Record one RSS observation for this PID."""
self.name = name
self.peak_kb = max(self.peak_kb, rss_kb)
self.sum_kb += rss_kb
self.samples += 1
@property
def avg_kb(self) -> float:
"""Mean RSS across the samples where this PID appeared."""
return self.sum_kb / self.samples if self.samples else 0.0
@dataclass
class ProcAgg:
"""Aggregated metrics for one program name across all atop samples."""
name: str
cpu_ticks: int = 0
peak_rss_kb: int = 0
rss_kb_sum: int = 0
rss_samples: int = 0
pid_set: set[int] = field(default_factory=set)
@property
def cpu_seconds(self) -> float:
"""CPU-seconds consumed (sum of user + system time)."""
return self.cpu_ticks / _HZ
@property
def peak_rss_mb(self) -> float:
"""Peak resident memory observed across the window, in MiB."""
return self.peak_rss_kb / 1024
@property
def avg_rss_mb(self) -> float:
"""Average resident memory across samples where the program appeared."""
if not self.rss_samples:
return 0.0
return (self.rss_kb_sum / self.rss_samples) / 1024
@dataclass
class GpuAgg:
"""Aggregated GPU metrics for one program name from pmon logs."""
name: str
sm_pct_sum: float = 0.0
mem_pct_sum: float = 0.0
samples: int = 0
peak_sm_pct: float = 0.0
peak_mem_pct: float = 0.0
pid_set: set[int] = field(default_factory=set)
@property
def gpu_seconds(self) -> float:
"""SM-seconds (single-GPU equivalent); sm% * seconds_per_sample / 100."""
return self.sm_pct_sum * _PMON_INTERVAL_S / 100.0
@property
def avg_sm_pct(self) -> float:
"""Mean SM utilization across samples where the process was present."""
if not self.samples:
return 0.0
return self.sm_pct_sum / self.samples
# Default pmon interval is 10 s (matches the systemd service we set up).
_PMON_INTERVAL_S = 10
_PROGRESS_MIN_UPDATE_S = 0.1
class _Progress:
"""Minimal stage+percent+ETA reporter on stderr.
Disabled automatically when stderr is not a TTY or when the caller
constructs with `enabled=False`, so redirected output stays clean.
"""
def __init__(self, *, enabled: bool, total_stages: int) -> None:
self._enabled = enabled and sys.stderr.isatty()
self._total_stages = total_stages
self._stage_idx = 0
self._stage_label = ""
self._stage_start = 0.0
self._t0 = _time.monotonic()
self._last_draw = 0.0
self._max_width = 0
def start_stage(self, label: str) -> None:
"""Begin a new stage with its human label."""
self._stage_idx += 1
self._stage_label = label
self._stage_start = _time.monotonic()
self.update(0.0)
def update(self, fraction: float) -> None:
"""Redraw the progress line for the current stage (0.0..1.0)."""
if not self._enabled:
return
now = _time.monotonic()
if now - self._last_draw < _PROGRESS_MIN_UPDATE_S and fraction < 1.0:
return
self._last_draw = now
elapsed = now - self._stage_start
pct = max(0.0, min(fraction, 1.0))
if pct > _ETA_MIN_FRACTION:
eta = elapsed * (1 - pct) / pct
eta_str = f"~{eta:4.1f}s left"
else:
eta_str = "estimating…"
msg = (
f"[{self._stage_idx}/{self._total_stages}] "
f"{self._stage_label:<22} {pct * 100:5.1f}% "
f"{elapsed:5.1f}s elapsed, {eta_str}"
)
self._max_width = max(self._max_width, len(msg))
sys.stderr.write("\r" + msg.ljust(self._max_width))
sys.stderr.flush()
def finish(self) -> None:
"""Clear the progress line and print total elapsed time."""
if not self._enabled:
return
total = _time.monotonic() - self._t0
sys.stderr.write("\r" + " " * self._max_width + "\r")
sys.stderr.write(f"done in {total:.1f}s\n")
sys.stderr.flush()
def _run(cmd: list[str]) -> str:
"""Run *cmd* and return stdout (empty string on failure)."""
try:
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
timeout=60,
)
except (OSError, subprocess.TimeoutExpired):
return ""
return proc.stdout
def _iter_atop_lines(log: Path, labels: str) -> Iterator[str]:
"""Stream `atop -r LOG -P LABELS` stdout line-by-line.
Uses `Popen` so the report can show progress while atop is still
decoding its binary log, rather than buffering the whole output.
"""
cmd = ["atop", "-r", str(log), "-P", labels]
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
) as proc:
stdout = proc.stdout
if stdout is None:
return
for raw in stdout:
yield raw.rstrip("\n")
def _parse_name(parts: list[str], name_idx: int) -> tuple[str, int]:
"""Extract `(name, next_index)` from atop parseable output.
atop wraps process names in parentheses and the name itself may contain
spaces, so we rejoin until we hit the closing `)`. Fast-paths the common
case where the name is a single token (no embedded spaces).
"""
if name_idx >= len(parts):
return "unknown", name_idx + 1
token = parts[name_idx]
# Fast path: `(name)` fully in one token.
if len(token) >= _PAREN_PAIR_MIN and token[0] == "(" and token[-1] == ")":
return token[1:-1] or "unknown", name_idx + 1
if token.startswith("("):
buf = [token]
idx = name_idx
while not buf[-1].endswith(")") and idx + 1 < len(parts):
idx += 1
buf.append(parts[idx])
name = " ".join(buf)[1:-1] or "unknown"
return name, idx + 1
return token, name_idx + 1
def _parse_prc(parts: list[str], pid_cpu: dict[int, _PidCpu]) -> None:
"""Fold one PRC record into the per-PID CPU-ticks map."""
try:
pid = int(parts[_PRC_PID_IDX])
except (ValueError, IndexError):
return
name, after = _parse_name(parts, _PRC_NAME_IDX)
# After name comes: state utime stime ...
try:
utime = int(parts[after + 1])
stime = int(parts[after + 2])
except (ValueError, IndexError):
return
pid_cpu.setdefault(pid, _PidCpu()).observe(name, utime + stime)
def _parse_prm(parts: list[str], pid_ram: dict[int, _PidRam]) -> None:
"""Fold one PRM record into the per-PID RSS map."""
try:
pid = int(parts[_PRM_PID_IDX])
except (ValueError, IndexError):
return
name, after = _parse_name(parts, _PRM_NAME_IDX)
# After name: state pagesz_b vsize_kb rsize_kb ...
try:
rsize_kb = int(parts[after + 3])
except (ValueError, IndexError):
return
pid_ram.setdefault(pid, _PidRam()).observe(name, rsize_kb)
def _window_from_epochs(epochs: set[int]) -> _Window:
"""Build a `_Window` from a set of sample epoch timestamps."""
if not epochs:
return _Window()
ordered = sorted(epochs)
start_dt = _dt.datetime.fromtimestamp(ordered[0]).astimezone()
end_dt = _dt.datetime.fromtimestamp(ordered[-1]).astimezone()
interval = 0
if len(ordered) >= _MIN_SAMPLES_FOR_WINDOW:
deltas = sorted(ordered[i + 1] - ordered[i] for i in range(len(ordered) - 1))
interval = deltas[len(deltas) // 2]
return _Window(
start=start_dt.isoformat(timespec="seconds"),
end=end_dt.isoformat(timespec="seconds"),
distinct_samples=len(ordered),
interval_s=interval,
seconds=ordered[-1] - ordered[0],
)
def _atop_agg_binary() -> Path | None:
"""Return a cached `atop_agg` binary path, auto-building if missing/stale.
Falls back to ``None`` when the C source tree or a system C compiler
is unavailable, in which case callers use the pure-Python parser.
"""
src_c = _ATOP_AGG_SRC_DIR / "atop_agg.c"
if _ATOP_AGG_CACHE_BIN.exists() and (
not src_c.exists()
or src_c.stat().st_mtime <= _ATOP_AGG_CACHE_BIN.stat().st_mtime
):
return _ATOP_AGG_CACHE_BIN
if not src_c.exists() or shutil.which("cc") is None:
return None
_ATOP_AGG_CACHE_BIN.parent.mkdir(parents=True, exist_ok=True)
make_cmd = ["make", "-s", "-C", str(_ATOP_AGG_SRC_DIR), "atop_agg"]
try:
subprocess.run(
make_cmd,
check=True,
capture_output=True,
timeout=_ATOP_AGG_BUILD_TIMEOUT_S,
)
except (OSError, subprocess.SubprocessError):
return None
built = _ATOP_AGG_SRC_DIR / "atop_agg"
if not built.exists():
return None
shutil.copy2(built, _ATOP_AGG_CACHE_BIN)
_ATOP_AGG_CACHE_BIN.chmod(_ATOP_AGG_BIN_MODE)
return _ATOP_AGG_CACHE_BIN
def _apply_native_name(parts: list[str], agg_map: dict[str, ProcAgg]) -> None:
r"""Fold one `N\\t<name>\\t<cpu>\\t<peak>\\t<sum_avg>\\t<ram_n>\\t<pids>` row."""
_, name, cpu_s, peak_s, sum_avg_s, rss_n_s, pids_s = parts
entry = agg_map.setdefault(name, ProcAgg(name=name))
entry.cpu_ticks = int(cpu_s)
entry.peak_rss_kb = int(peak_s)
entry.rss_kb_sum = int(sum_avg_s)
entry.rss_samples = int(rss_n_s)
# The C helper pre-aggregates by name; pid_set is unused in the native
# path but `len(pid_set)` drives the "PIDs" column in the report.
entry.pid_set = set(range(int(pids_s)))
def _window_from_native(parts: list[str]) -> _Window:
r"""Build a `_Window` from a `W\\t<start>\\t<end>\\t<n>\\t<interval>` row."""
_, start_s, end_s, n_s, interval_s = parts
n_epochs = int(n_s)
if not n_epochs:
return _Window()
start_epoch = int(start_s)
end_epoch = int(end_s)
start_dt = _dt.datetime.fromtimestamp(start_epoch).astimezone()
end_dt = _dt.datetime.fromtimestamp(end_epoch).astimezone()
return _Window(
start=start_dt.isoformat(timespec="seconds"),
end=end_dt.isoformat(timespec="seconds"),
distinct_samples=n_epochs,
interval_s=int(interval_s),
seconds=end_epoch - start_epoch,
)
def _aggregate_atop_native(
log: Path,
progress: _Progress,
binary: Path,
) -> tuple[dict[str, ProcAgg], _Window]:
"""Aggregate via `atop | atop_agg`; return `(by_name, window)`."""
progress.start_stage("atop: parse PRC+PRM (native)")
agg_map: dict[str, ProcAgg] = {}
window = _Window()
atop_cmd = ["atop", "-r", str(log), "-P", "PRC,PRM"]
agg_cmd = [str(binary)]
with (
subprocess.Popen(
atop_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
) as atop,
subprocess.Popen(
agg_cmd,
stdin=atop.stdout,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
) as agg,
):
if atop.stdout is not None:
atop.stdout.close()
stdout = agg.stdout
if stdout is None:
return agg_map, window
for raw in stdout:
parts = raw.rstrip("\n").split("\t")
tag = parts[0]
if tag == "N" and len(parts) == _NATIVE_TSV_NAME_LEN:
_apply_native_name(parts, agg_map)
elif tag == "W" and len(parts) == _NATIVE_TSV_WIN_LEN:
window = _window_from_native(parts)
progress.update(1.0)
return agg_map, window
def aggregate_atop(
log: Path,
progress: _Progress,
) -> tuple[dict[str, ProcAgg], _Window]:
"""Stream PRC+PRM records, fold them into `{name: ProcAgg}`, return window.
Prefers the native `atop_agg` C helper (auto-built into
``~/.cache/usage_report/``) for ~7\u00d7 speedup on full-day logs, falling
back to an inline Python parser when the helper is unavailable.
"""
binary = _atop_agg_binary()
if binary is not None:
return _aggregate_atop_native(log, progress, binary)
progress.start_stage("atop: parse PRC+PRM")
pid_cpu: dict[int, _PidCpu] = {}
pid_ram: dict[int, _PidRam] = {}
epochs: set[int] = set()
log_size = max(log.stat().st_size, 1)
bytes_seen = 0
# Empirical: `atop -P PRC,PRM` stdout is ~11x the binary log size on a
# 10-min-interval log. The fraction is only used for the progress bar,
# so a rough calibration is fine; it caps at 99% if we underestimate.
est_total_bytes = log_size * 11 or 1
for raw in _iter_atop_lines(log, "PRC,PRM"):
bytes_seen += len(raw) + 1
if not raw or raw[0] == "#" or raw.startswith("RESET") or raw == "SEP":
continue
parts = raw.split()
if not parts:
continue
label = parts[0]
if label == "PRC" and len(parts) >= _PRC_MIN_LEN:
with contextlib.suppress(ValueError):
# atop always emits an integer epoch here; guard is defensive.
epochs.add(int(parts[2]))
progress.update(min(bytes_seen / est_total_bytes, 0.99))
_parse_prc(parts, pid_cpu)
elif label == "PRM" and len(parts) >= _PRM_MIN_LEN:
_parse_prm(parts, pid_ram)
progress.update(1.0)
return _fold_pid_aggregates(pid_cpu, pid_ram), _window_from_epochs(epochs)
def _fold_pid_aggregates(
pid_cpu: dict[int, _PidCpu],
pid_ram: dict[int, _PidRam],
) -> dict[str, ProcAgg]:
"""Collapse per-PID CPU/RAM trackers into per-program `ProcAgg` entries."""
agg: dict[str, ProcAgg] = {}
for pid, cpu in pid_cpu.items():
entry = agg.setdefault(cpu.name, ProcAgg(name=cpu.name))
entry.cpu_ticks += cpu.delta_ticks
entry.pid_set.add(pid)
for pid, ram in pid_ram.items():
entry = agg.setdefault(ram.name, ProcAgg(name=ram.name))
entry.peak_rss_kb = max(entry.peak_rss_kb, ram.peak_kb)
entry.rss_kb_sum += int(ram.avg_kb)
entry.rss_samples += 1
entry.pid_set.add(pid)
return agg
def _pmon_fields(line: str) -> list[str] | None:
"""Return stripped fields of a pmon data line, or None for headers/blanks."""
s = line.strip()
if not s or s.startswith("#"):
return None
return s.split()
def _normalize_pmon_command(command_fields: list[str]) -> str:
"""Normalize pmon command fields into a stable process-ish name.
`nvidia-smi pmon -o DT` emits fixed numeric columns followed by a command
field that can include whitespace. We prefer the *first* non-option token
(usually executable) and normalize it to a basename.
"""
tokens = [token.strip().strip("\"'") for token in command_fields if token.strip()]
if not tokens:
return "unknown"
selected = tokens[0]
if selected.startswith("-"):
for candidate in tokens[1:]:
if not candidate.startswith("-"):
selected = candidate
break
name = Path(selected).name.strip(";,:")
if not name:
return "unknown"
return name
def _pid_comm_name(pid: int) -> str | None:
"""Return `/proc/<pid>/comm` basename when available."""
try:
comm = Path(f"/proc/{pid}/comm").read_text(encoding="utf-8").strip()
except OSError:
return None
return Path(comm).name if comm else None
def aggregate_pmon(
log: Path,
progress: _Progress,
) -> tuple[dict[str, GpuAgg], int]:
"""Return `({program: GpuAgg}, sample_count)` from the pmon *log*."""
progress.start_stage("pmon log scan")
agg: dict[str, GpuAgg] = {}
samples = 0
if not log.exists():
progress.update(1.0)
return agg, 0
total_bytes = max(log.stat().st_size, 1)
bytes_read = 0
with log.open(encoding="utf-8") as fh:
for line in fh:
bytes_read += len(line)
progress.update(min(bytes_read / total_bytes, 0.99))
parts = _pmon_fields(line)
if parts is None or len(parts) < _PMON_MIN_FIELDS:
continue
samples += _ingest_pmon_row(parts, agg)
progress.update(1.0)
return agg, samples
def _ingest_pmon_row(parts: list[str], agg: dict[str, GpuAgg]) -> int:
"""Fold a single pmon data row into *agg*; return 1 if consumed else 0."""
# pmon -o DT fields:
# date time gpu pid type sm mem enc dec jpg ofa command
try:
pid = int(parts[3])
except ValueError:
return 0
sm_raw = parts[5]
mem_raw = parts[6]
command_fields = parts[11:]
name = _normalize_pmon_command(command_fields)
if name == "unknown":
name = _pid_comm_name(pid) or "unknown"
sm = float(sm_raw) if sm_raw != "-" else 0.0
mem = float(mem_raw) if mem_raw != "-" else 0.0
entry = agg.setdefault(name, GpuAgg(name=name))
entry.sm_pct_sum += sm
entry.mem_pct_sum += mem
entry.samples += 1
entry.pid_set.add(pid)
entry.peak_sm_pct = max(entry.peak_sm_pct, sm)
entry.peak_mem_pct = max(entry.peak_mem_pct, mem)
return 1
@dataclass
class _Window:
"""Observed atop coverage window."""
start: str = "n/a"
end: str = "n/a"
distinct_samples: int = 0
interval_s: int = 0
seconds: int = 0
def _host_profile() -> dict[str, str]:

View File

@ -231,6 +231,7 @@ repos:
language: system
types: [python]
pass_filenames: true
require_serial: true
stages: [pre-commit]
# ===========================================================================

View File

@ -55,7 +55,7 @@ def _build_pytest_command(packages: set[str]) -> list[str]:
"--cov-fail-under=100",
"-q",
"-n",
"auto",
"4",
# Override addopts from pyproject.toml to drop the global
# --cov=python_pkg that would widen coverage to the entire tree.
"-o",

View File

@ -0,0 +1,72 @@
"""Early bird window detection and log helpers for ScreenLocker."""
from __future__ import annotations
from datetime import datetime, timezone
import json
import logging
from python_pkg.screen_locker._constants import (
EARLY_BIRD_END_HOUR,
EARLY_BIRD_END_MINUTE,
EARLY_BIRD_START_HOUR,
)
_logger = logging.getLogger(__name__)
class EarlyBirdMixin:
"""Mixin providing early-bird time window checks and log helpers."""
def _get_local_time_minutes(self) -> int:
"""Return current local time as minutes from midnight."""
now = datetime.now(tz=timezone.utc).astimezone()
return now.hour * 60 + now.minute
def _is_early_bird_time(self) -> bool:
"""Return True if current local time is in the early bird window."""
minutes = self._get_local_time_minutes()
start = EARLY_BIRD_START_HOUR * 60
end = EARLY_BIRD_END_HOUR * 60 + EARLY_BIRD_END_MINUTE
return start <= minutes < end
def _is_early_bird_log(self) -> bool:
"""Check if today's workout log entry is an early_bird provisional entry."""
if not self.log_file.exists():
return False
try:
with self.log_file.open() as f:
logs = json.load(f)
except (OSError, json.JSONDecodeError):
return False
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
entry = logs.get(today)
if entry is None:
return False
return entry.get("workout_data", {}).get("type") == "early_bird"
def _save_early_bird_log(self) -> None:
"""Save an early_bird provisional entry to the workout log."""
self.workout_data = {"type": "early_bird"}
self.save_workout_log()
def _try_auto_upgrade_early_bird(self) -> bool:
"""Silently upgrade today's early_bird entry if phone shows a workout."""
try:
status, message = self._verify_phone_workout()
except (OSError, RuntimeError) as exc:
_logger.info("Early bird upgrade phone check failed: %s", exc)
return False
if status != "verified":
_logger.info(
"Early bird upgrade skipped (phone status=%s): %s",
status,
message,
)
return False
self.workout_data["type"] = "phone_verified"
self.workout_data["source"] = message
self.workout_data["after_early_bird"] = "true"
self._adjust_shutdown_time_later()
self.save_workout_log()
return True

View File

@ -0,0 +1,80 @@
"""Window configuration and input-grab helpers for ScreenLocker."""
from __future__ import annotations
import contextlib
import logging
import shutil
import subprocess
import tkinter as tk
_logger = logging.getLogger(__name__)
class WindowSetupMixin:
"""Mixin providing window setup, VT switching control, and input-grab helpers."""
def _disable_vt_switching(self) -> None:
"""Disable VT switching in X11 while the lock is active.
Prevents bypassing the lock by switching to a TTY with Ctrl+Alt+Fn.
Best-effort: silently ignored if setxkbmap is unavailable.
"""
setxkbmap = shutil.which("setxkbmap")
if setxkbmap is None:
_logger.warning("setxkbmap not found; VT switching will not be disabled")
return
subprocess.run([setxkbmap, "-option", "srvrkeys:none"], check=False)
def _restore_vt_switching(self) -> None:
"""Restore VT switching after the lock is dismissed."""
setxkbmap = shutil.which("setxkbmap")
if setxkbmap is None:
return
subprocess.run([setxkbmap, "-option", ""], check=False)
def _setup_window(self) -> None:
"""Configure the window for fullscreen lock."""
screen_w = self.root.winfo_screenwidth()
screen_h = self.root.winfo_screenheight()
self.root.overrideredirect(boolean=True)
self.root.geometry(f"{screen_w}x{screen_h}+0+0")
self.root.attributes(fullscreen=True)
self.root.attributes(topmost=True)
self.root.configure(bg="#1a1a1a", cursor="arrow")
if not self.demo_mode:
self._disable_vt_switching()
def _setup_verify_window(self) -> None:
"""Configure window for post-sick-day workout verification."""
self.root.geometry("600x400")
self.root.configure(bg="#1a1a1a", cursor="arrow")
self.root.protocol("WM_DELETE_WINDOW", self.close)
def _setup_demo_close_button(self) -> None:
"""Add close button for demo mode."""
close_btn = tk.Button(
self.root,
text="✕ Close Demo",
font=("Arial", 12),
bg="#ff4444",
fg="white",
command=self.close,
cursor="hand2",
)
close_btn.place(x=10, y=10)
def _grab_input(self) -> None:
"""Force input focus to the locker window."""
self.root.update_idletasks()
self.root.focus_force()
if self.demo_mode:
with contextlib.suppress(tk.TclError):
self.root.grab_set()
else:
try:
self.root.grab_set_global()
except tk.TclError:
_logger.warning("Global grab failed, falling back to local grab")
with contextlib.suppress(tk.TclError):
self.root.grab_set()

View File

@ -6,13 +6,10 @@ Requires user to log their workout to unlock the screen.
from __future__ import annotations
import contextlib
from datetime import datetime, timezone
import json
import logging
from pathlib import Path
import shutil
import subprocess
import sys
import tkinter as tk
from typing import TYPE_CHECKING
@ -31,6 +28,7 @@ from python_pkg.screen_locker._constants import (
SICK_LOCKOUT_SECONDS,
STRONGLIFTS_DB_REMOTE,
)
from python_pkg.screen_locker._early_bird import EarlyBirdMixin
from python_pkg.screen_locker._log_integrity import (
_load_hmac_key,
compute_entry_hmac,
@ -40,6 +38,7 @@ from python_pkg.screen_locker._phone_verification import PhoneVerificationMixin
from python_pkg.screen_locker._shutdown import ShutdownMixin
from python_pkg.screen_locker._sick_dialog import SickDialogMixin
from python_pkg.screen_locker._ui_flows import UIFlowsMixin
from python_pkg.screen_locker._window_setup import WindowSetupMixin
from python_pkg.wake_alarm._state import has_workout_skip_today
if TYPE_CHECKING:
@ -80,6 +79,8 @@ def _assert_not_under_pytest() -> None:
class ScreenLocker(
EarlyBirdMixin,
WindowSetupMixin,
ShutdownMixin,
PhoneVerificationMixin,
SickDialogMixin,
@ -122,43 +123,6 @@ class ScreenLocker(
self._start_phone_check()
self._grab_input()
def _disable_vt_switching(self) -> None:
"""Disable VT switching in X11 while the lock is active.
Prevents bypassing the lock by switching to a TTY with Ctrl+Alt+Fn.
Best-effort: silently ignored if setxkbmap is unavailable.
"""
setxkbmap = shutil.which("setxkbmap")
if setxkbmap is None:
_logger.warning("setxkbmap not found; VT switching will not be disabled")
return
subprocess.run([setxkbmap, "-option", "srvrkeys:none"], check=False)
def _restore_vt_switching(self) -> None:
"""Restore VT switching after the lock is dismissed."""
setxkbmap = shutil.which("setxkbmap")
if setxkbmap is None:
return
subprocess.run([setxkbmap, "-option", ""], check=False)
def _setup_window(self) -> None:
"""Configure the window for fullscreen lock."""
screen_w = self.root.winfo_screenwidth()
screen_h = self.root.winfo_screenheight()
self.root.overrideredirect(boolean=True)
self.root.geometry(f"{screen_w}x{screen_h}+0+0")
self.root.attributes(fullscreen=True)
self.root.attributes(topmost=True)
self.root.configure(bg="#1a1a1a", cursor="arrow")
if not self.demo_mode:
self._disable_vt_switching()
def _setup_verify_window(self) -> None:
"""Configure window for post-sick-day workout verification."""
self.root.geometry("600x400")
self.root.configure(bg="#1a1a1a", cursor="arrow")
self.root.protocol("WM_DELETE_WINDOW", self.close)
def _is_sick_day_log(self) -> bool:
"""Check if today's workout log is a sick day (not yet verified)."""
if not self.log_file.exists():
@ -219,59 +183,6 @@ class ScreenLocker(
)
sys.exit(0)
def _get_local_time_minutes(self) -> int:
"""Return current local time as minutes from midnight."""
now = datetime.now(tz=timezone.utc).astimezone()
return now.hour * 60 + now.minute
def _is_early_bird_time(self) -> bool:
"""Return True if current local time is in the early bird window."""
minutes = self._get_local_time_minutes()
start = EARLY_BIRD_START_HOUR * 60
end = EARLY_BIRD_END_HOUR * 60 + EARLY_BIRD_END_MINUTE
return start <= minutes < end
def _is_early_bird_log(self) -> bool:
"""Check if today's workout log entry is an early_bird provisional entry."""
if not self.log_file.exists():
return False
try:
with self.log_file.open() as f:
logs = json.load(f)
except (OSError, json.JSONDecodeError):
return False
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
entry = logs.get(today)
if entry is None:
return False
return entry.get("workout_data", {}).get("type") == "early_bird"
def _save_early_bird_log(self) -> None:
"""Save an early_bird provisional entry to the workout log."""
self.workout_data = {"type": "early_bird"}
self.save_workout_log()
def _try_auto_upgrade_early_bird(self) -> bool:
"""Silently upgrade today's early_bird entry if phone shows a workout."""
try:
status, message = self._verify_phone_workout()
except (OSError, RuntimeError) as exc:
_logger.info("Early bird upgrade phone check failed: %s", exc)
return False
if status != "verified":
_logger.info(
"Early bird upgrade skipped (phone status=%s): %s",
status,
message,
)
return False
self.workout_data["type"] = "phone_verified"
self.workout_data["source"] = message
self.workout_data["after_early_bird"] = "true"
self._adjust_shutdown_time_later()
self.save_workout_log()
return True
def _try_auto_upgrade_sick_day(self) -> bool:
"""Silently upgrade today's sick_day entry if phone shows a workout."""
try:
@ -293,34 +204,6 @@ class ScreenLocker(
self.save_workout_log()
return True
def _setup_demo_close_button(self) -> None:
"""Add close button for demo mode."""
close_btn = tk.Button(
self.root,
text="✕ Close Demo",
font=("Arial", 12),
bg="#ff4444",
fg="white",
command=self.close,
cursor="hand2",
)
close_btn.place(x=10, y=10)
def _grab_input(self) -> None:
"""Force input focus to the locker window."""
self.root.update_idletasks()
self.root.focus_force()
if self.demo_mode:
with contextlib.suppress(tk.TclError):
self.root.grab_set()
else:
try:
self.root.grab_set_global()
except tk.TclError:
_logger.warning("Global grab failed, falling back to local grab")
with contextlib.suppress(tk.TclError):
self.root.grab_set()
def clear_container(self) -> None:
"""Remove all widgets from the main container."""
for widget in self.container.winfo_children():

View File

@ -70,10 +70,10 @@ def mock_subprocess_run() -> Generator[MagicMock]:
"""
with (
patch(
"python_pkg.screen_locker.screen_lock.shutil.which",
"python_pkg.screen_locker._window_setup.shutil.which",
return_value="/usr/bin/setxkbmap",
),
patch("python_pkg.screen_locker.screen_lock.subprocess.run") as mock,
patch("python_pkg.screen_locker._window_setup.subprocess.run") as mock,
):
yield mock

View File

@ -3,16 +3,12 @@
from __future__ import annotations
import datetime
import json
import sqlite3
import subprocess
import time
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
from python_pkg.screen_locker.screen_lock import STRONGLIFTS_DB_REMOTE
from python_pkg.screen_locker.tests.conftest import create_locker
@ -478,379 +474,3 @@ class TestCountTodayWorkouts:
conn.close()
assert locker._count_today_workouts(db_file) == 2
class TestGetTodayWorkoutDurationMinutes:
"""Tests for _get_today_workout_duration_minutes method."""
def test_returns_duration_for_today_workout(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns correct duration for a 60-minute workout."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
duration_ms = 60 * 60 * 1000 # 60 minutes
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + duration_ms),
)
conn.commit()
conn.close()
result = locker._get_today_workout_duration_minutes(db_file)
assert result == pytest.approx(60.0, abs=1.0)
def test_returns_zero_for_no_workouts(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0.0 when no workouts today."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
yesterday_ms = int((time.time() - 200000) * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", yesterday_ms, yesterday_ms + 3600000),
)
conn.commit()
conn.close()
assert not locker._get_today_workout_duration_minutes(db_file)
def test_sums_multiple_workouts(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test sums durations of multiple workouts today."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
# 30 min + 25 min = 55 min total
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + 30 * 60 * 1000),
)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w2", now_ms + 31 * 60 * 1000, now_ms + 56 * 60 * 1000),
)
conn.commit()
conn.close()
result = locker._get_today_workout_duration_minutes(db_file)
assert result == pytest.approx(55.0, abs=1.0)
def test_ignores_invalid_finish(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ignores workouts where finish <= start."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
# finish == start (zero duration - should be excluded by WHERE)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms),
)
conn.commit()
conn.close()
assert not locker._get_today_workout_duration_minutes(db_file)
def test_invalid_db_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0.0 for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "not_a_db.db"
bad_file.write_text("not a database")
assert not locker._get_today_workout_duration_minutes(bad_file)
def test_missing_table_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0.0 when workouts table doesn't exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "empty.db"
conn = sqlite3.connect(str(db_file))
conn.execute("CREATE TABLE other (id TEXT)")
conn.commit()
conn.close()
assert not locker._get_today_workout_duration_minutes(db_file)
class TestGetTodayExerciseCount:
"""Tests for _get_today_exercise_count method."""
def test_counts_exercises(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test counts distinct exercises in today's workouts."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
exercises_json = json.dumps(
[
{"id": "squat", "name": "Squat"},
{"id": "bench_press", "name": "Bench Press"},
{"id": "squat", "name": "Squat"},
{"category": "WARMUP"},
]
)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, exercises_json),
)
conn.commit()
conn.close()
assert locker._get_today_exercise_count(db_file) == 2
def test_no_exercises_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when no exercises exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, "[]"),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
def test_invalid_db_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "bad.db"
bad_file.write_text("not a db")
assert not locker._get_today_exercise_count(bad_file)
def test_missing_exercises_column_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when workouts table has no exercises column."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "empty.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + 3600000),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
def test_null_exercises_json_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when exercises JSON is NULL."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "null_ex.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, None),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
def test_malformed_exercises_json_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when exercises JSON is malformed."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "bad_json.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, "not valid json"),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
class TestIsWorkoutFinishRecent:
"""Tests for _is_workout_finish_recent method."""
def test_recent_workout_returns_true(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns True for workout that finished recently."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
# Anchor to local noon to avoid midnight boundary issues: the SQL
# date() filter requires start and now to share the same local date.
local_noon = (
datetime.datetime.now(tz=datetime.timezone.utc)
.astimezone()
.replace(hour=12, minute=0, second=0, microsecond=0)
)
local_noon_ms = int(local_noon.timestamp() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", local_noon_ms, local_noon_ms + 3_600_000),
)
conn.commit()
conn.close()
assert locker._is_workout_finish_recent(db_file) is True
def test_old_workout_returns_false(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns False for workout that finished >24 hours ago."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
# Finished 25 hours ago (not "today" in local time either)
now_ms = int(time.time() * 1000)
old_finish = now_ms - 25 * 3600 * 1000 # beyond 24h window
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", old_finish - 3600000, old_finish),
)
conn.commit()
conn.close()
assert locker._is_workout_finish_recent(db_file) is False
def test_no_workouts_returns_false(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns False when no workouts exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
conn.commit()
conn.close()
assert locker._is_workout_finish_recent(db_file) is False
def test_invalid_db_returns_false(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns False for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "bad.db"
bad_file.write_text("not a db")
assert locker._is_workout_finish_recent(bad_file) is False

View File

@ -0,0 +1,394 @@
"""Tests for ADB commands, phone connection, and database operations."""
# pylint: disable=protected-access,unused-argument
from __future__ import annotations
import datetime
import json
import sqlite3
import time
from typing import TYPE_CHECKING
import pytest
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
from unittest.mock import MagicMock
class TestGetTodayWorkoutDurationMinutes:
"""Tests for _get_today_workout_duration_minutes method."""
def test_returns_duration_for_today_workout(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns correct duration for a 60-minute workout."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
duration_ms = 60 * 60 * 1000 # 60 minutes
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + duration_ms),
)
conn.commit()
conn.close()
result = locker._get_today_workout_duration_minutes(db_file)
assert result == pytest.approx(60.0, abs=1.0)
def test_returns_zero_for_no_workouts(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0.0 when no workouts today."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
yesterday_ms = int((time.time() - 200000) * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", yesterday_ms, yesterday_ms + 3600000),
)
conn.commit()
conn.close()
assert not locker._get_today_workout_duration_minutes(db_file)
def test_sums_multiple_workouts(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test sums durations of multiple workouts today."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
# 30 min + 25 min = 55 min total
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + 30 * 60 * 1000),
)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w2", now_ms + 31 * 60 * 1000, now_ms + 56 * 60 * 1000),
)
conn.commit()
conn.close()
result = locker._get_today_workout_duration_minutes(db_file)
assert result == pytest.approx(55.0, abs=1.0)
def test_ignores_invalid_finish(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ignores workouts where finish <= start."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
# finish == start (zero duration - should be excluded by WHERE)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms),
)
conn.commit()
conn.close()
assert not locker._get_today_workout_duration_minutes(db_file)
def test_invalid_db_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0.0 for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "not_a_db.db"
bad_file.write_text("not a database")
assert not locker._get_today_workout_duration_minutes(bad_file)
def test_missing_table_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0.0 when workouts table doesn't exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "empty.db"
conn = sqlite3.connect(str(db_file))
conn.execute("CREATE TABLE other (id TEXT)")
conn.commit()
conn.close()
assert not locker._get_today_workout_duration_minutes(db_file)
class TestGetTodayExerciseCount:
"""Tests for _get_today_exercise_count method."""
def test_counts_exercises(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test counts distinct exercises in today's workouts."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
exercises_json = json.dumps(
[
{"id": "squat", "name": "Squat"},
{"id": "bench_press", "name": "Bench Press"},
{"id": "squat", "name": "Squat"},
{"category": "WARMUP"},
]
)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, exercises_json),
)
conn.commit()
conn.close()
assert locker._get_today_exercise_count(db_file) == 2
def test_no_exercises_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when no exercises exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, "[]"),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
def test_invalid_db_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "bad.db"
bad_file.write_text("not a db")
assert not locker._get_today_exercise_count(bad_file)
def test_missing_exercises_column_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when workouts table has no exercises column."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "empty.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + 3600000),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
def test_null_exercises_json_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when exercises JSON is NULL."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "null_ex.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, None),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
def test_malformed_exercises_json_returns_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when exercises JSON is malformed."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "bad_json.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER, exercises TEXT)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?, ?)",
("w1", now_ms, now_ms + 3600000, "not valid json"),
)
conn.commit()
conn.close()
assert not locker._get_today_exercise_count(db_file)
class TestIsWorkoutFinishRecent:
"""Tests for _is_workout_finish_recent method."""
def test_recent_workout_returns_true(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns True for workout that finished recently."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
# Anchor to local noon to avoid midnight boundary issues: the SQL
# date() filter requires start and now to share the same local date.
local_noon = (
datetime.datetime.now(tz=datetime.timezone.utc)
.astimezone()
.replace(hour=12, minute=0, second=0, microsecond=0)
)
local_noon_ms = int(local_noon.timestamp() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", local_noon_ms, local_noon_ms + 3_600_000),
)
conn.commit()
conn.close()
assert locker._is_workout_finish_recent(db_file) is True
def test_old_workout_returns_false(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns False for workout that finished >24 hours ago."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
# Finished 25 hours ago (not "today" in local time either)
now_ms = int(time.time() * 1000)
old_finish = now_ms - 25 * 3600 * 1000 # beyond 24h window
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", old_finish - 3600000, old_finish),
)
conn.commit()
conn.close()
assert locker._is_workout_finish_recent(db_file) is False
def test_no_workouts_returns_false(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns False when no workouts exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
conn.commit()
conn.close()
assert locker._is_workout_finish_recent(db_file) is False
def test_invalid_db_returns_false(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns False for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "bad.db"
bad_file.write_text("not a db")
assert locker._is_workout_finish_recent(bad_file) is False

View File

@ -10,7 +10,7 @@ from unittest.mock import MagicMock, patch
import pytest
from python_pkg.screen_locker.screen_lock import ScreenLocker, _assert_not_under_pytest
from python_pkg.screen_locker.screen_lock import _assert_not_under_pytest
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
@ -340,227 +340,3 @@ class TestSaveWorkoutLog:
):
# Should not raise, just log warning
locker.save_workout_log()
class TestRun:
"""Tests for run method."""
def test_run_starts_mainloop(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test run starts the tkinter mainloop."""
locker = create_locker(mock_tk, tmp_path)
locker.run()
locker.root.mainloop.assert_called_once()
class TestAutoUpgradeSickDay:
"""Tests for sick_day → phone_verified silent upgrade helpers."""
def test_upgrade_succeeds_when_phone_verified(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Verified phone workout overwrites today's sick_day entry."""
log_file = tmp_path / "workout_log.json"
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
with (
patch.object(
locker,
"_verify_phone_workout",
return_value=("verified", "Workout verified! (1 session)"),
),
patch.object(
locker,
"_adjust_shutdown_time_later",
return_value=True,
) as mock_adjust,
patch(
"python_pkg.screen_locker.screen_lock.compute_entry_hmac",
return_value="sig",
),
):
assert locker._try_auto_upgrade_sick_day() is True
mock_adjust.assert_called_once()
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
with log_file.open() as f:
data: dict[str, Any] = json.load(f)
assert data[today]["workout_data"]["type"] == "phone_verified"
assert data[today]["workout_data"]["after_sick_day"] == "true"
def test_upgrade_skipped_when_not_verified(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Non-verified statuses leave the sick_day entry untouched."""
locker = create_locker(mock_tk, tmp_path)
with patch.object(
locker,
"_verify_phone_workout",
return_value=("no_phone", "No phone connected"),
):
assert locker._try_auto_upgrade_sick_day() is False
assert locker.workout_data == {}
def test_upgrade_skipped_on_exception(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Transient OSError/RuntimeError during check is non-fatal."""
locker = create_locker(mock_tk, tmp_path)
with patch.object(
locker,
"_verify_phone_workout",
side_effect=OSError("transient"),
):
assert locker._try_auto_upgrade_sick_day() is False
def test_init_exits_when_sick_day_upgrade_succeeds(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Startup exits 0 after a successful silent sick_day upgrade."""
mock_sys_exit.side_effect = SystemExit(0)
with (
patch.object(
ScreenLocker,
"_try_auto_upgrade_sick_day",
return_value=True,
) as mock_upgrade,
pytest.raises(SystemExit),
):
create_locker(mock_tk, tmp_path, is_sick_day_log=True)
mock_upgrade.assert_called_once()
mock_sys_exit.assert_called_once_with(0)
class TestMainEntry:
"""Tests for main entry point."""
def test_main_demo_mode_default(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test main defaults to demo mode."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
assert locker.demo_mode is True
def test_main_production_mode_flag(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test main with --production flag."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
assert locker.demo_mode is False
class TestAdjustShutdownTimeLater:
"""Tests for _adjust_shutdown_time_later method."""
def test_adjust_shutdown_time_later_success(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later adds hours successfully."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker, "_read_shutdown_config", MagicMock(return_value=(21, 22, 8))
)
object.__setattr__(
locker, "_write_shutdown_config", MagicMock(return_value=True)
)
result = locker._adjust_shutdown_time_later()
assert result is True
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
def test_adjust_shutdown_time_later_caps_at_23(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later caps hours at 23."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker, "_read_shutdown_config", MagicMock(return_value=(22, 23, 8))
)
object.__setattr__(
locker, "_write_shutdown_config", MagicMock(return_value=True)
)
result = locker._adjust_shutdown_time_later()
assert result is True
# 22+2=24 capped to 23, 23+2=25 capped to 23
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
def test_adjust_shutdown_time_later_no_config(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later returns False if config missing."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker, "_read_shutdown_config", MagicMock(return_value=None)
)
result = locker._adjust_shutdown_time_later()
assert result is False
def test_adjust_shutdown_time_later_oserror(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later handles OSError."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker,
"_read_shutdown_config",
MagicMock(side_effect=OSError("permission denied")),
)
result = locker._adjust_shutdown_time_later()
assert result is False
class TestGrabInput:
"""Tests for _grab_input method."""
def test_production_global_grab_tcl_error(
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
) -> None:
"""Test production mode falls back when global grab fails."""
mock_tk.Tk.return_value.grab_set_global.side_effect = tk.TclError("grab failed")
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
assert locker.demo_mode is False

View File

@ -0,0 +1,241 @@
"""Tests for screen_locker initialization, logging, and basic operations."""
from __future__ import annotations
from datetime import datetime, timezone
import json
import tkinter as tk
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
import pytest
from python_pkg.screen_locker.screen_lock import ScreenLocker
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
class TestRun:
"""Tests for run method."""
def test_run_starts_mainloop(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test run starts the tkinter mainloop."""
locker = create_locker(mock_tk, tmp_path)
locker.run()
locker.root.mainloop.assert_called_once()
class TestAutoUpgradeSickDay:
"""Tests for sick_day → phone_verified silent upgrade helpers."""
def test_upgrade_succeeds_when_phone_verified(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Verified phone workout overwrites today's sick_day entry."""
log_file = tmp_path / "workout_log.json"
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
with (
patch.object(
locker,
"_verify_phone_workout",
return_value=("verified", "Workout verified! (1 session)"),
),
patch.object(
locker,
"_adjust_shutdown_time_later",
return_value=True,
) as mock_adjust,
patch(
"python_pkg.screen_locker.screen_lock.compute_entry_hmac",
return_value="sig",
),
):
assert locker._try_auto_upgrade_sick_day() is True
mock_adjust.assert_called_once()
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
with log_file.open() as f:
data: dict[str, Any] = json.load(f)
assert data[today]["workout_data"]["type"] == "phone_verified"
assert data[today]["workout_data"]["after_sick_day"] == "true"
def test_upgrade_skipped_when_not_verified(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Non-verified statuses leave the sick_day entry untouched."""
locker = create_locker(mock_tk, tmp_path)
with patch.object(
locker,
"_verify_phone_workout",
return_value=("no_phone", "No phone connected"),
):
assert locker._try_auto_upgrade_sick_day() is False
assert locker.workout_data == {}
def test_upgrade_skipped_on_exception(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Transient OSError/RuntimeError during check is non-fatal."""
locker = create_locker(mock_tk, tmp_path)
with patch.object(
locker,
"_verify_phone_workout",
side_effect=OSError("transient"),
):
assert locker._try_auto_upgrade_sick_day() is False
def test_init_exits_when_sick_day_upgrade_succeeds(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Startup exits 0 after a successful silent sick_day upgrade."""
mock_sys_exit.side_effect = SystemExit(0)
with (
patch.object(
ScreenLocker,
"_try_auto_upgrade_sick_day",
return_value=True,
) as mock_upgrade,
pytest.raises(SystemExit),
):
create_locker(mock_tk, tmp_path, is_sick_day_log=True)
mock_upgrade.assert_called_once()
mock_sys_exit.assert_called_once_with(0)
class TestMainEntry:
"""Tests for main entry point."""
def test_main_demo_mode_default(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test main defaults to demo mode."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
assert locker.demo_mode is True
def test_main_production_mode_flag(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test main with --production flag."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
assert locker.demo_mode is False
class TestAdjustShutdownTimeLater:
"""Tests for _adjust_shutdown_time_later method."""
def test_adjust_shutdown_time_later_success(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later adds hours successfully."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker, "_read_shutdown_config", MagicMock(return_value=(21, 22, 8))
)
object.__setattr__(
locker, "_write_shutdown_config", MagicMock(return_value=True)
)
result = locker._adjust_shutdown_time_later()
assert result is True
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
def test_adjust_shutdown_time_later_caps_at_23(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later caps hours at 23."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker, "_read_shutdown_config", MagicMock(return_value=(22, 23, 8))
)
object.__setattr__(
locker, "_write_shutdown_config", MagicMock(return_value=True)
)
result = locker._adjust_shutdown_time_later()
assert result is True
# 22+2=24 capped to 23, 23+2=25 capped to 23
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
def test_adjust_shutdown_time_later_no_config(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later returns False if config missing."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker, "_read_shutdown_config", MagicMock(return_value=None)
)
result = locker._adjust_shutdown_time_later()
assert result is False
def test_adjust_shutdown_time_later_oserror(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later handles OSError."""
locker = create_locker(mock_tk, tmp_path)
object.__setattr__(
locker,
"_read_shutdown_config",
MagicMock(side_effect=OSError("permission denied")),
)
result = locker._adjust_shutdown_time_later()
assert result is False
class TestGrabInput:
"""Tests for _grab_input method."""
def test_production_global_grab_tcl_error(
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
) -> None:
"""Test production mode falls back when global grab fails."""
mock_tk.Tk.return_value.grab_set_global.side_effect = tk.TclError("grab failed")
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
assert locker.demo_mode is False

View File

@ -6,11 +6,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from python_pkg.screen_locker._constants import NO_PHONE_EXTRA_LOCKOUT_SECONDS
from python_pkg.screen_locker.screen_lock import (
PHONE_PENALTY_DELAY_DEMO,
PHONE_PENALTY_DELAY_PRODUCTION,
)
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
@ -491,166 +486,3 @@ class TestStartPhoneCheck:
locker._handle_startup_phone_result.assert_called_once_with(
"no_phone", "No phone"
)
class TestShowPhonePenalty:
"""Tests for _show_phone_penalty and _update_phone_penalty methods."""
def test_show_phone_penalty_demo_delay(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test demo mode uses short penalty delay."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
object.__setattr__(locker, "clear_container", MagicMock())
locker._show_phone_penalty("test message")
# _update_phone_penalty is called once, decrementing by 1
assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_DEMO - 1
def test_show_phone_penalty_production_delay(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test production mode uses long penalty delay (base + no-phone bump)."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
object.__setattr__(locker, "clear_container", MagicMock())
locker._show_phone_penalty("test message")
expected = PHONE_PENALTY_DELAY_PRODUCTION + NO_PHONE_EXTRA_LOCKOUT_SECONDS - 1
assert locker.phone_penalty_remaining == expected
def test_update_phone_penalty_countdown(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone penalty countdown decrements."""
locker = create_locker(mock_tk, tmp_path)
locker.phone_penalty_remaining = 5
locker.phone_penalty_label = MagicMock()
locker._update_phone_penalty()
assert locker.phone_penalty_remaining == 4
locker.phone_penalty_label.config.assert_called_once_with(text="5")
locker.root.after.assert_called()
def test_update_phone_penalty_at_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone penalty calls done function when timer reaches zero."""
locker = create_locker(mock_tk, tmp_path)
locker.phone_penalty_remaining = 0
locker.phone_penalty_label = MagicMock()
mock_done = MagicMock()
locker._phone_penalty_done_fn = mock_done
locker._update_phone_penalty()
mock_done.assert_called_once()
def test_show_phone_penalty_default_callback_shows_retry(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test default phone penalty callback shows retry+sick screen."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
object.__setattr__(locker, "clear_container", MagicMock())
object.__setattr__(locker, "_show_retry_and_sick", MagicMock())
locker._show_phone_penalty("No phone connected")
# Simulate timer reaching zero by calling the done function
locker._phone_penalty_done_fn()
locker._show_retry_and_sick.assert_called_once_with("No phone connected")
class TestUnlockScreenShutdownAdjustment:
"""Tests for unlock_screen shutdown time adjustment."""
def test_unlock_screen_adjusts_for_phone_verified(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen adjusts shutdown for phone-verified workout."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "phone_verified"}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=True)
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
def test_unlock_screen_skips_adjustment_for_sick_day(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen does not adjust for sick day."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "sick_day"}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=True)
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_not_called()
def test_unlock_screen_skips_adjustment_no_type(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen does not adjust when no workout type."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=True)
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_not_called()
def test_unlock_screen_handles_adjustment_failure(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen continues when adjustment fails."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "phone_verified"}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=False)
)
# Should not raise, should continue with unlock
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
locker.root.after.assert_called()

View File

@ -0,0 +1,180 @@
"""Tests for phone workout verification, phone check, and unlock operations."""
# pylint: disable=protected-access,unused-argument
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
from python_pkg.screen_locker._constants import NO_PHONE_EXTRA_LOCKOUT_SECONDS
from python_pkg.screen_locker.screen_lock import (
PHONE_PENALTY_DELAY_DEMO,
PHONE_PENALTY_DELAY_PRODUCTION,
)
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
class TestShowPhonePenalty:
"""Tests for _show_phone_penalty and _update_phone_penalty methods."""
def test_show_phone_penalty_demo_delay(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test demo mode uses short penalty delay."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
object.__setattr__(locker, "clear_container", MagicMock())
locker._show_phone_penalty("test message")
# _update_phone_penalty is called once, decrementing by 1
assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_DEMO - 1
def test_show_phone_penalty_production_delay(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test production mode uses long penalty delay (base + no-phone bump)."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
object.__setattr__(locker, "clear_container", MagicMock())
locker._show_phone_penalty("test message")
expected = PHONE_PENALTY_DELAY_PRODUCTION + NO_PHONE_EXTRA_LOCKOUT_SECONDS - 1
assert locker.phone_penalty_remaining == expected
def test_update_phone_penalty_countdown(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone penalty countdown decrements."""
locker = create_locker(mock_tk, tmp_path)
locker.phone_penalty_remaining = 5
locker.phone_penalty_label = MagicMock()
locker._update_phone_penalty()
assert locker.phone_penalty_remaining == 4
locker.phone_penalty_label.config.assert_called_once_with(text="5")
locker.root.after.assert_called()
def test_update_phone_penalty_at_zero(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone penalty calls done function when timer reaches zero."""
locker = create_locker(mock_tk, tmp_path)
locker.phone_penalty_remaining = 0
locker.phone_penalty_label = MagicMock()
mock_done = MagicMock()
locker._phone_penalty_done_fn = mock_done
locker._update_phone_penalty()
mock_done.assert_called_once()
def test_show_phone_penalty_default_callback_shows_retry(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test default phone penalty callback shows retry+sick screen."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
object.__setattr__(locker, "clear_container", MagicMock())
object.__setattr__(locker, "_show_retry_and_sick", MagicMock())
locker._show_phone_penalty("No phone connected")
# Simulate timer reaching zero by calling the done function
locker._phone_penalty_done_fn()
locker._show_retry_and_sick.assert_called_once_with("No phone connected")
class TestUnlockScreenShutdownAdjustment:
"""Tests for unlock_screen shutdown time adjustment."""
def test_unlock_screen_adjusts_for_phone_verified(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen adjusts shutdown for phone-verified workout."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "phone_verified"}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=True)
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
def test_unlock_screen_skips_adjustment_for_sick_day(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen does not adjust for sick day."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "sick_day"}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=True)
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_not_called()
def test_unlock_screen_skips_adjustment_no_type(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen does not adjust when no workout type."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=True)
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_not_called()
def test_unlock_screen_handles_adjustment_failure(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen continues when adjustment fails."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "phone_verified"}
object.__setattr__(
locker, "_adjust_shutdown_time_later", MagicMock(return_value=False)
)
# Should not raise, should continue with unlock
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
locker.root.after.assert_called()

View File

@ -109,7 +109,7 @@ class TestVTSwitching:
) -> None:
"""No crash and no subprocess call when setxkbmap is not installed."""
with patch(
"python_pkg.screen_locker.screen_lock.shutil.which",
"python_pkg.screen_locker._window_setup.shutil.which",
return_value=None,
):
create_locker(mock_tk, tmp_path, demo_mode=False)
@ -128,7 +128,7 @@ class TestVTSwitching:
mock_subprocess_run.reset_mock()
with patch(
"python_pkg.screen_locker.screen_lock.shutil.which",
"python_pkg.screen_locker._window_setup.shutil.which",
return_value=None,
):
locker.close()

View File

@ -5,6 +5,10 @@ from __future__ import annotations
import logging
from python_pkg.steam_backlog_enforcer._enforce_loop import get_all_owned_app_ids
from python_pkg.steam_backlog_enforcer._scanning_confidence import (
_confidence_fail_reasons,
_refresh_candidate_confidence,
)
from python_pkg.steam_backlog_enforcer.config import Config, State, load_snapshot
from python_pkg.steam_backlog_enforcer.enforcer import (
enforce_allowed_game,
@ -26,9 +30,7 @@ from python_pkg.steam_backlog_enforcer.hltb import (
)
from python_pkg.steam_backlog_enforcer.library_hider import hide_other_games
from python_pkg.steam_backlog_enforcer.scanning import (
_confidence_fail_reasons,
_pick_next_shortest_candidate,
_refresh_candidate_confidence,
pick_next_game,
)
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo, SteamAPIClient

View File

@ -0,0 +1,471 @@
"""Internal HLTB search helpers: URL discovery, auth, matching, and batch fetch."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from http import HTTPStatus
import json
import logging
import re
import time
from typing import Any
import aiohttp
from howlongtobeatpy.HTMLRequests import HTMLRequests
from python_pkg.steam_backlog_enforcer._hltb_detail import (
_fetch_leisure_times,
)
from python_pkg.steam_backlog_enforcer._hltb_types import (
_SAVE_INTERVAL,
_SUBSET_SUFFIXES,
MAX_CONCURRENT,
MIN_SIMILARITY,
HLTBResult,
ProgressCb,
_AuthInfo,
save_hltb_cache,
)
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────────────────────
# HLTB API setup (done once, not per-request like the library)
# ──────────────────────────────────────────────────────────────
def _get_hltb_search_url() -> str:
"""Discover the current HLTB search API endpoint.
Scrapes the homepage for JS bundles containing the fetch URL.
Falls back to ``/api/finder`` if extraction fails.
"""
try:
search_info = HTMLRequests.send_website_request_getcode(
parse_all_scripts=False,
)
if search_info is None:
search_info = HTMLRequests.send_website_request_getcode(
parse_all_scripts=True,
)
if search_info and search_info.search_url:
url: str = HTMLRequests.BASE_URL + search_info.search_url
return url
except (OSError, RuntimeError, ValueError, TypeError):
logger.debug("Failed to discover HLTB search URL, using default")
return "https://howlongtobeat.com/api/finder"
async def _get_auth_info(
search_url: str,
session: aiohttp.ClientSession,
) -> _AuthInfo | None:
"""Fetch the HLTB auth token and honeypot key/val (one GET request)."""
init_url = search_url + "/init"
ts = int(time.time() * 1000)
headers = {
"User-Agent": (
"Mozilla/5.0 (X11; Linux x86_64; rv:136.0) Gecko/20100101 Firefox/136.0"
),
"referer": "https://howlongtobeat.com/",
}
try:
async with session.get(
init_url,
params={"t": ts},
headers=headers,
) as resp:
if resp.status == HTTPStatus.OK:
data = await resp.json()
token: str | None = data.get("token")
if token is None:
return None
return _AuthInfo(
token=token,
hp_key=data.get("hpKey", ""),
hp_val=data.get("hpVal", ""),
)
except (aiohttp.ClientError, asyncio.TimeoutError):
logger.warning("Failed to get HLTB auth token")
return None
def _similarity(a: str, b: str) -> float:
"""Case-insensitive SequenceMatcher ratio between two strings."""
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
def _build_search_payload(game_name: str, auth: _AuthInfo | None = None) -> str:
"""Build the JSON POST body for an HLTB search."""
payload: dict[str, Any] = {
"searchType": "games",
"searchTerms": game_name.split(),
"searchPage": 1,
"size": 20,
"searchOptions": {
"games": {
"userId": 0,
"platform": "",
"sortCategory": "popular",
"rangeCategory": "main",
"rangeTime": {"min": 0, "max": 0},
"gameplay": {
"perspective": "",
"flow": "",
"genre": "",
"difficulty": "",
},
"rangeYear": {"max": "", "min": ""},
"modifier": "",
},
"users": {"sortCategory": "postcount"},
"lists": {"sortCategory": "follows"},
"filter": "",
"sort": 0,
"randomizer": 0,
},
"useCache": True,
}
if auth and auth.hp_key:
payload[auth.hp_key] = auth.hp_val
return json.dumps(payload)
def _build_search_variants(game_name: str) -> list[str]:
"""Return fallback search terms for one Steam game title."""
base = game_name.strip()
variants = [base]
no_year = re.sub(r"\s*\(\d{4}\)$", "", base).strip()
if no_year and no_year != base:
variants.append(no_year)
return variants
def _collect_candidates(
query_name: str,
data: dict[str, Any],
) -> list[tuple[dict[str, Any], float]]:
"""Build candidate list from one HLTB response payload."""
candidates: list[tuple[dict[str, Any], float]] = []
lower_name = query_name.lower()
for entry in data.get("data", []):
entry_name = entry.get("game_name", "")
entry_alias = entry.get("game_alias", "") or ""
is_dlc = str(entry.get("game_type", "")).lower() == "dlc"
sim = max(
_similarity(query_name, entry_name),
_similarity(query_name, entry_alias),
)
is_full_edition = (
(not is_dlc) and entry_name.lower().startswith(lower_name + ":")
) or ((not is_dlc) and entry_name.lower().startswith(lower_name + " -"))
if sim >= MIN_SIMILARITY or is_full_edition:
comp_100 = entry.get("comp_100", 0)
if comp_100 and comp_100 > 0:
candidates.append((entry, sim))
return candidates
def _build_result_from_best(
app_id: int,
original_name: str,
query_name: str,
best: tuple[dict[str, Any], float],
) -> HLTBResult:
"""Convert selected HLTB entry into HLTBResult."""
entry, sim = best
hours = round(entry["comp_100"] / 3600, 2)
logger.debug(
("HLTB match for '%s' via '%s': '%s' (id=%s, comp_100=%s, sim=%.3f)"),
original_name,
query_name,
entry.get("game_name"),
entry.get("game_id"),
entry.get("comp_100"),
sim,
)
return HLTBResult(
app_id=app_id,
game_name=original_name,
completionist_hours=hours,
similarity=sim,
hltb_game_id=entry.get("game_id", 0),
comp_100_count=int(entry.get("comp_100_count", 0) or 0),
count_comp=int(entry.get("count_comp", 0) or 0),
)
def _pick_best_hltb_entry(
search_name: str,
candidates: list[tuple[dict[str, Any], float]],
) -> tuple[dict[str, Any], float] | None:
"""Pick the best HLTB entry, preferring full editions over demos/chapters.
When a short name like "FAITH" matches both "FAITH" (demo) and
"FAITH: The Unholy Trinity" (full game), prefer the full game
since Steam often lists the full game under the shorter name.
When an exact match like "Timberman" (26 h) competes against an
unrelated subtitle entry like "Timberman: The Big Adventure" (2 h),
the exact match wins because it has more hours.
"""
if not candidates:
return None
# Prefer base games over DLC entries when both are present.
non_dlc = [c for c in candidates if str(c[0].get("game_type", "")).lower() != "dlc"]
usable = non_dlc or candidates
if len(usable) == 1:
return usable[0]
lower = search_name.lower()
best_exact = _find_exact_match(usable, lower)
best_extended = _find_best_extended(usable, lower)
return _resolve_exact_vs_extended(best_exact, best_extended, usable)
def _find_exact_match(
usable: list[tuple[dict[str, Any], float]],
lower: str,
) -> tuple[dict[str, Any], float] | None:
"""Find best exact name/alias match (highest comp_100)."""
return next(
(
(e, s)
for e, s in sorted(
usable,
key=lambda x: x[0].get("comp_100", 0),
reverse=True,
)
if (e.get("game_name") or "").lower() == lower
or (e.get("game_alias") or "").lower() == lower
),
None,
)
def _find_best_extended(
usable: list[tuple[dict[str, Any], float]],
lower: str,
) -> tuple[dict[str, Any], float] | None:
"""Find best extended entry ("Name: Subtitle" / "Name - Subtitle").
Skips subset entries (prologue, demo, etc.).
"""
best: tuple[dict[str, Any], float] | None = None
for entry, sim in usable:
game_type = str(entry.get("game_type", "")).lower()
if game_type not in ("", "game"):
continue
entry_name = (entry.get("game_name") or "").lower()
if entry_name.startswith((lower + ":", lower + " -")):
suffix = entry_name[len(lower) :].lstrip(" :-")
if not any(suffix.startswith(kw) for kw in _SUBSET_SUFFIXES) and (
best is None or entry.get("comp_100", 0) > best[0].get("comp_100", 0)
):
best = (entry, sim)
return best
def _resolve_exact_vs_extended(
best_exact: tuple[dict[str, Any], float] | None,
best_extended: tuple[dict[str, Any], float] | None,
usable: list[tuple[dict[str, Any], float]],
) -> tuple[dict[str, Any], float]:
"""Decide between exact match, extended entry, or highest similarity."""
if best_exact is not None and best_extended is not None:
exact_hours = best_exact[0].get("comp_100", 0)
extended_hours = best_extended[0].get("comp_100", 0)
exact_confidence = int(best_exact[0].get("comp_100_count", 0) or 0) + int(
best_exact[0].get("count_comp", 0) or 0
)
extended_confidence = int(best_extended[0].get("comp_100_count", 0) or 0) + int(
best_extended[0].get("count_comp", 0) or 0
)
# Prefer the extended entry only when it has strictly more hours
# than the exact match AND at least as much confidence.
# This lets "FAITH: The Unholy Trinity" (full game) beat
# a low-confidence exact demo while preventing low-confidence
# mods like "Celeste - Strawberry Jam" from beating
# the exact base game.
if extended_hours > exact_hours and extended_confidence >= exact_confidence:
return best_extended
return best_exact
if best_exact is not None:
return best_exact
if best_extended is not None:
return best_extended
# Fall back to highest similarity.
return max(usable, key=lambda x: x[1])
# ──────────────────────────────────────────────────────────────
# Async fetching with shared session & progress
# ──────────────────────────────────────────────────────────────
@dataclass
class _SearchCtx:
"""Shared context for HLTB search requests."""
session: aiohttp.ClientSession
search_url: str
headers: dict[str, str]
cache: dict[int, float]
polls: dict[int, int] = field(default_factory=dict)
count_comp: dict[int, int] = field(default_factory=dict)
auth: _AuthInfo | None = None
counter: dict[str, int] = field(default_factory=dict)
total: int = 0
progress_cb: ProgressCb | None = None
async def _search_one(
sem: asyncio.Semaphore,
ctx: _SearchCtx,
app_id: int,
name: str,
) -> HLTBResult | None:
"""Search HLTB for one game via direct POST, update cache."""
async with sem:
result: HLTBResult | None = None
for query_name in _build_search_variants(name):
payload = _build_search_payload(query_name, ctx.auth)
try:
async with ctx.session.post(
ctx.search_url,
headers=ctx.headers,
data=payload,
) as resp:
if resp.status != HTTPStatus.OK:
continue
data = await resp.json()
candidates = _collect_candidates(query_name, data)
best = _pick_best_hltb_entry(query_name, candidates)
if best is None:
continue
result = _build_result_from_best(app_id, name, query_name, best)
break
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
logger.debug("HLTB search failed for '%s': %s", query_name, exc)
# Update cache immediately (miss = -1).
if result is not None:
ctx.cache[app_id] = result.completionist_hours
ctx.polls[app_id] = result.comp_100_count
ctx.count_comp[app_id] = result.count_comp
ctx.counter["found"] += 1
else:
ctx.cache[app_id] = -1
ctx.polls[app_id] = 0
ctx.count_comp[app_id] = 0
ctx.counter["done"] += 1
done = ctx.counter["done"]
# Incremental save every _SAVE_INTERVAL lookups.
if not done % _SAVE_INTERVAL:
save_hltb_cache(ctx.cache, ctx.polls, ctx.count_comp)
# Report progress.
if ctx.progress_cb is not None:
ctx.progress_cb(done, ctx.total, ctx.counter["found"], name)
return result
async def _fetch_batch(
games: list[tuple[int, str]],
cache: dict[int, float],
polls: dict[int, int],
progress_cb: ProgressCb | None,
count_comp: dict[int, int] | None = None,
) -> list[HLTBResult]:
"""Fetch HLTB data for a batch of games using one shared session."""
# 1. Discover the search URL (sync, one-time).
search_url = _get_hltb_search_url()
logger.info("HLTB search URL: %s", search_url)
timeout = aiohttp.ClientTimeout(total=20, sock_read=15)
# 2. Get auth info (separate session — avoids reuse issues).
async with aiohttp.ClientSession(timeout=timeout) as init_session:
auth = await _get_auth_info(search_url, init_session)
if auth is None:
logger.warning("Could not get HLTB auth info, aborting fetch.")
return []
logger.info("HLTB auth token acquired.")
# 3. Build shared headers for all search requests.
headers: dict[str, str] = {
"content-type": "application/json",
"accept": "*/*",
"User-Agent": (
"Mozilla/5.0 (X11; Linux x86_64; rv:136.0) Gecko/20100101 Firefox/136.0"
),
"referer": "https://howlongtobeat.com/",
"x-auth-token": auth.token,
}
if auth.hp_key:
headers["x-hp-key"] = auth.hp_key
headers["x-hp-val"] = auth.hp_val
# 4. Fire all searches through a single persistent session.
sem = asyncio.Semaphore(MAX_CONCURRENT)
counter = {"done": 0, "found": 0}
total = len(games)
if count_comp is None:
count_comp = {}
connector = aiohttp.TCPConnector(
limit=MAX_CONCURRENT,
keepalive_timeout=30,
)
async with aiohttp.ClientSession(
timeout=timeout,
connector=connector,
) as session:
ctx = _SearchCtx(
session=session,
search_url=search_url,
headers=headers,
cache=cache,
polls=polls,
count_comp=count_comp,
auth=auth,
counter=counter,
total=total,
progress_cb=progress_cb,
)
tasks = [
_search_one(
sem,
ctx,
app_id,
name,
)
for app_id, name in games
]
results = await asyncio.gather(*tasks)
search_results = [r for r in results if r is not None]
# 5. Fetch leisure times + DLC from game detail pages.
logger.info(
"Fetching leisure times for %d games from detail pages...",
len(search_results),
)
await _fetch_leisure_times(
search_results,
cache,
polls,
progress_cb=None,
count_comp=count_comp,
)
return search_results

View File

@ -0,0 +1,249 @@
"""Confidence-checking and candidate-filtering helpers for scanning."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from python_pkg.steam_backlog_enforcer._hltb_types import (
load_hltb_cache,
load_hltb_count_comp_cache,
load_hltb_polls_cache,
save_hltb_cache,
)
from python_pkg.steam_backlog_enforcer.game_install import _echo
from python_pkg.steam_backlog_enforcer.hltb import fetch_hltb_confidence_cached
if TYPE_CHECKING:
from python_pkg.steam_backlog_enforcer.config import State
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
logger = logging.getLogger(__name__)
_MIN_COMP_100_POLLS = 3
_MIN_COUNT_COMP = 15
_MIN_CONFIDENCE_SUM = 18
def _apply_cached_confidence_to_candidates(candidates: list[GameInfo]) -> None:
"""Overlay cached confidence counters onto candidate game objects."""
polls_cache = load_hltb_polls_cache()
count_comp_cache = load_hltb_count_comp_cache()
for game in candidates:
if game.app_id in polls_cache:
game.comp_100_count = polls_cache[game.app_id]
if game.app_id in count_comp_cache:
game.count_comp = count_comp_cache[game.app_id]
def _confidence_fail_reasons(game: GameInfo) -> list[str]:
"""Return threshold-failure reasons for a game's HLTB confidence data."""
reasons: list[str] = []
if game.comp_100_count < _MIN_COMP_100_POLLS:
reasons.append(f"comp_100 polls {game.comp_100_count} < {_MIN_COMP_100_POLLS}")
if game.count_comp < _MIN_COUNT_COMP:
reasons.append(f"count_comp {game.count_comp} < {_MIN_COUNT_COMP}")
total = game.comp_100_count + game.count_comp
if total < _MIN_CONFIDENCE_SUM:
reasons.append(f"comp_100+count_comp {total} < {_MIN_CONFIDENCE_SUM}")
return reasons
def _refresh_candidate_confidence(game: GameInfo) -> None:
"""Refresh confidence metrics for one candidate when cache looks stale.
Only refreshes when both metrics are missing (0), which typically means
the game was cached before confidence fields were added.
"""
if game.comp_100_count > 0 or game.count_comp > 0:
return
_refresh_candidate_confidence_batch([game])
def _force_refresh_candidate_confidence(game: GameInfo) -> None:
"""Force-refresh one candidate's confidence metrics from HLTB."""
_refresh_candidate_confidence_batch([game], force=True)
def _refresh_candidate_confidence_batch(
candidates: list[GameInfo],
*,
force: bool = False,
) -> None:
"""Refresh missing confidence metrics for candidates in one HLTB batch.
This prevents O(N) one-game API loops when many snapshot entries predate
confidence fields and therefore have ``comp_100_count==0`` and
``count_comp==0``.
"""
missing = [
game
for game in candidates
if force or (game.comp_100_count == 0 and game.count_comp == 0)
]
if not missing:
return
refresh_slice = missing
if len(refresh_slice) == 1:
game = refresh_slice[0]
_echo(f" Refreshing HLTB confidence for {game.name} (AppID={game.app_id})...")
else:
_echo(f" Refreshing HLTB confidence for {len(refresh_slice)} candidate(s)...")
cache = load_hltb_cache()
polls = load_hltb_polls_cache()
count_comp = load_hltb_count_comp_cache()
app_ids = [game.app_id for game in refresh_slice]
names = [(game.app_id, game.name) for game in refresh_slice]
prior_hours = {aid: cache.get(aid, -1) for aid in app_ids}
for aid in app_ids:
cache.pop(aid, None)
polls.pop(aid, None)
count_comp.pop(aid, None)
save_hltb_cache(cache, polls, count_comp)
fetch_hltb_confidence_cached(names)
refreshed_hours = load_hltb_cache()
refreshed_polls = load_hltb_polls_cache()
refreshed_count_comp = load_hltb_count_comp_cache()
for aid, old_hours in prior_hours.items():
if old_hours > 0 and refreshed_hours.get(aid, -1) <= 0:
refreshed_hours[aid] = old_hours
save_hltb_cache(refreshed_hours, refreshed_polls, refreshed_count_comp)
for game in refresh_slice:
game.comp_100_count = refreshed_polls.get(game.app_id, 0)
game.count_comp = refreshed_count_comp.get(game.app_id, 0)
def _filter_hltb_confident_candidates(
candidates: list[GameInfo],
) -> list[GameInfo]:
"""Keep only candidates that satisfy HLTB confidence thresholds."""
_refresh_candidate_confidence_batch(candidates)
kept: list[GameInfo] = []
for game in candidates:
reasons = _confidence_fail_reasons(game)
if reasons:
_echo(
f" Skipping {game.name} (AppID={game.app_id}): "
f"HLTB confidence too low ({'; '.join(reasons)})"
)
continue
kept.append(game)
return kept
def _candidate_passes_hltb_confidence(game: GameInfo) -> bool:
"""Return True if candidate passes confidence with cache-first behavior.
Only refreshes when confidence fields are missing (both zero), which keeps
normal runs cache-friendly and avoids repeated refetches for known
low-confidence entries.
"""
reasons = _confidence_fail_reasons(game)
if not reasons:
return True
# Re-check once when confidence fields are missing in cache.
_refresh_candidate_confidence(game)
reasons = _confidence_fail_reasons(game)
if reasons:
_echo(
f" Skipping {game.name} (AppID={game.app_id}): "
f"HLTB confidence too low ({'; '.join(reasons)})"
)
return False
return True
def _backfill_polls_for_finished(
state: State,
games: list[GameInfo],
) -> dict[int, int]:
"""Lazily fetch poll counts for already-finished games missing them.
Reads the polls cache, identifies finished games whose poll count is
still ``0`` (typically because the cache predates the polls schema),
and triggers a one-shot HLTB search to backfill them. Returns the
refreshed polls cache.
"""
polls_cache = load_hltb_polls_cache()
name_by_id = {g.app_id: g.name for g in games}
missing = [
(aid, name_by_id[aid])
for aid in state.finished_app_ids
if aid in name_by_id and polls_cache.get(aid, 0) == 0
]
if not missing:
return polls_cache
logger.info(
"Backfilling HLTB poll counts for %d already-finished games...",
len(missing),
)
# Force a fresh search by removing the hours entries we want to refetch.
# (fetch_hltb_times_cached skips entries already in the hours cache.)
cache = load_hltb_cache()
preserved_hours = {aid: cache[aid] for aid, _ in missing if aid in cache}
for aid, _name in missing:
cache.pop(aid, None)
save_hltb_cache(cache, polls_cache)
fetch_hltb_confidence_cached(missing)
# Restore any previously-known hours that the refetch may have replaced
# with a worse match (we trust prior leisure+dlc estimates).
refreshed_hours = load_hltb_cache()
refreshed_polls = load_hltb_polls_cache()
for aid, prior_hours in preserved_hours.items():
if prior_hours > 0 and refreshed_hours.get(aid, -1) <= 0:
refreshed_hours[aid] = prior_hours
save_hltb_cache(refreshed_hours, refreshed_polls)
return refreshed_polls
def _report_poll_confidence(
chosen: GameInfo,
games: list[GameInfo],
state: State,
) -> None:
"""Print HLTB poll-count confidence info for the just-assigned game.
Shows the chosen game's ``comp_100_count`` (number of polled
completionist times on HowLongToBeat) and the historical minimum
among the user's previously-finished games. Marks a new historical
low so the user can be skeptical of unreliable estimates.
"""
polls_cache = _backfill_polls_for_finished(state, games)
chosen_polls = polls_cache.get(chosen.app_id, chosen.comp_100_count)
chosen.comp_100_count = chosen_polls
finished_polls = [
(polls_cache[aid], aid)
for aid in state.finished_app_ids
if polls_cache.get(aid, 0) > 0
]
if not finished_polls:
_echo(f" HLTB confidence: {chosen_polls} polled completionist times")
return
min_polls, min_aid = min(finished_polls)
name_by_id = {g.app_id: g.name for g in games}
min_name = name_by_id.get(min_aid, f"AppID={min_aid}")
warning = ""
if 0 < chosen_polls < min_polls:
warning = " ⚠ NEW LOW — estimate may be unreliable"
elif chosen_polls == 0:
warning = " ⚠ no polls recorded — estimate may be unreliable"
_echo(f" HLTB confidence: {chosen_polls} polled completionist times{warning}")
_echo(f" Historical min among finished: {min_polls} ({min_name})")

View File

@ -13,30 +13,23 @@ Fetches leisure completionist hour estimates from howlongtobeat.com with:
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from http import HTTPStatus
import json
import logging
import re
import time
from typing import Any
import aiohttp
from howlongtobeatpy.HTMLRequests import HTMLRequests
from python_pkg.steam_backlog_enforcer._hltb_detail import (
_fetch_leisure_times,
from python_pkg.steam_backlog_enforcer._hltb_search import (
_fetch_batch,
_get_auth_info,
_get_hltb_search_url,
_search_one,
_SearchCtx,
)
from python_pkg.steam_backlog_enforcer._hltb_types import (
_SAVE_INTERVAL,
_SUBSET_SUFFIXES,
HLTB_BASE_URL,
MAX_CONCURRENT,
MIN_SIMILARITY,
HLTBResult,
ProgressCb,
_AuthInfo,
load_hltb_cache,
load_hltb_count_comp_cache,
load_hltb_polls_cache,
@ -47,444 +40,8 @@ logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────────────────────
# HLTB API setup (done once, not per-request like the library)
# Confidence-only batch fetch (no leisure/DLC detail pages)
# ──────────────────────────────────────────────────────────────
def _get_hltb_search_url() -> str:
"""Discover the current HLTB search API endpoint.
Scrapes the homepage for JS bundles containing the fetch URL.
Falls back to ``/api/finder`` if extraction fails.
"""
try:
search_info = HTMLRequests.send_website_request_getcode(
parse_all_scripts=False,
)
if search_info is None:
search_info = HTMLRequests.send_website_request_getcode(
parse_all_scripts=True,
)
if search_info and search_info.search_url:
url: str = HTMLRequests.BASE_URL + search_info.search_url
return url
except (OSError, RuntimeError, ValueError, TypeError):
logger.debug("Failed to discover HLTB search URL, using default")
return "https://howlongtobeat.com/api/finder"
async def _get_auth_info(
search_url: str,
session: aiohttp.ClientSession,
) -> _AuthInfo | None:
"""Fetch the HLTB auth token and honeypot key/val (one GET request)."""
init_url = search_url + "/init"
ts = int(time.time() * 1000)
headers = {
"User-Agent": (
"Mozilla/5.0 (X11; Linux x86_64; rv:136.0) Gecko/20100101 Firefox/136.0"
),
"referer": "https://howlongtobeat.com/",
}
try:
async with session.get(
init_url,
params={"t": ts},
headers=headers,
) as resp:
if resp.status == HTTPStatus.OK:
data = await resp.json()
token: str | None = data.get("token")
if token is None:
return None
return _AuthInfo(
token=token,
hp_key=data.get("hpKey", ""),
hp_val=data.get("hpVal", ""),
)
except (aiohttp.ClientError, asyncio.TimeoutError):
logger.warning("Failed to get HLTB auth token")
return None
def _similarity(a: str, b: str) -> float:
"""Case-insensitive SequenceMatcher ratio between two strings."""
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
def _build_search_payload(game_name: str, auth: _AuthInfo | None = None) -> str:
"""Build the JSON POST body for an HLTB search."""
payload: dict[str, Any] = {
"searchType": "games",
"searchTerms": game_name.split(),
"searchPage": 1,
"size": 20,
"searchOptions": {
"games": {
"userId": 0,
"platform": "",
"sortCategory": "popular",
"rangeCategory": "main",
"rangeTime": {"min": 0, "max": 0},
"gameplay": {
"perspective": "",
"flow": "",
"genre": "",
"difficulty": "",
},
"rangeYear": {"max": "", "min": ""},
"modifier": "",
},
"users": {"sortCategory": "postcount"},
"lists": {"sortCategory": "follows"},
"filter": "",
"sort": 0,
"randomizer": 0,
},
"useCache": True,
}
if auth and auth.hp_key:
payload[auth.hp_key] = auth.hp_val
return json.dumps(payload)
def _build_search_variants(game_name: str) -> list[str]:
"""Return fallback search terms for one Steam game title."""
base = game_name.strip()
variants = [base]
no_year = re.sub(r"\s*\(\d{4}\)$", "", base).strip()
if no_year and no_year != base:
variants.append(no_year)
return variants
def _collect_candidates(
query_name: str,
data: dict[str, Any],
) -> list[tuple[dict[str, Any], float]]:
"""Build candidate list from one HLTB response payload."""
candidates: list[tuple[dict[str, Any], float]] = []
lower_name = query_name.lower()
for entry in data.get("data", []):
entry_name = entry.get("game_name", "")
entry_alias = entry.get("game_alias", "") or ""
is_dlc = str(entry.get("game_type", "")).lower() == "dlc"
sim = max(
_similarity(query_name, entry_name),
_similarity(query_name, entry_alias),
)
is_full_edition = (
(not is_dlc) and entry_name.lower().startswith(lower_name + ":")
) or ((not is_dlc) and entry_name.lower().startswith(lower_name + " -"))
if sim >= MIN_SIMILARITY or is_full_edition:
comp_100 = entry.get("comp_100", 0)
if comp_100 and comp_100 > 0:
candidates.append((entry, sim))
return candidates
def _build_result_from_best(
app_id: int,
original_name: str,
query_name: str,
best: tuple[dict[str, Any], float],
) -> HLTBResult:
"""Convert selected HLTB entry into HLTBResult."""
entry, sim = best
hours = round(entry["comp_100"] / 3600, 2)
logger.debug(
("HLTB match for '%s' via '%s': '%s' (id=%s, comp_100=%s, sim=%.3f)"),
original_name,
query_name,
entry.get("game_name"),
entry.get("game_id"),
entry.get("comp_100"),
sim,
)
return HLTBResult(
app_id=app_id,
game_name=original_name,
completionist_hours=hours,
similarity=sim,
hltb_game_id=entry.get("game_id", 0),
comp_100_count=int(entry.get("comp_100_count", 0) or 0),
count_comp=int(entry.get("count_comp", 0) or 0),
)
def _pick_best_hltb_entry(
search_name: str,
candidates: list[tuple[dict[str, Any], float]],
) -> tuple[dict[str, Any], float] | None:
"""Pick the best HLTB entry, preferring full editions over demos/chapters.
When a short name like "FAITH" matches both "FAITH" (demo) and
"FAITH: The Unholy Trinity" (full game), prefer the full game
since Steam often lists the full game under the shorter name.
When an exact match like "Timberman" (26 h) competes against an
unrelated subtitle entry like "Timberman: The Big Adventure" (2 h),
the exact match wins because it has more hours.
"""
if not candidates:
return None
# Prefer base games over DLC entries when both are present.
non_dlc = [c for c in candidates if str(c[0].get("game_type", "")).lower() != "dlc"]
usable = non_dlc or candidates
if len(usable) == 1:
return usable[0]
lower = search_name.lower()
best_exact = _find_exact_match(usable, lower)
best_extended = _find_best_extended(usable, lower)
return _resolve_exact_vs_extended(best_exact, best_extended, usable)
def _find_exact_match(
usable: list[tuple[dict[str, Any], float]],
lower: str,
) -> tuple[dict[str, Any], float] | None:
"""Find best exact name/alias match (highest comp_100)."""
return next(
(
(e, s)
for e, s in sorted(
usable,
key=lambda x: x[0].get("comp_100", 0),
reverse=True,
)
if (e.get("game_name") or "").lower() == lower
or (e.get("game_alias") or "").lower() == lower
),
None,
)
def _find_best_extended(
usable: list[tuple[dict[str, Any], float]],
lower: str,
) -> tuple[dict[str, Any], float] | None:
"""Find best extended entry ("Name: Subtitle" / "Name - Subtitle").
Skips subset entries (prologue, demo, etc.).
"""
best: tuple[dict[str, Any], float] | None = None
for entry, sim in usable:
game_type = str(entry.get("game_type", "")).lower()
if game_type not in ("", "game"):
continue
entry_name = (entry.get("game_name") or "").lower()
if entry_name.startswith((lower + ":", lower + " -")):
suffix = entry_name[len(lower) :].lstrip(" :-")
if not any(suffix.startswith(kw) for kw in _SUBSET_SUFFIXES) and (
best is None or entry.get("comp_100", 0) > best[0].get("comp_100", 0)
):
best = (entry, sim)
return best
def _resolve_exact_vs_extended(
best_exact: tuple[dict[str, Any], float] | None,
best_extended: tuple[dict[str, Any], float] | None,
usable: list[tuple[dict[str, Any], float]],
) -> tuple[dict[str, Any], float]:
"""Decide between exact match, extended entry, or highest similarity."""
if best_exact is not None and best_extended is not None:
exact_hours = best_exact[0].get("comp_100", 0)
extended_hours = best_extended[0].get("comp_100", 0)
exact_confidence = int(best_exact[0].get("comp_100_count", 0) or 0) + int(
best_exact[0].get("count_comp", 0) or 0
)
extended_confidence = int(best_extended[0].get("comp_100_count", 0) or 0) + int(
best_extended[0].get("count_comp", 0) or 0
)
# Prefer the extended entry only when it has strictly more hours
# than the exact match AND at least as much confidence.
# This lets "FAITH: The Unholy Trinity" (full game) beat
# a low-confidence exact demo while preventing low-confidence
# mods like "Celeste - Strawberry Jam" from beating
# the exact base game.
if extended_hours > exact_hours and extended_confidence >= exact_confidence:
return best_extended
return best_exact
if best_exact is not None:
return best_exact
if best_extended is not None:
return best_extended
# Fall back to highest similarity.
return max(usable, key=lambda x: x[1])
# ──────────────────────────────────────────────────────────────
# Async fetching with shared session & progress
# ──────────────────────────────────────────────────────────────
@dataclass
class _SearchCtx:
"""Shared context for HLTB search requests."""
session: aiohttp.ClientSession
search_url: str
headers: dict[str, str]
cache: dict[int, float]
polls: dict[int, int] = field(default_factory=dict)
count_comp: dict[int, int] = field(default_factory=dict)
auth: _AuthInfo | None = None
counter: dict[str, int] = field(default_factory=dict)
total: int = 0
progress_cb: ProgressCb | None = None
async def _search_one(
sem: asyncio.Semaphore,
ctx: _SearchCtx,
app_id: int,
name: str,
) -> HLTBResult | None:
"""Search HLTB for one game via direct POST, update cache."""
async with sem:
result: HLTBResult | None = None
for query_name in _build_search_variants(name):
payload = _build_search_payload(query_name, ctx.auth)
try:
async with ctx.session.post(
ctx.search_url,
headers=ctx.headers,
data=payload,
) as resp:
if resp.status != HTTPStatus.OK:
continue
data = await resp.json()
candidates = _collect_candidates(query_name, data)
best = _pick_best_hltb_entry(query_name, candidates)
if best is None:
continue
result = _build_result_from_best(app_id, name, query_name, best)
break
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
logger.debug("HLTB search failed for '%s': %s", query_name, exc)
# Update cache immediately (miss = -1).
if result is not None:
ctx.cache[app_id] = result.completionist_hours
ctx.polls[app_id] = result.comp_100_count
ctx.count_comp[app_id] = result.count_comp
ctx.counter["found"] += 1
else:
ctx.cache[app_id] = -1
ctx.polls[app_id] = 0
ctx.count_comp[app_id] = 0
ctx.counter["done"] += 1
done = ctx.counter["done"]
# Incremental save every _SAVE_INTERVAL lookups.
if not done % _SAVE_INTERVAL:
save_hltb_cache(ctx.cache, ctx.polls, ctx.count_comp)
# Report progress.
if ctx.progress_cb is not None:
ctx.progress_cb(done, ctx.total, ctx.counter["found"], name)
return result
async def _fetch_batch(
games: list[tuple[int, str]],
cache: dict[int, float],
polls: dict[int, int],
progress_cb: ProgressCb | None,
count_comp: dict[int, int] | None = None,
) -> list[HLTBResult]:
"""Fetch HLTB data for a batch of games using one shared session."""
# 1. Discover the search URL (sync, one-time).
search_url = _get_hltb_search_url()
logger.info("HLTB search URL: %s", search_url)
timeout = aiohttp.ClientTimeout(total=20, sock_read=15)
# 2. Get auth info (separate session — avoids reuse issues).
async with aiohttp.ClientSession(timeout=timeout) as init_session:
auth = await _get_auth_info(search_url, init_session)
if auth is None:
logger.warning("Could not get HLTB auth info, aborting fetch.")
return []
logger.info("HLTB auth token acquired.")
# 3. Build shared headers for all search requests.
headers: dict[str, str] = {
"content-type": "application/json",
"accept": "*/*",
"User-Agent": (
"Mozilla/5.0 (X11; Linux x86_64; rv:136.0) Gecko/20100101 Firefox/136.0"
),
"referer": "https://howlongtobeat.com/",
"x-auth-token": auth.token,
}
if auth.hp_key:
headers["x-hp-key"] = auth.hp_key
headers["x-hp-val"] = auth.hp_val
# 4. Fire all searches through a single persistent session.
sem = asyncio.Semaphore(MAX_CONCURRENT)
counter = {"done": 0, "found": 0}
total = len(games)
if count_comp is None:
count_comp = {}
connector = aiohttp.TCPConnector(
limit=MAX_CONCURRENT,
keepalive_timeout=30,
)
async with aiohttp.ClientSession(
timeout=timeout,
connector=connector,
) as session:
ctx = _SearchCtx(
session=session,
search_url=search_url,
headers=headers,
cache=cache,
polls=polls,
count_comp=count_comp,
auth=auth,
counter=counter,
total=total,
progress_cb=progress_cb,
)
tasks = [
_search_one(
sem,
ctx,
app_id,
name,
)
for app_id, name in games
]
results = await asyncio.gather(*tasks)
search_results = [r for r in results if r is not None]
# 5. Fetch leisure times + DLC from game detail pages.
logger.info(
"Fetching leisure times for %d games from detail pages...",
len(search_results),
)
await _fetch_leisure_times(
search_results,
cache,
polls,
progress_cb=None,
count_comp=count_comp,
)
return search_results
async def _fetch_batch_confidence_only(
games: list[tuple[int, str]],
cache: dict[int, float],

View File

@ -12,6 +12,7 @@ from python_pkg.steam_backlog_enforcer._enforce_loop import (
do_enforce,
get_all_owned_app_ids,
)
from python_pkg.steam_backlog_enforcer._hltb_types import load_hltb_cache
from python_pkg.steam_backlog_enforcer._whitelist import (
WHITELIST_COOLDOWN_SECONDS,
add_pending_exception,
@ -40,6 +41,7 @@ from python_pkg.steam_backlog_enforcer.library_hider import (
from python_pkg.steam_backlog_enforcer.scanning import (
do_check,
do_scan,
pick_next_game,
)
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
from python_pkg.steam_backlog_enforcer.store_blocker import (
@ -355,6 +357,29 @@ def cmd_unhide(config: Config, _state: State) -> None:
_echo("Done!")
def cmd_pick(config: Config, state: State) -> None:
"""Manually pick a new game from the shortest-first candidate list."""
snapshot_data = load_snapshot()
if not snapshot_data:
_echo("No snapshot found. Run 'scan' first.")
return
games = [GameInfo.from_snapshot(d) for d in snapshot_data]
hltb_cache = load_hltb_cache()
for game in games:
if game.app_id in hltb_cache:
game.completionist_hours = hltb_cache[game.app_id]
pick_next_game(games, state, config)
if state.current_app_id is not None:
owned_ids = get_all_owned_app_ids(config)
if owned_ids:
hidden = hide_other_games(owned_ids, state.current_app_id)
if hidden > 0:
_echo(f"\n Library: hid {hidden} games")
COMMANDS: dict[str, tuple[str, Callable[[Config, State], object]]] = {
"scan": ("Scan library & assign a game", do_scan),
"check": ("Check assigned game completion", do_check),
@ -371,6 +396,7 @@ COMMANDS: dict[str, tuple[str, Callable[[Config, State], object]]] = {
"uninstall": ("Uninstall all non-assigned games", cmd_uninstall),
"setup": ("Run first-time setup", cmd_setup),
"done": ("Finish game, open HLTB, pick next", cmd_done),
"pick": ("Manually pick your next game from candidates", cmd_pick),
}
# Extra commands with non-standard arg handling (shown in help but not in COMMANDS).

View File

@ -7,10 +7,13 @@ import time
from typing import Any
from python_pkg.steam_backlog_enforcer._hltb_types import (
load_hltb_cache,
load_hltb_count_comp_cache,
load_hltb_polls_cache,
save_hltb_cache,
)
from python_pkg.steam_backlog_enforcer._scanning_confidence import (
_apply_cached_confidence_to_candidates,
_candidate_passes_hltb_confidence,
_report_poll_confidence,
)
from python_pkg.steam_backlog_enforcer.config import (
Config,
@ -28,7 +31,6 @@ from python_pkg.steam_backlog_enforcer.game_install import (
uninstall_other_games,
)
from python_pkg.steam_backlog_enforcer.hltb import (
fetch_hltb_confidence_cached,
fetch_hltb_times_cached,
)
from python_pkg.steam_backlog_enforcer.protondb import (
@ -40,10 +42,6 @@ from python_pkg.steam_backlog_enforcer.steam_api import GameInfo, SteamAPIClient
logger = logging.getLogger(__name__)
_TAMPER_CHECK_LIMIT = 3
_MIN_COMP_100_POLLS = 3
_MIN_COUNT_COMP = 15
_MIN_CONFIDENCE_SUM = 18
# ──────────────────────────────────────────────────────────────
# Scanning & game selection
@ -162,220 +160,131 @@ def _pick_playable_candidate(
return None
def pick_next_game(games: list[GameInfo], state: State, config: Config) -> None:
"""Select the next game: shortest completionist time first.
_PICK_LIST_SIZE = 10
Games with silver-or-worse ProtonDB ratings (or gold trending
downward) are automatically skipped as unplayable on Linux.
"""
skip = set(state.finished_app_ids)
candidates = [g for g in games if not g.is_complete and g.app_id not in skip]
if not candidates:
_echo(
_NO_CONF_MSG = (
"\nNo assignable games found "
"(HLTB confidence thresholds: comp_100 polls>=3, "
"count_comp>=15, sum>=18)."
)
state.current_app_id = None
state.current_game_name = ""
state.save()
return
# Sort: games with known HLTB time first (shortest), then unknown.
def sort_key(g: GameInfo) -> tuple[int, float]:
def _sort_key(g: GameInfo) -> tuple[int, float]:
"""Sort by known HLTB time (shortest first), then unknown games."""
if g.completionist_hours > 0:
return (0, g.completionist_hours)
return (1, g.name.lower().encode().hex().__hash__())
candidates.sort(key=sort_key)
_apply_cached_confidence_to_candidates(candidates)
chosen, confidence_skipped, linux_skipped = _pick_next_shortest_candidate(
candidates
)
if chosen is None:
if confidence_skipped > 0 and linux_skipped == 0:
_echo(
"\nNo assignable games found "
"(HLTB confidence thresholds: comp_100 polls>=3, "
"count_comp>=15, sum>=18)."
)
def _collect_qualified_candidates(
candidates: list[GameInfo],
) -> tuple[list[GameInfo], int, int]:
"""Collect up to _PICK_LIST_SIZE playable, HLTB-confident candidates."""
qualified: list[GameInfo] = []
confidence_skipped = 0
linux_skipped = 0
for game in candidates:
if len(qualified) >= _PICK_LIST_SIZE:
break
if not _candidate_passes_hltb_confidence(game):
confidence_skipped += 1
continue
playable = _pick_playable_candidate([game])
if playable is not None:
qualified.append(playable)
else:
_echo("\nNo playable games left (all have poor ProtonDB ratings)!")
state.current_app_id = None
state.current_game_name = ""
state.save()
return
linux_skipped += 1
return qualified, confidence_skipped, linux_skipped
def _prompt_user_pick(qualified: list[GameInfo]) -> int:
"""Present numbered list, return 0-based index of user's choice."""
for i, g in enumerate(qualified, 1):
hours_str = (
f" (~{g.completionist_hours:.1f}h)" if g.completionist_hours > 0 else ""
)
_echo(f" {i}. {g.name} (AppID={g.app_id}){hours_str}")
while True:
raw = input("Select game number: ")
try:
idx = int(raw)
except ValueError:
_echo(f"Invalid input: {raw!r}")
continue
if idx < 1 or idx > len(qualified):
_echo(f"Out of range: {idx}")
continue
return idx - 1
def _assign_chosen_game(
chosen: GameInfo,
games: list[GameInfo],
state: State,
config: Config,
) -> None:
"""Save assignment, announce it, and handle install/uninstall."""
state.current_app_id = chosen.app_id
state.current_game_name = chosen.name
state.save()
hours_str = ""
if chosen.completionist_hours > 0:
hours_str = f" (~{chosen.completionist_hours:.1f}h leisure+dlc)"
hours_str = (
f" (~{chosen.completionist_hours:.1f}h leisure+dlc)"
if chosen.completionist_hours > 0
else ""
)
_echo(f"\n>>> ASSIGNED: {chosen.name} (AppID={chosen.app_id}){hours_str}")
_echo(
f" Progress: {chosen.unlocked_achievements}/{chosen.total_achievements}"
f" ({chosen.completion_pct:.1f}%)"
)
_report_poll_confidence(chosen, games, state)
# Uninstall all other games first, then auto-install the assigned one.
if config.uninstall_other_games:
count = uninstall_other_games(chosen.app_id)
if count:
_echo(f"\n Uninstalled {count} non-assigned games")
if not is_game_installed(chosen.app_id):
_echo(f"\n Auto-installing {chosen.name}...")
install_game(
chosen.app_id,
chosen.name,
config.steam_id,
use_steam_protocol=True,
chosen.app_id, chosen.name, config.steam_id, use_steam_protocol=True
)
def _apply_cached_confidence_to_candidates(candidates: list[GameInfo]) -> None:
"""Overlay cached confidence counters onto candidate game objects."""
polls_cache = load_hltb_polls_cache()
count_comp_cache = load_hltb_count_comp_cache()
for game in candidates:
if game.app_id in polls_cache:
game.comp_100_count = polls_cache[game.app_id]
if game.app_id in count_comp_cache:
game.count_comp = count_comp_cache[game.app_id]
def pick_next_game(games: list[GameInfo], state: State, config: Config) -> None:
"""Present a ranked list of eligible games and let the user pick one.
def _confidence_fail_reasons(game: GameInfo) -> list[str]:
"""Return threshold-failure reasons for a game's HLTB confidence data."""
reasons: list[str] = []
if game.comp_100_count < _MIN_COMP_100_POLLS:
reasons.append(f"comp_100 polls {game.comp_100_count} < {_MIN_COMP_100_POLLS}")
if game.count_comp < _MIN_COUNT_COMP:
reasons.append(f"count_comp {game.count_comp} < {_MIN_COUNT_COMP}")
total = game.comp_100_count + game.count_comp
if total < _MIN_CONFIDENCE_SUM:
reasons.append(f"comp_100+count_comp {total} < {_MIN_CONFIDENCE_SUM}")
return reasons
def _refresh_candidate_confidence(game: GameInfo) -> None:
"""Refresh confidence metrics for one candidate when cache looks stale.
Only refreshes when both metrics are missing (0), which typically means
the game was cached before confidence fields were added.
Games are ranked by shortest completionist time first. Games with
silver-or-worse ProtonDB ratings (or gold trending downward) are
excluded as unplayable on Linux.
"""
if game.comp_100_count > 0 or game.count_comp > 0:
skip = set(state.finished_app_ids)
candidates = [g for g in games if not g.is_complete and g.app_id not in skip]
if not candidates:
_echo(_NO_CONF_MSG)
state.current_app_id = None
state.current_game_name = ""
state.save()
return
_refresh_candidate_confidence_batch([game])
candidates.sort(key=_sort_key)
_apply_cached_confidence_to_candidates(candidates)
qualified, confidence_skipped, linux_skipped = _collect_qualified_candidates(
candidates
)
def _force_refresh_candidate_confidence(game: GameInfo) -> None:
"""Force-refresh one candidate's confidence metrics from HLTB."""
_refresh_candidate_confidence_batch([game], force=True)
def _refresh_candidate_confidence_batch(
candidates: list[GameInfo],
*,
force: bool = False,
) -> None:
"""Refresh missing confidence metrics for candidates in one HLTB batch.
This prevents O(N) one-game API loops when many snapshot entries predate
confidence fields and therefore have ``comp_100_count==0`` and
``count_comp==0``.
"""
missing = [
game
for game in candidates
if force or (game.comp_100_count == 0 and game.count_comp == 0)
]
if not missing:
if not qualified:
_echo(
_NO_CONF_MSG
if confidence_skipped > 0 and linux_skipped == 0
else "\nNo playable games left (all have poor ProtonDB ratings)!"
)
state.current_app_id = None
state.current_game_name = ""
state.save()
return
refresh_slice = missing
if len(refresh_slice) == 1:
game = refresh_slice[0]
_echo(f" Refreshing HLTB confidence for {game.name} (AppID={game.app_id})...")
else:
_echo(f" Refreshing HLTB confidence for {len(refresh_slice)} candidate(s)...")
cache = load_hltb_cache()
polls = load_hltb_polls_cache()
count_comp = load_hltb_count_comp_cache()
app_ids = [game.app_id for game in refresh_slice]
names = [(game.app_id, game.name) for game in refresh_slice]
prior_hours = {aid: cache.get(aid, -1) for aid in app_ids}
for aid in app_ids:
cache.pop(aid, None)
polls.pop(aid, None)
count_comp.pop(aid, None)
save_hltb_cache(cache, polls, count_comp)
fetch_hltb_confidence_cached(names)
refreshed_hours = load_hltb_cache()
refreshed_polls = load_hltb_polls_cache()
refreshed_count_comp = load_hltb_count_comp_cache()
for aid, old_hours in prior_hours.items():
if old_hours > 0 and refreshed_hours.get(aid, -1) <= 0:
refreshed_hours[aid] = old_hours
save_hltb_cache(refreshed_hours, refreshed_polls, refreshed_count_comp)
for game in refresh_slice:
game.comp_100_count = refreshed_polls.get(game.app_id, 0)
game.count_comp = refreshed_count_comp.get(game.app_id, 0)
def _filter_hltb_confident_candidates(
candidates: list[GameInfo],
) -> list[GameInfo]:
"""Keep only candidates that satisfy HLTB confidence thresholds."""
_refresh_candidate_confidence_batch(candidates)
kept: list[GameInfo] = []
for game in candidates:
reasons = _confidence_fail_reasons(game)
if reasons:
_echo(
f" Skipping {game.name} (AppID={game.app_id}): "
f"HLTB confidence too low ({'; '.join(reasons)})"
)
continue
kept.append(game)
return kept
def _candidate_passes_hltb_confidence(game: GameInfo) -> bool:
"""Return True if candidate passes confidence with cache-first behavior.
Only refreshes when confidence fields are missing (both zero), which keeps
normal runs cache-friendly and avoids repeated refetches for known
low-confidence entries.
"""
reasons = _confidence_fail_reasons(game)
if not reasons:
return True
# Re-check once when confidence fields are missing in cache.
_refresh_candidate_confidence(game)
reasons = _confidence_fail_reasons(game)
if reasons:
_echo(
f" Skipping {game.name} (AppID={game.app_id}): "
f"HLTB confidence too low ({'; '.join(reasons)})"
)
return False
return True
idx = _prompt_user_pick(qualified)
_assign_chosen_game(qualified[idx], games, state, config)
def _pick_next_shortest_candidate(
@ -407,89 +316,32 @@ def _pick_next_shortest_candidate(
return None, confidence_skipped, linux_skipped
def _backfill_polls_for_finished(
state: State,
games: list[GameInfo],
) -> dict[int, int]:
"""Lazily fetch poll counts for already-finished games missing them.
def _collect_top_candidates(
candidates: list[GameInfo],
n: int = 3,
) -> tuple[list[GameInfo], int, int]:
"""Collect up to n candidates that pass the Linux compatibility gate.
Reads the polls cache, identifies finished games whose poll count is
still ``0`` (typically because the cache predates the polls schema),
and triggers a one-shot HLTB search to backfill them. Returns the
refreshed polls cache.
Args:
candidates: Pre-sorted list of candidate games.
n: Maximum number of qualified games to collect.
Returns:
Tuple of (qualified_list, conf_skipped, linux_skipped).
"""
polls_cache = load_hltb_polls_cache()
name_by_id = {g.app_id: g.name for g in games}
missing = [
(aid, name_by_id[aid])
for aid in state.finished_app_ids
if aid in name_by_id and polls_cache.get(aid, 0) == 0
]
if not missing:
return polls_cache
logger.info(
"Backfilling HLTB poll counts for %d already-finished games...",
len(missing),
)
# Force a fresh search by removing the hours entries we want to refetch.
# (fetch_hltb_times_cached skips entries already in the hours cache.)
cache = load_hltb_cache()
preserved_hours = {aid: cache[aid] for aid, _ in missing if aid in cache}
for aid, _name in missing:
cache.pop(aid, None)
save_hltb_cache(cache, polls_cache)
fetch_hltb_confidence_cached(missing)
# Restore any previously-known hours that the refetch may have replaced
# with a worse match (we trust prior leisure+dlc estimates).
refreshed_hours = load_hltb_cache()
refreshed_polls = load_hltb_polls_cache()
for aid, prior_hours in preserved_hours.items():
if prior_hours > 0 and refreshed_hours.get(aid, -1) <= 0:
refreshed_hours[aid] = prior_hours
save_hltb_cache(refreshed_hours, refreshed_polls)
return refreshed_polls
def _report_poll_confidence(
chosen: GameInfo,
games: list[GameInfo],
state: State,
) -> None:
"""Print HLTB poll-count confidence info for the just-assigned game.
Shows the chosen game's ``comp_100_count`` (number of polled
completionist times on HowLongToBeat) and the historical minimum
among the user's previously-finished games. Marks a new historical
low so the user can be skeptical of unreliable estimates.
"""
polls_cache = _backfill_polls_for_finished(state, games)
chosen_polls = polls_cache.get(chosen.app_id, chosen.comp_100_count)
chosen.comp_100_count = chosen_polls
finished_polls = [
(polls_cache[aid], aid)
for aid in state.finished_app_ids
if polls_cache.get(aid, 0) > 0
]
if not finished_polls:
_echo(f" HLTB confidence: {chosen_polls} polled completionist times")
return
min_polls, min_aid = min(finished_polls)
name_by_id = {g.app_id: g.name for g in games}
min_name = name_by_id.get(min_aid, f"AppID={min_aid}")
warning = ""
if 0 < chosen_polls < min_polls:
warning = " ⚠ NEW LOW — estimate may be unreliable"
elif chosen_polls == 0:
warning = " ⚠ no polls recorded — estimate may be unreliable"
_echo(f" HLTB confidence: {chosen_polls} polled completionist times{warning}")
_echo(f" Historical min among finished: {min_polls} ({min_name})")
qualified: list[GameInfo] = []
linux_skipped = 0
for game in candidates:
if len(qualified) >= n:
break
playable = _pick_playable_candidate([game])
if playable is not None:
qualified.append(playable)
else:
linux_skipped += 1
if linux_skipped > 0:
_echo(f" Skipped {linux_skipped} game(s) with poor Linux compatibility")
return qualified, 0, linux_skipped
# ──────────────────────────────────────────────────────────────

View File

@ -5,7 +5,6 @@ from __future__ import annotations
from unittest.mock import patch
from python_pkg.steam_backlog_enforcer._cmd_done import (
_should_reassign_candidate,
_try_reassign_shorter_game,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
@ -446,186 +445,3 @@ class TestTryReassignShorterGame:
assert not result
mock_pick.assert_not_called()
def test_reassigns_when_current_hours_unknown(self) -> None:
"""If current game has unknown hours, allow a confident replacement."""
snap = [
_snap(app_id=1, name="Current", unlocked_achievements=5),
_snap(
app_id=2, name="Known", unlocked_achievements=5, completionist_hours=9.0
),
]
state = State(current_app_id=2, current_game_name="Known")
known_game = GameInfo(
app_id=2,
name="Known",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=9.0,
comp_100_count=3,
count_comp=15,
)
with (
patch(f"{CMD_DONE_PKG}.load_snapshot", return_value=snap),
patch(
f"{CMD_DONE_PKG}._pick_next_shortest_candidate",
return_value=(known_game, 0, 0),
),
patch(f"{CMD_DONE_PKG}.pick_next_game"),
patch(f"{CMD_DONE_PKG}.get_all_owned_app_ids", return_value=[]),
patch(f"{CMD_DONE_PKG}.hide_other_games"),
):
result = _try_reassign_shorter_game(
{2: 9.0},
1,
-1.0,
state,
Config(),
)
assert result
def test_try_reassign_returns_false_when_playable_not_shorter(self) -> None:
"""_try_reassign_shorter_game should not reassign to longer candidates."""
snap = [
_snap(
app_id=1,
name="Current",
unlocked_achievements=5,
completionist_hours=8.0,
comp_100_count=10,
count_comp=40,
),
_snap(
app_id=2,
name="Longer",
unlocked_achievements=5,
completionist_hours=12.0,
comp_100_count=10,
count_comp=40,
),
]
longer = GameInfo(
app_id=2,
name="Longer",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=12.0,
comp_100_count=10,
count_comp=40,
)
with (
patch(f"{CMD_DONE_PKG}.load_snapshot", return_value=snap),
patch(
f"{CMD_DONE_PKG}.load_hltb_polls_cache",
return_value={1: 10, 2: 10},
),
patch(
f"{CMD_DONE_PKG}.load_hltb_count_comp_cache",
return_value={1: 40, 2: 40},
),
patch(
f"{CMD_DONE_PKG}._pick_next_shortest_candidate",
return_value=(longer, 0, 0),
),
patch(f"{CMD_DONE_PKG}.pick_next_game") as mock_pick_next,
patch(f"{CMD_DONE_PKG}._echo"),
):
result = _try_reassign_shorter_game(
hltb_cache={1: 8.0, 2: 12.0},
app_id=1,
hours=8.0,
state=State(),
config=Config(),
)
assert not result
mock_pick_next.assert_not_called()
def test_try_reassign_stops_when_should_reassign_is_false(self) -> None:
"""Covers early return when policy says not to reassign."""
snap = [
_snap(
app_id=1,
name="Current",
unlocked_achievements=5,
completionist_hours=8.0,
comp_100_count=10,
count_comp=40,
),
_snap(
app_id=2,
name="Candidate",
unlocked_achievements=5,
completionist_hours=6.0,
comp_100_count=10,
count_comp=40,
),
]
candidate = GameInfo(
app_id=2,
name="Candidate",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=6.0,
comp_100_count=10,
count_comp=40,
)
with (
patch(f"{CMD_DONE_PKG}.load_snapshot", return_value=snap),
patch(
f"{CMD_DONE_PKG}.load_hltb_polls_cache",
return_value={1: 10, 2: 10},
),
patch(
f"{CMD_DONE_PKG}.load_hltb_count_comp_cache",
return_value={1: 40, 2: 40},
),
patch(
f"{CMD_DONE_PKG}._pick_next_shortest_candidate",
return_value=(candidate, 0, 0),
),
patch(
f"{CMD_DONE_PKG}._should_reassign_candidate",
return_value=False,
),
patch(f"{CMD_DONE_PKG}.pick_next_game") as mock_pick_next,
patch(f"{CMD_DONE_PKG}._echo"),
):
result = _try_reassign_shorter_game(
hltb_cache={1: 8.0, 2: 6.0},
app_id=1,
hours=8.0,
state=State(),
config=Config(),
)
assert not result
mock_pick_next.assert_not_called()
class TestShouldReassignCandidate:
"""Tests for _should_reassign_candidate."""
def test_returns_false_when_candidate_not_shorter(self) -> None:
candidate = GameInfo(
app_id=2,
name="Candidate",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=9.0,
comp_100_count=3,
count_comp=15,
)
should = _should_reassign_candidate(
candidate,
8.0,
force_reassign=False,
)
assert should is False

View File

@ -0,0 +1,217 @@
"""Tests for _cmd_done module (part 2)."""
from __future__ import annotations
from unittest.mock import patch
from python_pkg.steam_backlog_enforcer._cmd_done import (
_should_reassign_candidate,
_try_reassign_shorter_game,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
CMD_DONE_PKG = "python_pkg.steam_backlog_enforcer._cmd_done"
def _snap(**overrides: object) -> dict[str, object]:
snapshot: dict[str, object] = {
"app_id": 1,
"name": "G",
"total_achievements": 10,
"unlocked_achievements": 0,
"playtime_minutes": 60,
"completionist_hours": -1,
"comp_100_count": 3,
"count_comp": 15,
}
snapshot["app_id"] = overrides.get("app_id", 1)
snapshot.update(overrides)
return snapshot
class TestTryReassignShorterGame2:
"""Tests for _try_reassign_shorter_game (continued)."""
def test_reassigns_when_current_hours_unknown(self) -> None:
"""If current game has unknown hours, allow a confident replacement."""
snap = [
_snap(app_id=1, name="Current", unlocked_achievements=5),
_snap(
app_id=2, name="Known", unlocked_achievements=5, completionist_hours=9.0
),
]
state = State(current_app_id=2, current_game_name="Known")
known_game = GameInfo(
app_id=2,
name="Known",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=9.0,
comp_100_count=3,
count_comp=15,
)
with (
patch(f"{CMD_DONE_PKG}.load_snapshot", return_value=snap),
patch(
f"{CMD_DONE_PKG}._pick_next_shortest_candidate",
return_value=(known_game, 0, 0),
),
patch(f"{CMD_DONE_PKG}.pick_next_game"),
patch(f"{CMD_DONE_PKG}.get_all_owned_app_ids", return_value=[]),
patch(f"{CMD_DONE_PKG}.hide_other_games"),
):
result = _try_reassign_shorter_game(
{2: 9.0},
1,
-1.0,
state,
Config(),
)
assert result
def test_try_reassign_returns_false_when_playable_not_shorter(self) -> None:
"""_try_reassign_shorter_game should not reassign to longer candidates."""
snap = [
_snap(
app_id=1,
name="Current",
unlocked_achievements=5,
completionist_hours=8.0,
comp_100_count=10,
count_comp=40,
),
_snap(
app_id=2,
name="Longer",
unlocked_achievements=5,
completionist_hours=12.0,
comp_100_count=10,
count_comp=40,
),
]
longer = GameInfo(
app_id=2,
name="Longer",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=12.0,
comp_100_count=10,
count_comp=40,
)
with (
patch(f"{CMD_DONE_PKG}.load_snapshot", return_value=snap),
patch(
f"{CMD_DONE_PKG}.load_hltb_polls_cache",
return_value={1: 10, 2: 10},
),
patch(
f"{CMD_DONE_PKG}.load_hltb_count_comp_cache",
return_value={1: 40, 2: 40},
),
patch(
f"{CMD_DONE_PKG}._pick_next_shortest_candidate",
return_value=(longer, 0, 0),
),
patch(f"{CMD_DONE_PKG}.pick_next_game") as mock_pick_next,
patch(f"{CMD_DONE_PKG}._echo"),
):
result = _try_reassign_shorter_game(
hltb_cache={1: 8.0, 2: 12.0},
app_id=1,
hours=8.0,
state=State(),
config=Config(),
)
assert not result
mock_pick_next.assert_not_called()
def test_try_reassign_stops_when_should_reassign_is_false(self) -> None:
"""Covers early return when policy says not to reassign."""
snap = [
_snap(
app_id=1,
name="Current",
unlocked_achievements=5,
completionist_hours=8.0,
comp_100_count=10,
count_comp=40,
),
_snap(
app_id=2,
name="Candidate",
unlocked_achievements=5,
completionist_hours=6.0,
comp_100_count=10,
count_comp=40,
),
]
candidate = GameInfo(
app_id=2,
name="Candidate",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=6.0,
comp_100_count=10,
count_comp=40,
)
with (
patch(f"{CMD_DONE_PKG}.load_snapshot", return_value=snap),
patch(
f"{CMD_DONE_PKG}.load_hltb_polls_cache",
return_value={1: 10, 2: 10},
),
patch(
f"{CMD_DONE_PKG}.load_hltb_count_comp_cache",
return_value={1: 40, 2: 40},
),
patch(
f"{CMD_DONE_PKG}._pick_next_shortest_candidate",
return_value=(candidate, 0, 0),
),
patch(
f"{CMD_DONE_PKG}._should_reassign_candidate",
return_value=False,
),
patch(f"{CMD_DONE_PKG}.pick_next_game") as mock_pick_next,
patch(f"{CMD_DONE_PKG}._echo"),
):
result = _try_reassign_shorter_game(
hltb_cache={1: 8.0, 2: 6.0},
app_id=1,
hours=8.0,
state=State(),
config=Config(),
)
assert not result
mock_pick_next.assert_not_called()
class TestShouldReassignCandidate:
"""Tests for _should_reassign_candidate."""
def test_returns_false_when_candidate_not_shorter(self) -> None:
candidate = GameInfo(
app_id=2,
name="Candidate",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
completionist_hours=9.0,
comp_100_count=3,
count_comp=15,
)
should = _should_reassign_candidate(
candidate,
8.0,
force_reassign=False,
)
assert should is False

View File

@ -9,12 +9,10 @@ from unittest.mock import MagicMock, patch
from python_pkg.steam_backlog_enforcer._enforce_loop import (
_enforce_auto_install,
_enforce_hide_games,
_enforce_loop_iteration,
_enforce_setup,
_guard_installed_games,
_load_owned_app_ids_cache,
_save_owned_app_ids_cache,
do_enforce,
get_all_owned_app_ids,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
@ -373,185 +371,3 @@ class TestEnforceHideGames:
):
_enforce_hide_games(Config(), state)
assert any("skipped" in str(c) for c in mock_echo.call_args_list)
class TestEnforceLoopIteration:
"""Tests for _enforce_loop_iteration."""
def test_kills_unauthorized(self) -> None:
config = Config(
kill_unauthorized_games=True,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(
f"{PKG}.enforce_allowed_game",
return_value=[(1234, 999)],
),
patch(f"{PKG}.send_notification"),
patch(f"{PKG}._echo"),
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
def test_no_kill(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
mock_enforce.assert_not_called()
def test_guards_installed(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=True,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}._guard_installed_games", return_value=1),
patch(f"{PKG}._echo"),
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
def test_guard_removes_zero(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=True,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}._guard_installed_games", return_value=0),
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
def test_reinstalls_missing(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}.is_game_installed", return_value=False),
patch(f"{PKG}.install_game") as mock_install,
):
_enforce_loop_iteration(config, state)
mock_install.assert_called_once()
def test_no_app_id_skip_reinstall(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=None)
with (
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
patch(f"{PKG}._guard_installed_games") as mock_guard,
patch(f"{PKG}.is_game_installed") as mock_installed,
):
_enforce_loop_iteration(config, state)
mock_enforce.assert_not_called()
mock_guard.assert_not_called()
mock_installed.assert_not_called()
def test_promotes_newly_approved_exceptions(self) -> None:
"""Loop body at line 286 executes when promote returns non-empty list."""
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}.is_game_installed", return_value=True),
patch(
f"{PKG}.promote_pending_exceptions",
return_value=[440],
),
):
_enforce_loop_iteration(config, state)
class TestDoEnforce:
"""Tests for do_enforce."""
def test_no_game(self) -> None:
with patch(f"{PKG}._echo") as mock_echo:
do_enforce(Config(), State())
assert any("No game" in str(c) for c in mock_echo.call_args_list)
def test_keyboard_interrupt(self) -> None:
state = State(current_app_id=1, current_game_name="G")
config = Config()
fresh = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", return_value=fresh),
patch(
f"{PKG}._enforce_loop_iteration",
side_effect=KeyboardInterrupt,
),
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
def test_runs_iterations(self) -> None:
state = State(current_app_id=1, current_game_name="G")
config = Config()
fresh = State(current_app_id=1, current_game_name="G")
call_count = 0
def side_effect(*_args: object, **_kwargs: object) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise KeyboardInterrupt
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", return_value=fresh),
patch(
f"{PKG}._enforce_loop_iteration",
side_effect=side_effect,
),
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
assert call_count == 2
def test_state_load_failure_continues(self) -> None:
"""Corrupt state file should not crash the daemon."""
import json as json_mod
state = State(current_app_id=1, current_game_name="G")
config = Config()
call_count = 0
def load_side_effect() -> State:
nonlocal call_count
call_count += 1
if call_count == 1:
msg = "bad"
raise json_mod.JSONDecodeError(msg, "", 0)
if call_count == 2:
raise KeyboardInterrupt
return State(current_app_id=1) # pragma: no cover
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", side_effect=load_side_effect),
patch(f"{PKG}._enforce_loop_iteration") as mock_iter,
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
mock_iter.assert_not_called()

View File

@ -0,0 +1,195 @@
"""Tests for _enforce_loop module (part 2)."""
from __future__ import annotations
from unittest.mock import patch
from python_pkg.steam_backlog_enforcer._enforce_loop import (
_enforce_loop_iteration,
do_enforce,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
PKG = "python_pkg.steam_backlog_enforcer._enforce_loop"
class TestEnforceLoopIteration:
"""Tests for _enforce_loop_iteration."""
def test_kills_unauthorized(self) -> None:
config = Config(
kill_unauthorized_games=True,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(
f"{PKG}.enforce_allowed_game",
return_value=[(1234, 999)],
),
patch(f"{PKG}.send_notification"),
patch(f"{PKG}._echo"),
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
def test_no_kill(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
mock_enforce.assert_not_called()
def test_guards_installed(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=True,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}._guard_installed_games", return_value=1),
patch(f"{PKG}._echo"),
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
def test_guard_removes_zero(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=True,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}._guard_installed_games", return_value=0),
patch(f"{PKG}.is_game_installed", return_value=True),
):
_enforce_loop_iteration(config, state)
def test_reinstalls_missing(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}.is_game_installed", return_value=False),
patch(f"{PKG}.install_game") as mock_install,
):
_enforce_loop_iteration(config, state)
mock_install.assert_called_once()
def test_no_app_id_skip_reinstall(self) -> None:
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=None)
with (
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
patch(f"{PKG}._guard_installed_games") as mock_guard,
patch(f"{PKG}.is_game_installed") as mock_installed,
):
_enforce_loop_iteration(config, state)
mock_enforce.assert_not_called()
mock_guard.assert_not_called()
mock_installed.assert_not_called()
def test_promotes_newly_approved_exceptions(self) -> None:
"""Loop body at line 286 executes when promote returns non-empty list."""
config = Config(
kill_unauthorized_games=False,
uninstall_other_games=False,
)
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}.is_game_installed", return_value=True),
patch(
f"{PKG}.promote_pending_exceptions",
return_value=[440],
),
):
_enforce_loop_iteration(config, state)
class TestDoEnforce:
"""Tests for do_enforce."""
def test_no_game(self) -> None:
with patch(f"{PKG}._echo") as mock_echo:
do_enforce(Config(), State())
assert any("No game" in str(c) for c in mock_echo.call_args_list)
def test_keyboard_interrupt(self) -> None:
state = State(current_app_id=1, current_game_name="G")
config = Config()
fresh = State(current_app_id=1, current_game_name="G")
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", return_value=fresh),
patch(
f"{PKG}._enforce_loop_iteration",
side_effect=KeyboardInterrupt,
),
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
def test_runs_iterations(self) -> None:
state = State(current_app_id=1, current_game_name="G")
config = Config()
fresh = State(current_app_id=1, current_game_name="G")
call_count = 0
def side_effect(*_args: object, **_kwargs: object) -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise KeyboardInterrupt
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", return_value=fresh),
patch(
f"{PKG}._enforce_loop_iteration",
side_effect=side_effect,
),
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
assert call_count == 2
def test_state_load_failure_continues(self) -> None:
"""Corrupt state file should not crash the daemon."""
import json as json_mod
state = State(current_app_id=1, current_game_name="G")
config = Config()
call_count = 0
def load_side_effect() -> State:
nonlocal call_count
call_count += 1
if call_count == 1:
msg = "bad"
raise json_mod.JSONDecodeError(msg, "", 0)
if call_count == 2:
raise KeyboardInterrupt
return State(current_app_id=1) # pragma: no cover
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", side_effect=load_side_effect),
patch(f"{PKG}._enforce_loop_iteration") as mock_iter,
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
mock_iter.assert_not_called()

View File

@ -14,11 +14,7 @@ from python_pkg.steam_backlog_enforcer.game_install import (
_ensure_steam_running,
_get_real_user,
_get_uid_gid_for_user,
_read_install_dir,
_remove_manifest,
_trigger_steam_install,
get_installed_games,
install_game,
is_game_installed,
)
@ -282,247 +278,3 @@ class TestEnsureSteamRunning:
),
):
_ensure_steam_running()
class TestInstallGame:
"""Tests for install_game."""
def test_already_installed(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.touch()
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
assert install_game(440, "TF2", "steam123") is True
def test_use_steam_protocol_success(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._trigger_steam_install",
return_value=True,
),
):
assert install_game(440, "TF2", "s1", use_steam_protocol=True) is True
def test_use_steam_protocol_fallback(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._trigger_steam_install",
return_value=False,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
):
assert install_game(440, "TF2", "s1", use_steam_protocol=True) is True
assert (tmp_path / "appmanifest_440.acf").exists()
def test_manifest_write_as_root(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=0,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._get_real_user",
return_value="alice",
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._get_uid_gid_for_user",
return_value=(1001, 1001),
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.chown"
) as mock_chown,
):
assert install_game(440, "TF2", "s1") is True
mock_chown.assert_called_once()
def test_manifest_write_failure(self, tmp_path: Path) -> None:
# Make steamapps path not writable
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path / "nonexistent" / "deep",
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
):
assert install_game(440, "TF2", "s1") is False
def test_empty_game_name(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
):
assert install_game(440, "", "s1") is True
def test_manifest_not_root_no_chown(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.chown"
) as mock_chown,
):
assert install_game(440, "TF2", "s1") is True
mock_chown.assert_not_called()
def test_root_user_is_root(self, tmp_path: Path) -> None:
"""When real user IS root, don't chown."""
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=0,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._get_real_user",
return_value="root",
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.chown"
) as mock_chown,
):
assert install_game(440, "TF2", "s1") is True
mock_chown.assert_not_called()
class TestGetInstalledGames:
"""Tests for get_installed_games."""
def test_parses_manifests(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"appid"\t\t"440"\n"name"\t\t"Team Fortress 2"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == [(440, "Team Fortress 2")]
def test_no_name(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"appid"\t\t"440"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == [(440, "Unknown (440)")]
def test_empty_dir(self, tmp_path: Path) -> None:
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == []
def test_no_appid_match(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"name"\t\t"NoAppId"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == []
class TestReadInstallDir:
"""Tests for _read_install_dir."""
def test_reads_dir(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"installdir"\t\t"Team Fortress 2"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = _read_install_dir(manifest)
assert result == tmp_path / "common" / "Team Fortress 2"
def test_no_match(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"appid"\t\t"440"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
assert _read_install_dir(manifest) is None
def test_missing_file(self, tmp_path: Path) -> None:
manifest = tmp_path / "nonexistent.acf"
assert _read_install_dir(manifest) is None
def test_os_error(self, tmp_path: Path) -> None:
manifest = MagicMock()
manifest.exists.return_value = True
manifest.read_text.side_effect = OSError
assert _read_install_dir(manifest) is None
class TestRemoveManifest:
"""Tests for _remove_manifest."""
def test_removes(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.touch()
assert _remove_manifest(manifest, "TF2", 440) is True
assert not manifest.exists()
def test_already_gone(self, tmp_path: Path) -> None:
manifest = tmp_path / "nonexistent.acf"
assert _remove_manifest(manifest, "TF2", 440) is True
def test_os_error(self) -> None:
manifest = MagicMock()
manifest.exists.return_value = True
manifest.unlink.side_effect = OSError
assert _remove_manifest(manifest, "TF2", 440) is False

View File

@ -0,0 +1,263 @@
"""Tests for game_install module (part 3 — install, get, read, remove)."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from python_pkg.steam_backlog_enforcer.game_install import (
_read_install_dir,
_remove_manifest,
get_installed_games,
install_game,
)
if TYPE_CHECKING:
from pathlib import Path
PKG = "python_pkg.steam_backlog_enforcer.game_install"
class TestInstallGame:
"""Tests for install_game."""
def test_already_installed(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.touch()
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
assert install_game(440, "TF2", "steam123") is True
def test_use_steam_protocol_success(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._trigger_steam_install",
return_value=True,
),
):
assert install_game(440, "TF2", "s1", use_steam_protocol=True) is True
def test_use_steam_protocol_fallback(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._trigger_steam_install",
return_value=False,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
):
assert install_game(440, "TF2", "s1", use_steam_protocol=True) is True
assert (tmp_path / "appmanifest_440.acf").exists()
def test_manifest_write_as_root(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=0,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._get_real_user",
return_value="alice",
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._get_uid_gid_for_user",
return_value=(1001, 1001),
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.chown"
) as mock_chown,
):
assert install_game(440, "TF2", "s1") is True
mock_chown.assert_called_once()
def test_manifest_write_failure(self, tmp_path: Path) -> None:
# Make steamapps path not writable
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path / "nonexistent" / "deep",
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
):
assert install_game(440, "TF2", "s1") is False
def test_empty_game_name(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
):
assert install_game(440, "", "s1") is True
def test_manifest_not_root_no_chown(self, tmp_path: Path) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=1000,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.chown"
) as mock_chown,
):
assert install_game(440, "TF2", "s1") is True
mock_chown.assert_not_called()
def test_root_user_is_root(self, tmp_path: Path) -> None:
"""When real user IS root, don't chown."""
with (
patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH",
tmp_path,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._ensure_steam_running"
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.geteuid",
return_value=0,
),
patch(
"python_pkg.steam_backlog_enforcer.game_install._get_real_user",
return_value="root",
),
patch(
"python_pkg.steam_backlog_enforcer.game_install.os.chown"
) as mock_chown,
):
assert install_game(440, "TF2", "s1") is True
mock_chown.assert_not_called()
class TestGetInstalledGames:
"""Tests for get_installed_games."""
def test_parses_manifests(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"appid"\t\t"440"\n"name"\t\t"Team Fortress 2"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == [(440, "Team Fortress 2")]
def test_no_name(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"appid"\t\t"440"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == [(440, "Unknown (440)")]
def test_empty_dir(self, tmp_path: Path) -> None:
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == []
def test_no_appid_match(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"name"\t\t"NoAppId"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = get_installed_games()
assert result == []
class TestReadInstallDir:
"""Tests for _read_install_dir."""
def test_reads_dir(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"installdir"\t\t"Team Fortress 2"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
result = _read_install_dir(manifest)
assert result == tmp_path / "common" / "Team Fortress 2"
def test_no_match(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.write_text('"appid"\t\t"440"\n')
with patch(
"python_pkg.steam_backlog_enforcer.game_install.STEAMAPPS_PATH", tmp_path
):
assert _read_install_dir(manifest) is None
def test_missing_file(self, tmp_path: Path) -> None:
manifest = tmp_path / "nonexistent.acf"
assert _read_install_dir(manifest) is None
def test_os_error(self, tmp_path: Path) -> None:
manifest = MagicMock()
manifest.exists.return_value = True
manifest.read_text.side_effect = OSError
assert _read_install_dir(manifest) is None
class TestRemoveManifest:
"""Tests for _remove_manifest."""
def test_removes(self, tmp_path: Path) -> None:
manifest = tmp_path / "appmanifest_440.acf"
manifest.touch()
assert _remove_manifest(manifest, "TF2", 440) is True
assert not manifest.exists()
def test_already_gone(self, tmp_path: Path) -> None:
manifest = tmp_path / "nonexistent.acf"
assert _remove_manifest(manifest, "TF2", 440) is True
def test_os_error(self) -> None:
manifest = MagicMock()
manifest.exists.return_value = True
manifest.unlink.side_effect = OSError
assert _remove_manifest(manifest, "TF2", 440) is False

View File

@ -9,13 +9,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from python_pkg.steam_backlog_enforcer.hltb import (
from python_pkg.steam_backlog_enforcer._hltb_search import (
_AuthInfo,
_build_search_payload,
_get_auth_info,
_get_hltb_search_url,
_pick_best_hltb_entry,
_similarity,
)
from python_pkg.steam_backlog_enforcer.hltb import (
_get_auth_info,
load_hltb_cache,
save_hltb_cache,
)
@ -77,14 +79,18 @@ class TestGetHltbSearchUrl:
def test_discovers_url(self) -> None:
mock_info = MagicMock()
mock_info.search_url = "/api/search/abc"
with patch("python_pkg.steam_backlog_enforcer.hltb.HTMLRequests") as mock_html:
with patch(
"python_pkg.steam_backlog_enforcer._hltb_search.HTMLRequests"
) as mock_html:
mock_html.send_website_request_getcode.return_value = mock_info
mock_html.BASE_URL = "https://howlongtobeat.com"
url = _get_hltb_search_url()
assert url == "https://howlongtobeat.com/api/search/abc"
def test_fallback_url(self) -> None:
with patch("python_pkg.steam_backlog_enforcer.hltb.HTMLRequests") as mock_html:
with patch(
"python_pkg.steam_backlog_enforcer._hltb_search.HTMLRequests"
) as mock_html:
mock_html.send_website_request_getcode.return_value = None
url = _get_hltb_search_url()
assert url == "https://howlongtobeat.com/api/finder"
@ -92,14 +98,18 @@ class TestGetHltbSearchUrl:
def test_first_returns_none_second_returns_info(self) -> None:
mock_info = MagicMock()
mock_info.search_url = "/api/search/xyz"
with patch("python_pkg.steam_backlog_enforcer.hltb.HTMLRequests") as mock_html:
with patch(
"python_pkg.steam_backlog_enforcer._hltb_search.HTMLRequests"
) as mock_html:
mock_html.send_website_request_getcode.side_effect = [None, mock_info]
mock_html.BASE_URL = "https://howlongtobeat.com"
url = _get_hltb_search_url()
assert url == "https://howlongtobeat.com/api/search/xyz"
def test_exception_fallback(self) -> None:
with patch("python_pkg.steam_backlog_enforcer.hltb.HTMLRequests") as mock_html:
with patch(
"python_pkg.steam_backlog_enforcer._hltb_search.HTMLRequests"
) as mock_html:
mock_html.send_website_request_getcode.side_effect = RuntimeError
url = _get_hltb_search_url()
assert url == "https://howlongtobeat.com/api/finder"

View File

@ -7,10 +7,10 @@ from unittest.mock import MagicMock, patch
from typing_extensions import Self
from python_pkg.steam_backlog_enforcer._hltb_search import _AuthInfo
from python_pkg.steam_backlog_enforcer.hltb import (
HLTB_BASE_URL,
HLTBResult,
_AuthInfo,
_fetch_batch_confidence_only,
fetch_hltb_confidence,
fetch_hltb_confidence_cached,

View File

@ -3,26 +3,20 @@
from __future__ import annotations
import asyncio
import json
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from typing_extensions import Self
from python_pkg.steam_backlog_enforcer._hltb_detail import (
_extract_leisure_hours,
_parse_game_page,
)
from python_pkg.steam_backlog_enforcer.hltb import (
_SAVE_INTERVAL,
HLTBResult,
_AuthInfo,
from python_pkg.steam_backlog_enforcer._hltb_search import (
_fetch_batch,
_pick_best_hltb_entry,
_search_one,
_SearchCtx,
)
from python_pkg.steam_backlog_enforcer._hltb_types import (
_SAVE_INTERVAL,
)
if TYPE_CHECKING:
from collections.abc import Callable
@ -246,7 +240,7 @@ class TestSearchOne:
ctx.counter["done"] = _SAVE_INTERVAL - 1
with patch(
"python_pkg.steam_backlog_enforcer.hltb.save_hltb_cache"
"python_pkg.steam_backlog_enforcer._hltb_search.save_hltb_cache"
) as mock_save:
asyncio.run(_search_one(asyncio.Semaphore(1), ctx, 440, "TF2"))
mock_save.assert_called_once()
@ -258,11 +252,11 @@ class TestFetchBatchHltb:
def test_no_auth(self) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_hltb_search_url",
"python_pkg.steam_backlog_enforcer._hltb_search._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_auth_info",
"python_pkg.steam_backlog_enforcer._hltb_search._get_auth_info",
new_callable=AsyncMock,
return_value=None,
),
@ -273,260 +267,3 @@ class TestFetchBatchHltb:
class TestPickBestEntry:
"""Tests for exact-vs-extended entry choice logic."""
def test_prefers_exact_over_low_confidence_modded_extended(self) -> None:
exact = (
{
"game_name": "Celeste",
"game_alias": "",
"game_type": "game",
"comp_100": 141105,
"comp_100_count": 899,
"count_comp": 14055,
},
1.0,
)
mod_extended = (
{
"game_name": "Celeste - Strawberry Jam",
"game_alias": "",
"game_type": "mod",
"comp_100": 952080,
"comp_100_count": 1,
"count_comp": 6,
},
0.9,
)
best = _pick_best_hltb_entry("Celeste", [exact, mod_extended])
assert best is not None
assert best[0]["game_name"] == "Celeste"
def test_prefers_extended_when_confident_and_longer(self) -> None:
exact_demo = (
{
"game_name": "FAITH",
"game_alias": "",
"game_type": "game",
"comp_100": 1800,
"comp_100_count": 1,
"count_comp": 1,
},
1.0,
)
full_extended = (
{
"game_name": "FAITH: The Unholy Trinity",
"game_alias": "",
"game_type": "game",
"comp_100": 25200,
"comp_100_count": 50,
"count_comp": 500,
},
0.9,
)
best = _pick_best_hltb_entry("FAITH", [exact_demo, full_extended])
assert best is not None
assert best[0]["game_name"] == "FAITH: The Unholy Trinity"
def test_with_auth(self) -> None:
auth = _AuthInfo("token123", "ign_x", "ff")
with (
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_auth_info",
new_callable=AsyncMock,
return_value=auth,
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._search_one",
new_callable=AsyncMock,
return_value=HLTBResult(
app_id=440,
game_name="TF2",
completionist_hours=50.0,
similarity=1.0,
hltb_game_id=12345,
),
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._fetch_leisure_times",
new_callable=AsyncMock,
),
):
results = asyncio.run(_fetch_batch([(440, "TF2")], {}, {}, None))
assert len(results) == 1
def test_with_auth_no_hp(self) -> None:
auth = _AuthInfo("tok123")
with (
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_auth_info",
new_callable=AsyncMock,
return_value=auth,
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._search_one",
new_callable=AsyncMock,
return_value=None,
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._fetch_leisure_times",
new_callable=AsyncMock,
),
):
results = asyncio.run(_fetch_batch([(440, "TF2")], {}, {}, None))
assert results == []
def test_filters_none_results(self) -> None:
auth = _AuthInfo("tok123")
with (
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._get_auth_info",
new_callable=AsyncMock,
return_value=auth,
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._search_one",
new_callable=AsyncMock,
return_value=None,
),
patch(
"python_pkg.steam_backlog_enforcer.hltb._fetch_leisure_times",
new_callable=AsyncMock,
),
):
results = asyncio.run(_fetch_batch([(440, "TF2")], {}, {}, None))
assert results == []
class TestParseGamePage:
"""Tests for _parse_game_page."""
def test_valid_html(self) -> None:
game_data: dict[str, Any] = {
"game": [{"comp_100_h": 21243, "comp_100": 6800}],
"relationships": [],
}
next_data = {
"props": {"pageProps": {"game": {"data": game_data}}},
}
html = (
'<html><script id="__NEXT_DATA__" type="application/json">'
+ json.dumps(next_data)
+ "</script></html>"
)
assert _parse_game_page(html) == game_data
def test_no_script_tag(self) -> None:
assert _parse_game_page("<html></html>") is None
def test_bad_json(self) -> None:
html = '<script id="__NEXT_DATA__" type="application/json">{not json}</script>'
assert _parse_game_page(html) is None
def test_missing_keys(self) -> None:
html = (
'<script id="__NEXT_DATA__" type="application/json">{"props": {}}</script>'
)
assert _parse_game_page(html) is None
class TestExtractLeisureHours:
"""Tests for _extract_leisure_hours."""
def test_leisure_time_only(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 21243, "comp_100": 6800}],
"relationships": [],
}
assert _extract_leisure_hours(data) == round(21243 / 3600, 2)
def test_leisure_with_dlc(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 21243, "comp_100": 6800}],
"relationships": [
{"game_type": "dlc", "comp_100": 12298},
{"game_type": "dlc", "comp_100": 3600},
],
}
assert _extract_leisure_hours(data) == round((21243 + 12298 + 3600) / 3600, 2)
def test_fallback_to_comp_100(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100": 7200}],
"relationships": [],
}
assert _extract_leisure_hours(data) == round(7200 / 3600, 2)
def test_no_game_data(self) -> None:
assert _extract_leisure_hours({"game": [], "relationships": []}) == -1
def test_zero_leisure(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 0, "comp_100": 0}],
"relationships": [],
}
assert _extract_leisure_hours(data) == -1
def test_no_game_key(self) -> None:
assert _extract_leisure_hours({"relationships": []}) == -1
def test_non_dlc_relationship_ignored(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 3600}],
"relationships": [
{"game_type": "game", "comp_100": 9999},
{"game_type": "dlc", "comp_100": 1800},
],
}
assert _extract_leisure_hours(data) == round((3600 + 1800) / 3600, 2)
def test_dlc_zero_comp_100_skipped(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 3600}],
"relationships": [
{"game_type": "dlc", "comp_100": 0},
],
}
assert _extract_leisure_hours(data) == round(3600 / 3600, 2)
def test_negative_leisure(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": -1, "comp_100": -1}],
"relationships": [],
}
assert _extract_leisure_hours(data) == -1
def test_string_numeric_fields(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": "7200", "comp_100": "3600"}],
"relationships": [{"game_type": "dlc", "game_id": "1", "comp_100": "1800"}],
}
assert _extract_leisure_hours(data) == round((7200 + 1800) / 3600, 2)
def test_bad_string_falls_back_to_comp_100(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": "bad", "comp_100": "3600"}],
"relationships": [],
}
assert _extract_leisure_hours(data) == 1.0
def test_relationships_not_list(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 3600}],
"relationships": "not-a-list",
}
assert _extract_leisure_hours(data) == 1.0

View File

@ -0,0 +1,307 @@
"""Tests for HLTB search entry picking, page parsing, and leisure extraction."""
from __future__ import annotations
import asyncio
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from typing_extensions import Self
from python_pkg.steam_backlog_enforcer._hltb_detail import (
_extract_leisure_hours,
_parse_game_page,
)
from python_pkg.steam_backlog_enforcer._hltb_search import (
_fetch_batch,
_pick_best_hltb_entry,
)
from python_pkg.steam_backlog_enforcer._hltb_types import (
HLTBResult,
_AuthInfo,
)
class _FakeResponse:
"""Async context manager mimicking aiohttp response."""
def __init__(self, status: int, json_data: dict[str, Any] | None = None) -> None:
self.status = status
self._json_data = json_data or {}
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, *args: object) -> None:
pass
async def json(self) -> dict[str, Any]:
return self._json_data
def _make_session(resp: _FakeResponse) -> MagicMock:
session = MagicMock()
session.post.return_value = resp
return session
class TestPickBestEntry:
"""Tests for exact-vs-extended entry choice logic."""
def test_prefers_exact_over_low_confidence_modded_extended(self) -> None:
exact = (
{
"game_name": "Celeste",
"game_alias": "",
"game_type": "game",
"comp_100": 141105,
"comp_100_count": 899,
"count_comp": 14055,
},
1.0,
)
mod_extended = (
{
"game_name": "Celeste - Strawberry Jam",
"game_alias": "",
"game_type": "mod",
"comp_100": 952080,
"comp_100_count": 1,
"count_comp": 6,
},
0.9,
)
best = _pick_best_hltb_entry("Celeste", [exact, mod_extended])
assert best is not None
assert best[0]["game_name"] == "Celeste"
def test_prefers_extended_when_confident_and_longer(self) -> None:
exact_demo = (
{
"game_name": "FAITH",
"game_alias": "",
"game_type": "game",
"comp_100": 1800,
"comp_100_count": 1,
"count_comp": 1,
},
1.0,
)
full_extended = (
{
"game_name": "FAITH: The Unholy Trinity",
"game_alias": "",
"game_type": "game",
"comp_100": 25200,
"comp_100_count": 50,
"count_comp": 500,
},
0.9,
)
best = _pick_best_hltb_entry("FAITH", [exact_demo, full_extended])
assert best is not None
assert best[0]["game_name"] == "FAITH: The Unholy Trinity"
def test_with_auth(self) -> None:
auth = _AuthInfo("token123", "ign_x", "ff")
with (
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._get_auth_info",
new_callable=AsyncMock,
return_value=auth,
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._search_one",
new_callable=AsyncMock,
return_value=HLTBResult(
app_id=440,
game_name="TF2",
completionist_hours=50.0,
similarity=1.0,
hltb_game_id=12345,
),
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._fetch_leisure_times",
new_callable=AsyncMock,
),
):
results = asyncio.run(_fetch_batch([(440, "TF2")], {}, {}, None))
assert len(results) == 1
def test_with_auth_no_hp(self) -> None:
auth = _AuthInfo("tok123")
with (
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._get_auth_info",
new_callable=AsyncMock,
return_value=auth,
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._search_one",
new_callable=AsyncMock,
return_value=None,
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._fetch_leisure_times",
new_callable=AsyncMock,
),
):
results = asyncio.run(_fetch_batch([(440, "TF2")], {}, {}, None))
assert results == []
def test_filters_none_results(self) -> None:
auth = _AuthInfo("tok123")
with (
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._get_hltb_search_url",
return_value="https://example.com",
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._get_auth_info",
new_callable=AsyncMock,
return_value=auth,
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._search_one",
new_callable=AsyncMock,
return_value=None,
),
patch(
"python_pkg.steam_backlog_enforcer._hltb_search._fetch_leisure_times",
new_callable=AsyncMock,
),
):
results = asyncio.run(_fetch_batch([(440, "TF2")], {}, {}, None))
assert results == []
class TestParseGamePage:
"""Tests for _parse_game_page."""
def test_valid_html(self) -> None:
game_data: dict[str, Any] = {
"game": [{"comp_100_h": 21243, "comp_100": 6800}],
"relationships": [],
}
next_data = {
"props": {"pageProps": {"game": {"data": game_data}}},
}
html = (
'<html><script id="__NEXT_DATA__" type="application/json">'
+ json.dumps(next_data)
+ "</script></html>"
)
assert _parse_game_page(html) == game_data
def test_no_script_tag(self) -> None:
assert _parse_game_page("<html></html>") is None
def test_bad_json(self) -> None:
html = '<script id="__NEXT_DATA__" type="application/json">{not json}</script>'
assert _parse_game_page(html) is None
def test_missing_keys(self) -> None:
html = (
'<script id="__NEXT_DATA__" type="application/json">{"props": {}}</script>'
)
assert _parse_game_page(html) is None
class TestExtractLeisureHours:
"""Tests for _extract_leisure_hours."""
def test_leisure_time_only(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 21243, "comp_100": 6800}],
"relationships": [],
}
assert _extract_leisure_hours(data) == round(21243 / 3600, 2)
def test_leisure_with_dlc(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 21243, "comp_100": 6800}],
"relationships": [
{"game_type": "dlc", "comp_100": 12298},
{"game_type": "dlc", "comp_100": 3600},
],
}
assert _extract_leisure_hours(data) == round((21243 + 12298 + 3600) / 3600, 2)
def test_fallback_to_comp_100(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100": 7200}],
"relationships": [],
}
assert _extract_leisure_hours(data) == round(7200 / 3600, 2)
def test_no_game_data(self) -> None:
assert _extract_leisure_hours({"game": [], "relationships": []}) == -1
def test_zero_leisure(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 0, "comp_100": 0}],
"relationships": [],
}
assert _extract_leisure_hours(data) == -1
def test_no_game_key(self) -> None:
assert _extract_leisure_hours({"relationships": []}) == -1
def test_non_dlc_relationship_ignored(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 3600}],
"relationships": [
{"game_type": "game", "comp_100": 9999},
{"game_type": "dlc", "comp_100": 1800},
],
}
assert _extract_leisure_hours(data) == round((3600 + 1800) / 3600, 2)
def test_dlc_zero_comp_100_skipped(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 3600}],
"relationships": [
{"game_type": "dlc", "comp_100": 0},
],
}
assert _extract_leisure_hours(data) == round(3600 / 3600, 2)
def test_negative_leisure(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": -1, "comp_100": -1}],
"relationships": [],
}
assert _extract_leisure_hours(data) == -1
def test_string_numeric_fields(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": "7200", "comp_100": "3600"}],
"relationships": [{"game_type": "dlc", "game_id": "1", "comp_100": "1800"}],
}
assert _extract_leisure_hours(data) == round((7200 + 1800) / 3600, 2)
def test_bad_string_falls_back_to_comp_100(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": "bad", "comp_100": "3600"}],
"relationships": [],
}
assert _extract_leisure_hours(data) == 1.0
def test_relationships_not_list(self) -> None:
data: dict[str, Any] = {
"game": [{"comp_100_h": 3600}],
"relationships": "not-a-list",
}
assert _extract_leisure_hours(data) == 1.0

View File

@ -2,19 +2,15 @@
from __future__ import annotations
import sys
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from python_pkg.steam_backlog_enforcer._cmd_done import (
_enforce_on_done,
_finalize_completion,
cmd_done,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
from python_pkg.steam_backlog_enforcer.main import main
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
CMD_DONE_PKG = "python_pkg.steam_backlog_enforcer._cmd_done"
@ -302,8 +298,6 @@ class TestEnforceOnDone:
_enforce_on_done(config, state)
mock_install.assert_called_once_with(1, "G", "s1", use_steam_protocol=True)
class TestCmdDone:
"""Tests for cmd_done."""
def test_no_game_assigned(self) -> None:
@ -425,54 +419,3 @@ class TestCmdDone:
patch(f"{CMD_DONE_PKG}._try_reassign_shorter_game", return_value=True),
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
class TestMain:
"""Tests for main CLI entry point."""
def test_no_args_exits(self) -> None:
with (
patch.object(sys, "argv", ["prog"]),
patch(f"{PKG}._echo"),
pytest.raises(SystemExit, match="1"),
):
main()
def test_unknown_command_exits(self) -> None:
with (
patch.object(sys, "argv", ["prog", "bogus"]),
patch(f"{PKG}._echo"),
pytest.raises(SystemExit, match="1"),
):
main()
def test_valid_command_runs(self) -> None:
mock_cmd = MagicMock()
with (
patch.object(sys, "argv", ["prog", "status"]),
patch(f"{PKG}.Config.load", return_value=Config(steam_api_key="k")),
patch(f"{PKG}.State.load", return_value=State()),
patch.dict(f"{PKG}.COMMANDS", {"status": ("s", mock_cmd)}),
):
main()
mock_cmd.assert_called_once()
def test_setup_no_key_required(self) -> None:
mock_cmd = MagicMock()
with (
patch.object(sys, "argv", ["prog", "setup"]),
patch(f"{PKG}.Config.load", return_value=Config()),
patch(f"{PKG}.State.load", return_value=State()),
patch.dict(f"{PKG}.COMMANDS", {"setup": ("s", mock_cmd)}),
):
main()
mock_cmd.assert_called_once()
def test_no_api_key_exits(self) -> None:
with (
patch.object(sys, "argv", ["prog", "status"]),
patch(f"{PKG}.Config.load", return_value=Config()),
patch(f"{PKG}._echo"),
pytest.raises(SystemExit, match="1"),
):
main()

View File

@ -0,0 +1,309 @@
"""Tests for main CLI module — part 3 (cmd_done, main, cmd_pick)."""
from __future__ import annotations
import sys
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from python_pkg.steam_backlog_enforcer._cmd_done import (
cmd_done,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
from python_pkg.steam_backlog_enforcer.main import cmd_pick, main
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
CMD_DONE_PKG = "python_pkg.steam_backlog_enforcer._cmd_done"
PKG = "python_pkg.steam_backlog_enforcer.main"
def _snap(
app_id: int,
name: str,
total: int,
unlocked: int,
hours: float,
) -> dict[str, Any]:
return {
"app_id": app_id,
"name": name,
"total_achievements": total,
"unlocked_achievements": unlocked,
"playtime_minutes": 0,
"completionist_hours": hours,
"achievements": [],
}
class TestCmdDone:
"""Tests for cmd_done."""
def test_no_game_assigned(self) -> None:
with patch(f"{CMD_DONE_PKG}._echo") as mock_echo:
cmd_done(Config(), State())
assert any("No game" in str(c) for c in mock_echo.call_args_list)
def test_fetch_fails(self) -> None:
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = None
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{CMD_DONE_PKG}.SteamAPIClient", return_value=mock_client),
patch(f"{CMD_DONE_PKG}._echo"),
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
def test_not_complete_enforces(self) -> None:
game = GameInfo(
app_id=1,
name="G",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{CMD_DONE_PKG}.SteamAPIClient", return_value=mock_client),
patch(f"{CMD_DONE_PKG}._echo"),
patch(f"{CMD_DONE_PKG}.load_hltb_cache", return_value={1: 20.0}),
patch(f"{CMD_DONE_PKG}._try_reassign_shorter_game", return_value=False),
patch(f"{CMD_DONE_PKG}._enforce_on_done"),
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
def test_complete_finalizes(self) -> None:
game = GameInfo(
app_id=1,
name="G",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=60,
)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{CMD_DONE_PKG}.SteamAPIClient", return_value=mock_client),
patch(f"{CMD_DONE_PKG}._echo"),
patch(f"{CMD_DONE_PKG}.load_hltb_cache", return_value={1: 10.0}),
patch(f"{CMD_DONE_PKG}._try_reassign_shorter_game", return_value=False),
patch(f"{CMD_DONE_PKG}._finalize_completion") as mock_final,
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
mock_final.assert_called_once()
def test_hltb_cache_miss_fetches(self) -> None:
game = GameInfo(
app_id=1,
name="G",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{CMD_DONE_PKG}.SteamAPIClient", return_value=mock_client),
patch(f"{CMD_DONE_PKG}._echo"),
patch(f"{CMD_DONE_PKG}.load_hltb_cache", return_value={}),
patch(
f"{CMD_DONE_PKG}.fetch_hltb_times_cached",
return_value={1: 15.0},
),
patch(f"{CMD_DONE_PKG}._try_reassign_shorter_game", return_value=False),
patch(f"{CMD_DONE_PKG}._enforce_on_done"),
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
def test_hltb_negative_no_display(self) -> None:
"""Covers the hours <= 0 branch (no HLTB estimate display)."""
game = GameInfo(
app_id=1,
name="G",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{CMD_DONE_PKG}.SteamAPIClient", return_value=mock_client),
patch(f"{CMD_DONE_PKG}._echo"),
patch(f"{CMD_DONE_PKG}.load_hltb_cache", return_value={1: -1.0}),
patch(f"{CMD_DONE_PKG}._try_reassign_shorter_game", return_value=False),
patch(f"{CMD_DONE_PKG}._enforce_on_done"),
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
def test_reassign_returns_true(self) -> None:
game = GameInfo(
app_id=1,
name="G",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=60,
)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
state = State(current_app_id=1, current_game_name="G")
with (
patch(f"{CMD_DONE_PKG}.SteamAPIClient", return_value=mock_client),
patch(f"{CMD_DONE_PKG}._echo"),
patch(f"{CMD_DONE_PKG}.load_hltb_cache", return_value={1: 50.0}),
patch(f"{CMD_DONE_PKG}._try_reassign_shorter_game", return_value=True),
):
cmd_done(Config(steam_api_key="k", steam_id="i"), state)
class TestMain:
"""Tests for main CLI entry point."""
def test_no_args_exits(self) -> None:
with (
patch.object(sys, "argv", ["prog"]),
patch(f"{PKG}._echo"),
pytest.raises(SystemExit, match="1"),
):
main()
def test_unknown_command_exits(self) -> None:
with (
patch.object(sys, "argv", ["prog", "bogus"]),
patch(f"{PKG}._echo"),
pytest.raises(SystemExit, match="1"),
):
main()
def test_valid_command_runs(self) -> None:
mock_cmd = MagicMock()
with (
patch.object(sys, "argv", ["prog", "status"]),
patch(f"{PKG}.Config.load", return_value=Config(steam_api_key="k")),
patch(f"{PKG}.State.load", return_value=State()),
patch.dict(f"{PKG}.COMMANDS", {"status": ("s", mock_cmd)}),
):
main()
mock_cmd.assert_called_once()
def test_setup_no_key_required(self) -> None:
mock_cmd = MagicMock()
with (
patch.object(sys, "argv", ["prog", "setup"]),
patch(f"{PKG}.Config.load", return_value=Config()),
patch(f"{PKG}.State.load", return_value=State()),
patch.dict(f"{PKG}.COMMANDS", {"setup": ("s", mock_cmd)}),
):
main()
mock_cmd.assert_called_once()
def test_no_api_key_exits(self) -> None:
with (
patch.object(sys, "argv", ["prog", "status"]),
patch(f"{PKG}.Config.load", return_value=Config()),
patch(f"{PKG}._echo"),
pytest.raises(SystemExit, match="1"),
):
main()
class TestCmdPick:
"""Tests for cmd_pick."""
def test_no_snapshot_prints_message(self) -> None:
with (
patch(f"{PKG}.load_snapshot", return_value=[]),
patch(f"{PKG}._echo") as mock_echo,
):
cmd_pick(Config(steam_api_key="k", steam_id="i"), State())
mock_echo.assert_called_once_with("No snapshot found. Run 'scan' first.")
def test_calls_pick_next_game(self) -> None:
snap = [_snap(2, "NewGame", 10, 0, 5.0)]
with (
patch(f"{PKG}.load_snapshot", return_value=snap),
patch(f"{PKG}.load_hltb_cache", return_value={2: 5.0}),
patch(f"{PKG}.pick_next_game") as mock_pick,
patch(f"{PKG}.get_all_owned_app_ids", return_value=[]),
):
config = Config(steam_api_key="k", steam_id="i")
state = State()
cmd_pick(config, state)
mock_pick.assert_called_once()
def test_hides_games_after_pick(self) -> None:
snap = [_snap(2, "NewGame", 10, 0, 5.0)]
state = State(current_app_id=2, current_game_name="NewGame")
with (
patch(f"{PKG}.load_snapshot", return_value=snap),
patch(f"{PKG}.load_hltb_cache", return_value={2: 5.0}),
patch(f"{PKG}.pick_next_game"),
patch(f"{PKG}.get_all_owned_app_ids", return_value=[1, 2, 3]),
patch(f"{PKG}.hide_other_games", return_value=2) as mock_hide,
patch(f"{PKG}._echo"),
):
cmd_pick(Config(steam_api_key="k", steam_id="i"), state)
mock_hide.assert_called_once_with([1, 2, 3], 2)
def test_no_hide_message_when_none_hidden(self) -> None:
snap = [_snap(2, "NewGame", 10, 0, 5.0)]
state = State(current_app_id=2, current_game_name="NewGame")
with (
patch(f"{PKG}.load_snapshot", return_value=snap),
patch(f"{PKG}.load_hltb_cache", return_value={}),
patch(f"{PKG}.pick_next_game"),
patch(f"{PKG}.get_all_owned_app_ids", return_value=[1, 2, 3]),
patch(f"{PKG}.hide_other_games", return_value=0),
patch(f"{PKG}._echo") as mock_echo,
):
cmd_pick(Config(steam_api_key="k", steam_id="i"), state)
mock_echo.assert_not_called()
def test_no_hide_when_no_current_app(self) -> None:
snap = [_snap(2, "NewGame", 10, 0, 5.0)]
with (
patch(f"{PKG}.load_snapshot", return_value=snap),
patch(f"{PKG}.load_hltb_cache", return_value={}),
patch(f"{PKG}.pick_next_game"),
patch(f"{PKG}.get_all_owned_app_ids") as mock_owned,
):
cmd_pick(Config(steam_api_key="k", steam_id="i"), State())
mock_owned.assert_not_called()
def test_no_hide_when_owned_ids_empty(self) -> None:
snap = [_snap(2, "NewGame", 10, 0, 5.0)]
state = State(current_app_id=2, current_game_name="NewGame")
with (
patch(f"{PKG}.load_snapshot", return_value=snap),
patch(f"{PKG}.load_hltb_cache", return_value={}),
patch(f"{PKG}.pick_next_game"),
patch(f"{PKG}.get_all_owned_app_ids", return_value=[]),
patch(f"{PKG}.hide_other_games") as mock_hide,
):
cmd_pick(Config(steam_api_key="k", steam_id="i"), state)
mock_hide.assert_not_called()
def test_hltb_cache_applied_to_games(self) -> None:
snap = [_snap(2, "NewGame", 10, 0, -1.0)]
captured_games: list[list[GameInfo]] = []
config = Config(steam_api_key="k", steam_id="i")
state = State()
def capture_pick(games: list[GameInfo], *_args: object) -> None:
captured_games.append(list(games))
with (
patch(f"{PKG}.load_snapshot", return_value=snap),
patch(f"{PKG}.load_hltb_cache", return_value={2: 7.5}),
patch(f"{PKG}.pick_next_game", side_effect=capture_pick),
patch(f"{PKG}.get_all_owned_app_ids", return_value=[]),
):
cmd_pick(config, state)
assert len(captured_games) == 1
assert captured_games[0][0].completionist_hours == pytest.approx(7.5)

View File

@ -6,7 +6,7 @@ import json
from typing import TYPE_CHECKING
from unittest.mock import patch
from python_pkg.steam_backlog_enforcer import _cmd_done, scanning
from python_pkg.steam_backlog_enforcer import _cmd_done
from python_pkg.steam_backlog_enforcer._hltb_types import (
HLTBResult,
load_hltb_cache,
@ -350,380 +350,3 @@ class TestReportAssignedConfidence:
_cmd_done._report_assigned_confidence(1, _state([2], current=1))
assert not any("NEW LOW" in s for s in echoed)
assert not any("no polls recorded" in s for s in echoed)
class TestScanningPollsIntegration:
def test_do_scan_kept_assignment_reports(self) -> None:
# Targeted test for scanning's `else` branch that prints CURRENT.
echoed: list[str] = []
games = [
GameInfo(
app_id=1,
name="X",
total_achievements=10,
unlocked_achievements=2,
playtime_minutes=0,
completionist_hours=5.0,
comp_100_count=20,
)
]
state = _state([], current=1)
with (
patch(f"{_SCAN}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
patch(f"{_SCAN}._report_poll_confidence") as mock_report,
):
# Directly invoke just the kept-assignment branch.
current = next((g for g in games if g.app_id == state.current_app_id), None)
assert current is not None
scanning._echo(f"\n>>> CURRENT: {current.name} (AppID={current.app_id})")
scanning._report_poll_confidence(current, games, state)
assert any("CURRENT" in s for s in echoed)
mock_report.assert_called_once()
def test_report_poll_confidence_new_low(self) -> None:
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=0,
)
games = [
chosen,
GameInfo(
app_id=2,
name="Old",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=0,
),
]
with (
patch(
f"{_SCAN}._backfill_polls_for_finished",
return_value={1: 1, 2: 5},
),
patch(f"{_SCAN}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
):
scanning._report_poll_confidence(chosen, games, _state([2], current=1))
assert any("NEW LOW" in s for s in echoed)
assert chosen.comp_100_count == 1
def test_report_poll_confidence_no_history(self) -> None:
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=4,
)
with (
patch(f"{_SCAN}._backfill_polls_for_finished", return_value={1: 4}),
patch(f"{_SCAN}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
):
scanning._report_poll_confidence(chosen, [chosen], _state([], current=1))
# No "Historical min" line when no finished games have polls.
assert not any("Historical min" in s for s in echoed)
assert any("HLTB confidence: 4" in s for s in echoed)
def test_scanning_backfill_no_missing(self, tmp_path: Path) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"2": {"hours": 1.0, "polls": 5}}), encoding="utf-8"
)
with patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file):
result = scanning._backfill_polls_for_finished(
_state([2]),
[
GameInfo(
app_id=2,
name="X",
total_achievements=0,
unlocked_achievements=0,
playtime_minutes=0,
)
],
)
assert result == {2: 5}
def test_scanning_backfill_with_missing(self, tmp_path: Path) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"2": {"hours": 3.0, "polls": 0}}), encoding="utf-8"
)
def fake_fetch(games: list[tuple[int, str]]) -> dict[int, float]:
data = json.loads(cache_file.read_text(encoding="utf-8"))
for aid, _name in games:
data[str(aid)] = {"hours": 3.0, "polls": 8}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {aid: 3.0 for aid, _ in games}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(f"{_SCAN}.fetch_hltb_confidence_cached", side_effect=fake_fetch),
):
result = scanning._backfill_polls_for_finished(
_state([2]),
[
GameInfo(
app_id=2,
name="X",
total_achievements=0,
unlocked_achievements=0,
playtime_minutes=0,
)
],
)
assert result == {2: 8}
def test_scanning_backfill_preserves_hours_on_miss(self, tmp_path: Path) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"2": {"hours": 9.0, "polls": 0}}), encoding="utf-8"
)
def fake_fetch(games: list[tuple[int, str]]) -> dict[int, float]:
data = json.loads(cache_file.read_text(encoding="utf-8"))
for aid, _name in games:
data[str(aid)] = {"hours": -1, "polls": 0}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {aid: -1 for aid, _ in games}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(f"{_SCAN}.fetch_hltb_confidence_cached", side_effect=fake_fetch),
):
scanning._backfill_polls_for_finished(
_state([2]),
[
GameInfo(
app_id=2,
name="X",
total_achievements=0,
unlocked_achievements=0,
playtime_minutes=0,
)
],
)
final = json.loads(cache_file.read_text(encoding="utf-8"))
assert final["2"]["hours"] == 9.0
def test_report_poll_confidence_chosen_zero_polls(self) -> None:
"""Covers scanning.py 301-302: 0-poll chosen with history yields warning."""
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=0,
)
old = GameInfo(
app_id=2,
name="Old",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=0,
)
with (
patch(
f"{_SCAN}._backfill_polls_for_finished",
return_value={1: 0, 2: 5},
),
patch(f"{_SCAN}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
):
scanning._report_poll_confidence(
chosen, [chosen, old], _state([2], current=1)
)
assert any("no polls recorded" in s for s in echoed)
def test_do_scan_kept_assignment_missing_game(self) -> None:
"""Covers scanning.py 110->116: current_app_id set but game absent."""
from python_pkg.steam_backlog_enforcer.config import Config
from python_pkg.steam_backlog_enforcer.scanning import do_scan
other = GameInfo(
app_id=999,
name="Other",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=0,
)
from unittest.mock import MagicMock
mock_client = MagicMock()
mock_client.build_game_list.return_value = [other]
with (
patch(f"{_SCAN}.SteamAPIClient", return_value=mock_client),
patch(f"{_SCAN}.fetch_hltb_times_cached", return_value={999: 10.0}),
patch(f"{_SCAN}.save_snapshot"),
patch(f"{_SCAN}.pick_next_game") as mock_pick,
patch(f"{_SCAN}._echo"),
patch(f"{_SCAN}._report_poll_confidence") as mock_report,
):
config = Config(steam_api_key="k", steam_id="i")
state = State(current_app_id=440) # not in games
do_scan(config, state)
mock_pick.assert_not_called()
mock_report.assert_not_called()
def test_cmd_done_no_finished_history_chosen_has_polls(self) -> None:
"""Covers _cmd_done.py 100->103: no finished history, chosen has >0 polls."""
echoed: list[str] = []
with (
patch(
f"{_CMD}._backfill_polls_for_finished",
return_value={1: 7},
),
patch(
f"{_CMD}.load_snapshot",
return_value=[
{"app_id": 1, "name": "Chosen"},
],
),
patch(f"{_CMD}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
):
_cmd_done._report_assigned_confidence(1, _state([], current=1))
assert any("HLTB confidence: 7" in s for s in echoed)
assert not any("NEW LOW" in s for s in echoed)
assert not any("no polls recorded" in s for s in echoed)
def test_report_poll_confidence_chosen_equals_min(self) -> None:
"""Covers scanning.py 301->304: chosen_polls >= min_polls, no warning."""
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=5,
)
old = GameInfo(
app_id=2,
name="Old",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=0,
)
with (
patch(
f"{_SCAN}._backfill_polls_for_finished",
return_value={1: 5, 2: 5},
),
patch(f"{_SCAN}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
):
scanning._report_poll_confidence(
chosen, [chosen, old], _state([2], current=1)
)
assert not any("NEW LOW" in s for s in echoed)
assert not any("no polls recorded" in s for s in echoed)
def test_refresh_candidate_confidence_noop_when_present(self) -> None:
game = GameInfo(
app_id=1,
name="Known",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=3,
count_comp=15,
)
with patch(f"{_SCAN}.fetch_hltb_confidence_cached") as mock_fetch:
scanning._refresh_candidate_confidence(game)
mock_fetch.assert_not_called()
def test_refresh_candidate_confidence_backfills_zeroes(
self, tmp_path: Path
) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"1": {"hours": 4.0, "polls": 0, "count_comp": 0}}),
encoding="utf-8",
)
game = GameInfo(
app_id=1,
name="NeedsRefresh",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=0,
count_comp=0,
)
def fake_fetch(_games: list[tuple[int, str]]) -> dict[int, float]:
data = json.loads(cache_file.read_text(encoding="utf-8"))
data["1"] = {"hours": 4.0, "polls": 3, "count_comp": 15}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {1: 4.0}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(f"{_SCAN}.fetch_hltb_confidence_cached", side_effect=fake_fetch),
patch(f"{_SCAN}._echo"),
):
scanning._refresh_candidate_confidence(game)
assert game.comp_100_count == 3
assert game.count_comp == 15
def test_filter_hltb_confidence_batches_refreshes(self, tmp_path: Path) -> None:
"""Filtering refreshes missing confidence in one batched cache lookup."""
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps(
{
"1": {"hours": 4.0, "polls": 0, "count_comp": 0},
"2": {"hours": 5.0, "polls": 0, "count_comp": 0},
}
),
encoding="utf-8",
)
game_a = GameInfo(
app_id=1,
name="A",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=0,
count_comp=0,
)
game_b = GameInfo(
app_id=2,
name="B",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=0,
count_comp=0,
)
def fake_fetch(games: list[tuple[int, str]]) -> dict[int, float]:
assert sorted(games) == [(1, "A"), (2, "B")]
data = json.loads(cache_file.read_text(encoding="utf-8"))
data["1"] = {"hours": 4.0, "polls": 3, "count_comp": 15}
data["2"] = {"hours": 5.0, "polls": 3, "count_comp": 15}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {1: 4.0, 2: 5.0}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(
f"{_SCAN}.fetch_hltb_confidence_cached", side_effect=fake_fetch
) as mock_fetch,
patch(f"{_SCAN}._echo"),
):
kept = scanning._filter_hltb_confident_candidates([game_a, game_b])
assert [game.app_id for game in kept] == [1, 2]
mock_fetch.assert_called_once()

View File

@ -0,0 +1,417 @@
"""Tests for HLTB poll-count tracking — scanning integration (part 2)."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from python_pkg.steam_backlog_enforcer import _cmd_done, _scanning_confidence, scanning
from python_pkg.steam_backlog_enforcer.config import State
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
if TYPE_CHECKING:
from pathlib import Path
_TYPES = "python_pkg.steam_backlog_enforcer._hltb_types"
_CMD = "python_pkg.steam_backlog_enforcer._cmd_done"
_SCAN = "python_pkg.steam_backlog_enforcer.scanning"
_SCANCONF = "python_pkg.steam_backlog_enforcer._scanning_confidence"
def _state(finished: list[int], current: int | None = None) -> State:
s = State()
s.finished_app_ids = list(finished)
s.current_app_id = current
s.current_game_name = ""
return s
class TestScanningPollsIntegration:
def test_do_scan_kept_assignment_reports(self) -> None:
# Targeted test for scanning's `else` branch that prints CURRENT.
echoed: list[str] = []
games = [
GameInfo(
app_id=1,
name="X",
total_achievements=10,
unlocked_achievements=2,
playtime_minutes=0,
completionist_hours=5.0,
comp_100_count=20,
)
]
state = _state([], current=1)
with (
patch(f"{_SCAN}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
patch(f"{_SCAN}._report_poll_confidence") as mock_report,
):
# Directly invoke just the kept-assignment branch.
current = next((g for g in games if g.app_id == state.current_app_id), None)
assert current is not None
scanning._echo(f"\n>>> CURRENT: {current.name} (AppID={current.app_id})")
scanning._report_poll_confidence(current, games, state)
assert any("CURRENT" in s for s in echoed)
mock_report.assert_called_once()
def test_report_poll_confidence_new_low(self) -> None:
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=0,
)
games = [
chosen,
GameInfo(
app_id=2,
name="Old",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=0,
),
]
with (
patch(
f"{_SCANCONF}._backfill_polls_for_finished",
return_value={1: 1, 2: 5},
),
patch(
f"{_SCANCONF}._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
):
scanning._report_poll_confidence(chosen, games, _state([2], current=1))
assert any("NEW LOW" in s for s in echoed)
assert chosen.comp_100_count == 1
def test_report_poll_confidence_no_history(self) -> None:
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=4,
)
with (
patch(f"{_SCANCONF}._backfill_polls_for_finished", return_value={1: 4}),
patch(
f"{_SCANCONF}._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
):
scanning._report_poll_confidence(chosen, [chosen], _state([], current=1))
# No "Historical min" line when no finished games have polls.
assert not any("Historical min" in s for s in echoed)
assert any("HLTB confidence: 4" in s for s in echoed)
def test_scanning_backfill_no_missing(self, tmp_path: Path) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"2": {"hours": 1.0, "polls": 5}}), encoding="utf-8"
)
with patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file):
result = _scanning_confidence._backfill_polls_for_finished(
_state([2]),
[
GameInfo(
app_id=2,
name="X",
total_achievements=0,
unlocked_achievements=0,
playtime_minutes=0,
)
],
)
assert result == {2: 5}
def test_scanning_backfill_with_missing(self, tmp_path: Path) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"2": {"hours": 3.0, "polls": 0}}), encoding="utf-8"
)
def fake_fetch(games: list[tuple[int, str]]) -> dict[int, float]:
data = json.loads(cache_file.read_text(encoding="utf-8"))
for aid, _name in games:
data[str(aid)] = {"hours": 3.0, "polls": 8}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {aid: 3.0 for aid, _ in games}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(f"{_SCANCONF}.fetch_hltb_confidence_cached", side_effect=fake_fetch),
):
result = _scanning_confidence._backfill_polls_for_finished(
_state([2]),
[
GameInfo(
app_id=2,
name="X",
total_achievements=0,
unlocked_achievements=0,
playtime_minutes=0,
)
],
)
assert result == {2: 8}
def test_scanning_backfill_preserves_hours_on_miss(self, tmp_path: Path) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"2": {"hours": 9.0, "polls": 0}}), encoding="utf-8"
)
def fake_fetch(games: list[tuple[int, str]]) -> dict[int, float]:
data = json.loads(cache_file.read_text(encoding="utf-8"))
for aid, _name in games:
data[str(aid)] = {"hours": -1, "polls": 0}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {aid: -1 for aid, _ in games}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(f"{_SCANCONF}.fetch_hltb_confidence_cached", side_effect=fake_fetch),
):
_scanning_confidence._backfill_polls_for_finished(
_state([2]),
[
GameInfo(
app_id=2,
name="X",
total_achievements=0,
unlocked_achievements=0,
playtime_minutes=0,
)
],
)
final = json.loads(cache_file.read_text(encoding="utf-8"))
assert final["2"]["hours"] == 9.0
def test_report_poll_confidence_chosen_zero_polls(self) -> None:
"""Covers scanning.py 301-302: 0-poll chosen with history yields warning."""
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=0,
)
old = GameInfo(
app_id=2,
name="Old",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=0,
)
with (
patch(
f"{_SCANCONF}._backfill_polls_for_finished",
return_value={1: 0, 2: 5},
),
patch(
f"{_SCANCONF}._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
):
scanning._report_poll_confidence(
chosen, [chosen, old], _state([2], current=1)
)
assert any("no polls recorded" in s for s in echoed)
def test_do_scan_kept_assignment_missing_game(self) -> None:
"""Covers scanning.py 110->116: current_app_id set but game absent."""
from python_pkg.steam_backlog_enforcer.config import Config
from python_pkg.steam_backlog_enforcer.scanning import do_scan
other = GameInfo(
app_id=999,
name="Other",
total_achievements=10,
unlocked_achievements=5,
playtime_minutes=0,
)
mock_client = MagicMock()
mock_client.build_game_list.return_value = [other]
with (
patch(f"{_SCAN}.SteamAPIClient", return_value=mock_client),
patch(f"{_SCAN}.fetch_hltb_times_cached", return_value={999: 10.0}),
patch(f"{_SCAN}.save_snapshot"),
patch(f"{_SCAN}.pick_next_game") as mock_pick,
patch(f"{_SCAN}._echo"),
patch(f"{_SCAN}._report_poll_confidence") as mock_report,
):
config = Config(steam_api_key="k", steam_id="i")
state = State(current_app_id=440) # not in games
do_scan(config, state)
mock_pick.assert_not_called()
mock_report.assert_not_called()
def test_cmd_done_no_finished_history_chosen_has_polls(self) -> None:
"""Covers _cmd_done.py 100->103: no finished history, chosen has >0 polls."""
echoed: list[str] = []
with (
patch(
f"{_CMD}._backfill_polls_for_finished",
return_value={1: 7},
),
patch(
f"{_CMD}.load_snapshot",
return_value=[
{"app_id": 1, "name": "Chosen"},
],
),
patch(f"{_CMD}._echo", side_effect=lambda *a, **_: echoed.append(a[0])),
):
_cmd_done._report_assigned_confidence(1, _state([], current=1))
assert any("HLTB confidence: 7" in s for s in echoed)
assert not any("NEW LOW" in s for s in echoed)
assert not any("no polls recorded" in s for s in echoed)
def test_report_poll_confidence_chosen_equals_min(self) -> None:
"""Covers scanning.py 301->304: chosen_polls >= min_polls, no warning."""
echoed: list[str] = []
chosen = GameInfo(
app_id=1,
name="Chosen",
total_achievements=10,
unlocked_achievements=0,
playtime_minutes=0,
comp_100_count=5,
)
old = GameInfo(
app_id=2,
name="Old",
total_achievements=10,
unlocked_achievements=10,
playtime_minutes=0,
)
with (
patch(
f"{_SCANCONF}._backfill_polls_for_finished",
return_value={1: 5, 2: 5},
),
patch(
f"{_SCANCONF}._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
):
scanning._report_poll_confidence(
chosen, [chosen, old], _state([2], current=1)
)
assert not any("NEW LOW" in s for s in echoed)
assert not any("no polls recorded" in s for s in echoed)
def test_refresh_candidate_confidence_noop_when_present(self) -> None:
game = GameInfo(
app_id=1,
name="Known",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=3,
count_comp=15,
)
with patch(f"{_SCANCONF}.fetch_hltb_confidence_cached") as mock_fetch:
_scanning_confidence._refresh_candidate_confidence(game)
mock_fetch.assert_not_called()
def test_refresh_candidate_confidence_backfills_zeroes(
self, tmp_path: Path
) -> None:
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps({"1": {"hours": 4.0, "polls": 0, "count_comp": 0}}),
encoding="utf-8",
)
game = GameInfo(
app_id=1,
name="NeedsRefresh",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=0,
count_comp=0,
)
def fake_fetch(_games: list[tuple[int, str]]) -> dict[int, float]:
data = json.loads(cache_file.read_text(encoding="utf-8"))
data["1"] = {"hours": 4.0, "polls": 3, "count_comp": 15}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {1: 4.0}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(f"{_SCANCONF}.fetch_hltb_confidence_cached", side_effect=fake_fetch),
patch(f"{_SCANCONF}._echo"),
):
_scanning_confidence._refresh_candidate_confidence(game)
assert game.comp_100_count == 3
assert game.count_comp == 15
def test_filter_hltb_confidence_batches_refreshes(self, tmp_path: Path) -> None:
"""Filtering refreshes missing confidence in one batched cache lookup."""
cache_file = tmp_path / "hltb_cache.json"
cache_file.write_text(
json.dumps(
{
"1": {"hours": 4.0, "polls": 0, "count_comp": 0},
"2": {"hours": 5.0, "polls": 0, "count_comp": 0},
}
),
encoding="utf-8",
)
game_a = GameInfo(
app_id=1,
name="A",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=0,
count_comp=0,
)
game_b = GameInfo(
app_id=2,
name="B",
total_achievements=10,
unlocked_achievements=1,
playtime_minutes=0,
comp_100_count=0,
count_comp=0,
)
def fake_fetch(games: list[tuple[int, str]]) -> dict[int, float]:
assert sorted(games) == [(1, "A"), (2, "B")]
data = json.loads(cache_file.read_text(encoding="utf-8"))
data["1"] = {"hours": 4.0, "polls": 3, "count_comp": 15}
data["2"] = {"hours": 5.0, "polls": 3, "count_comp": 15}
cache_file.write_text(json.dumps(data), encoding="utf-8")
return {1: 4.0, 2: 5.0}
with (
patch(f"{_TYPES}.HLTB_CACHE_FILE", cache_file),
patch(f"{_TYPES}.CONFIG_DIR", tmp_path),
patch(
f"{_SCANCONF}.fetch_hltb_confidence_cached", side_effect=fake_fetch
) as mock_fetch,
patch(f"{_SCANCONF}._echo"),
):
kept = _scanning_confidence._filter_hltb_confident_candidates(
[game_a, game_b]
)
assert [game.app_id for game in kept] == [1, 2]
mock_fetch.assert_called_once()

View File

@ -8,12 +8,7 @@ from unittest.mock import MagicMock, patch
from python_pkg.steam_backlog_enforcer.config import Config, State
from python_pkg.steam_backlog_enforcer.protondb import ProtonDBRating
from python_pkg.steam_backlog_enforcer.scanning import (
_filter_hltb_confident_candidates,
_force_refresh_candidate_confidence,
_pick_next_shortest_candidate,
_pick_playable_candidate,
_refresh_candidate_confidence_batch,
do_check,
do_scan,
pick_next_game,
)
@ -223,14 +218,12 @@ class TestPickNextGame:
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._force_refresh_candidate_confidence"
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer._scanning_confidence._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
@ -239,6 +232,7 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([g1, g2], state, config)
assert state.current_app_id == 2
@ -270,6 +264,7 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([g1, g2], state, config)
assert state.current_app_id == 2
@ -293,14 +288,12 @@ class TestPickNextGame:
config = Config(steam_api_key="k", steam_id="i", uninstall_other_games=True)
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._force_refresh_candidate_confidence"
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer._scanning_confidence._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=2,
@ -309,6 +302,7 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([g1], state, config)
assert state.current_app_id == 1
@ -318,14 +312,12 @@ class TestPickNextGame:
config = Config(steam_api_key="k", steam_id="i", uninstall_other_games=False)
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._force_refresh_candidate_confidence"
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer._scanning_confidence._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=False,
@ -333,6 +325,7 @@ class TestPickNextGame:
patch(
"python_pkg.steam_backlog_enforcer.scanning.install_game"
) as mock_install,
patch("builtins.input", return_value="1"),
):
pick_next_game([g1], state, config)
mock_install.assert_called_once()
@ -356,6 +349,7 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([g1, g2], state, config)
assert state.current_app_id == 2
@ -379,6 +373,7 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([g1], state, config)
assert state.current_app_id == 1
@ -394,9 +389,6 @@ class TestPickNextGame:
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._force_refresh_candidate_confidence"
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
@ -405,6 +397,10 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
@ -413,6 +409,7 @@ class TestPickNextGame:
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([low, valid], state, config)
assert state.current_app_id == 2
@ -435,7 +432,8 @@ class TestPickNextGame:
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._force_refresh_candidate_confidence"
"python_pkg.steam_backlog_enforcer._scanning_confidence._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
@ -446,350 +444,3 @@ class TestPickNextGame:
assert state.current_app_id is None
mock_pick.assert_not_called()
assert any("No assignable games found" in line for line in echoed)
def test_zero_confidence_is_refreshed_before_skipping(self) -> None:
"""Missing confidence fields are refreshed once before final skip decision."""
stale = _game(app_id=1, name="Celeste", hours=1.0)
stale.comp_100_count = 0
stale.count_comp = 0
fallback = _game(app_id=2, name="Fallback", hours=2.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
echoed: list[str] = []
def refresh_side_effect(game: GameInfo) -> None:
if game.app_id == 1:
game.comp_100_count = 899
game.count_comp = 14055
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._refresh_candidate_confidence",
side_effect=refresh_side_effect,
) as mock_refresh,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
):
pick_next_game([stale, fallback], state, config)
assert state.current_app_id == 1
mock_refresh.assert_called_once_with(stale)
assert not any("Skipping Celeste" in line for line in echoed)
def test_nonzero_low_confidence_does_not_force_refetch(self) -> None:
"""Non-zero low-confidence entries are skipped using cached values."""
low = _game(app_id=1, name="Low", hours=1.0)
low.comp_100_count = 1
low.count_comp = 8
fallback = _game(app_id=2, name="Fallback", hours=2.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._refresh_candidate_confidence_batch"
) as mock_refresh_batch,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
):
pick_next_game([low, fallback], state, config)
assert state.current_app_id == 2
mock_refresh_batch.assert_not_called()
def test_cached_confidence_overlay_avoids_refetch_for_zero_snapshot_fields(
self,
) -> None:
"""Use cached confidence before deciding whether refresh is needed."""
low = _game(app_id=1, name="Low", hours=1.0)
low.comp_100_count = 0
low.count_comp = 0
fallback = _game(app_id=2, name="Fallback", hours=2.0)
fallback.comp_100_count = 3
fallback.count_comp = 20
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_hltb_polls_cache",
return_value={1: 1, 2: 3},
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_hltb_count_comp_cache",
return_value={1: 8, 2: 20},
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._refresh_candidate_confidence_batch"
) as mock_refresh_batch,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
):
pick_next_game([low, fallback], state, config)
assert state.current_app_id == 2
mock_refresh_batch.assert_not_called()
def test_stops_after_first_confident_assignment(self) -> None:
"""Only candidates up to the winning one are checked/skipped."""
low = _game(app_id=1, name="Low", hours=1.0)
low.comp_100_count = 1
low.count_comp = 2
good = _game(app_id=2, name="Good", hours=2.0)
good.comp_100_count = 10
good.count_comp = 50
never_checked = _game(app_id=3, name="NeverChecked", hours=3.0)
never_checked.comp_100_count = 0
never_checked.count_comp = 0
config = Config(steam_api_key="k", steam_id="i")
state = State()
echoed: list[str] = []
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._refresh_candidate_confidence"
) as mock_refresh,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
):
pick_next_game([low, good, never_checked], state, config)
assert state.current_app_id == 2
mock_refresh.assert_called_once_with(low)
assert any("Skipping Low" in line for line in echoed)
assert not any("Skipping NeverChecked" in line for line in echoed)
class TestDoCheck:
"""Tests for do_check."""
def test_no_assignment(self) -> None:
with patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo:
do_check(Config(), State())
mock_echo.assert_called()
def test_fetch_fails(self) -> None:
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = None
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)
class TestConfidenceHelpers:
"""Coverage-focused tests for scanning confidence helper branches."""
def test_force_refresh_candidate_confidence_delegates(self) -> None:
game = _game(app_id=10, name="A")
with patch(
"python_pkg.steam_backlog_enforcer.scanning._refresh_candidate_confidence_batch",
) as mock_batch:
_force_refresh_candidate_confidence(game)
mock_batch.assert_called_once_with([game], force=True)
def test_refresh_candidate_confidence_batch_no_missing_skips_fetch(self) -> None:
game = _game(app_id=20, name="B", hours=12.0)
game.comp_100_count = 3
game.count_comp = 15
with patch(
"python_pkg.steam_backlog_enforcer.scanning.fetch_hltb_confidence_cached",
) as mock_fetch:
_refresh_candidate_confidence_batch([game], force=False)
mock_fetch.assert_not_called()
def test_refresh_candidate_confidence_batch_preserves_existing_hours(self) -> None:
game = _game(app_id=30, name="C", hours=9.5)
game.comp_100_count = 0
game.count_comp = 0
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_hltb_cache",
side_effect=[{30: 9.5}, {30: -1.0}],
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_hltb_polls_cache",
return_value={30: 0},
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_hltb_count_comp_cache",
return_value={30: 0},
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.fetch_hltb_confidence_cached",
return_value={30: -1.0},
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.save_hltb_cache",
) as mock_save,
):
_refresh_candidate_confidence_batch([game], force=True)
assert game.completionist_hours == 9.5
saved_cache = mock_save.call_args.args[0]
assert saved_cache[30] == 9.5
def test_filter_hltb_confident_candidates_skips_low_confidence(self) -> None:
low = _game(app_id=40, name="Low", hours=2.0)
low.comp_100_count = 1
low.count_comp = 2
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._refresh_candidate_confidence_batch",
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo,
):
result = _filter_hltb_confident_candidates([low])
assert result == []
assert mock_echo.called
def test_pick_next_shortest_candidate_logs_skipped_unplayable_batches(self) -> None:
bad = _game(app_id=50, name="Bad", hours=1.0)
good = _game(app_id=51, name="Good", hours=2.0)
bad.comp_100_count = 3
bad.count_comp = 15
good.comp_100_count = 3
good.count_comp = 15
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=[None, good],
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo,
):
picked, skipped_low_conf, skipped_linux = _pick_next_shortest_candidate(
[bad, good],
)
assert picked is good
assert skipped_low_conf == 0
assert skipped_linux == 1
assert any(
"Skipped 1 game(s) with poor Linux compatibility" in str(call)
for call in mock_echo.call_args_list
)
def test_complete(self) -> None:
game = _game(app_id=440, name="TF2", total=5, unlocked=5)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
snap = [game.to_snapshot()]
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.send_notification",
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_snapshot",
return_value=snap,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.pick_next_game",
),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)
assert 440 in state.finished_app_ids
def test_complete_no_snapshot(self) -> None:
game = _game(app_id=440, name="TF2", total=5, unlocked=5)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.send_notification",
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_snapshot",
return_value=None,
),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)
def test_not_complete(self) -> None:
game = _game(app_id=440, name="TF2", total=10, unlocked=5)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)

View File

@ -0,0 +1,280 @@
"""Tests for scanning module (part 3): TestPickNextGame continued."""
from __future__ import annotations
from unittest.mock import patch
from python_pkg.steam_backlog_enforcer.config import Config, State
from python_pkg.steam_backlog_enforcer.scanning import pick_next_game
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
def _game(
app_id: int = 1,
name: str = "G",
total: int = 10,
unlocked: int = 0,
hours: float = -1,
) -> GameInfo:
return GameInfo(
app_id=app_id,
name=name,
total_achievements=total,
unlocked_achievements=unlocked,
playtime_minutes=60,
completionist_hours=hours,
comp_100_count=3,
count_comp=15,
)
class TestPickNextGame:
"""Tests for pick_next_game (continued from test_scanning.py)."""
def test_zero_confidence_is_refreshed_before_skipping(self) -> None:
"""Missing confidence fields are refreshed once before final skip decision."""
stale = _game(app_id=1, name="Celeste", hours=1.0)
stale.comp_100_count = 0
stale.count_comp = 0
fallback = _game(app_id=2, name="Fallback", hours=2.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
echoed: list[str] = []
def refresh_side_effect(game: GameInfo) -> None:
if game.app_id == 1:
game.comp_100_count = 899
game.count_comp = 14055
with (
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._refresh_candidate_confidence",
side_effect=refresh_side_effect,
) as mock_refresh,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([stale, fallback], state, config)
assert state.current_app_id == 1
mock_refresh.assert_called_once_with(stale)
assert not any("Skipping Celeste" in line for line in echoed)
def test_nonzero_low_confidence_does_not_force_refetch(self) -> None:
"""Non-zero low-confidence entries are skipped using cached values."""
low = _game(app_id=1, name="Low", hours=1.0)
low.comp_100_count = 1
low.count_comp = 8
fallback = _game(app_id=2, name="Fallback", hours=2.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._refresh_candidate_confidence_batch"
) as mock_refresh_batch,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([low, fallback], state, config)
assert state.current_app_id == 2
mock_refresh_batch.assert_not_called()
def test_cached_confidence_overlay_avoids_refetch_for_zero_snapshot_fields(
self,
) -> None:
"""Use cached confidence before deciding whether refresh is needed."""
low = _game(app_id=1, name="Low", hours=1.0)
low.comp_100_count = 0
low.count_comp = 0
fallback = _game(app_id=2, name="Fallback", hours=2.0)
fallback.comp_100_count = 3
fallback.count_comp = 20
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.load_hltb_polls_cache",
return_value={1: 1, 2: 3},
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.load_hltb_count_comp_cache",
return_value={1: 8, 2: 20},
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._refresh_candidate_confidence_batch"
) as mock_refresh_batch,
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game([low, fallback], state, config)
assert state.current_app_id == 2
mock_refresh_batch.assert_not_called()
def test_stops_collecting_after_n_qualified(self) -> None:
"""Collection stops once _PICK_LIST_SIZE candidates are qualified."""
# Create 11 games that all pass filters; only the first 10 should be
# presented and the 11th should never trigger a ProtonDB call.
games = [_game(app_id=i, name=f"G{i}", hours=float(i)) for i in range(1, 12)]
protondb_call_count = 0
def playable_side_effect(c: list[GameInfo]) -> GameInfo | None:
nonlocal protondb_call_count
protondb_call_count += 1
return c[0] if c else None
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=playable_side_effect,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="1"),
):
pick_next_game(games, state, config)
assert state.current_app_id == 1
assert protondb_call_count == 10
def test_user_picks_second_candidate(self) -> None:
"""User can select a game other than the shortest one."""
g1 = _game(app_id=1, name="Short", hours=5.0)
g2 = _game(app_id=2, name="Medium", hours=15.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", return_value="2"),
):
pick_next_game([g1, g2], state, config)
assert state.current_app_id == 2
def test_invalid_input_then_valid(self) -> None:
"""Non-numeric input prints error and loops until valid input."""
g1 = _game(app_id=1, name="G1", hours=5.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
echoed: list[str] = []
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", side_effect=["abc", "1"]),
):
pick_next_game([g1], state, config)
assert state.current_app_id == 1
assert any("Invalid input" in line for line in echoed)
def test_out_of_range_then_valid(self) -> None:
"""Out-of-range number prints error and loops until valid input."""
g1 = _game(app_id=1, name="G1", hours=5.0)
config = Config(steam_api_key="k", steam_id="i")
state = State()
echoed: list[str] = []
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning._echo",
side_effect=lambda *a, **_: echoed.append(a[0]),
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.is_game_installed",
return_value=True,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.uninstall_other_games",
return_value=0,
),
patch("builtins.input", side_effect=["99", "1"]),
):
pick_next_game([g1], state, config)
assert state.current_app_id == 1
assert any("Out of range" in line for line in echoed)

View File

@ -0,0 +1,328 @@
"""Scanning tests (part 4): collect_top_candidates, do_check, confidence."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from python_pkg.steam_backlog_enforcer._scanning_confidence import (
_filter_hltb_confident_candidates,
_force_refresh_candidate_confidence,
_refresh_candidate_confidence_batch,
)
from python_pkg.steam_backlog_enforcer.config import Config, State
from python_pkg.steam_backlog_enforcer.scanning import (
_collect_top_candidates,
_pick_next_shortest_candidate,
do_check,
)
from python_pkg.steam_backlog_enforcer.steam_api import GameInfo
def _game(
app_id: int = 1,
name: str = "G",
total: int = 10,
unlocked: int = 0,
hours: float = -1,
) -> GameInfo:
return GameInfo(
app_id=app_id,
name=name,
total_achievements=total,
unlocked_achievements=unlocked,
playtime_minutes=60,
completionist_hours=hours,
comp_100_count=3,
count_comp=15,
)
class TestCollectTopCandidates:
"""Tests for _collect_top_candidates."""
def test_collects_up_to_n(self) -> None:
"""Returns at most n qualified candidates."""
games = [_game(app_id=i, name=f"G{i}", hours=float(i)) for i in range(1, 6)]
with patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
):
qualified, conf_skip, linux_skip = _collect_top_candidates(games, n=3)
assert len(qualified) == 3
assert [g.app_id for g in qualified] == [1, 2, 3]
assert conf_skip == 0
assert linux_skip == 0
def test_skips_linux_incompatible(self) -> None:
"""Games failing ProtonDB are counted in linux_skipped."""
g1 = _game(app_id=1, name="Borked", hours=1.0)
g2 = _game(app_id=2, name="Good", hours=2.0)
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: None if c[0].app_id == 1 else c[0],
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
):
qualified, conf_skip, linux_skip = _collect_top_candidates([g1, g2], n=10)
assert [g.app_id for g in qualified] == [2]
assert linux_skip == 1
assert conf_skip == 0
def test_empty_candidates(self) -> None:
qualified, conf_skip, linux_skip = _collect_top_candidates([])
assert qualified == []
assert conf_skip == 0
assert linux_skip == 0
def test_no_linux_skip_message_when_zero(self) -> None:
"""No skip message is printed when linux_skipped is 0."""
g = _game(app_id=1, name="Good", hours=1.0)
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=lambda c: c[0] if c else None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo,
):
_collect_top_candidates([g], n=10)
mock_echo.assert_not_called()
class TestDoCheck:
"""Tests for do_check."""
def test_no_assignment(self) -> None:
with patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo:
do_check(Config(), State())
mock_echo.assert_called()
def test_fetch_fails(self) -> None:
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = None
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)
class TestConfidenceHelpers:
"""Coverage-focused tests for scanning confidence helper branches."""
def test_force_refresh_candidate_confidence_delegates(self) -> None:
game = _game(app_id=10, name="A")
with patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._refresh_candidate_confidence_batch",
) as mock_batch:
_force_refresh_candidate_confidence(game)
mock_batch.assert_called_once_with([game], force=True)
def test_refresh_candidate_confidence_batch_no_missing_skips_fetch(self) -> None:
game = _game(app_id=20, name="B", hours=12.0)
game.comp_100_count = 3
game.count_comp = 15
with patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.fetch_hltb_confidence_cached",
) as mock_fetch:
_refresh_candidate_confidence_batch([game], force=False)
mock_fetch.assert_not_called()
def test_refresh_candidate_confidence_batch_preserves_existing_hours(self) -> None:
game = _game(app_id=30, name="C", hours=9.5)
game.comp_100_count = 0
game.count_comp = 0
with (
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.load_hltb_cache",
side_effect=[{30: 9.5}, {30: -1.0}],
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.load_hltb_polls_cache",
return_value={30: 0},
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.load_hltb_count_comp_cache",
return_value={30: 0},
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.fetch_hltb_confidence_cached",
return_value={30: -1.0},
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence.save_hltb_cache",
) as mock_save,
):
_refresh_candidate_confidence_batch([game], force=True)
assert game.completionist_hours == 9.5
saved_cache = mock_save.call_args.args[0]
assert saved_cache[30] == 9.5
def test_filter_hltb_confident_candidates_skips_low_confidence(self) -> None:
low = _game(app_id=40, name="Low", hours=2.0)
low.comp_100_count = 1
low.count_comp = 2
with (
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._refresh_candidate_confidence_batch",
),
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._echo"
) as mock_echo,
):
result = _filter_hltb_confident_candidates([low])
assert result == []
assert mock_echo.called
def test_pick_next_shortest_candidate_logs_skipped_unplayable_batches(self) -> None:
bad = _game(app_id=50, name="Bad", hours=1.0)
good = _game(app_id=51, name="Good", hours=2.0)
bad.comp_100_count = 3
bad.count_comp = 15
good.comp_100_count = 3
good.count_comp = 15
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
side_effect=[None, good],
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo,
):
picked, skipped_low_conf, skipped_linux = _pick_next_shortest_candidate(
[bad, good],
)
assert picked is good
assert skipped_low_conf == 0
assert skipped_linux == 1
assert any(
"Skipped 1 game(s) with poor Linux compatibility" in str(call)
for call in mock_echo.call_args_list
)
def test_pick_next_shortest_candidate_no_echo_when_linux_skipped_zero(
self,
) -> None:
"""Covers 419->423: no echo printed when linux_skipped == 0."""
good = _game(app_id=51, name="Good", hours=2.0)
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
return_value=good,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo,
):
picked, _skipped_low_conf, skipped_linux = _pick_next_shortest_candidate(
[good],
)
assert picked is good
assert skipped_linux == 0
mock_echo.assert_not_called()
def test_pick_next_shortest_candidate_skips_low_confidence(self) -> None:
"""Covers lines 413-414: confidence_skipped += 1; continue."""
low_conf = _game(app_id=10, name="Low", hours=1.0)
low_conf.comp_100_count = 0
low_conf.count_comp = 0
with (
patch(
"python_pkg.steam_backlog_enforcer._scanning_confidence._refresh_candidate_confidence"
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
):
picked, skipped_low_conf, skipped_linux = _pick_next_shortest_candidate(
[low_conf],
)
assert picked is None
assert skipped_low_conf == 1
assert skipped_linux == 0
def test_pick_next_shortest_candidate_all_protondb_fail(self) -> None:
"""Covers lines 426-428: linux_skipped > 0 after loop, return None."""
g1 = _game(app_id=10, name="Borked", hours=1.0)
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning._pick_playable_candidate",
return_value=None,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo") as mock_echo,
):
picked, _skipped_low_conf, skipped_linux = _pick_next_shortest_candidate(
[g1],
)
assert picked is None
assert skipped_linux == 1
assert any(
"Skipped 1 game(s) with poor Linux compatibility" in str(call)
for call in mock_echo.call_args_list
)
game = _game(app_id=440, name="TF2", total=5, unlocked=5)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
snap = [game.to_snapshot()]
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.send_notification",
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_snapshot",
return_value=snap,
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.pick_next_game",
),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)
assert 440 in state.finished_app_ids
def test_complete_no_snapshot(self) -> None:
game = _game(app_id=440, name="TF2", total=5, unlocked=5)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch(
"python_pkg.steam_backlog_enforcer.scanning.send_notification",
),
patch(
"python_pkg.steam_backlog_enforcer.scanning.load_snapshot",
return_value=None,
),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)
def test_not_complete(self) -> None:
game = _game(app_id=440, name="TF2", total=10, unlocked=5)
mock_client = MagicMock()
mock_client.refresh_single_game.return_value = game
with (
patch(
"python_pkg.steam_backlog_enforcer.scanning.SteamAPIClient",
return_value=mock_client,
),
patch("python_pkg.steam_backlog_enforcer.scanning._echo"),
patch("python_pkg.steam_backlog_enforcer.scanning.detect_tampering"),
):
state = State(current_app_id=440, current_game_name="TF2")
do_check(Config(steam_api_key="k", steam_id="i"), state)

View File

@ -48,6 +48,31 @@ def _is_alarm_day() -> bool:
return datetime.now(tz=timezone.utc).weekday() in ALARM_DAYS
def _wake_display() -> None:
"""Force the display on and disable screensaver during alarm."""
xset = shutil.which("xset")
if xset is None:
return
for cmd in (
[xset, "dpms", "force", "on"],
[xset, "s", "off"],
):
subprocess.run(cmd, check=False, capture_output=True, timeout=5)
def _restore_display() -> None:
"""Re-enable screensaver after the alarm ends."""
xset = shutil.which("xset")
if xset is None:
return
subprocess.run(
[xset, "s", "on"],
check=False,
capture_output=True,
timeout=5,
)
def _beep_soft() -> None:
"""Play a soft system beep via terminal bell."""
sys.stdout.write("\a")
@ -119,6 +144,7 @@ class WakeAlarm:
self._stop_beep = threading.Event()
self._beep_thread: threading.Thread | None = None
self._alarm_start: float = time.monotonic()
self._active = True
self.root = tk.Tk()
self.root.title("Wake Alarm" + (" [DEMO]" if demo_mode else ""))
@ -213,6 +239,7 @@ class WakeAlarm:
def _dismiss_alarm(self, *, earned_skip: bool) -> None:
"""Dismiss the alarm and save state."""
self._active = False
self.dismissed = True
self._stop_beep.set()
now_iso = datetime.now(tz=timezone.utc).isoformat()
@ -241,11 +268,12 @@ class WakeAlarm:
def _close(self) -> None:
"""Close the alarm window."""
self._stop_beep.set()
_restore_display()
self.root.destroy()
def _schedule_code_refresh(self) -> None:
"""Refresh the dismiss code periodically."""
if self.dismissed:
if not self._active:
return
self._current_code = _generate_code()
self._code_label.configure(text=self._current_code)
@ -260,8 +288,9 @@ class WakeAlarm:
def _on_dismiss_window_expired(self) -> None:
"""Called when the dismiss window expires without valid dismissal."""
if self.dismissed:
if not self._active:
return
self._active = False
self._stop_beep.set()
save_wake_state(dismissed_at=None, skip_workout=False)
_logger.info("Dismiss window expired — no workout skip.")
@ -281,6 +310,7 @@ class WakeAlarm:
def _close_and_schedule_fallback(self) -> None:
"""Close the window and schedule the 1 PM fallback alarm."""
_restore_display()
self.root.destroy()
def _update_timer(self) -> None:
@ -349,6 +379,7 @@ def main() -> None:
return
demo_mode = "--demo" in sys.argv
_wake_display()
alarm = WakeAlarm(demo_mode=demo_mode)
alarm.run()

View File

@ -24,20 +24,29 @@ RTCWAKE_BIN="/usr/sbin/rtcwake"
echo "=== Weekend Wake Alarm Installer ==="
# 0. Install system dependencies
echo "[0/5] Checking system dependencies..."
if ! command -v speaker-test &>/dev/null; then
echo " Installing alsa-utils (required for speaker-test)..."
sudo pacman -S --noconfirm alsa-utils
else
echo " alsa-utils already installed"
fi
# 1. Install systemd user service
echo "[1/4] Installing systemd user service..."
echo "[1/5] Installing systemd user service..."
mkdir -p "$SYSTEMD_USER_DIR"
cp "$SERVICE_FILE" "$SYSTEMD_USER_DIR/wake-alarm.service"
systemctl --user daemon-reload
echo " Installed to $SYSTEMD_USER_DIR/wake-alarm.service"
# 2. Enable service
echo "[2/4] Enabling wake-alarm.service..."
echo "[2/5] Enabling wake-alarm.service..."
systemctl --user enable wake-alarm.service
echo " Service enabled (will start on next boot)"
# 3. Install systemd-sleep hook (restarts alarm after hibernate resume)
echo "[3/4] Installing systemd-sleep hook..."
echo "[3/5] Installing systemd-sleep hook..."
sudo cp "$SLEEP_HOOK_SRC" "$SLEEP_HOOK_DST"
sudo chmod 0755 "$SLEEP_HOOK_DST"
echo " Installed to $SLEEP_HOOK_DST"
@ -61,7 +70,6 @@ sudo chmod 0755 "$SHUTDOWN_WRAPPER_DST"
echo " Installed to $SHUTDOWN_WRAPPER_DST"
echo " 'shutdown now' will now hibernate (not poweroff) on alarm nights."
echo ""
echo "=== Installation complete ==="
echo "The wake alarm will activate on boot for alarm days (Mon, Fri, Sat, Sun)."
echo "After hibernate resume the sleep hook will restart the alarm service."

View File

@ -12,20 +12,18 @@ if TYPE_CHECKING:
from collections.abc import Generator
from python_pkg.wake_alarm._alarm import (
WakeAlarm,
_beep_loud,
_beep_medium,
_beep_soft,
_generate_code,
_is_alarm_day,
_restore_display,
_should_run_alarm,
_speaker_test_path,
main,
_wake_display,
)
from python_pkg.wake_alarm._constants import (
DISMISS_CODE_LENGTH,
PHASE_MEDIUM_END,
PHASE_SOFT_END,
)
# ---------------------------------------------------------------------------
@ -348,372 +346,29 @@ class TestShouldRunAlarm:
assert _should_run_alarm() is True
class TestWakeAlarmInit:
"""Tests for WakeAlarm initialization."""
class TestDisplayHelpers:
"""Tests for _wake_display and _restore_display when xset is absent."""
def test_demo_mode_sets_smaller_window(
self,
mock_tk_module: MagicMock,
) -> None:
"""Demo mode creates a smaller window."""
alarm = WakeAlarm(demo_mode=True)
assert alarm.demo_mode is True
assert alarm.dismissed is False
alarm._stop_beep.set() # Stop beep thread
def test_production_mode_fullscreen(
self,
mock_tk_module: MagicMock,
) -> None:
"""Production mode activates fullscreen."""
alarm = WakeAlarm(demo_mode=False)
assert alarm.demo_mode is False
mock_root = mock_tk_module.Tk.return_value
mock_root.overrideredirect.assert_called_once()
alarm._stop_beep.set()
class TestWakeAlarmDismiss:
"""Tests for alarm dismiss logic."""
def test_correct_code_dismisses(
self,
mock_tk_module: MagicMock,
) -> None:
"""Entering the correct code dismisses the alarm."""
alarm = WakeAlarm(demo_mode=True)
code = alarm._current_code
mock_entry = mock_tk_module.Entry.return_value
mock_entry.get.return_value = code
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._on_submit()
assert alarm.dismissed is True
mock_save.assert_called_once()
call_kwargs = mock_save.call_args[1]
assert call_kwargs["skip_workout"] is True
alarm._stop_beep.set()
def test_wrong_code_does_not_dismiss(
self,
mock_tk_module: MagicMock,
) -> None:
"""Entering the wrong code shows error without dismissing."""
alarm = WakeAlarm(demo_mode=True)
mock_entry = mock_tk_module.Entry.return_value
mock_entry.get.return_value = "000000"
# Ensure current code is different
alarm._current_code = "123456"
alarm._on_submit()
assert alarm.dismissed is False
alarm._stop_beep.set()
def test_dismiss_window_expired(
self,
mock_tk_module: MagicMock,
) -> None:
"""Window expiry saves state with no skip."""
alarm = WakeAlarm(demo_mode=True)
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._on_dismiss_window_expired()
assert alarm.dismissed is False
mock_save.assert_called_once_with(
dismissed_at=None,
skip_workout=False,
)
alarm._stop_beep.set()
def test_dismiss_window_expired_noop_if_already_dismissed(
self,
mock_tk_module: MagicMock,
) -> None:
"""Expiry is a no-op if already dismissed."""
alarm = WakeAlarm(demo_mode=True)
alarm.dismissed = True
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._on_dismiss_window_expired()
mock_save.assert_not_called()
alarm._stop_beep.set()
class TestMain:
"""Tests for the main() entry point."""
def test_exits_when_not_alarm_day(self) -> None:
"""main() returns early when not an alarm day."""
with patch(
"python_pkg.wake_alarm._alarm._should_run_alarm",
return_value=False,
):
main() # Should just return without error
def test_creates_alarm_when_should_run(
self,
mock_tk_module: MagicMock,
) -> None:
"""main() creates a WakeAlarm when conditions are met."""
def test_wake_display_skips_when_xset_missing(self) -> None:
"""_wake_display does nothing when xset is not on PATH."""
with (
patch(
"python_pkg.wake_alarm._alarm._should_run_alarm",
return_value=True,
"python_pkg.wake_alarm._alarm.shutil.which",
return_value=None,
),
patch(
"python_pkg.wake_alarm._alarm.sys",
) as mock_sys,
patch.object(WakeAlarm, "run") as mock_run,
patch.object(WakeAlarm, "__init__", return_value=None),
patch("python_pkg.wake_alarm._alarm.subprocess.run") as mock_run,
):
mock_sys.argv = []
main()
mock_run.assert_called_once()
class TestCodeRefreshAndTimer:
"""Tests for code refresh and timer update methods."""
def test_code_refresh_changes_code(
self,
mock_tk_module: MagicMock,
) -> None:
"""Code refresh generates a new code."""
alarm = WakeAlarm(demo_mode=True)
# Call refresh many times — at least one should differ
codes = set()
for _ in range(50):
alarm._schedule_code_refresh()
codes.add(alarm._current_code)
assert len(codes) > 1
alarm._stop_beep.set()
def test_code_refresh_noop_when_dismissed(
self,
mock_tk_module: MagicMock,
) -> None:
"""Code refresh is a no-op after dismissal."""
alarm = WakeAlarm(demo_mode=True)
alarm.dismissed = True
old_code = alarm._current_code
alarm._schedule_code_refresh()
# Code doesn't change because dismissed=True causes early return
assert alarm._current_code == old_code
alarm._stop_beep.set()
def test_update_timer_noop_when_dismissed(
self,
mock_tk_module: MagicMock,
) -> None:
"""Timer update is a no-op after dismissal."""
alarm = WakeAlarm(demo_mode=True)
alarm.dismissed = True
alarm._update_timer() # Should not raise
alarm._stop_beep.set()
class TestBeepLoop:
"""Tests for the beep loop thread."""
def test_beep_loop_stops_on_event(
self,
mock_tk_module: MagicMock,
) -> None:
"""Beep loop exits when stop event is set."""
alarm = WakeAlarm(demo_mode=True)
alarm._stop_beep.set()
# Loop should exit immediately
with patch(
"python_pkg.wake_alarm._alarm._beep_soft",
):
alarm._beep_loop()
alarm._stop_beep.set()
class TestCloseAndFallback:
"""Tests for close and fallback scheduling."""
def test_close_stops_beep_and_destroys(
self,
mock_tk_module: MagicMock,
) -> None:
"""_close sets stop event and destroys root."""
alarm = WakeAlarm(demo_mode=True)
alarm._close()
assert alarm._stop_beep.is_set()
alarm.root.destroy.assert_called()
def test_close_and_schedule_fallback(
self,
mock_tk_module: MagicMock,
) -> None:
"""_close_and_schedule_fallback destroys root."""
alarm = WakeAlarm(demo_mode=True)
alarm._close_and_schedule_fallback()
alarm.root.destroy.assert_called()
alarm._stop_beep.set()
class TestDismissWithoutSkip:
"""Tests for alarm dismiss without earning skip."""
def test_dismiss_without_skip_shows_no_skip_message(
self,
mock_tk_module: MagicMock,
) -> None:
"""Dismissing with earned_skip=False shows appropriate message."""
alarm = WakeAlarm(demo_mode=True)
# Simulate existing child widgets
mock_widget = MagicMock()
alarm._container.winfo_children.return_value = [mock_widget]
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._dismiss_alarm(earned_skip=False)
assert alarm.dismissed is True
mock_save.assert_called_once()
call_kwargs = mock_save.call_args[1]
assert call_kwargs["skip_workout"] is False
mock_widget.destroy.assert_called_once()
alarm._stop_beep.set()
class TestDismissWindowExpiredWidgets:
"""Tests for widget cleanup during dismiss window expiry."""
def test_expired_creates_label(
self,
mock_tk_module: MagicMock,
) -> None:
"""Expiry creates a 'Too late' label and destroys children."""
alarm = WakeAlarm(demo_mode=True)
mock_widget = MagicMock()
alarm._container.winfo_children.return_value = [mock_widget]
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
):
alarm._on_dismiss_window_expired()
mock_widget.destroy.assert_called_once()
mock_tk_module.Label.assert_called()
alarm._stop_beep.set()
class TestBeepLoopPhases:
"""Tests for different beep loop escalation phases."""
def test_medium_phase(
self,
mock_tk_module: MagicMock,
) -> None:
"""Beep loop enters medium phase after PHASE_SOFT_END minutes."""
alarm = WakeAlarm(demo_mode=True)
# Set alarm start to make elapsed > PHASE_SOFT_END minutes
import time as time_mod
alarm._alarm_start = time_mod.monotonic() - (PHASE_SOFT_END + 1) * 60
call_count = 0
def stop_after_one(*_args: object, **_kwargs: object) -> None:
nonlocal call_count
call_count += 1
if call_count >= 1:
alarm._stop_beep.set()
_wake_display()
mock_run.assert_not_called()
def test_restore_display_skips_when_xset_missing(self) -> None:
"""_restore_display does nothing when xset is not on PATH."""
with (
patch(
"python_pkg.wake_alarm._alarm._beep_medium",
side_effect=stop_after_one,
) as mock_beep,
"python_pkg.wake_alarm._alarm.shutil.which",
return_value=None,
),
patch("python_pkg.wake_alarm._alarm.subprocess.run") as mock_run,
):
alarm._beep_loop()
mock_beep.assert_called()
alarm._stop_beep.set()
def test_loud_phase(
self,
mock_tk_module: MagicMock,
) -> None:
"""Beep loop enters loud phase after PHASE_MEDIUM_END minutes."""
alarm = WakeAlarm(demo_mode=True)
import time as time_mod
alarm._alarm_start = time_mod.monotonic() - (PHASE_MEDIUM_END + 1) * 60
call_count = 0
def stop_after_one(*_args: object, **_kwargs: object) -> None:
nonlocal call_count
call_count += 1
if call_count >= 1:
alarm._stop_beep.set()
with (
patch(
"python_pkg.wake_alarm._alarm._beep_loud",
side_effect=stop_after_one,
) as mock_beep,
):
alarm._beep_loop()
mock_beep.assert_called()
alarm._stop_beep.set()
class TestRunMethod:
"""Tests for the run() method."""
def test_run_calls_mainloop(
self,
mock_tk_module: MagicMock,
) -> None:
"""run() calls root.mainloop()."""
alarm = WakeAlarm(demo_mode=True)
alarm.run()
alarm.root.mainloop.assert_called_once()
alarm._stop_beep.set()
class TestUpdateTimerActive:
"""Tests for timer update when alarm is active."""
def test_update_timer_shows_remaining(
self,
mock_tk_module: MagicMock,
) -> None:
"""Timer update shows remaining time when not dismissed."""
alarm = WakeAlarm(demo_mode=True)
alarm._update_timer()
alarm._timer_label.configure.assert_called()
alarm._stop_beep.set()
def test_update_timer_stops_at_zero(
self,
mock_tk_module: MagicMock,
) -> None:
"""Timer stops scheduling when remaining time reaches zero."""
import time as time_mod
alarm = WakeAlarm(demo_mode=True)
# Set alarm start far in the past so remaining = 0
alarm._alarm_start = time_mod.monotonic() - 60 * 60
alarm._update_timer()
# root.after should NOT be called for re-scheduling
# (configure is still called to show 00:00)
alarm._timer_label.configure.assert_called()
alarm._stop_beep.set()
_restore_display()
mock_run.assert_not_called()

View File

@ -0,0 +1,432 @@
"""Tests for _alarm.py — WakeAlarm init, dismiss, run, and beep phases (part 2)."""
from __future__ import annotations
import tkinter as tk
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
if TYPE_CHECKING:
from collections.abc import Generator
from python_pkg.wake_alarm._alarm import (
WakeAlarm,
main,
)
from python_pkg.wake_alarm._constants import (
PHASE_MEDIUM_END,
PHASE_SOFT_END,
)
# ---------------------------------------------------------------------------
# Helpers (duplicated from part 1 so this file is self-contained)
# ---------------------------------------------------------------------------
def _make_mock_tk() -> MagicMock:
"""Build a MagicMock that stands in for the tkinter module."""
mock = MagicMock()
mock_root = MagicMock()
mock_root.winfo_screenwidth.return_value = 1920
mock_root.winfo_screenheight.return_value = 1080
mock.Tk.return_value = mock_root
mock.Frame.return_value = MagicMock()
mock.Label.return_value = MagicMock()
mock.Entry.return_value = MagicMock()
mock.TclError = tk.TclError
mock.END = tk.END
return mock
@pytest.fixture(autouse=True)
def _block_real_tk() -> Generator[MagicMock]:
"""Prevent any real Tk windows in tests."""
mock = _make_mock_tk()
with patch("python_pkg.wake_alarm._alarm.tk", mock):
yield mock
@pytest.fixture
def mock_tk_module() -> Generator[MagicMock]:
"""Provide explicit access to the mocked tk module."""
mock = _make_mock_tk()
with patch("python_pkg.wake_alarm._alarm.tk", mock):
yield mock
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestWakeAlarmInit:
"""Tests for WakeAlarm initialization."""
def test_demo_mode_sets_smaller_window(
self,
mock_tk_module: MagicMock,
) -> None:
"""Demo mode creates a smaller window."""
alarm = WakeAlarm(demo_mode=True)
assert alarm.demo_mode is True
assert alarm.dismissed is False
alarm._stop_beep.set() # Stop beep thread
def test_production_mode_fullscreen(
self,
mock_tk_module: MagicMock,
) -> None:
"""Production mode activates fullscreen."""
alarm = WakeAlarm(demo_mode=False)
assert alarm.demo_mode is False
mock_root = mock_tk_module.Tk.return_value
mock_root.overrideredirect.assert_called_once()
alarm._stop_beep.set()
class TestWakeAlarmDismiss:
"""Tests for alarm dismiss logic."""
def test_correct_code_dismisses(
self,
mock_tk_module: MagicMock,
) -> None:
"""Entering the correct code dismisses the alarm."""
alarm = WakeAlarm(demo_mode=True)
code = alarm._current_code
mock_entry = mock_tk_module.Entry.return_value
mock_entry.get.return_value = code
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._on_submit()
assert alarm.dismissed is True
mock_save.assert_called_once()
call_kwargs = mock_save.call_args[1]
assert call_kwargs["skip_workout"] is True
alarm._stop_beep.set()
def test_wrong_code_does_not_dismiss(
self,
mock_tk_module: MagicMock,
) -> None:
"""Entering the wrong code shows error without dismissing."""
alarm = WakeAlarm(demo_mode=True)
mock_entry = mock_tk_module.Entry.return_value
mock_entry.get.return_value = "000000"
# Ensure current code is different
alarm._current_code = "123456"
alarm._on_submit()
assert alarm.dismissed is False
alarm._stop_beep.set()
def test_dismiss_window_expired(
self,
mock_tk_module: MagicMock,
) -> None:
"""Window expiry saves state with no skip."""
alarm = WakeAlarm(demo_mode=True)
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._on_dismiss_window_expired()
assert alarm.dismissed is False
mock_save.assert_called_once_with(
dismissed_at=None,
skip_workout=False,
)
alarm._stop_beep.set()
def test_dismiss_window_expired_noop_if_not_active(
self,
mock_tk_module: MagicMock,
) -> None:
"""Expiry is a no-op if alarm is no longer active."""
alarm = WakeAlarm(demo_mode=True)
alarm._active = False
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._on_dismiss_window_expired()
mock_save.assert_not_called()
alarm._stop_beep.set()
class TestMain:
"""Tests for the main() entry point."""
def test_exits_when_not_alarm_day(self) -> None:
"""main() returns early when not an alarm day."""
with patch(
"python_pkg.wake_alarm._alarm._should_run_alarm",
return_value=False,
):
main() # Should just return without error
def test_creates_alarm_when_should_run(
self,
mock_tk_module: MagicMock,
) -> None:
"""main() creates a WakeAlarm when conditions are met."""
with (
patch(
"python_pkg.wake_alarm._alarm._should_run_alarm",
return_value=True,
),
patch(
"python_pkg.wake_alarm._alarm.sys",
) as mock_sys,
patch.object(WakeAlarm, "run") as mock_run,
patch.object(WakeAlarm, "__init__", return_value=None),
):
mock_sys.argv = []
main()
mock_run.assert_called_once()
class TestCodeRefreshAndTimer:
"""Tests for code refresh and timer update methods."""
def test_code_refresh_changes_code(
self,
mock_tk_module: MagicMock,
) -> None:
"""Code refresh generates a new code."""
alarm = WakeAlarm(demo_mode=True)
# Call refresh many times — at least one should differ
codes = set()
for _ in range(50):
alarm._schedule_code_refresh()
codes.add(alarm._current_code)
assert len(codes) > 1
alarm._stop_beep.set()
def test_code_refresh_noop_when_not_active(
self,
mock_tk_module: MagicMock,
) -> None:
"""Code refresh is a no-op when alarm is no longer active."""
alarm = WakeAlarm(demo_mode=True)
alarm._active = False
old_code = alarm._current_code
alarm._schedule_code_refresh()
# Code doesn't change because _active=False causes early return
assert alarm._current_code == old_code
alarm._stop_beep.set()
def test_update_timer_noop_when_dismissed(
self,
mock_tk_module: MagicMock,
) -> None:
"""Timer update is a no-op after dismissal."""
alarm = WakeAlarm(demo_mode=True)
alarm.dismissed = True
alarm._update_timer() # Should not raise
alarm._stop_beep.set()
class TestBeepLoop:
"""Tests for the beep loop thread."""
def test_beep_loop_stops_on_event(
self,
mock_tk_module: MagicMock,
) -> None:
"""Beep loop exits when stop event is set."""
alarm = WakeAlarm(demo_mode=True)
alarm._stop_beep.set()
# Loop should exit immediately
with patch(
"python_pkg.wake_alarm._alarm._beep_soft",
):
alarm._beep_loop()
alarm._stop_beep.set()
class TestCloseAndFallback:
"""Tests for close and fallback scheduling."""
def test_close_stops_beep_and_destroys(
self,
mock_tk_module: MagicMock,
) -> None:
"""_close sets stop event and destroys root."""
alarm = WakeAlarm(demo_mode=True)
alarm._close()
assert alarm._stop_beep.is_set()
alarm.root.destroy.assert_called()
def test_close_and_schedule_fallback(
self,
mock_tk_module: MagicMock,
) -> None:
"""_close_and_schedule_fallback destroys root."""
alarm = WakeAlarm(demo_mode=True)
alarm._close_and_schedule_fallback()
alarm.root.destroy.assert_called()
alarm._stop_beep.set()
class TestDismissWithoutSkip:
"""Tests for alarm dismiss without earning skip."""
def test_dismiss_without_skip_shows_no_skip_message(
self,
mock_tk_module: MagicMock,
) -> None:
"""Dismissing with earned_skip=False shows appropriate message."""
alarm = WakeAlarm(demo_mode=True)
# Simulate existing child widgets
mock_widget = MagicMock()
alarm._container.winfo_children.return_value = [mock_widget]
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
) as mock_save:
alarm._dismiss_alarm(earned_skip=False)
assert alarm.dismissed is True
mock_save.assert_called_once()
call_kwargs = mock_save.call_args[1]
assert call_kwargs["skip_workout"] is False
mock_widget.destroy.assert_called_once()
alarm._stop_beep.set()
class TestDismissWindowExpiredWidgets:
"""Tests for widget cleanup during dismiss window expiry."""
def test_expired_creates_label(
self,
mock_tk_module: MagicMock,
) -> None:
"""Expiry creates a 'Too late' label and destroys children."""
alarm = WakeAlarm(demo_mode=True)
mock_widget = MagicMock()
alarm._container.winfo_children.return_value = [mock_widget]
with patch(
"python_pkg.wake_alarm._alarm.save_wake_state",
):
alarm._on_dismiss_window_expired()
mock_widget.destroy.assert_called_once()
mock_tk_module.Label.assert_called()
alarm._stop_beep.set()
class TestBeepLoopPhases:
"""Tests for different beep loop escalation phases."""
def test_medium_phase(
self,
mock_tk_module: MagicMock,
) -> None:
"""Beep loop enters medium phase after PHASE_SOFT_END minutes."""
alarm = WakeAlarm(demo_mode=True)
# Set alarm start to make elapsed > PHASE_SOFT_END minutes
import time as time_mod
alarm._alarm_start = time_mod.monotonic() - (PHASE_SOFT_END + 1) * 60
call_count = 0
def stop_after_one(*_args: object, **_kwargs: object) -> None:
nonlocal call_count
call_count += 1
if call_count >= 1:
alarm._stop_beep.set()
with (
patch(
"python_pkg.wake_alarm._alarm._beep_medium",
side_effect=stop_after_one,
) as mock_beep,
):
alarm._beep_loop()
mock_beep.assert_called()
alarm._stop_beep.set()
def test_loud_phase(
self,
mock_tk_module: MagicMock,
) -> None:
"""Beep loop enters loud phase after PHASE_MEDIUM_END minutes."""
alarm = WakeAlarm(demo_mode=True)
import time as time_mod
alarm._alarm_start = time_mod.monotonic() - (PHASE_MEDIUM_END + 1) * 60
call_count = 0
def stop_after_one(*_args: object, **_kwargs: object) -> None:
nonlocal call_count
call_count += 1
if call_count >= 1:
alarm._stop_beep.set()
with (
patch(
"python_pkg.wake_alarm._alarm._beep_loud",
side_effect=stop_after_one,
) as mock_beep,
):
alarm._beep_loop()
mock_beep.assert_called()
alarm._stop_beep.set()
class TestRunMethod:
"""Tests for the run() method."""
def test_run_calls_mainloop(
self,
mock_tk_module: MagicMock,
) -> None:
"""run() calls root.mainloop()."""
alarm = WakeAlarm(demo_mode=True)
alarm.run()
alarm.root.mainloop.assert_called_once()
alarm._stop_beep.set()
class TestUpdateTimerActive:
"""Tests for timer update when alarm is active."""
def test_update_timer_shows_remaining(
self,
mock_tk_module: MagicMock,
) -> None:
"""Timer update shows remaining time when not dismissed."""
alarm = WakeAlarm(demo_mode=True)
alarm._update_timer()
alarm._timer_label.configure.assert_called()
alarm._stop_beep.set()
def test_update_timer_stops_at_zero(
self,
mock_tk_module: MagicMock,
) -> None:
"""Timer stops scheduling when remaining time reaches zero."""
import time as time_mod
alarm = WakeAlarm(demo_mode=True)
# Set alarm start far in the past so remaining = 0
alarm._alarm_start = time_mod.monotonic() - 60 * 60
alarm._update_timer()
# root.after should NOT be called for re-scheduling
# (configure is still called to show 00:00)
alarm._timer_label.configure.assert_called()
alarm._stop_beep.set()