fix(steam_backlog_enforcer): prevent enforce daemon from deleting assigned game

- Guard enforce_allowed_game() and _guard_installed_games() against
  current_app_id=None so they never treat all games as unauthorized
- Add early return in _enforce_loop_iteration when no game is assigned
- Wrap State.load() in enforce loop with error handling for corrupt files
- Switch all config/cache file writes to atomic (tmpfile + rename)
- Add robust error handling to State.load() for corrupt JSON
- Update tests for new behavior and add coverage for atomic writes
This commit is contained in:
Krzysztof kuhy Rudnicki 2026-03-25 19:19:52 +01:00
parent 2ed98ce4db
commit 61096eded3
9 changed files with 155 additions and 38 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import time import time
@ -64,6 +65,8 @@ def _guard_installed_games(allowed_app_id: int | None) -> int:
Returns number of games removed this pass. Returns number of games removed this pass.
""" """
if allowed_app_id is None:
return 0
installed = get_installed_games() installed = get_installed_games()
count = 0 count = 0
for app_id, name in installed: for app_id, name in installed:
@ -165,6 +168,9 @@ def _enforce_loop_iteration(config: Config, state: State) -> None:
config: Enforcer configuration. config: Enforcer configuration.
state: Current enforcer state. state: Current enforcer state.
""" """
if state.current_app_id is None:
return
# A) Kill unauthorized game processes. # A) Kill unauthorized game processes.
if config.kill_unauthorized_games: if config.kill_unauthorized_games:
violations = enforce_allowed_game( violations = enforce_allowed_game(
@ -223,7 +229,12 @@ def do_enforce(config: Config, state: State) -> None:
# Reload state from disk so CLI changes (e.g. new game # Reload state from disk so CLI changes (e.g. new game
# assignment via ``done`` / ``scan``) take effect immediately # assignment via ``done`` / ``scan``) take effect immediately
# without needing to restart the daemon. # without needing to restart the daemon.
try:
fresh = State.load() fresh = State.load()
except (json.JSONDecodeError, OSError, ValueError) as exc:
logger.warning("Failed to reload state: %s", exc)
time.sleep(ENFORCE_INTERVAL)
continue
state.current_app_id = fresh.current_app_id state.current_app_id = fresh.current_app_id
state.current_game_name = fresh.current_game_name state.current_game_name = fresh.current_game_name
state.finished_app_ids = fresh.finished_app_ids state.finished_app_ids = fresh.finished_app_ids

View File

@ -2,10 +2,14 @@
from __future__ import annotations from __future__ import annotations
import contextlib
from dataclasses import dataclass, field from dataclasses import dataclass, field
import json import json
import logging
import os
from pathlib import Path from pathlib import Path
import sys import sys
import tempfile
from typing import Any from typing import Any
CONFIG_DIR = Path.home() / ".config" / "steam_backlog_enforcer" CONFIG_DIR = Path.home() / ".config" / "steam_backlog_enforcer"
@ -25,6 +29,25 @@ BLOCKED_DOMAINS = [
HOSTS_FILE = Path("/etc/hosts") HOSTS_FILE = Path("/etc/hosts")
logger = logging.getLogger(__name__)
def _atomic_write(path: Path, data: str) -> None:
"""Write data to a file atomically via a temporary file + rename."""
path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
tmp_path = Path(tmp)
try:
os.write(fd, data.encode("utf-8"))
os.close(fd)
tmp_path.replace(path)
except BaseException:
with contextlib.suppress(OSError):
os.close(fd)
with contextlib.suppress(OSError):
tmp_path.unlink()
raise
@dataclass @dataclass
class Config: class Config:
@ -40,9 +63,9 @@ class Config:
def save(self) -> None: def save(self) -> None:
"""Persist config to disk.""" """Persist config to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True) _atomic_write(
CONFIG_FILE.write_text( CONFIG_FILE,
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8" json.dumps(self.__dict__, indent=2) + "\n",
) )
@classmethod @classmethod
@ -66,16 +89,20 @@ class State:
def save(self) -> None: def save(self) -> None:
"""Persist state to disk.""" """Persist state to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True) _atomic_write(
STATE_FILE.write_text( STATE_FILE,
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8" json.dumps(self.__dict__, indent=2) + "\n",
) )
@classmethod @classmethod
def load(cls) -> State: def load(cls) -> State:
"""Load state from disk, or return defaults.""" """Load state from disk, or return defaults."""
if STATE_FILE.exists(): if STATE_FILE.exists():
try:
data = json.loads(STATE_FILE.read_text(encoding="utf-8")) data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError, ValueError):
logger.warning("Corrupt state file, using defaults.")
return cls()
return cls( return cls(
**{k: v for k, v in data.items() if k in cls.__dataclass_fields__} **{k: v for k, v in data.items() if k in cls.__dataclass_fields__}
) )
@ -84,8 +111,10 @@ class State:
def save_snapshot(data: list[dict[str, Any]]) -> None: def save_snapshot(data: list[dict[str, Any]]) -> None:
"""Save an achievement snapshot to disk.""" """Save an achievement snapshot to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True) _atomic_write(
SNAPSHOT_FILE.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8") SNAPSHOT_FILE,
json.dumps(data, indent=2) + "\n",
)
def load_snapshot() -> list[dict[str, Any]] | None: def load_snapshot() -> list[dict[str, Any]] | None:

View File

@ -47,6 +47,8 @@ def enforce_allowed_game(
Returns list of (pid, app_id) that were killed or detected. Returns list of (pid, app_id) that were killed or detected.
""" """
if allowed_app_id is None:
return []
running = get_running_steam_game_pids() running = get_running_steam_game_pids()
violations: list[tuple[int, int]] = [] violations: list[tuple[int, int]] = []

View File

@ -23,7 +23,7 @@ from typing import Any
import aiohttp import aiohttp
from howlongtobeatpy.HTMLRequests import HTMLRequests from howlongtobeatpy.HTMLRequests import HTMLRequests
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,11 +71,10 @@ def load_hltb_cache() -> dict[int, float]:
def save_hltb_cache(cache: dict[int, float]) -> None: def save_hltb_cache(cache: dict[int, float]) -> None:
"""Save the HLTB cache to disk.""" """Save the HLTB cache to disk."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
try: try:
HLTB_CACHE_FILE.write_text( _atomic_write(
HLTB_CACHE_FILE,
json.dumps({str(k): v for k, v in cache.items()}, indent=2) + "\n", json.dumps({str(k): v for k, v in cache.items()}, indent=2) + "\n",
encoding="utf-8",
) )
except OSError: except OSError:
logger.exception("Failed to save HLTB cache") logger.exception("Failed to save HLTB cache")

View File

@ -17,7 +17,7 @@ from typing import Any
import aiohttp import aiohttp
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -92,8 +92,10 @@ def _load_cache() -> dict[str, Any]:
def _save_cache(cache: dict[str, Any]) -> None: def _save_cache(cache: dict[str, Any]) -> None:
"""Persist the ProtonDB cache.""" """Persist the ProtonDB cache."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True) _atomic_write(
PROTONDB_CACHE_FILE.write_text(json.dumps(cache, indent=2) + "\n", encoding="utf-8") PROTONDB_CACHE_FILE,
json.dumps(cache, indent=2) + "\n",
)
async def _fetch_one( async def _fetch_one(

View File

@ -11,6 +11,7 @@ import pytest
from python_pkg.steam_backlog_enforcer.config import ( from python_pkg.steam_backlog_enforcer.config import (
Config, Config,
State, State,
_atomic_write,
interactive_setup, interactive_setup,
load_snapshot, load_snapshot,
save_snapshot, save_snapshot,
@ -20,6 +21,49 @@ if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
class TestAtomicWrite:
"""Tests for _atomic_write."""
def test_writes_file(self, tmp_path: Path) -> None:
target = tmp_path / "out.json"
_atomic_write(target, '{"key": "value"}\n')
assert target.read_text(encoding="utf-8") == '{"key": "value"}\n'
def test_creates_parent_dirs(self, tmp_path: Path) -> None:
target = tmp_path / "sub" / "deep" / "out.json"
_atomic_write(target, "data")
assert target.read_text(encoding="utf-8") == "data"
def test_cleanup_on_write_error(self, tmp_path: Path) -> None:
target = tmp_path / "out.json"
with (
patch(
"python_pkg.steam_backlog_enforcer.config.os.write",
side_effect=OSError("disk full"),
),
pytest.raises(OSError, match="disk full"),
):
_atomic_write(target, "data")
assert not target.exists()
tmp_files = list(tmp_path.glob("*.tmp"))
assert tmp_files == []
def test_cleanup_on_replace_error(self, tmp_path: Path) -> None:
target = tmp_path / "out.json"
with (
patch.object(
type(target),
"replace",
side_effect=OSError("no perm"),
),
pytest.raises(OSError, match="no perm"),
):
_atomic_write(target, "data")
assert not target.exists()
tmp_files = list(tmp_path.glob("*.tmp"))
assert tmp_files == []
class TestConfig: class TestConfig:
"""Tests for Config dataclass.""" """Tests for Config dataclass."""
@ -120,6 +164,14 @@ class TestState:
st = State.load() st = State.load()
assert st.current_app_id is None assert st.current_app_id is None
def test_load_corrupt(self, tmp_path: Path) -> None:
state_file = tmp_path / "state.json"
state_file.write_text("not valid json{{", encoding="utf-8")
with patch("python_pkg.steam_backlog_enforcer.config.STATE_FILE", state_file):
st = State.load()
assert st.current_app_id is None
assert st.current_game_name == ""
class TestSnapshot: class TestSnapshot:
"""Tests for snapshot save/load.""" """Tests for snapshot save/load."""

View File

@ -102,6 +102,9 @@ class TestGuardInstalledGames:
): ):
assert _guard_installed_games(440) == 0 assert _guard_installed_games(440) == 0
def test_allowed_none_skips(self) -> None:
assert _guard_installed_games(None) == 0
class TestEnforceSetup: class TestEnforceSetup:
"""Tests for _enforce_setup.""" """Tests for _enforce_setup."""
@ -297,8 +300,14 @@ class TestEnforceLoopIteration:
uninstall_other_games=False, uninstall_other_games=False,
) )
state = State(current_app_id=None) state = State(current_app_id=None)
with patch(f"{PKG}.is_game_installed") as mock_installed: with (
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
patch(f"{PKG}._guard_installed_games") as mock_guard,
patch(f"{PKG}.is_game_installed") as mock_installed,
):
_enforce_loop_iteration(config, state) _enforce_loop_iteration(config, state)
mock_enforce.assert_not_called()
mock_guard.assert_not_called()
mock_installed.assert_not_called() mock_installed.assert_not_called()
@ -350,3 +359,31 @@ class TestDoEnforce:
): ):
do_enforce(config, state) do_enforce(config, state)
assert call_count == 2 assert call_count == 2
def test_state_load_failure_continues(self) -> None:
"""Corrupt state file should not crash the daemon."""
import json as json_mod
state = State(current_app_id=1, current_game_name="G")
config = Config()
call_count = 0
def load_side_effect() -> State:
nonlocal call_count
call_count += 1
if call_count == 1:
msg = "bad"
raise json_mod.JSONDecodeError(msg, "", 0)
if call_count == 2:
raise KeyboardInterrupt
return State(current_app_id=1) # pragma: no cover
with (
patch(f"{PKG}._enforce_setup"),
patch(f"{PKG}._echo"),
patch.object(State, "load", side_effect=load_side_effect),
patch(f"{PKG}._enforce_loop_iteration") as mock_iter,
patch(f"{PKG}.time.sleep"),
):
do_enforce(config, state)
mock_iter.assert_not_called()

View File

@ -130,18 +130,8 @@ class TestEnforceAllowedGame:
assert result == [(100, 570)] assert result == [(100, 570)]
def test_allowed_none(self) -> None: def test_allowed_none(self) -> None:
with (
patch(
"python_pkg.steam_backlog_enforcer.enforcer.get_running_steam_game_pids",
return_value={100: 570},
),
patch(
"python_pkg.steam_backlog_enforcer.enforcer.kill_process"
) as mock_kill,
):
result = enforce_allowed_game(None, kill_unauthorized=True) result = enforce_allowed_game(None, kill_unauthorized=True)
assert result == [(100, 570)] assert result == []
mock_kill.assert_called_once_with(100, 570)
class TestKillProcess: class TestKillProcess:

View File

@ -66,14 +66,9 @@ class TestHltbCache:
assert cache_file.exists() assert cache_file.exists()
def test_save_cache_os_error(self, tmp_path: Path) -> None: def test_save_cache_os_error(self, tmp_path: Path) -> None:
cache_file = MagicMock() with patch(
cache_file.write_text = MagicMock(side_effect=OSError) "python_pkg.steam_backlog_enforcer.hltb._atomic_write",
with ( side_effect=OSError("disk full"),
patch("python_pkg.steam_backlog_enforcer.hltb.HLTB_CACHE_FILE", cache_file),
patch(
"python_pkg.steam_backlog_enforcer.hltb.CONFIG_DIR",
MagicMock(mkdir=MagicMock()),
),
): ):
save_hltb_cache({440: 10.5}) # Should not raise save_hltb_cache({440: 10.5}) # Should not raise