mirror of
https://github.com/kuhyx/steam-backlog-enforcer.git
synced 2026-07-04 15:43:09 +02:00
fix(steam_backlog_enforcer): prevent enforce daemon from deleting assigned game
- Guard enforce_allowed_game() and _guard_installed_games() against current_app_id=None so they never treat all games as unauthorized - Add early return in _enforce_loop_iteration when no game is assigned - Wrap State.load() in enforce loop with error handling for corrupt files - Switch all config/cache file writes to atomic (tmpfile + rename) - Add robust error handling to State.load() for corrupt JSON - Update tests for new behavior and add coverage for atomic writes
This commit is contained in:
parent
2ed98ce4db
commit
61096eded3
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
@ -64,6 +65,8 @@ def _guard_installed_games(allowed_app_id: int | None) -> int:
|
||||
|
||||
Returns number of games removed this pass.
|
||||
"""
|
||||
if allowed_app_id is None:
|
||||
return 0
|
||||
installed = get_installed_games()
|
||||
count = 0
|
||||
for app_id, name in installed:
|
||||
@ -165,6 +168,9 @@ def _enforce_loop_iteration(config: Config, state: State) -> None:
|
||||
config: Enforcer configuration.
|
||||
state: Current enforcer state.
|
||||
"""
|
||||
if state.current_app_id is None:
|
||||
return
|
||||
|
||||
# A) Kill unauthorized game processes.
|
||||
if config.kill_unauthorized_games:
|
||||
violations = enforce_allowed_game(
|
||||
@ -223,7 +229,12 @@ def do_enforce(config: Config, state: State) -> None:
|
||||
# Reload state from disk so CLI changes (e.g. new game
|
||||
# assignment via ``done`` / ``scan``) take effect immediately
|
||||
# without needing to restart the daemon.
|
||||
fresh = State.load()
|
||||
try:
|
||||
fresh = State.load()
|
||||
except (json.JSONDecodeError, OSError, ValueError) as exc:
|
||||
logger.warning("Failed to reload state: %s", exc)
|
||||
time.sleep(ENFORCE_INTERVAL)
|
||||
continue
|
||||
state.current_app_id = fresh.current_app_id
|
||||
state.current_game_name = fresh.current_game_name
|
||||
state.finished_app_ids = fresh.finished_app_ids
|
||||
|
||||
@ -2,10 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
CONFIG_DIR = Path.home() / ".config" / "steam_backlog_enforcer"
|
||||
@ -25,6 +29,25 @@ BLOCKED_DOMAINS = [
|
||||
|
||||
HOSTS_FILE = Path("/etc/hosts")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _atomic_write(path: Path, data: str) -> None:
|
||||
"""Write data to a file atomically via a temporary file + rename."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
|
||||
tmp_path = Path(tmp)
|
||||
try:
|
||||
os.write(fd, data.encode("utf-8"))
|
||||
os.close(fd)
|
||||
tmp_path.replace(path)
|
||||
except BaseException:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
with contextlib.suppress(OSError):
|
||||
tmp_path.unlink()
|
||||
raise
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
@ -40,9 +63,9 @@ class Config:
|
||||
|
||||
def save(self) -> None:
|
||||
"""Persist config to disk."""
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_FILE.write_text(
|
||||
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8"
|
||||
_atomic_write(
|
||||
CONFIG_FILE,
|
||||
json.dumps(self.__dict__, indent=2) + "\n",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -66,16 +89,20 @@ class State:
|
||||
|
||||
def save(self) -> None:
|
||||
"""Persist state to disk."""
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
STATE_FILE.write_text(
|
||||
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8"
|
||||
_atomic_write(
|
||||
STATE_FILE,
|
||||
json.dumps(self.__dict__, indent=2) + "\n",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> State:
|
||||
"""Load state from disk, or return defaults."""
|
||||
if STATE_FILE.exists():
|
||||
data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
||||
try:
|
||||
data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError, ValueError):
|
||||
logger.warning("Corrupt state file, using defaults.")
|
||||
return cls()
|
||||
return cls(
|
||||
**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}
|
||||
)
|
||||
@ -84,8 +111,10 @@ class State:
|
||||
|
||||
def save_snapshot(data: list[dict[str, Any]]) -> None:
|
||||
"""Save an achievement snapshot to disk."""
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SNAPSHOT_FILE.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
|
||||
_atomic_write(
|
||||
SNAPSHOT_FILE,
|
||||
json.dumps(data, indent=2) + "\n",
|
||||
)
|
||||
|
||||
|
||||
def load_snapshot() -> list[dict[str, Any]] | None:
|
||||
|
||||
@ -47,6 +47,8 @@ def enforce_allowed_game(
|
||||
|
||||
Returns list of (pid, app_id) that were killed or detected.
|
||||
"""
|
||||
if allowed_app_id is None:
|
||||
return []
|
||||
running = get_running_steam_game_pids()
|
||||
violations: list[tuple[int, int]] = []
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ from typing import Any
|
||||
import aiohttp
|
||||
from howlongtobeatpy.HTMLRequests import HTMLRequests
|
||||
|
||||
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR
|
||||
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -71,11 +71,10 @@ def load_hltb_cache() -> dict[int, float]:
|
||||
|
||||
def save_hltb_cache(cache: dict[int, float]) -> None:
|
||||
"""Save the HLTB cache to disk."""
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
HLTB_CACHE_FILE.write_text(
|
||||
_atomic_write(
|
||||
HLTB_CACHE_FILE,
|
||||
json.dumps({str(k): v for k, v in cache.items()}, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
except OSError:
|
||||
logger.exception("Failed to save HLTB cache")
|
||||
|
||||
@ -17,7 +17,7 @@ from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR
|
||||
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -92,8 +92,10 @@ def _load_cache() -> dict[str, Any]:
|
||||
|
||||
def _save_cache(cache: dict[str, Any]) -> None:
|
||||
"""Persist the ProtonDB cache."""
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
PROTONDB_CACHE_FILE.write_text(json.dumps(cache, indent=2) + "\n", encoding="utf-8")
|
||||
_atomic_write(
|
||||
PROTONDB_CACHE_FILE,
|
||||
json.dumps(cache, indent=2) + "\n",
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_one(
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
from python_pkg.steam_backlog_enforcer.config import (
|
||||
Config,
|
||||
State,
|
||||
_atomic_write,
|
||||
interactive_setup,
|
||||
load_snapshot,
|
||||
save_snapshot,
|
||||
@ -20,6 +21,49 @@ if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestAtomicWrite:
|
||||
"""Tests for _atomic_write."""
|
||||
|
||||
def test_writes_file(self, tmp_path: Path) -> None:
|
||||
target = tmp_path / "out.json"
|
||||
_atomic_write(target, '{"key": "value"}\n')
|
||||
assert target.read_text(encoding="utf-8") == '{"key": "value"}\n'
|
||||
|
||||
def test_creates_parent_dirs(self, tmp_path: Path) -> None:
|
||||
target = tmp_path / "sub" / "deep" / "out.json"
|
||||
_atomic_write(target, "data")
|
||||
assert target.read_text(encoding="utf-8") == "data"
|
||||
|
||||
def test_cleanup_on_write_error(self, tmp_path: Path) -> None:
|
||||
target = tmp_path / "out.json"
|
||||
with (
|
||||
patch(
|
||||
"python_pkg.steam_backlog_enforcer.config.os.write",
|
||||
side_effect=OSError("disk full"),
|
||||
),
|
||||
pytest.raises(OSError, match="disk full"),
|
||||
):
|
||||
_atomic_write(target, "data")
|
||||
assert not target.exists()
|
||||
tmp_files = list(tmp_path.glob("*.tmp"))
|
||||
assert tmp_files == []
|
||||
|
||||
def test_cleanup_on_replace_error(self, tmp_path: Path) -> None:
|
||||
target = tmp_path / "out.json"
|
||||
with (
|
||||
patch.object(
|
||||
type(target),
|
||||
"replace",
|
||||
side_effect=OSError("no perm"),
|
||||
),
|
||||
pytest.raises(OSError, match="no perm"),
|
||||
):
|
||||
_atomic_write(target, "data")
|
||||
assert not target.exists()
|
||||
tmp_files = list(tmp_path.glob("*.tmp"))
|
||||
assert tmp_files == []
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Tests for Config dataclass."""
|
||||
|
||||
@ -120,6 +164,14 @@ class TestState:
|
||||
st = State.load()
|
||||
assert st.current_app_id is None
|
||||
|
||||
def test_load_corrupt(self, tmp_path: Path) -> None:
|
||||
state_file = tmp_path / "state.json"
|
||||
state_file.write_text("not valid json{{", encoding="utf-8")
|
||||
with patch("python_pkg.steam_backlog_enforcer.config.STATE_FILE", state_file):
|
||||
st = State.load()
|
||||
assert st.current_app_id is None
|
||||
assert st.current_game_name == ""
|
||||
|
||||
|
||||
class TestSnapshot:
|
||||
"""Tests for snapshot save/load."""
|
||||
|
||||
@ -102,6 +102,9 @@ class TestGuardInstalledGames:
|
||||
):
|
||||
assert _guard_installed_games(440) == 0
|
||||
|
||||
def test_allowed_none_skips(self) -> None:
|
||||
assert _guard_installed_games(None) == 0
|
||||
|
||||
|
||||
class TestEnforceSetup:
|
||||
"""Tests for _enforce_setup."""
|
||||
@ -297,8 +300,14 @@ class TestEnforceLoopIteration:
|
||||
uninstall_other_games=False,
|
||||
)
|
||||
state = State(current_app_id=None)
|
||||
with patch(f"{PKG}.is_game_installed") as mock_installed:
|
||||
with (
|
||||
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
|
||||
patch(f"{PKG}._guard_installed_games") as mock_guard,
|
||||
patch(f"{PKG}.is_game_installed") as mock_installed,
|
||||
):
|
||||
_enforce_loop_iteration(config, state)
|
||||
mock_enforce.assert_not_called()
|
||||
mock_guard.assert_not_called()
|
||||
mock_installed.assert_not_called()
|
||||
|
||||
|
||||
@ -350,3 +359,31 @@ class TestDoEnforce:
|
||||
):
|
||||
do_enforce(config, state)
|
||||
assert call_count == 2
|
||||
|
||||
def test_state_load_failure_continues(self) -> None:
|
||||
"""Corrupt state file should not crash the daemon."""
|
||||
import json as json_mod
|
||||
|
||||
state = State(current_app_id=1, current_game_name="G")
|
||||
config = Config()
|
||||
call_count = 0
|
||||
|
||||
def load_side_effect() -> State:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
msg = "bad"
|
||||
raise json_mod.JSONDecodeError(msg, "", 0)
|
||||
if call_count == 2:
|
||||
raise KeyboardInterrupt
|
||||
return State(current_app_id=1) # pragma: no cover
|
||||
|
||||
with (
|
||||
patch(f"{PKG}._enforce_setup"),
|
||||
patch(f"{PKG}._echo"),
|
||||
patch.object(State, "load", side_effect=load_side_effect),
|
||||
patch(f"{PKG}._enforce_loop_iteration") as mock_iter,
|
||||
patch(f"{PKG}.time.sleep"),
|
||||
):
|
||||
do_enforce(config, state)
|
||||
mock_iter.assert_not_called()
|
||||
|
||||
@ -130,18 +130,8 @@ class TestEnforceAllowedGame:
|
||||
assert result == [(100, 570)]
|
||||
|
||||
def test_allowed_none(self) -> None:
|
||||
with (
|
||||
patch(
|
||||
"python_pkg.steam_backlog_enforcer.enforcer.get_running_steam_game_pids",
|
||||
return_value={100: 570},
|
||||
),
|
||||
patch(
|
||||
"python_pkg.steam_backlog_enforcer.enforcer.kill_process"
|
||||
) as mock_kill,
|
||||
):
|
||||
result = enforce_allowed_game(None, kill_unauthorized=True)
|
||||
assert result == [(100, 570)]
|
||||
mock_kill.assert_called_once_with(100, 570)
|
||||
result = enforce_allowed_game(None, kill_unauthorized=True)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestKillProcess:
|
||||
|
||||
@ -66,14 +66,9 @@ class TestHltbCache:
|
||||
assert cache_file.exists()
|
||||
|
||||
def test_save_cache_os_error(self, tmp_path: Path) -> None:
|
||||
cache_file = MagicMock()
|
||||
cache_file.write_text = MagicMock(side_effect=OSError)
|
||||
with (
|
||||
patch("python_pkg.steam_backlog_enforcer.hltb.HLTB_CACHE_FILE", cache_file),
|
||||
patch(
|
||||
"python_pkg.steam_backlog_enforcer.hltb.CONFIG_DIR",
|
||||
MagicMock(mkdir=MagicMock()),
|
||||
),
|
||||
with patch(
|
||||
"python_pkg.steam_backlog_enforcer.hltb._atomic_write",
|
||||
side_effect=OSError("disk full"),
|
||||
):
|
||||
save_hltb_cache({440: 10.5}) # Should not raise
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user