fix(steam_backlog_enforcer): prevent enforce daemon from deleting assigned game

- Guard enforce_allowed_game() and _guard_installed_games() against
  current_app_id=None so they never treat all games as unauthorized
- Add early return in _enforce_loop_iteration when no game is assigned
- Wrap State.load() in enforce loop with error handling for corrupt files
- Switch all config/cache file writes to atomic (tmpfile + rename)
- Add robust error handling to State.load() for corrupt JSON
- Update tests for new behavior and add coverage for atomic writes
This commit is contained in:
Krzysztof kuhy Rudnicki 2026-03-25 19:19:52 +01:00
parent 2ed98ce4db
commit 61096eded3
9 changed files with 155 additions and 38 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import json
import logging
import time
@ -64,6 +65,8 @@ def _guard_installed_games(allowed_app_id: int | None) -> int:
Returns number of games removed this pass.
"""
if allowed_app_id is None:
return 0
installed = get_installed_games()
count = 0
for app_id, name in installed:
@ -165,6 +168,9 @@ def _enforce_loop_iteration(config: Config, state: State) -> None:
config: Enforcer configuration.
state: Current enforcer state.
"""
if state.current_app_id is None:
return
# A) Kill unauthorized game processes.
if config.kill_unauthorized_games:
violations = enforce_allowed_game(
@ -223,7 +229,12 @@ def do_enforce(config: Config, state: State) -> None:
# Reload state from disk so CLI changes (e.g. new game
# assignment via ``done`` / ``scan``) take effect immediately
# without needing to restart the daemon.
fresh = State.load()
try:
fresh = State.load()
except (json.JSONDecodeError, OSError, ValueError) as exc:
logger.warning("Failed to reload state: %s", exc)
time.sleep(ENFORCE_INTERVAL)
continue
state.current_app_id = fresh.current_app_id
state.current_game_name = fresh.current_game_name
state.finished_app_ids = fresh.finished_app_ids

View File

@ -2,10 +2,14 @@
from __future__ import annotations
import contextlib
from dataclasses import dataclass, field
import json
import logging
import os
from pathlib import Path
import sys
import tempfile
from typing import Any
CONFIG_DIR = Path.home() / ".config" / "steam_backlog_enforcer"
@ -25,6 +29,25 @@ BLOCKED_DOMAINS = [
HOSTS_FILE = Path("/etc/hosts")
logger = logging.getLogger(__name__)
def _atomic_write(path: Path, data: str) -> None:
"""Write data to a file atomically via a temporary file + rename."""
path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
tmp_path = Path(tmp)
try:
os.write(fd, data.encode("utf-8"))
os.close(fd)
tmp_path.replace(path)
except BaseException:
with contextlib.suppress(OSError):
os.close(fd)
with contextlib.suppress(OSError):
tmp_path.unlink()
raise
@dataclass
class Config:
@ -40,9 +63,9 @@ class Config:
def save(self) -> None:
"""Persist config to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
CONFIG_FILE.write_text(
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8"
_atomic_write(
CONFIG_FILE,
json.dumps(self.__dict__, indent=2) + "\n",
)
@classmethod
@ -66,16 +89,20 @@ class State:
def save(self) -> None:
"""Persist state to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
STATE_FILE.write_text(
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8"
_atomic_write(
STATE_FILE,
json.dumps(self.__dict__, indent=2) + "\n",
)
@classmethod
def load(cls) -> State:
"""Load state from disk, or return defaults."""
if STATE_FILE.exists():
data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
try:
data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError, ValueError):
logger.warning("Corrupt state file, using defaults.")
return cls()
return cls(
**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}
)
@ -84,8 +111,10 @@ class State:
def save_snapshot(data: list[dict[str, Any]]) -> None:
"""Save an achievement snapshot to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
SNAPSHOT_FILE.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
_atomic_write(
SNAPSHOT_FILE,
json.dumps(data, indent=2) + "\n",
)
def load_snapshot() -> list[dict[str, Any]] | None:

View File

@ -47,6 +47,8 @@ def enforce_allowed_game(
Returns list of (pid, app_id) that were killed or detected.
"""
if allowed_app_id is None:
return []
running = get_running_steam_game_pids()
violations: list[tuple[int, int]] = []

View File

@ -23,7 +23,7 @@ from typing import Any
import aiohttp
from howlongtobeatpy.HTMLRequests import HTMLRequests
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
logger = logging.getLogger(__name__)
@ -71,11 +71,10 @@ def load_hltb_cache() -> dict[int, float]:
def save_hltb_cache(cache: dict[int, float]) -> None:
"""Save the HLTB cache to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
try:
HLTB_CACHE_FILE.write_text(
_atomic_write(
HLTB_CACHE_FILE,
json.dumps({str(k): v for k, v in cache.items()}, indent=2) + "\n",
encoding="utf-8",
)
except OSError:
logger.exception("Failed to save HLTB cache")

View File

@ -17,7 +17,7 @@ from typing import Any
import aiohttp
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
logger = logging.getLogger(__name__)
@ -92,8 +92,10 @@ def _load_cache() -> dict[str, Any]:
def _save_cache(cache: dict[str, Any]) -> None:
"""Persist the ProtonDB cache."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
PROTONDB_CACHE_FILE.write_text(json.dumps(cache, indent=2) + "\n", encoding="utf-8")
_atomic_write(
PROTONDB_CACHE_FILE,
json.dumps(cache, indent=2) + "\n",
)
async def _fetch_one(

View File

@ -11,6 +11,7 @@ import pytest
from python_pkg.steam_backlog_enforcer.config import (
Config,
State,
_atomic_write,
interactive_setup,
load_snapshot,
save_snapshot,
@ -20,6 +21,49 @@ if TYPE_CHECKING:
from pathlib import Path
class TestAtomicWrite:
"""Tests for _atomic_write."""
def test_writes_file(self, tmp_path: Path) -> None:
target = tmp_path / "out.json"
_atomic_write(target, '{"key": "value"}\n')
assert target.read_text(encoding="utf-8") == '{"key": "value"}\n'
def test_creates_parent_dirs(self, tmp_path: Path) -> None:
target = tmp_path / "sub" / "deep" / "out.json"
_atomic_write(target, "data")
assert target.read_text(encoding="utf-8") == "data"
def test_cleanup_on_write_error(self, tmp_path: Path) -> None:
target = tmp_path / "out.json"
with (
patch(
"python_pkg.steam_backlog_enforcer.config.os.write",
side_effect=OSError("disk full"),
),
pytest.raises(OSError, match="disk full"),
):
_atomic_write(target, "data")
assert not target.exists()
tmp_files = list(tmp_path.glob("*.tmp"))
assert tmp_files == []
def test_cleanup_on_replace_error(self, tmp_path: Path) -> None:
target = tmp_path / "out.json"
with (
patch.object(
type(target),
"replace",
side_effect=OSError("no perm"),
),
pytest.raises(OSError, match="no perm"),
):
_atomic_write(target, "data")
assert not target.exists()
tmp_files = list(tmp_path.glob("*.tmp"))
assert tmp_files == []
class TestConfig:
"""Tests for Config dataclass."""
@ -120,6 +164,14 @@ class TestState:
st = State.load()
assert st.current_app_id is None
def test_load_corrupt(self, tmp_path: Path) -> None:
state_file = tmp_path / "state.json"
state_file.write_text("not valid json{{", encoding="utf-8")
with patch("python_pkg.steam_backlog_enforcer.config.STATE_FILE", state_file):
st = State.load()
assert st.current_app_id is None
assert st.current_game_name == ""
class TestSnapshot:
"""Tests for snapshot save/load."""

View File

@ -102,6 +102,9 @@ class TestGuardInstalledGames:
):
assert _guard_installed_games(440) == 0
def test_allowed_none_skips(self) -> None:
assert _guard_installed_games(None) == 0
class TestEnforceSetup:
"""Tests for _enforce_setup."""
@ -297,8 +300,14 @@ class TestEnforceLoopIteration:
uninstall_other_games=False,
)
state = State(current_app_id=None)
with patch(f"{PKG}.is_game_installed") as mock_installed:
with (
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
patch(f"{PKG}._guard_installed_games") as mock_guard,
patch(f"{PKG}.is_game_installed") as mock_installed,
):
_enforce_loop_iteration(config, state)
mock_enforce.assert_not_called()
mock_guard.assert_not_called()
mock_installed.assert_not_called()
@ -350,3 +359,31 @@ class TestDoEnforce:
):
do_enforce(config, state)
assert call_count == 2
def test_state_load_failure_continues(self) -> None:
"""Corrupt state file should not crash the daemon."""
import json as json_mod
state = State(current_app_id=1, current_game_name="G")
config = Config()
call_count = 0
def load_side_effect() -> State:
nonlocal call_count
call_count += 1
if call_count == 1:
msg = "bad"
raise json_mod.JSONDecodeError(msg, "", 0)
if call_count == 2:
raise KeyboardInterrupt
return State(current_app_id=1) # pragma: no cover
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", side_effect=load_side_effect),
patch(f"{PKG}._enforce_loop_iteration") as mock_iter,
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
mock_iter.assert_not_called()

