mirror of
https://github.com/kuhyx/steam-backlog-enforcer.git
synced 2026-07-04 15:23:05 +02:00
fix(steam_backlog_enforcer): prevent enforce daemon from deleting assigned game
- Guard enforce_allowed_game() and _guard_installed_games() against current_app_id=None so they never treat all games as unauthorized - Add early return in _enforce_loop_iteration when no game is assigned - Wrap State.load() in enforce loop with error handling for corrupt files - Switch all config/cache file writes to atomic (tmpfile + rename) - Add robust error handling to State.load() for corrupt JSON - Update tests for new behavior and add coverage for atomic writes
This commit is contained in:
parent
2ed98ce4db
commit
61096eded3
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -64,6 +65,8 @@ def _guard_installed_games(allowed_app_id: int | None) -> int:
|
|||||||
|
|
||||||
Returns number of games removed this pass.
|
Returns number of games removed this pass.
|
||||||
"""
|
"""
|
||||||
|
if allowed_app_id is None:
|
||||||
|
return 0
|
||||||
installed = get_installed_games()
|
installed = get_installed_games()
|
||||||
count = 0
|
count = 0
|
||||||
for app_id, name in installed:
|
for app_id, name in installed:
|
||||||
@ -165,6 +168,9 @@ def _enforce_loop_iteration(config: Config, state: State) -> None:
|
|||||||
config: Enforcer configuration.
|
config: Enforcer configuration.
|
||||||
state: Current enforcer state.
|
state: Current enforcer state.
|
||||||
"""
|
"""
|
||||||
|
if state.current_app_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
# A) Kill unauthorized game processes.
|
# A) Kill unauthorized game processes.
|
||||||
if config.kill_unauthorized_games:
|
if config.kill_unauthorized_games:
|
||||||
violations = enforce_allowed_game(
|
violations = enforce_allowed_game(
|
||||||
@ -223,7 +229,12 @@ def do_enforce(config: Config, state: State) -> None:
|
|||||||
# Reload state from disk so CLI changes (e.g. new game
|
# Reload state from disk so CLI changes (e.g. new game
|
||||||
# assignment via ``done`` / ``scan``) take effect immediately
|
# assignment via ``done`` / ``scan``) take effect immediately
|
||||||
# without needing to restart the daemon.
|
# without needing to restart the daemon.
|
||||||
fresh = State.load()
|
try:
|
||||||
|
fresh = State.load()
|
||||||
|
except (json.JSONDecodeError, OSError, ValueError) as exc:
|
||||||
|
logger.warning("Failed to reload state: %s", exc)
|
||||||
|
time.sleep(ENFORCE_INTERVAL)
|
||||||
|
continue
|
||||||
state.current_app_id = fresh.current_app_id
|
state.current_app_id = fresh.current_app_id
|
||||||
state.current_game_name = fresh.current_game_name
|
state.current_game_name = fresh.current_game_name
|
||||||
state.finished_app_ids = fresh.finished_app_ids
|
state.finished_app_ids = fresh.finished_app_ids
|
||||||
|
|||||||
@ -2,10 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
CONFIG_DIR = Path.home() / ".config" / "steam_backlog_enforcer"
|
CONFIG_DIR = Path.home() / ".config" / "steam_backlog_enforcer"
|
||||||
@ -25,6 +29,25 @@ BLOCKED_DOMAINS = [
|
|||||||
|
|
||||||
HOSTS_FILE = Path("/etc/hosts")
|
HOSTS_FILE = Path("/etc/hosts")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _atomic_write(path: Path, data: str) -> None:
|
||||||
|
"""Write data to a file atomically via a temporary file + rename."""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
try:
|
||||||
|
os.write(fd, data.encode("utf-8"))
|
||||||
|
os.close(fd)
|
||||||
|
tmp_path.replace(path)
|
||||||
|
except BaseException:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(fd)
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
tmp_path.unlink()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
@ -40,9 +63,9 @@ class Config:
|
|||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
"""Persist config to disk."""
|
"""Persist config to disk."""
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
_atomic_write(
|
||||||
CONFIG_FILE.write_text(
|
CONFIG_FILE,
|
||||||
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8"
|
json.dumps(self.__dict__, indent=2) + "\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -66,16 +89,20 @@ class State:
|
|||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
"""Persist state to disk."""
|
"""Persist state to disk."""
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
_atomic_write(
|
||||||
STATE_FILE.write_text(
|
STATE_FILE,
|
||||||
json.dumps(self.__dict__, indent=2) + "\n", encoding="utf-8"
|
json.dumps(self.__dict__, indent=2) + "\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls) -> State:
|
def load(cls) -> State:
|
||||||
"""Load state from disk, or return defaults."""
|
"""Load state from disk, or return defaults."""
|
||||||
if STATE_FILE.exists():
|
if STATE_FILE.exists():
|
||||||
data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
try:
|
||||||
|
data = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, OSError, ValueError):
|
||||||
|
logger.warning("Corrupt state file, using defaults.")
|
||||||
|
return cls()
|
||||||
return cls(
|
return cls(
|
||||||
**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}
|
**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}
|
||||||
)
|
)
|
||||||
@ -84,8 +111,10 @@ class State:
|
|||||||
|
|
||||||
def save_snapshot(data: list[dict[str, Any]]) -> None:
|
def save_snapshot(data: list[dict[str, Any]]) -> None:
|
||||||
"""Save an achievement snapshot to disk."""
|
"""Save an achievement snapshot to disk."""
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
_atomic_write(
|
||||||
SNAPSHOT_FILE.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
|
SNAPSHOT_FILE,
|
||||||
|
json.dumps(data, indent=2) + "\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_snapshot() -> list[dict[str, Any]] | None:
|
def load_snapshot() -> list[dict[str, Any]] | None:
|
||||||
|
|||||||
@ -47,6 +47,8 @@ def enforce_allowed_game(
|
|||||||
|
|
||||||
Returns list of (pid, app_id) that were killed or detected.
|
Returns list of (pid, app_id) that were killed or detected.
|
||||||
"""
|
"""
|
||||||
|
if allowed_app_id is None:
|
||||||
|
return []
|
||||||
running = get_running_steam_game_pids()
|
running = get_running_steam_game_pids()
|
||||||
violations: list[tuple[int, int]] = []
|
violations: list[tuple[int, int]] = []
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from typing import Any
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from howlongtobeatpy.HTMLRequests import HTMLRequests
|
from howlongtobeatpy.HTMLRequests import HTMLRequests
|
||||||
|
|
||||||
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR
|
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -71,11 +71,10 @@ def load_hltb_cache() -> dict[int, float]:
|
|||||||
|
|
||||||
def save_hltb_cache(cache: dict[int, float]) -> None:
|
def save_hltb_cache(cache: dict[int, float]) -> None:
|
||||||
"""Save the HLTB cache to disk."""
|
"""Save the HLTB cache to disk."""
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
try:
|
||||||
HLTB_CACHE_FILE.write_text(
|
_atomic_write(
|
||||||
|
HLTB_CACHE_FILE,
|
||||||
json.dumps({str(k): v for k, v in cache.items()}, indent=2) + "\n",
|
json.dumps({str(k): v for k, v in cache.items()}, indent=2) + "\n",
|
||||||
encoding="utf-8",
|
|
||||||
)
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.exception("Failed to save HLTB cache")
|
logger.exception("Failed to save HLTB cache")
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from typing import Any
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR
|
from python_pkg.steam_backlog_enforcer.config import CONFIG_DIR, _atomic_write
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -92,8 +92,10 @@ def _load_cache() -> dict[str, Any]:
|
|||||||
|
|
||||||
def _save_cache(cache: dict[str, Any]) -> None:
|
def _save_cache(cache: dict[str, Any]) -> None:
|
||||||
"""Persist the ProtonDB cache."""
|
"""Persist the ProtonDB cache."""
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
_atomic_write(
|
||||||
PROTONDB_CACHE_FILE.write_text(json.dumps(cache, indent=2) + "\n", encoding="utf-8")
|
PROTONDB_CACHE_FILE,
|
||||||
|
json.dumps(cache, indent=2) + "\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_one(
|
async def _fetch_one(
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import pytest
|
|||||||
from python_pkg.steam_backlog_enforcer.config import (
|
from python_pkg.steam_backlog_enforcer.config import (
|
||||||
Config,
|
Config,
|
||||||
State,
|
State,
|
||||||
|
_atomic_write,
|
||||||
interactive_setup,
|
interactive_setup,
|
||||||
load_snapshot,
|
load_snapshot,
|
||||||
save_snapshot,
|
save_snapshot,
|
||||||
@ -20,6 +21,49 @@ if TYPE_CHECKING:
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class TestAtomicWrite:
|
||||||
|
"""Tests for _atomic_write."""
|
||||||
|
|
||||||
|
def test_writes_file(self, tmp_path: Path) -> None:
|
||||||
|
target = tmp_path / "out.json"
|
||||||
|
_atomic_write(target, '{"key": "value"}\n')
|
||||||
|
assert target.read_text(encoding="utf-8") == '{"key": "value"}\n'
|
||||||
|
|
||||||
|
def test_creates_parent_dirs(self, tmp_path: Path) -> None:
|
||||||
|
target = tmp_path / "sub" / "deep" / "out.json"
|
||||||
|
_atomic_write(target, "data")
|
||||||
|
assert target.read_text(encoding="utf-8") == "data"
|
||||||
|
|
||||||
|
def test_cleanup_on_write_error(self, tmp_path: Path) -> None:
|
||||||
|
target = tmp_path / "out.json"
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"python_pkg.steam_backlog_enforcer.config.os.write",
|
||||||
|
side_effect=OSError("disk full"),
|
||||||
|
),
|
||||||
|
pytest.raises(OSError, match="disk full"),
|
||||||
|
):
|
||||||
|
_atomic_write(target, "data")
|
||||||
|
assert not target.exists()
|
||||||
|
tmp_files = list(tmp_path.glob("*.tmp"))
|
||||||
|
assert tmp_files == []
|
||||||
|
|
||||||
|
def test_cleanup_on_replace_error(self, tmp_path: Path) -> None:
|
||||||
|
target = tmp_path / "out.json"
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
type(target),
|
||||||
|
"replace",
|
||||||
|
side_effect=OSError("no perm"),
|
||||||
|
),
|
||||||
|
pytest.raises(OSError, match="no perm"),
|
||||||
|
):
|
||||||
|
_atomic_write(target, "data")
|
||||||
|
assert not target.exists()
|
||||||
|
tmp_files = list(tmp_path.glob("*.tmp"))
|
||||||
|
assert tmp_files == []
|
||||||
|
|
||||||
|
|
||||||
class TestConfig:
|
class TestConfig:
|
||||||
"""Tests for Config dataclass."""
|
"""Tests for Config dataclass."""
|
||||||
|
|
||||||
@ -120,6 +164,14 @@ class TestState:
|
|||||||
st = State.load()
|
st = State.load()
|
||||||
assert st.current_app_id is None
|
assert st.current_app_id is None
|
||||||
|
|
||||||
|
def test_load_corrupt(self, tmp_path: Path) -> None:
|
||||||
|
state_file = tmp_path / "state.json"
|
||||||
|
state_file.write_text("not valid json{{", encoding="utf-8")
|
||||||
|
with patch("python_pkg.steam_backlog_enforcer.config.STATE_FILE", state_file):
|
||||||
|
st = State.load()
|
||||||
|
assert st.current_app_id is None
|
||||||
|
assert st.current_game_name == ""
|
||||||
|
|
||||||
|
|
||||||
class TestSnapshot:
|
class TestSnapshot:
|
||||||
"""Tests for snapshot save/load."""
|
"""Tests for snapshot save/load."""
|
||||||
|
|||||||
@ -102,6 +102,9 @@ class TestGuardInstalledGames:
|
|||||||
):
|
):
|
||||||
assert _guard_installed_games(440) == 0
|
assert _guard_installed_games(440) == 0
|
||||||
|
|
||||||
|
def test_allowed_none_skips(self) -> None:
|
||||||
|
assert _guard_installed_games(None) == 0
|
||||||
|
|
||||||
|
|
||||||
class TestEnforceSetup:
|
class TestEnforceSetup:
|
||||||
"""Tests for _enforce_setup."""
|
"""Tests for _enforce_setup."""
|
||||||
@ -297,8 +300,14 @@ class TestEnforceLoopIteration:
|
|||||||
uninstall_other_games=False,
|
uninstall_other_games=False,
|
||||||
)
|
)
|
||||||
state = State(current_app_id=None)
|
state = State(current_app_id=None)
|
||||||
with patch(f"{PKG}.is_game_installed") as mock_installed:
|
with (
|
||||||
|
patch(f"{PKG}.enforce_allowed_game") as mock_enforce,
|
||||||
|
patch(f"{PKG}._guard_installed_games") as mock_guard,
|
||||||
|
patch(f"{PKG}.is_game_installed") as mock_installed,
|
||||||
|
):
|
||||||
_enforce_loop_iteration(config, state)
|
_enforce_loop_iteration(config, state)
|
||||||
|
mock_enforce.assert_not_called()
|
||||||
|
mock_guard.assert_not_called()
|
||||||
mock_installed.assert_not_called()
|
mock_installed.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@ -350,3 +359,31 @@ class TestDoEnforce:
|
|||||||
):
|
):
|
||||||
do_enforce(config, state)
|
do_enforce(config, state)
|
||||||
assert call_count == 2
|
assert call_count == 2
|
||||||
|
|
||||||
|
def test_state_load_failure_continues(self) -> None:
|
||||||
|
"""Corrupt state file should not crash the daemon."""
|
||||||
|
import json as json_mod
|
||||||
|
|
||||||
|
state = State(current_app_id=1, current_game_name="G")
|
||||||
|
config = Config()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def load_side_effect() -> State:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
msg = "bad"
|
||||||
|
raise json_mod.JSONDecodeError(msg, "", 0)
|
||||||
|
if call_count == 2:
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
return State(current_app_id=1) # pragma: no cover
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(f"{PKG}._enforce_setup"),
|
||||||
|
patch(f"{PKG}._echo"),
|
||||||
|
patch.object(State, "load", side_effect=load_side_effect),
|
||||||
|
patch(f"{PKG}._enforce_loop_iteration") as mock_iter,
|
||||||
|
patch(f"{PKG}.time.sleep"),
|
||||||
|
):
|
||||||
|
do_enforce(config, state)
|
||||||
|
mock_iter.assert_not_called()
|
||||||
|
|||||||
@ -130,18 +130,8 @@ class TestEnforceAllowedGame:
|
|||||||
assert result == [(100, 570)]
|
assert result == [(100, 570)]
|
||||||
|
|
||||||
def test_allowed_none(self) -> None:
|
def test_allowed_none(self) -> None:
|
||||||
with (
|
result = enforce_allowed_game(None, kill_unauthorized=True)
|
||||||
patch(
|
assert result == []
|
||||||
"python_pkg.steam_backlog_enforcer.enforcer.get_running_steam_game_pids",
|
|
||||||
return_value={100: 570},
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"python_pkg.steam_backlog_enforcer.enforcer.kill_process"
|
|
||||||
) as mock_kill,
|
|
||||||
):
|
|
||||||
result = enforce_allowed_game(None, kill_unauthorized=True)
|
|
||||||
assert result == [(100, 570)]
|
|
||||||
mock_kill.assert_called_once_with(100, 570)
|
|
||||||
|
|
||||||
|
|
||||||
class TestKillProcess:
|
class TestKillProcess:
|
||||||
|
|||||||
@ -66,14 +66,9 @@ class TestHltbCache:
|
|||||||
assert cache_file.exists()
|
assert cache_file.exists()
|
||||||
|
|
||||||
def test_save_cache_os_error(self, tmp_path: Path) -> None:
|
def test_save_cache_os_error(self, tmp_path: Path) -> None:
|
||||||
cache_file = MagicMock()
|
with patch(
|
||||||
cache_file.write_text = MagicMock(side_effect=OSError)
|
"python_pkg.steam_backlog_enforcer.hltb._atomic_write",
|
||||||
with (
|
side_effect=OSError("disk full"),
|
||||||
patch("python_pkg.steam_backlog_enforcer.hltb.HLTB_CACHE_FILE", cache_file),
|
|
||||||
patch(
|
|
||||||
"python_pkg.steam_backlog_enforcer.hltb.CONFIG_DIR",
|
|
||||||
MagicMock(mkdir=MagicMock()),
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
save_hltb_cache({440: 10.5}) # Should not raise
|
save_hltb_cache({440: 10.5}) # Should not raise
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user