View File

@ -130,18 +130,8 @@ class TestEnforceAllowedGame:
assert result == [(100, 570)]
def test_allowed_none(self) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.enforcer.get_running_steam_game_pids",
return_value={100: 570},
),
patch(
"python_pkg.steam_backlog_enforcer.enforcer.kill_process"
) as mock_kill,
):
result = enforce_allowed_game(None, kill_unauthorized=True)
assert result == [(100, 570)]
mock_kill.assert_called_once_with(100, 570)
result = enforce_allowed_game(None, kill_unauthorized=True)
assert result == []
class TestKillProcess:

View File

@ -66,14 +66,9 @@ class TestHltbCache:
assert cache_file.exists()
def test_save_cache_os_error(self, tmp_path: Path) -> None:
cache_file = MagicMock()
cache_file.write_text = MagicMock(side_effect=OSError)
with (
patch("python_pkg.steam_backlog_enforcer.hltb.HLTB_CACHE_FILE", cache_file),
patch(
"python_pkg.steam_backlog_enforcer.hltb.CONFIG_DIR",
MagicMock(mkdir=MagicMock()),
),
with patch(
"python_pkg.steam_backlog_enforcer.hltb._atomic_write",
side_effect=OSError("disk full"),
):
save_hltb_cache({440: 10.5}) # Should not raise