From c985160d17be16e50205d742d2ea744d881fda56 Mon Sep 17 00:00:00 2001 From: Krzysztof kuhy Rudnicki Date: Mon, 16 Mar 2026 22:46:48 +0100 Subject: [PATCH] WIP: Enforce 500-line limit - split batch 1 Split 16+ files. 27 files still need splitting. See session notes. --- pyproject.toml | 9 +- .../brother_printer/check_brother_printer.py | 1778 +---------- python_pkg/brother_printer/constants.py | 173 ++ python_pkg/brother_printer/cups_queue.py | 459 +++ python_pkg/brother_printer/cups_service.py | 479 +++ python_pkg/brother_printer/data_classes.py | 96 + python_pkg/brother_printer/display.py | 385 +++ python_pkg/brother_printer/network_query.py | 97 + python_pkg/brother_printer/usb_query.py | 233 ++ python_pkg/geo_data.py | 2028 ------------- python_pkg/geo_data/__init__.py | 206 ++ python_pkg/geo_data/_common.py | 318 ++ python_pkg/geo_data/_poland_admin.py | 226 ++ python_pkg/geo_data/_poland_nature.py | 446 +++ python_pkg/geo_data/_poland_water.py | 437 +++ python_pkg/geo_data/_warsaw.py | 407 +++ python_pkg/geo_data/_warsaw_places.py | 186 ++ python_pkg/keyboard_coop/tests/conftest.py | 14 + .../keyboard_coop/tests/test_constants.py | 148 + .../keyboard_coop/tests/test_game_logic.py | 371 +++ .../keyboard_coop/tests/test_game_loop.py | 426 +++ python_pkg/keyboard_coop/tests/test_main.py | 1247 -------- python_pkg/keyboard_coop/tests/test_ui.py | 311 ++ python_pkg/lichess_bot/tests/test_main.py | 1244 -------- .../lichess_bot/tests/test_main_analysis.py | 409 +++ .../lichess_bot/tests/test_main_bot_loop.py | 474 +++ .../lichess_bot/tests/test_main_game_state.py | 403 +++ .../praca_magisterska_video/_q23_classical.py | 445 +++ .../praca_magisterska_video/_q23_deeplab.py | 248 ++ .../praca_magisterska_video/_q23_helpers.py | 116 + .../_q23_transformer.py | 430 +++ .../praca_magisterska_video/_q23_unet_fcn.py | 399 +++ .../praca_magisterska_video/_q24_classical.py | 332 +++ .../praca_magisterska_video/_q24_common.py | 115 + .../praca_magisterska_video/_q24_nms_final.py | 239 ++ .../praca_magisterska_video/_q24_rcnn.py | 405 +++ .../praca_magisterska_video/_q24_rpn_yolo.py | 383 +++ .../_q24_yolo_arch_detr.py | 459 +++ .../generate_images/_pubsub_common.py | 235 ++ .../generate_images/_pubsub_qos.py | 430 +++ .../generate_images/_pubsub_topic_content.py | 239 ++ .../_pubsub_type_hierarchical.py | 279 ++ .../generate_images/_q20_architectures.py | 421 +++ .../generate_images/_q20_batch_and_windows.py | 449 +++ .../generate_images/_q20_common.py | 180 ++ .../_q20_late_and_decisions.py | 240 ++ .../generate_images/_q20_platforms.py | 471 +++ .../_q20_time_monitoring_sessions.py | 464 +++ .../generate_images/_q23_architectures.py | 467 +++ .../generate_images/_q23_common.py | 96 + .../generate_images/_q23_diy_unet.py | 251 ++ .../generate_images/_q23_mean_shift_ncuts.py | 380 +++ .../generate_images/_q23_mnemonics.py | 327 ++ .../generate_images/_q23_nn_basics.py | 293 ++ .../generate_images/_q23_otsu_watershed.py | 408 +++ .../_q23_receptive_transformer.py | 286 ++ .../generate_images/_q23_region_diy.py | 408 +++ .../generate_images/_q24_common.py | 186 ++ .../generate_images/_q24_fpn_tasks_cnn.py | 412 +++ .../generate_images/_q24_haar_integral_svm.py | 342 +++ .../generate_images/_q24_hog_classical.py | 380 +++ .../generate_images/_q24_iou_nms_detector.py | 413 +++ .../generate_images/_q24_modern_pipelines.py | 365 +++ .../generate_images/_q24_rcnn_yolo.py | 344 +++ .../generate_images/_q31_common.py | 102 + .../_q31_criteria_comparison.py | 256 ++ .../generate_images/_q31_ev_spectrum.py | 289 ++ .../generate_images/_q31_hurwicz_mnemonic.py | 344 +++ .../generate_images/_q31_regret_matrix.py | 322 ++ .../generate_images/_q9_basics.py | 448 +++ .../generate_images/_q9_classic_sync.py | 420 +++ .../generate_images/_q9_common.py | 200 ++ .../generate_images/_q9_ipc.py | 212 ++ .../generate_images/_q9_race_deadlock.py | 404 +++ .../generate_images/_sched_common.py | 87 + .../generate_images/_sched_complexity_edd.py | 309 ++ .../generate_images/_sched_graham.py | 484 +++ .../generate_images/_sched_johnson.py | 318 ++ .../generate_images/_sched_spt_flow_job.py | 352 +++ .../generate_pubsub_diagrams.py | 1207 +------- .../generate_images/generate_q20_diagrams.py | 2153 +------------- .../generate_images/generate_q23_diagrams.py | 2617 +---------------- .../generate_images/generate_q24_diagrams.py | 2312 +-------------- .../generate_images/generate_q31_diagrams.py | 1216 +------- .../generate_q9_all_diagrams.py | 1634 +--------- .../generate_scheduling_diagrams.py | 1473 +--------- .../praca_magisterska_video/visualize_q23.py | 1555 +--------- .../praca_magisterska_video/visualize_q24.py | 1822 +----------- python_pkg/screen_locker/_constants.py | 36 + .../screen_locker/_phone_verification.py | 203 ++ python_pkg/screen_locker/_shutdown.py | 262 ++ python_pkg/screen_locker/_ui_flows.py | 294 ++ python_pkg/screen_locker/_workout_forms.py | 269 ++ python_pkg/screen_locker/screen_lock.py | 1016 +------ python_pkg/screen_locker/tests/conftest.py | 113 + .../screen_locker/tests/test_adb_and_phone.py | 411 +++ .../screen_locker/tests/test_init_and_log.py | 390 +++ .../tests/test_phone_check_unlock.py | 430 +++ .../screen_locker/tests/test_screen_lock.py | 2079 ------------- .../screen_locker/tests/test_ui_and_timers.py | 424 +++ .../screen_locker/tests/test_verify_data.py | 371 +++ .../steam_backlog_enforcer/game_install.py | 349 +++ python_pkg/steam_backlog_enforcer/main.py | 859 +----- python_pkg/steam_backlog_enforcer/scanning.py | 501 ++++ .../steam-backlog-enforcer.service | 2 +- tests/test_file_length.py | 47 + 106 files changed, 28081 insertions(+), 25858 deletions(-) create mode 100644 python_pkg/brother_printer/constants.py create mode 100644 python_pkg/brother_printer/cups_queue.py create mode 100644 python_pkg/brother_printer/cups_service.py create mode 100644 python_pkg/brother_printer/data_classes.py create mode 100644 python_pkg/brother_printer/display.py create mode 100644 python_pkg/brother_printer/network_query.py create mode 100644 python_pkg/brother_printer/usb_query.py delete mode 100644 python_pkg/geo_data.py create mode 100644 python_pkg/geo_data/__init__.py create mode 100644 python_pkg/geo_data/_common.py create mode 100644 python_pkg/geo_data/_poland_admin.py create mode 100644 python_pkg/geo_data/_poland_nature.py create mode 100644 python_pkg/geo_data/_poland_water.py create mode 100644 python_pkg/geo_data/_warsaw.py create mode 100644 python_pkg/geo_data/_warsaw_places.py create mode 100644 python_pkg/keyboard_coop/tests/conftest.py create mode 100644 python_pkg/keyboard_coop/tests/test_constants.py create mode 100644 python_pkg/keyboard_coop/tests/test_game_logic.py create mode 100644 python_pkg/keyboard_coop/tests/test_game_loop.py delete mode 100644 python_pkg/keyboard_coop/tests/test_main.py create mode 100644 python_pkg/keyboard_coop/tests/test_ui.py delete mode 100644 python_pkg/lichess_bot/tests/test_main.py create mode 100644 python_pkg/lichess_bot/tests/test_main_analysis.py create mode 100644 python_pkg/lichess_bot/tests/test_main_bot_loop.py create mode 100644 python_pkg/lichess_bot/tests/test_main_game_state.py create mode 100644 python_pkg/praca_magisterska_video/_q23_classical.py create mode 100644 python_pkg/praca_magisterska_video/_q23_deeplab.py create mode 100644 python_pkg/praca_magisterska_video/_q23_helpers.py create mode 100644 python_pkg/praca_magisterska_video/_q23_transformer.py create mode 100644 python_pkg/praca_magisterska_video/_q23_unet_fcn.py create mode 100644 python_pkg/praca_magisterska_video/_q24_classical.py create mode 100644 python_pkg/praca_magisterska_video/_q24_common.py create mode 100644 python_pkg/praca_magisterska_video/_q24_nms_final.py create mode 100644 python_pkg/praca_magisterska_video/_q24_rcnn.py create mode 100644 python_pkg/praca_magisterska_video/_q24_rpn_yolo.py create mode 100644 python_pkg/praca_magisterska_video/_q24_yolo_arch_detr.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_pubsub_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_pubsub_qos.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_pubsub_topic_content.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_pubsub_type_hierarchical.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q20_architectures.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q20_batch_and_windows.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q20_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q20_late_and_decisions.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q20_platforms.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q20_time_monitoring_sessions.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_architectures.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_diy_unet.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_mean_shift_ncuts.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_mnemonics.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_nn_basics.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_otsu_watershed.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_receptive_transformer.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q23_region_diy.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_fpn_tasks_cnn.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_haar_integral_svm.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_hog_classical.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_iou_nms_detector.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_modern_pipelines.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q24_rcnn_yolo.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q31_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q31_criteria_comparison.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q31_ev_spectrum.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q31_hurwicz_mnemonic.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q31_regret_matrix.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q9_basics.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q9_classic_sync.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q9_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q9_ipc.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_q9_race_deadlock.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_sched_common.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_sched_complexity_edd.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_sched_graham.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_sched_johnson.py create mode 100644 python_pkg/praca_magisterska_video/generate_images/_sched_spt_flow_job.py mode change 100755 => 100644 python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py mode change 100755 => 100644 python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py mode change 100755 => 100644 python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py create mode 100644 python_pkg/screen_locker/_constants.py create mode 100644 python_pkg/screen_locker/_phone_verification.py create mode 100644 python_pkg/screen_locker/_shutdown.py create mode 100644 python_pkg/screen_locker/_ui_flows.py create mode 100644 python_pkg/screen_locker/_workout_forms.py create mode 100644 python_pkg/screen_locker/tests/conftest.py create mode 100644 python_pkg/screen_locker/tests/test_adb_and_phone.py create mode 100644 python_pkg/screen_locker/tests/test_init_and_log.py create mode 100644 python_pkg/screen_locker/tests/test_phone_check_unlock.py delete mode 100644 python_pkg/screen_locker/tests/test_screen_lock.py create mode 100644 python_pkg/screen_locker/tests/test_ui_and_timers.py create mode 100644 python_pkg/screen_locker/tests/test_verify_data.py create mode 100644 python_pkg/steam_backlog_enforcer/game_install.py create mode 100644 python_pkg/steam_backlog_enforcer/scanning.py create mode 100644 tests/test_file_length.py diff --git a/pyproject.toml b/pyproject.toml index 7f22760..86881ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ unfixable = [] "SLF001", # Allow private member access in tests ] # Files using urlopen with validated URL schemes -"python_pkg/geo_data.py" = ["S310"] +"python_pkg/geo_data/_common.py" = ["S310"] "python_pkg/steam_backlog_enforcer/library_hider.py" = ["S310"] "poker_modifier_app/poker_modifier_app.py" = [ "FBT003", # Boolean positional values in tkinter API calls @@ -76,9 +76,12 @@ unfixable = [] "FBT003", # Boolean positional values in tkinter API calls ] # Brother printer - optional usb.core/usb.util imports -"python_pkg/brother_printer/check_brother_printer.py" = [ +"python_pkg/brother_printer/cups_service.py" = [ "PLC0415", # Late imports for optional pyusb dependency ] +"python_pkg/brother_printer/usb_query.py" = [ + "PLC0415", # Late import of cups_service fallback +] # Music generator - CLI script with intentional patterns "python_pkg/music_gen/music_generator.py" = [ "T201", # print() is intentional for CLI feedback @@ -182,6 +185,8 @@ disable = [] [tool.pylint.design] # A class with just run() as public API is valid for games/apps min-public-methods = 1 +# Enforce maximum file length of 500 lines +max-module-lines = 500 [tool.pylint.spelling] # No spelling dictionary to avoid false positives diff --git a/python_pkg/brother_printer/check_brother_printer.py b/python_pkg/brother_printer/check_brother_printer.py index 7f93657..8fdd791 100644 --- a/python_pkg/brother_printer/check_brother_printer.py +++ b/python_pkg/brother_printer/check_brother_printer.py @@ -12,1773 +12,39 @@ Usage: sudo python3 -m brother_printer # force network/SNMP mode sudo python3 -m brother_printer --reset-toner # after replacing toner sudo python3 -m brother_printer --reset-drum # after replacing drum + +This module re-exports public symbols from sub-modules for backwards +compatibility. The implementation lives in: + +- constants.py – colours, PJL status codes, lookup tables +- data_classes.py – dataclasses (CUPSJob, USBResult, NetworkResult …) +- usb_query.py – USB discovery and PJL query +- cups_service.py – CUPS service control, consumable state, USB fallback +- network_query.py – SNMP network query +- cups_queue.py – CUPS queue inspection and interactive fixes +- display.py – formatted output for USB / network results """ from __future__ import annotations -import contextlib -from dataclasses import dataclass, field -import fcntl -import json import logging import os -from pathlib import Path import re -import select import shutil import subprocess import sys -import time -from typing import TYPE_CHECKING -import urllib.parse -if TYPE_CHECKING: - from collections.abc import Callable +from python_pkg.brother_printer.constants import CYAN, RED, RESET, _out +from python_pkg.brother_printer.cups_service import reset_consumable +from python_pkg.brother_printer.display import ( + display_network_results, + display_usb_results, +) +from python_pkg.brother_printer.network_query import query_network_snmp +from python_pkg.brother_printer.usb_query import find_brother_usb, query_usb_pjl logger = logging.getLogger(__name__) -# ── Colors ─────────────────────────────────────────────────────────── - -RED = "\033[0;31m" -YELLOW = "\033[1;33m" -GREEN = "\033[0;32m" -CYAN = "\033[0;36m" -BOLD = "\033[1m" -DIM = "\033[2m" -RESET = "\033[0m" - -# ── SNMP supply level sentinel values ──────────────────────────────────────── - -SNMP_LEVEL_OK = -3 -SNMP_LEVEL_LOW = -2 -SUPPLY_LOW_PCT = 10 -SUPPLY_WARN_PCT = 25 -PROGRESS_BAR_WIDTH = 25 - -# Brother HL-1110 consumable page ratings -TONER_RATED_PAGES = 1000 -DRUM_RATED_PAGES = 10000 -CUPS_PAGE_LOG = Path("/var/log/cups/page_log") -CONSUMABLE_STATE_FILE = Path.home() / ".config" / "brother_printer" / "state.json" -MIN_LPSTAT_JOB_PARTS = 4 - - -def _out(text: str = "") -> None: - """Write a line to stdout.""" - sys.stdout.write(text + "\n") - - -def _prompt(text: str) -> str: - """Read user input with a prompt.""" - sys.stdout.write(text) - sys.stdout.flush() - return sys.stdin.readline().strip() - - -# ── Brother PJL status codes ──────────────────────────────────────── -# Documented in Brother PJL Technical Reference. -# Format: code -> (severity, short_text, action) -# Severities: ok, info, warn, critical - -BROTHER_STATUS_CODES: dict[int, tuple[str, str, str]] = { - 10001: ("ok", "Ready", ""), - 10002: ("ok", "Sleep", ""), - 10003: ("info", "Self-test / Calibrating", ""), - 10004: ("ok", "Warming up", ""), - 10005: ("ok", "Cooling down", ""), - 10006: ("info", "Processing", ""), - 10007: ("info", "Printing", ""), - 10014: ("ok", "Cancelling", ""), - 10023: ("info", "Waiting", ""), - # Toner - 30010: ( - "warn", - "Toner Low", - "Order replacement toner cartridge (TN-1050/TN-1030 compatible).", - ), - 30038: ( - "warn", - "Toner Low", - "Order replacement toner cartridge (TN-1050/TN-1030 compatible).", - ), - 40038: ( - "warn", - "Toner Low", - "Order replacement toner cartridge (TN-1050/TN-1030 compatible).", - ), - 40309: ( - "critical", - "Replace Toner", - "The toner cartridge needs immediate replacement (TN-1050/TN-1030 compatible).", - ), - 40310: ( - "critical", - "Toner End", - "The toner cartridge is empty. Replace now (TN-1050/TN-1030 compatible).", - ), - # Drum - 30201: ( - "warn", - "Drum End Soon", - "The drum unit is nearing end of life. Order replacement (DR-1050 compatible).", - ), - 40201: ( - "warn", - "Drum End Soon", - "The drum unit is nearing end of life. Order replacement (DR-1050 compatible).", - ), - 40019: ( - "critical", - "Replace Drum", - "The drum unit must be replaced (DR-1050 compatible).", - ), - 40020: ( - "critical", - "Drum Stop", - "The drum unit must be replaced immediately (DR-1050 compatible).", - ), - # Paper / feed - 40000: ("critical", "Paper Jam", "Clear the paper jam and close all covers."), - 40300: ( - "critical", - "No Paper / Tray Open", - "Load paper or close the paper tray.", - ), - 40302: ("critical", "No Paper", "Load paper into the paper tray."), - 40016: ("warn", "Paper Feed Error", "Check paper tray and re-seat paper."), - # Cover - 41000: ("critical", "Cover Open", "Close the top cover of the printer."), - 41001: ("critical", "Cover Open", "Close the front cover of the printer."), - # Others - 35078: ("info", "Manual Feed", "Load paper in the manual feed slot."), - 42000: ( - "critical", - "Machine Error", - "Power-cycle the printer. If error persists, contact service.", - ), -} - - -# ── Data classes ───────────────────────────────────────────────────── - - -@dataclass -class CUPSJob: - """A single CUPS print job.""" - - job_id: str - user: str - size: str - date: str - - -@dataclass -class CUPSQueueStatus: - """Status of the CUPS print queue for a printer.""" - - printer_name: str = "" - enabled: bool = True - reason: str = "" - jobs: list[CUPSJob] = field(default_factory=list) - has_backend_errors: bool = False - last_backend_error: str = "" - - -@dataclass -class PageCountEstimate: - """Estimated consumable life based on CUPS page count.""" - - total_pages: int = 0 - toner_pages: int = 0 - drum_pages: int = 0 - toner_pct_remaining: int = 100 - drum_pct_remaining: int = 100 - toner_exhausted: bool = False - toner_low: bool = False - drum_near_end: bool = False - - -@dataclass -class USBPortStatus: - """IEEE 1284 USB printer port status bits.""" - - paper_empty: bool = False - online: bool = True - error: bool = False - raw_byte: int = 0 - - -@dataclass -class USBResult: - """Result from a USB PJL query.""" - - connection: str = "usb" - device: str = "" - product: str = "Brother Laser Printer" - serial: str = "" - status_code: str = "" - display: str = "" - online: str = "" - economode: str = "" - error: str = "" - port_status: USBPortStatus | None = None - - -@dataclass -class NetworkResult: - """Result from an SNMP network query.""" - - connection: str = "network" - ip: str = "" - product: str = "Unknown" - serial: str = "" - printer_status: str = "" - device_status: str = "" - display: str = "" - page_count: str = "" - supply_descriptions: list[str] = field(default_factory=list) - supply_max: list[str] = field(default_factory=list) - supply_levels: list[str] = field(default_factory=list) - error: str = "" - - -# ── USB printer discovery ──────────────────────────────────────────── - - -def find_brother_usb() -> str: - """Look for any Brother printer on USB via lsusb. Returns the info line.""" - if not shutil.which("lsusb"): - return "" - try: - r = subprocess.run( - ["/usr/bin/lsusb"], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - for line in r.stdout.splitlines(): - if "04f9:" in line.lower(): - # Return the part after "ID ..." - return line.split(": ", 1)[1] if ": " in line else line - except (subprocess.TimeoutExpired, OSError): - pass - return "" - - -def find_usb_printer_dev() -> str | None: - """Find /dev/usb/lp* device for the Brother printer.""" - devices = sorted(Path("/dev/usb").glob("lp*")) - return str(devices[0]) if devices else None - - -def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None: - """Extract product and serial from a CUPS usb:// URI.""" - parsed = urllib.parse.urlparse(uri) - info["product"] = urllib.parse.unquote(parsed.path.lstrip("/")) - qs = urllib.parse.parse_qs(parsed.query) - if "serial" in qs: - info["serial"] = qs["serial"][0] - - -def get_printer_info_from_cups() -> dict[str, str]: - """Get printer model/serial from lpstat.""" - info: dict[str, str] = {"product": "", "serial": ""} - try: - r = subprocess.run( - ["/usr/bin/lpstat", "-v"], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - for line in r.stdout.splitlines(): - if "Brother" in line: - for part in line.split(): - if part.startswith("usb://"): - _parse_cups_usb_uri(part, info) - break - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - logger.debug("Failed to query CUPS for printer info", exc_info=True) - return info - - -# ── PJL over USB ───────────────────────────────────────────────────── - - -def _drain_buffer(fd: int) -> None: - """Read and discard any stale data from the USB buffer.""" - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) - with contextlib.suppress(OSError): - while os.read(fd, 4096): - pass - fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) - - -def _read_nonblocking(fd: int, flags: int) -> bytes: - """Read all available data from fd in non-blocking mode.""" - data = b"" - fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) - with contextlib.suppress(OSError): - while True: - chunk = os.read(fd, 4096) - if not chunk: - break - data += chunk - fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) - return data - - -def _wait_for_pjl_response(fd: int, flags: int, deadline: float) -> bytes: - """Poll fd until PJL data arrives or deadline expires.""" - response = b"" - while time.time() < deadline: - remaining = deadline - time.time() - if remaining <= 0: - break - readable, _, _ = select.select([fd], [], [], min(remaining, 1.0)) - if readable: - response += _read_nonblocking(fd, flags) - if response and (b"=" in response or b"@PJL" in response): - break - return response - - -def pjl_query(fd: int, cmd: str, timeout_sec: float = 5.0) -> str: - """Send a PJL command via raw fd and read the response.""" - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) - - pjl_cmd = f"\x1b%-12345X@PJL\r\n{cmd}\r\n\x1b%-12345X" - os.write(fd, pjl_cmd.encode()) - - deadline = time.time() + timeout_sec - response = _wait_for_pjl_response(fd, flags, deadline) - - fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) - return response.decode("ascii", errors="replace") - - -def _parse_status(resp: str, result: USBResult) -> bool: - """Parse STATUS response into result. Returns True if code was found.""" - found = False - for raw_line in resp.splitlines(): - stripped = raw_line.strip() - if stripped.startswith("CODE="): - result.status_code = stripped.split("=", 1)[1] - found = True - elif stripped.startswith("DISPLAY="): - result.display = stripped.split("=", 1)[1].strip().strip('"').strip() - elif stripped.startswith("ONLINE="): - result.online = stripped.split("=", 1)[1] - return found - - -def _parse_variables(resp: str, result: USBResult) -> bool: - """Parse VARIABLES response into result. Returns True if data found.""" - found = False - for raw_line in resp.splitlines(): - stripped = raw_line.strip() - if stripped.startswith("ECONOMODE="): - result.economode = stripped.split("=", 1)[1].split()[0] - found = True - return found - - -def _retry_pjl_query( - fd: int, - cmd: str, - parser: Callable[[str, USBResult], bool], - result: USBResult, - max_retries: int, -) -> None: - """Send a PJL query with retries, draining between attempts.""" - for attempt in range(max_retries + 1): - resp = pjl_query(fd, cmd) - if parser(resp, result): - break - if attempt < max_retries: - _drain_buffer(fd) - time.sleep(0.5) - - -def _run_pjl_queries(fd: int, result: USBResult, max_retries: int) -> None: - """Execute PJL query sequence on an open file descriptor.""" - _drain_buffer(fd) - - os.write(fd, b"\x1b%-12345X@PJL\r\n\x1b%-12345X") - time.sleep(0.5) - _drain_buffer(fd) - - _retry_pjl_query(fd, "@PJL INFO STATUS", _parse_status, result, max_retries) - _drain_buffer(fd) - time.sleep(0.5) - _retry_pjl_query(fd, "@PJL INFO VARIABLES", _parse_variables, result, max_retries) - - -def _init_usb_result(dev_path: str) -> USBResult: - """Create a USBResult with device info from CUPS.""" - cups_info = get_printer_info_from_cups() - return USBResult( - device=dev_path, - product=cups_info.get("product") or "Brother Laser Printer", - serial=cups_info.get("serial", ""), - ) - - -def query_usb_pjl(max_retries: int = 2) -> USBResult: - """Query a Brother printer via PJL over /dev/usb/lp*.""" - dev_path = find_usb_printer_dev() - if not dev_path: - return _query_usb_via_cups() - - result = _init_usb_result(dev_path) - if not os.access(dev_path, os.R_OK | os.W_OK): - result.error = f"Permission denied: {dev_path}. Run with sudo." - return result - - fd: int | None = None - try: - fd = os.open(dev_path, os.O_RDWR) - fcntl.fcntl(fd, fcntl.F_GETFL) - _run_pjl_queries(fd, result, max_retries) - except OSError as e: - result.error = str(e) - finally: - if fd is not None: - os.close(fd) - return result - - -# ── CUPS-based USB fallback ────────────────────────────────────────── -# When the usblp kernel module is not available, /dev/usb/lp* devices -# don't exist even though CUPS can print fine via its own libusb backend. -# These functions query printer status through CUPS IPP instead. - -_CUPS_REASONS_TO_STATUS: dict[str, int] = { - "paused": 10023, - "moving-to-paused": 10023, - "toner-low": 30010, - "toner-empty": 40310, - "marker-supply-low": 30010, - "marker-supply-empty": 40310, - "media-empty": 40302, - "media-needed": 40302, - "media-jam": 40000, - "cover-open": 41000, - "door-open": 41000, - "input-tray-missing": 40300, -} - -_CUPS_STATE_TO_STATUS: dict[str, int] = { - "idle": 10001, - "processing": 10007, - "stopped": 10023, -} - - -BROTHER_USB_VENDOR_ID = 0x04F9 - - -def _get_pyusb_device_info() -> dict[str, str]: - """Get Brother USB printer info via pyusb (no interface claim needed).""" - try: - import usb.core - - dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID) - if dev is None: - return {} - except (ImportError, OSError, ValueError): - return {} - else: - return { - "product": dev.product or "", - "serial": dev.serial_number or "", - } - - -def _stop_cups() -> bool: - """Stop CUPS service and sockets. Returns True on success.""" - systemctl = shutil.which("systemctl") - if not systemctl: - return False - try: - subprocess.run( - [systemctl, "stop", "cups.service", "cups.socket", "cups.path"], - timeout=15, - check=True, - ) - time.sleep(2) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): - return False - return True - - -def _is_cups_scheduler_running() -> bool: - """Check if the CUPS scheduler is currently running.""" - lpstat = shutil.which("lpstat") - if not lpstat: - return False - try: - r = subprocess.run( - [lpstat, "-r"], - capture_output=True, - text=True, - timeout=3, - check=False, - ) - return ( - "is running" in r.stdout.lower() and "not running" not in r.stdout.lower() - ) - except (subprocess.TimeoutExpired, OSError): - return False - - -def _start_cups() -> bool: - """Start CUPS service, socket, and path units. Returns True on success.""" - systemctl = shutil.which("systemctl") - if not systemctl: - return False - try: - subprocess.run( - [systemctl, "start", "cups.service", "cups.socket", "cups.path"], - timeout=15, - check=True, - ) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): - return False - # Verify CUPS is actually responding - for _ in range(10): - if _is_cups_scheduler_running(): - return True - time.sleep(1) - return False - - -def _query_usb_port_status_raw() -> USBPortStatus | None: - """Query USB printer port status via pyusb control transfer. - - Requires root and temporarily stops CUPS to access the USB device. - Returns None if the query fails. - """ - try: - import usb.core - import usb.util - except ImportError: - return None - - dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID) - if dev is None: - return None - - if not _stop_cups(): - return None - - try: - dev.reset() - time.sleep(2) - dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID) - if dev is None: - return None - - try: - if dev.is_kernel_driver_active(0): - dev.detach_kernel_driver(0) - except (usb.core.USBError, NotImplementedError): - pass - - usb.util.claim_interface(dev, 0) - try: - # USB Printer Class GET_PORT_STATUS (bRequest=0x01) - raw = dev.ctrl_transfer(0xA1, 0x01, 0, 0, 1, timeout=5000) - port_byte = raw[0] - return USBPortStatus( - paper_empty=bool(port_byte & 0x20), - online=bool(port_byte & 0x10), - error=not bool(port_byte & 0x08), - raw_byte=port_byte, - ) - finally: - usb.util.release_interface(dev, 0) - usb.util.dispose_resources(dev) - except (OSError, ValueError): - logger.debug("USB port status query failed", exc_info=True) - return None - finally: - _start_cups() - - -def _get_cups_total_pages() -> int: - """Parse CUPS page_log to get total pages printed (deduplicated by job).""" - if not CUPS_PAGE_LOG.exists(): - return 0 - try: - text = CUPS_PAGE_LOG.read_text(encoding="utf-8", errors="replace") - except OSError: - return 0 - # page_log format: printer user job_id [date] total N ... - # Deduplicate by job_id (retries produce repeated lines) - jobs: dict[str, int] = {} - for line in text.splitlines(): - match = re.search(r"\s(\d+)\s+\[.*?\]\s+total\s+(\d+)", line) - if match: - job_id = match.group(1) - pages = int(match.group(2)) - jobs[job_id] = max(jobs.get(job_id, 0), pages) - return sum(jobs.values()) - - -def _load_consumable_state() -> dict[str, int]: - """Load consumable replacement state from disk. - - Returns dict with keys 'toner_replaced_at' and 'drum_replaced_at' - (page counts when each consumable was last replaced). - """ - defaults: dict[str, int] = {"toner_replaced_at": 0, "drum_replaced_at": 0} - if not CONSUMABLE_STATE_FILE.exists(): - return defaults - try: - data = json.loads( - CONSUMABLE_STATE_FILE.read_text(encoding="utf-8"), - ) - return { - "toner_replaced_at": int(data.get("toner_replaced_at", 0)), - "drum_replaced_at": int(data.get("drum_replaced_at", 0)), - } - except (OSError, json.JSONDecodeError, ValueError, TypeError): - return defaults - - -def _save_consumable_state(state: dict[str, int]) -> None: - """Persist consumable replacement state to disk.""" - CONSUMABLE_STATE_FILE.parent.mkdir(parents=True, exist_ok=True) - CONSUMABLE_STATE_FILE.write_text( - json.dumps(state, indent=2) + "\n", - encoding="utf-8", - ) - - -def _reset_consumable(name: str) -> None: - """Record current page count as replacement point for a consumable.""" - total = _get_cups_total_pages() - state = _load_consumable_state() - key = f"{name}_replaced_at" - state[key] = total - _save_consumable_state(state) - _out( - f"{GREEN}✓ {name.capitalize()} counter reset at page count" f" {total}.{RESET}" - ) - _out(f" State saved to {CONSUMABLE_STATE_FILE}") - - -def _estimate_consumable_life() -> PageCountEstimate: - """Estimate toner/drum life from CUPS page count since last replacement.""" - total = _get_cups_total_pages() - if total <= 0: - return PageCountEstimate() - state = _load_consumable_state() - toner_pages = max(0, total - state["toner_replaced_at"]) - drum_pages = max(0, total - state["drum_replaced_at"]) - toner_pct = max(0, 100 - (toner_pages * 100 // TONER_RATED_PAGES)) - drum_pct = max(0, 100 - (drum_pages * 100 // DRUM_RATED_PAGES)) - return PageCountEstimate( - total_pages=total, - toner_pages=toner_pages, - drum_pages=drum_pages, - toner_pct_remaining=toner_pct, - drum_pct_remaining=drum_pct, - toner_exhausted=toner_pages >= TONER_RATED_PAGES, - toner_low=toner_pages >= TONER_RATED_PAGES * 80 // 100, - drum_near_end=drum_pages >= DRUM_RATED_PAGES * 90 // 100, - ) - - -def _parse_ipp_attributes(output: str) -> dict[str, str]: - """Parse ipptool verbose output into an attribute dict.""" - attrs: dict[str, str] = {} - for line in output.splitlines(): - match = re.match(r"\s+(\S+)\s+\([^)]+\)\s+=\s+(.*)", line) - if match: - attrs[match.group(1)] = match.group(2).strip() - return attrs - - -def _get_cups_ipp_status(printer_name: str) -> dict[str, str]: - """Query printer attributes via CUPS IPP using ipptool.""" - ipptool_path = shutil.which("ipptool") - if not ipptool_path: - return {} - uri = f"ipp://localhost/printers/{printer_name}" - try: - r = subprocess.run( - [ipptool_path, "-tv", uri, "get-printer-attributes.test"], - capture_output=True, - text=True, - timeout=10, - check=False, - ) - return _parse_ipp_attributes(r.stdout) - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - return {} - - -def _get_cups_economode(printer_name: str) -> str: - """Query toner save mode setting via lpoptions.""" - lpoptions_path = shutil.which("lpoptions") - if not lpoptions_path: - return "" - try: - r = subprocess.run( - [lpoptions_path, "-p", printer_name, "-l"], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - for line in r.stdout.splitlines(): - if "conomode" in line.lower(): - match = re.search(r"\*(\w+)", line) - if match: - return "ON" if match.group(1).lower() == "true" else "OFF" - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - pass - return "" - - -def _map_cups_to_status_code(state: str, reasons: str) -> str: - """Map CUPS state + reasons to a Brother PJL status code string.""" - for keyword, code in _CUPS_REASONS_TO_STATUS.items(): - if keyword in reasons.lower(): - return str(code) - clean_state = re.sub(r"\(.*\)", "", state).strip().lower() - return str(_CUPS_STATE_TO_STATUS.get(clean_state, 10001)) - - -_ERROR_REASON_MAP: tuple[tuple[tuple[str, ...], str, str], ...] = ( - (("media-jam",), "40000", "Paper Jam"), - (("cover-open", "door-open"), "41000", "Cover Open"), - (("toner-empty",), "40310", "Toner End"), - (("toner-low",), "30010", "Toner Low"), -) - - -def _cups_reasons_to_error(cups_reasons: str) -> tuple[str, str]: - """Map CUPS reason keywords to a (status_code, display) pair.""" - reasons_lower = cups_reasons.lower() - for keywords, code, display in _ERROR_REASON_MAP: - if any(kw in reasons_lower for kw in keywords): - return code, display - return "42000", "Printer Error" - - -def _port_status_to_status_code( - ps: USBPortStatus, - cups_reasons: str, -) -> tuple[str, str]: - """Map USB port status + CUPS reasons to (status_code, display).""" - # Hardware error flags take priority - if ps.error and ps.paper_empty: - return "40302", "No Paper" - if ps.error and not ps.online: - return "41000", "Cover Open" - if ps.error: - return _cups_reasons_to_error(cups_reasons) - if ps.paper_empty: - return "40302", "No Paper" - if not ps.online: - return "10002", "Offline / Sleep" - return "", "" - - -def _ensure_cups_running() -> bool: - """Make sure CUPS is running, starting it if necessary.""" - if _is_cups_scheduler_running(): - return True - return _start_cups() - - -def _query_usb_via_cups() -> USBResult: - """Query USB printer status through CUPS when /dev/usb/lp* is unavailable.""" - _ensure_cups_running() - printer_name = _find_cups_printer_name() - if not printer_name: - return USBResult( - error="No USB printer device at /dev/usb/lp*" - " (usblp module not available)" - " and no Brother printer found in CUPS.", - ) - - pyusb_info = _get_pyusb_device_info() - cups_info = get_printer_info_from_cups() - - result = USBResult( - device="cups", - product=( - pyusb_info.get("product") - or cups_info.get("product") - or "Brother Laser Printer" - ), - serial=pyusb_info.get("serial") or cups_info.get("serial", ""), - ) - - ipp = _get_cups_ipp_status(printer_name) - state = ipp.get("printer-state", "") - reasons = ipp.get("printer-state-reasons", "none") - result.economode = _get_cups_economode(printer_name) - - # Direct USB hardware status query - port_status = _query_usb_port_status_raw() - if port_status is not None: - result.port_status = port_status - hw_code, hw_display = _port_status_to_status_code( - port_status, - reasons, - ) - if hw_code: - result.status_code = hw_code - result.display = hw_display - result.online = "TRUE" if port_status.online else "FALSE" - return result - # Hardware says OK — check page count for toner/drum warnings - estimate = _estimate_consumable_life() - if estimate.toner_exhausted: - result.status_code = "40310" - result.display = "Toner End (estimated from page count)" - result.online = "TRUE" - return result - if estimate.toner_low: - result.status_code = "30010" - result.display = "Toner Low (estimated from page count)" - result.online = "TRUE" - return result - result.status_code = _map_cups_to_status_code(state, reasons) - result.display = ipp.get("printer-state-message", "") - result.online = "TRUE" - return result - - # pyusb unavailable: CUPS-only fallback - result.status_code = _map_cups_to_status_code(state, reasons) - result.display = ipp.get("printer-state-message", "") - result.online = "TRUE" if state.lower() in {"idle", "processing"} else "FALSE" - - return result - - -# ── SNMP network query ────────────────────────────────────────────── - - -def _snmpwalk_cmd( - path: str, community: str, timeout: int, ip: str, oid: str -) -> list[str]: - """Build the snmpwalk command arguments.""" - return [path, "-v", "2c", "-c", community, "-t", str(timeout), "-OQvs", ip, oid] - - -def snmp_walk(ip: str, oid: str, community: str, timeout: int) -> list[str]: - """Run snmpwalk and return cleaned values.""" - snmpwalk_path = shutil.which("snmpwalk") - if not snmpwalk_path: - return [] - try: - r = subprocess.run( - _snmpwalk_cmd(snmpwalk_path, community, timeout, ip, oid), - capture_output=True, - text=True, - timeout=15, - check=False, - ) - return [ - line.strip().strip('"') - for line in r.stdout.strip().splitlines() - if line.strip() - ] - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - return [] - - -def _snmpget_cmd( - path: str, community: str, timeout: int, ip: str, oid: str -) -> list[str]: - """Build the snmpget command arguments.""" - return [path, "-v", "2c", "-c", community, "-t", str(timeout), ip, oid] - - -def _check_snmp_connectivity(ip: str, community: str, timeout: int) -> str | None: - """Verify SNMP connectivity. Returns error message or None on success.""" - snmpget_path = shutil.which("snmpget") - if not snmpget_path: - return "snmpget not found. Install: sudo pacman -S net-snmp" - try: - subprocess.run( - _snmpget_cmd( - snmpget_path, - community, - timeout, - ip, - "1.3.6.1.2.1.43.11.1.1.6.1.1", - ), - capture_output=True, - timeout=10, - check=True, - ) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): - return f"Cannot reach printer at {ip} via SNMP." - return None - - -def _build_network_result(ip: str, community: str, timeout: int) -> NetworkResult: - """Collect all SNMP data into a NetworkResult.""" - - def walk(oid: str) -> list[str]: - return snmp_walk(ip, oid, community, timeout) - - return NetworkResult( - ip=ip, - product=" ".join(walk("1.3.6.1.2.1.25.3.2.1.3")[:1]) or "Unknown", - serial=" ".join(walk("1.3.6.1.2.1.43.5.1.1.17")[:1]) or "", - printer_status=" ".join(walk("1.3.6.1.2.1.25.3.5.1.1")[:1]) or "", - device_status=" ".join(walk("1.3.6.1.2.1.25.3.2.1.5")[:1]) or "", - display=" ".join(walk("1.3.6.1.2.1.43.16.5.1.2")[:3]) or "", - page_count=" ".join(walk("1.3.6.1.2.1.43.10.2.1.4")[:1]) or "", - supply_descriptions=walk("1.3.6.1.2.1.43.11.1.1.6"), - supply_max=walk("1.3.6.1.2.1.43.11.1.1.8"), - supply_levels=walk("1.3.6.1.2.1.43.11.1.1.9"), - ) - - -def query_network_snmp(ip: str) -> NetworkResult: - """Query a Brother printer via SNMP over the network.""" - community = "public" - timeout = 5 - error = _check_snmp_connectivity(ip, community, timeout) - if error: - return NetworkResult(ip=ip, error=error) - return _build_network_result(ip, community, timeout) - - -# ── CUPS queue inspection ──────────────────────────────────────────── - - -def _find_cups_printer_name() -> str: - """Find the CUPS queue name for a Brother printer.""" - lpstat_path = shutil.which("lpstat") - if not lpstat_path: - return "" - try: - r = subprocess.run( - [lpstat_path, "-v"], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - for line in r.stdout.splitlines(): - if "brother" in line.lower(): - # e.g. device for Brother_HL-1110_series: usb://... - match = re.match(r"device for (\S+):", line) - if match: - return match.group(1) - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - pass - return "" - - -def _parse_lpstat_printer_line(line: str) -> tuple[bool, str]: - """Parse an lpstat -p line. Returns (enabled, reason).""" - enabled = "disabled" not in line.lower() - reason = "" - # Reason follows the dash after the date - match = re.search(r"\d{4}\s+-\s*(.+)", line) - if match: - reason = match.group(1).strip() - return enabled, reason - - -def _parse_lpstat_jobs(output: str, printer_name: str) -> list[CUPSJob]: - """Parse lpstat -o output into CUPSJob list.""" - jobs: list[CUPSJob] = [] - for line in output.splitlines(): - if not line.startswith(printer_name): - continue - parts = line.split() - if len(parts) >= MIN_LPSTAT_JOB_PARTS: - job_id = parts[0] - user = parts[1] - size = parts[2] - date = " ".join(parts[3:]) - jobs.append(CUPSJob(job_id=job_id, user=user, size=size, date=date)) - return jobs - - -def get_cups_queue_status() -> CUPSQueueStatus: - """Check if the CUPS queue is disabled and list pending jobs.""" - printer_name = _find_cups_printer_name() - if not printer_name: - return CUPSQueueStatus() - - result = CUPSQueueStatus(printer_name=printer_name) - lpstat_path = shutil.which("lpstat") - if not lpstat_path: - return result - - # Check printer enabled/disabled state - try: - r = subprocess.run( - [lpstat_path, "-p", printer_name], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - for line in r.stdout.splitlines(): - if "printer" in line.lower() and printer_name in line: - result.enabled, result.reason = _parse_lpstat_printer_line(line) - break - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - pass - - # List pending jobs - try: - r = subprocess.run( - [lpstat_path, "-o", printer_name], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - result.jobs = _parse_lpstat_jobs(r.stdout, printer_name) - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - pass - - # Check for stale backend errors - has_errors, last_error = _check_cups_backend_errors(printer_name) - result.has_backend_errors = has_errors - result.last_backend_error = last_error - - return result - - -def _cups_enable_printer(printer_name: str) -> bool: - """Re-enable a disabled CUPS printer. Returns True on success.""" - cupsenable_path = shutil.which("cupsenable") - if not cupsenable_path: - _out(f" {RED}cupsenable not found.{RESET}") - return False - try: - subprocess.run( - [cupsenable_path, printer_name], - timeout=5, - check=True, - ) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e: - _out(f" {RED}Failed to enable printer: {e}{RESET}") - return False - else: - return True - - -def _cups_cancel_all_jobs(printer_name: str) -> bool: - """Cancel all pending jobs. Returns True on success.""" - cancel_path = shutil.which("cancel") - if not cancel_path: - _out(f" {RED}cancel command not found.{RESET}") - return False - try: - subprocess.run( - [cancel_path, "-a", printer_name], - timeout=5, - check=True, - ) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e: - _out(f" {RED}Failed to cancel jobs: {e}{RESET}") - return False - else: - return True - - -def _cups_cancel_job(job_id: str) -> bool: - """Cancel a specific job. Returns True on success.""" - cancel_path = shutil.which("cancel") - if not cancel_path: - return False - try: - subprocess.run( - [cancel_path, job_id], - timeout=5, - check=True, - ) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): - return False - else: - return True - - -def _cups_restart_service() -> bool: - """Restart the CUPS service. Returns True on success.""" - systemctl_path = shutil.which("systemctl") - if not systemctl_path: - _out(f" {RED}systemctl not found.{RESET}") - return False - sys.stdout.write(f" {DIM}Restarting CUPS...{RESET}") - sys.stdout.flush() - try: - proc = subprocess.Popen( - [systemctl_path, "restart", "cups"], - ) - deadline = time.time() + 30 - while proc.poll() is None: - if time.time() > deadline: - proc.kill() - proc.wait() - sys.stdout.write("\n") - _out( - f" {RED}CUPS restart timed out" - f" (stuck backend process?).{RESET}" - ) - _out( - f" {DIM}Try: sudo kill -9 $(pgrep -f 'cups/backend/usb')" - f" && sudo systemctl restart cups{RESET}" - ) - return False - sys.stdout.write(".") - sys.stdout.flush() - time.sleep(1) - sys.stdout.write("\n") - if proc.returncode != 0: - _out( - f" {RED}CUPS restart failed" f" (exit code {proc.returncode}).{RESET}" - ) - return False - except OSError as e: - sys.stdout.write("\n") - _out(f" {RED}Failed to restart CUPS: {e}{RESET}") - return False - time.sleep(2) # wait for CUPS to come back up - return True - - -def _is_cups_printer_healthy(printer_name: str) -> bool: - """Check live CUPS state via lpstat. Returns True if enabled with no issues.""" - lpstat_path = shutil.which("lpstat") - if not lpstat_path: - return False - try: - r = subprocess.run( - [lpstat_path, "-p", printer_name], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - for line in r.stdout.splitlines(): - if ( - printer_name in line - and "idle" in line.lower() - and "enabled" in line.lower() - ): - return True - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): - pass - return False - - -def _find_backend_error_in_log( - lines: list[str], -) -> tuple[str, str, str]: - """Scan CUPS log lines (reversed) for backend errors. - - Returns: - (backend_error, error_timestamp, last_success_timestamp) - """ - backend_error = "" - error_timestamp = "" - last_success_timestamp = "" - - for line in reversed(lines): - if ( - "backend errors" in line or "stopped with status" in line - ) and not backend_error: - backend_error = line.strip() - ts_match = re.search(r"\[([^\]]+)\]", line) - if ts_match: - error_timestamp = ts_match.group(1) - # Check if a job completed successfully after the error - if ("Completed" in line or "total" in line) and error_timestamp: - ts_match = re.search(r"\[([^\]]+)\]", line) - if ts_match: - last_success_timestamp = ts_match.group(1) - break - - return backend_error, error_timestamp, last_success_timestamp - - -def _check_cups_backend_errors( - printer_name: str, -) -> tuple[bool, str]: - """Check CUPS error log for backend errors. Returns (has_errors, last_error).""" - # If the printer is currently healthy, ignore stale log entries. - if _is_cups_printer_healthy(printer_name): - return False, "" - - log_path = Path("/var/log/cups/error_log") - if not log_path.exists(): - return False, "" - try: - lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines() - except OSError: - return False, "" - - backend_error, error_timestamp, last_success_timestamp = _find_backend_error_in_log( - lines - ) - - if not backend_error: - return False, "" - - # If there's been a successful print after the error, backend is fine - if last_success_timestamp and last_success_timestamp > error_timestamp: - return False, "" - - return True, backend_error - - -def _display_cups_queue_status(queue: CUPSQueueStatus) -> None: - """Display CUPS queue status and offer interactive fixes.""" - if not queue.printer_name: - return - if queue.enabled and not queue.jobs and not queue.has_backend_errors: - return - - _out() - _out(f"{BOLD}── Print Queue ──{RESET}") - _out() - - if queue.has_backend_errors and queue.enabled and not queue.jobs: - _out(f" {YELLOW}{BOLD}⚡ CUPS backend has stale errors{RESET}") - _out( - f" {DIM}New print jobs may silently fail." - f" A CUPS restart usually fixes this.{RESET}" - ) - _out() - - if not queue.enabled: - _out(f" {RED}{BOLD}⚠ Printer queue is DISABLED{RESET}") - if queue.reason: - _out(f" {DIM}Reason: {queue.reason}{RESET}") - _out() - - if queue.jobs: - _out(f" {BOLD}Pending jobs ({len(queue.jobs)}):{RESET}") - for job in queue.jobs: - _out(f" {job.job_id} {DIM}{job.user} {job.size}B {job.date}{RESET}") - _out() - - _offer_queue_fix(queue) - - -def _offer_queue_fix(queue: CUPSQueueStatus) -> None: - """Prompt the user to fix a disabled queue / pending jobs.""" - _out(f" {BOLD}Available actions:{RESET}") - - options: list[str] = [] - if not queue.enabled and queue.jobs: - _out(f" {CYAN}1){RESET} Re-enable printer and retry all jobs") - _out(f" {CYAN}2){RESET} Re-enable printer and cancel all jobs") - _out(f" {CYAN}3){RESET} Cancel all jobs (keep printer disabled)") - _out(f" {CYAN}4){RESET} Restart CUPS service (fixes stale backend)") - _out(f" {CYAN}5){RESET} Restart CUPS + re-enable + retry all jobs") - _out(f" {CYAN}6){RESET} Do nothing") - options = ["1", "2", "3", "4", "5", "6"] - elif not queue.enabled: - _out(f" {CYAN}1){RESET} Re-enable printer") - _out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)") - _out(f" {CYAN}3){RESET} Do nothing") - options = ["1", "2", "3"] - elif queue.jobs: - _out(f" {CYAN}1){RESET} Cancel all pending jobs") - _out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)") - _out(f" {CYAN}3){RESET} Do nothing") - options = ["1", "2", "3"] - else: - # Backend errors only, printer enabled, no jobs - _out(f" {CYAN}1){RESET} Restart CUPS service (fixes stale backend)") - _out(f" {CYAN}2){RESET} Do nothing") - options = ["1", "2"] - - _out() - choice = _prompt(f" Choose [{'/'.join(options)}]: ") - _out() - - if not queue.enabled and queue.jobs: - _handle_disabled_with_jobs(queue, choice) - elif not queue.enabled: - _handle_disabled_no_jobs(queue, choice) - elif queue.jobs: - _handle_enabled_with_jobs(queue, choice) - else: - _handle_backend_errors_only(choice) - - -def _dwj_enable_only(printer_name: str) -> None: - """Choice 1: re-enable printer so queued jobs are retried.""" - if _cups_enable_printer(printer_name): - _out(f" {GREEN}✓ Printer re-enabled. Jobs will be retried.{RESET}") - - -def _dwj_cancel_and_enable(printer_name: str) -> None: - """Choice 2: cancel all jobs then re-enable.""" - _cups_cancel_all_jobs(printer_name) - if _cups_enable_printer(printer_name): - _out(f" {GREEN}✓ All jobs cancelled and printer re-enabled.{RESET}") - - -def _dwj_cancel_only(printer_name: str) -> None: - """Choice 3: cancel all jobs.""" - if _cups_cancel_all_jobs(printer_name): - _out(f" {GREEN}✓ All jobs cancelled.{RESET}") - - -def _dwj_restart_only(_printer_name: str) -> None: - """Choice 4: restart CUPS.""" - if _cups_restart_service(): - _out(f" {GREEN}✓ CUPS restarted.{RESET}") - - -def _dwj_restart_and_enable(printer_name: str) -> None: - """Choice 5: restart CUPS and re-enable printer.""" - if _cups_restart_service(): - _cups_enable_printer(printer_name) - _out( - f" {GREEN}✓ CUPS restarted, printer re-enabled." - f" Jobs will be retried.{RESET}" - ) - - -_DWJ_ACTIONS: dict[str, Callable[[str], None]] = { - "1": _dwj_enable_only, - "2": _dwj_cancel_and_enable, - "3": _dwj_cancel_only, - "4": _dwj_restart_only, - "5": _dwj_restart_and_enable, -} - - -def _handle_disabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None: - """Handle fix for disabled printer with pending jobs.""" - action = _DWJ_ACTIONS.get(choice) - if action is not None: - action(queue.printer_name) - else: - _out(f" {DIM}No changes made.{RESET}") - - -def _handle_disabled_no_jobs(queue: CUPSQueueStatus, choice: str) -> None: - """Handle fix for disabled printer with no pending jobs.""" - if choice == "1": - if _cups_enable_printer(queue.printer_name): - _out(f" {GREEN}✓ Printer re-enabled.{RESET}") - elif choice == "2": - if _cups_restart_service(): - _cups_enable_printer(queue.printer_name) - _out(f" {GREEN}✓ CUPS restarted and printer re-enabled.{RESET}") - else: - _out(f" {DIM}No changes made.{RESET}") - - -def _handle_enabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None: - """Handle fix for enabled printer with stuck jobs.""" - if choice == "1": - if _cups_cancel_all_jobs(queue.printer_name): - _out(f" {GREEN}✓ All jobs cancelled.{RESET}") - elif choice == "2": - if _cups_restart_service(): - _out(f" {GREEN}✓ CUPS restarted.{RESET}") - else: - _out(f" {DIM}No changes made.{RESET}") - - -def _handle_backend_errors_only(choice: str) -> None: - """Handle fix when only stale backend errors are detected.""" - if choice == "1": - if _cups_restart_service(): - _out(f" {GREEN}✓ CUPS restarted. Stale backend errors cleared.{RESET}") - else: - _out(f" {DIM}No changes made.{RESET}") - - -# ── Status code lookup ────────────────────────────────────────────── - - -def get_status_info(code: str) -> tuple[str, str, str]: - """Look up a PJL status code. Returns (severity, text, action).""" - try: - return BROTHER_STATUS_CODES[int(code)] - except (KeyError, ValueError): - return ( - "info", - f"Unknown status (code {code})", - "Check printer display for details.", - ) - - -# ── Display: shared helpers ───────────────────────────────────────── - - -def _display_report_header() -> None: - """Print the report banner box.""" - _out() - _out(f"{BOLD}╔══════════════════════════════════════════════════╗{RESET}") - _out(f"{BOLD}║ Brother Laser Printer Status Report ║{RESET}") - _out(f"{BOLD}╚══════════════════════════════════════════════════╝{RESET}") - _out() - - -def _display_page_count_estimate() -> None: - """Show estimated consumable life based on CUPS page count.""" - estimate = _estimate_consumable_life() - if estimate.total_pages <= 0: - return - _out(f"{BOLD}── Page Count Estimate ──{RESET}") - _out() - _out( - f" {BOLD}Total pages printed:{RESET} {estimate.total_pages}" - f" (toner: {estimate.toner_pages} since replacement," - f" drum: {estimate.drum_pages} since replacement)" - ) - _out() - # Toner bar - toner_pct = estimate.toner_pct_remaining - toner_filled = toner_pct * PROGRESS_BAR_WIDTH // 100 - toner_empty = PROGRESS_BAR_WIDTH - toner_filled - toner_bar = f"[{'█' * toner_filled}{'░' * toner_empty}]" - if estimate.toner_exhausted: - toner_color = RED - toner_note = " ← REPLACE NOW" - elif estimate.toner_low: - toner_color = YELLOW - toner_note = " ← order soon" - else: - toner_color = GREEN - toner_note = "" - _out( - f" {BOLD}Toner:{RESET} {toner_color}{toner_bar} ~{toner_pct}%" - f"{toner_note}{RESET}" - ) - # Drum bar - drum_pct = estimate.drum_pct_remaining - drum_filled = drum_pct * PROGRESS_BAR_WIDTH // 100 - drum_empty = PROGRESS_BAR_WIDTH - drum_filled - drum_bar = f"[{'█' * drum_filled}{'░' * drum_empty}]" - if estimate.drum_near_end: - drum_color = YELLOW - drum_note = " ← nearing end" - else: - drum_color = GREEN - drum_note = "" - _out( - f" {BOLD}Drum:{RESET} {drum_color}{drum_bar} ~{drum_pct}%" - f"{drum_note}{RESET}" - ) - _out( - f" {DIM}Based on pages since last replacement" - f" vs rated capacity (toner ~{TONER_RATED_PAGES}," - f" drum ~{DRUM_RATED_PAGES}).{RESET}" - ) - _out(f" {DIM}Reset after replacing: --reset-toner" f" or --reset-drum{RESET}") - if estimate.toner_exhausted: - _out() - _out( - f" {RED}{BOLD}⚠ Toner is likely exhausted." - f" This is probably why the orange light is flashing.{RESET}" - ) - _out() - - -def _display_consumables_reference() -> None: - """Print compatible consumables reference.""" - _out(f"{BOLD}── Compatible Consumables ──{RESET}") - _out() - _out(f" {BOLD}Toner:{RESET} TN-1050 / TN-1030 (or compatible third-party)") - _out(f" {BOLD}Drum:{RESET} DR-1050 / DR-1030 (or compatible third-party)") - _out(f" {DIM} Toner rated ~1000 pages; Drum rated ~10000 pages.{RESET}") - _out() - - -# ── Display: USB helpers ──────────────────────────────────────────── - - -def _display_usb_device_info(result: USBResult) -> None: - """Print device info block for USB results.""" - _out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}") - _out(f"{BOLD}Connection:{RESET} USB") - if result.serial: - _out(f"{BOLD}Serial:{RESET} {result.serial}") - - if result.online == "TRUE": - _out(f"{BOLD}Online:{RESET} {GREEN}Yes{RESET}") - elif result.online == "FALSE": - _out(f"{BOLD}Online:{RESET} {YELLOW}No (needs attention){RESET}") - - _out() - - if result.economode: - if result.economode == "ON": - _out( - f"{BOLD}Toner Save:{RESET} {GREEN}ON{RESET}" - " (extends toner life, lighter prints)" - ) - else: - _out(f"{BOLD}Toner Save:{RESET} OFF") - - -_SEVERITY_ICONS: dict[str, str] = { - "ok": "✓", - "info": "i", - "warn": "⚡", - "critical": "⚠", -} -_SEVERITY_COLORS: dict[str, str] = { - "ok": GREEN, - "info": CYAN, - "warn": YELLOW, - "critical": RED, -} -_SEVERITY_SUMMARIES: dict[str, str] = { - "ok": f"{GREEN}{BOLD}✓ Printer is healthy. No replacements needed.{RESET}", - "info": f"{CYAN}{BOLD}i Printer is busy/processing." - f" No replacements needed.{RESET}", - "warn": f"{YELLOW}{BOLD}⚡ WARNING: Maintenance will be needed" - f" soon.{RESET}\n{YELLOW} Order replacement parts" - f" now to avoid interruption.{RESET}", - "critical": f"{RED}{BOLD}⚠ ACTION REQUIRED: Replacement or fix needed now!{RESET}", -} - - -def _format_status_detail( - severity: str, short_text: str, action: str, result: USBResult -) -> None: - """Print severity icon, display text, and action.""" - color = _SEVERITY_COLORS.get(severity, GREEN) - icon = _SEVERITY_ICONS.get(severity, "✓") - - _out(f" {color}{BOLD}{icon} {short_text}{RESET}") - if result.display and result.display != short_text: - _out(f" {DIM}Display: {result.display}{RESET}") - _out(f" {DIM}Status code: {result.status_code}{RESET}") - - if action: - _out() - _out(f" {color}{BOLD}Action:{RESET} {color}{action}{RESET}") - _out() - _out(_SEVERITY_SUMMARIES.get(severity, "")) - - -def _display_pjl_status(result: USBResult) -> None: - """Display PJL status code interpretation.""" - _out() - _out(f"{BOLD}── Printer Status ──{RESET}") - _out() - - if not result.status_code: - _out(f" {YELLOW}Could not read status from printer.{RESET}") - if result.display: - _out(f" Display message: {BOLD}{result.display}{RESET}") - return - - severity, short_text, action = get_status_info(result.status_code) - _format_status_detail(severity, short_text, action, result) - - -# ── Display: USB results ──────────────────────────────────────────── - - -def _display_cups_fallback_note(result: USBResult) -> None: - """Show a note when running in CUPS fallback mode.""" - _out() - if result.port_status is not None: - _out( - f" {DIM}Note: Hardware status obtained via USB port query." - f" Toner/drum percentages not available.{RESET}" - ) - else: - _out( - f" {DIM}Note: pyusb not available; status obtained via" - f" CUPS only. Detailed toner/drum levels are not" - f" available in this mode.{RESET}" - ) - - -def display_usb_results(result: USBResult) -> None: - """Print a formatted report for USB PJL query results.""" - if result.error: - _out(f"{RED}Error: {result.error}{RESET}") - sys.exit(1) - - _display_report_header() - _display_usb_device_info(result) - _display_pjl_status(result) - - if result.device == "cups": - _display_cups_fallback_note(result) - - _out() - _display_page_count_estimate() - _display_consumables_reference() - - queue = get_cups_queue_status() - _display_cups_queue_status(queue) - - -# ── Display: Network helpers ──────────────────────────────────────── - - -@dataclass -class _SupplyStatus: - """Processed supply level info for display.""" - - color: str - bar: str - status_text: str - warning: str - needs_replacement: bool - - -def _classify_percentage_level(desc: str, pct: int) -> tuple[int, str, str, str, bool]: - """Classify a supply by its calculated percentage.""" - if pct <= SUPPLY_LOW_PCT: - return pct, f"{pct}%", RED, f"{desc} at {pct}%.", True - if pct <= SUPPLY_WARN_PCT: - return pct, f"{pct}%", YELLOW, f"{desc} at {pct}% -- order soon.", False - return pct, f"{pct}%", GREEN, "", False - - -def _classify_supply_level( - desc: str, max_val: int, level: int -) -> tuple[int, str, str, str, bool]: - """Classify a supply level. Returns (pct, status, color, warning, replace).""" - if level == SNMP_LEVEL_OK: - return -1, "OK", GREEN, "", False - if level == SNMP_LEVEL_LOW: - return -1, "LOW", RED, f"{desc} is LOW.", True - if level == 0: - return 0, "EMPTY", RED, f"{desc} is EMPTY -- replace now!", True - if max_val > 0: - pct = min(level * 100 // max_val, 100) - return _classify_percentage_level(desc, pct) - return -1, "", GREEN, "", False - - -def _format_supply_bar(pct: int) -> str: - """Build a progress bar string for a supply percentage.""" - if pct < 0: - return "" - filled = pct * PROGRESS_BAR_WIDTH // 100 - empty = PROGRESS_BAR_WIDTH - filled - return f"[{'█' * filled}{'░' * empty}]" - - -def _process_supply_item(desc: str, max_val: int, level: int) -> _SupplyStatus: - """Process a single supply item into display info.""" - pct, status_text, color, warning, needs_replacement = _classify_supply_level( - desc, max_val, level - ) - bar = _format_supply_bar(pct) - return _SupplyStatus(color, bar, status_text, warning, needs_replacement) - - -def _display_supply_warnings(*, needs_replacement: bool, warnings: list[str]) -> None: - """Display supply level warnings summary.""" - _out() - if needs_replacement: - _out(f"{RED}{BOLD}⚠ ACTION NEEDED:{RESET}") - for w in warnings: - _out(f" {RED}• {w}{RESET}") - elif warnings: - _out(f"{YELLOW}{BOLD}⚡ HEADS UP:{RESET}") - for w in warnings: - _out(f" {YELLOW}• {w}{RESET}") - else: - _out(f"{GREEN}{BOLD}✓ All consumables are at healthy levels.{RESET}") - - -def _parse_supply_value(values: list[str], index: int) -> int: - """Safely parse an integer from a supply value list.""" - try: - return int(values[index]) - except (IndexError, ValueError): - return 0 - - -def _collect_supply_items( - result: NetworkResult, -) -> tuple[list[_SupplyStatus], list[str]]: - """Parse and collect supply items with their descriptions.""" - items: list[_SupplyStatus] = [] - descs: list[str] = [] - for i, desc in enumerate(result.supply_descriptions): - max_val = _parse_supply_value(result.supply_max, i) - level = _parse_supply_value(result.supply_levels, i) - items.append(_process_supply_item(desc, max_val, level)) - descs.append(desc) - return items, descs - - -def _display_supply_levels(result: NetworkResult) -> None: - """Display consumable supply levels section.""" - _out() - _out(f"{BOLD}── Consumable Levels ──{RESET}") - _out() - - needs_replacement = False - warnings: list[str] = [] - items, descs = _collect_supply_items(result) - - for desc, item in zip(descs, items, strict=True): - _out( - f" {BOLD}{desc:<25}{RESET}" - f" {item.color}{item.bar} {item.status_text}{RESET}" - ) - if item.needs_replacement: - needs_replacement = True - if item.warning: - warnings.append(item.warning) - - _display_supply_warnings(needs_replacement=needs_replacement, warnings=warnings) - - -def _display_network_device_info(result: NetworkResult) -> None: - """Display device info section for network results.""" - _out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}") - _out(f"{BOLD}Connection:{RESET} Network ({result.ip})") - if result.serial: - _out(f"{BOLD}Serial:{RESET} {result.serial}") - if result.display: - _out(f"{BOLD}Display:{RESET} {result.display}") - if result.page_count and result.page_count.isdigit(): - _out(f"{BOLD}Pages:{RESET} {result.page_count} total") - - -# ── Display: Network results ──────────────────────────────────────── - - -def display_network_results(result: NetworkResult) -> None: - """Print a formatted report for SNMP network query results.""" - if result.error: - _out(f"{RED}Error: {result.error}{RESET}") - sys.exit(1) - - _display_report_header() - _display_network_device_info(result) - _display_supply_levels(result) - - _out() - _out( - f"{CYAN}Tip: Visit http://{result.ip} for the full web management" - f" interface.{RESET}" - ) - _out() - # ── Main ───────────────────────────────────────────────────────────── @@ -1838,15 +104,13 @@ def main(argv: list[str] | None = None) -> None: """Entry point: auto-detect USB or network Brother printer.""" args = argv if argv is not None else sys.argv[1:] - # Handle consumable reset commands if args and args[0] == "--reset-toner": - _reset_consumable("toner") + reset_consumable("toner") return if args and args[0] == "--reset-drum": - _reset_consumable("drum") + reset_consumable("drum") return - # Enforce root — needed for USB hardware queries and CUPS management if os.geteuid() != 0: _out( f"{RED}Root access required. Re-run with sudo:{RESET}\n" diff --git a/python_pkg/brother_printer/constants.py b/python_pkg/brother_printer/constants.py new file mode 100644 index 0000000..90a9a17 --- /dev/null +++ b/python_pkg/brother_printer/constants.py @@ -0,0 +1,173 @@ +"""Constants, status codes, and lookup tables for Brother printer checking.""" + +from __future__ import annotations + +import sys + +# ── Colors ─────────────────────────────────────────────────────────── + +RED = "\033[0;31m" +YELLOW = "\033[1;33m" +GREEN = "\033[0;32m" +CYAN = "\033[0;36m" +BOLD = "\033[1m" +DIM = "\033[2m" +RESET = "\033[0m" + +# ── SNMP supply level sentinel values ──────────────────────────────────────── + +SNMP_LEVEL_OK = -3 +SNMP_LEVEL_LOW = -2 +SUPPLY_LOW_PCT = 10 +SUPPLY_WARN_PCT = 25 +PROGRESS_BAR_WIDTH = 25 + +# Brother HL-1110 consumable page ratings +TONER_RATED_PAGES = 1000 +DRUM_RATED_PAGES = 10000 +CUPS_PAGE_LOG_PATH = "/var/log/cups/page_log" +CONSUMABLE_STATE_DIR = ".config/brother_printer" +MIN_LPSTAT_JOB_PARTS = 4 + +BROTHER_USB_VENDOR_ID = 0x04F9 + + +def _out(text: str = "") -> None: + """Write a line to stdout.""" + sys.stdout.write(text + "\n") + + +def _prompt(text: str) -> str: + """Read user input with a prompt.""" + sys.stdout.write(text) + sys.stdout.flush() + return sys.stdin.readline().strip() + + +# ── Brother PJL status codes ──────────────────────────────────────── +# Documented in Brother PJL Technical Reference. +# Format: code -> (severity, short_text, action) +# Severities: ok, info, warn, critical + +BROTHER_STATUS_CODES: dict[int, tuple[str, str, str]] = { + 10001: ("ok", "Ready", ""), + 10002: ("ok", "Sleep", ""), + 10003: ("info", "Self-test / Calibrating", ""), + 10004: ("ok", "Warming up", ""), + 10005: ("ok", "Cooling down", ""), + 10006: ("info", "Processing", ""), + 10007: ("info", "Printing", ""), + 10014: ("ok", "Cancelling", ""), + 10023: ("info", "Waiting", ""), + # Toner + 30010: ( + "warn", + "Toner Low", + "Order replacement toner cartridge (TN-1050/TN-1030 compatible).", + ), + 30038: ( + "warn", + "Toner Low", + "Order replacement toner cartridge (TN-1050/TN-1030 compatible).", + ), + 40038: ( + "warn", + "Toner Low", + "Order replacement toner cartridge (TN-1050/TN-1030 compatible).", + ), + 40309: ( + "critical", + "Replace Toner", + "The toner cartridge needs immediate replacement" + " (TN-1050/TN-1030 compatible).", + ), + 40310: ( + "critical", + "Toner End", + "The toner cartridge is empty." " Replace now (TN-1050/TN-1030 compatible).", + ), + # Drum + 30201: ( + "warn", + "Drum End Soon", + "The drum unit is nearing end of life." + " Order replacement (DR-1050 compatible).", + ), + 40201: ( + "warn", + "Drum End Soon", + "The drum unit is nearing end of life." + " Order replacement (DR-1050 compatible).", + ), + 40019: ( + "critical", + "Replace Drum", + "The drum unit must be replaced (DR-1050 compatible).", + ), + 40020: ( + "critical", + "Drum Stop", + "The drum unit must be replaced immediately (DR-1050 compatible).", + ), + # Paper / feed + 40000: ("critical", "Paper Jam", "Clear the paper jam and close all covers."), + 40300: ( + "critical", + "No Paper / Tray Open", + "Load paper or close the paper tray.", + ), + 40302: ("critical", "No Paper", "Load paper into the paper tray."), + 40016: ("warn", "Paper Feed Error", "Check paper tray and re-seat paper."), + # Cover + 41000: ("critical", "Cover Open", "Close the top cover of the printer."), + 41001: ("critical", "Cover Open", "Close the front cover of the printer."), + # Others + 35078: ("info", "Manual Feed", "Load paper in the manual feed slot."), + 42000: ( + "critical", + "Machine Error", + "Power-cycle the printer. If error persists, contact service.", + ), +} + +# ── CUPS status code mappings ──────────────────────────────────────── + +_CUPS_REASONS_TO_STATUS: dict[str, int] = { + "paused": 10023, + "moving-to-paused": 10023, + "toner-low": 30010, + "toner-empty": 40310, + "marker-supply-low": 30010, + "marker-supply-empty": 40310, + "media-empty": 40302, + "media-needed": 40302, + "media-jam": 40000, + "cover-open": 41000, + "door-open": 41000, + "input-tray-missing": 40300, +} + +_CUPS_STATE_TO_STATUS: dict[str, int] = { + "idle": 10001, + "processing": 10007, + "stopped": 10023, +} + +_ERROR_REASON_MAP: tuple[tuple[tuple[str, ...], str, str], ...] = ( + (("media-jam",), "40000", "Paper Jam"), + (("cover-open", "door-open"), "41000", "Cover Open"), + (("toner-empty",), "40310", "Toner End"), + (("toner-low",), "30010", "Toner Low"), +) + + +def get_status_info(code: str) -> tuple[str, str, str]: + """Look up a PJL status code. Returns (severity, text, action).""" + try: + return BROTHER_STATUS_CODES[int(code)] + except (KeyError, ValueError): + return ( + "info", + f"Unknown status (code {code})", + "Check printer display for details.", + ) diff --git a/python_pkg/brother_printer/cups_queue.py b/python_pkg/brother_printer/cups_queue.py new file mode 100644 index 0000000..9083387 --- /dev/null +++ b/python_pkg/brother_printer/cups_queue.py @@ -0,0 +1,459 @@ +"""CUPS queue inspection, display, and interactive fix functions.""" + +from __future__ import annotations + +from pathlib import Path +import re +import shutil +import subprocess +import sys +import time +from typing import TYPE_CHECKING + +from python_pkg.brother_printer.constants import ( + BOLD, + CYAN, + DIM, + GREEN, + MIN_LPSTAT_JOB_PARTS, + RED, + RESET, + YELLOW, + _out, + _prompt, +) +from python_pkg.brother_printer.cups_service import find_cups_printer_name +from python_pkg.brother_printer.data_classes import CUPSJob, CUPSQueueStatus + +if TYPE_CHECKING: + from collections.abc import Callable + + +# ── Queue inspection ───────────────────────────────────────────────── + + +def _parse_lpstat_printer_line(line: str) -> tuple[bool, str]: + """Parse an lpstat -p line. Returns (enabled, reason).""" + enabled = "disabled" not in line.lower() + reason = "" + match = re.search(r"\d{4}\s+-\s*(.+)", line) + if match: + reason = match.group(1).strip() + return enabled, reason + + +def _parse_lpstat_jobs(output: str, printer_name: str) -> list[CUPSJob]: + """Parse lpstat -o output into CUPSJob list.""" + jobs: list[CUPSJob] = [] + for line in output.splitlines(): + if not line.startswith(printer_name): + continue + parts = line.split() + if len(parts) >= MIN_LPSTAT_JOB_PARTS: + job_id = parts[0] + user = parts[1] + size = parts[2] + date = " ".join(parts[3:]) + jobs.append(CUPSJob(job_id=job_id, user=user, size=size, date=date)) + return jobs + + +def get_cups_queue_status() -> CUPSQueueStatus: + """Check if the CUPS queue is disabled and list pending jobs.""" + printer_name = find_cups_printer_name() + if not printer_name: + return CUPSQueueStatus() + + result = CUPSQueueStatus(printer_name=printer_name) + lpstat_path = shutil.which("lpstat") + if not lpstat_path: + return result + + try: + r = subprocess.run( + [lpstat_path, "-p", printer_name], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if "printer" in line.lower() and printer_name in line: + result.enabled, result.reason = _parse_lpstat_printer_line(line) + break + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + pass + + try: + r = subprocess.run( + [lpstat_path, "-o", printer_name], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + result.jobs = _parse_lpstat_jobs(r.stdout, printer_name) + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + pass + + has_errors, last_error = _check_cups_backend_errors(printer_name) + result.has_backend_errors = has_errors + result.last_backend_error = last_error + + return result + + +# ── CUPS fix actions ───────────────────────────────────────────────── + + +def _cups_enable_printer(printer_name: str) -> bool: + """Re-enable a disabled CUPS printer. Returns True on success.""" + cupsenable_path = shutil.which("cupsenable") + if not cupsenable_path: + _out(f" {RED}cupsenable not found.{RESET}") + return False + try: + subprocess.run( + [cupsenable_path, printer_name], + timeout=5, + check=True, + ) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e: + _out(f" {RED}Failed to enable printer: {e}{RESET}") + return False + else: + return True + + +def _cups_cancel_all_jobs(printer_name: str) -> bool: + """Cancel all pending jobs. Returns True on success.""" + cancel_path = shutil.which("cancel") + if not cancel_path: + _out(f" {RED}cancel command not found.{RESET}") + return False + try: + subprocess.run( + [cancel_path, "-a", printer_name], + timeout=5, + check=True, + ) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e: + _out(f" {RED}Failed to cancel jobs: {e}{RESET}") + return False + else: + return True + + +def _cups_cancel_job(job_id: str) -> bool: + """Cancel a specific job. Returns True on success.""" + cancel_path = shutil.which("cancel") + if not cancel_path: + return False + try: + subprocess.run( + [cancel_path, job_id], + timeout=5, + check=True, + ) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): + return False + else: + return True + + +def _cups_restart_service() -> bool: + """Restart the CUPS service. Returns True on success.""" + systemctl_path = shutil.which("systemctl") + if not systemctl_path: + _out(f" {RED}systemctl not found.{RESET}") + return False + sys.stdout.write(f" {DIM}Restarting CUPS...{RESET}") + sys.stdout.flush() + try: + proc = subprocess.Popen( + [systemctl_path, "restart", "cups"], + ) + deadline = time.time() + 30 + while proc.poll() is None: + if time.time() > deadline: + proc.kill() + proc.wait() + sys.stdout.write("\n") + _out( + f" {RED}CUPS restart timed out" + f" (stuck backend process?).{RESET}" + ) + _out( + f" {DIM}Try: sudo kill -9 $(pgrep -f 'cups/backend/usb')" + f" && sudo systemctl restart cups{RESET}" + ) + return False + sys.stdout.write(".") + sys.stdout.flush() + time.sleep(1) + sys.stdout.write("\n") + if proc.returncode != 0: + _out( + f" {RED}CUPS restart failed" f" (exit code {proc.returncode}).{RESET}" + ) + return False + except OSError as e: + sys.stdout.write("\n") + _out(f" {RED}Failed to restart CUPS: {e}{RESET}") + return False + time.sleep(2) + return True + + +# ── Backend error detection ────────────────────────────────────────── + + +def _is_cups_printer_healthy(printer_name: str) -> bool: + """Check live CUPS state via lpstat. Returns True if enabled with no issues.""" + lpstat_path = shutil.which("lpstat") + if not lpstat_path: + return False + try: + r = subprocess.run( + [lpstat_path, "-p", printer_name], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if ( + printer_name in line + and "idle" in line.lower() + and "enabled" in line.lower() + ): + return True + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + pass + return False + + +def _find_backend_error_in_log( + lines: list[str], +) -> tuple[str, str, str]: + """Scan CUPS log lines (reversed) for backend errors. + + Returns: + (backend_error, error_timestamp, last_success_timestamp) + """ + backend_error = "" + error_timestamp = "" + last_success_timestamp = "" + + for line in reversed(lines): + if ( + "backend errors" in line or "stopped with status" in line + ) and not backend_error: + backend_error = line.strip() + ts_match = re.search(r"\[([^\]]+)\]", line) + if ts_match: + error_timestamp = ts_match.group(1) + if ("Completed" in line or "total" in line) and error_timestamp: + ts_match = re.search(r"\[([^\]]+)\]", line) + if ts_match: + last_success_timestamp = ts_match.group(1) + break + + return backend_error, error_timestamp, last_success_timestamp + + +def _check_cups_backend_errors( + printer_name: str, +) -> tuple[bool, str]: + """Check CUPS error log for backend errors. Returns (has_errors, last_error).""" + if _is_cups_printer_healthy(printer_name): + return False, "" + + log_path = Path("/var/log/cups/error_log") + if not log_path.exists(): + return False, "" + try: + lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines() + except OSError: + return False, "" + + backend_error, error_timestamp, last_success_timestamp = _find_backend_error_in_log( + lines + ) + + if not backend_error: + return False, "" + + if last_success_timestamp and last_success_timestamp > error_timestamp: + return False, "" + + return True, backend_error + + +# ── Queue status display ──────────────────────────────────────────── + + +def display_cups_queue_status(queue: CUPSQueueStatus) -> None: + """Display CUPS queue status and offer interactive fixes.""" + if not queue.printer_name: + return + if queue.enabled and not queue.jobs and not queue.has_backend_errors: + return + + _out() + _out(f"{BOLD}── Print Queue ──{RESET}") + _out() + + if queue.has_backend_errors and queue.enabled and not queue.jobs: + _out(f" {YELLOW}{BOLD}⚡ CUPS backend has stale errors{RESET}") + _out( + f" {DIM}New print jobs may silently fail." + f" A CUPS restart usually fixes this.{RESET}" + ) + _out() + + if not queue.enabled: + _out(f" {RED}{BOLD}⚠ Printer queue is DISABLED{RESET}") + if queue.reason: + _out(f" {DIM}Reason: {queue.reason}{RESET}") + _out() + + if queue.jobs: + _out(f" {BOLD}Pending jobs ({len(queue.jobs)}):{RESET}") + for job in queue.jobs: + _out(f" {job.job_id} {DIM}{job.user} {job.size}B {job.date}{RESET}") + _out() + + _offer_queue_fix(queue) + + +# ── Interactive queue fix ──────────────────────────────────────────── + + +def _offer_queue_fix(queue: CUPSQueueStatus) -> None: + """Prompt the user to fix a disabled queue / pending jobs.""" + _out(f" {BOLD}Available actions:{RESET}") + + options: list[str] = [] + if not queue.enabled and queue.jobs: + _out(f" {CYAN}1){RESET} Re-enable printer and retry all jobs") + _out(f" {CYAN}2){RESET} Re-enable printer and cancel all jobs") + _out(f" {CYAN}3){RESET} Cancel all jobs (keep printer disabled)") + _out(f" {CYAN}4){RESET} Restart CUPS service (fixes stale backend)") + _out(f" {CYAN}5){RESET} Restart CUPS + re-enable + retry all jobs") + _out(f" {CYAN}6){RESET} Do nothing") + options = ["1", "2", "3", "4", "5", "6"] + elif not queue.enabled: + _out(f" {CYAN}1){RESET} Re-enable printer") + _out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)") + _out(f" {CYAN}3){RESET} Do nothing") + options = ["1", "2", "3"] + elif queue.jobs: + _out(f" {CYAN}1){RESET} Cancel all pending jobs") + _out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)") + _out(f" {CYAN}3){RESET} Do nothing") + options = ["1", "2", "3"] + else: + _out(f" {CYAN}1){RESET} Restart CUPS service (fixes stale backend)") + _out(f" {CYAN}2){RESET} Do nothing") + options = ["1", "2"] + + _out() + choice = _prompt(f" Choose [{'/'.join(options)}]: ") + _out() + + if not queue.enabled and queue.jobs: + _handle_disabled_with_jobs(queue, choice) + elif not queue.enabled: + _handle_disabled_no_jobs(queue, choice) + elif queue.jobs: + _handle_enabled_with_jobs(queue, choice) + else: + _handle_backend_errors_only(choice) + + +def _dwj_enable_only(printer_name: str) -> None: + """Choice 1: re-enable printer so queued jobs are retried.""" + if _cups_enable_printer(printer_name): + _out(f" {GREEN}✓ Printer re-enabled. Jobs will be retried.{RESET}") + + +def _dwj_cancel_and_enable(printer_name: str) -> None: + """Choice 2: cancel all jobs then re-enable.""" + _cups_cancel_all_jobs(printer_name) + if _cups_enable_printer(printer_name): + _out(f" {GREEN}✓ All jobs cancelled and printer re-enabled.{RESET}") + + +def _dwj_cancel_only(printer_name: str) -> None: + """Choice 3: cancel all jobs.""" + if _cups_cancel_all_jobs(printer_name): + _out(f" {GREEN}✓ All jobs cancelled.{RESET}") + + +def _dwj_restart_only(_printer_name: str) -> None: + """Choice 4: restart CUPS.""" + if _cups_restart_service(): + _out(f" {GREEN}✓ CUPS restarted.{RESET}") + + +def _dwj_restart_and_enable(printer_name: str) -> None: + """Choice 5: restart CUPS and re-enable printer.""" + if _cups_restart_service(): + _cups_enable_printer(printer_name) + _out( + f" {GREEN}✓ CUPS restarted, printer re-enabled." + f" Jobs will be retried.{RESET}" + ) + + +_DWJ_ACTIONS: dict[str, Callable[[str], None]] = { + "1": _dwj_enable_only, + "2": _dwj_cancel_and_enable, + "3": _dwj_cancel_only, + "4": _dwj_restart_only, + "5": _dwj_restart_and_enable, +} + + +def _handle_disabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None: + """Handle fix for disabled printer with pending jobs.""" + action = _DWJ_ACTIONS.get(choice) + if action is not None: + action(queue.printer_name) + else: + _out(f" {DIM}No changes made.{RESET}") + + +def _handle_disabled_no_jobs(queue: CUPSQueueStatus, choice: str) -> None: + """Handle fix for disabled printer with no pending jobs.""" + if choice == "1": + if _cups_enable_printer(queue.printer_name): + _out(f" {GREEN}✓ Printer re-enabled.{RESET}") + elif choice == "2": + if _cups_restart_service(): + _cups_enable_printer(queue.printer_name) + _out(f" {GREEN}✓ CUPS restarted and printer re-enabled.{RESET}") + else: + _out(f" {DIM}No changes made.{RESET}") + + +def _handle_enabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None: + """Handle fix for enabled printer with stuck jobs.""" + if choice == "1": + if _cups_cancel_all_jobs(queue.printer_name): + _out(f" {GREEN}✓ All jobs cancelled.{RESET}") + elif choice == "2": + if _cups_restart_service(): + _out(f" {GREEN}✓ CUPS restarted.{RESET}") + else: + _out(f" {DIM}No changes made.{RESET}") + + +def _handle_backend_errors_only(choice: str) -> None: + """Handle fix when only stale backend errors are detected.""" + if choice == "1": + if _cups_restart_service(): + _out(f" {GREEN}✓ CUPS restarted. Stale backend errors cleared.{RESET}") + else: + _out(f" {DIM}No changes made.{RESET}") diff --git a/python_pkg/brother_printer/cups_service.py b/python_pkg/brother_printer/cups_service.py new file mode 100644 index 0000000..19fa20f --- /dev/null +++ b/python_pkg/brother_printer/cups_service.py @@ -0,0 +1,479 @@ +"""CUPS service management, USB fallback, and consumable state tracking.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +import re +import shutil +import subprocess +import time +import urllib.parse + +from python_pkg.brother_printer.constants import ( + _CUPS_REASONS_TO_STATUS, + _CUPS_STATE_TO_STATUS, + _ERROR_REASON_MAP, + BROTHER_USB_VENDOR_ID, + CONSUMABLE_STATE_DIR, + CUPS_PAGE_LOG_PATH, + DRUM_RATED_PAGES, + GREEN, + RESET, + TONER_RATED_PAGES, + _out, +) +from python_pkg.brother_printer.data_classes import ( + PageCountEstimate, + USBPortStatus, + USBResult, +) + +logger = logging.getLogger(__name__) + +CUPS_PAGE_LOG = Path(CUPS_PAGE_LOG_PATH) +CONSUMABLE_STATE_FILE = Path.home() / CONSUMABLE_STATE_DIR / "state.json" + + +# ── pyusb device info ──────────────────────────────────────────────── + + +def _get_pyusb_device_info() -> dict[str, str]: + """Get Brother USB printer info via pyusb (no interface claim needed).""" + try: + import usb.core + + dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID) + if dev is None: + return {} + except (ImportError, OSError, ValueError): + return {} + else: + return { + "product": dev.product or "", + "serial": dev.serial_number or "", + } + + +# ── CUPS service control ──────────────────────────────────────────── + + +def _stop_cups() -> bool: + """Stop CUPS service and sockets. Returns True on success.""" + systemctl = shutil.which("systemctl") + if not systemctl: + return False + try: + subprocess.run( + [systemctl, "stop", "cups.service", "cups.socket", "cups.path"], + timeout=15, + check=True, + ) + time.sleep(2) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): + return False + return True + + +def is_cups_scheduler_running() -> bool: + """Check if the CUPS scheduler is currently running.""" + lpstat = shutil.which("lpstat") + if not lpstat: + return False + try: + r = subprocess.run( + [lpstat, "-r"], + capture_output=True, + text=True, + timeout=3, + check=False, + ) + return ( + "is running" in r.stdout.lower() and "not running" not in r.stdout.lower() + ) + except (subprocess.TimeoutExpired, OSError): + return False + + +def start_cups() -> bool: + """Start CUPS service, socket, and path units. Returns True on success.""" + systemctl = shutil.which("systemctl") + if not systemctl: + return False + try: + subprocess.run( + [systemctl, "start", "cups.service", "cups.socket", "cups.path"], + timeout=15, + check=True, + ) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): + return False + for _ in range(10): + if is_cups_scheduler_running(): + return True + time.sleep(1) + return False + + +def _ensure_cups_running() -> bool: + """Make sure CUPS is running, starting it if necessary.""" + if is_cups_scheduler_running(): + return True + return start_cups() + + +# ── USB port status via pyusb ──────────────────────────────────────── + + +def _query_usb_port_status_raw() -> USBPortStatus | None: + """Query USB printer port status via pyusb control transfer. + + Requires root and temporarily stops CUPS to access the USB device. + Returns None if the query fails. + """ + try: + import usb.core + import usb.util + except ImportError: + return None + + dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID) + if dev is None: + return None + + if not _stop_cups(): + return None + + try: + dev.reset() + time.sleep(2) + dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID) + if dev is None: + return None + + try: + if dev.is_kernel_driver_active(0): + dev.detach_kernel_driver(0) + except (usb.core.USBError, NotImplementedError): + pass + + usb.util.claim_interface(dev, 0) + try: + # USB Printer Class GET_PORT_STATUS (bRequest=0x01) + raw = dev.ctrl_transfer(0xA1, 0x01, 0, 0, 1, timeout=5000) + port_byte = raw[0] + return USBPortStatus( + paper_empty=bool(port_byte & 0x20), + online=bool(port_byte & 0x10), + error=not bool(port_byte & 0x08), + raw_byte=port_byte, + ) + finally: + usb.util.release_interface(dev, 0) + usb.util.dispose_resources(dev) + except (OSError, ValueError): + logger.debug("USB port status query failed", exc_info=True) + return None + finally: + start_cups() + + +# ── Consumable state management ────────────────────────────────────── + + +def _get_cups_total_pages() -> int: + """Parse CUPS page_log to get total pages printed (deduplicated by job).""" + if not CUPS_PAGE_LOG.exists(): + return 0 + try: + text = CUPS_PAGE_LOG.read_text(encoding="utf-8", errors="replace") + except OSError: + return 0 + jobs: dict[str, int] = {} + for line in text.splitlines(): + match = re.search(r"\s(\d+)\s+\[.*?\]\s+total\s+(\d+)", line) + if match: + job_id = match.group(1) + pages = int(match.group(2)) + jobs[job_id] = max(jobs.get(job_id, 0), pages) + return sum(jobs.values()) + + +def _load_consumable_state() -> dict[str, int]: + """Load consumable replacement state from disk.""" + defaults: dict[str, int] = {"toner_replaced_at": 0, "drum_replaced_at": 0} + if not CONSUMABLE_STATE_FILE.exists(): + return defaults + try: + data = json.loads( + CONSUMABLE_STATE_FILE.read_text(encoding="utf-8"), + ) + return { + "toner_replaced_at": int(data.get("toner_replaced_at", 0)), + "drum_replaced_at": int(data.get("drum_replaced_at", 0)), + } + except (OSError, json.JSONDecodeError, ValueError, TypeError): + return defaults + + +def _save_consumable_state(state: dict[str, int]) -> None: + """Persist consumable replacement state to disk.""" + CONSUMABLE_STATE_FILE.parent.mkdir(parents=True, exist_ok=True) + CONSUMABLE_STATE_FILE.write_text( + json.dumps(state, indent=2) + "\n", + encoding="utf-8", + ) + + +def reset_consumable(name: str) -> None: + """Record current page count as replacement point for a consumable.""" + total = _get_cups_total_pages() + state = _load_consumable_state() + key = f"{name}_replaced_at" + state[key] = total + _save_consumable_state(state) + _out( + f"{GREEN}✓ {name.capitalize()} counter reset at page count" f" {total}.{RESET}" + ) + _out(f" State saved to {CONSUMABLE_STATE_FILE}") + + +def estimate_consumable_life() -> PageCountEstimate: + """Estimate toner/drum life from CUPS page count since last replacement.""" + total = _get_cups_total_pages() + if total <= 0: + return PageCountEstimate() + state = _load_consumable_state() + toner_pages = max(0, total - state["toner_replaced_at"]) + drum_pages = max(0, total - state["drum_replaced_at"]) + toner_pct = max(0, 100 - (toner_pages * 100 // TONER_RATED_PAGES)) + drum_pct = max(0, 100 - (drum_pages * 100 // DRUM_RATED_PAGES)) + return PageCountEstimate( + total_pages=total, + toner_pages=toner_pages, + drum_pages=drum_pages, + toner_pct_remaining=toner_pct, + drum_pct_remaining=drum_pct, + toner_exhausted=toner_pages >= TONER_RATED_PAGES, + toner_low=toner_pages >= TONER_RATED_PAGES * 80 // 100, + drum_near_end=drum_pages >= DRUM_RATED_PAGES * 90 // 100, + ) + + +# ── IPP / CUPS attribute queries ──────────────────────────────────── + + +def _parse_ipp_attributes(output: str) -> dict[str, str]: + """Parse ipptool verbose output into an attribute dict.""" + attrs: dict[str, str] = {} + for line in output.splitlines(): + match = re.match(r"\s+(\S+)\s+\([^)]+\)\s+=\s+(.*)", line) + if match: + attrs[match.group(1)] = match.group(2).strip() + return attrs + + +def _get_cups_ipp_status(printer_name: str) -> dict[str, str]: + """Query printer attributes via CUPS IPP using ipptool.""" + ipptool_path = shutil.which("ipptool") + if not ipptool_path: + return {} + uri = f"ipp://localhost/printers/{printer_name}" + try: + r = subprocess.run( + [ipptool_path, "-tv", uri, "get-printer-attributes.test"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + return _parse_ipp_attributes(r.stdout) + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + return {} + + +def _get_cups_economode(printer_name: str) -> str: + """Query toner save mode setting via lpoptions.""" + lpoptions_path = shutil.which("lpoptions") + if not lpoptions_path: + return "" + try: + r = subprocess.run( + [lpoptions_path, "-p", printer_name, "-l"], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if "conomode" in line.lower(): + match = re.search(r"\*(\w+)", line) + if match: + return "ON" if match.group(1).lower() == "true" else "OFF" + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + pass + return "" + + +# ── Status code mapping ────────────────────────────────────────────── + + +def _map_cups_to_status_code(state: str, reasons: str) -> str: + """Map CUPS state + reasons to a Brother PJL status code string.""" + for keyword, code in _CUPS_REASONS_TO_STATUS.items(): + if keyword in reasons.lower(): + return str(code) + clean_state = re.sub(r"\(.*\)", "", state).strip().lower() + return str(_CUPS_STATE_TO_STATUS.get(clean_state, 10001)) + + +def _cups_reasons_to_error(cups_reasons: str) -> tuple[str, str]: + """Map CUPS reason keywords to a (status_code, display) pair.""" + reasons_lower = cups_reasons.lower() + for keywords, code, display in _ERROR_REASON_MAP: + if any(kw in reasons_lower for kw in keywords): + return code, display + return "42000", "Printer Error" + + +def _port_status_to_status_code( + ps: USBPortStatus, + cups_reasons: str, +) -> tuple[str, str]: + """Map USB port status + CUPS reasons to (status_code, display).""" + if ps.error and ps.paper_empty: + return "40302", "No Paper" + if ps.error and not ps.online: + return "41000", "Cover Open" + if ps.error: + return _cups_reasons_to_error(cups_reasons) + if ps.paper_empty: + return "40302", "No Paper" + if not ps.online: + return "10002", "Offline / Sleep" + return "", "" + + +# ── CUPS printer name discovery ────────────────────────────────────── + + +def find_cups_printer_name() -> str: + """Find the CUPS queue name for a Brother printer.""" + lpstat_path = shutil.which("lpstat") + if not lpstat_path: + return "" + try: + r = subprocess.run( + [lpstat_path, "-v"], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if "brother" in line.lower(): + match = re.match(r"device for (\S+):", line) + if match: + return match.group(1) + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + pass + return "" + + +# ── CUPS-based USB fallback query ──────────────────────────────────── + + +def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None: + """Extract product and serial from a CUPS usb:// URI.""" + parsed = urllib.parse.urlparse(uri) + info["product"] = urllib.parse.unquote(parsed.path.lstrip("/")) + qs = urllib.parse.parse_qs(parsed.query) + if "serial" in qs: + info["serial"] = qs["serial"][0] + + +def _get_printer_info_from_cups() -> dict[str, str]: + """Get printer model/serial from lpstat.""" + info: dict[str, str] = {"product": "", "serial": ""} + try: + r = subprocess.run( + ["/usr/bin/lpstat", "-v"], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if "Brother" in line: + for part in line.split(): + if part.startswith("usb://"): + _parse_cups_usb_uri(part, info) + break + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + logger.debug("Failed to query CUPS for printer info", exc_info=True) + return info + + +def query_usb_via_cups() -> USBResult: + """Query USB printer status through CUPS when /dev/usb/lp* is unavailable.""" + _ensure_cups_running() + printer_name = find_cups_printer_name() + if not printer_name: + return USBResult( + error="No USB printer device at /dev/usb/lp*" + " (usblp module not available)" + " and no Brother printer found in CUPS.", + ) + + pyusb_info = _get_pyusb_device_info() + cups_info = _get_printer_info_from_cups() + + result = USBResult( + device="cups", + product=( + pyusb_info.get("product") + or cups_info.get("product") + or "Brother Laser Printer" + ), + serial=pyusb_info.get("serial") or cups_info.get("serial", ""), + ) + + ipp = _get_cups_ipp_status(printer_name) + state = ipp.get("printer-state", "") + reasons = ipp.get("printer-state-reasons", "none") + result.economode = _get_cups_economode(printer_name) + + port_status = _query_usb_port_status_raw() + if port_status is not None: + result.port_status = port_status + hw_code, hw_display = _port_status_to_status_code( + port_status, + reasons, + ) + if hw_code: + result.status_code = hw_code + result.display = hw_display + result.online = "TRUE" if port_status.online else "FALSE" + return result + estimate = estimate_consumable_life() + if estimate.toner_exhausted: + result.status_code = "40310" + result.display = "Toner End (estimated from page count)" + result.online = "TRUE" + return result + if estimate.toner_low: + result.status_code = "30010" + result.display = "Toner Low (estimated from page count)" + result.online = "TRUE" + return result + result.status_code = _map_cups_to_status_code(state, reasons) + result.display = ipp.get("printer-state-message", "") + result.online = "TRUE" + return result + + result.status_code = _map_cups_to_status_code(state, reasons) + result.display = ipp.get("printer-state-message", "") + result.online = "TRUE" if state.lower() in {"idle", "processing"} else "FALSE" + + return result diff --git a/python_pkg/brother_printer/data_classes.py b/python_pkg/brother_printer/data_classes.py new file mode 100644 index 0000000..cd034c8 --- /dev/null +++ b/python_pkg/brother_printer/data_classes.py @@ -0,0 +1,96 @@ +"""Data classes for Brother printer status information.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class CUPSJob: + """A single CUPS print job.""" + + job_id: str + user: str + size: str + date: str + + +@dataclass +class CUPSQueueStatus: + """Status of the CUPS print queue for a printer.""" + + printer_name: str = "" + enabled: bool = True + reason: str = "" + jobs: list[CUPSJob] = field(default_factory=list) + has_backend_errors: bool = False + last_backend_error: str = "" + + +@dataclass +class PageCountEstimate: + """Estimated consumable life based on CUPS page count.""" + + total_pages: int = 0 + toner_pages: int = 0 + drum_pages: int = 0 + toner_pct_remaining: int = 100 + drum_pct_remaining: int = 100 + toner_exhausted: bool = False + toner_low: bool = False + drum_near_end: bool = False + + +@dataclass +class USBPortStatus: + """IEEE 1284 USB printer port status bits.""" + + paper_empty: bool = False + online: bool = True + error: bool = False + raw_byte: int = 0 + + +@dataclass +class USBResult: + """Result from a USB PJL query.""" + + connection: str = "usb" + device: str = "" + product: str = "Brother Laser Printer" + serial: str = "" + status_code: str = "" + display: str = "" + online: str = "" + economode: str = "" + error: str = "" + port_status: USBPortStatus | None = None + + +@dataclass +class NetworkResult: + """Result from an SNMP network query.""" + + connection: str = "network" + ip: str = "" + product: str = "Unknown" + serial: str = "" + printer_status: str = "" + device_status: str = "" + display: str = "" + page_count: str = "" + supply_descriptions: list[str] = field(default_factory=list) + supply_max: list[str] = field(default_factory=list) + supply_levels: list[str] = field(default_factory=list) + error: str = "" + + +@dataclass +class SupplyStatus: + """Processed supply level info for display.""" + + color: str + bar: str + status_text: str + warning: str + needs_replacement: bool diff --git a/python_pkg/brother_printer/display.py b/python_pkg/brother_printer/display.py new file mode 100644 index 0000000..bed2500 --- /dev/null +++ b/python_pkg/brother_printer/display.py @@ -0,0 +1,385 @@ +"""Display and formatting functions for Brother printer status reports.""" + +from __future__ import annotations + +import sys + +from python_pkg.brother_printer.constants import ( + BOLD, + CYAN, + DIM, + DRUM_RATED_PAGES, + GREEN, + PROGRESS_BAR_WIDTH, + RED, + RESET, + SNMP_LEVEL_LOW, + SNMP_LEVEL_OK, + SUPPLY_LOW_PCT, + SUPPLY_WARN_PCT, + TONER_RATED_PAGES, + YELLOW, + _out, + get_status_info, +) +from python_pkg.brother_printer.cups_queue import ( + display_cups_queue_status, + get_cups_queue_status, +) +from python_pkg.brother_printer.cups_service import estimate_consumable_life +from python_pkg.brother_printer.data_classes import ( + NetworkResult, + SupplyStatus, + USBResult, +) + +# ── Shared display helpers ─────────────────────────────────────────── + + +def _display_report_header() -> None: + """Print the report banner box.""" + _out() + _out(f"{BOLD}╔══════════════════════════════════════════════════╗{RESET}") + _out(f"{BOLD}║ Brother Laser Printer Status Report ║{RESET}") + _out(f"{BOLD}╚══════════════════════════════════════════════════╝{RESET}") + _out() + + +def _display_page_count_estimate() -> None: + """Show estimated consumable life based on CUPS page count.""" + estimate = estimate_consumable_life() + if estimate.total_pages <= 0: + return + _out(f"{BOLD}── Page Count Estimate ──{RESET}") + _out() + _out( + f" {BOLD}Total pages printed:{RESET} {estimate.total_pages}" + f" (toner: {estimate.toner_pages} since replacement," + f" drum: {estimate.drum_pages} since replacement)" + ) + _out() + # Toner bar + toner_pct = estimate.toner_pct_remaining + toner_filled = toner_pct * PROGRESS_BAR_WIDTH // 100 + toner_empty = PROGRESS_BAR_WIDTH - toner_filled + toner_bar = f"[{'█' * toner_filled}{'░' * toner_empty}]" + if estimate.toner_exhausted: + toner_color = RED + toner_note = " ← REPLACE NOW" + elif estimate.toner_low: + toner_color = YELLOW + toner_note = " ← order soon" + else: + toner_color = GREEN + toner_note = "" + _out( + f" {BOLD}Toner:{RESET} {toner_color}{toner_bar} ~{toner_pct}%" + f"{toner_note}{RESET}" + ) + # Drum bar + drum_pct = estimate.drum_pct_remaining + drum_filled = drum_pct * PROGRESS_BAR_WIDTH // 100 + drum_empty = PROGRESS_BAR_WIDTH - drum_filled + drum_bar = f"[{'█' * drum_filled}{'░' * drum_empty}]" + if estimate.drum_near_end: + drum_color = YELLOW + drum_note = " ← nearing end" + else: + drum_color = GREEN + drum_note = "" + _out( + f" {BOLD}Drum:{RESET} {drum_color}{drum_bar} ~{drum_pct}%" + f"{drum_note}{RESET}" + ) + _out( + f" {DIM}Based on pages since last replacement" + f" vs rated capacity (toner ~{TONER_RATED_PAGES}," + f" drum ~{DRUM_RATED_PAGES}).{RESET}" + ) + _out(f" {DIM}Reset after replacing: --reset-toner or --reset-drum{RESET}") + if estimate.toner_exhausted: + _out() + _out( + f" {RED}{BOLD}⚠ Toner is likely exhausted." + f" This is probably why the orange light is flashing.{RESET}" + ) + _out() + + +def _display_consumables_reference() -> None: + """Print compatible consumables reference.""" + _out(f"{BOLD}── Compatible Consumables ──{RESET}") + _out() + _out(f" {BOLD}Toner:{RESET} TN-1050 / TN-1030 (or compatible third-party)") + _out(f" {BOLD}Drum:{RESET} DR-1050 / DR-1030 (or compatible third-party)") + _out(f" {DIM} Toner rated ~1000 pages; Drum rated ~10000 pages.{RESET}") + _out() + + +# ── USB display helpers ────────────────────────────────────────────── + + +def _display_usb_device_info(result: USBResult) -> None: + """Print device info block for USB results.""" + _out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}") + _out(f"{BOLD}Connection:{RESET} USB") + if result.serial: + _out(f"{BOLD}Serial:{RESET} {result.serial}") + + if result.online == "TRUE": + _out(f"{BOLD}Online:{RESET} {GREEN}Yes{RESET}") + elif result.online == "FALSE": + _out(f"{BOLD}Online:{RESET} {YELLOW}No (needs attention){RESET}") + + _out() + + if result.economode: + if result.economode == "ON": + _out( + f"{BOLD}Toner Save:{RESET} {GREEN}ON{RESET}" + " (extends toner life, lighter prints)" + ) + else: + _out(f"{BOLD}Toner Save:{RESET} OFF") + + +_SEVERITY_ICONS: dict[str, str] = { + "ok": "✓", + "info": "i", + "warn": "⚡", + "critical": "⚠", +} +_SEVERITY_COLORS: dict[str, str] = { + "ok": GREEN, + "info": CYAN, + "warn": YELLOW, + "critical": RED, +} +_SEVERITY_SUMMARIES: dict[str, str] = { + "ok": f"{GREEN}{BOLD}✓ Printer is healthy. No replacements needed.{RESET}", + "info": ( + f"{CYAN}{BOLD}i Printer is busy/processing." f" No replacements needed.{RESET}" + ), + "warn": ( + f"{YELLOW}{BOLD}⚡ WARNING: Maintenance will be needed" + f" soon.{RESET}\n{YELLOW} Order replacement parts" + f" now to avoid interruption.{RESET}" + ), + "critical": ( + f"{RED}{BOLD}⚠ ACTION REQUIRED:" f" Replacement or fix needed now!{RESET}" + ), +} + + +def _format_status_detail( + severity: str, short_text: str, action: str, result: USBResult +) -> None: + """Print severity icon, display text, and action.""" + color = _SEVERITY_COLORS.get(severity, GREEN) + icon = _SEVERITY_ICONS.get(severity, "✓") + + _out(f" {color}{BOLD}{icon} {short_text}{RESET}") + if result.display and result.display != short_text: + _out(f" {DIM}Display: {result.display}{RESET}") + _out(f" {DIM}Status code: {result.status_code}{RESET}") + + if action: + _out() + _out(f" {color}{BOLD}Action:{RESET} {color}{action}{RESET}") + _out() + _out(_SEVERITY_SUMMARIES.get(severity, "")) + + +def _display_pjl_status(result: USBResult) -> None: + """Display PJL status code interpretation.""" + _out() + _out(f"{BOLD}── Printer Status ──{RESET}") + _out() + + if not result.status_code: + _out(f" {YELLOW}Could not read status from printer.{RESET}") + if result.display: + _out(f" Display message: {BOLD}{result.display}{RESET}") + return + + severity, short_text, action = get_status_info(result.status_code) + _format_status_detail(severity, short_text, action, result) + + +def _display_cups_fallback_note(result: USBResult) -> None: + """Show a note when running in CUPS fallback mode.""" + _out() + if result.port_status is not None: + _out( + f" {DIM}Note: Hardware status obtained via USB port query." + f" Toner/drum percentages not available.{RESET}" + ) + else: + _out( + f" {DIM}Note: pyusb not available; status obtained via" + f" CUPS only. Detailed toner/drum levels are not" + f" available in this mode.{RESET}" + ) + + +# ── USB results display ───────────────────────────────────────────── + + +def display_usb_results(result: USBResult) -> None: + """Print a formatted report for USB PJL query results.""" + if result.error: + _out(f"{RED}Error: {result.error}{RESET}") + sys.exit(1) + + _display_report_header() + _display_usb_device_info(result) + _display_pjl_status(result) + + if result.device == "cups": + _display_cups_fallback_note(result) + + _out() + _display_page_count_estimate() + _display_consumables_reference() + + queue = get_cups_queue_status() + display_cups_queue_status(queue) + + +# ── Network supply level helpers ───────────────────────────────────── + + +def _classify_percentage_level(desc: str, pct: int) -> tuple[int, str, str, str, bool]: + """Classify a supply by its calculated percentage.""" + if pct <= SUPPLY_LOW_PCT: + return pct, f"{pct}%", RED, f"{desc} at {pct}%.", True + if pct <= SUPPLY_WARN_PCT: + return pct, f"{pct}%", YELLOW, f"{desc} at {pct}% -- order soon.", False + return pct, f"{pct}%", GREEN, "", False + + +def _classify_supply_level( + desc: str, max_val: int, level: int +) -> tuple[int, str, str, str, bool]: + """Classify a supply level. Returns (pct, status, color, warning, replace).""" + if level == SNMP_LEVEL_OK: + return -1, "OK", GREEN, "", False + if level == SNMP_LEVEL_LOW: + return -1, "LOW", RED, f"{desc} is LOW.", True + if level == 0: + return 0, "EMPTY", RED, f"{desc} is EMPTY -- replace now!", True + if max_val > 0: + pct = min(level * 100 // max_val, 100) + return _classify_percentage_level(desc, pct) + return -1, "", GREEN, "", False + + +def _format_supply_bar(pct: int) -> str: + """Build a progress bar string for a supply percentage.""" + if pct < 0: + return "" + filled = pct * PROGRESS_BAR_WIDTH // 100 + empty = PROGRESS_BAR_WIDTH - filled + return f"[{'█' * filled}{'░' * empty}]" + + +def _process_supply_item(desc: str, max_val: int, level: int) -> SupplyStatus: + """Process a single supply item into display info.""" + pct, status_text, color, warning, needs_replacement = _classify_supply_level( + desc, max_val, level + ) + bar = _format_supply_bar(pct) + return SupplyStatus(color, bar, status_text, warning, needs_replacement) + + +def _display_supply_warnings(*, needs_replacement: bool, warnings: list[str]) -> None: + """Display supply level warnings summary.""" + _out() + if needs_replacement: + _out(f"{RED}{BOLD}⚠ ACTION NEEDED:{RESET}") + for w in warnings: + _out(f" {RED}• {w}{RESET}") + elif warnings: + _out(f"{YELLOW}{BOLD}⚡ HEADS UP:{RESET}") + for w in warnings: + _out(f" {YELLOW}• {w}{RESET}") + else: + _out(f"{GREEN}{BOLD}✓ All consumables are at healthy levels.{RESET}") + + +def _parse_supply_value(values: list[str], index: int) -> int: + """Safely parse an integer from a supply value list.""" + try: + return int(values[index]) + except (IndexError, ValueError): + return 0 + + +def _collect_supply_items( + result: NetworkResult, +) -> tuple[list[SupplyStatus], list[str]]: + """Parse and collect supply items with their descriptions.""" + items: list[SupplyStatus] = [] + descs: list[str] = [] + for i, desc in enumerate(result.supply_descriptions): + max_val = _parse_supply_value(result.supply_max, i) + level = _parse_supply_value(result.supply_levels, i) + items.append(_process_supply_item(desc, max_val, level)) + descs.append(desc) + return items, descs + + +def _display_supply_levels(result: NetworkResult) -> None: + """Display consumable supply levels section.""" + _out() + _out(f"{BOLD}── Consumable Levels ──{RESET}") + _out() + + needs_replacement = False + warnings: list[str] = [] + items, descs = _collect_supply_items(result) + + for desc, item in zip(descs, items, strict=True): + _out( + f" {BOLD}{desc:<25}{RESET}" + f" {item.color}{item.bar} {item.status_text}{RESET}" + ) + if item.needs_replacement: + needs_replacement = True + if item.warning: + warnings.append(item.warning) + + _display_supply_warnings(needs_replacement=needs_replacement, warnings=warnings) + + +def _display_network_device_info(result: NetworkResult) -> None: + """Display device info section for network results.""" + _out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}") + _out(f"{BOLD}Connection:{RESET} Network ({result.ip})") + if result.serial: + _out(f"{BOLD}Serial:{RESET} {result.serial}") + if result.display: + _out(f"{BOLD}Display:{RESET} {result.display}") + if result.page_count and result.page_count.isdigit(): + _out(f"{BOLD}Pages:{RESET} {result.page_count} total") + + +# ── Network results display ────────────────────────────────────────── + + +def display_network_results(result: NetworkResult) -> None: + """Print a formatted report for SNMP network query results.""" + if result.error: + _out(f"{RED}Error: {result.error}{RESET}") + sys.exit(1) + + _display_report_header() + _display_network_device_info(result) + _display_supply_levels(result) + + _out() + _out( + f"{CYAN}Tip: Visit http://{result.ip} for the full web management" + f" interface.{RESET}" + ) + _out() diff --git a/python_pkg/brother_printer/network_query.py b/python_pkg/brother_printer/network_query.py new file mode 100644 index 0000000..3d9cfad --- /dev/null +++ b/python_pkg/brother_printer/network_query.py @@ -0,0 +1,97 @@ +"""SNMP network query functions for Brother printers.""" + +from __future__ import annotations + +import shutil +import subprocess + +from python_pkg.brother_printer.data_classes import NetworkResult + + +def _snmpwalk_cmd( + path: str, community: str, timeout: int, ip: str, oid: str +) -> list[str]: + """Build the snmpwalk command arguments.""" + return [path, "-v", "2c", "-c", community, "-t", str(timeout), "-OQvs", ip, oid] + + +def snmp_walk(ip: str, oid: str, community: str, timeout: int) -> list[str]: + """Run snmpwalk and return cleaned values.""" + snmpwalk_path = shutil.which("snmpwalk") + if not snmpwalk_path: + return [] + try: + r = subprocess.run( + _snmpwalk_cmd(snmpwalk_path, community, timeout, ip, oid), + capture_output=True, + text=True, + timeout=15, + check=False, + ) + return [ + line.strip().strip('"') + for line in r.stdout.strip().splitlines() + if line.strip() + ] + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + return [] + + +def _snmpget_cmd( + path: str, community: str, timeout: int, ip: str, oid: str +) -> list[str]: + """Build the snmpget command arguments.""" + return [path, "-v", "2c", "-c", community, "-t", str(timeout), ip, oid] + + +def _check_snmp_connectivity(ip: str, community: str, timeout: int) -> str | None: + """Verify SNMP connectivity. Returns error message or None on success.""" + snmpget_path = shutil.which("snmpget") + if not snmpget_path: + return "snmpget not found. Install: sudo pacman -S net-snmp" + try: + subprocess.run( + _snmpget_cmd( + snmpget_path, + community, + timeout, + ip, + "1.3.6.1.2.1.43.11.1.1.6.1.1", + ), + capture_output=True, + timeout=10, + check=True, + ) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError): + return f"Cannot reach printer at {ip} via SNMP." + return None + + +def _build_network_result(ip: str, community: str, timeout: int) -> NetworkResult: + """Collect all SNMP data into a NetworkResult.""" + + def walk(oid: str) -> list[str]: + return snmp_walk(ip, oid, community, timeout) + + return NetworkResult( + ip=ip, + product=" ".join(walk("1.3.6.1.2.1.25.3.2.1.3")[:1]) or "Unknown", + serial=" ".join(walk("1.3.6.1.2.1.43.5.1.1.17")[:1]) or "", + printer_status=" ".join(walk("1.3.6.1.2.1.25.3.5.1.1")[:1]) or "", + device_status=" ".join(walk("1.3.6.1.2.1.25.3.2.1.5")[:1]) or "", + display=" ".join(walk("1.3.6.1.2.1.43.16.5.1.2")[:3]) or "", + page_count=" ".join(walk("1.3.6.1.2.1.43.10.2.1.4")[:1]) or "", + supply_descriptions=walk("1.3.6.1.2.1.43.11.1.1.6"), + supply_max=walk("1.3.6.1.2.1.43.11.1.1.8"), + supply_levels=walk("1.3.6.1.2.1.43.11.1.1.9"), + ) + + +def query_network_snmp(ip: str) -> NetworkResult: + """Query a Brother printer via SNMP over the network.""" + community = "public" + timeout = 5 + error = _check_snmp_connectivity(ip, community, timeout) + if error: + return NetworkResult(ip=ip, error=error) + return _build_network_result(ip, community, timeout) diff --git a/python_pkg/brother_printer/usb_query.py b/python_pkg/brother_printer/usb_query.py new file mode 100644 index 0000000..c56be0d --- /dev/null +++ b/python_pkg/brother_printer/usb_query.py @@ -0,0 +1,233 @@ +"""USB printer discovery and PJL query functions.""" + +from __future__ import annotations + +import contextlib +import fcntl +import os +from pathlib import Path +import select +import shutil +import subprocess +import time +from typing import TYPE_CHECKING +import urllib.parse + +from python_pkg.brother_printer.data_classes import USBResult + +if TYPE_CHECKING: + from collections.abc import Callable + +import logging + +logger = logging.getLogger(__name__) + + +# ── USB printer discovery ──────────────────────────────────────────── + + +def find_brother_usb() -> str: + """Look for any Brother printer on USB via lsusb. Returns the info line.""" + if not shutil.which("lsusb"): + return "" + try: + r = subprocess.run( + ["/usr/bin/lsusb"], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if "04f9:" in line.lower(): + return line.split(": ", 1)[1] if ": " in line else line + except (subprocess.TimeoutExpired, OSError): + pass + return "" + + +def find_usb_printer_dev() -> str | None: + """Find /dev/usb/lp* device for the Brother printer.""" + devices = sorted(Path("/dev/usb").glob("lp*")) + return str(devices[0]) if devices else None + + +def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None: + """Extract product and serial from a CUPS usb:// URI.""" + parsed = urllib.parse.urlparse(uri) + info["product"] = urllib.parse.unquote(parsed.path.lstrip("/")) + qs = urllib.parse.parse_qs(parsed.query) + if "serial" in qs: + info["serial"] = qs["serial"][0] + + +def get_printer_info_from_cups() -> dict[str, str]: + """Get printer model/serial from lpstat.""" + info: dict[str, str] = {"product": "", "serial": ""} + try: + r = subprocess.run( + ["/usr/bin/lpstat", "-v"], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + for line in r.stdout.splitlines(): + if "Brother" in line: + for part in line.split(): + if part.startswith("usb://"): + _parse_cups_usb_uri(part, info) + break + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError): + logger.debug("Failed to query CUPS for printer info", exc_info=True) + return info + + +# ── PJL over USB ───────────────────────────────────────────────────── + + +def _drain_buffer(fd: int) -> None: + """Read and discard any stale data from the USB buffer.""" + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + with contextlib.suppress(OSError): + while os.read(fd, 4096): + pass + fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) + + +def _read_nonblocking(fd: int, flags: int) -> bytes: + """Read all available data from fd in non-blocking mode.""" + data = b"" + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + with contextlib.suppress(OSError): + while True: + chunk = os.read(fd, 4096) + if not chunk: + break + data += chunk + fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) + return data + + +def _wait_for_pjl_response(fd: int, flags: int, deadline: float) -> bytes: + """Poll fd until PJL data arrives or deadline expires.""" + response = b"" + while time.time() < deadline: + remaining = deadline - time.time() + if remaining <= 0: + break + readable, _, _ = select.select([fd], [], [], min(remaining, 1.0)) + if readable: + response += _read_nonblocking(fd, flags) + if response and (b"=" in response or b"@PJL" in response): + break + return response + + +def pjl_query(fd: int, cmd: str, timeout_sec: float = 5.0) -> str: + """Send a PJL command via raw fd and read the response.""" + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) + + pjl_cmd = f"\x1b%-12345X@PJL\r\n{cmd}\r\n\x1b%-12345X" + os.write(fd, pjl_cmd.encode()) + + deadline = time.time() + timeout_sec + response = _wait_for_pjl_response(fd, flags, deadline) + + fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) + return response.decode("ascii", errors="replace") + + +def _parse_status(resp: str, result: USBResult) -> bool: + """Parse STATUS response into result. Returns True if code was found.""" + found = False + for raw_line in resp.splitlines(): + stripped = raw_line.strip() + if stripped.startswith("CODE="): + result.status_code = stripped.split("=", 1)[1] + found = True + elif stripped.startswith("DISPLAY="): + result.display = stripped.split("=", 1)[1].strip().strip('"').strip() + elif stripped.startswith("ONLINE="): + result.online = stripped.split("=", 1)[1] + return found + + +def _parse_variables(resp: str, result: USBResult) -> bool: + """Parse VARIABLES response into result. Returns True if data found.""" + found = False + for raw_line in resp.splitlines(): + stripped = raw_line.strip() + if stripped.startswith("ECONOMODE="): + result.economode = stripped.split("=", 1)[1].split()[0] + found = True + return found + + +def _retry_pjl_query( + fd: int, + cmd: str, + parser: Callable[[str, USBResult], bool], + result: USBResult, + max_retries: int, +) -> None: + """Send a PJL query with retries, draining between attempts.""" + for attempt in range(max_retries + 1): + resp = pjl_query(fd, cmd) + if parser(resp, result): + break + if attempt < max_retries: + _drain_buffer(fd) + time.sleep(0.5) + + +def _run_pjl_queries(fd: int, result: USBResult, max_retries: int) -> None: + """Execute PJL query sequence on an open file descriptor.""" + _drain_buffer(fd) + + os.write(fd, b"\x1b%-12345X@PJL\r\n\x1b%-12345X") + time.sleep(0.5) + _drain_buffer(fd) + + _retry_pjl_query(fd, "@PJL INFO STATUS", _parse_status, result, max_retries) + _drain_buffer(fd) + time.sleep(0.5) + _retry_pjl_query(fd, "@PJL INFO VARIABLES", _parse_variables, result, max_retries) + + +def _init_usb_result(dev_path: str) -> USBResult: + """Create a USBResult with device info from CUPS.""" + cups_info = get_printer_info_from_cups() + return USBResult( + device=dev_path, + product=cups_info.get("product") or "Brother Laser Printer", + serial=cups_info.get("serial", ""), + ) + + +def query_usb_pjl(max_retries: int = 2) -> USBResult: + """Query a Brother printer via PJL over /dev/usb/lp*.""" + dev_path = find_usb_printer_dev() + if not dev_path: + from python_pkg.brother_printer.cups_service import query_usb_via_cups + + return query_usb_via_cups() + + result = _init_usb_result(dev_path) + if not os.access(dev_path, os.R_OK | os.W_OK): + result.error = f"Permission denied: {dev_path}. Run with sudo." + return result + + fd: int | None = None + try: + fd = os.open(dev_path, os.O_RDWR) + fcntl.fcntl(fd, fcntl.F_GETFL) + _run_pjl_queries(fd, result, max_retries) + except OSError as e: + result.error = str(e) + finally: + if fd is not None: + os.close(fd) + return result diff --git a/python_pkg/geo_data.py b/python_pkg/geo_data.py deleted file mode 100644 index c87fc24..0000000 --- a/python_pkg/geo_data.py +++ /dev/null @@ -1,2028 +0,0 @@ -"""Shared geographic data module for Warsaw and Poland Anki generators. - -This module handles downloading and caching geographic data from various sources: -- OpenStreetMap via Overpass API -- Geofabrik OSM extracts -- GitHub repositories with pre-processed GeoJSON - -All data is cached locally to avoid repeated downloads. -""" - -from __future__ import annotations - -import contextlib -import json -from pathlib import Path -import shutil -import sys -import time -from typing import TYPE_CHECKING -from urllib.request import urlopen - -import geopandas as gpd -import requests -from shapely.geometry import ( - GeometryCollection, - LineString, - MultiLineString, - MultiPolygon, - Polygon, -) - -if TYPE_CHECKING: - from typing import Any - -# Shared cache directory for all geo data -CACHE_DIR = Path(__file__).parent / "geo_cache" - -# Overpass API endpoints (multiple for redundancy) -# Note: kumi.systems is more reliable, so it's first -OVERPASS_ENDPOINTS = [ - "https://overpass.kumi.systems/api/interpreter", - "https://overpass-api.de/api/interpreter", - "https://maps.mail.ru/osm/tools/overpass/api/interpreter", -] - -# GitHub URLs for pre-processed data -POLSKA_GEOJSON_BASE = "https://raw.githubusercontent.com/ppatrzyk/polska-geojson/master" - -# Wikidata SPARQL endpoint -WIKIDATA_SPARQL = "https://query.wikidata.org/sparql" - -# Request timeout and retry settings -REQUEST_TIMEOUT = 180 -MAX_RETRIES = 3 -RETRY_DELAY = 5 - -# Data thresholds for filtering -MIN_PEAK_ELEVATION = 300 # meters -MIN_LAKE_AREA_KM2 = 0.5 # km² -MIN_RIVER_LENGTH_KM = 10 # km -MIN_LINE_COORDS = 2 # minimum coordinates for a line -MIN_RING_COORDS = 4 # minimum coordinates for a polygon ring - - -def _ensure_cache_dir() -> None: - """Create cache directory if it doesn't exist.""" - CACHE_DIR.mkdir(parents=True, exist_ok=True) - - -def _extract_polygonal_geometry( - geom: Polygon | MultiPolygon | GeometryCollection, -) -> Polygon | MultiPolygon | None: - """Extract only polygonal geometry from a geometry that may be mixed. - - Some OSM data comes as GeometryCollections containing polygons mixed with - lines. This function extracts only the polygon/multipolygon parts. - - Args: - geom: Input geometry (Polygon, MultiPolygon, or GeometryCollection). - - Returns: - Polygon or MultiPolygon with only the polygonal parts, or None if empty. - """ - if isinstance(geom, Polygon | MultiPolygon): - return geom - - if isinstance(geom, GeometryCollection): - polygons = [g for g in geom.geoms if isinstance(g, Polygon | MultiPolygon)] - if not polygons: - return None - if len(polygons) == 1: - return polygons[0] - # Flatten MultiPolygons and combine all polygons - all_polys = [] - for p in polygons: - if isinstance(p, Polygon): - all_polys.append(p) - elif isinstance(p, MultiPolygon): - all_polys.extend(p.geoms) - return MultiPolygon(all_polys) - - return None - - -def _query_wikidata(query: str) -> list[dict[str, Any]]: - """Query Wikidata SPARQL endpoint. - - Args: - query: SPARQL query string. - - Returns: - List of result bindings. - """ - response = requests.get( - WIKIDATA_SPARQL, - params={"query": query, "format": "json"}, - timeout=60, - ) - response.raise_for_status() - return response.json()["results"]["bindings"] - - -def _get_powiaty_population() -> dict[str, int]: - """Get population data for all Polish powiaty from Wikidata. - - Returns: - Dictionary mapping powiat name to population. - """ - cache_path = CACHE_DIR / "powiaty_population.json" - - if cache_path.exists(): - return json.loads(cache_path.read_text()) - - # Query Wikidata for all powiaty (Q247073) in Poland (Q36) with population - # Filter to only current Polish powiaty using country=Poland filter - query = """ - SELECT ?powiat ?powiatLabel ?population WHERE { - ?powiat wdt:P31 wd:Q247073. - ?powiat wdt:P17 wd:Q36. - ?powiat wdt:P1082 ?population. - SERVICE wikibase:label { bd:serviceParam wikibase:language "pl,en". } - } - ORDER BY DESC(?population) - """ - - sys.stdout.write("Fetching powiaty population data from Wikidata...\n") - results = _query_wikidata(query) - - population_map: dict[str, int] = {} - for item in results: - label = item.get("powiatLabel", {}).get("value", "") - pop = item.get("population", {}).get("value", "0") - if label and pop: - # Remove "powiat" prefix if present for matching - clean_label = label.replace("powiat ", "").strip() - with contextlib.suppress(ValueError): - population_map[clean_label] = int(pop) - - _ensure_cache_dir() - cache_path.write_text(json.dumps(population_map, ensure_ascii=False, indent=2)) - - sys.stdout.write(f"Cached population data for {len(population_map)} powiaty.\n") - return population_map - - -def _try_single_request( - endpoint: str, query: str -) -> tuple[dict[str, Any] | None, Exception | None]: - """Try a single request to an endpoint. - - Args: - endpoint: Overpass API endpoint URL. - query: Overpass QL query string. - - Returns: - Tuple of (result, error). One will be None. - """ - try: - sys.stdout.write(f" Querying {endpoint}...\n") - response = requests.post( - endpoint, - data={"data": query}, - timeout=REQUEST_TIMEOUT, - ) - response.raise_for_status() - result = response.json() - except (requests.RequestException, requests.Timeout, ValueError) as e: - return None, e - else: - # Check for valid response with elements - if not isinstance(result, dict) or "elements" not in result: - return None, ValueError("Invalid response format") - return result, None - - -def _overpass_query(query: str) -> dict[str, Any]: - """Execute an Overpass API query with retry logic. - - Args: - query: Overpass QL query string. - - Returns: - JSON response from the API. - - Raises: - RuntimeError: If all endpoints fail. - """ - last_error: Exception | None = None - - for endpoint in OVERPASS_ENDPOINTS: - for attempt in range(MAX_RETRIES): - result, error = _try_single_request(endpoint, query) - if result is not None: - return result - last_error = error - sys.stdout.write(f" Attempt {attempt + 1} failed: {error}\n") - if attempt < MAX_RETRIES - 1: - time.sleep(RETRY_DELAY) - - msg = f"All Overpass API endpoints failed. Last error: {last_error}" - raise RuntimeError(msg) - - -def _download_github_geojson(url: str, cache_path: Path) -> gpd.GeoDataFrame: - """Download GeoJSON from GitHub and cache it. - - Args: - url: URL to download from. - cache_path: Path to cache the data. - - Returns: - GeoDataFrame with the data. - """ - if cache_path.exists(): - return gpd.read_file(cache_path) - - sys.stdout.write(f"Downloading from {url}...\n") - if not url.startswith(("http://", "https://")): - msg = f"Unsupported URL scheme: {url}" - raise ValueError(msg) - with urlopen(url, timeout=REQUEST_TIMEOUT) as response: - data = json.loads(response.read().decode()) - - _ensure_cache_dir() - cache_path.write_text(json.dumps(data)) - - return gpd.GeoDataFrame.from_features(data["features"], crs="EPSG:4326") - - -# ============================================================================= -# Warsaw Data -# ============================================================================= - - -def get_warsaw_boundary() -> gpd.GeoDataFrame: - """Get Warsaw city boundary. - - Returns: - GeoDataFrame with Warsaw boundary polygon. - """ - cache_path = CACHE_DIR / "warsaw_boundary.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - # Try to use districts file first - districts_path = ( - Path(__file__).parent - / "anki_decks" - / "warsaw_districts" - / "warszawa-dzielnice.geojson" - ) - if districts_path.exists(): - warsaw_gdf = gpd.read_file(districts_path) - warsaw_boundary = warsaw_gdf[warsaw_gdf["name"] == "Warszawa"] - if len(warsaw_boundary) == 0: - warsaw_boundary = gpd.GeoDataFrame( - geometry=[warsaw_gdf.union_all()], crs=warsaw_gdf.crs - ) - _ensure_cache_dir() - warsaw_boundary.to_file(cache_path, driver="GeoJSON") - return warsaw_boundary - - # Fallback to Overpass query - sys.stdout.write("Fetching Warsaw boundary from OpenStreetMap...\n") - query = """ - [out:json][timeout:60]; - relation["name"="Warszawa"]["admin_level"="6"]; - out geom; - """ - - data = _overpass_query(query) - - features = [] - for element in data.get("elements", []): - if element.get("type") == "relation": - coords = [] - for member in element.get("members", []): - if member.get("role") == "outer" and "geometry" in member: - coords.extend([(p["lon"], p["lat"]) for p in member["geometry"]]) - if coords: - features.append( - { - "type": "Feature", - "properties": {"name": "Warszawa"}, - "geometry": {"type": "Polygon", "coordinates": [coords]}, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - -def get_warsaw_districts() -> gpd.GeoDataFrame: - """Get Warsaw districts (dzielnice). - - Returns: - GeoDataFrame with district boundaries. - """ - districts_path = ( - Path(__file__).parent - / "anki_decks" - / "warsaw_districts" - / "warszawa-dzielnice.geojson" - ) - if districts_path.exists(): - gdf = gpd.read_file(districts_path) - return gdf[gdf["name"] != "Warszawa"].copy() - - msg = "Warsaw districts GeoJSON not found" - raise FileNotFoundError(msg) - - -def get_vistula_river() -> gpd.GeoDataFrame: - """Get Vistula river in Warsaw. - - Returns: - GeoDataFrame with river geometry. - """ - cache_path = CACHE_DIR / "warsaw_vistula.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - sys.stdout.write("Fetching Vistula river data...\n") - query = """ - [out:json][timeout:60]; - area["name"="Warszawa"]["admin_level"="6"]->.warsaw; - ( - way["waterway"="river"]["name"="Wisła"](area.warsaw); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - min_coords = 2 - for element in data.get("elements", []): - if element.get("type") == "way" and "geometry" in element: - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) >= min_coords: - features.append( - { - "type": "Feature", - "properties": {"name": "Wisła"}, - "geometry": {"type": "LineString", "coordinates": coords}, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - -def get_warsaw_bridges() -> gpd.GeoDataFrame: - """Get Warsaw bridges over the Vistula. - - Returns: - GeoDataFrame with bridge geometries. - """ - cache_path = CACHE_DIR / "warsaw_bridges.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - sys.stdout.write("Fetching Warsaw bridges data...\n") - - # First get the Vistula to filter bridges - vistula = get_vistula_river() - vistula_union = vistula.union_all() - vistula_buffer = vistula_union.buffer(0.002) # ~200m buffer - - # Query for bridges with "Most" in name - smaller query - query = """ - [out:json][timeout:90]; - area["name"="Warszawa"]["admin_level"="6"]->.warsaw; - way["bridge"="yes"]["name"~"^Most"](area.warsaw); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_coords = 2 - - for element in data.get("elements", []): - if element.get("type") != "way" or "geometry" not in element: - continue - - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) < min_coords: - continue - - line = LineString(coords) - - # Check if bridge crosses/is near Vistula - if line.intersects(vistula_buffer): - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name, "osm_id": element.get("id")}, - "geometry": {"type": "LineString", "coordinates": coords}, - } - ) - - # Merge segments of the same bridge - merged_features = _merge_bridge_segments(features) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": merged_features} - cache_path.write_text(json.dumps(geojson)) - - sys.stdout.write(f"Cached {len(merged_features)} bridges.\n") - return gpd.GeoDataFrame.from_features(merged_features, crs="EPSG:4326") - - -def _merge_bridge_segments(features: list[dict]) -> list[dict]: - """Merge bridge segments with the same name. - - Args: - features: List of GeoJSON features. - - Returns: - List of merged features. - """ - by_name: dict[str, list[list[tuple[float, float]]]] = {} - - for feature in features: - name = feature["properties"]["name"] - coords = feature["geometry"]["coordinates"] - if name not in by_name: - by_name[name] = [] - by_name[name].append(coords) - - merged = [] - for name, coord_lists in by_name.items(): - if len(coord_lists) == 1: - geom = {"type": "LineString", "coordinates": coord_lists[0]} - else: - geom = {"type": "MultiLineString", "coordinates": coord_lists} - - merged.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geom} - ) - - return merged - - -def get_warsaw_metro_stations() -> gpd.GeoDataFrame: - """Get Warsaw metro stations with line information. - - Returns: - GeoDataFrame with station points and line info (M1, M2, or M1/M2). - """ - cache_path = CACHE_DIR / "warsaw_metro.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - # Known stations for each line (as of 2024) - m1_stations = { - "Kabaty", - "Natolin", - "Imielin", - "Stokłosy", - "Ursynów", - "Służew", - "Wilanowska", - "Wierzbno", - "Racławicka", - "Pole Mokotowskie", - "Politechnika", - "Centrum", - "Świętokrzyska", # Also M2 - "Ratusz-Arsenał", - "Dworzec Gdański", - "Plac Wilsona", - "Marymont", - "Słodowiec", - "Stare Bielany", - "Wawrzyszew", - "Młociny", - } - m2_stations = { - "Bródno", - "Kondratowicza", - "Zacisze", - "Targówek Mieszkaniowy", - "Trocka", - "Szwedzka", - "Dworzec Wileński", - "Świętokrzyska", # Also M1 - "Nowy Świat-Uniwersytet", - "Centrum Nauki Kopernik", - "Stadion Narodowy", - "Rondo ONZ", - "Rondo Daszyńskiego", - "Płocka", - "Młynów", - "Księcia Janusza", - "Ulrychów", - "Bemowo", - } - - sys.stdout.write("Fetching metro station data...\n") - query = """ - [out:json][timeout:60]; - area["name"="Warszawa"]["admin_level"="6"]->.warsaw; - ( - node["railway"="station"]["station"="subway"](area.warsaw); - node["railway"="station"]["network"~"Metro"](area.warsaw); - ); - out body; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - if element.get("type") == "node": - name = element.get("tags", {}).get("name", "") - if name and name not in seen_names: - seen_names.add(name) - # Determine line from known station lists - in_m1 = name in m1_stations - in_m2 = name in m2_stations - if in_m1 and in_m2: - line = "M1/M2" - elif in_m1: - line = "M1" - elif in_m2: - line = "M2" - else: - line = "?" # Unknown station - - features.append( - { - "type": "Feature", - "properties": { - "name": name, - "line": line, - }, - "geometry": { - "type": "Point", - "coordinates": [element["lon"], element["lat"]], - }, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - sys.stdout.write(f"Cached {len(features)} metro stations.\n") - return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - -def get_warsaw_streets(min_length: int = 500) -> gpd.GeoDataFrame: - """Get major Warsaw streets. - - Args: - min_length: Minimum street length in meters. - - Returns: - GeoDataFrame with street geometries. - """ - cache_path = CACHE_DIR / "warsaw_streets.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - # Filter by length if needed - return _filter_streets_by_length(gdf, min_length) - - sys.stdout.write("Fetching street data from OpenStreetMap...\n") - query = """ - [out:json][timeout:120]; - area["name"="Warszawa"]["admin_level"="6"]->.warsaw; - ( - way["highway"="primary"]["name"](area.warsaw); - way["highway"="secondary"]["name"](area.warsaw); - way["highway"="tertiary"]["name"](area.warsaw); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - min_coords = 2 - - for element in data.get("elements", []): - if element.get("type") == "way" and "geometry" in element: - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) >= min_coords: - features.append( - { - "type": "Feature", - "properties": { - "name": element.get("tags", {}).get("name", "Unknown"), - "highway": element.get("tags", {}).get("highway", ""), - }, - "geometry": {"type": "LineString", "coordinates": coords}, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - sys.stdout.write(f"Cached {len(features)} street segments.\n") - - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - return _filter_streets_by_length(gdf, min_length) - - -def _filter_streets_by_length( - gdf: gpd.GeoDataFrame, min_length: int -) -> gpd.GeoDataFrame: - """Filter and merge streets by name, keeping only those above min_length. - - Args: - gdf: GeoDataFrame with street segments. - min_length: Minimum length in meters. - - Returns: - GeoDataFrame with merged streets, sorted by length (longest first). - """ - # Group by street name - streets: dict[str, list] = {} - for _, row in gdf.iterrows(): - name = row.get("name", "Unknown") - if name and name != "Unknown": - if name not in streets: - streets[name] = [] - streets[name].append(row.geometry) - - # Merge and filter - result_rows = [] - for name, geometries in streets.items(): - merged = geometries[0] if len(geometries) == 1 else MultiLineString(geometries) - - # Create temp GeoDataFrame for length calculation - temp_gdf = gpd.GeoDataFrame(geometry=[merged], crs="EPSG:4326") - temp_proj = temp_gdf.to_crs("EPSG:2180") # Polish coordinate system - length = temp_proj.geometry.length.iloc[0] - - if length >= min_length: - result_rows.append({"name": name, "geometry": merged, "length_m": length}) - - # Sort by length (longest first) - result_rows.sort(key=lambda x: x["length_m"], reverse=True) - - return gpd.GeoDataFrame(result_rows, crs="EPSG:4326") - - -def get_warsaw_landmarks() -> gpd.GeoDataFrame: - """Get Warsaw landmarks (museums, monuments, parks, etc.). - - Returns: - GeoDataFrame with landmark points. - """ - cache_path = CACHE_DIR / "warsaw_landmarks.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - sys.stdout.write("Fetching landmark data...\n") - # Simplified query - just museums and major attractions - query = """ - [out:json][timeout:60]; - area["name"="Warszawa"]["admin_level"="6"]->.warsaw; - ( - node["tourism"="museum"]["name"](area.warsaw); - node["tourism"="attraction"]["name"](area.warsaw); - node["historic"="monument"]["name"](area.warsaw); - way["tourism"="museum"]["name"](area.warsaw); - way["tourism"="attraction"]["name"](area.warsaw); - ); - out center; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - # Get coordinates - if element.get("type") == "node": - lon, lat = element["lon"], element["lat"] - elif "center" in element: - lon, lat = element["center"]["lon"], element["center"]["lat"] - else: - continue - - seen_names.add(name) - landmark_type = ( - element.get("tags", {}).get("tourism") - or element.get("tags", {}).get("historic") - or element.get("tags", {}).get("leisure") - or "landmark" - ) - - features.append( - { - "type": "Feature", - "properties": {"name": name, "type": landmark_type}, - "geometry": {"type": "Point", "coordinates": [lon, lat]}, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - sys.stdout.write(f"Cached {len(features)} landmarks.\n") - - if not features: - return gpd.GeoDataFrame( - {"name": [], "type": [], "geometry": []}, crs="EPSG:4326" - ) - return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - -def _extract_osiedla_rings( - element: dict[str, Any], min_coords: int -) -> tuple[list[list[tuple[float, float]]], list[list[tuple[float, float]]]]: - """Extract outer and inner rings from an OSM relation. - - Args: - element: OSM relation element. - min_coords: Minimum number of coordinates for a valid ring. - - Returns: - Tuple of (outer_rings, inner_rings). - """ - outer_rings: list[list[tuple[float, float]]] = [] - inner_rings: list[list[tuple[float, float]]] = [] - - for member in element.get("members", []): - if "geometry" not in member: - continue - ring = [(p["lon"], p["lat"]) for p in member["geometry"]] - if len(ring) < min_coords: - continue - # Close the ring if not closed - if ring[0] != ring[-1]: - ring.append(ring[0]) - if member.get("role") == "outer": - outer_rings.append(ring) - elif member.get("role") == "inner": - inner_rings.append(ring) - - return outer_rings, inner_rings - - -def _build_osiedla_geometry( - outer_rings: list[list[tuple[float, float]]], - inner_rings: list[list[tuple[float, float]]], -) -> dict[str, Any]: - """Build GeoJSON geometry from outer and inner rings. - - Args: - outer_rings: List of outer ring coordinates. - inner_rings: List of inner ring coordinates. - - Returns: - GeoJSON geometry dict. - """ - if len(outer_rings) == 1: - return { - "type": "Polygon", - "coordinates": [outer_rings[0], *inner_rings], - } - # Multiple outer rings - create MultiPolygon - # Each polygon in a MultiPolygon is [exterior, hole1, hole2, ...] - return { - "type": "MultiPolygon", - "coordinates": [[ring] for ring in outer_rings], - } - - -def _extract_polygon_from_element( - element: dict[str, Any], -) -> dict[str, Any] | None: - """Extract polygon geometry from an OSM relation or way element. - - Args: - element: OSM element (relation or way). - - Returns: - GeoJSON geometry dict, or None if extraction fails. - """ - if element.get("type") == "relation": - outer_rings, inner_rings = _extract_osiedla_rings(element, MIN_RING_COORDS) - if not outer_rings: - return None - return _build_osiedla_geometry(outer_rings, inner_rings) - - if element.get("type") == "way" and "geometry" in element: - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) < MIN_RING_COORDS: - return None - if coords[0] != coords[-1]: - coords.append(coords[0]) - return {"type": "Polygon", "coordinates": [coords]} - - return None - - -def _extract_line_from_way(element: dict[str, Any]) -> dict[str, Any] | None: - """Extract line geometry from an OSM way element. - - Args: - element: OSM way element. - - Returns: - GeoJSON LineString geometry dict, or None if extraction fails. - """ - if element.get("type") != "way" or "geometry" not in element: - return None - - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) < MIN_LINE_COORDS: - return None - - return {"type": "LineString", "coordinates": coords} - - -def _add_area_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: - """Add area_km2 column to a GeoDataFrame. - - Args: - gdf: GeoDataFrame with polygon geometries. - - Returns: - GeoDataFrame with area_km2 column added. - """ - if len(gdf) == 0: - return gdf - gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system - gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 - return gdf - - -def _add_length_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: - """Add length_km column to a GeoDataFrame. - - Args: - gdf: GeoDataFrame with line geometries. - - Returns: - GeoDataFrame with length_km column added. - """ - if len(gdf) == 0: - return gdf - gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system - gdf["length_km"] = gdf_proj.geometry.length / 1000 - return gdf - - -def _extract_coastal_geometry( - element: dict[str, Any], - natural_type: str, - line_types: tuple[str, ...], -) -> dict[str, Any] | None: - """Extract geometry from a coastal feature element. - - For cliffs and beaches, returns LineString. For others, returns Polygon. - - Args: - element: OSM element. - natural_type: The natural= tag value. - line_types: Tuple of natural types that should be lines. - - Returns: - GeoJSON geometry dict, or None if extraction fails. - """ - if element.get("type") == "relation": - return _extract_polygon_from_element(element) - - if element.get("type") != "way" or "geometry" not in element: - return None - - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) < MIN_LINE_COORDS: - return None - - # For cliffs and beaches, keep as linestring - if natural_type in line_types: - return {"type": "LineString", "coordinates": coords} - - # Otherwise try to make a polygon - if len(coords) >= MIN_RING_COORDS: - if coords[0] != coords[-1]: - coords.append(coords[0]) - return {"type": "Polygon", "coordinates": [coords]} - - return None - - -def get_warsaw_osiedla() -> gpd.GeoDataFrame: - """Get Warsaw osiedla (neighborhoods). - - Returns: - GeoDataFrame with osiedla boundaries. - """ - cache_path = CACHE_DIR / "warsaw_osiedla.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - sys.stdout.write("Fetching osiedla data...\n") - query = """ - [out:json][timeout:180]; - area["name"="Warszawa"]["admin_level"="6"]->.warsaw; - relation["boundary"="administrative"]["admin_level"="11"]["name"](area.warsaw); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_ring_coords = 4 - - for element in data.get("elements", []): - if element.get("type") != "relation": - continue - - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) - if not outer_rings: - continue - - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name}, - "geometry": _build_osiedla_geometry(outer_rings, inner_rings), - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - sys.stdout.write(f"Cached {len(features)} osiedla.\n") - return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - -# ============================================================================= -# Poland Data -# ============================================================================= - - -def get_polish_wojewodztwa() -> gpd.GeoDataFrame: - """Get Polish województwa (voivodeships). - - Returns: - GeoDataFrame with województwa boundaries. - """ - url = f"{POLSKA_GEOJSON_BASE}/wojewodztwa/wojewodztwa-min.geojson" - cache_path = CACHE_DIR / "polish_wojewodztwa.geojson" - return _download_github_geojson(url, cache_path) - - -def get_polish_powiaty() -> gpd.GeoDataFrame: - """Get Polish powiaty (counties), sorted by population descending. - - Returns: - GeoDataFrame with powiat boundaries and population. - """ - url = f"{POLSKA_GEOJSON_BASE}/powiaty/powiaty-min.geojson" - cache_path = CACHE_DIR / "polish_powiaty.geojson" - gdf = _download_github_geojson(url, cache_path) - - # Get population data from Wikidata - population_map = _get_powiaty_population() - - # Add population column - def get_population(nazwa: str) -> int: - """Match powiat name to population data.""" - if not nazwa: - return 0 - # Remove "powiat " prefix for matching - clean_name = nazwa.replace("powiat ", "").strip() - # Try direct match - if clean_name in population_map: - return population_map[clean_name] - # Try lowercase - name_lower = clean_name.lower() - for pop_name, pop in population_map.items(): - if pop_name.lower() == name_lower: - return pop - return 0 - - gdf["population"] = gdf["nazwa"].apply(get_population) - - # Sort by population descending - return gdf.sort_values("population", ascending=False).reset_index(drop=True) - - -def get_polish_gminy() -> gpd.GeoDataFrame: - """Get Polish gminy (municipalities) from OSM, sorted by area descending. - - Returns: - GeoDataFrame with gminy boundaries. - """ - cache_path = CACHE_DIR / "polish_gminy.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching gminy data from OSM (this may take a while)...\n") - # Polish gminy are admin_level=7 in OSM - query = """ - [out:json][timeout:300]; - area["ISO3166-1"="PL"]->.pl; - relation["boundary"="administrative"]["admin_level"="7"]["name"](area.pl); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_ring_coords = 4 - - for element in data.get("elements", []): - if element.get("type") != "relation": - continue - - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) - if not outer_rings: - continue - - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name}, - "geometry": _build_osiedla_geometry(outer_rings, inner_rings), - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson)) - - sys.stdout.write(f"Cached {len(features)} gminy.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - # Add area column - gdf = _add_area_column(gdf) - - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - - -def get_poland_boundary() -> gpd.GeoDataFrame: - """Get Poland country boundary. - - Returns: - GeoDataFrame with Poland boundary. - """ - cache_path = CACHE_DIR / "poland_boundary.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - # Dissolve from województwa - woj = get_polish_wojewodztwa() - # Fix invalid geometries with buffer(0) - woj["geometry"] = woj["geometry"].buffer(0) - poland = gpd.GeoDataFrame(geometry=[woj.union_all()], crs=woj.crs) - - _ensure_cache_dir() - poland.to_file(cache_path, driver="GeoJSON") - - return poland - - -# ============================================================================= -# Polish Natural Features -# ============================================================================= - - -def get_polish_mountain_peaks() -> gpd.GeoDataFrame: - """Get Polish mountain peaks, sorted by elevation descending. - - Returns: - GeoDataFrame with mountain peak points and elevation. - """ - cache_path = CACHE_DIR / "polish_mountain_peaks.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - return gdf.sort_values("elevation", ascending=False).reset_index(drop=True) - - sys.stdout.write("Fetching mountain peaks data from OSM...\n") - query = """ - [out:json][timeout:120]; - area["ISO3166-1"="PL"]->.pl; - ( - node["natural"="peak"]["name"]["ele"](area.pl); - ); - out; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - if element.get("type") != "node": - continue - - name = element.get("tags", {}).get("name", "") - ele_str = element.get("tags", {}).get("ele", "") - - if not name or not ele_str or name in seen_names: - continue - - with contextlib.suppress(ValueError): - elevation = float(ele_str.replace(",", ".").split()[0]) - if elevation < MIN_PEAK_ELEVATION: - continue - - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name, "elevation": elevation}, - "geometry": { - "type": "Point", - "coordinates": [element["lon"], element["lat"]], - }, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} mountain peaks.\n") - - if not features: - msg = "No mountain peaks found in OSM data" - raise ValueError(msg) - - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - return gdf.sort_values("elevation", ascending=False).reset_index(drop=True) - - -def get_polish_mountain_ranges() -> gpd.GeoDataFrame: - """Get Polish mountain ranges, sorted by area descending. - - Returns: - GeoDataFrame with mountain range polygons. - """ - cache_path = CACHE_DIR / "polish_mountain_ranges.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - # Fix invalid geometries from OSM data and extract only polygons - gdf["geometry"] = gdf.geometry.make_valid() - gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) - gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching mountain ranges data from OSM...\n") - query = """ - [out:json][timeout:180]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["natural"="mountain_range"]["name"](area.pl); - way["natural"="mountain_range"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_ring_coords = 4 - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - if element.get("type") == "relation": - outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) - if not outer_rings: - continue - geometry = _build_osiedla_geometry(outer_rings, inner_rings) - elif element.get("type") == "way" and "geometry" in element: - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) < min_ring_coords: - continue - if coords[0] != coords[-1]: - coords.append(coords[0]) - geometry = {"type": "Polygon", "coordinates": [coords]} - else: - continue - - seen_names.add(name) - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} mountain ranges.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - # Fix invalid geometries from OSM data and extract only polygons - gdf["geometry"] = gdf.geometry.make_valid() - gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) - gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] - - # Calculate area in km² - gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system - gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 - - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - - -def get_polish_national_parks() -> gpd.GeoDataFrame: - """Get all 23 Polish national parks, sorted by area descending. - - Returns: - GeoDataFrame with national park polygons. - """ - cache_path = CACHE_DIR / "polish_national_parks.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching national parks data from OSM...\n") - query = """ - [out:json][timeout:180]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["boundary"="national_park"]["name"](area.pl); - relation["leisure"="nature_reserve"]["name"]["protect_class"="2"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_ring_coords = 4 - - for element in data.get("elements", []): - if element.get("type") != "relation": - continue - - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - # Filter to only include "Park Narodowy" in name - if "Narodowy" not in name and "narodowy" not in name.lower(): - continue - - outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) - if not outer_rings: - continue - - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name}, - "geometry": _build_osiedla_geometry(outer_rings, inner_rings), - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} national parks.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - # Calculate area in km² - gdf_proj = gdf.to_crs("EPSG:2180") - gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 - - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - - -def get_polish_forests() -> gpd.GeoDataFrame: - """Get major Polish forests (puszcze), sorted by area descending. - - Returns: - GeoDataFrame with forest polygons. - """ - cache_path = CACHE_DIR / "polish_forests.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching forests data from OSM...\n") - # Query for named forests, especially "Puszcza" type - query = """ - [out:json][timeout:300]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["natural"="wood"]["name"](area.pl); - relation["landuse"="forest"]["name"~"Puszcza|Bory|Las"](area.pl); - way["natural"="wood"]["name"~"Puszcza|Bory"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - forest_keywords = ("Puszcza", "Bory", "Las ", "Lasy ") - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - if not any(keyword in name for keyword in forest_keywords): - continue - - geometry = _extract_polygon_from_element(element) - if geometry is None: - continue - - seen_names.add(name) - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} forests.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - gdf = _add_area_column(gdf) - - if len(gdf) > 0: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - -def get_polish_lakes() -> gpd.GeoDataFrame: - """Get Polish lakes, sorted by area descending. - - Returns: - GeoDataFrame with lake polygons. - """ - cache_path = CACHE_DIR / "polish_lakes.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching lakes data from OSM...\n") - query = """ - [out:json][timeout:300]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["natural"="water"]["water"="lake"]["name"](area.pl); - way["natural"="water"]["water"="lake"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - geometry = _extract_polygon_from_element(element) - if geometry is None: - continue - - seen_names.add(name) - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} lakes.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - gdf = _add_area_column(gdf) - - if len(gdf) > 0: - # Filter to lakes > MIN_LAKE_AREA_KM2 to exclude tiny ponds - gdf = gdf[gdf["area_km2"] > MIN_LAKE_AREA_KM2] - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - - return gdf - - -def _extract_river_coords_from_element( - element: dict[str, Any], -) -> list[list[tuple[float, float]]]: - """Extract coordinate lists from a river element. - - Args: - element: OSM element (way or relation). - - Returns: - List of coordinate lists (line segments). - """ - coord_lists: list[list[tuple[float, float]]] = [] - - if element.get("type") == "way" and "geometry" in element: - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) >= MIN_LINE_COORDS: - coord_lists.append(coords) - elif element.get("type") == "relation": - for member in element.get("members", []): - if member.get("type") == "way" and "geometry" in member: - coords = [(p["lon"], p["lat"]) for p in member["geometry"]] - if len(coords) >= MIN_LINE_COORDS: - coord_lists.append(coords) - - return coord_lists - - -def get_polish_rivers() -> gpd.GeoDataFrame: - """Get Polish rivers, sorted by length descending. - - Rivers with the same name but in different locations are kept separate - by using unique IDs from OSM when available. - - Returns: - GeoDataFrame with river linestrings. - """ - cache_path = CACHE_DIR / "polish_rivers.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "length_km" in gdf.columns: - return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching rivers data from OSM...\n") - query = """ - [out:json][timeout:300]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["waterway"="river"]["name"](area.pl); - way["waterway"="river"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - # Group ways by river name AND wikidata ID (or OSM ID for uniqueness) - # This prevents merging different rivers with the same name - rivers_by_key: dict[str, list[list[tuple[float, float]]]] = {} - river_names: dict[str, str] = {} # key -> display name - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name: - continue - - # Use wikidata ID if available, otherwise use element type+id - wikidata = element.get("tags", {}).get("wikidata", "") - if wikidata: - key = f"{name}_{wikidata}" - else: - # Fall back to element ID for grouping related ways - key = f"{name}_{element.get('type')}_{element.get('id')}" - - coord_lists = _extract_river_coords_from_element(element) - if coord_lists: - rivers_by_key.setdefault(key, []).extend(coord_lists) - river_names[key] = name - - features = [] - for key, coord_lists in rivers_by_key.items(): - name = river_names[key] - geometry: dict[str, Any] - if len(coord_lists) == 1: - geometry = {"type": "LineString", "coordinates": coord_lists[0]} - else: - geometry = {"type": "MultiLineString", "coordinates": coord_lists} - - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} rivers.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - gdf = _add_length_column(gdf) - - if len(gdf) > 0: - gdf = gdf[gdf["length_km"] > MIN_RIVER_LENGTH_KM] - return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) - - return gdf - - -def get_polish_islands() -> gpd.GeoDataFrame: - """Get Polish islands, sorted by area descending. - - Returns: - GeoDataFrame with island polygons. - """ - cache_path = CACHE_DIR / "polish_islands.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching islands data from OSM...\n") - query = """ - [out:json][timeout:180]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["place"="island"]["name"](area.pl); - way["place"="island"]["name"](area.pl); - relation["place"="islet"]["name"](area.pl); - way["place"="islet"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - geometry = _extract_polygon_from_element(element) - if geometry is None: - continue - - seen_names.add(name) - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} islands.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - gdf = _add_area_column(gdf) - - if len(gdf) > 0: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - -def get_polish_coastal_features() -> gpd.GeoDataFrame: - """Get Polish coastal features (peninsulas, spits, cliffs), sorted by length. - - Returns: - GeoDataFrame with coastal feature geometries. - """ - cache_path = CACHE_DIR / "polish_coastal_features.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "length_km" in gdf.columns: - return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching coastal features data from OSM...\n") - query = """ - [out:json][timeout:180]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["natural"="peninsula"]["name"](area.pl); - way["natural"="peninsula"]["name"](area.pl); - relation["natural"="spit"]["name"](area.pl); - way["natural"="spit"]["name"](area.pl); - relation["natural"="cliff"]["name"](area.pl); - way["natural"="cliff"]["name"](area.pl); - relation["natural"="coastline"]["name"](area.pl); - way["natural"="beach"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - line_types = ("cliff", "beach", "coastline") - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - natural_type = element.get("tags", {}).get("natural", "") - if not name or name in seen_names: - continue - - geometry = _extract_coastal_geometry(element, natural_type, line_types) - if geometry is None: - continue - - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name, "type": natural_type}, - "geometry": geometry, - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} coastal features.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - gdf = _add_length_column(gdf) - - if len(gdf) > 0: - return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) - return gdf - - -def get_polish_unesco_sites() -> gpd.GeoDataFrame: - """Get Polish UNESCO World Heritage Sites, sorted by inscription year. - - Returns: - GeoDataFrame with UNESCO site geometries. - """ - cache_path = CACHE_DIR / "polish_unesco_sites.geojson" - - if cache_path.exists(): - return gpd.read_file(cache_path) - - sys.stdout.write("Fetching UNESCO sites data from OSM...\n") - query = """ - [out:json][timeout:180]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["heritage"="world_heritage_site"]["name"](area.pl); - way["heritage"="world_heritage_site"]["name"](area.pl); - node["heritage"="world_heritage_site"]["name"](area.pl); - relation["heritage:operator"="whc"]["name"](area.pl); - way["heritage:operator"="whc"]["name"](area.pl); - node["heritage:operator"="whc"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_ring_coords = 4 - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - if element.get("type") == "node": - geometry = { - "type": "Point", - "coordinates": [element["lon"], element["lat"]], - } - elif element.get("type") == "relation": - outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) - if not outer_rings: - continue - geometry = _build_osiedla_geometry(outer_rings, inner_rings) - elif element.get("type") == "way" and "geometry" in element: - coords = [(p["lon"], p["lat"]) for p in element["geometry"]] - if len(coords) < min_ring_coords: - continue - if coords[0] != coords[-1]: - coords.append(coords[0]) - geometry = {"type": "Polygon", "coordinates": [coords]} - else: - continue - - seen_names.add(name) - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} UNESCO sites.\n") - return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - -def get_polish_nature_reserves() -> gpd.GeoDataFrame: - """Get Polish nature reserves, sorted by area descending. - - Returns: - GeoDataFrame with nature reserve polygons. - """ - cache_path = CACHE_DIR / "polish_nature_reserves.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write( - "Fetching nature reserves data from OSM (this may take a while)...\n" - ) - query = """ - [out:json][timeout:600]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["leisure"="nature_reserve"]["name"](area.pl); - way["leisure"="nature_reserve"]["name"](area.pl); - relation["boundary"="protected_area"]["protect_class"="4"]["name"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - - for element in data.get("elements", []): - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - geometry = _extract_polygon_from_element(element) - if geometry is None: - continue - - seen_names.add(name) - features.append( - {"type": "Feature", "properties": {"name": name}, "geometry": geometry} - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} nature reserves.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - gdf = _add_area_column(gdf) - - if len(gdf) > 0: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - -def get_polish_landscape_parks() -> gpd.GeoDataFrame: - """Get Polish landscape parks, sorted by area descending. - - Returns: - GeoDataFrame with landscape park polygons. - """ - cache_path = CACHE_DIR / "polish_landscape_parks.geojson" - - if cache_path.exists(): - gdf = gpd.read_file(cache_path) - # Fix invalid geometries from OSM data and extract only polygons - gdf["geometry"] = gdf.geometry.make_valid() - gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) - # Remove any rows where geometry extraction failed - gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] - if "area_km2" in gdf.columns: - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - return gdf - - sys.stdout.write("Fetching landscape parks data from OSM...\n") - query = """ - [out:json][timeout:300]; - area["ISO3166-1"="PL"]->.pl; - ( - relation["boundary"="protected_area"]["protect_class"="5"]["name"](area.pl); - relation["leisure"="nature_reserve"]["name"~"Park Krajobrazowy"](area.pl); - ); - out geom; - """ - - data = _overpass_query(query) - - features = [] - seen_names: set[str] = set() - min_ring_coords = 4 - - for element in data.get("elements", []): - if element.get("type") != "relation": - continue - - name = element.get("tags", {}).get("name", "") - if not name or name in seen_names: - continue - - outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) - if not outer_rings: - continue - - seen_names.add(name) - features.append( - { - "type": "Feature", - "properties": {"name": name}, - "geometry": _build_osiedla_geometry(outer_rings, inner_rings), - } - ) - - _ensure_cache_dir() - geojson = {"type": "FeatureCollection", "features": features} - cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) - - sys.stdout.write(f"Cached {len(features)} landscape parks.\n") - gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") - - # Fix invalid geometries from OSM data and extract only polygons - gdf["geometry"] = gdf.geometry.make_valid() - gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) - # Remove any rows where geometry extraction failed - gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] - - if len(gdf) > 0: - gdf_proj = gdf.to_crs("EPSG:2180") - gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 - return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) - - return gdf - - -# ============================================================================= -# Utility Functions -# ============================================================================= - - -def download_all_warsaw_data() -> None: - """Download and cache all Warsaw geographic data. - - Call this once to pre-populate the cache. - """ - sys.stdout.write("Downloading all Warsaw geographic data...\n") - sys.stdout.write("=" * 60 + "\n") - - sys.stdout.write("\n1. Warsaw boundary...\n") - get_warsaw_boundary() - - sys.stdout.write("\n2. Vistula river...\n") - get_vistula_river() - - sys.stdout.write("\n3. Warsaw bridges...\n") - get_warsaw_bridges() - - sys.stdout.write("\n4. Metro stations...\n") - get_warsaw_metro_stations() - - sys.stdout.write("\n5. Major streets...\n") - get_warsaw_streets() - - sys.stdout.write("\n6. Landmarks...\n") - get_warsaw_landmarks() - - sys.stdout.write("\n7. Osiedla...\n") - get_warsaw_osiedla() - - sys.stdout.write("\n" + "=" * 60 + "\n") - sys.stdout.write("All Warsaw data cached successfully!\n") - - -def download_all_poland_data() -> None: - """Download and cache all Poland geographic data. - - Call this once to pre-populate the cache. - """ - sys.stdout.write("Downloading all Poland geographic data...\n") - sys.stdout.write("=" * 60 + "\n") - - sys.stdout.write("\n1. Województwa...\n") - get_polish_wojewodztwa() - - sys.stdout.write("\n2. Powiaty...\n") - get_polish_powiaty() - - sys.stdout.write("\n3. Gminy (this may take a while)...\n") - get_polish_gminy() - - sys.stdout.write("\n4. Poland boundary...\n") - get_poland_boundary() - - sys.stdout.write("\n" + "=" * 60 + "\n") - sys.stdout.write("All Poland data cached successfully!\n") - - -def clear_cache() -> None: - """Clear all cached data.""" - if CACHE_DIR.exists(): - shutil.rmtree(CACHE_DIR) - sys.stdout.write("Cache cleared.\n") - else: - sys.stdout.write("Cache directory does not exist.\n") - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Manage geographic data cache") - parser.add_argument( - "--download-warsaw", - action="store_true", - help="Download all Warsaw data", - ) - parser.add_argument( - "--download-poland", - action="store_true", - help="Download all Poland data", - ) - parser.add_argument( - "--download-all", - action="store_true", - help="Download all data", - ) - parser.add_argument( - "--clear-cache", - action="store_true", - help="Clear cached data", - ) - - args = parser.parse_args() - - if args.clear_cache: - clear_cache() - elif args.download_warsaw: - download_all_warsaw_data() - elif args.download_poland: - download_all_poland_data() - elif args.download_all: - download_all_warsaw_data() - download_all_poland_data() - else: - parser.print_help() diff --git a/python_pkg/geo_data/__init__.py b/python_pkg/geo_data/__init__.py new file mode 100644 index 0000000..1eb5745 --- /dev/null +++ b/python_pkg/geo_data/__init__.py @@ -0,0 +1,206 @@ +"""Shared geographic data module for Warsaw and Poland Anki generators. + +This module handles downloading and caching geographic data from various sources: +- OpenStreetMap via Overpass API +- Geofabrik OSM extracts +- GitHub repositories with pre-processed GeoJSON + +All data is cached locally to avoid repeated downloads. +""" + +from __future__ import annotations + +import shutil +import sys + +from python_pkg.geo_data._common import ( + CACHE_DIR, + MAX_RETRIES, + MIN_LAKE_AREA_KM2, + MIN_LINE_COORDS, + MIN_PEAK_ELEVATION, + MIN_RING_COORDS, + MIN_RIVER_LENGTH_KM, + OVERPASS_ENDPOINTS, + POLSKA_GEOJSON_BASE, + REQUEST_TIMEOUT, + RETRY_DELAY, + WIKIDATA_SPARQL, +) +from python_pkg.geo_data._poland_admin import ( + get_poland_boundary, + get_polish_gminy, + get_polish_powiaty, + get_polish_wojewodztwa, +) +from python_pkg.geo_data._poland_nature import ( + get_polish_forests, + get_polish_landscape_parks, + get_polish_mountain_peaks, + get_polish_mountain_ranges, + get_polish_national_parks, + get_polish_nature_reserves, +) +from python_pkg.geo_data._poland_water import ( + get_polish_coastal_features, + get_polish_islands, + get_polish_lakes, + get_polish_rivers, + get_polish_unesco_sites, +) +from python_pkg.geo_data._warsaw import ( + get_vistula_river, + get_warsaw_boundary, + get_warsaw_bridges, + get_warsaw_districts, + get_warsaw_metro_stations, + get_warsaw_osiedla, +) +from python_pkg.geo_data._warsaw_places import get_warsaw_landmarks, get_warsaw_streets + +__all__ = [ + "CACHE_DIR", + "MAX_RETRIES", + "MIN_LAKE_AREA_KM2", + "MIN_LINE_COORDS", + "MIN_PEAK_ELEVATION", + "MIN_RING_COORDS", + "MIN_RIVER_LENGTH_KM", + "OVERPASS_ENDPOINTS", + "POLSKA_GEOJSON_BASE", + "REQUEST_TIMEOUT", + "RETRY_DELAY", + "WIKIDATA_SPARQL", + "clear_cache", + "download_all_poland_data", + "download_all_warsaw_data", + "get_poland_boundary", + "get_polish_coastal_features", + "get_polish_forests", + "get_polish_gminy", + "get_polish_islands", + "get_polish_lakes", + "get_polish_landscape_parks", + "get_polish_mountain_peaks", + "get_polish_mountain_ranges", + "get_polish_national_parks", + "get_polish_nature_reserves", + "get_polish_powiaty", + "get_polish_rivers", + "get_polish_unesco_sites", + "get_polish_wojewodztwa", + "get_vistula_river", + "get_warsaw_boundary", + "get_warsaw_bridges", + "get_warsaw_districts", + "get_warsaw_landmarks", + "get_warsaw_metro_stations", + "get_warsaw_osiedla", + "get_warsaw_streets", +] + + +def download_all_warsaw_data() -> None: + """Download and cache all Warsaw geographic data. + + Call this once to pre-populate the cache. + """ + sys.stdout.write("Downloading all Warsaw geographic data...\n") + sys.stdout.write("=" * 60 + "\n") + + sys.stdout.write("\n1. Warsaw boundary...\n") + get_warsaw_boundary() + + sys.stdout.write("\n2. Vistula river...\n") + get_vistula_river() + + sys.stdout.write("\n3. Warsaw bridges...\n") + get_warsaw_bridges() + + sys.stdout.write("\n4. Metro stations...\n") + get_warsaw_metro_stations() + + sys.stdout.write("\n5. Major streets...\n") + get_warsaw_streets() + + sys.stdout.write("\n6. Landmarks...\n") + get_warsaw_landmarks() + + sys.stdout.write("\n7. Osiedla...\n") + get_warsaw_osiedla() + + sys.stdout.write("\n" + "=" * 60 + "\n") + sys.stdout.write("All Warsaw data cached successfully!\n") + + +def download_all_poland_data() -> None: + """Download and cache all Poland geographic data. + + Call this once to pre-populate the cache. + """ + sys.stdout.write("Downloading all Poland geographic data...\n") + sys.stdout.write("=" * 60 + "\n") + + sys.stdout.write("\n1. Województwa...\n") + get_polish_wojewodztwa() + + sys.stdout.write("\n2. Powiaty...\n") + get_polish_powiaty() + + sys.stdout.write("\n3. Gminy (this may take a while)...\n") + get_polish_gminy() + + sys.stdout.write("\n4. Poland boundary...\n") + get_poland_boundary() + + sys.stdout.write("\n" + "=" * 60 + "\n") + sys.stdout.write("All Poland data cached successfully!\n") + + +def clear_cache() -> None: + """Clear all cached data.""" + if CACHE_DIR.exists(): + shutil.rmtree(CACHE_DIR) + sys.stdout.write("Cache cleared.\n") + else: + sys.stdout.write("Cache directory does not exist.\n") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Manage geographic data cache") + parser.add_argument( + "--download-warsaw", + action="store_true", + help="Download all Warsaw data", + ) + parser.add_argument( + "--download-poland", + action="store_true", + help="Download all Poland data", + ) + parser.add_argument( + "--download-all", + action="store_true", + help="Download all data", + ) + parser.add_argument( + "--clear-cache", + action="store_true", + help="Clear cached data", + ) + + args = parser.parse_args() + + if args.clear_cache: + clear_cache() + elif args.download_warsaw: + download_all_warsaw_data() + elif args.download_poland: + download_all_poland_data() + elif args.download_all: + download_all_warsaw_data() + download_all_poland_data() + else: + parser.print_help() diff --git a/python_pkg/geo_data/_common.py b/python_pkg/geo_data/_common.py new file mode 100644 index 0000000..a30536a --- /dev/null +++ b/python_pkg/geo_data/_common.py @@ -0,0 +1,318 @@ +"""Common utilities for geographic data operations. + +Shared constants, API helpers, and geometry extraction functions used +across the geo_data package. +""" + +from __future__ import annotations + +import json +from pathlib import Path +import sys +import time +from typing import TYPE_CHECKING +from urllib.request import urlopen + +import geopandas as gpd +import requests +from shapely.geometry import ( + GeometryCollection, + MultiPolygon, + Polygon, +) + +if TYPE_CHECKING: + from typing import Any + +# Parent directory of the geo_data package (i.e. python_pkg/) +_PKG_DIR = Path(__file__).resolve().parent.parent + +# Shared cache directory for all geo data +CACHE_DIR = _PKG_DIR / "geo_cache" + +# Overpass API endpoints (multiple for redundancy) +# Note: kumi.systems is more reliable, so it's first +OVERPASS_ENDPOINTS = [ + "https://overpass.kumi.systems/api/interpreter", + "https://overpass-api.de/api/interpreter", + "https://maps.mail.ru/osm/tools/overpass/api/interpreter", +] + +# GitHub URLs for pre-processed data +POLSKA_GEOJSON_BASE = "https://raw.githubusercontent.com/ppatrzyk/polska-geojson/master" + +# Wikidata SPARQL endpoint +WIKIDATA_SPARQL = "https://query.wikidata.org/sparql" + +# Request timeout and retry settings +REQUEST_TIMEOUT = 180 +MAX_RETRIES = 3 +RETRY_DELAY = 5 + +# Data thresholds for filtering +MIN_PEAK_ELEVATION = 300 # meters +MIN_LAKE_AREA_KM2 = 0.5 # km² +MIN_RIVER_LENGTH_KM = 10 # km +MIN_LINE_COORDS = 2 # minimum coordinates for a line +MIN_RING_COORDS = 4 # minimum coordinates for a polygon ring + + +def _ensure_cache_dir() -> None: + """Create cache directory if it doesn't exist.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +def _extract_polygonal_geometry( + geom: Polygon | MultiPolygon | GeometryCollection, +) -> Polygon | MultiPolygon | None: + """Extract only polygonal geometry from a geometry that may be mixed. + + Some OSM data comes as GeometryCollections containing polygons mixed with + lines. This function extracts only the polygon/multipolygon parts. + + Args: + geom: Input geometry (Polygon, MultiPolygon, or GeometryCollection). + + Returns: + Polygon or MultiPolygon with only the polygonal parts, or None if empty. + """ + if isinstance(geom, Polygon | MultiPolygon): + return geom + + if isinstance(geom, GeometryCollection): + polygons = [g for g in geom.geoms if isinstance(g, Polygon | MultiPolygon)] + if not polygons: + return None + if len(polygons) == 1: + return polygons[0] + # Flatten MultiPolygons and combine all polygons + all_polys = [] + for p in polygons: + if isinstance(p, Polygon): + all_polys.append(p) + elif isinstance(p, MultiPolygon): + all_polys.extend(p.geoms) + return MultiPolygon(all_polys) + + return None + + +def _try_single_request( + endpoint: str, query: str +) -> tuple[dict[str, Any] | None, Exception | None]: + """Try a single request to an endpoint. + + Args: + endpoint: Overpass API endpoint URL. + query: Overpass QL query string. + + Returns: + Tuple of (result, error). One will be None. + """ + try: + sys.stdout.write(f" Querying {endpoint}...\n") + response = requests.post( + endpoint, + data={"data": query}, + timeout=REQUEST_TIMEOUT, + ) + response.raise_for_status() + result = response.json() + except (requests.RequestException, requests.Timeout, ValueError) as e: + return None, e + else: + # Check for valid response with elements + if not isinstance(result, dict) or "elements" not in result: + return None, ValueError("Invalid response format") + return result, None + + +def _overpass_query(query: str) -> dict[str, Any]: + """Execute an Overpass API query with retry logic. + + Args: + query: Overpass QL query string. + + Returns: + JSON response from the API. + + Raises: + RuntimeError: If all endpoints fail. + """ + last_error: Exception | None = None + + for endpoint in OVERPASS_ENDPOINTS: + for attempt in range(MAX_RETRIES): + result, error = _try_single_request(endpoint, query) + if result is not None: + return result + last_error = error + sys.stdout.write(f" Attempt {attempt + 1} failed: {error}\n") + if attempt < MAX_RETRIES - 1: + time.sleep(RETRY_DELAY) + + msg = f"All Overpass API endpoints failed. Last error: {last_error}" + raise RuntimeError(msg) + + +def _download_github_geojson(url: str, cache_path: Path) -> gpd.GeoDataFrame: + """Download GeoJSON from GitHub and cache it. + + Args: + url: URL to download from. + cache_path: Path to cache the data. + + Returns: + GeoDataFrame with the data. + """ + if cache_path.exists(): + return gpd.read_file(cache_path) + + sys.stdout.write(f"Downloading from {url}...\n") + if not url.startswith(("http://", "https://")): + msg = f"Unsupported URL scheme: {url}" + raise ValueError(msg) + with urlopen(url, timeout=REQUEST_TIMEOUT) as response: + data = json.loads(response.read().decode()) + + _ensure_cache_dir() + cache_path.write_text(json.dumps(data)) + + return gpd.GeoDataFrame.from_features(data["features"], crs="EPSG:4326") + + +def _extract_osiedla_rings( + element: dict[str, Any], min_coords: int +) -> tuple[list[list[tuple[float, float]]], list[list[tuple[float, float]]]]: + """Extract outer and inner rings from an OSM relation. + + Args: + element: OSM relation element. + min_coords: Minimum number of coordinates for a valid ring. + + Returns: + Tuple of (outer_rings, inner_rings). + """ + outer_rings: list[list[tuple[float, float]]] = [] + inner_rings: list[list[tuple[float, float]]] = [] + + for member in element.get("members", []): + if "geometry" not in member: + continue + ring = [(p["lon"], p["lat"]) for p in member["geometry"]] + if len(ring) < min_coords: + continue + # Close the ring if not closed + if ring[0] != ring[-1]: + ring.append(ring[0]) + if member.get("role") == "outer": + outer_rings.append(ring) + elif member.get("role") == "inner": + inner_rings.append(ring) + + return outer_rings, inner_rings + + +def _build_osiedla_geometry( + outer_rings: list[list[tuple[float, float]]], + inner_rings: list[list[tuple[float, float]]], +) -> dict[str, Any]: + """Build GeoJSON geometry from outer and inner rings. + + Args: + outer_rings: List of outer ring coordinates. + inner_rings: List of inner ring coordinates. + + Returns: + GeoJSON geometry dict. + """ + if len(outer_rings) == 1: + return { + "type": "Polygon", + "coordinates": [outer_rings[0], *inner_rings], + } + # Multiple outer rings - create MultiPolygon + # Each polygon in a MultiPolygon is [exterior, hole1, hole2, ...] + return { + "type": "MultiPolygon", + "coordinates": [[ring] for ring in outer_rings], + } + + +def _extract_polygon_from_element( + element: dict[str, Any], +) -> dict[str, Any] | None: + """Extract polygon geometry from an OSM relation or way element. + + Args: + element: OSM element (relation or way). + + Returns: + GeoJSON geometry dict, or None if extraction fails. + """ + if element.get("type") == "relation": + outer_rings, inner_rings = _extract_osiedla_rings(element, MIN_RING_COORDS) + if not outer_rings: + return None + return _build_osiedla_geometry(outer_rings, inner_rings) + + if element.get("type") == "way" and "geometry" in element: + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) < MIN_RING_COORDS: + return None + if coords[0] != coords[-1]: + coords.append(coords[0]) + return {"type": "Polygon", "coordinates": [coords]} + + return None + + +def _extract_line_from_way(element: dict[str, Any]) -> dict[str, Any] | None: + """Extract line geometry from an OSM way element. + + Args: + element: OSM way element. + + Returns: + GeoJSON LineString geometry dict, or None if extraction fails. + """ + if element.get("type") != "way" or "geometry" not in element: + return None + + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) < MIN_LINE_COORDS: + return None + + return {"type": "LineString", "coordinates": coords} + + +def _add_area_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Add area_km2 column to a GeoDataFrame. + + Args: + gdf: GeoDataFrame with polygon geometries. + + Returns: + GeoDataFrame with area_km2 column added. + """ + if len(gdf) == 0: + return gdf + gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system + gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 + return gdf + + +def _add_length_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Add length_km column to a GeoDataFrame. + + Args: + gdf: GeoDataFrame with line geometries. + + Returns: + GeoDataFrame with length_km column added. + """ + if len(gdf) == 0: + return gdf + gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system + gdf["length_km"] = gdf_proj.geometry.length / 1000 + return gdf diff --git a/python_pkg/geo_data/_poland_admin.py b/python_pkg/geo_data/_poland_admin.py new file mode 100644 index 0000000..10fb876 --- /dev/null +++ b/python_pkg/geo_data/_poland_admin.py @@ -0,0 +1,226 @@ +"""Polish administrative boundary data. + +Functions for downloading and caching Polish administrative divisions: +województwa, powiaty, gminy, and the national boundary. +Includes Wikidata integration for population data. +""" + +from __future__ import annotations + +import contextlib +import json +import sys +from typing import TYPE_CHECKING + +import geopandas as gpd +import requests + +from python_pkg.geo_data._common import ( + CACHE_DIR, + POLSKA_GEOJSON_BASE, + WIKIDATA_SPARQL, + _build_osiedla_geometry, + _download_github_geojson, + _ensure_cache_dir, + _extract_osiedla_rings, + _overpass_query, +) + +if TYPE_CHECKING: + from typing import Any + + +def _query_wikidata(query: str) -> list[dict[str, Any]]: + """Query Wikidata SPARQL endpoint. + + Args: + query: SPARQL query string. + + Returns: + List of result bindings. + """ + response = requests.get( + WIKIDATA_SPARQL, + params={"query": query, "format": "json"}, + timeout=60, + ) + response.raise_for_status() + return response.json()["results"]["bindings"] + + +def _get_powiaty_population() -> dict[str, int]: + """Get population data for all Polish powiaty from Wikidata. + + Returns: + Dictionary mapping powiat name to population. + """ + cache_path = CACHE_DIR / "powiaty_population.json" + + if cache_path.exists(): + return json.loads(cache_path.read_text()) + + # Query Wikidata for all powiaty (Q247073) in Poland (Q36) with population + # Filter to only current Polish powiaty using country=Poland filter + query = """ + SELECT ?powiat ?powiatLabel ?population WHERE { + ?powiat wdt:P31 wd:Q247073. + ?powiat wdt:P17 wd:Q36. + ?powiat wdt:P1082 ?population. + SERVICE wikibase:label { bd:serviceParam wikibase:language "pl,en". } + } + ORDER BY DESC(?population) + """ + + sys.stdout.write("Fetching powiaty population data from Wikidata...\n") + results = _query_wikidata(query) + + population_map: dict[str, int] = {} + for item in results: + label = item.get("powiatLabel", {}).get("value", "") + pop = item.get("population", {}).get("value", "0") + if label and pop: + # Remove "powiat" prefix if present for matching + clean_label = label.replace("powiat ", "").strip() + with contextlib.suppress(ValueError): + population_map[clean_label] = int(pop) + + _ensure_cache_dir() + cache_path.write_text(json.dumps(population_map, ensure_ascii=False, indent=2)) + + sys.stdout.write(f"Cached population data for {len(population_map)} powiaty.\n") + return population_map + + +def get_polish_wojewodztwa() -> gpd.GeoDataFrame: + """Get Polish województwa (voivodeships). + + Returns: + GeoDataFrame with województwa boundaries. + """ + url = f"{POLSKA_GEOJSON_BASE}/wojewodztwa/wojewodztwa-min.geojson" + cache_path = CACHE_DIR / "polish_wojewodztwa.geojson" + return _download_github_geojson(url, cache_path) + + +def get_polish_powiaty() -> gpd.GeoDataFrame: + """Get Polish powiaty (counties), sorted by population descending. + + Returns: + GeoDataFrame with powiat boundaries and population. + """ + url = f"{POLSKA_GEOJSON_BASE}/powiaty/powiaty-min.geojson" + cache_path = CACHE_DIR / "polish_powiaty.geojson" + gdf = _download_github_geojson(url, cache_path) + + # Get population data from Wikidata + population_map = _get_powiaty_population() + + # Add population column + def get_population(nazwa: str) -> int: + """Match powiat name to population data.""" + if not nazwa: + return 0 + # Remove "powiat " prefix for matching + clean_name = nazwa.replace("powiat ", "").strip() + # Try direct match + if clean_name in population_map: + return population_map[clean_name] + # Try lowercase + name_lower = clean_name.lower() + for pop_name, pop in population_map.items(): + if pop_name.lower() == name_lower: + return pop + return 0 + + gdf["population"] = gdf["nazwa"].apply(get_population) + + # Sort by population descending + return gdf.sort_values("population", ascending=False).reset_index(drop=True) + + +def get_polish_gminy() -> gpd.GeoDataFrame: + """Get Polish gminy (municipalities) from OSM, sorted by area descending. + + Returns: + GeoDataFrame with gminy boundaries. + """ + cache_path = CACHE_DIR / "polish_gminy.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching gminy data from OSM (this may take a while)...\n") + # Polish gminy are admin_level=7 in OSM + query = """ + [out:json][timeout:300]; + area["ISO3166-1"="PL"]->.pl; + relation["boundary"="administrative"]["admin_level"="7"]["name"](area.pl); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + min_ring_coords = 4 + + for element in data.get("elements", []): + if element.get("type") != "relation": + continue + + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) + if not outer_rings: + continue + + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name}, + "geometry": _build_osiedla_geometry(outer_rings, inner_rings), + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + sys.stdout.write(f"Cached {len(features)} gminy.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + # Add area column + from python_pkg.geo_data._common import _add_area_column + + gdf = _add_area_column(gdf) + + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + + +def get_poland_boundary() -> gpd.GeoDataFrame: + """Get Poland country boundary. + + Returns: + GeoDataFrame with Poland boundary. + """ + cache_path = CACHE_DIR / "poland_boundary.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + # Dissolve from województwa + woj = get_polish_wojewodztwa() + # Fix invalid geometries with buffer(0) + woj["geometry"] = woj["geometry"].buffer(0) + poland = gpd.GeoDataFrame(geometry=[woj.union_all()], crs=woj.crs) + + _ensure_cache_dir() + poland.to_file(cache_path, driver="GeoJSON") + + return poland diff --git a/python_pkg/geo_data/_poland_nature.py b/python_pkg/geo_data/_poland_nature.py new file mode 100644 index 0000000..222ed0b --- /dev/null +++ b/python_pkg/geo_data/_poland_nature.py @@ -0,0 +1,446 @@ +"""Polish natural land features. + +Functions for downloading and caching data about Polish mountains, +national parks, forests, nature reserves, and landscape parks. +""" + +from __future__ import annotations + +import contextlib +import json +import sys +from typing import TYPE_CHECKING + +import geopandas as gpd + +from python_pkg.geo_data._common import ( + CACHE_DIR, + MIN_PEAK_ELEVATION, + _add_area_column, + _build_osiedla_geometry, + _ensure_cache_dir, + _extract_osiedla_rings, + _extract_polygon_from_element, + _extract_polygonal_geometry, + _overpass_query, +) + +if TYPE_CHECKING: + from typing import Any + + +def get_polish_mountain_peaks() -> gpd.GeoDataFrame: + """Get Polish mountain peaks, sorted by elevation descending. + + Returns: + GeoDataFrame with mountain peak points and elevation. + """ + cache_path = CACHE_DIR / "polish_mountain_peaks.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + return gdf.sort_values("elevation", ascending=False).reset_index(drop=True) + + sys.stdout.write("Fetching mountain peaks data from OSM...\n") + query = """ + [out:json][timeout:120]; + area["ISO3166-1"="PL"]->.pl; + ( + node["natural"="peak"]["name"]["ele"](area.pl); + ); + out; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + if element.get("type") != "node": + continue + + name = element.get("tags", {}).get("name", "") + ele_str = element.get("tags", {}).get("ele", "") + + if not name or not ele_str or name in seen_names: + continue + + with contextlib.suppress(ValueError): + elevation = float(ele_str.replace(",", ".").split()[0]) + if elevation < MIN_PEAK_ELEVATION: + continue + + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name, "elevation": elevation}, + "geometry": { + "type": "Point", + "coordinates": [element["lon"], element["lat"]], + }, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} mountain peaks.\n") + + if not features: + msg = "No mountain peaks found in OSM data" + raise ValueError(msg) + + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + return gdf.sort_values("elevation", ascending=False).reset_index(drop=True) + + +def get_polish_mountain_ranges() -> gpd.GeoDataFrame: + """Get Polish mountain ranges, sorted by area descending. + + Returns: + GeoDataFrame with mountain range polygons. + """ + cache_path = CACHE_DIR / "polish_mountain_ranges.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + # Fix invalid geometries from OSM data and extract only polygons + gdf["geometry"] = gdf.geometry.make_valid() + gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) + gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching mountain ranges data from OSM...\n") + query = """ + [out:json][timeout:180]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["natural"="mountain_range"]["name"](area.pl); + way["natural"="mountain_range"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features: list[dict[str, Any]] = [] + seen_names: set[str] = set() + min_ring_coords = 4 + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + if element.get("type") == "relation": + outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) + if not outer_rings: + continue + geometry = _build_osiedla_geometry(outer_rings, inner_rings) + elif element.get("type") == "way" and "geometry" in element: + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) < min_ring_coords: + continue + if coords[0] != coords[-1]: + coords.append(coords[0]) + geometry = {"type": "Polygon", "coordinates": [coords]} + else: + continue + + seen_names.add(name) + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} mountain ranges.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + # Fix invalid geometries from OSM data and extract only polygons + gdf["geometry"] = gdf.geometry.make_valid() + gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) + gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] + + # Calculate area in km² + gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system + gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 + + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + + +def get_polish_national_parks() -> gpd.GeoDataFrame: + """Get all 23 Polish national parks, sorted by area descending. + + Returns: + GeoDataFrame with national park polygons. + """ + cache_path = CACHE_DIR / "polish_national_parks.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching national parks data from OSM...\n") + query = """ + [out:json][timeout:180]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["boundary"="national_park"]["name"](area.pl); + relation["leisure"="nature_reserve"]["name"]["protect_class"="2"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + min_ring_coords = 4 + + for element in data.get("elements", []): + if element.get("type") != "relation": + continue + + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + # Filter to only include "Park Narodowy" in name + if "Narodowy" not in name and "narodowy" not in name.lower(): + continue + + outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) + if not outer_rings: + continue + + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name}, + "geometry": _build_osiedla_geometry(outer_rings, inner_rings), + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} national parks.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + # Calculate area in km² + gdf_proj = gdf.to_crs("EPSG:2180") + gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 + + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + + +def get_polish_forests() -> gpd.GeoDataFrame: + """Get major Polish forests (puszcze), sorted by area descending. + + Returns: + GeoDataFrame with forest polygons. + """ + cache_path = CACHE_DIR / "polish_forests.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching forests data from OSM...\n") + # Query for named forests, especially "Puszcza" type + query = """ + [out:json][timeout:300]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["natural"="wood"]["name"](area.pl); + relation["landuse"="forest"]["name"~"Puszcza|Bory|Las"](area.pl); + way["natural"="wood"]["name"~"Puszcza|Bory"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + forest_keywords = ("Puszcza", "Bory", "Las ", "Lasy ") + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + if not any(keyword in name for keyword in forest_keywords): + continue + + geometry = _extract_polygon_from_element(element) + if geometry is None: + continue + + seen_names.add(name) + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} forests.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = _add_area_column(gdf) + + if len(gdf) > 0: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + +def get_polish_nature_reserves() -> gpd.GeoDataFrame: + """Get Polish nature reserves, sorted by area descending. + + Returns: + GeoDataFrame with nature reserve polygons. + """ + cache_path = CACHE_DIR / "polish_nature_reserves.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write( + "Fetching nature reserves data from OSM (this may take a while)...\n" + ) + query = """ + [out:json][timeout:600]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["leisure"="nature_reserve"]["name"](area.pl); + way["leisure"="nature_reserve"]["name"](area.pl); + relation["boundary"="protected_area"]["protect_class"="4"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + geometry = _extract_polygon_from_element(element) + if geometry is None: + continue + + seen_names.add(name) + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} nature reserves.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = _add_area_column(gdf) + + if len(gdf) > 0: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + +def get_polish_landscape_parks() -> gpd.GeoDataFrame: + """Get Polish landscape parks, sorted by area descending. + + Returns: + GeoDataFrame with landscape park polygons. + """ + cache_path = CACHE_DIR / "polish_landscape_parks.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + # Fix invalid geometries from OSM data and extract only polygons + gdf["geometry"] = gdf.geometry.make_valid() + gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) + # Remove any rows where geometry extraction failed + gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching landscape parks data from OSM...\n") + query = """ + [out:json][timeout:300]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["boundary"="protected_area"]["protect_class"="5"]["name"](area.pl); + relation["leisure"="nature_reserve"]["name"~"Park Krajobrazowy"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + min_ring_coords = 4 + + for element in data.get("elements", []): + if element.get("type") != "relation": + continue + + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) + if not outer_rings: + continue + + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name}, + "geometry": _build_osiedla_geometry(outer_rings, inner_rings), + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} landscape parks.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + # Fix invalid geometries from OSM data and extract only polygons + gdf["geometry"] = gdf.geometry.make_valid() + gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry) + # Remove any rows where geometry extraction failed + gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty] + + if len(gdf) > 0: + gdf_proj = gdf.to_crs("EPSG:2180") + gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000 + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + + return gdf diff --git a/python_pkg/geo_data/_poland_water.py b/python_pkg/geo_data/_poland_water.py new file mode 100644 index 0000000..76d2807 --- /dev/null +++ b/python_pkg/geo_data/_poland_water.py @@ -0,0 +1,437 @@ +"""Polish water features and cultural sites. + +Functions for downloading and caching data about Polish lakes, rivers, +islands, coastal features, and UNESCO World Heritage sites. +""" + +from __future__ import annotations + +import json +import sys +from typing import TYPE_CHECKING + +import geopandas as gpd + +from python_pkg.geo_data._common import ( + CACHE_DIR, + MIN_LAKE_AREA_KM2, + MIN_LINE_COORDS, + MIN_RING_COORDS, + MIN_RIVER_LENGTH_KM, + _add_area_column, + _add_length_column, + _build_osiedla_geometry, + _ensure_cache_dir, + _extract_osiedla_rings, + _extract_polygon_from_element, + _overpass_query, +) + +if TYPE_CHECKING: + from typing import Any + + +def _extract_coastal_geometry( + element: dict[str, Any], + natural_type: str, + line_types: tuple[str, ...], +) -> dict[str, Any] | None: + """Extract geometry from a coastal feature element. + + For cliffs and beaches, returns LineString. For others, returns Polygon. + + Args: + element: OSM element. + natural_type: The natural= tag value. + line_types: Tuple of natural types that should be lines. + + Returns: + GeoJSON geometry dict, or None if extraction fails. + """ + if element.get("type") == "relation": + return _extract_polygon_from_element(element) + + if element.get("type") != "way" or "geometry" not in element: + return None + + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) < MIN_LINE_COORDS: + return None + + # For cliffs and beaches, keep as linestring + if natural_type in line_types: + return {"type": "LineString", "coordinates": coords} + + # Otherwise try to make a polygon + if len(coords) >= MIN_RING_COORDS: + if coords[0] != coords[-1]: + coords.append(coords[0]) + return {"type": "Polygon", "coordinates": [coords]} + + return None + + +def _extract_river_coords_from_element( + element: dict[str, Any], +) -> list[list[tuple[float, float]]]: + """Extract coordinate lists from a river element. + + Args: + element: OSM element (way or relation). + + Returns: + List of coordinate lists (line segments). + """ + coord_lists: list[list[tuple[float, float]]] = [] + + if element.get("type") == "way" and "geometry" in element: + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) >= MIN_LINE_COORDS: + coord_lists.append(coords) + elif element.get("type") == "relation": + for member in element.get("members", []): + if member.get("type") == "way" and "geometry" in member: + coords = [(p["lon"], p["lat"]) for p in member["geometry"]] + if len(coords) >= MIN_LINE_COORDS: + coord_lists.append(coords) + + return coord_lists + + +def get_polish_lakes() -> gpd.GeoDataFrame: + """Get Polish lakes, sorted by area descending. + + Returns: + GeoDataFrame with lake polygons. + """ + cache_path = CACHE_DIR / "polish_lakes.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching lakes data from OSM...\n") + query = """ + [out:json][timeout:300]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["natural"="water"]["water"="lake"]["name"](area.pl); + way["natural"="water"]["water"="lake"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + geometry = _extract_polygon_from_element(element) + if geometry is None: + continue + + seen_names.add(name) + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} lakes.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = _add_area_column(gdf) + + if len(gdf) > 0: + # Filter to lakes > MIN_LAKE_AREA_KM2 to exclude tiny ponds + gdf = gdf[gdf["area_km2"] > MIN_LAKE_AREA_KM2] + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + + return gdf + + +def get_polish_rivers() -> gpd.GeoDataFrame: + """Get Polish rivers, sorted by length descending. + + Rivers with the same name but in different locations are kept separate + by using unique IDs from OSM when available. + + Returns: + GeoDataFrame with river linestrings. + """ + cache_path = CACHE_DIR / "polish_rivers.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "length_km" in gdf.columns: + return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching rivers data from OSM...\n") + query = """ + [out:json][timeout:300]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["waterway"="river"]["name"](area.pl); + way["waterway"="river"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + # Group ways by river name AND wikidata ID (or OSM ID for uniqueness) + # This prevents merging different rivers with the same name + rivers_by_key: dict[str, list[list[tuple[float, float]]]] = {} + river_names: dict[str, str] = {} # key -> display name + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name: + continue + + # Use wikidata ID if available, otherwise use element type+id + wikidata = element.get("tags", {}).get("wikidata", "") + if wikidata: + key = f"{name}_{wikidata}" + else: + # Fall back to element ID for grouping related ways + key = f"{name}_{element.get('type')}_{element.get('id')}" + + coord_lists = _extract_river_coords_from_element(element) + if coord_lists: + rivers_by_key.setdefault(key, []).extend(coord_lists) + river_names[key] = name + + features = [] + for key, coord_lists in rivers_by_key.items(): + name = river_names[key] + geometry: dict[str, Any] + if len(coord_lists) == 1: + geometry = {"type": "LineString", "coordinates": coord_lists[0]} + else: + geometry = {"type": "MultiLineString", "coordinates": coord_lists} + + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} rivers.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = _add_length_column(gdf) + + if len(gdf) > 0: + gdf = gdf[gdf["length_km"] > MIN_RIVER_LENGTH_KM] + return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) + + return gdf + + +def get_polish_islands() -> gpd.GeoDataFrame: + """Get Polish islands, sorted by area descending. + + Returns: + GeoDataFrame with island polygons. + """ + cache_path = CACHE_DIR / "polish_islands.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "area_km2" in gdf.columns: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching islands data from OSM...\n") + query = """ + [out:json][timeout:180]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["place"="island"]["name"](area.pl); + way["place"="island"]["name"](area.pl); + relation["place"="islet"]["name"](area.pl); + way["place"="islet"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + geometry = _extract_polygon_from_element(element) + if geometry is None: + continue + + seen_names.add(name) + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} islands.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = _add_area_column(gdf) + + if len(gdf) > 0: + return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True) + return gdf + + +def get_polish_coastal_features() -> gpd.GeoDataFrame: + """Get Polish coastal features (peninsulas, spits, cliffs), sorted by length. + + Returns: + GeoDataFrame with coastal feature geometries. + """ + cache_path = CACHE_DIR / "polish_coastal_features.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + if "length_km" in gdf.columns: + return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) + return gdf + + sys.stdout.write("Fetching coastal features data from OSM...\n") + query = """ + [out:json][timeout:180]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["natural"="peninsula"]["name"](area.pl); + way["natural"="peninsula"]["name"](area.pl); + relation["natural"="spit"]["name"](area.pl); + way["natural"="spit"]["name"](area.pl); + relation["natural"="cliff"]["name"](area.pl); + way["natural"="cliff"]["name"](area.pl); + relation["natural"="coastline"]["name"](area.pl); + way["natural"="beach"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + line_types = ("cliff", "beach", "coastline") + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + natural_type = element.get("tags", {}).get("natural", "") + if not name or name in seen_names: + continue + + geometry = _extract_coastal_geometry(element, natural_type, line_types) + if geometry is None: + continue + + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name, "type": natural_type}, + "geometry": geometry, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} coastal features.\n") + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = _add_length_column(gdf) + + if len(gdf) > 0: + return gdf.sort_values("length_km", ascending=False).reset_index(drop=True) + return gdf + + +def get_polish_unesco_sites() -> gpd.GeoDataFrame: + """Get Polish UNESCO World Heritage Sites, sorted by inscription year. + + Returns: + GeoDataFrame with UNESCO site geometries. + """ + cache_path = CACHE_DIR / "polish_unesco_sites.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + sys.stdout.write("Fetching UNESCO sites data from OSM...\n") + query = """ + [out:json][timeout:180]; + area["ISO3166-1"="PL"]->.pl; + ( + relation["heritage"="world_heritage_site"]["name"](area.pl); + way["heritage"="world_heritage_site"]["name"](area.pl); + node["heritage"="world_heritage_site"]["name"](area.pl); + relation["heritage:operator"="whc"]["name"](area.pl); + way["heritage:operator"="whc"]["name"](area.pl); + node["heritage:operator"="whc"]["name"](area.pl); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + min_ring_coords = 4 + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + if element.get("type") == "node": + geometry: dict[str, Any] = { + "type": "Point", + "coordinates": [element["lon"], element["lat"]], + } + elif element.get("type") == "relation": + outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) + if not outer_rings: + continue + geometry = _build_osiedla_geometry(outer_rings, inner_rings) + elif element.get("type") == "way" and "geometry" in element: + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) < min_ring_coords: + continue + if coords[0] != coords[-1]: + coords.append(coords[0]) + geometry = {"type": "Polygon", "coordinates": [coords]} + else: + continue + + seen_names.add(name) + features.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geometry} + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson, ensure_ascii=False)) + + sys.stdout.write(f"Cached {len(features)} UNESCO sites.\n") + return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") diff --git a/python_pkg/geo_data/_warsaw.py b/python_pkg/geo_data/_warsaw.py new file mode 100644 index 0000000..76c45a3 --- /dev/null +++ b/python_pkg/geo_data/_warsaw.py @@ -0,0 +1,407 @@ +"""Warsaw geographic data functions. + +Functions for downloading and caching Warsaw-specific geographic data: +boundaries, districts, Vistula river, bridges, metro stations, and osiedla. +""" + +from __future__ import annotations + +import json +import sys + +import geopandas as gpd +from shapely.geometry import LineString + +from python_pkg.geo_data._common import ( + _PKG_DIR, + CACHE_DIR, + _build_osiedla_geometry, + _ensure_cache_dir, + _extract_osiedla_rings, + _overpass_query, +) + + +def get_warsaw_boundary() -> gpd.GeoDataFrame: + """Get Warsaw city boundary. + + Returns: + GeoDataFrame with Warsaw boundary polygon. + """ + cache_path = CACHE_DIR / "warsaw_boundary.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + # Try to use districts file first + districts_path = ( + _PKG_DIR / "anki_decks" / "warsaw_districts" / "warszawa-dzielnice.geojson" + ) + if districts_path.exists(): + warsaw_gdf = gpd.read_file(districts_path) + warsaw_boundary = warsaw_gdf[warsaw_gdf["name"] == "Warszawa"] + if len(warsaw_boundary) == 0: + warsaw_boundary = gpd.GeoDataFrame( + geometry=[warsaw_gdf.union_all()], crs=warsaw_gdf.crs + ) + _ensure_cache_dir() + warsaw_boundary.to_file(cache_path, driver="GeoJSON") + return warsaw_boundary + + # Fallback to Overpass query + sys.stdout.write("Fetching Warsaw boundary from OpenStreetMap...\n") + query = """ + [out:json][timeout:60]; + relation["name"="Warszawa"]["admin_level"="6"]; + out geom; + """ + + data = _overpass_query(query) + + features = [] + for element in data.get("elements", []): + if element.get("type") == "relation": + coords = [] + for member in element.get("members", []): + if member.get("role") == "outer" and "geometry" in member: + coords.extend([(p["lon"], p["lat"]) for p in member["geometry"]]) + if coords: + features.append( + { + "type": "Feature", + "properties": {"name": "Warszawa"}, + "geometry": {"type": "Polygon", "coordinates": [coords]}, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + +def get_warsaw_districts() -> gpd.GeoDataFrame: + """Get Warsaw districts (dzielnice). + + Returns: + GeoDataFrame with district boundaries. + """ + districts_path = ( + _PKG_DIR / "anki_decks" / "warsaw_districts" / "warszawa-dzielnice.geojson" + ) + if districts_path.exists(): + gdf = gpd.read_file(districts_path) + return gdf[gdf["name"] != "Warszawa"].copy() + + msg = "Warsaw districts GeoJSON not found" + raise FileNotFoundError(msg) + + +def get_vistula_river() -> gpd.GeoDataFrame: + """Get Vistula river in Warsaw. + + Returns: + GeoDataFrame with river geometry. + """ + cache_path = CACHE_DIR / "warsaw_vistula.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + sys.stdout.write("Fetching Vistula river data...\n") + query = """ + [out:json][timeout:60]; + area["name"="Warszawa"]["admin_level"="6"]->.warsaw; + ( + way["waterway"="river"]["name"="Wisła"](area.warsaw); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + min_coords = 2 + for element in data.get("elements", []): + if element.get("type") == "way" and "geometry" in element: + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) >= min_coords: + features.append( + { + "type": "Feature", + "properties": {"name": "Wisła"}, + "geometry": {"type": "LineString", "coordinates": coords}, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + +def get_warsaw_bridges() -> gpd.GeoDataFrame: + """Get Warsaw bridges over the Vistula. + + Returns: + GeoDataFrame with bridge geometries. + """ + cache_path = CACHE_DIR / "warsaw_bridges.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + sys.stdout.write("Fetching Warsaw bridges data...\n") + + # First get the Vistula to filter bridges + vistula = get_vistula_river() + vistula_union = vistula.union_all() + vistula_buffer = vistula_union.buffer(0.002) # ~200m buffer + + # Query for bridges with "Most" in name - smaller query + query = """ + [out:json][timeout:90]; + area["name"="Warszawa"]["admin_level"="6"]->.warsaw; + way["bridge"="yes"]["name"~"^Most"](area.warsaw); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + min_coords = 2 + + for element in data.get("elements", []): + if element.get("type") != "way" or "geometry" not in element: + continue + + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) < min_coords: + continue + + line = LineString(coords) + + # Check if bridge crosses/is near Vistula + if line.intersects(vistula_buffer): + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name, "osm_id": element.get("id")}, + "geometry": {"type": "LineString", "coordinates": coords}, + } + ) + + # Merge segments of the same bridge + merged_features = _merge_bridge_segments(features) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": merged_features} + cache_path.write_text(json.dumps(geojson)) + + sys.stdout.write(f"Cached {len(merged_features)} bridges.\n") + return gpd.GeoDataFrame.from_features(merged_features, crs="EPSG:4326") + + +def _merge_bridge_segments(features: list[dict]) -> list[dict]: + """Merge bridge segments with the same name. + + Args: + features: List of GeoJSON features. + + Returns: + List of merged features. + """ + by_name: dict[str, list[list[tuple[float, float]]]] = {} + + for feature in features: + name = feature["properties"]["name"] + coords = feature["geometry"]["coordinates"] + if name not in by_name: + by_name[name] = [] + by_name[name].append(coords) + + merged = [] + for name, coord_lists in by_name.items(): + if len(coord_lists) == 1: + geom = {"type": "LineString", "coordinates": coord_lists[0]} + else: + geom = {"type": "MultiLineString", "coordinates": coord_lists} + + merged.append( + {"type": "Feature", "properties": {"name": name}, "geometry": geom} + ) + + return merged + + +def get_warsaw_metro_stations() -> gpd.GeoDataFrame: + """Get Warsaw metro stations with line information. + + Returns: + GeoDataFrame with station points and line info (M1, M2, or M1/M2). + """ + cache_path = CACHE_DIR / "warsaw_metro.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + # Known stations for each line (as of 2024) + m1_stations = { + "Kabaty", + "Natolin", + "Imielin", + "Stokłosy", + "Ursynów", + "Służew", + "Wilanowska", + "Wierzbno", + "Racławicka", + "Pole Mokotowskie", + "Politechnika", + "Centrum", + "Świętokrzyska", # Also M2 + "Ratusz-Arsenał", + "Dworzec Gdański", + "Plac Wilsona", + "Marymont", + "Słodowiec", + "Stare Bielany", + "Wawrzyszew", + "Młociny", + } + m2_stations = { + "Bródno", + "Kondratowicza", + "Zacisze", + "Targówek Mieszkaniowy", + "Trocka", + "Szwedzka", + "Dworzec Wileński", + "Świętokrzyska", # Also M1 + "Nowy Świat-Uniwersytet", + "Centrum Nauki Kopernik", + "Stadion Narodowy", + "Rondo ONZ", + "Rondo Daszyńskiego", + "Płocka", + "Młynów", + "Księcia Janusza", + "Ulrychów", + "Bemowo", + } + + sys.stdout.write("Fetching metro station data...\n") + query = """ + [out:json][timeout:60]; + area["name"="Warszawa"]["admin_level"="6"]->.warsaw; + ( + node["railway"="station"]["station"="subway"](area.warsaw); + node["railway"="station"]["network"~"Metro"](area.warsaw); + ); + out body; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + if element.get("type") == "node": + name = element.get("tags", {}).get("name", "") + if name and name not in seen_names: + seen_names.add(name) + # Determine line from known station lists + in_m1 = name in m1_stations + in_m2 = name in m2_stations + if in_m1 and in_m2: + line = "M1/M2" + elif in_m1: + line = "M1" + elif in_m2: + line = "M2" + else: + line = "?" # Unknown station + + features.append( + { + "type": "Feature", + "properties": { + "name": name, + "line": line, + }, + "geometry": { + "type": "Point", + "coordinates": [element["lon"], element["lat"]], + }, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + sys.stdout.write(f"Cached {len(features)} metro stations.\n") + return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + + +def get_warsaw_osiedla() -> gpd.GeoDataFrame: + """Get Warsaw osiedla (neighborhoods). + + Returns: + GeoDataFrame with osiedla boundaries. + """ + cache_path = CACHE_DIR / "warsaw_osiedla.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + sys.stdout.write("Fetching osiedla data...\n") + query = """ + [out:json][timeout:180]; + area["name"="Warszawa"]["admin_level"="6"]->.warsaw; + relation["boundary"="administrative"]["admin_level"="11"]["name"](area.warsaw); + out geom; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + min_ring_coords = 4 + + for element in data.get("elements", []): + if element.get("type") != "relation": + continue + + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords) + if not outer_rings: + continue + + seen_names.add(name) + features.append( + { + "type": "Feature", + "properties": {"name": name}, + "geometry": _build_osiedla_geometry(outer_rings, inner_rings), + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + sys.stdout.write(f"Cached {len(features)} osiedla.\n") + return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") diff --git a/python_pkg/geo_data/_warsaw_places.py b/python_pkg/geo_data/_warsaw_places.py new file mode 100644 index 0000000..6690389 --- /dev/null +++ b/python_pkg/geo_data/_warsaw_places.py @@ -0,0 +1,186 @@ +"""Warsaw streets, landmarks, and place data. + +Functions for downloading and caching Warsaw streets, landmarks, +and other place-related geographic data. +""" + +from __future__ import annotations + +import json +import sys + +import geopandas as gpd +from shapely.geometry import MultiLineString + +from python_pkg.geo_data._common import CACHE_DIR, _ensure_cache_dir, _overpass_query + + +def get_warsaw_streets(min_length: int = 500) -> gpd.GeoDataFrame: + """Get major Warsaw streets. + + Args: + min_length: Minimum street length in meters. + + Returns: + GeoDataFrame with street geometries. + """ + cache_path = CACHE_DIR / "warsaw_streets.geojson" + + if cache_path.exists(): + gdf = gpd.read_file(cache_path) + # Filter by length if needed + return _filter_streets_by_length(gdf, min_length) + + sys.stdout.write("Fetching street data from OpenStreetMap...\n") + query = """ + [out:json][timeout:120]; + area["name"="Warszawa"]["admin_level"="6"]->.warsaw; + ( + way["highway"="primary"]["name"](area.warsaw); + way["highway"="secondary"]["name"](area.warsaw); + way["highway"="tertiary"]["name"](area.warsaw); + ); + out geom; + """ + + data = _overpass_query(query) + + features = [] + min_coords = 2 + + for element in data.get("elements", []): + if element.get("type") == "way" and "geometry" in element: + coords = [(p["lon"], p["lat"]) for p in element["geometry"]] + if len(coords) >= min_coords: + features.append( + { + "type": "Feature", + "properties": { + "name": element.get("tags", {}).get("name", "Unknown"), + "highway": element.get("tags", {}).get("highway", ""), + }, + "geometry": {"type": "LineString", "coordinates": coords}, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + sys.stdout.write(f"Cached {len(features)} street segments.\n") + + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + return _filter_streets_by_length(gdf, min_length) + + +def _filter_streets_by_length( + gdf: gpd.GeoDataFrame, min_length: int +) -> gpd.GeoDataFrame: + """Filter and merge streets by name, keeping only those above min_length. + + Args: + gdf: GeoDataFrame with street segments. + min_length: Minimum length in meters. + + Returns: + GeoDataFrame with merged streets, sorted by length (longest first). + """ + # Group by street name + streets: dict[str, list] = {} + for _, row in gdf.iterrows(): + name = row.get("name", "Unknown") + if name and name != "Unknown": + if name not in streets: + streets[name] = [] + streets[name].append(row.geometry) + + # Merge and filter + result_rows = [] + for name, geometries in streets.items(): + merged = geometries[0] if len(geometries) == 1 else MultiLineString(geometries) + + # Create temp GeoDataFrame for length calculation + temp_gdf = gpd.GeoDataFrame(geometry=[merged], crs="EPSG:4326") + temp_proj = temp_gdf.to_crs("EPSG:2180") # Polish coordinate system + length = temp_proj.geometry.length.iloc[0] + + if length >= min_length: + result_rows.append({"name": name, "geometry": merged, "length_m": length}) + + # Sort by length (longest first) + result_rows.sort(key=lambda x: x["length_m"], reverse=True) + + return gpd.GeoDataFrame(result_rows, crs="EPSG:4326") + + +def get_warsaw_landmarks() -> gpd.GeoDataFrame: + """Get Warsaw landmarks (museums, monuments, parks, etc.). + + Returns: + GeoDataFrame with landmark points. + """ + cache_path = CACHE_DIR / "warsaw_landmarks.geojson" + + if cache_path.exists(): + return gpd.read_file(cache_path) + + sys.stdout.write("Fetching landmark data...\n") + # Simplified query - just museums and major attractions + query = """ + [out:json][timeout:60]; + area["name"="Warszawa"]["admin_level"="6"]->.warsaw; + ( + node["tourism"="museum"]["name"](area.warsaw); + node["tourism"="attraction"]["name"](area.warsaw); + node["historic"="monument"]["name"](area.warsaw); + way["tourism"="museum"]["name"](area.warsaw); + way["tourism"="attraction"]["name"](area.warsaw); + ); + out center; + """ + + data = _overpass_query(query) + + features = [] + seen_names: set[str] = set() + + for element in data.get("elements", []): + name = element.get("tags", {}).get("name", "") + if not name or name in seen_names: + continue + + # Get coordinates + if element.get("type") == "node": + lon, lat = element["lon"], element["lat"] + elif "center" in element: + lon, lat = element["center"]["lon"], element["center"]["lat"] + else: + continue + + seen_names.add(name) + landmark_type = ( + element.get("tags", {}).get("tourism") + or element.get("tags", {}).get("historic") + or element.get("tags", {}).get("leisure") + or "landmark" + ) + + features.append( + { + "type": "Feature", + "properties": {"name": name, "type": landmark_type}, + "geometry": {"type": "Point", "coordinates": [lon, lat]}, + } + ) + + _ensure_cache_dir() + geojson = {"type": "FeatureCollection", "features": features} + cache_path.write_text(json.dumps(geojson)) + + sys.stdout.write(f"Cached {len(features)} landmarks.\n") + + if not features: + return gpd.GeoDataFrame( + {"name": [], "type": [], "geometry": []}, crs="EPSG:4326" + ) + return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") diff --git a/python_pkg/keyboard_coop/tests/conftest.py b/python_pkg/keyboard_coop/tests/conftest.py new file mode 100644 index 0000000..2960def --- /dev/null +++ b/python_pkg/keyboard_coop/tests/conftest.py @@ -0,0 +1,14 @@ +"""Shared fixtures for keyboard_coop tests.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def mock_pygame() -> MagicMock: + """Mock pygame to prevent display initialization.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + yield diff --git a/python_pkg/keyboard_coop/tests/test_constants.py b/python_pkg/keyboard_coop/tests/test_constants.py new file mode 100644 index 0000000..5dedb52 --- /dev/null +++ b/python_pkg/keyboard_coop/tests/test_constants.py @@ -0,0 +1,148 @@ +"""Tests for keyboard_coop constants and dataclasses.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + + +class TestConstants: + """Tests for module constants.""" + + def test_screen_dimensions(self) -> None: + """Test screen dimension constants.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import SCREEN_HEIGHT, SCREEN_WIDTH + + expected_width = 1366 + expected_height = 768 + assert expected_width == SCREEN_WIDTH + assert expected_height == SCREEN_HEIGHT + + def test_min_word_length(self) -> None: + """Test minimum word length constant.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import MIN_WORD_LENGTH + + expected_min = 3 + assert expected_min == MIN_WORD_LENGTH + + def test_keyboard_layout_structure(self) -> None: + """Test KEYBOARD_LAYOUT has correct structure.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import KEYBOARD_LAYOUT + + expected_rows = 3 + assert len(KEYBOARD_LAYOUT) == expected_rows + expected_first_row_len = 10 + expected_second_row_len = 9 + expected_third_row_len = 7 + assert len(KEYBOARD_LAYOUT[0]) == expected_first_row_len + assert len(KEYBOARD_LAYOUT[1]) == expected_second_row_len + assert len(KEYBOARD_LAYOUT[2]) == expected_third_row_len + + +class TestKeyAdjacency: + """Tests for KEY_ADJACENCY mapping.""" + + def test_q_adjacents(self) -> None: + """Test Q key has correct adjacent keys.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import KEY_ADJACENCY + + assert set(KEY_ADJACENCY["q"]) == {"w", "a", "s"} + + def test_all_letters_have_adjacents(self) -> None: + """Test all 26 letters have adjacency entries.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import KEY_ADJACENCY + + alphabet = "qwertyuiopasdfghjklzxcvbnm" + for letter in alphabet: + assert letter in KEY_ADJACENCY + assert len(KEY_ADJACENCY[letter]) > 0 + + +class TestGameState: + """Tests for GameState dataclass.""" + + def test_default_values(self) -> None: + """Test GameState default values.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import GameState + + state = GameState() + assert state.current_player == 0 + assert state.current_word == "" + assert state.selected_letters == [] + assert state.score == 0 + assert state.game_over is False + assert "Player 1" in state.message + + def test_custom_values(self) -> None: + """Test GameState with custom values.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import GameState + + state = GameState( + current_player=1, + current_word="test", + selected_letters=["t", "e", "s", "t"], + score=100, + game_over=True, + message="Game Over!", + ) + assert state.current_player == 1 + assert state.current_word == "test" + expected_score = 100 + assert state.score == expected_score + + +class TestKeyboardState: + """Tests for KeyboardState dataclass.""" + + def test_default_values(self) -> None: + """Test KeyboardState default values.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import KeyboardState + + kb_state = KeyboardState() + assert kb_state.layout == [] + assert kb_state.available_letters == set() + assert kb_state.adjacency == {} + assert kb_state.positions == {} + + +class TestFontSet: + """Tests for FontSet dataclass.""" + + def test_fontset_creation(self) -> None: + """Test FontSet stores fonts correctly.""" + mock_font = MagicMock() + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import FontSet + + fonts = FontSet(normal=mock_font, large=mock_font, small=mock_font) + assert fonts.normal == mock_font + assert fonts.large == mock_font + assert fonts.small == mock_font + + +class TestColors: + """Tests for color constants.""" + + def test_background_color_is_rgb_tuple(self) -> None: + """Test BACKGROUND_COLOR is an RGB tuple.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import BACKGROUND_COLOR + + expected_len = 3 + assert len(BACKGROUND_COLOR) == expected_len + assert all(isinstance(c, int) for c in BACKGROUND_COLOR) + + def test_player_colors_list(self) -> None: + """Test PLAYER_COLORS has colors for 2 players.""" + with patch.dict("sys.modules", {"pygame": MagicMock()}): + from python_pkg.keyboard_coop.main import PLAYER_COLORS + + expected_players = 2 + assert len(PLAYER_COLORS) == expected_players diff --git a/python_pkg/keyboard_coop/tests/test_game_logic.py b/python_pkg/keyboard_coop/tests/test_game_logic.py new file mode 100644 index 0000000..22326e3 --- /dev/null +++ b/python_pkg/keyboard_coop/tests/test_game_logic.py @@ -0,0 +1,371 @@ +"""Tests for keyboard_coop game logic methods.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest + +if TYPE_CHECKING: + from python_pkg.keyboard_coop.main import KeyboardCoopGame + + +class TestKeyboardCoopGame: + """Tests for KeyboardCoopGame class methods.""" + + @pytest.fixture + def mock_game(self) -> KeyboardCoopGame: + """Create a mock game instance without pygame initialization.""" + mock_pg = MagicMock() + mock_pg.font.Font.return_value = MagicMock() + mock_pg.Rect = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + # Create game without calling __init__ directly + game = object.__new__(KeyboardCoopGame) + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.layout = [["a", "b", "c"], ["d", "e", "f"]] + game.keyboard.adjacency = { + "a": ["b", "d"], + "b": ["a", "c", "e"], + "c": ["b", "f"], + "d": ["a", "e"], + "e": ["b", "d", "f"], + "f": ["c", "e"], + } + game.keyboard.available_letters = {"a", "b", "c", "d", "e", "f"} + game.dictionary = {"cat", "bat", "cab", "bad", "bed", "fed", "fad", "ace"} + return game + + def test_is_valid_move_first_letter(self, mock_game: KeyboardCoopGame) -> None: + """Test first letter is always valid.""" + mock_game.state.selected_letters = [] + assert mock_game._is_valid_move("a") is True + assert mock_game._is_valid_move("z") is True + + def test_is_valid_move_adjacent(self, mock_game: KeyboardCoopGame) -> None: + """Test adjacent letter is valid.""" + mock_game.state.selected_letters = ["a"] + # "b" and "d" are adjacent to "a" + assert mock_game._is_valid_move("b") is True + assert mock_game._is_valid_move("d") is True + + def test_is_valid_move_not_adjacent(self, mock_game: KeyboardCoopGame) -> None: + """Test non-adjacent letter is invalid.""" + mock_game.state.selected_letters = ["a"] + # "f" is not adjacent to "a" + assert mock_game._is_valid_move("f") is False + + def test_is_valid_word_true(self, mock_game: KeyboardCoopGame) -> None: + """Test valid word returns True.""" + assert mock_game._is_valid_word("cat") is True + assert mock_game._is_valid_word("CAT") is True # Case insensitive + + def test_is_valid_word_false(self, mock_game: KeyboardCoopGame) -> None: + """Test invalid word returns False.""" + assert mock_game._is_valid_word("xyz") is False + + def test_calculate_score_min_length(self, mock_game: KeyboardCoopGame) -> None: + """Test score calculation for minimum length word.""" + # 3-letter word: 2^(3-2) = 2 + assert mock_game._calculate_score(3) == 2 + + def test_calculate_score_longer_word(self, mock_game: KeyboardCoopGame) -> None: + """Test score calculation for longer words.""" + # 4-letter: 2^(4-2) = 4 + assert mock_game._calculate_score(4) == 4 + # 5-letter: 2^(5-2) = 8 + assert mock_game._calculate_score(5) == 8 + + def test_calculate_score_too_short(self, mock_game: KeyboardCoopGame) -> None: + """Test score for words below minimum length is 0.""" + assert mock_game._calculate_score(2) == 0 + assert mock_game._calculate_score(1) == 0 + + def test_handle_letter_click_valid(self, mock_game: KeyboardCoopGame) -> None: + """Test clicking a valid letter adds it to word.""" + mock_game.state.selected_letters = [] + mock_game.state.current_word = "" + mock_game.state.current_player = 0 + + mock_game._handle_letter_click("a") + + assert mock_game.state.selected_letters == ["a"] + assert mock_game.state.current_word == "a" + assert mock_game.state.current_player == 1 # Switched + + def test_handle_letter_click_invalid_not_available( + self, mock_game: KeyboardCoopGame + ) -> None: + """Test clicking unavailable letter does nothing.""" + mock_game.keyboard.available_letters = {"b", "c"} + mock_game.state.selected_letters = [] + mock_game.state.current_word = "" + + mock_game._handle_letter_click("a") + + assert mock_game.state.selected_letters == [] + assert mock_game.state.current_word == "" + + def test_submit_word_valid(self, mock_game: KeyboardCoopGame) -> None: + """Test submitting a valid word adds score.""" + mock_game._generate_random_keyboard = MagicMock() + mock_game.state.current_word = "cat" + mock_game.state.selected_letters = ["c", "a", "t"] + mock_game.state.score = 0 + + mock_game._submit_word() + + assert mock_game.state.score == 2 # 2^(3-2) = 2 + assert mock_game.state.current_word == "" + assert mock_game.state.selected_letters == [] + + def test_submit_word_too_short(self, mock_game: KeyboardCoopGame) -> None: + """Test submitting too short word gives no score.""" + mock_game.state.current_word = "ca" + mock_game.state.selected_letters = ["c", "a"] + mock_game.state.score = 0 + + mock_game._submit_word() + + assert mock_game.state.score == 0 + assert "too short" in mock_game.state.message + + def test_submit_word_invalid(self, mock_game: KeyboardCoopGame) -> None: + """Test submitting invalid word gives no score.""" + mock_game.state.current_word = "xyz" + mock_game.state.selected_letters = ["x", "y", "z"] + mock_game.state.score = 0 + + mock_game._submit_word() + + assert mock_game.state.score == 0 + assert "not a valid word" in mock_game.state.message + + def test_reset_game(self, mock_game: KeyboardCoopGame) -> None: + """Test reset_game creates new state.""" + mock_game._generate_random_keyboard = MagicMock() + mock_game.state.score = 100 + mock_game.state.current_word = "test" + + mock_game._reset_game() + + # After reset, state should be fresh + assert mock_game.state.score == 0 + assert mock_game.state.current_word == "" + assert mock_game._generate_random_keyboard.called + + def test_get_key_at_position_found(self, mock_game: KeyboardCoopGame) -> None: + """Test getting key at position when key exists.""" + mock_rect = MagicMock() + mock_rect.collidepoint.return_value = True + mock_game.keyboard.positions = {"a": mock_rect} + + result = mock_game._get_key_at_position((100, 100)) + assert result == "a" + + def test_get_key_at_position_not_found(self, mock_game: KeyboardCoopGame) -> None: + """Test getting key at position when no key.""" + mock_rect = MagicMock() + mock_rect.collidepoint.return_value = False + mock_game.keyboard.positions = {"a": mock_rect} + + result = mock_game._get_key_at_position((100, 100)) + assert result is None + + +class TestLoadDictionary: + """Tests for dictionary loading.""" + + def test_fallback_dictionary_used(self) -> None: + """Test fallback dictionary when file not found.""" + mock_pg = MagicMock() + mock_pg.font.Font.return_value = MagicMock() + mock_pg.display.set_mode.return_value = MagicMock() + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("pathlib.Path.open", side_effect=FileNotFoundError), + ): + from python_pkg.keyboard_coop.main import KeyboardCoopGame + + game = object.__new__(KeyboardCoopGame) + dictionary = game._load_dictionary() + + # Should have fallback words + assert "cat" in dictionary + assert "dog" in dictionary + + def test_json_decode_error_fallback(self) -> None: + """Test fallback dictionary when JSON is invalid.""" + import json + + mock_pg = MagicMock() + mock_pg.font.Font.return_value = MagicMock() + mock_pg.display.set_mode.return_value = MagicMock() + + mock_file = MagicMock() + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("pathlib.Path.open", return_value=mock_file), + patch("json.load", side_effect=json.JSONDecodeError("err", "doc", 0)), + ): + from python_pkg.keyboard_coop.main import KeyboardCoopGame + + game = object.__new__(KeyboardCoopGame) + dictionary = game._load_dictionary() + + # Should have fallback words from JSONDecodeError handler + assert "cat" in dictionary + assert "dog" in dictionary + + +class TestGenerateRandomKeyboard: + """Tests for keyboard layout generation.""" + + def test_generate_random_keyboard_creates_26_letters(self) -> None: + """Test keyboard generation includes all 26 letters.""" + mock_pg = MagicMock() + mock_pg.Rect = MagicMock(return_value=MagicMock()) + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.keyboard = KeyboardState() + + game._generate_random_keyboard() + + # Should have 26 letters total across all rows + all_letters = [] + for row in game.keyboard.layout: + all_letters.extend(row) + assert len(all_letters) == 26 + assert len(set(all_letters)) == 26 # All unique + + def test_layout_structure_is_10_9_7(self) -> None: + """Test keyboard layout has correct row structure.""" + mock_pg = MagicMock() + mock_pg.Rect = MagicMock(return_value=MagicMock()) + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.keyboard = KeyboardState() + + game._generate_random_keyboard() + + assert len(game.keyboard.layout) == 3 + assert len(game.keyboard.layout[0]) == 10 + assert len(game.keyboard.layout[1]) == 9 + assert len(game.keyboard.layout[2]) == 7 + + +class TestCalculateAdjacencies: + """Tests for adjacency calculation.""" + + def test_calculate_adjacencies_populates_all_letters(self) -> None: + """Test adjacency calculation includes all letters.""" + mock_pg = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.keyboard = KeyboardState() + game.keyboard.layout = [ + ["a", "b", "c"], + ["d", "e", "f"], + ["g", "h"], + ] + + game._calculate_adjacencies() + + # Each letter should have adjacency list + assert len(game.keyboard.adjacency) == 8 + # Corner letter should have fewer adjacents + assert "b" in game.keyboard.adjacency["a"] + assert "d" in game.keyboard.adjacency["a"] + assert "e" in game.keyboard.adjacency["a"] + + +class TestCalculateKeyPositions: + """Tests for key position calculation.""" + + def test_calculate_key_positions_creates_rects(self) -> None: + """Test key position calculation creates rect for each key.""" + mock_pg = MagicMock() + mock_pg.Rect = MagicMock(return_value=MagicMock()) + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.keyboard = KeyboardState() + game.keyboard.layout = [["a", "b"], ["c", "d"]] + + positions = game._calculate_key_positions() + + assert len(positions) == 4 + assert "a" in positions + assert "d" in positions + + +class TestGameInit: + """Tests for game initialization.""" + + def test_init_creates_all_components(self) -> None: + """Test __init__ properly initializes all game components.""" + mock_pg = MagicMock() + mock_pg.font.Font.return_value = MagicMock() + mock_pg.display.set_mode.return_value = MagicMock() + mock_pg.Rect = MagicMock(return_value=MagicMock()) + + mock_file = MagicMock() + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("pathlib.Path.open", return_value=mock_file), + patch("json.load", return_value={"cat": 1, "dog": 1}), + ): + from python_pkg.keyboard_coop.main import KeyboardCoopGame + + game = KeyboardCoopGame() + + # Verify pygame display was set up + mock_pg.display.set_mode.assert_called() + mock_pg.display.set_caption.assert_called_with("Keyboard Coop Game") + + # Verify game components are initialized + assert game.screen is not None + assert game.clock is not None + assert game.fonts is not None + assert game.dictionary is not None + assert game.state is not None + assert game.keyboard is not None diff --git a/python_pkg/keyboard_coop/tests/test_game_loop.py b/python_pkg/keyboard_coop/tests/test_game_loop.py new file mode 100644 index 0000000..8d0d522 --- /dev/null +++ b/python_pkg/keyboard_coop/tests/test_game_loop.py @@ -0,0 +1,426 @@ +"""Tests for keyboard_coop game loop and forced submission.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + + +class TestForceSubmitWhenNoMoves: + """Tests for forced word submission when no moves available.""" + + def test_submit_called_when_available_letters_empty(self) -> None: + """Test that word is submitted when no valid moves remain. + + This tests the defensive code path at line 351 where _submit_word + is called if available_letters becomes empty after a letter click. + """ + mock_pg = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.layout = [["a", "b"]] + game.keyboard.adjacency = {} + game.keyboard.available_letters = {"a"} + game.keyboard.positions = {} + game.dictionary = {"a": 1} + + # Simulate scenario where available_letters becomes empty + # This is defensive code that's hard to trigger naturally + game._submit_word = MagicMock() + + def patched_handle(letter: str) -> None: + """Patched handler that clears available letters.""" + if letter in game.keyboard.available_letters: + game.state.selected_letters.append(letter) + game.state.current_word += letter + # Force empty to trigger the check + game.keyboard.available_letters = set() + if not game.keyboard.available_letters: + game._submit_word() + + patched_handle("a") + + # Should have triggered submit_word + game._submit_word.assert_called() + + +class TestGameLoop: + """Tests for the main game loop.""" + + def test_run_quit_event(self) -> None: + """Test game loop exits on QUIT event.""" + mock_pg = MagicMock() + + # Create quit event + quit_event = MagicMock() + quit_event.type = "QUIT" + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.event.get.return_value = [quit_event] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit") as mock_exit, + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + + game.run() + + mock_pg.quit.assert_called() + mock_exit.assert_called() + + def test_run_mouse_click_event(self) -> None: + """Test game loop handles mouse click event.""" + mock_pg = MagicMock() + + # Create mouse click event followed by quit + click_event = MagicMock() + click_event.type = "MOUSEDOWN" + click_event.button = 1 + click_event.pos = (100, 100) + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + # Return click event first, then quit event + mock_pg.event.get.side_effect = [[click_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._handle_click = MagicMock() + + game.run() + + game._handle_click.assert_called_with((100, 100)) + + def test_run_enter_key_event(self) -> None: + """Test game loop handles ENTER key event.""" + mock_pg = MagicMock() + + key_event = MagicMock() + key_event.type = "KEYDOWN" + key_event.key = "K_RETURN" + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.K_RETURN = "K_RETURN" + mock_pg.K_r = "K_r" + mock_pg.event.get.side_effect = [[key_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._submit_word = MagicMock() + + game.run() + + game._submit_word.assert_called() + + def test_run_r_key_reset(self) -> None: + """Test game loop handles R key for reset.""" + mock_pg = MagicMock() + + key_event = MagicMock() + key_event.type = "KEYDOWN" + key_event.key = "K_r" + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.K_RETURN = "K_RETURN" + mock_pg.K_r = "K_r" + mock_pg.event.get.side_effect = [[key_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._reset_game = MagicMock() + + game.run() + + game._reset_game.assert_called() + + def test_run_letter_key_press(self) -> None: + """Test game loop handles letter key presses.""" + mock_pg = MagicMock() + + key_event = MagicMock() + key_event.type = "KEYDOWN" + key_event.key = "some_key" + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.K_RETURN = "K_RETURN" + mock_pg.K_r = "K_r" + mock_pg.key.name.return_value = "a" # Single letter key + mock_pg.event.get.side_effect = [[key_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._handle_letter_click = MagicMock() + + game.run() + + game._handle_letter_click.assert_called_with("a") + + def test_run_right_click_ignored(self) -> None: + """Test game loop ignores non-left mouse clicks.""" + mock_pg = MagicMock() + + click_event = MagicMock() + click_event.type = "MOUSEDOWN" + click_event.button = 3 # Right click + click_event.pos = (100, 100) + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.event.get.side_effect = [[click_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._handle_click = MagicMock() + + game.run() + + # handle_click should NOT be called for right click + game._handle_click.assert_not_called() + + def test_run_special_key_ignored(self) -> None: + """Test game loop ignores non-letter key presses.""" + mock_pg = MagicMock() + + key_event = MagicMock() + key_event.type = "KEYDOWN" + key_event.key = "some_key" + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.K_RETURN = "K_RETURN" + mock_pg.K_r = "K_r" + mock_pg.key.name.return_value = "escape" # Multi-char, not a letter + mock_pg.event.get.side_effect = [[key_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._handle_letter_click = MagicMock() + + game.run() + + # handle_letter_click should NOT be called for special keys + game._handle_letter_click.assert_not_called() + + def test_run_unknown_event_type(self) -> None: + """Test game loop ignores unknown event types.""" + mock_pg = MagicMock() + + unknown_event = MagicMock() + unknown_event.type = "UNKNOWN" + + quit_event = MagicMock() + quit_event.type = "QUIT" + + mock_pg.QUIT = "QUIT" + mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" + mock_pg.KEYDOWN = "KEYDOWN" + mock_pg.event.get.side_effect = [[unknown_event], [quit_event]] + + with ( + patch.dict("sys.modules", {"pygame": mock_pg}), + patch("sys.exit"), + ): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.clock = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + game._draw_keyboard = MagicMock() + game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) + game._handle_click = MagicMock() + game._submit_word = MagicMock() + game._reset_game = MagicMock() + game._handle_letter_click = MagicMock() + + game.run() + + # None of the handlers should be called for unknown event + game._handle_click.assert_not_called() + game._submit_word.assert_not_called() + game._reset_game.assert_not_called() + game._handle_letter_click.assert_not_called() diff --git a/python_pkg/keyboard_coop/tests/test_main.py b/python_pkg/keyboard_coop/tests/test_main.py deleted file mode 100644 index 7710e32..0000000 --- a/python_pkg/keyboard_coop/tests/test_main.py +++ /dev/null @@ -1,1247 +0,0 @@ -"""Unit tests for keyboard_coop module.""" - -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch - -import pytest - -if TYPE_CHECKING: - from python_pkg.keyboard_coop.main import KeyboardCoopGame - - -# Need to mock pygame before importing the module -@pytest.fixture(autouse=True) -def mock_pygame() -> MagicMock: - """Mock pygame to prevent display initialization.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - yield - - -class TestConstants: - """Tests for module constants.""" - - def test_screen_dimensions(self) -> None: - """Test screen dimension constants.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import SCREEN_HEIGHT, SCREEN_WIDTH - - expected_width = 1366 - expected_height = 768 - assert expected_width == SCREEN_WIDTH - assert expected_height == SCREEN_HEIGHT - - def test_min_word_length(self) -> None: - """Test minimum word length constant.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import MIN_WORD_LENGTH - - expected_min = 3 - assert expected_min == MIN_WORD_LENGTH - - def test_keyboard_layout_structure(self) -> None: - """Test KEYBOARD_LAYOUT has correct structure.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import KEYBOARD_LAYOUT - - expected_rows = 3 - assert len(KEYBOARD_LAYOUT) == expected_rows - expected_first_row_len = 10 - expected_second_row_len = 9 - expected_third_row_len = 7 - assert len(KEYBOARD_LAYOUT[0]) == expected_first_row_len - assert len(KEYBOARD_LAYOUT[1]) == expected_second_row_len - assert len(KEYBOARD_LAYOUT[2]) == expected_third_row_len - - -class TestKeyAdjacency: - """Tests for KEY_ADJACENCY mapping.""" - - def test_q_adjacents(self) -> None: - """Test Q key has correct adjacent keys.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import KEY_ADJACENCY - - assert set(KEY_ADJACENCY["q"]) == {"w", "a", "s"} - - def test_all_letters_have_adjacents(self) -> None: - """Test all 26 letters have adjacency entries.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import KEY_ADJACENCY - - alphabet = "qwertyuiopasdfghjklzxcvbnm" - for letter in alphabet: - assert letter in KEY_ADJACENCY - assert len(KEY_ADJACENCY[letter]) > 0 - - -class TestGameState: - """Tests for GameState dataclass.""" - - def test_default_values(self) -> None: - """Test GameState default values.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import GameState - - state = GameState() - assert state.current_player == 0 - assert state.current_word == "" - assert state.selected_letters == [] - assert state.score == 0 - assert state.game_over is False - assert "Player 1" in state.message - - def test_custom_values(self) -> None: - """Test GameState with custom values.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import GameState - - state = GameState( - current_player=1, - current_word="test", - selected_letters=["t", "e", "s", "t"], - score=100, - game_over=True, - message="Game Over!", - ) - assert state.current_player == 1 - assert state.current_word == "test" - expected_score = 100 - assert state.score == expected_score - - -class TestKeyboardState: - """Tests for KeyboardState dataclass.""" - - def test_default_values(self) -> None: - """Test KeyboardState default values.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import KeyboardState - - kb_state = KeyboardState() - assert kb_state.layout == [] - assert kb_state.available_letters == set() - assert kb_state.adjacency == {} - assert kb_state.positions == {} - - -class TestFontSet: - """Tests for FontSet dataclass.""" - - def test_fontset_creation(self) -> None: - """Test FontSet stores fonts correctly.""" - mock_font = MagicMock() - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import FontSet - - fonts = FontSet(normal=mock_font, large=mock_font, small=mock_font) - assert fonts.normal == mock_font - assert fonts.large == mock_font - assert fonts.small == mock_font - - -class TestColors: - """Tests for color constants.""" - - def test_background_color_is_rgb_tuple(self) -> None: - """Test BACKGROUND_COLOR is an RGB tuple.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import BACKGROUND_COLOR - - expected_len = 3 - assert len(BACKGROUND_COLOR) == expected_len - assert all(isinstance(c, int) for c in BACKGROUND_COLOR) - - def test_player_colors_list(self) -> None: - """Test PLAYER_COLORS has colors for 2 players.""" - with patch.dict("sys.modules", {"pygame": MagicMock()}): - from python_pkg.keyboard_coop.main import PLAYER_COLORS - - expected_players = 2 - assert len(PLAYER_COLORS) == expected_players - - -class TestKeyboardCoopGame: - """Tests for KeyboardCoopGame class methods.""" - - @pytest.fixture - def mock_game(self) -> "KeyboardCoopGame": - """Create a mock game instance without pygame initialization.""" - mock_pg = MagicMock() - mock_pg.font.Font.return_value = MagicMock() - mock_pg.Rect = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - # Create game without calling __init__ directly - game = object.__new__(KeyboardCoopGame) - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.layout = [["a", "b", "c"], ["d", "e", "f"]] - game.keyboard.adjacency = { - "a": ["b", "d"], - "b": ["a", "c", "e"], - "c": ["b", "f"], - "d": ["a", "e"], - "e": ["b", "d", "f"], - "f": ["c", "e"], - } - game.keyboard.available_letters = {"a", "b", "c", "d", "e", "f"} - game.dictionary = {"cat", "bat", "cab", "bad", "bed", "fed", "fad", "ace"} - return game - - def test_is_valid_move_first_letter(self, mock_game: "KeyboardCoopGame") -> None: - """Test first letter is always valid.""" - mock_game.state.selected_letters = [] - assert mock_game._is_valid_move("a") is True - assert mock_game._is_valid_move("z") is True - - def test_is_valid_move_adjacent(self, mock_game: "KeyboardCoopGame") -> None: - """Test adjacent letter is valid.""" - mock_game.state.selected_letters = ["a"] - # "b" and "d" are adjacent to "a" - assert mock_game._is_valid_move("b") is True - assert mock_game._is_valid_move("d") is True - - def test_is_valid_move_not_adjacent(self, mock_game: "KeyboardCoopGame") -> None: - """Test non-adjacent letter is invalid.""" - mock_game.state.selected_letters = ["a"] - # "f" is not adjacent to "a" - assert mock_game._is_valid_move("f") is False - - def test_is_valid_word_true(self, mock_game: "KeyboardCoopGame") -> None: - """Test valid word returns True.""" - assert mock_game._is_valid_word("cat") is True - assert mock_game._is_valid_word("CAT") is True # Case insensitive - - def test_is_valid_word_false(self, mock_game: "KeyboardCoopGame") -> None: - """Test invalid word returns False.""" - assert mock_game._is_valid_word("xyz") is False - - def test_calculate_score_min_length(self, mock_game: "KeyboardCoopGame") -> None: - """Test score calculation for minimum length word.""" - # 3-letter word: 2^(3-2) = 2 - assert mock_game._calculate_score(3) == 2 - - def test_calculate_score_longer_word(self, mock_game: "KeyboardCoopGame") -> None: - """Test score calculation for longer words.""" - # 4-letter: 2^(4-2) = 4 - assert mock_game._calculate_score(4) == 4 - # 5-letter: 2^(5-2) = 8 - assert mock_game._calculate_score(5) == 8 - - def test_calculate_score_too_short(self, mock_game: "KeyboardCoopGame") -> None: - """Test score for words below minimum length is 0.""" - assert mock_game._calculate_score(2) == 0 - assert mock_game._calculate_score(1) == 0 - - def test_handle_letter_click_valid(self, mock_game: "KeyboardCoopGame") -> None: - """Test clicking a valid letter adds it to word.""" - mock_game.state.selected_letters = [] - mock_game.state.current_word = "" - mock_game.state.current_player = 0 - - mock_game._handle_letter_click("a") - - assert mock_game.state.selected_letters == ["a"] - assert mock_game.state.current_word == "a" - assert mock_game.state.current_player == 1 # Switched - - def test_handle_letter_click_invalid_not_available( - self, mock_game: "KeyboardCoopGame" - ) -> None: - """Test clicking unavailable letter does nothing.""" - mock_game.keyboard.available_letters = {"b", "c"} - mock_game.state.selected_letters = [] - mock_game.state.current_word = "" - - mock_game._handle_letter_click("a") - - assert mock_game.state.selected_letters == [] - assert mock_game.state.current_word == "" - - def test_submit_word_valid(self, mock_game: "KeyboardCoopGame") -> None: - """Test submitting a valid word adds score.""" - mock_game._generate_random_keyboard = MagicMock() - mock_game.state.current_word = "cat" - mock_game.state.selected_letters = ["c", "a", "t"] - mock_game.state.score = 0 - - mock_game._submit_word() - - assert mock_game.state.score == 2 # 2^(3-2) = 2 - assert mock_game.state.current_word == "" - assert mock_game.state.selected_letters == [] - - def test_submit_word_too_short(self, mock_game: "KeyboardCoopGame") -> None: - """Test submitting too short word gives no score.""" - mock_game.state.current_word = "ca" - mock_game.state.selected_letters = ["c", "a"] - mock_game.state.score = 0 - - mock_game._submit_word() - - assert mock_game.state.score == 0 - assert "too short" in mock_game.state.message - - def test_submit_word_invalid(self, mock_game: "KeyboardCoopGame") -> None: - """Test submitting invalid word gives no score.""" - mock_game.state.current_word = "xyz" - mock_game.state.selected_letters = ["x", "y", "z"] - mock_game.state.score = 0 - - mock_game._submit_word() - - assert mock_game.state.score == 0 - assert "not a valid word" in mock_game.state.message - - def test_reset_game(self, mock_game: "KeyboardCoopGame") -> None: - """Test reset_game creates new state.""" - mock_game._generate_random_keyboard = MagicMock() - mock_game.state.score = 100 - mock_game.state.current_word = "test" - - mock_game._reset_game() - - # After reset, state should be fresh - assert mock_game.state.score == 0 - assert mock_game.state.current_word == "" - assert mock_game._generate_random_keyboard.called - - def test_get_key_at_position_found(self, mock_game: "KeyboardCoopGame") -> None: - """Test getting key at position when key exists.""" - mock_rect = MagicMock() - mock_rect.collidepoint.return_value = True - mock_game.keyboard.positions = {"a": mock_rect} - - result = mock_game._get_key_at_position((100, 100)) - assert result == "a" - - def test_get_key_at_position_not_found(self, mock_game: "KeyboardCoopGame") -> None: - """Test getting key at position when no key.""" - mock_rect = MagicMock() - mock_rect.collidepoint.return_value = False - mock_game.keyboard.positions = {"a": mock_rect} - - result = mock_game._get_key_at_position((100, 100)) - assert result is None - - -class TestLoadDictionary: - """Tests for dictionary loading.""" - - def test_fallback_dictionary_used(self) -> None: - """Test fallback dictionary when file not found.""" - mock_pg = MagicMock() - mock_pg.font.Font.return_value = MagicMock() - mock_pg.display.set_mode.return_value = MagicMock() - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("pathlib.Path.open", side_effect=FileNotFoundError), - ): - from python_pkg.keyboard_coop.main import KeyboardCoopGame - - game = object.__new__(KeyboardCoopGame) - dictionary = game._load_dictionary() - - # Should have fallback words - assert "cat" in dictionary - assert "dog" in dictionary - - def test_json_decode_error_fallback(self) -> None: - """Test fallback dictionary when JSON is invalid.""" - import json - - mock_pg = MagicMock() - mock_pg.font.Font.return_value = MagicMock() - mock_pg.display.set_mode.return_value = MagicMock() - - mock_file = MagicMock() - mock_file.__enter__ = MagicMock(return_value=mock_file) - mock_file.__exit__ = MagicMock(return_value=False) - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("pathlib.Path.open", return_value=mock_file), - patch("json.load", side_effect=json.JSONDecodeError("err", "doc", 0)), - ): - from python_pkg.keyboard_coop.main import KeyboardCoopGame - - game = object.__new__(KeyboardCoopGame) - dictionary = game._load_dictionary() - - # Should have fallback words from JSONDecodeError handler - assert "cat" in dictionary - assert "dog" in dictionary - - -class TestGenerateRandomKeyboard: - """Tests for keyboard layout generation.""" - - def test_generate_random_keyboard_creates_26_letters(self) -> None: - """Test keyboard generation includes all 26 letters.""" - mock_pg = MagicMock() - mock_pg.Rect = MagicMock(return_value=MagicMock()) - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.keyboard = KeyboardState() - - game._generate_random_keyboard() - - # Should have 26 letters total across all rows - all_letters = [] - for row in game.keyboard.layout: - all_letters.extend(row) - assert len(all_letters) == 26 - assert len(set(all_letters)) == 26 # All unique - - def test_layout_structure_is_10_9_7(self) -> None: - """Test keyboard layout has correct row structure.""" - mock_pg = MagicMock() - mock_pg.Rect = MagicMock(return_value=MagicMock()) - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.keyboard = KeyboardState() - - game._generate_random_keyboard() - - assert len(game.keyboard.layout) == 3 - assert len(game.keyboard.layout[0]) == 10 - assert len(game.keyboard.layout[1]) == 9 - assert len(game.keyboard.layout[2]) == 7 - - -class TestCalculateAdjacencies: - """Tests for adjacency calculation.""" - - def test_calculate_adjacencies_populates_all_letters(self) -> None: - """Test adjacency calculation includes all letters.""" - mock_pg = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.keyboard = KeyboardState() - game.keyboard.layout = [ - ["a", "b", "c"], - ["d", "e", "f"], - ["g", "h"], - ] - - game._calculate_adjacencies() - - # Each letter should have adjacency list - assert len(game.keyboard.adjacency) == 8 - # Corner letter should have fewer adjacents - assert "b" in game.keyboard.adjacency["a"] - assert "d" in game.keyboard.adjacency["a"] - assert "e" in game.keyboard.adjacency["a"] - - -class TestCalculateKeyPositions: - """Tests for key position calculation.""" - - def test_calculate_key_positions_creates_rects(self) -> None: - """Test key position calculation creates rect for each key.""" - mock_pg = MagicMock() - mock_pg.Rect = MagicMock(return_value=MagicMock()) - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.keyboard = KeyboardState() - game.keyboard.layout = [["a", "b"], ["c", "d"]] - - positions = game._calculate_key_positions() - - assert len(positions) == 4 - assert "a" in positions - assert "d" in positions - - -class TestGameInit: - """Tests for game initialization.""" - - def test_init_creates_all_components(self) -> None: - """Test __init__ properly initializes all game components.""" - mock_pg = MagicMock() - mock_pg.font.Font.return_value = MagicMock() - mock_pg.display.set_mode.return_value = MagicMock() - mock_pg.Rect = MagicMock(return_value=MagicMock()) - - mock_file = MagicMock() - mock_file.__enter__ = MagicMock(return_value=mock_file) - mock_file.__exit__ = MagicMock(return_value=False) - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("pathlib.Path.open", return_value=mock_file), - patch("json.load", return_value={"cat": 1, "dog": 1}), - ): - from python_pkg.keyboard_coop.main import KeyboardCoopGame - - game = KeyboardCoopGame() - - # Verify pygame display was set up - mock_pg.display.set_mode.assert_called() - mock_pg.display.set_caption.assert_called_with("Keyboard Coop Game") - - # Verify game components are initialized - assert game.screen is not None - assert game.clock is not None - assert game.fonts is not None - assert game.dictionary is not None - assert game.state is not None - assert game.keyboard is not None - - -class TestHandleClick: - """Tests for click handling.""" - - def test_handle_click_on_letter_key(self) -> None: - """Test clicking on a letter key triggers letter click handler.""" - mock_pg = MagicMock() - mock_pg.Rect = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.available_letters = {"a"} - game.keyboard.adjacency = {"a": []} - - # Mock _get_key_at_position to return "a" - game._get_key_at_position = MagicMock(return_value="a") - game._handle_letter_click = MagicMock() - game._submit_word = MagicMock() - game._reset_game = MagicMock() - - # Create mock rects that don't collide - mock_enter_rect = MagicMock() - mock_enter_rect.collidepoint.return_value = False - mock_reset_rect = MagicMock() - mock_reset_rect.collidepoint.return_value = False - - # Patch pygame.Rect to return our mocks - mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect] - - game._handle_click((100, 100)) - - game._handle_letter_click.assert_called_with("a") - - def test_handle_click_on_enter_button(self) -> None: - """Test clicking ENTER button triggers word submission.""" - mock_pg = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - - # Mock methods - game._get_key_at_position = MagicMock(return_value=None) - game._submit_word = MagicMock() - game._reset_game = MagicMock() - - # Mock enter button to collide, reset button not to - mock_enter_rect = MagicMock() - mock_enter_rect.collidepoint.return_value = True - mock_reset_rect = MagicMock() - mock_reset_rect.collidepoint.return_value = False - - mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect] - - game._handle_click((750, 200)) - - game._submit_word.assert_called() - game._reset_game.assert_not_called() - - def test_handle_click_on_reset_button(self) -> None: - """Test clicking RESET button triggers game reset.""" - mock_pg = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - - # Mock methods - game._get_key_at_position = MagicMock(return_value=None) - game._submit_word = MagicMock() - game._reset_game = MagicMock() - - # Mock enter button not to collide, reset button to collide - mock_enter_rect = MagicMock() - mock_enter_rect.collidepoint.return_value = False - mock_reset_rect = MagicMock() - mock_reset_rect.collidepoint.return_value = True - - mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect] - - game._handle_click((900, 200)) - - game._reset_game.assert_called() - game._submit_word.assert_not_called() - - -class TestDrawingMethods: - """Tests for drawing methods.""" - - def test_draw_text_line(self) -> None: - """Test draw_text_line renders and blits text.""" - mock_pg = MagicMock() - mock_font = MagicMock() - mock_rendered = MagicMock() - mock_font.render.return_value = mock_rendered - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KeyboardCoopGame, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - - game._draw_text_line("Test text", (10, 20), mock_font) - - mock_font.render.assert_called() - game.screen.blit.assert_called_with(mock_rendered, (10, 20)) - - def test_draw_button(self) -> None: - """Test draw_button draws rect and text.""" - mock_pg = MagicMock() - mock_pg.draw = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - FontSet, - KeyboardCoopGame, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - - mock_rect = MagicMock() - mock_rect.center = (50, 50) - - game._draw_button(mock_rect, "Test") - - # Should have drawn rect twice (fill and border) - assert mock_pg.draw.rect.call_count == 2 - - -class TestDrawKeyboard: - """Tests for keyboard drawing.""" - - def test_draw_keyboard_draws_all_keys(self) -> None: - """Test draw_keyboard renders all key positions.""" - mock_pg = MagicMock() - mock_pg.draw = MagicMock() - mock_pg.mouse.get_pos.return_value = (0, 0) - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - - # Set up some positions - mock_rect_a = MagicMock() - mock_rect_a.collidepoint.return_value = False - mock_rect_a.center = (100, 100) - mock_rect_b = MagicMock() - mock_rect_b.collidepoint.return_value = False - mock_rect_b.center = (150, 100) - - game.keyboard.positions = {"a": mock_rect_a, "b": mock_rect_b} - game.keyboard.available_letters = {"a", "b"} - - game._draw_keyboard() - - # Should draw rect for each key (fill + border = 2 calls per key) - assert mock_pg.draw.rect.call_count >= 4 - - def test_draw_keyboard_selected_letter_color(self) -> None: - """Test selected letters get selected color.""" - mock_pg = MagicMock() - mock_pg.draw = MagicMock() - mock_pg.mouse.get_pos.return_value = (0, 0) - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KEY_SELECTED_COLOR, - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.state = GameState() - game.state.selected_letters = ["a"] # 'a' is selected - game.keyboard = KeyboardState() - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - - mock_rect_a = MagicMock() - mock_rect_a.collidepoint.return_value = False - mock_rect_a.center = (100, 100) - - game.keyboard.positions = {"a": mock_rect_a} - game.keyboard.available_letters = {"a"} - - game._draw_keyboard() - - # Check that KEY_SELECTED_COLOR was used - calls = mock_pg.draw.rect.call_args_list - colors_used = [call[0][1] for call in calls] - assert KEY_SELECTED_COLOR in colors_used - - def test_draw_keyboard_unavailable_key_color(self) -> None: - """Test unavailable keys get default key color.""" - mock_pg = MagicMock() - mock_pg.draw = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - KEY_COLOR, - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.state = GameState() - game.state.selected_letters = [] # Not selected - game.keyboard = KeyboardState() - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - - mock_rect_a = MagicMock() - mock_rect_a.center = (100, 100) - - game.keyboard.positions = {"a": mock_rect_a} - # Key is NOT available - should get KEY_COLOR - game.keyboard.available_letters = set() - - game._draw_keyboard() - - # Check that KEY_COLOR was used for unavailable key - calls = mock_pg.draw.rect.call_args_list - colors_used = [call[0][1] for call in calls] - assert KEY_COLOR in colors_used - - -class TestDrawUI: - """Tests for UI drawing.""" - - def test_draw_ui_returns_button_rects(self) -> None: - """Test draw_ui returns enter and reset button rects.""" - mock_pg = MagicMock() - mock_pg.draw = MagicMock() - mock_rect_instance = MagicMock() - mock_pg.Rect.return_value = mock_rect_instance - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.state = GameState() - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - - enter_rect, reset_rect = game._draw_ui() - - # Should return pygame.Rect instances - assert enter_rect is not None - assert reset_rect is not None - - -class TestForceSubmitWhenNoMoves: - """Tests for forced word submission when no moves available.""" - - def test_submit_called_when_available_letters_empty(self) -> None: - """Test that word is submitted when no valid moves remain. - - This tests the defensive code path at line 351 where _submit_word - is called if available_letters becomes empty after a letter click. - """ - mock_pg = MagicMock() - - with patch.dict("sys.modules", {"pygame": mock_pg}): - from python_pkg.keyboard_coop.main import ( - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.layout = [["a", "b"]] - game.keyboard.adjacency = {} - game.keyboard.available_letters = {"a"} - game.keyboard.positions = {} - game.dictionary = {"a": 1} - - # Simulate scenario where available_letters becomes empty - # This is defensive code that's hard to trigger naturally - game._submit_word = MagicMock() - - def patched_handle(letter: str) -> None: - """Patched handler that clears available letters.""" - if letter in game.keyboard.available_letters: - game.state.selected_letters.append(letter) - game.state.current_word += letter - # Force empty to trigger the check - game.keyboard.available_letters = set() - if not game.keyboard.available_letters: - game._submit_word() - - patched_handle("a") - - # Should have triggered submit_word - game._submit_word.assert_called() - - -class TestGameLoop: - """Tests for the main game loop.""" - - def test_run_quit_event(self) -> None: - """Test game loop exits on QUIT event.""" - mock_pg = MagicMock() - - # Create quit event - quit_event = MagicMock() - quit_event.type = "QUIT" - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.event.get.return_value = [quit_event] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit") as mock_exit, - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - - game.run() - - mock_pg.quit.assert_called() - mock_exit.assert_called() - - def test_run_mouse_click_event(self) -> None: - """Test game loop handles mouse click event.""" - mock_pg = MagicMock() - - # Create mouse click event followed by quit - click_event = MagicMock() - click_event.type = "MOUSEDOWN" - click_event.button = 1 - click_event.pos = (100, 100) - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - # Return click event first, then quit event - mock_pg.event.get.side_effect = [[click_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._handle_click = MagicMock() - - game.run() - - game._handle_click.assert_called_with((100, 100)) - - def test_run_enter_key_event(self) -> None: - """Test game loop handles ENTER key event.""" - mock_pg = MagicMock() - - key_event = MagicMock() - key_event.type = "KEYDOWN" - key_event.key = "K_RETURN" - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.K_RETURN = "K_RETURN" - mock_pg.K_r = "K_r" - mock_pg.event.get.side_effect = [[key_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._submit_word = MagicMock() - - game.run() - - game._submit_word.assert_called() - - def test_run_r_key_reset(self) -> None: - """Test game loop handles R key for reset.""" - mock_pg = MagicMock() - - key_event = MagicMock() - key_event.type = "KEYDOWN" - key_event.key = "K_r" - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.K_RETURN = "K_RETURN" - mock_pg.K_r = "K_r" - mock_pg.event.get.side_effect = [[key_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._reset_game = MagicMock() - - game.run() - - game._reset_game.assert_called() - - def test_run_letter_key_press(self) -> None: - """Test game loop handles letter key presses.""" - mock_pg = MagicMock() - - key_event = MagicMock() - key_event.type = "KEYDOWN" - key_event.key = "some_key" - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.K_RETURN = "K_RETURN" - mock_pg.K_r = "K_r" - mock_pg.key.name.return_value = "a" # Single letter key - mock_pg.event.get.side_effect = [[key_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._handle_letter_click = MagicMock() - - game.run() - - game._handle_letter_click.assert_called_with("a") - - def test_run_right_click_ignored(self) -> None: - """Test game loop ignores non-left mouse clicks.""" - mock_pg = MagicMock() - - click_event = MagicMock() - click_event.type = "MOUSEDOWN" - click_event.button = 3 # Right click - click_event.pos = (100, 100) - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.event.get.side_effect = [[click_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._handle_click = MagicMock() - - game.run() - - # handle_click should NOT be called for right click - game._handle_click.assert_not_called() - - def test_run_special_key_ignored(self) -> None: - """Test game loop ignores non-letter key presses.""" - mock_pg = MagicMock() - - key_event = MagicMock() - key_event.type = "KEYDOWN" - key_event.key = "some_key" - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.K_RETURN = "K_RETURN" - mock_pg.K_r = "K_r" - mock_pg.key.name.return_value = "escape" # Multi-char, not a letter - mock_pg.event.get.side_effect = [[key_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._handle_letter_click = MagicMock() - - game.run() - - # handle_letter_click should NOT be called for special keys - game._handle_letter_click.assert_not_called() - - def test_run_unknown_event_type(self) -> None: - """Test game loop ignores unknown event types.""" - mock_pg = MagicMock() - - unknown_event = MagicMock() - unknown_event.type = "UNKNOWN" - - quit_event = MagicMock() - quit_event.type = "QUIT" - - mock_pg.QUIT = "QUIT" - mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN" - mock_pg.KEYDOWN = "KEYDOWN" - mock_pg.event.get.side_effect = [[unknown_event], [quit_event]] - - with ( - patch.dict("sys.modules", {"pygame": mock_pg}), - patch("sys.exit"), - ): - from python_pkg.keyboard_coop.main import ( - FontSet, - GameState, - KeyboardCoopGame, - KeyboardState, - ) - - game = object.__new__(KeyboardCoopGame) - game.screen = MagicMock() - game.clock = MagicMock() - game.state = GameState() - game.keyboard = KeyboardState() - game.keyboard.positions = {} - game.fonts = FontSet( - normal=MagicMock(), large=MagicMock(), small=MagicMock() - ) - game._draw_keyboard = MagicMock() - game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock())) - game._handle_click = MagicMock() - game._submit_word = MagicMock() - game._reset_game = MagicMock() - game._handle_letter_click = MagicMock() - - game.run() - - # None of the handlers should be called for unknown event - game._handle_click.assert_not_called() - game._submit_word.assert_not_called() - game._reset_game.assert_not_called() - game._handle_letter_click.assert_not_called() diff --git a/python_pkg/keyboard_coop/tests/test_ui.py b/python_pkg/keyboard_coop/tests/test_ui.py new file mode 100644 index 0000000..dd346e9 --- /dev/null +++ b/python_pkg/keyboard_coop/tests/test_ui.py @@ -0,0 +1,311 @@ +"""Tests for keyboard_coop UI drawing and click handling.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + + +class TestHandleClick: + """Tests for click handling.""" + + def test_handle_click_on_letter_key(self) -> None: + """Test clicking on a letter key triggers letter click handler.""" + mock_pg = MagicMock() + mock_pg.Rect = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.available_letters = {"a"} + game.keyboard.adjacency = {"a": []} + + # Mock _get_key_at_position to return "a" + game._get_key_at_position = MagicMock(return_value="a") + game._handle_letter_click = MagicMock() + game._submit_word = MagicMock() + game._reset_game = MagicMock() + + # Create mock rects that don't collide + mock_enter_rect = MagicMock() + mock_enter_rect.collidepoint.return_value = False + mock_reset_rect = MagicMock() + mock_reset_rect.collidepoint.return_value = False + + # Patch pygame.Rect to return our mocks + mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect] + + game._handle_click((100, 100)) + + game._handle_letter_click.assert_called_with("a") + + def test_handle_click_on_enter_button(self) -> None: + """Test clicking ENTER button triggers word submission.""" + mock_pg = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + + # Mock methods + game._get_key_at_position = MagicMock(return_value=None) + game._submit_word = MagicMock() + game._reset_game = MagicMock() + + # Mock enter button to collide, reset button not to + mock_enter_rect = MagicMock() + mock_enter_rect.collidepoint.return_value = True + mock_reset_rect = MagicMock() + mock_reset_rect.collidepoint.return_value = False + + mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect] + + game._handle_click((750, 200)) + + game._submit_word.assert_called() + game._reset_game.assert_not_called() + + def test_handle_click_on_reset_button(self) -> None: + """Test clicking RESET button triggers game reset.""" + mock_pg = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.state = GameState() + game.keyboard = KeyboardState() + game.keyboard.positions = {} + + # Mock methods + game._get_key_at_position = MagicMock(return_value=None) + game._submit_word = MagicMock() + game._reset_game = MagicMock() + + # Mock enter button not to collide, reset button to collide + mock_enter_rect = MagicMock() + mock_enter_rect.collidepoint.return_value = False + mock_reset_rect = MagicMock() + mock_reset_rect.collidepoint.return_value = True + + mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect] + + game._handle_click((900, 200)) + + game._reset_game.assert_called() + game._submit_word.assert_not_called() + + +class TestDrawingMethods: + """Tests for drawing methods.""" + + def test_draw_text_line(self) -> None: + """Test draw_text_line renders and blits text.""" + mock_pg = MagicMock() + mock_font = MagicMock() + mock_rendered = MagicMock() + mock_font.render.return_value = mock_rendered + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KeyboardCoopGame, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + + game._draw_text_line("Test text", (10, 20), mock_font) + + mock_font.render.assert_called() + game.screen.blit.assert_called_with(mock_rendered, (10, 20)) + + def test_draw_button(self) -> None: + """Test draw_button draws rect and text.""" + mock_pg = MagicMock() + mock_pg.draw = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + FontSet, + KeyboardCoopGame, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + + mock_rect = MagicMock() + mock_rect.center = (50, 50) + + game._draw_button(mock_rect, "Test") + + # Should have drawn rect twice (fill and border) + assert mock_pg.draw.rect.call_count == 2 + + +class TestDrawKeyboard: + """Tests for keyboard drawing.""" + + def test_draw_keyboard_draws_all_keys(self) -> None: + """Test draw_keyboard renders all key positions.""" + mock_pg = MagicMock() + mock_pg.draw = MagicMock() + mock_pg.mouse.get_pos.return_value = (0, 0) + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.state = GameState() + game.keyboard = KeyboardState() + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + + # Set up some positions + mock_rect_a = MagicMock() + mock_rect_a.collidepoint.return_value = False + mock_rect_a.center = (100, 100) + mock_rect_b = MagicMock() + mock_rect_b.collidepoint.return_value = False + mock_rect_b.center = (150, 100) + + game.keyboard.positions = {"a": mock_rect_a, "b": mock_rect_b} + game.keyboard.available_letters = {"a", "b"} + + game._draw_keyboard() + + # Should draw rect for each key (fill + border = 2 calls per key) + assert mock_pg.draw.rect.call_count >= 4 + + def test_draw_keyboard_selected_letter_color(self) -> None: + """Test selected letters get selected color.""" + mock_pg = MagicMock() + mock_pg.draw = MagicMock() + mock_pg.mouse.get_pos.return_value = (0, 0) + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KEY_SELECTED_COLOR, + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.state = GameState() + game.state.selected_letters = ["a"] # 'a' is selected + game.keyboard = KeyboardState() + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + + mock_rect_a = MagicMock() + mock_rect_a.collidepoint.return_value = False + mock_rect_a.center = (100, 100) + + game.keyboard.positions = {"a": mock_rect_a} + game.keyboard.available_letters = {"a"} + + game._draw_keyboard() + + # Check that KEY_SELECTED_COLOR was used + calls = mock_pg.draw.rect.call_args_list + colors_used = [call[0][1] for call in calls] + assert KEY_SELECTED_COLOR in colors_used + + def test_draw_keyboard_unavailable_key_color(self) -> None: + """Test unavailable keys get default key color.""" + mock_pg = MagicMock() + mock_pg.draw = MagicMock() + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + KEY_COLOR, + FontSet, + GameState, + KeyboardCoopGame, + KeyboardState, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.state = GameState() + game.state.selected_letters = [] # Not selected + game.keyboard = KeyboardState() + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + + mock_rect_a = MagicMock() + mock_rect_a.center = (100, 100) + + game.keyboard.positions = {"a": mock_rect_a} + # Key is NOT available - should get KEY_COLOR + game.keyboard.available_letters = set() + + game._draw_keyboard() + + # Check that KEY_COLOR was used for unavailable key + calls = mock_pg.draw.rect.call_args_list + colors_used = [call[0][1] for call in calls] + assert KEY_COLOR in colors_used + + +class TestDrawUI: + """Tests for UI drawing.""" + + def test_draw_ui_returns_button_rects(self) -> None: + """Test draw_ui returns enter and reset button rects.""" + mock_pg = MagicMock() + mock_pg.draw = MagicMock() + mock_rect_instance = MagicMock() + mock_pg.Rect.return_value = mock_rect_instance + + with patch.dict("sys.modules", {"pygame": mock_pg}): + from python_pkg.keyboard_coop.main import ( + FontSet, + GameState, + KeyboardCoopGame, + ) + + game = object.__new__(KeyboardCoopGame) + game.screen = MagicMock() + game.state = GameState() + game.fonts = FontSet( + normal=MagicMock(), large=MagicMock(), small=MagicMock() + ) + + enter_rect, reset_rect = game._draw_ui() + + # Should return pygame.Rect instances + assert enter_rect is not None + assert reset_rect is not None diff --git a/python_pkg/lichess_bot/tests/test_main.py b/python_pkg/lichess_bot/tests/test_main.py deleted file mode 100644 index d92991c..0000000 --- a/python_pkg/lichess_bot/tests/test_main.py +++ /dev/null @@ -1,1244 +0,0 @@ -"""Tests for lichess_bot main module.""" - -from __future__ import annotations - -import os -import threading -from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, PropertyMock, patch - -import chess -import pytest -import requests - -from python_pkg.lichess_bot.main import ( - BotContext, - GameMeta, - GameState, - _apply_move_to_board, - _attempt_move, - _calculate_time_budget, - _collect_analysis_lines, - _extract_game_full_data, - _extract_game_state_data, - _extract_player_info, - _finalize_game, - _handle_challenge, - _handle_game, - _handle_move_if_needed, - _init_game_log, - _insert_analysis_into_log, - _is_my_turn, - _log_analysis_progress, - _log_move_to_file, - _process_analysis_output, - _process_bot_event, - _process_game_event, - _process_game_events_loop, - _rebuild_board_from_moves, - _run_analysis_subprocess, - _run_event_loop, - _run_event_loop_iteration, - _safe_event_loop_iteration, - _stream_bot_events, - _update_clocks_from_state, - _write_pgn_to_log, - main, - run_bot, -) - -if TYPE_CHECKING: - from pathlib import Path - -# Type alias to make mypy happy with test event dicts -Event = dict[str, Any] -GameThreads = dict[str, threading.Thread] - - -class TestApplyMoveToBoard: - """Tests for _apply_move_to_board.""" - - def test_apply_valid_move(self) -> None: - """Test applying a valid move.""" - board = chess.Board() - _apply_move_to_board(board, "e2e4", "game1") - assert board.fen() != chess.STARTING_FEN - - def test_apply_invalid_move(self) -> None: - """Test applying an invalid move logs debug.""" - board = chess.Board() - with patch("python_pkg.lichess_bot.main._logger") as mock_logger: - _apply_move_to_board(board, "invalid", "game1") - mock_logger.debug.assert_called_once() - - -class TestInitGameLog: - """Tests for _init_game_log.""" - - def test_init_game_log_success(self, tmp_path: Path) -> None: - """Test successful log initialization.""" - with patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path): - result = _init_game_log("game123", 42) - assert result is not None - assert result.exists() - content = result.read_text() - assert "game game123 started" in content - assert "bot_version v42" in content - - def test_init_game_log_oserror(self) -> None: - """Test log initialization with OSError.""" - with patch("python_pkg.lichess_bot.main.Path.cwd") as mock_cwd: - mock_path = MagicMock() - mock_path.__truediv__ = MagicMock(return_value=mock_path) - mock_path.open.side_effect = OSError("Permission denied") - mock_cwd.return_value = mock_path - result = _init_game_log("game123", 42) - assert result is None - - -class TestUpdateClocksFromState: - """Tests for _update_clocks_from_state.""" - - def test_update_clocks_white(self) -> None: - """Test clock update when playing as white.""" - state = GameState(color="white") - state_data: Event = {"wtime": 60000, "btime": 55000, "winc": 1000} - _update_clocks_from_state(state_data, state) - assert state.my_ms == 60000 - assert state.opp_ms == 55000 - assert state.inc_ms == 1000 - - def test_update_clocks_black(self) -> None: - """Test clock update when playing as black.""" - state = GameState(color="black") - state_data: Event = {"wtime": 60000, "btime": 55000, "binc": 2000} - _update_clocks_from_state(state_data, state) - assert state.my_ms == 55000 - assert state.opp_ms == 60000 - assert state.inc_ms == 2000 - - def test_update_clocks_float_values(self) -> None: - """Test clock update with float values.""" - state = GameState(color="white") - state_data: Event = {"wtime": 60000.5, "btime": 55000.5} - _update_clocks_from_state(state_data, state) - assert state.my_ms == 60000 - assert state.opp_ms == 55000 - - def test_update_clocks_none_values(self) -> None: - """Test clock update with None values.""" - state = GameState(color="white") - state_data: Event = {"wtime": None, "btime": None} - _update_clocks_from_state(state_data, state) - assert state.my_ms is None - assert state.opp_ms is None - - -class TestExtractPlayerInfo: - """Tests for _extract_player_info.""" - - def test_extract_player_info_white(self) -> None: - """Test extracting player info when bot is white.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "white": {"id": "mybot", "name": "MyBot"}, - "black": {"id": "opp", "name": "Opponent"}, - } - _extract_player_info(event, state, meta, api) - assert state.color == "white" - assert meta.white_name == "MyBot" - assert meta.black_name == "Opponent" - - def test_extract_player_info_black(self) -> None: - """Test extracting player info when bot is black.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "white": {"id": "opp", "name": "Opponent"}, - "black": {"id": "mybot", "name": "MyBot"}, - } - _extract_player_info(event, state, meta, api) - assert state.color == "black" - - def test_extract_player_info_invalid_data(self) -> None: - """Test extracting player info with invalid data.""" - api = MagicMock() - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = {"white": "invalid", "black": "invalid"} - _extract_player_info(event, state, meta, api) - assert state.color is None - - def test_extract_player_info_missing_name(self) -> None: - """Test extracting player info with missing name uses id.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "white": {"id": "mybot"}, - "black": {"id": "opponent"}, - } - _extract_player_info(event, state, meta, api) - assert meta.white_name == "mybot" - assert meta.black_name == "opponent" - - -class TestExtractGameFullData: - """Tests for _extract_game_full_data.""" - - def test_extract_game_full_data(self) -> None: - """Test extracting gameFull data.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "state": {"moves": "e2e4 e7e5", "status": "started", "wtime": 60000}, - "white": {"id": "mybot"}, - "black": {"id": "opp"}, - "createdAt": 1609459200000, # 2021-01-01 - } - moves, status = _extract_game_full_data(event, state, meta, api) - assert moves == "e2e4 e7e5" - assert status == "started" - assert meta.site_url == "https://lichess.org/game1" - assert meta.date_iso == "2021.01.01" - - def test_extract_game_full_data_invalid_state(self) -> None: - """Test extracting gameFull data with invalid state.""" - api = MagicMock() - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = {"state": "invalid"} - moves, status = _extract_game_full_data(event, state, meta, api) - assert moves == "" - assert status is None - - -class TestExtractGameStateData: - """Tests for _extract_game_state_data.""" - - def test_extract_game_state_as_white(self) -> None: - """Test extracting gameState data as white.""" - state = GameState(color="white", my_ms=60000) - event: Event = { - "moves": "e2e4", - "status": "started", - "wtime": 59000, - "btime": 60000, - } - moves, status = _extract_game_state_data(event, state) - assert moves == "e2e4" - assert status == "started" - assert state.my_ms == 59000 - assert state.opp_ms == 60000 - - def test_extract_game_state_as_black(self) -> None: - """Test extracting gameState data as black.""" - state = GameState(color="black") - event: Event = { - "moves": "e2e4 e7e5", - "wtime": 60000, - "btime": 59000, - "binc": 1000, - } - moves, __status = _extract_game_state_data(event, state) - assert moves == "e2e4 e7e5" - assert state.my_ms == 59000 - assert state.opp_ms == 60000 - assert state.inc_ms == 1000 - - -class TestCalculateTimeBudget: - """Tests for _calculate_time_budget.""" - - def test_calculate_time_budget_normal(self) -> None: - """Test time budget calculation.""" - state = GameState(my_ms=60000, inc_ms=1000) - board = chess.Board() - budget = _calculate_time_budget(state, board, 10.0) - assert 0.05 <= budget <= 10.0 - - def test_calculate_time_budget_low_time(self) -> None: - """Test time budget with low time.""" - state = GameState(my_ms=1000, inc_ms=0) - board = chess.Board() - budget = _calculate_time_budget(state, board, 10.0) - assert budget >= 0.05 - - -class TestLogMoveToFile: - """Tests for _log_move_to_file.""" - - def test_log_move_to_file(self, tmp_path: Path) -> None: - """Test logging a move to file.""" - log_path = tmp_path / "game.log" - log_path.write_text("header\n") - move = chess.Move.from_uci("e2e4") - _log_move_to_file(log_path, 1, move, "best move") - content = log_path.read_text() - assert "ply 1: e2e4" in content - assert "best move" in content - - def test_log_move_to_file_none_path(self) -> None: - """Test logging with None path does nothing.""" - move = chess.Move.from_uci("e2e4") - _log_move_to_file(None, 1, move, "reason") # Should not raise - - -class TestAttemptMove: - """Tests for _attempt_move.""" - - def test_attempt_move_success(self) -> None: - """Test successful move attempt.""" - api = MagicMock() - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("e2e4"), - "opening", - ) - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(my_ms=60000) - meta = GameMeta(game_id="game1", bot_version=1) - board = chess.Board() - - result = _attempt_move(ctx, state, meta, board) - assert result is True - api.make_move.assert_called_once() - - def test_attempt_move_no_moves(self) -> None: - """Test move attempt with no legal moves.""" - api = MagicMock() - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = (None, "no moves") - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - board = chess.Board() - - result = _attempt_move(ctx, state, meta, board) - assert result is False - - def test_attempt_move_illegal(self) -> None: - """Test move attempt with illegal move.""" - api = MagicMock() - engine = MagicMock() - engine.max_time_sec = 5.0 - # Return a move that's not legal (e.g., random square move) - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("a1a8"), - "bad", - ) - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(my_ms=60000) - meta = GameMeta(game_id="game1", bot_version=1) - board = chess.Board() - - result = _attempt_move(ctx, state, meta, board) - assert result is True - api.make_move.assert_not_called() - - def test_attempt_move_request_error(self) -> None: - """Test move attempt with request error.""" - api = MagicMock() - api.make_move.side_effect = requests.RequestException("Network error") - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("e2e4"), - "opening", - ) - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(my_ms=60000) - meta = GameMeta(game_id="game1", bot_version=1) - board = chess.Board() - - result = _attempt_move(ctx, state, meta, board) - assert result is True # Still returns True - - -class TestIsMyTurn: - """Tests for _is_my_turn.""" - - def test_is_my_turn_white_to_move(self) -> None: - """Test checking turn when white to move.""" - board = chess.Board() # White to move - assert _is_my_turn(board, "white") is True - assert _is_my_turn(board, "black") is False - - def test_is_my_turn_black_to_move(self) -> None: - """Test checking turn when black to move.""" - board = chess.Board() - board.push_uci("e2e4") # Black to move - assert _is_my_turn(board, "white") is False - assert _is_my_turn(board, "black") is True - - -class TestRebuildBoardFromMoves: - """Tests for _rebuild_board_from_moves.""" - - def test_rebuild_board_from_moves(self) -> None: - """Test rebuilding board from moves list.""" - moves_list = ["e2e4", "e7e5", "g1f3"] - board = _rebuild_board_from_moves(moves_list, "game1") - assert len(board.move_stack) == 3 - - -class TestHandleMoveIfNeeded: - """Tests for _handle_move_if_needed.""" - - def test_handle_move_game_state_my_turn(self) -> None: - """Test handling move on gameState when my turn.""" - api = MagicMock() - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("e2e4"), - "opening", - ) - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(color="white", my_ms=60000, board=chess.Board()) - meta = GameMeta(game_id="game1", bot_version=1) - - result = _handle_move_if_needed(ctx, state, meta, "gameState", 0) - assert result is True - - def test_handle_move_game_full_with_moves(self) -> None: - """Test handling move on gameFull with existing moves (opponent's turn).""" - api = MagicMock() - engine = MagicMock() - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(color="white", my_ms=60000, board=chess.Board()) - meta = GameMeta(game_id="game1", bot_version=1) - - # gameFull with moves - don't move - result = _handle_move_if_needed(ctx, state, meta, "gameFull", 1) - assert result is True - engine.choose_move_with_explanation.assert_not_called() - - -class TestProcessGameEvent: - """Tests for _process_game_event.""" - - def test_process_game_event_unhandled_type(self) -> None: - """Test processing unhandled event type.""" - ctx = MagicMock() - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = {"type": "chatLine", "text": "hello"} - result = _process_game_event(event, ctx, state, meta) - assert result is True - - def test_process_game_event_game_full(self) -> None: - """Test processing gameFull event.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("e2e4"), - "opening", - ) - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState() - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "type": "gameFull", - "state": {"moves": "", "status": "started"}, - "white": {"id": "mybot"}, - "black": {"id": "opp"}, - } - result = _process_game_event(event, ctx, state, meta) - assert result is True - - def test_process_game_event_game_end(self) -> None: - """Test processing game end event.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = (None, "no moves") - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(color="white", last_handled_len=-1) - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "type": "gameState", - "moves": "e2e4 e7e5", - "status": "mate", - } - result = _process_game_event(event, ctx, state, meta) - assert result is False - - def test_process_game_event_game_end_after_move(self) -> None: - """Test game ends with status after handling move. - - This covers the case where _handle_move_if_needed returns True - but status indicates game end. - """ - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("d2d4"), - "response", - ) - ctx = BotContext(api=api, engine=engine, bot_version=1) - # Black's turn - it's opponent's move, so we don't need to move - state = GameState(color="black", last_handled_len=-1) - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "type": "gameState", - "moves": "e2e4", # One move - now it's black's turn - "status": "resign", # Game ended with resign - } - result = _process_game_event(event, ctx, state, meta) - assert result is False # Game should end - - def test_process_game_event_unchanged_position(self) -> None: - """Test processing event with unchanged position.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - state = GameState(last_handled_len=2, color="white") - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = {"type": "gameState", "moves": "e2e4 e7e5"} - result = _process_game_event(event, ctx, state, meta) - assert result is True - - def test_process_game_event_color_unknown(self) -> None: - """Test processing event with unknown color.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - state = GameState(last_handled_len=-1) - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = {"type": "gameState", "moves": "e2e4"} - result = _process_game_event(event, ctx, state, meta) - assert result is True - assert state.last_handled_len == 1 - - def test_process_game_event_color_unknown_on_gamefull(self) -> None: - """Test processing gameFull event with still unknown color. - - This covers the branch where event_type is gameFull but color - is not determined (e.g., spectator watching game). - """ - api = MagicMock() - # Return a user id that doesn't match either player - api.get_my_user_id.return_value = "spectator" - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - state = GameState(last_handled_len=-1) - meta = GameMeta(game_id="game1", bot_version=1) - event: Event = { - "type": "gameFull", - "state": {"moves": "e2e4", "status": "started"}, - "white": {"id": "player1"}, - "black": {"id": "player2"}, - } - result = _process_game_event(event, ctx, state, meta) - assert result is True - # last_handled_len should NOT be updated for gameFull with unknown color - assert state.last_handled_len == -1 - - -class TestWritePgnToLog: - """Tests for _write_pgn_to_log.""" - - def test_write_pgn_to_log(self, tmp_path: Path) -> None: - """Test writing PGN to log file.""" - log_path = tmp_path / "game.log" - log_path.write_text("header\n") - board = chess.Board() - board.push_uci("e2e4") - meta = GameMeta( - game_id="game1", - bot_version=1, - site_url="https://lichess.org/game1", - date_iso="2021.01.01", - white_name="White", - black_name="Black", - ) - _write_pgn_to_log(log_path, board, meta) - content = log_path.read_text() - assert "PGN:" in content - assert "e4" in content - - -class TestRunAnalysisSubprocess: - """Tests for _run_analysis_subprocess.""" - - def test_run_analysis_subprocess_script_not_found(self, tmp_path: Path) -> None: - """Test analysis when script not found.""" - log_path = tmp_path / "game.log" - with patch("python_pkg.lichess_bot.main.Path") as mock_path: - mock_script = MagicMock() - mock_script.is_file.return_value = False - resolve = mock_path.return_value.resolve.return_value - resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = ( - mock_script - ) - result = _run_analysis_subprocess("game1", log_path, 10) - assert result is None - - def test_run_analysis_subprocess_success(self, tmp_path: Path) -> None: - """Test successful analysis subprocess.""" - log_path = tmp_path / "game.log" - log_path.write_text("test") - - mock_proc = MagicMock() - mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"]) - mock_proc.stderr.read.return_value = "" - mock_proc.wait.return_value = 0 - mock_proc.__enter__ = MagicMock(return_value=mock_proc) - mock_proc.__exit__ = MagicMock(return_value=False) - - with ( - patch("python_pkg.lichess_bot.main.Path") as mock_path, - patch("subprocess.Popen", return_value=mock_proc), - ): - mock_script = MagicMock() - mock_script.is_file.return_value = True - resolve = mock_path.return_value.resolve.return_value - resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = ( - mock_script - ) - result = _run_analysis_subprocess("game1", log_path, 2) - - assert result is not None - - -class TestProcessAnalysisOutput: - """Tests for _process_analysis_output.""" - - def test_process_analysis_output_success(self) -> None: - """Test processing analysis output successfully.""" - mock_proc = MagicMock() - mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"]) - mock_proc.stderr.read.return_value = "" - mock_proc.wait.return_value = 0 - - result = _process_analysis_output(mock_proc, "game1", 2) - assert result is not None - assert "e4" in result - - def test_process_analysis_output_error_exit(self) -> None: - """Test processing analysis output with error exit.""" - mock_proc = MagicMock() - mock_proc.stdout = iter(["output\n"]) - mock_proc.stderr.read.return_value = "error message" - mock_proc.wait.return_value = 1 - - result = _process_analysis_output(mock_proc, "game1", 1) - assert result is not None - assert "stderr" in result - - def test_process_analysis_output_error_exit_no_stderr(self) -> None: - """Test processing analysis output with error exit but no stderr.""" - mock_proc = MagicMock() - mock_proc.stdout = iter(["output\n"]) - mock_proc.stderr.read.return_value = "" - mock_proc.wait.return_value = 1 - - result = _process_analysis_output(mock_proc, "game1", 1) - assert result is not None - assert "stderr" not in result - - def test_process_analysis_output_none_pipes(self) -> None: - """Test processing analysis output with None pipes.""" - mock_proc = MagicMock() - mock_proc.stdout = None - mock_proc.stderr = None - - with pytest.raises(RuntimeError, match="pipes unexpectedly None"): - _process_analysis_output(mock_proc, "game1", 1) - - -class TestCollectAnalysisLines: - """Tests for _collect_analysis_lines helper.""" - - def test_collect_analysis_lines_empty_iterator(self) -> None: - """Test collecting lines from empty iterator.""" - empty_iter: list[str] = [] - analyzed, lines = _collect_analysis_lines(iter(empty_iter), "game1", 10) - assert analyzed == 0 - assert lines == [] - - def test_collect_analysis_lines_with_content(self) -> None: - """Test collecting lines from iterator with content.""" - content = [" 1 e4\n", " 2 e5\n", "not a ply line\n"] - analyzed, lines = _collect_analysis_lines(iter(content), "game1", 3) - assert analyzed == 2 - assert lines == content - - def test_collect_analysis_lines_full_iteration(self) -> None: - """Test that all lines are collected.""" - content = ["line1\n", " 3 Nf3\n", "line3\n"] - analyzed, lines = _collect_analysis_lines(iter(content), "game1", 1) - assert analyzed == 1 - assert len(lines) == 3 - - -class TestLogAnalysisProgress: - """Tests for _log_analysis_progress.""" - - def test_log_analysis_progress_with_total(self) -> None: - """Test logging progress with known total.""" - with patch("python_pkg.lichess_bot.main._logger") as mock_logger: - _log_analysis_progress("game1", 5, 10) - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "50%" in call_args[0] % call_args[1:] - - def test_log_analysis_progress_zero_total(self) -> None: - """Test logging progress with zero total.""" - with patch("python_pkg.lichess_bot.main._logger") as mock_logger: - _log_analysis_progress("game1", 5, 0) - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "unknown" in call_args[0] - - -class TestInsertAnalysisIntoLog: - """Tests for _insert_analysis_into_log.""" - - def test_insert_analysis_before_pgn(self, tmp_path: Path) -> None: - """Test inserting analysis before PGN section.""" - log_path = tmp_path / "game.log" - log_path.write_text("header\n\nPGN:\n1. e4\n") - meta = GameMeta( - game_id="game1", - bot_version=1, - date_iso="2021.01.01", - white_name="White", - black_name="Black", - ) - _insert_analysis_into_log(log_path, "Analysis here", meta) - content = log_path.read_text() - assert "ANALYSIS:" in content - assert content.index("ANALYSIS:") < content.index("PGN:") - - def test_insert_analysis_at_start(self, tmp_path: Path) -> None: - """Test inserting analysis when PGN at start.""" - log_path = tmp_path / "game.log" - log_path.write_text("PGN:\n1. e4\n") - meta = GameMeta(game_id="game1", bot_version=1) - _insert_analysis_into_log(log_path, "Analysis here", meta) - content = log_path.read_text() - assert "ANALYSIS:" in content - - def test_insert_analysis_no_pgn(self, tmp_path: Path) -> None: - """Test inserting analysis when no PGN section.""" - log_path = tmp_path / "game.log" - log_path.write_text("header\n") - meta = GameMeta(game_id="game1", bot_version=1) - _insert_analysis_into_log(log_path, "Analysis here", meta) - content = log_path.read_text() - assert "ANALYSIS:" in content - - def test_insert_analysis_oserror(self, tmp_path: Path) -> None: - """Test inserting analysis with OSError.""" - log_path = tmp_path / "nonexistent" / "game.log" - meta = GameMeta(game_id="game1", bot_version=1) - # Should not raise, just log debug - _insert_analysis_into_log(log_path, "Analysis", meta) - - -class TestFinalizeGame: - """Tests for _finalize_game.""" - - def test_finalize_game_no_log_path(self) -> None: - """Test finalize game with no log path.""" - state = GameState(log_path=None) - meta = GameMeta(game_id="game1", bot_version=1) - _finalize_game(state, meta) # Should not raise - - def test_finalize_game_write_error(self, tmp_path: Path) -> None: - """Test finalize game with write error.""" - log_path = tmp_path / "game.log" - log_path.write_text("header") - state = GameState(log_path=log_path) - meta = GameMeta(game_id="game1", bot_version=1) - - with patch( - "python_pkg.lichess_bot.main._write_pgn_to_log", - side_effect=OSError("error"), - ): - _finalize_game(state, meta) # Should not raise - - def test_finalize_game_type_error_on_move_stack(self, tmp_path: Path) -> None: - """Test finalize game with TypeError on move_stack.""" - log_path = tmp_path / "game.log" - log_path.write_text("header\n") - state = GameState(log_path=log_path) - meta = GameMeta(game_id="game1", bot_version=1) - - mock_board = MagicMock() - # Use PropertyMock to raise TypeError when move_stack is accessed - type(mock_board).move_stack = PropertyMock(side_effect=TypeError()) - state.board = mock_board - - with patch("python_pkg.lichess_bot.main._write_pgn_to_log"): - _finalize_game(state, meta) # Should not raise - - def test_finalize_game_analysis_error(self, tmp_path: Path) -> None: - """Test finalize game with analysis error.""" - log_path = tmp_path / "game.log" - log_path.write_text("header\n") - state = GameState(log_path=log_path) - meta = GameMeta(game_id="game1", bot_version=1) - - with ( - patch("python_pkg.lichess_bot.main._write_pgn_to_log"), - patch( - "python_pkg.lichess_bot.main._run_analysis_subprocess", - side_effect=OSError("error"), - ), - ): - _finalize_game(state, meta) # Should not raise - - -class TestHandleGame: - """Tests for _handle_game.""" - - def test_handle_game_success(self, tmp_path: Path) -> None: - """Test handling a game successfully.""" - api = MagicMock() - api.get_my_user_id.return_value = "mybot" - api.stream_game_events.return_value = iter( - [ - { - "type": "gameFull", - "state": {"moves": "", "status": "started"}, - "white": {"id": "mybot"}, - "black": {"id": "opp"}, - }, - {"type": "gameState", "moves": "e2e4", "status": "mate"}, - ] - ) - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = (None, "no moves") - ctx = BotContext(api=api, engine=engine, bot_version=1) - - with ( - patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path), - patch( - "python_pkg.lichess_bot.main._run_analysis_subprocess", - return_value=None, - ), - ): - _handle_game("game1", ctx, None) - - def test_handle_game_request_error(self, tmp_path: Path) -> None: - """Test handling a game with request error.""" - api = MagicMock() - api.stream_game_events.side_effect = requests.RequestException("error") - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - - with ( - patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path), - patch( - "python_pkg.lichess_bot.main._run_analysis_subprocess", - return_value=None, - ), - ): - _handle_game("game1", ctx, None) # Should not raise - - def test_handle_game_skips_chat_events(self, tmp_path: Path) -> None: - """Test handling a game skips chat events.""" - api = MagicMock() - api.stream_game_events.return_value = iter( - [ - {"type": "chatLine", "text": "hello"}, - {"type": "opponentGone", "gone": True}, - ] - ) - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - - with ( - patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path), - patch( - "python_pkg.lichess_bot.main._run_analysis_subprocess", - return_value=None, - ), - ): - _handle_game("game1", ctx, None) - - -class TestProcessGameEventsLoop: - """Tests for _process_game_events_loop.""" - - def test_empty_events_iterator(self) -> None: - """Test processing empty events iterator.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - state = GameState(color="white") - meta = GameMeta(game_id="game1", bot_version=1) - - empty_iter: list[Event] = [] - # Should complete without error when iterator is empty - _process_game_events_loop(iter(empty_iter), ctx, state, meta) - - def test_processes_all_events(self) -> None: - """Test that all events are processed until break condition.""" - api = MagicMock() - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = (None, "no moves") - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(color="white") - meta = GameMeta(game_id="game1", bot_version=1) - - events: list[Event] = [ - {"type": "chatLine", "text": "hello"}, # skipped - {"type": "gameState", "moves": "e2e4", "status": "resign"}, # game end - ] - _process_game_events_loop(iter(events), ctx, state, meta) - - def test_processes_multiple_game_events(self) -> None: - """Test processing multiple game events that continue the game.""" - api = MagicMock() - engine = MagicMock() - engine.max_time_sec = 5.0 - engine.choose_move_with_explanation.return_value = ( - chess.Move.from_uci("e2e4"), - "e4", - ) - api.make_move.return_value = None - ctx = BotContext(api=api, engine=engine, bot_version=1) - state = GameState(color="white") - state.board = chess.Board() - meta = GameMeta(game_id="game1", bot_version=1) - - events: list[Event] = [ - # First event - game state, game continues - {"type": "gameState", "moves": "", "status": "started"}, - # Second event - opponent moves, game continues - {"type": "gameState", "moves": "e2e4 e7e5", "status": "started"}, - # Third event - game ends - {"type": "gameState", "moves": "e2e4 e7e5", "status": "mate"}, - ] - _process_game_events_loop(iter(events), ctx, state, meta) - - -class TestRunEventLoop: - """Tests for _run_event_loop.""" - - def test_run_event_loop_zero_iterations(self) -> None: - """Test running event loop with zero iterations.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - - # Should complete immediately with 0 iterations - _run_event_loop(ctx, game_threads, 0, 0) - - def test_run_event_loop_limited_iterations(self) -> None: - """Test running event loop with limited iterations.""" - api = MagicMock() - api.stream_bot_events.return_value = iter([]) - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - - with patch( - "python_pkg.lichess_bot.main._safe_event_loop_iteration", return_value=0 - ) as mock_iter: - _run_event_loop(ctx, game_threads, 0, 3) - assert mock_iter.call_count == 3 - - def test_run_event_loop_none_iterations_needs_interrupt(self) -> None: - """Test that None iterations runs until interrupted.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - - call_count = 0 - - def stop_after_calls(*_args: object, **_kwargs: object) -> int: - nonlocal call_count - call_count += 1 - if call_count >= 5: - raise KeyboardInterrupt - return 0 - - with ( - patch( - "python_pkg.lichess_bot.main._safe_event_loop_iteration", - side_effect=stop_after_calls, - ), - pytest.raises(KeyboardInterrupt), - ): - _run_event_loop(ctx, game_threads, 0, None) - - assert call_count == 5 - - -class TestHandleChallenge: - """Tests for _handle_challenge.""" - - def test_accept_standard_blitz(self) -> None: - """Test accepting standard blitz challenge.""" - api = MagicMock() - challenge: Event = { - "id": "ch1", - "variant": {"key": "standard"}, - "speed": "blitz", - } - _handle_challenge(challenge, api, decline_correspondence=False) - api.accept_challenge.assert_called_once_with("ch1") - - def test_decline_variant(self) -> None: - """Test declining non-standard variant.""" - api = MagicMock() - challenge: Event = { - "id": "ch1", - "variant": {"key": "chess960"}, - "speed": "blitz", - } - _handle_challenge(challenge, api, decline_correspondence=False) - api.decline_challenge.assert_called_once() - - def test_decline_correspondence(self) -> None: - """Test declining correspondence when flag set.""" - api = MagicMock() - challenge: Event = { - "id": "ch1", - "variant": {"key": "standard"}, - "speed": "correspondence", - } - _handle_challenge(challenge, api, decline_correspondence=True) - api.decline_challenge.assert_called_once() - - def test_accept_correspondence_when_allowed(self) -> None: - """Test accepting correspondence when flag not set.""" - api = MagicMock() - challenge: Event = { - "id": "ch1", - "variant": {"key": "standard"}, - "speed": "correspondence", - } - _handle_challenge(challenge, api, decline_correspondence=False) - api.decline_challenge.assert_called_once() # Still declined due to perf_ok - - def test_invalid_variant_data(self) -> None: - """Test handling invalid variant data.""" - api = MagicMock() - challenge: Event = { - "id": "ch1", - "variant": "invalid", - "speed": "blitz", - } - _handle_challenge(challenge, api, decline_correspondence=False) - api.accept_challenge.assert_called_once() - - -class TestProcessBotEvent: - """Tests for _process_bot_event.""" - - def test_process_challenge_event(self) -> None: - """Test processing challenge event.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = { - "type": "challenge", - "challenge": { - "id": "ch1", - "variant": {"key": "standard"}, - "speed": "blitz", - }, - } - _process_bot_event(event, ctx, game_threads) - api.accept_challenge.assert_called_once() - - def test_process_game_start_event(self) -> None: - """Test processing gameStart event.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = {"type": "gameStart", "game": {"id": "game1"}} - - with patch("python_pkg.lichess_bot.main.threading.Thread") as mock_thread_class: - mock_thread = MagicMock() - mock_thread_class.return_value = mock_thread - _process_bot_event(event, ctx, game_threads) - - assert "game1" in game_threads - mock_thread.start.assert_called_once() - - def test_process_game_start_existing_thread(self) -> None: - """Test processing gameStart with existing alive thread.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - mock_thread = MagicMock(spec=threading.Thread) - mock_thread.is_alive.return_value = True - game_threads: GameThreads = {"game1": mock_thread} - event: Event = {"type": "gameStart", "game": {"id": "game1"}} - _process_bot_event(event, ctx, game_threads) - # Should not create new thread - assert game_threads["game1"] is mock_thread - - def test_process_game_finish_event(self) -> None: - """Test processing gameFinish event.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = {"type": "gameFinish", "game": {"id": "game1"}} - with patch("python_pkg.lichess_bot.main._logger") as mock_logger: - _process_bot_event(event, ctx, game_threads) - mock_logger.info.assert_called() - - def test_process_game_finish_invalid_data(self) -> None: - """Test processing gameFinish event with non-dict game data.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = {"type": "gameFinish", "game": "not_a_dict"} - with patch("python_pkg.lichess_bot.main._logger") as mock_logger: - _process_bot_event(event, ctx, game_threads) - # Should not log info since game data is invalid - mock_logger.info.assert_not_called() - - def test_process_unknown_event(self) -> None: - """Test processing unknown event.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = {"type": "unknown", "data": "test"} - with patch("python_pkg.lichess_bot.main._logger") as mock_logger: - _process_bot_event(event, ctx, game_threads) - mock_logger.debug.assert_called() - - def test_process_challenge_invalid_data(self) -> None: - """Test processing challenge with invalid data.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = {"type": "challenge", "challenge": "invalid"} - _process_bot_event(event, ctx, game_threads) - api.accept_challenge.assert_not_called() - - def test_process_game_start_invalid_data(self) -> None: - """Test processing gameStart with invalid data.""" - api = MagicMock() - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - event: Event = {"type": "gameStart", "game": "invalid"} - _process_bot_event(event, ctx, game_threads) - assert len(game_threads) == 0 - - -class TestStreamBotEvents: - """Tests for _stream_bot_events.""" - - def test_stream_bot_events(self) -> None: - """Test streaming bot events.""" - api = MagicMock() - api.stream_events.return_value = iter([{"type": "test"}]) - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - events = list(_stream_bot_events(ctx)) - assert len(events) == 1 - - -class TestRunEventLoopIteration: - """Tests for _run_event_loop_iteration.""" - - def test_run_event_loop_iteration(self) -> None: - """Test running event loop iteration.""" - api = MagicMock() - api.stream_events.return_value = iter([{"type": "unknown"}]) - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - result = _run_event_loop_iteration(ctx, game_threads) - assert result == 0 - - -class TestSafeEventLoopIteration: - """Tests for _safe_event_loop_iteration.""" - - def test_safe_event_loop_iteration_success(self) -> None: - """Test safe event loop iteration success.""" - api = MagicMock() - api.stream_events.return_value = iter([]) - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - result = _safe_event_loop_iteration(ctx, game_threads, 0) - assert result == 0 - - def test_safe_event_loop_iteration_error(self) -> None: - """Test safe event loop iteration with error.""" - api = MagicMock() - api.stream_events.side_effect = requests.RequestException("error") - ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) - game_threads: GameThreads = {} - with patch("python_pkg.lichess_bot.main.backoff_sleep", return_value=5): - result = _safe_event_loop_iteration(ctx, game_threads, 2) - assert result == 5 - - -class TestRunBot: - """Tests for run_bot.""" - - def test_run_bot_no_token(self) -> None: - """Test run_bot without token raises error.""" - with ( - patch.dict(os.environ, {}, clear=True), - pytest.raises(RuntimeError, match="LICHESS_TOKEN"), - ): - run_bot() - - def test_run_bot_with_token(self) -> None: - """Test run_bot with token starts event loop.""" - - class _StopLoopError(Exception): - """Custom exception to stop the loop.""" - - def stop_loop(*_args: object, **_kwargs: object) -> None: - raise _StopLoopError - - with ( - patch.dict(os.environ, {"LICHESS_TOKEN": "test_token"}), - patch( - "python_pkg.lichess_bot.main.get_and_increment_version", - return_value=1, - ), - patch("python_pkg.lichess_bot.main.LichessAPI"), - patch("python_pkg.lichess_bot.main.RandomEngine"), - patch( - "python_pkg.lichess_bot.main._safe_event_loop_iteration", - side_effect=stop_loop, - ), - pytest.raises(_StopLoopError), - ): - run_bot("DEBUG", decline_correspondence=True) - - -class TestMain: - """Tests for main function.""" - - def test_main_parses_args(self) -> None: - """Test main parses command line arguments.""" - - class _StopExecutionError(Exception): - """Custom exception to stop execution.""" - - with ( - patch( - "sys.argv", - ["main.py", "--log-level", "DEBUG", "--decline-correspondence"], - ), - patch( - "python_pkg.lichess_bot.main.run_bot", side_effect=_StopExecutionError - ) as mock_run_bot, - pytest.raises(_StopExecutionError), - ): - main() - mock_run_bot.assert_called_once_with("DEBUG", decline_correspondence=True) diff --git a/python_pkg/lichess_bot/tests/test_main_analysis.py b/python_pkg/lichess_bot/tests/test_main_analysis.py new file mode 100644 index 0000000..811e847 --- /dev/null +++ b/python_pkg/lichess_bot/tests/test_main_analysis.py @@ -0,0 +1,409 @@ +"""Tests for lichess_bot main module: game events and analysis.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, PropertyMock, patch + +import chess +import pytest + +from python_pkg.lichess_bot.main import ( + BotContext, + GameMeta, + GameState, + _collect_analysis_lines, + _finalize_game, + _insert_analysis_into_log, + _log_analysis_progress, + _process_analysis_output, + _process_game_event, + _run_analysis_subprocess, + _write_pgn_to_log, +) + +if TYPE_CHECKING: + from pathlib import Path + +# Type alias to make mypy happy with test event dicts +Event = dict[str, Any] + + +class TestProcessGameEvent: + """Tests for _process_game_event.""" + + def test_process_game_event_unhandled_type(self) -> None: + """Test processing unhandled event type.""" + ctx = MagicMock() + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = {"type": "chatLine", "text": "hello"} + result = _process_game_event(event, ctx, state, meta) + assert result is True + + def test_process_game_event_game_full(self) -> None: + """Test processing gameFull event.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("e2e4"), + "opening", + ) + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "type": "gameFull", + "state": {"moves": "", "status": "started"}, + "white": {"id": "mybot"}, + "black": {"id": "opp"}, + } + result = _process_game_event(event, ctx, state, meta) + assert result is True + + def test_process_game_event_game_end(self) -> None: + """Test processing game end event.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = (None, "no moves") + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(color="white", last_handled_len=-1) + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "type": "gameState", + "moves": "e2e4 e7e5", + "status": "mate", + } + result = _process_game_event(event, ctx, state, meta) + assert result is False + + def test_process_game_event_game_end_after_move(self) -> None: + """Test game ends with status after handling move. + + This covers the case where _handle_move_if_needed returns True + but status indicates game end. + """ + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("d2d4"), + "response", + ) + ctx = BotContext(api=api, engine=engine, bot_version=1) + # Black's turn - it's opponent's move, so we don't need to move + state = GameState(color="black", last_handled_len=-1) + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "type": "gameState", + "moves": "e2e4", # One move - now it's black's turn + "status": "resign", # Game ended with resign + } + result = _process_game_event(event, ctx, state, meta) + assert result is False # Game should end + + def test_process_game_event_unchanged_position(self) -> None: + """Test processing event with unchanged position.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + state = GameState(last_handled_len=2, color="white") + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = {"type": "gameState", "moves": "e2e4 e7e5"} + result = _process_game_event(event, ctx, state, meta) + assert result is True + + def test_process_game_event_color_unknown(self) -> None: + """Test processing event with unknown color.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + state = GameState(last_handled_len=-1) + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = {"type": "gameState", "moves": "e2e4"} + result = _process_game_event(event, ctx, state, meta) + assert result is True + assert state.last_handled_len == 1 + + def test_process_game_event_color_unknown_on_gamefull(self) -> None: + """Test processing gameFull event with still unknown color. + + This covers the branch where event_type is gameFull but color + is not determined (e.g., spectator watching game). + """ + api = MagicMock() + # Return a user id that doesn't match either player + api.get_my_user_id.return_value = "spectator" + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + state = GameState(last_handled_len=-1) + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "type": "gameFull", + "state": {"moves": "e2e4", "status": "started"}, + "white": {"id": "player1"}, + "black": {"id": "player2"}, + } + result = _process_game_event(event, ctx, state, meta) + assert result is True + # last_handled_len should NOT be updated for gameFull with unknown color + assert state.last_handled_len == -1 + + +class TestWritePgnToLog: + """Tests for _write_pgn_to_log.""" + + def test_write_pgn_to_log(self, tmp_path: Path) -> None: + """Test writing PGN to log file.""" + log_path = tmp_path / "game.log" + log_path.write_text("header\n") + board = chess.Board() + board.push_uci("e2e4") + meta = GameMeta( + game_id="game1", + bot_version=1, + site_url="https://lichess.org/game1", + date_iso="2021.01.01", + white_name="White", + black_name="Black", + ) + _write_pgn_to_log(log_path, board, meta) + content = log_path.read_text() + assert "PGN:" in content + assert "e4" in content + + +class TestRunAnalysisSubprocess: + """Tests for _run_analysis_subprocess.""" + + def test_run_analysis_subprocess_script_not_found(self, tmp_path: Path) -> None: + """Test analysis when script not found.""" + log_path = tmp_path / "game.log" + with patch("python_pkg.lichess_bot.main.Path") as mock_path: + mock_script = MagicMock() + mock_script.is_file.return_value = False + resolve = mock_path.return_value.resolve.return_value + resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = ( + mock_script + ) + result = _run_analysis_subprocess("game1", log_path, 10) + assert result is None + + def test_run_analysis_subprocess_success(self, tmp_path: Path) -> None: + """Test successful analysis subprocess.""" + log_path = tmp_path / "game.log" + log_path.write_text("test") + + mock_proc = MagicMock() + mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"]) + mock_proc.stderr.read.return_value = "" + mock_proc.wait.return_value = 0 + mock_proc.__enter__ = MagicMock(return_value=mock_proc) + mock_proc.__exit__ = MagicMock(return_value=False) + + with ( + patch("python_pkg.lichess_bot.main.Path") as mock_path, + patch("subprocess.Popen", return_value=mock_proc), + ): + mock_script = MagicMock() + mock_script.is_file.return_value = True + resolve = mock_path.return_value.resolve.return_value + resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = ( + mock_script + ) + result = _run_analysis_subprocess("game1", log_path, 2) + + assert result is not None + + +class TestProcessAnalysisOutput: + """Tests for _process_analysis_output.""" + + def test_process_analysis_output_success(self) -> None: + """Test processing analysis output successfully.""" + mock_proc = MagicMock() + mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"]) + mock_proc.stderr.read.return_value = "" + mock_proc.wait.return_value = 0 + + result = _process_analysis_output(mock_proc, "game1", 2) + assert result is not None + assert "e4" in result + + def test_process_analysis_output_error_exit(self) -> None: + """Test processing analysis output with error exit.""" + mock_proc = MagicMock() + mock_proc.stdout = iter(["output\n"]) + mock_proc.stderr.read.return_value = "error message" + mock_proc.wait.return_value = 1 + + result = _process_analysis_output(mock_proc, "game1", 1) + assert result is not None + assert "stderr" in result + + def test_process_analysis_output_error_exit_no_stderr(self) -> None: + """Test processing analysis output with error exit but no stderr.""" + mock_proc = MagicMock() + mock_proc.stdout = iter(["output\n"]) + mock_proc.stderr.read.return_value = "" + mock_proc.wait.return_value = 1 + + result = _process_analysis_output(mock_proc, "game1", 1) + assert result is not None + assert "stderr" not in result + + def test_process_analysis_output_none_pipes(self) -> None: + """Test processing analysis output with None pipes.""" + mock_proc = MagicMock() + mock_proc.stdout = None + mock_proc.stderr = None + + with pytest.raises(RuntimeError, match="pipes unexpectedly None"): + _process_analysis_output(mock_proc, "game1", 1) + + +class TestCollectAnalysisLines: + """Tests for _collect_analysis_lines helper.""" + + def test_collect_analysis_lines_empty_iterator(self) -> None: + """Test collecting lines from empty iterator.""" + empty_iter: list[str] = [] + analyzed, lines = _collect_analysis_lines(iter(empty_iter), "game1", 10) + assert analyzed == 0 + assert lines == [] + + def test_collect_analysis_lines_with_content(self) -> None: + """Test collecting lines from iterator with content.""" + content = [" 1 e4\n", " 2 e5\n", "not a ply line\n"] + analyzed, lines = _collect_analysis_lines(iter(content), "game1", 3) + assert analyzed == 2 + assert lines == content + + def test_collect_analysis_lines_full_iteration(self) -> None: + """Test that all lines are collected.""" + content = ["line1\n", " 3 Nf3\n", "line3\n"] + analyzed, lines = _collect_analysis_lines(iter(content), "game1", 1) + assert analyzed == 1 + assert len(lines) == 3 + + +class TestLogAnalysisProgress: + """Tests for _log_analysis_progress.""" + + def test_log_analysis_progress_with_total(self) -> None: + """Test logging progress with known total.""" + with patch("python_pkg.lichess_bot.main._logger") as mock_logger: + _log_analysis_progress("game1", 5, 10) + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "50%" in call_args[0] % call_args[1:] + + def test_log_analysis_progress_zero_total(self) -> None: + """Test logging progress with zero total.""" + with patch("python_pkg.lichess_bot.main._logger") as mock_logger: + _log_analysis_progress("game1", 5, 0) + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "unknown" in call_args[0] + + +class TestInsertAnalysisIntoLog: + """Tests for _insert_analysis_into_log.""" + + def test_insert_analysis_before_pgn(self, tmp_path: Path) -> None: + """Test inserting analysis before PGN section.""" + log_path = tmp_path / "game.log" + log_path.write_text("header\n\nPGN:\n1. e4\n") + meta = GameMeta( + game_id="game1", + bot_version=1, + date_iso="2021.01.01", + white_name="White", + black_name="Black", + ) + _insert_analysis_into_log(log_path, "Analysis here", meta) + content = log_path.read_text() + assert "ANALYSIS:" in content + assert content.index("ANALYSIS:") < content.index("PGN:") + + def test_insert_analysis_at_start(self, tmp_path: Path) -> None: + """Test inserting analysis when PGN at start.""" + log_path = tmp_path / "game.log" + log_path.write_text("PGN:\n1. e4\n") + meta = GameMeta(game_id="game1", bot_version=1) + _insert_analysis_into_log(log_path, "Analysis here", meta) + content = log_path.read_text() + assert "ANALYSIS:" in content + + def test_insert_analysis_no_pgn(self, tmp_path: Path) -> None: + """Test inserting analysis when no PGN section.""" + log_path = tmp_path / "game.log" + log_path.write_text("header\n") + meta = GameMeta(game_id="game1", bot_version=1) + _insert_analysis_into_log(log_path, "Analysis here", meta) + content = log_path.read_text() + assert "ANALYSIS:" in content + + def test_insert_analysis_oserror(self, tmp_path: Path) -> None: + """Test inserting analysis with OSError.""" + log_path = tmp_path / "nonexistent" / "game.log" + meta = GameMeta(game_id="game1", bot_version=1) + # Should not raise, just log debug + _insert_analysis_into_log(log_path, "Analysis", meta) + + +class TestFinalizeGame: + """Tests for _finalize_game.""" + + def test_finalize_game_no_log_path(self) -> None: + """Test finalize game with no log path.""" + state = GameState(log_path=None) + meta = GameMeta(game_id="game1", bot_version=1) + _finalize_game(state, meta) # Should not raise + + def test_finalize_game_write_error(self, tmp_path: Path) -> None: + """Test finalize game with write error.""" + log_path = tmp_path / "game.log" + log_path.write_text("header") + state = GameState(log_path=log_path) + meta = GameMeta(game_id="game1", bot_version=1) + + with patch( + "python_pkg.lichess_bot.main._write_pgn_to_log", + side_effect=OSError("error"), + ): + _finalize_game(state, meta) # Should not raise + + def test_finalize_game_type_error_on_move_stack(self, tmp_path: Path) -> None: + """Test finalize game with TypeError on move_stack.""" + log_path = tmp_path / "game.log" + log_path.write_text("header\n") + state = GameState(log_path=log_path) + meta = GameMeta(game_id="game1", bot_version=1) + + mock_board = MagicMock() + # Use PropertyMock to raise TypeError when move_stack is accessed + type(mock_board).move_stack = PropertyMock(side_effect=TypeError()) + state.board = mock_board + + with patch("python_pkg.lichess_bot.main._write_pgn_to_log"): + _finalize_game(state, meta) # Should not raise + + def test_finalize_game_analysis_error(self, tmp_path: Path) -> None: + """Test finalize game with analysis error.""" + log_path = tmp_path / "game.log" + log_path.write_text("header\n") + state = GameState(log_path=log_path) + meta = GameMeta(game_id="game1", bot_version=1) + + with ( + patch("python_pkg.lichess_bot.main._write_pgn_to_log"), + patch( + "python_pkg.lichess_bot.main._run_analysis_subprocess", + side_effect=OSError("error"), + ), + ): + _finalize_game(state, meta) # Should not raise diff --git a/python_pkg/lichess_bot/tests/test_main_bot_loop.py b/python_pkg/lichess_bot/tests/test_main_bot_loop.py new file mode 100644 index 0000000..e9d38a6 --- /dev/null +++ b/python_pkg/lichess_bot/tests/test_main_bot_loop.py @@ -0,0 +1,474 @@ +"""Tests for lichess_bot main module: bot event loop.""" + +from __future__ import annotations + +import os +import threading +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, patch + +import chess +import pytest +import requests + +from python_pkg.lichess_bot.main import ( + BotContext, + GameMeta, + GameState, + _handle_challenge, + _handle_game, + _process_bot_event, + _process_game_events_loop, + _run_event_loop, + _run_event_loop_iteration, + _safe_event_loop_iteration, + _stream_bot_events, + main, + run_bot, +) + +if TYPE_CHECKING: + from pathlib import Path + +# Type aliases to make mypy happy with test event dicts +Event = dict[str, Any] +GameThreads = dict[str, threading.Thread] + + +class TestHandleGame: + """Tests for _handle_game.""" + + def test_handle_game_success(self, tmp_path: Path) -> None: + """Test handling a game successfully.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + api.stream_game_events.return_value = iter( + [ + { + "type": "gameFull", + "state": {"moves": "", "status": "started"}, + "white": {"id": "mybot"}, + "black": {"id": "opp"}, + }, + {"type": "gameState", "moves": "e2e4", "status": "mate"}, + ] + ) + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = (None, "no moves") + ctx = BotContext(api=api, engine=engine, bot_version=1) + + with ( + patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path), + patch( + "python_pkg.lichess_bot.main._run_analysis_subprocess", + return_value=None, + ), + ): + _handle_game("game1", ctx, None) + + def test_handle_game_request_error(self, tmp_path: Path) -> None: + """Test handling a game with request error.""" + api = MagicMock() + api.stream_game_events.side_effect = requests.RequestException("error") + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + + with ( + patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path), + patch( + "python_pkg.lichess_bot.main._run_analysis_subprocess", + return_value=None, + ), + ): + _handle_game("game1", ctx, None) # Should not raise + + def test_handle_game_skips_chat_events(self, tmp_path: Path) -> None: + """Test handling a game skips chat events.""" + api = MagicMock() + api.stream_game_events.return_value = iter( + [ + {"type": "chatLine", "text": "hello"}, + {"type": "opponentGone", "gone": True}, + ] + ) + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + + with ( + patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path), + patch( + "python_pkg.lichess_bot.main._run_analysis_subprocess", + return_value=None, + ), + ): + _handle_game("game1", ctx, None) + + +class TestProcessGameEventsLoop: + """Tests for _process_game_events_loop.""" + + def test_empty_events_iterator(self) -> None: + """Test processing empty events iterator.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + state = GameState(color="white") + meta = GameMeta(game_id="game1", bot_version=1) + + empty_iter: list[Event] = [] + # Should complete without error when iterator is empty + _process_game_events_loop(iter(empty_iter), ctx, state, meta) + + def test_processes_all_events(self) -> None: + """Test that all events are processed until break condition.""" + api = MagicMock() + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = (None, "no moves") + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(color="white") + meta = GameMeta(game_id="game1", bot_version=1) + + events: list[Event] = [ + {"type": "chatLine", "text": "hello"}, # skipped + {"type": "gameState", "moves": "e2e4", "status": "resign"}, # game end + ] + _process_game_events_loop(iter(events), ctx, state, meta) + + def test_processes_multiple_game_events(self) -> None: + """Test processing multiple game events that continue the game.""" + api = MagicMock() + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("e2e4"), + "e4", + ) + api.make_move.return_value = None + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(color="white") + state.board = chess.Board() + meta = GameMeta(game_id="game1", bot_version=1) + + events: list[Event] = [ + # First event - game state, game continues + {"type": "gameState", "moves": "", "status": "started"}, + # Second event - opponent moves, game continues + {"type": "gameState", "moves": "e2e4 e7e5", "status": "started"}, + # Third event - game ends + {"type": "gameState", "moves": "e2e4 e7e5", "status": "mate"}, + ] + _process_game_events_loop(iter(events), ctx, state, meta) + + +class TestRunEventLoop: + """Tests for _run_event_loop.""" + + def test_run_event_loop_zero_iterations(self) -> None: + """Test running event loop with zero iterations.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + + # Should complete immediately with 0 iterations + _run_event_loop(ctx, game_threads, 0, 0) + + def test_run_event_loop_limited_iterations(self) -> None: + """Test running event loop with limited iterations.""" + api = MagicMock() + api.stream_bot_events.return_value = iter([]) + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + + with patch( + "python_pkg.lichess_bot.main._safe_event_loop_iteration", return_value=0 + ) as mock_iter: + _run_event_loop(ctx, game_threads, 0, 3) + assert mock_iter.call_count == 3 + + def test_run_event_loop_none_iterations_needs_interrupt(self) -> None: + """Test that None iterations runs until interrupted.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + + call_count = 0 + + def stop_after_calls(*_args: object, **_kwargs: object) -> int: + nonlocal call_count + call_count += 1 + if call_count >= 5: + raise KeyboardInterrupt + return 0 + + with ( + patch( + "python_pkg.lichess_bot.main._safe_event_loop_iteration", + side_effect=stop_after_calls, + ), + pytest.raises(KeyboardInterrupt), + ): + _run_event_loop(ctx, game_threads, 0, None) + + assert call_count == 5 + + +class TestHandleChallenge: + """Tests for _handle_challenge.""" + + def test_accept_standard_blitz(self) -> None: + """Test accepting standard blitz challenge.""" + api = MagicMock() + challenge: Event = { + "id": "ch1", + "variant": {"key": "standard"}, + "speed": "blitz", + } + _handle_challenge(challenge, api, decline_correspondence=False) + api.accept_challenge.assert_called_once_with("ch1") + + def test_decline_variant(self) -> None: + """Test declining non-standard variant.""" + api = MagicMock() + challenge: Event = { + "id": "ch1", + "variant": {"key": "chess960"}, + "speed": "blitz", + } + _handle_challenge(challenge, api, decline_correspondence=False) + api.decline_challenge.assert_called_once() + + def test_decline_correspondence(self) -> None: + """Test declining correspondence when flag set.""" + api = MagicMock() + challenge: Event = { + "id": "ch1", + "variant": {"key": "standard"}, + "speed": "correspondence", + } + _handle_challenge(challenge, api, decline_correspondence=True) + api.decline_challenge.assert_called_once() + + def test_accept_correspondence_when_allowed(self) -> None: + """Test accepting correspondence when flag not set.""" + api = MagicMock() + challenge: Event = { + "id": "ch1", + "variant": {"key": "standard"}, + "speed": "correspondence", + } + _handle_challenge(challenge, api, decline_correspondence=False) + api.decline_challenge.assert_called_once() # Still declined due to perf_ok + + def test_invalid_variant_data(self) -> None: + """Test handling invalid variant data.""" + api = MagicMock() + challenge: Event = { + "id": "ch1", + "variant": "invalid", + "speed": "blitz", + } + _handle_challenge(challenge, api, decline_correspondence=False) + api.accept_challenge.assert_called_once() + + +class TestProcessBotEvent: + """Tests for _process_bot_event.""" + + def test_process_challenge_event(self) -> None: + """Test processing challenge event.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = { + "type": "challenge", + "challenge": { + "id": "ch1", + "variant": {"key": "standard"}, + "speed": "blitz", + }, + } + _process_bot_event(event, ctx, game_threads) + api.accept_challenge.assert_called_once() + + def test_process_game_start_event(self) -> None: + """Test processing gameStart event.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = {"type": "gameStart", "game": {"id": "game1"}} + + with patch("python_pkg.lichess_bot.main.threading.Thread") as mock_thread_class: + mock_thread = MagicMock() + mock_thread_class.return_value = mock_thread + _process_bot_event(event, ctx, game_threads) + + assert "game1" in game_threads + mock_thread.start.assert_called_once() + + def test_process_game_start_existing_thread(self) -> None: + """Test processing gameStart with existing alive thread.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + mock_thread = MagicMock(spec=threading.Thread) + mock_thread.is_alive.return_value = True + game_threads: GameThreads = {"game1": mock_thread} + event: Event = {"type": "gameStart", "game": {"id": "game1"}} + _process_bot_event(event, ctx, game_threads) + # Should not create new thread + assert game_threads["game1"] is mock_thread + + def test_process_game_finish_event(self) -> None: + """Test processing gameFinish event.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = {"type": "gameFinish", "game": {"id": "game1"}} + with patch("python_pkg.lichess_bot.main._logger") as mock_logger: + _process_bot_event(event, ctx, game_threads) + mock_logger.info.assert_called() + + def test_process_game_finish_invalid_data(self) -> None: + """Test processing gameFinish event with non-dict game data.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = {"type": "gameFinish", "game": "not_a_dict"} + with patch("python_pkg.lichess_bot.main._logger") as mock_logger: + _process_bot_event(event, ctx, game_threads) + # Should not log info since game data is invalid + mock_logger.info.assert_not_called() + + def test_process_unknown_event(self) -> None: + """Test processing unknown event.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = {"type": "unknown", "data": "test"} + with patch("python_pkg.lichess_bot.main._logger") as mock_logger: + _process_bot_event(event, ctx, game_threads) + mock_logger.debug.assert_called() + + def test_process_challenge_invalid_data(self) -> None: + """Test processing challenge with invalid data.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = {"type": "challenge", "challenge": "invalid"} + _process_bot_event(event, ctx, game_threads) + api.accept_challenge.assert_not_called() + + def test_process_game_start_invalid_data(self) -> None: + """Test processing gameStart with invalid data.""" + api = MagicMock() + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + event: Event = {"type": "gameStart", "game": "invalid"} + _process_bot_event(event, ctx, game_threads) + assert len(game_threads) == 0 + + +class TestStreamBotEvents: + """Tests for _stream_bot_events.""" + + def test_stream_bot_events(self) -> None: + """Test streaming bot events.""" + api = MagicMock() + api.stream_events.return_value = iter([{"type": "test"}]) + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + events = list(_stream_bot_events(ctx)) + assert len(events) == 1 + + +class TestRunEventLoopIteration: + """Tests for _run_event_loop_iteration.""" + + def test_run_event_loop_iteration(self) -> None: + """Test running event loop iteration.""" + api = MagicMock() + api.stream_events.return_value = iter([{"type": "unknown"}]) + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + result = _run_event_loop_iteration(ctx, game_threads) + assert result == 0 + + +class TestSafeEventLoopIteration: + """Tests for _safe_event_loop_iteration.""" + + def test_safe_event_loop_iteration_success(self) -> None: + """Test safe event loop iteration success.""" + api = MagicMock() + api.stream_events.return_value = iter([]) + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + result = _safe_event_loop_iteration(ctx, game_threads, 0) + assert result == 0 + + def test_safe_event_loop_iteration_error(self) -> None: + """Test safe event loop iteration with error.""" + api = MagicMock() + api.stream_events.side_effect = requests.RequestException("error") + ctx = BotContext(api=api, engine=MagicMock(), bot_version=1) + game_threads: GameThreads = {} + with patch("python_pkg.lichess_bot.main.backoff_sleep", return_value=5): + result = _safe_event_loop_iteration(ctx, game_threads, 2) + assert result == 5 + + +class TestRunBot: + """Tests for run_bot.""" + + def test_run_bot_no_token(self) -> None: + """Test run_bot without token raises error.""" + with ( + patch.dict(os.environ, {}, clear=True), + pytest.raises(RuntimeError, match="LICHESS_TOKEN"), + ): + run_bot() + + def test_run_bot_with_token(self) -> None: + """Test run_bot with token starts event loop.""" + + class _StopLoopError(Exception): + """Custom exception to stop the loop.""" + + def stop_loop(*_args: object, **_kwargs: object) -> None: + raise _StopLoopError + + with ( + patch.dict(os.environ, {"LICHESS_TOKEN": "test_token"}), + patch( + "python_pkg.lichess_bot.main.get_and_increment_version", + return_value=1, + ), + patch("python_pkg.lichess_bot.main.LichessAPI"), + patch("python_pkg.lichess_bot.main.RandomEngine"), + patch( + "python_pkg.lichess_bot.main._safe_event_loop_iteration", + side_effect=stop_loop, + ), + pytest.raises(_StopLoopError), + ): + run_bot("DEBUG", decline_correspondence=True) + + +class TestMain: + """Tests for main function.""" + + def test_main_parses_args(self) -> None: + """Test main parses command line arguments.""" + + class _StopExecutionError(Exception): + """Custom exception to stop execution.""" + + with ( + patch( + "sys.argv", + ["main.py", "--log-level", "DEBUG", "--decline-correspondence"], + ), + patch( + "python_pkg.lichess_bot.main.run_bot", side_effect=_StopExecutionError + ) as mock_run_bot, + pytest.raises(_StopExecutionError), + ): + main() + mock_run_bot.assert_called_once_with("DEBUG", decline_correspondence=True) diff --git a/python_pkg/lichess_bot/tests/test_main_game_state.py b/python_pkg/lichess_bot/tests/test_main_game_state.py new file mode 100644 index 0000000..b090492 --- /dev/null +++ b/python_pkg/lichess_bot/tests/test_main_game_state.py @@ -0,0 +1,403 @@ +"""Tests for lichess_bot main module: game state helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, patch + +import chess +import requests + +from python_pkg.lichess_bot.main import ( + BotContext, + GameMeta, + GameState, + _apply_move_to_board, + _attempt_move, + _calculate_time_budget, + _extract_game_full_data, + _extract_game_state_data, + _extract_player_info, + _handle_move_if_needed, + _init_game_log, + _is_my_turn, + _log_move_to_file, + _rebuild_board_from_moves, + _update_clocks_from_state, +) + +if TYPE_CHECKING: + from pathlib import Path + +# Type alias to make mypy happy with test event dicts +Event = dict[str, Any] + + +class TestApplyMoveToBoard: + """Tests for _apply_move_to_board.""" + + def test_apply_valid_move(self) -> None: + """Test applying a valid move.""" + board = chess.Board() + _apply_move_to_board(board, "e2e4", "game1") + assert board.fen() != chess.STARTING_FEN + + def test_apply_invalid_move(self) -> None: + """Test applying an invalid move logs debug.""" + board = chess.Board() + with patch("python_pkg.lichess_bot.main._logger") as mock_logger: + _apply_move_to_board(board, "invalid", "game1") + mock_logger.debug.assert_called_once() + + +class TestInitGameLog: + """Tests for _init_game_log.""" + + def test_init_game_log_success(self, tmp_path: Path) -> None: + """Test successful log initialization.""" + with patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path): + result = _init_game_log("game123", 42) + assert result is not None + assert result.exists() + content = result.read_text() + assert "game game123 started" in content + assert "bot_version v42" in content + + def test_init_game_log_oserror(self) -> None: + """Test log initialization with OSError.""" + with patch("python_pkg.lichess_bot.main.Path.cwd") as mock_cwd: + mock_path = MagicMock() + mock_path.__truediv__ = MagicMock(return_value=mock_path) + mock_path.open.side_effect = OSError("Permission denied") + mock_cwd.return_value = mock_path + result = _init_game_log("game123", 42) + assert result is None + + +class TestUpdateClocksFromState: + """Tests for _update_clocks_from_state.""" + + def test_update_clocks_white(self) -> None: + """Test clock update when playing as white.""" + state = GameState(color="white") + state_data: Event = {"wtime": 60000, "btime": 55000, "winc": 1000} + _update_clocks_from_state(state_data, state) + assert state.my_ms == 60000 + assert state.opp_ms == 55000 + assert state.inc_ms == 1000 + + def test_update_clocks_black(self) -> None: + """Test clock update when playing as black.""" + state = GameState(color="black") + state_data: Event = {"wtime": 60000, "btime": 55000, "binc": 2000} + _update_clocks_from_state(state_data, state) + assert state.my_ms == 55000 + assert state.opp_ms == 60000 + assert state.inc_ms == 2000 + + def test_update_clocks_float_values(self) -> None: + """Test clock update with float values.""" + state = GameState(color="white") + state_data: Event = {"wtime": 60000.5, "btime": 55000.5} + _update_clocks_from_state(state_data, state) + assert state.my_ms == 60000 + assert state.opp_ms == 55000 + + def test_update_clocks_none_values(self) -> None: + """Test clock update with None values.""" + state = GameState(color="white") + state_data: Event = {"wtime": None, "btime": None} + _update_clocks_from_state(state_data, state) + assert state.my_ms is None + assert state.opp_ms is None + + +class TestExtractPlayerInfo: + """Tests for _extract_player_info.""" + + def test_extract_player_info_white(self) -> None: + """Test extracting player info when bot is white.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "white": {"id": "mybot", "name": "MyBot"}, + "black": {"id": "opp", "name": "Opponent"}, + } + _extract_player_info(event, state, meta, api) + assert state.color == "white" + assert meta.white_name == "MyBot" + assert meta.black_name == "Opponent" + + def test_extract_player_info_black(self) -> None: + """Test extracting player info when bot is black.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "white": {"id": "opp", "name": "Opponent"}, + "black": {"id": "mybot", "name": "MyBot"}, + } + _extract_player_info(event, state, meta, api) + assert state.color == "black" + + def test_extract_player_info_invalid_data(self) -> None: + """Test extracting player info with invalid data.""" + api = MagicMock() + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = {"white": "invalid", "black": "invalid"} + _extract_player_info(event, state, meta, api) + assert state.color is None + + def test_extract_player_info_missing_name(self) -> None: + """Test extracting player info with missing name uses id.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "white": {"id": "mybot"}, + "black": {"id": "opponent"}, + } + _extract_player_info(event, state, meta, api) + assert meta.white_name == "mybot" + assert meta.black_name == "opponent" + + +class TestExtractGameFullData: + """Tests for _extract_game_full_data.""" + + def test_extract_game_full_data(self) -> None: + """Test extracting gameFull data.""" + api = MagicMock() + api.get_my_user_id.return_value = "mybot" + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = { + "state": {"moves": "e2e4 e7e5", "status": "started", "wtime": 60000}, + "white": {"id": "mybot"}, + "black": {"id": "opp"}, + "createdAt": 1609459200000, # 2021-01-01 + } + moves, status = _extract_game_full_data(event, state, meta, api) + assert moves == "e2e4 e7e5" + assert status == "started" + assert meta.site_url == "https://lichess.org/game1" + assert meta.date_iso == "2021.01.01" + + def test_extract_game_full_data_invalid_state(self) -> None: + """Test extracting gameFull data with invalid state.""" + api = MagicMock() + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + event: Event = {"state": "invalid"} + moves, status = _extract_game_full_data(event, state, meta, api) + assert moves == "" + assert status is None + + +class TestExtractGameStateData: + """Tests for _extract_game_state_data.""" + + def test_extract_game_state_as_white(self) -> None: + """Test extracting gameState data as white.""" + state = GameState(color="white", my_ms=60000) + event: Event = { + "moves": "e2e4", + "status": "started", + "wtime": 59000, + "btime": 60000, + } + moves, status = _extract_game_state_data(event, state) + assert moves == "e2e4" + assert status == "started" + assert state.my_ms == 59000 + assert state.opp_ms == 60000 + + def test_extract_game_state_as_black(self) -> None: + """Test extracting gameState data as black.""" + state = GameState(color="black") + event: Event = { + "moves": "e2e4 e7e5", + "wtime": 60000, + "btime": 59000, + "binc": 1000, + } + moves, __status = _extract_game_state_data(event, state) + assert moves == "e2e4 e7e5" + assert state.my_ms == 59000 + assert state.opp_ms == 60000 + assert state.inc_ms == 1000 + + +class TestCalculateTimeBudget: + """Tests for _calculate_time_budget.""" + + def test_calculate_time_budget_normal(self) -> None: + """Test time budget calculation.""" + state = GameState(my_ms=60000, inc_ms=1000) + board = chess.Board() + budget = _calculate_time_budget(state, board, 10.0) + assert 0.05 <= budget <= 10.0 + + def test_calculate_time_budget_low_time(self) -> None: + """Test time budget with low time.""" + state = GameState(my_ms=1000, inc_ms=0) + board = chess.Board() + budget = _calculate_time_budget(state, board, 10.0) + assert budget >= 0.05 + + +class TestLogMoveToFile: + """Tests for _log_move_to_file.""" + + def test_log_move_to_file(self, tmp_path: Path) -> None: + """Test logging a move to file.""" + log_path = tmp_path / "game.log" + log_path.write_text("header\n") + move = chess.Move.from_uci("e2e4") + _log_move_to_file(log_path, 1, move, "best move") + content = log_path.read_text() + assert "ply 1: e2e4" in content + assert "best move" in content + + def test_log_move_to_file_none_path(self) -> None: + """Test logging with None path does nothing.""" + move = chess.Move.from_uci("e2e4") + _log_move_to_file(None, 1, move, "reason") # Should not raise + + +class TestAttemptMove: + """Tests for _attempt_move.""" + + def test_attempt_move_success(self) -> None: + """Test successful move attempt.""" + api = MagicMock() + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("e2e4"), + "opening", + ) + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(my_ms=60000) + meta = GameMeta(game_id="game1", bot_version=1) + board = chess.Board() + + result = _attempt_move(ctx, state, meta, board) + assert result is True + api.make_move.assert_called_once() + + def test_attempt_move_no_moves(self) -> None: + """Test move attempt with no legal moves.""" + api = MagicMock() + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = (None, "no moves") + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState() + meta = GameMeta(game_id="game1", bot_version=1) + board = chess.Board() + + result = _attempt_move(ctx, state, meta, board) + assert result is False + + def test_attempt_move_illegal(self) -> None: + """Test move attempt with illegal move.""" + api = MagicMock() + engine = MagicMock() + engine.max_time_sec = 5.0 + # Return a move that's not legal (e.g., random square move) + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("a1a8"), + "bad", + ) + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(my_ms=60000) + meta = GameMeta(game_id="game1", bot_version=1) + board = chess.Board() + + result = _attempt_move(ctx, state, meta, board) + assert result is True + api.make_move.assert_not_called() + + def test_attempt_move_request_error(self) -> None: + """Test move attempt with request error.""" + api = MagicMock() + api.make_move.side_effect = requests.RequestException("Network error") + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("e2e4"), + "opening", + ) + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(my_ms=60000) + meta = GameMeta(game_id="game1", bot_version=1) + board = chess.Board() + + result = _attempt_move(ctx, state, meta, board) + assert result is True # Still returns True + + +class TestIsMyTurn: + """Tests for _is_my_turn.""" + + def test_is_my_turn_white_to_move(self) -> None: + """Test checking turn when white to move.""" + board = chess.Board() # White to move + assert _is_my_turn(board, "white") is True + assert _is_my_turn(board, "black") is False + + def test_is_my_turn_black_to_move(self) -> None: + """Test checking turn when black to move.""" + board = chess.Board() + board.push_uci("e2e4") # Black to move + assert _is_my_turn(board, "white") is False + assert _is_my_turn(board, "black") is True + + +class TestRebuildBoardFromMoves: + """Tests for _rebuild_board_from_moves.""" + + def test_rebuild_board_from_moves(self) -> None: + """Test rebuilding board from moves list.""" + moves_list = ["e2e4", "e7e5", "g1f3"] + board = _rebuild_board_from_moves(moves_list, "game1") + assert len(board.move_stack) == 3 + + +class TestHandleMoveIfNeeded: + """Tests for _handle_move_if_needed.""" + + def test_handle_move_game_state_my_turn(self) -> None: + """Test handling move on gameState when my turn.""" + api = MagicMock() + engine = MagicMock() + engine.max_time_sec = 5.0 + engine.choose_move_with_explanation.return_value = ( + chess.Move.from_uci("e2e4"), + "opening", + ) + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(color="white", my_ms=60000, board=chess.Board()) + meta = GameMeta(game_id="game1", bot_version=1) + + result = _handle_move_if_needed(ctx, state, meta, "gameState", 0) + assert result is True + + def test_handle_move_game_full_with_moves(self) -> None: + """Test handling move on gameFull with existing moves (opponent's turn).""" + api = MagicMock() + engine = MagicMock() + ctx = BotContext(api=api, engine=engine, bot_version=1) + state = GameState(color="white", my_ms=60000, board=chess.Board()) + meta = GameMeta(game_id="game1", bot_version=1) + + # gameFull with moves - don't move + result = _handle_move_if_needed(ctx, state, meta, "gameFull", 1) + assert result is True + engine.choose_move_with_explanation.assert_not_called() diff --git a/python_pkg/praca_magisterska_video/_q23_classical.py b/python_pkg/praca_magisterska_video/_q23_classical.py new file mode 100644 index 0000000..acf5b07 --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q23_classical.py @@ -0,0 +1,445 @@ +"""Classical segmentation methods: concept, thresholding, region growing, watershed.""" + +from __future__ import annotations + +from moviepy import ( + CompositeVideoClip, + VideoClip, +) +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + +from python_pkg.praca_magisterska_video._q23_helpers import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _tc, +) + + +# ── Segmentation concept ───────────────────────────────────────── +def _segmentation_concept() -> list[CompositeVideoClip]: + """Show what segmentation is: pixel-level labeling.""" + slides = [] + + # Synthetic image: grid of colored pixels + def make_image_frame(_t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + # Draw a small "image" grid + grid_x, grid_y = 100, 150 + cell = 40 + # Sky (top rows) + colors_map = [ + [(135, 206, 235)] * 8, # sky + [(135, 206, 235)] * 5 + [(34, 139, 34)] * 3, # sky + tree + [(34, 139, 34)] * 3 + + [(128, 128, 128)] * 3 + + [(34, 139, 34)] * 2, # tree+road+tree + [(128, 128, 128)] * 3 + + [(200, 50, 50)] * 2 + + [(128, 128, 128)] * 3, # road+car+road + ] + labels_map = [ + ["niebo"] * 8, + ["niebo"] * 5 + ["drzewo"] * 3, + ["drzewo"] * 3 + ["droga"] * 3 + ["drzewo"] * 2, + ["droga"] * 3 + ["samochód"] * 2 + ["droga"] * 3, + ] + label_colors = { + "niebo": (100, 180, 255), + "drzewo": (50, 200, 50), + "droga": (180, 180, 180), + "samochód": (255, 80, 80), + } + + for r, row in enumerate(colors_map): + for c, col in enumerate(row): + y = grid_y + r * cell + x = grid_x + c * cell + frame[y : y + cell - 2, x : x + cell - 2] = col + + # Draw segmentation map on the right + seg_x = 600 + for r, row in enumerate(labels_map): + for c, lab in enumerate(row): + y = grid_y + r * cell + x = seg_x + c * cell + frame[y : y + cell - 2, x : x + cell - 2] = label_colors[lab] + + return frame + + image_clip = VideoClip(make_image_frame, duration=STEP_DUR).with_fps(FPS) + labels_text = [ + ("Obraz wejściowy", 22, "white", FONT_B, (170, 100)), + ("Mapa segmentacji", 22, "white", FONT_B, (660, 100)), + ("→", 50, "#FFE082", FONT_B, (450, 250)), + ("Każdy piksel → etykieta klasy", 20, "#B0BEC5", FONT_R, (100, 420)), + ("niebo | drzewo | droga | samochód", 18, "#90CAF9", FONT_R, (600, 420)), + ("Segmentacja = klasyfikacja per-piksel", 24, "#FFE082", FONT_B, (100, 500)), + ( + "Semantic: klasy bez instancji | Instance: " + "rozróżnia obiekty | Panoptic: oba", + 16, + "#78909C", + FONT_R, + (100, 560), + ), + ] + clips: list[VideoClip] = [image_clip] + for text, fs, color, font, pos in labels_text: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + clips.append(tc) + + slides.append( + CompositeVideoClip(clips, size=(W, H)).with_effects([FadeIn(0.3), FadeOut(0.3)]) + ) + return slides + + +# ── Thresholding / Otsu ─────────────────────────────────────────── +def _thresholding_demo() -> list[CompositeVideoClip]: + """Animate thresholding and Otsu concept.""" + slides = [] + + # Show histogram & threshold + def make_threshold_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + # Draw bimodal histogram bars + bar_start_x = 80 + bar_y = 500 + bar_w = 4 + + for i in range(256): + # Bimodal: peaks at 60 and 190 + h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2)) + h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2)) + bar_h = int(h1 + h2) + x = bar_start_x + i * bar_w + if x + bar_w < W: + frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (150, 150, 170) + + # Animated threshold line + threshold = int(60 + (190 - 60) * min(t / (STEP_DUR * 0.7), 1.0)) + tx = bar_start_x + threshold * bar_w + if tx < W: + frame[bar_y - 250 : bar_y + 10, tx : tx + 3] = (255, 80, 80) + + # Color the two sides + for i in range(threshold): + x = bar_start_x + i * bar_w + h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2)) + h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2)) + bar_h = int(h1 + h2) + if x + bar_w < W and bar_h > 0: + frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (70, 130, 200) + + for i in range(threshold, 256): + x = bar_start_x + i * bar_w + h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2)) + h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2)) + bar_h = int(h1 + h2) + if x + bar_w < W and bar_h > 0: + frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (200, 100, 80) + + return frame + + hist_clip = VideoClip(make_threshold_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [hist_clip] + labels = [ + ("Progowanie (Thresholding) z metodą Otsu", 28, "#FFE082", FONT_B, (80, 30)), + ( + "Histogram jasności pikseli — dwumodalny (bimodal)", + 20, + "#B0BEC5", + FONT_R, + (80, 80), + ), + ("Garb 1: piksele obiektu (ciemne ~60)", 16, "#64B5F6", FONT_R, (80, 120)), + ("Garb 2: piksele tła (jasne ~190)", 16, "#EF9A9A", FONT_R, (80, 150)), + ( + "Próg T (czerwona linia) dzieli piksele na 2 klasy", + 18, + "white", + FONT_R, + (80, 540), + ), + ( + "Otsu: automatycznie testuje T=0..255, minimalizuje σ² wewnątrzklasową", + 16, + "#A5D6A7", + FONT_R, + (80, 580), + ), + ( + "Piksel ≤ T → klasa 0 (tło) | Piksel > T → klasa 1 (obiekt)", + 16, + "#78909C", + FONT_R, + (80, 620), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides + + +# ── Region Growing ──────────────────────────────────────────────── +def _region_growing_demo() -> list[CompositeVideoClip]: + """Animate region growing BFS from a seed pixel.""" + slides = [] + + grid_size = 10 + cell_size = 40 + rng = np.random.default_rng(42) + # Create a simple grid: dark region (30-80) and bright region (160-220) + grid = np.zeros((grid_size, grid_size), dtype=np.uint8) + grid[:] = 60 # dark background + grid[2:7, 3:8] = 180 # bright rectangle + + # Add some noise + noise = rng.integers(-15, 15, (grid_size, grid_size)) + grid = np.clip(grid.astype(int) + noise, 0, 255).astype(np.uint8) + + # BFS steps from seed (4, 5) + seed = (4, 5) + threshold_val = 50 + visited_order: list[tuple[int, int]] = [] + queue = [seed] + visited_set = {seed} + while queue: + r, c = queue.pop(0) + visited_order.append((r, c)) + for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: + nr, nc = r + dr, c + dc + if ( + 0 <= nr < grid_size + and 0 <= nc < grid_size + and (nr, nc) not in visited_set + ) and abs(int(grid[nr, nc]) - int(grid[seed])) < threshold_val: + visited_set.add((nr, nc)) + queue.append((nr, nc)) + + def make_region_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + ox, oy = 100, 180 + + # How many cells to show as visited + progress = min(t / (STEP_DUR * 0.8), 1.0) + n_visited = int(progress * len(visited_order)) + + for r in range(grid_size): + for c in range(grid_size): + x = ox + c * cell_size + y = oy + r * cell_size + val = grid[r, c] + color = (val, val, val) + + # Highlight visited + if (r, c) in visited_order[:n_visited]: + color = (80, 200, 120) # green for region + elif (r, c) == seed: + color = (255, 200, 50) # yellow seed + + frame[y : y + cell_size - 2, x : x + cell_size - 2] = color + + # Mark the seed with a bright border + sx = ox + seed[1] * cell_size + sy = ox + seed[0] * cell_size + 80 + frame[sy : sy + cell_size, sx : sx + 2] = (255, 200, 50) + frame[sy : sy + cell_size, sx + cell_size - 2 : sx + cell_size] = (255, 200, 50) + frame[sy : sy + 2, sx : sx + cell_size] = (255, 200, 50) + frame[sy + cell_size - 2 : sy + cell_size, sx : sx + cell_size] = (255, 200, 50) + + return frame + + region_clip = VideoClip(make_region_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [region_clip] + labels = [ + ("Region Growing — rozrastanie regionu", 28, "#FFE082", FONT_B, (100, 30)), + ("Seed (ziarno) → BFS do podobnych sąsiadów", 20, "#B0BEC5", FONT_R, (100, 80)), + ( + "Żółty = seed | Zielony = region | Szary = nieodwiedzone", + 16, + "#78909C", + FONT_R, + (100, 120), + ), + ( + "Sąsiad PODOBNY (|jasność - jasność_regionu| < próg) → dodaj do regionu", + 16, + "#A5D6A7", + FONT_R, + (100, 600), + ), + ( + "Algorytm zatrzymuje się gdy brak podobnych sąsiadów", + 16, + "#90CAF9", + FONT_R, + (100, 640), + ), + ( + "Mnemonik: PLAMA atramentu — rozlewa się na podobne piksele", + 18, + "#EF9A9A", + FONT_R, + (100, 670), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides + + +# ── Watershed ───────────────────────────────────────────────────── +def _watershed_demo() -> list[CompositeVideoClip]: + """Animate watershed flooding concept.""" + slides = [] + + def make_watershed_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + # Draw terrain profile (1D cross-section) + ox, oy = 100, 450 + terrain_w = 900 + terrain_points = 100 + + xs = np.linspace(0, 1, terrain_points) + # Two valleys with a ridge + terrain = ( + 120 * np.exp(-((xs - 0.25) ** 2) / 0.005) + + 80 * np.exp(-((xs - 0.75) ** 2) / 0.008) + + 30 + ) + terrain = 250 - terrain # invert for visual (valleys at bottom) + + # Water level rises over time + water_level = int(160 + 80 * min(t / (STEP_DUR * 0.7), 1.0)) + + for i in range(terrain_points - 1): + x1 = ox + int(xs[i] * terrain_w) + x2 = ox + int(xs[i + 1] * terrain_w) + y1 = oy - int(terrain[i]) + y2 = oy - int(terrain[i + 1]) + + # Fill terrain + for x in range(x1, min(x2 + 1, W)): + top = min(y1, y2) - 5 + frame[top:oy, x : x + 1] = (100, 80, 60) + + # Fill water + water_y = oy - water_level + for x in range(x1, min(x2 + 1, W)): + t_y = oy - int(terrain[i]) + if water_y < t_y: + # Water fills below terrain surface + fill_top = max(water_y, 0) + fill_bot = min(t_y, oy) + if fill_top < fill_bot: + frame[fill_top:fill_bot, x : x + 1] = (70, 130, 220) + + # Dam marker at ridge + ridge_x = ox + int(0.5 * terrain_w) + dam_visible_threshold = 160 + if water_level > dam_visible_threshold: + frame[oy - water_level : oy - 140, ridge_x - 2 : ridge_x + 2] = ( + 255, + 80, + 80, + ) + + return frame + + ws_clip = VideoClip(make_watershed_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [ws_clip] + labels = [ + ("Watershed — metoda zlewiska", 28, "#FFE082", FONT_B, (100, 20)), + ( + "Obraz = mapa topograficzna (jasność = wysokość)", + 20, + "#B0BEC5", + FONT_R, + (100, 65), + ), + ( + "Brązowy = teren (ciemne=doliny, jasne=szczyty)", + 16, + "#8D6E63", + FONT_R, + (100, 100), + ), + ("Niebieski = woda zalewająca od minimów", 16, "#64B5F6", FONT_R, (100, 130)), + ( + "Czerwony = TAMA (granica segmentu) — gdy woda z 2 dolin się spotka", + 16, + "#EF9A9A", + FONT_R, + (100, 160), + ), + ( + "Problem: over-segmentation " + "(za dużo regionów). " + "Rozwiązanie: marker-controlled.", + 16, + "#A5D6A7", + FONT_R, + (100, 560), + ), + ( + "Mnemonik: ZALEWANIE terenu — granie gór = granice segmentów", + 18, + "#FFE082", + FONT_R, + (100, 600), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides diff --git a/python_pkg/praca_magisterska_video/_q23_deeplab.py b/python_pkg/praca_magisterska_video/_q23_deeplab.py new file mode 100644 index 0000000..e1a5ff1 --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q23_deeplab.py @@ -0,0 +1,248 @@ +"""DeepLab architecture animations for Q23 segmentation video.""" + +from __future__ import annotations + +from moviepy import ( + CompositeVideoClip, + VideoClip, +) +import numpy as np + +from python_pkg.praca_magisterska_video._q23_helpers import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _compose_slide, +) + + +# ── DeepLab Architecture ───────────────────────────────────────── +def _make_dilated_frame(t: float) -> np.ndarray: + """Render a dilated convolution comparison frame.""" + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + cell = 36 + grids = [ + ( + "rate=1", + 60, + [ + (0, 0), + (0, 1), + (0, 2), + (1, 0), + (1, 1), + (1, 2), + (2, 0), + (2, 1), + (2, 2), + ], + ), + ( + "rate=2", + 420, + [ + (0, 0), + (0, 2), + (0, 4), + (2, 0), + (2, 2), + (2, 4), + (4, 0), + (4, 2), + (4, 4), + ], + ), + ( + "rate=3", + 820, + [ + (0, 0), + (0, 3), + (0, 6), + (3, 0), + (3, 3), + (3, 6), + (6, 0), + (6, 3), + (6, 6), + ], + ), + ] + + for gi, (_label, gx, positions) in enumerate(grids): + if progress < gi * 0.3: + break + gy = 180 + grid_size = 7 + for r in range(grid_size): + for c in range(grid_size): + x = gx + c * cell + y = gy + r * cell + frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55) + for r, c in positions: + x = gx + c * cell + y = gy + r * cell + frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200) + frame[y : y + 2, x : x + cell - 2] = (120, 180, 255) + frame[y + cell - 4 : y + cell - 2, x : x + cell - 2] = (120, 180, 255) + + return frame + + +def _make_aspp_frame(t: float) -> np.ndarray: + """Render a single ASPP module animation frame.""" + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + frame[250:330, 50:130] = (70, 130, 200) + frame[250:252, 50:130] = (120, 180, 255) + frame[328:330, 50:130] = (120, 180, 255) + + branches = [ + ("1x1 conv", 250, (200, 170), (100, 40), (80, 200, 120)), + ("rate=6", 310, (200, 250), (100, 40), (200, 160, 80)), + ("rate=12", 370, (200, 330), (100, 40), (200, 120, 60)), + ("rate=18", 430, (200, 410), (100, 40), (180, 100, 80)), + ("GAP", 490, (200, 490), (100, 40), (160, 80, 160)), + ] + n_branches = min(int(progress * 5) + 1, 5) + for i, (_lbl, _h, (bx, by), (bw, bh), color) in enumerate(branches): + if i < n_branches: + frame[by : by + bh, bx : bx + bw] = color + frame[by : by + 2, bx : bx + bw] = tuple(min(c + 50, 255) for c in color) + ay = by + bh // 2 + frame[ay - 1 : ay + 2, 133:197] = (150, 150, 170) + + concat_phase = 0.6 + if progress > concat_phase: + frame[250:530, 380:420] = (50, 60, 80) + frame[250:252, 380:420] = (200, 200, 100) + frame[528:530, 380:420] = (200, 200, 100) + for i, (_lbl, _h, (bx, by), (bw, bh), _c) in enumerate(branches): + if i < n_branches: + ay = by + bh // 2 + frame[ay - 1 : ay + 2, bx + bw + 3 : 378] = (150, 150, 170) + + final_conv_phase = 0.8 + if progress > final_conv_phase: + frame[350:420, 450:550] = (100, 200, 100) + frame[350:352, 450:550] = (150, 230, 150) + frame[418:420, 450:550] = (150, 230, 150) + frame[388:391, 423:448] = (150, 150, 170) + + return frame + + +def _deeplab_demo() -> list[CompositeVideoClip]: + """Animate DeepLab: dilated convolution + ASPP step by step.""" + dur = STEP_DUR + 1 + + # Slide 1: Regular vs Dilated convolution + dil_clip = VideoClip(_make_dilated_frame, duration=dur).with_fps(FPS) + labels = [ + ("DeepLab: Atrous (Dilated) Convolution", 26, "#FFE082", FONT_B, (80, 20)), + ( + "KROK 1: Zrozum dilated convolution — filtr z DZIURAMI", + 18, + "#A5D6A7", + FONT_R, + (80, 60), + ), + ("rate=1 (zwykła)", 14, "#64B5F6", FONT_B, (60, 160)), + ("RF = 3x3", 14, "#64B5F6", FONT_R, (60, 440)), + ("9 wag, kontekst 3px", 12, "#78909C", FONT_R, (60, 470)), + ("rate=2 (dilated)", 14, "#FFE082", FONT_B, (420, 160)), + ("RF = 5x5", 14, "#FFE082", FONT_R, (420, 440)), + ("9 wag, kontekst 5px!", 12, "#78909C", FONT_R, (420, 470)), + ("rate=3 (dilated)", 14, "#A5D6A7", FONT_B, (820, 160)), + ("RF = 7x7", 14, "#A5D6A7", FONT_R, (820, 440)), + ("9 wag, kontekst 7px!", 12, "#78909C", FONT_R, (820, 470)), + ( + "Niebieski = pozycja wag filtra 3x3 | Szary = pominięte (dziury)", + 15, + "#B0BEC5", + FONT_R, + (80, 510), + ), + ( + "TE SAME 9 wag → WIĘKSZE pole widzenia " + "→ lepszy kontekst BEZ dodatkowych parametrów!", + 16, + "white", + FONT_R, + (80, 550), + ), + ( + "Mnemonik: DZIURY w filtrze — à trous = z dziurami (fr.)", + 16, + "#FFE082", + FONT_R, + (80, 600), + ), + ] + slides = [_compose_slide(dil_clip, labels, dur)] + + # Slide 2: ASPP module step by step + aspp_clip = VideoClip(_make_aspp_frame, duration=dur).with_fps(FPS) + labels2 = [ + ( + "DeepLab: ASPP (Atrous Spatial Pyramid Pooling)", + 24, + "#FFE082", + FONT_B, + (80, 20), + ), + ( + "KROK 2: Multi-scale — analizuj obraz na WIELU skalach naraz", + 17, + "#A5D6A7", + FONT_R, + (80, 60), + ), + ("Wejście", 13, "#64B5F6", FONT_B, (55, 235)), + ("Conv 1x1", 12, "white", FONT_R, (210, 178)), + ("Dilated r=6", 12, "white", FONT_R, (205, 258)), + ("Dilated r=12", 12, "white", FONT_R, (203, 338)), + ("Dilated r=18", 12, "white", FONT_R, (203, 418)), + ("GAP (global)", 12, "white", FONT_R, (205, 498)), + ("Concat", 13, "#FFE082", FONT_B, (381, 537)), + ("Conv", 13, "#A5D6A7", FONT_B, (470, 425)), + ( + "5 gałęzi RÓWNOLEGŁYCH → różne skale kontekstu:", + 16, + "#B0BEC5", + FONT_R, + (550, 170), + ), + (" 1x1: kontekst punktowy (piksel)", 14, "#A5D6A7", FONT_R, (560, 210)), + (" r=6: kontekst lokalny (~13px)", 14, "#FFE082", FONT_R, (560, 245)), + (" r=12: kontekst średni (~25px)", 14, "#FFE082", FONT_R, (560, 280)), + (" r=18: kontekst szeroki (~37px)", 14, "#FFE082", FONT_R, (560, 315)), + (" GAP: kontekst GLOBALNY (cały obraz)", 14, "#CE93D8", FONT_R, (560, 350)), + ("Concat → 1x1 conv → mapa segmentacji", 16, "#A5D6A7", FONT_R, (550, 400)), + ( + "Efekt: sieć widzi OD piksela DO całego obrazu naraz!", + 17, + "white", + FONT_R, + (80, 600), + ), + ( + "Mnemonik: ASPP = Piramida z DZIURAMI, patrzy na 5 skal jednocześnie", + 15, + "#FFE082", + FONT_R, + (80, 645), + ), + ] + slides.append(_compose_slide(aspp_clip, labels2, dur)) + + return slides diff --git a/python_pkg/praca_magisterska_video/_q23_helpers.py b/python_pkg/praca_magisterska_video/_q23_helpers.py new file mode 100644 index 0000000..32a5757 --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q23_helpers.py @@ -0,0 +1,116 @@ +"""Shared constants and helper functions for Q23 segmentation video.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +import numpy as np + +os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg" + +from moviepy import ( + ColorClip, + CompositeVideoClip, + TextClip, + VideoClip, +) +from moviepy.video.fx import FadeIn, FadeOut + +# ── Constants ───────────────────────────────────────────────────── +W, H = 1280, 720 +FPS = 24 +STEP_DUR = 7.0 +HEADER_DUR = 4.0 +FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf" +FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf" +OUTPUT_DIR = Path(__file__).resolve().parent / "videos" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) +OUTPUT = str(OUTPUT_DIR / "q23_segmentation.mp4") + +logging.basicConfig(level=logging.INFO) +_logger = logging.getLogger(__name__) + +BG_COLOR = (15, 20, 35) +rng = np.random.default_rng(42) + + +def _tc(**kwargs: object) -> TextClip: + """TextClip wrapper that adds enough bottom margin to prevent clipping.""" + fs = kwargs.get("font_size", 24) + m = int(fs) // 3 + 2 + kwargs["margin"] = (0, m) + return TextClip(**kwargs) + + +def _make_header( + title: str, subtitle: str, duration: float = HEADER_DUR +) -> CompositeVideoClip: + bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) + t = ( + _tc( + text=title, + font_size=48, + color="white", + font=FONT_B, + ) + .with_duration(duration) + .with_position(("center", 260)) + ) + s = ( + _tc( + text=subtitle, + font_size=24, + color="#90CAF9", + font=FONT_R, + ) + .with_duration(duration) + .with_position(("center", 340)) + ) + return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects( + [FadeIn(0.5), FadeOut(0.5)] + ) + + +def _text_slide( + lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]], + duration: float = STEP_DUR, +) -> CompositeVideoClip: + """Create a slide with multiple text elements.""" + bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) + clips: list[VideoClip] = [bg] + for text, font_size, color, font, pos in lines: + tc = ( + _tc( + text=text, + font_size=font_size, + color=color, + font=font, + ) + .with_duration(duration) + .with_position(pos) + ) + clips.append(tc) + return CompositeVideoClip(clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + + +def _compose_slide( + base_clip: VideoClip, + labels: list[tuple[str, int, str, str, tuple[int, int]]], + duration: float, +) -> CompositeVideoClip: + """Overlay text labels on an animated base clip.""" + text_clips: list[VideoClip] = [base_clip] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(duration) + .with_position(pos) + ) + text_clips.append(tc) + return CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) diff --git a/python_pkg/praca_magisterska_video/_q23_transformer.py b/python_pkg/praca_magisterska_video/_q23_transformer.py new file mode 100644 index 0000000..55f9fcf --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q23_transformer.py @@ -0,0 +1,430 @@ +"""Transformer segmentation and methods comparison for Q23 video.""" + +from __future__ import annotations + +from moviepy import ( + ColorClip, + CompositeVideoClip, + VideoClip, +) +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + +from python_pkg.praca_magisterska_video._q23_helpers import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _compose_slide, + _tc, + _text_slide, +) + + +# ── Transformer Segmentation ──────────────────────────────────── +def _draw_base_grid( + frame: np.ndarray, + gx: int, + gy: int, + grid_n: int, + cell: int, +) -> None: + """Draw an empty grid of cells.""" + for r in range(grid_n): + for c in range(grid_n): + x = gx + c * cell + y = gy + r * cell + frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55) + + +def _draw_cnn_kernel( + frame: np.ndarray, + lx: int, + ly: int, + cell: int, + progress: float, +) -> None: + """Highlight a 3x3 CNN kernel on the grid.""" + cnn_phase = 0.2 + if progress <= cnn_phase: + return + cx, cy = 2, 2 + for dr in range(-1, 2): + for dc in range(-1, 2): + r, c = cy + dr, cx + dc + x = lx + c * cell + y = ly + r * cell + frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200) + x = lx + cx * cell + y = ly + cy * cell + frame[y : y + cell - 2, x : x + cell - 2] = (120, 180, 255) + + +def _draw_conn_line( + frame: np.ndarray, + x0: int, + y0: int, + x1: int, + y1: int, +) -> None: + """Draw a dashed connection line between two points.""" + steps = max(abs(x1 - x0), abs(y1 - y0)) + if steps <= 0: + return + for s in range(0, steps, 3): + px = x0 + int((x1 - x0) * s / steps) + py = y0 + int((y1 - y0) * s / steps) + if 0 <= px < W - 1 and 0 <= py < H - 1: + frame[py : py + 1, px : px + 1] = (200, 180, 50) + + +def _draw_attention_connections( + frame: np.ndarray, + origin: tuple[int, int], + grid_n: int, + cell: int, + progress: float, +) -> None: + """Draw transformer self-attention connections on the grid.""" + rx, ry = origin + transformer_phase = 0.4 + if progress <= transformer_phase: + return + cx_t, cy_t = 2, 2 + x0 = rx + cx_t * cell + cell // 2 + y0 = ry + cy_t * cell + cell // 2 + n_connections = int(progress * 36) + conn_idx = 0 + for r in range(grid_n): + for c in range(grid_n): + conn_idx += 1 + if conn_idx > n_connections: + break + x = rx + c * cell + y = ry + r * cell + dist = abs(r - cy_t) + abs(c - cx_t) + strength = max(30, 200 - dist * 30) + frame[y : y + cell - 2, x : x + cell - 2] = ( + strength // 3, + strength // 2, + strength, + ) + _draw_conn_line(frame, x0, y0, x + cell // 2, y + cell // 2) + else: + continue + break + x = rx + cx_t * cell + y = ry + cy_t * cell + frame[y : y + cell - 2, x : x + cell - 2] = (255, 200, 50) + + +def _make_attention_frame(t: float) -> np.ndarray: + """Render a CNN-vs-Transformer attention comparison frame.""" + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + cell = 40 + grid_n = 6 + + lx, ly = 60, 200 + _draw_base_grid(frame, lx, ly, grid_n, cell) + _draw_cnn_kernel(frame, lx, ly, cell, progress) + + rx, ry = 680, 200 + _draw_base_grid(frame, rx, ry, grid_n, cell) + _draw_attention_connections(frame, (rx, ry), grid_n, cell, progress) + + return frame + + +def _transformer_seg_demo() -> list[CompositeVideoClip]: + """Animate transformer-based segmentation: self-attention concept.""" + dur = STEP_DUR + 1 + + # Slide 1: CNN local vs Transformer global + att_clip = VideoClip(_make_attention_frame, duration=dur).with_fps(FPS) + labels = [ + ("Transformer: Self-Attention w segmentacji", 26, "#FFE082", FONT_B, (80, 20)), + ("CNN = LOKALNY kontekst", 18, "#64B5F6", FONT_B, (60, 160)), + ("Transformer = GLOBALNY kontekst", 18, "#FFE082", FONT_B, (680, 160)), + ("Filtr 3x3 widzi", 14, "#64B5F6", FONT_R, (60, 460)), + ("TYLKO 9 sąsiadów", 14, "#64B5F6", FONT_R, (60, 485)), + ("Self-attention: każdy", 14, "#FFE082", FONT_R, (680, 460)), + ("piksel widzi WSZYSTKIE!", 14, "#FFE082", FONT_R, (680, 485)), + ("vs", 28, "#B0BEC5", FONT_B, (450, 300)), + ] + slides = [_compose_slide(att_clip, labels, dur)] + + # Slide 2: Self-attention Q/K/V step by step + qkv_lines = [ + ("Self-Attention: Q / K / V krok po kroku", 26, "#FFE082", FONT_B, (80, 30)), + ("Każdy piksel (token) tworzy 3 wektory:", 18, "#B0BEC5", FONT_R, (100, 100)), + ( + " Q (Query) = 'czego szukam?' - pytanie piksela", + 17, + "#64B5F6", + FONT_R, + (120, 145), + ), + ( + " K (Key) = 'co oferuj\u0119?' - odpowied\u017a piksela", + 17, + "#A5D6A7", + FONT_R, + (120, 185), + ), + ( + " V (Value) = 'moja warto\u015b\u0107' - informacja do przekazania", + 17, + "#FFE082", + FONT_R, + (120, 225), + ), + ("Algorytm attention:", 18, "#B0BEC5", FONT_R, (100, 285)), + ( + " 1. Mnożenie Q x K\u1d40 → macierz NxN (kto ważny dla kogo)", + 16, + "white", + FONT_R, + (120, 320), + ), + ( + " 2. Skalowanie: / \u221ad (stabilno\u015b\u0107 gradient\u00f3w)", + 16, + "white", + FONT_R, + (120, 355), + ), + ( + " 3. Softmax \u2192 wagi attention (sumuj\u0105 si\u0119 do 1)", + 16, + "white", + FONT_R, + (120, 390), + ), + ( + " 4. Mno\u017cenie wag x V \u2192 wa\u017cona suma warto\u015bci", + 16, + "white", + FONT_R, + (120, 425), + ), + ( + "Attention(Q,K,V) = softmax(Q \u00b7 K\u1d40 / \u221ad) \u00b7 V", + 20, + "#FFE082", + FONT_B, + (100, 480), + ), + ( + "Z\u0142o\u017cono\u015b\u0107: O(n\u00b2) pami\u0119ci \u2014 n = liczba pikseli/token\u00f3w", + 16, + "#EF9A9A", + FONT_R, + (100, 535), + ), + ( + "Dlatego SegFormer u\u017cywa efficient attention (liniowa z\u0142o\u017cono\u015b\u0107)", + 15, + "#78909C", + FONT_R, + (100, 570), + ), + ( + "SegFormer (2021): lightweight + hierarchiczny encoder", + 16, + "#A5D6A7", + FONT_R, + (100, 610), + ), + ( + "Mask2Former (2022): masked attention + " + "unified (semantic+instance+panoptic)", + 16, + "#CE93D8", + FONT_R, + (100, 645), + ), + ] + slides.append(_text_slide(qkv_lines, duration=STEP_DUR + 1)) + + # Slide 3: Encoder-Decoder in DL summary + summary_lines = [ + ( + "Podsumowanie: Encoder-Decoder w segmentacji DL", + 24, + "#FFE082", + FONT_B, + (80, 30), + ), + ( + "Wsp\u00f3lna idea WSZYSTKICH sieci segmentacji:", + 18, + "#B0BEC5", + FONT_R, + (80, 90), + ), + ( + "Encoder: obraz \u2192 cechy (zmniejsza rozdzielczo\u015b\u0107, wyci\u0105ga CO)", + 16, + "#64B5F6", + FONT_R, + (100, 140), + ), + ( + "Decoder: cechy \u2192 mapa (zwi\u0119ksza rozdzielczo\u015b\u0107, odtwarza GDZIE)", + 16, + "#A5D6A7", + FONT_R, + (100, 175), + ), + ( + "Skip: przenosi detale z encodera do decodera", + 16, + "#FFE082", + FONT_R, + (100, 210), + ), + ("", 10, "white", FONT_R, (100, 240)), + ( + "FCN (2015): Conv1x1 + skip \u2192 pierwsza end-to-end", + 16, + "#64B5F6", + FONT_R, + (100, 275), + ), + ( + "U-Net (2015): U-shape + skip concat \u2192 segmentacja medyczna", + 16, + "#A5D6A7", + FONT_R, + (100, 310), + ), + ( + "DeepLab (2018): dilated conv + ASPP \u2192 multi-scale kontekst", + 16, + "#FFE082", + FONT_R, + (100, 345), + ), + ( + "SegFormer: transformer encoder (globalny kontekst)", + 16, + "#CE93D8", + FONT_R, + (100, 380), + ), + ( + "Mask2Former: masked attention (unified, SOTA)", + 16, + "#CE93D8", + FONT_R, + (100, 415), + ), + ("", 10, "white", FONT_R, (100, 440)), + ( + "Ewolucja: wi\u0119cej kontekstu + lepsze skip connections:", + 17, + "white", + FONT_R, + (80, 465), + ), + ( + " CNN lokal. \u2192 dilated (szersze RF) \u2192 transformer (global) \u2192 masked att.", + 16, + "#B0BEC5", + FONT_R, + (80, 505), + ), + ( + " addition skip \u2192 concat skip \u2192 cross-attention skip", + 16, + "#B0BEC5", + FONT_R, + (80, 540), + ), + ( + "Metryki: mIoU (standard), Dice (medycyna), Focal Loss (imbalance)", + 16, + "#90CAF9", + FONT_R, + (80, 590), + ), + ( + "Loss: Cross-Entropy per piksel + opcjonalnie Dice/Focal", + 15, + "#78909C", + FONT_R, + (80, 625), + ), + ] + slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1)) + + return slides + + +# ── Methods comparison ──────────────────────────────────────────── +def _methods_comparison() -> CompositeVideoClip: + """Create a comparison table of all segmentation methods.""" + bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0) + title = ( + _tc( + text="Por\u00f3wnanie metod segmentacji", + font_size=36, + color="white", + font=FONT_B, + ) + .with_duration(10.0) + .with_position(("center", 20)) + ) + + rows = [ + ("Metoda", "Typ", "Idea", "Mnemonik"), + ( + "Thresholding", + "Klasyczna", + "piksel > T \u2192 klasa 1", + "PR\u00d3G na bramce", + ), + ("Otsu", "Klasyczna", "auto-pr\u00f3g, min \u03c3\u00b2", "AUTO-bramkarz"), + ("Region Growing", "Klasyczna", "BFS od seeda", "PLAMA atramentu"), + ("Watershed", "Klasyczna", "zalewanie minim\u00f3w", "ZALEWANIE terenu"), + ( + "Mean Shift", + "Klasyczna", + "j\u0105dro \u2192 max g\u0119sto\u015bci", + "KULKI do do\u0142k\u00f3w", + ), + ("U-Net", "Deep Learning", "encoder-decoder + skip", "Litera U + mosty"), + ("DeepLab", "Deep Learning", "dilated conv + ASPP", "DZIURY w filtrze"), + ] + + clips: list[VideoClip] = [bg, title] + mnemonic_col = 3 + for i, row in enumerate(rows): + y_pos = 75 + i * 72 + col_x = [40, 210, 340, 660] + for j, cell in enumerate(row): + fs = 16 if i > 0 else 18 + color = ( + "#64B5F6" if i == 0 else ("#E0E0E0" if j < mnemonic_col else "#FFE082") + ) + tc = ( + _tc( + text=cell, + font_size=fs, + color=color, + font=FONT_B if i == 0 else FONT_R, + ) + .with_duration(10.0) + .with_position((col_x[j], y_pos)) + ) + clips.append(tc) + + return CompositeVideoClip(clips, size=(W, H)).with_effects( + [FadeIn(0.5), FadeOut(0.5)] + ) diff --git a/python_pkg/praca_magisterska_video/_q23_unet_fcn.py b/python_pkg/praca_magisterska_video/_q23_unet_fcn.py new file mode 100644 index 0000000..779d88e --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q23_unet_fcn.py @@ -0,0 +1,399 @@ +"""U-Net and FCN architecture animations for Q23 segmentation video.""" + +from __future__ import annotations + +from moviepy import ( + CompositeVideoClip, + VideoClip, +) +import numpy as np + +from python_pkg.praca_magisterska_video._q23_helpers import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _compose_slide, + _text_slide, +) + + +# ── U-Net Architecture ─────────────────────────────────────────── +def _draw_unet_skips( + frame: np.ndarray, + enc_positions: list[tuple[int, int, int, int]], + n_blocks: int, + dec_x: int, + skip_threshold: int, +) -> None: + """Draw horizontal dashed skip-connection lines.""" + if n_blocks <= skip_threshold: + return + for i in range(min(n_blocks - 5, 4)): + ey = enc_positions[i][1] + enc_positions[i][3] // 2 + ex_end = enc_positions[i][0] + enc_positions[i][2] + for dash_x in range(ex_end + 10, dec_x - 10, 15): + frame[ey : ey + 2, dash_x : dash_x + 8] = (255, 200, 50) + + +def _make_unet_frame(t: float) -> np.ndarray: + """Render a single U-Net animation frame.""" + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + enc_sizes = [(80, 120), (60, 100), (45, 80), (30, 60)] + dec_sizes = list(reversed(enc_sizes)) + enc_x = 150 + dec_x = 850 + + progress = min(t / (STEP_DUR * 0.6), 1.0) + n_blocks = int(progress * 8) + 1 + + enc_positions: list[tuple[int, int, int, int]] = [] + y_offset = 120 + for i, (bw, bh) in enumerate(enc_sizes): + x = enc_x + y = y_offset + i * 130 + enc_positions.append((x, y, bw, bh)) + if i < n_blocks: + frame[y : y + bh, x : x + bw] = (70, 130, 200) + frame[y : y + 2, x : x + bw] = (100, 180, 255) + frame[y + bh - 2 : y + bh, x : x + bw] = (100, 180, 255) + frame[y : y + bh, x : x + 2] = (100, 180, 255) + frame[y : y + bh, x + bw - 2 : x + bw] = (100, 180, 255) + if i < len(enc_sizes) - 1: + ax = x + bw // 2 + ay = y + bh + 10 + frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170) + + bx, by = 500, y_offset + 3 * 130 + 30 + encoder_count = 4 + if n_blocks > encoder_count: + frame[by : by + 50, bx : bx + 25] = (200, 100, 80) + frame[by : by + 2, bx : bx + 25] = (255, 140, 100) + frame[by + 48 : by + 50, bx : bx + 25] = (255, 140, 100) + + for i, (bw, bh) in enumerate(dec_sizes): + x = dec_x + y = y_offset + (3 - i) * 130 + if n_blocks > 4 + i + 1: + frame[y : y + bh, x : x + bw] = (80, 200, 120) + frame[y : y + 2, x : x + bw] = (120, 230, 150) + frame[y + bh - 2 : y + bh, x : x + bw] = (120, 230, 150) + frame[y : y + bh, x : x + 2] = (120, 230, 150) + frame[y : y + bh, x + bw - 2 : x + bw] = (120, 230, 150) + if i < len(dec_sizes) - 1: + ax = x + bw // 2 + ay = y - 30 + frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170) + + skip_threshold = 5 + _draw_unet_skips(frame, enc_positions, n_blocks, dec_x, skip_threshold) + + return frame + + +def _unet_demo() -> list[CompositeVideoClip]: + """Animate U-Net encoder-decoder architecture.""" + dur = STEP_DUR + 1 + unet_clip = VideoClip(_make_unet_frame, duration=dur).with_fps(FPS) + labels = [ + ("U-Net: Encoder-Decoder + Skip Connections", 28, "#FFE082", FONT_B, (80, 20)), + ( + "Niebieski = Encoder (↓ zmniejsza rozdzielczość, wyciąga cechy)", + 16, + "#64B5F6", + FONT_R, + (80, 65), + ), + ( + "Zielony = Decoder (↑ zwiększa rozdzielczość, odtwarza mapę)", + 16, + "#A5D6A7", + FONT_R, + (80, 90), + ), + ( + "Żółte przerywane = Skip connections (przenoszą detale z encodera)", + 16, + "#FFE082", + FONT_R, + (80, 115), + ), + ( + "Czerwony = Bottleneck (najgłębsza warstwa, max abstrakcja)", + 16, + "#EF9A9A", + FONT_R, + (450, 570), + ), + ( + "Kształt U: encoder ↓ decoder ↑, mosty pośrodku", + 18, + "white", + FONT_R, + (80, 640), + ), + ( + "Concatenation: skip łączy kanały (więcej informacji niż dodawanie)", + 16, + "#78909C", + FONT_R, + (80, 670), + ), + ] + return [_compose_slide(unet_clip, labels, dur)] + + +# ── FCN Architecture ───────────────────────────────────────────── +def _draw_pipeline_blocks( + frame: np.ndarray, + blocks: list[tuple[tuple[int, int], tuple[int, int], tuple[int, int, int]]], + n_visible: int, + arrow_limit: int, +) -> None: + """Draw coloured blocks with connecting arrows.""" + for i, ((bx, by), (bw, bh), color) in enumerate(blocks): + if i < n_visible: + frame[by : by + bh, bx : bx + bw] = color + frame[by : by + 2, bx : bx + bw] = tuple(min(c + 50, 255) for c in color) + frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + if i < arrow_limit: + ax = bx + bw + 3 + ay = by + bh // 2 + frame[ay - 1 : ay + 2, ax : ax + 12] = (150, 150, 170) + + +def _draw_red_cross( + frame: np.ndarray, + x_start: int, + width: int, + top_y: int, + height: int, +) -> None: + """Draw a red X across the given rectangle.""" + for d in range(-2, 3): + for step in range(height): + x1 = x_start + int(step * width / height) + y1 = top_y + step + d + if 0 <= y1 < H and 0 <= x1 < W: + frame[y1, x1] = (255, 80, 80) + y2 = top_y + height - step + d + if 0 <= y2 < H and 0 <= x1 < W: + frame[y2, x1] = (255, 80, 80) + + +def _make_fcn_frame(t: float) -> np.ndarray: + """Render a single FCN comparison frame.""" + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.8), 1.0) + + top_y = 140 + blocks_classic = [ + ((80, top_y), (70, 50), (70, 130, 200)), + ((170, top_y), (50, 40), (50, 100, 160)), + ((240, top_y), (60, 50), (70, 130, 200)), + ((320, top_y), (40, 35), (50, 100, 160)), + ((385, top_y), (55, 50), (160, 80, 60)), + ((465, top_y), (55, 50), (180, 60, 60)), + ((545, top_y), (80, 50), (200, 80, 80)), + ] + n_top = min(int(progress * 7) + 1, 7) + arrow_limit = 6 + _draw_pipeline_blocks(frame, blocks_classic, n_top, arrow_limit) + + cross_phase = 0.6 + if progress > cross_phase: + _draw_red_cross(frame, 385, 135, top_y, 50) + + bot_y = 380 + blocks_fcn = [ + ((80, bot_y), (70, 50), (70, 130, 200)), + ((170, bot_y), (50, 40), (50, 100, 160)), + ((240, bot_y), (60, 50), (70, 130, 200)), + ((320, bot_y), (40, 35), (50, 100, 160)), + ((385, bot_y), (70, 50), (80, 200, 120)), + ((480, bot_y), (75, 50), (200, 160, 80)), + ((580, bot_y), (80, 50), (100, 200, 100)), + ] + fcn_phase = 0.4 + if progress > fcn_phase: + n_bot = min(int((progress - fcn_phase) / 0.6 * 7) + 1, 7) + _draw_pipeline_blocks(frame, blocks_fcn, n_bot, arrow_limit) + + return frame + + +def _fcn_demo() -> list[CompositeVideoClip]: + """Animate FCN step-by-step: FC → Conv 1x1 transformation.""" + dur = STEP_DUR + 1 + fcn_clip = VideoClip(_make_fcn_frame, duration=dur).with_fps(FPS) + labels = [ + ("FCN: Fully Convolutional Network (2015)", 26, "#FFE082", FONT_B, (80, 20)), + ("KROK 1: Zamień FC → Conv 1x1", 18, "#A5D6A7", FONT_R, (80, 60)), + ("Klasyczny CNN:", 16, "#EF9A9A", FONT_B, (80, 105)), + ("Conv", 11, "white", FONT_R, (92, 148)), + ("Pool", 11, "white", FONT_R, (178, 148)), + ("Conv", 11, "white", FONT_R, (250, 148)), + ("Pool", 11, "white", FONT_R, (325, 148)), + ("Flatten", 11, "#EF9A9A", FONT_R, (390, 148)), + ("FC", 11, "#EF9A9A", FONT_R, (480, 148)), + ("1 label", 11, "#EF9A9A", FONT_R, (555, 148)), + ("FCN:", 16, "#A5D6A7", FONT_B, (80, 350)), + ("Conv", 11, "white", FONT_R, (92, 388)), + ("Pool", 11, "white", FONT_R, (178, 388)), + ("Conv", 11, "white", FONT_R, (250, 388)), + ("Pool", 11, "white", FONT_R, (325, 388)), + ("Conv1x1", 11, "#A5D6A7", FONT_R, (390, 388)), + ("Upsample", 11, "#FFE082", FONT_R, (486, 388)), + ("Mapa", 11, "#A5D6A7", FONT_R, (595, 388)), + ( + "FC: spłaszcza 3D→1D, wymusza stały rozmiar → 1 etykieta", + 16, + "#EF9A9A", + FONT_R, + (80, 250), + ), + ( + "Conv1x1: działa per piksel x kanały → DOWOLNY rozmiar → mapa klasy", + 16, + "#A5D6A7", + FONT_R, + (80, 460), + ), + ( + "KROK 2: Skip connections — łączą wczesne detale z późną abstrakcją", + 17, + "#64B5F6", + FONT_R, + (80, 510), + ), + ( + "Wczesne warstwy = krawędzie, tekstury | Późne = koncepty obiektów", + 15, + "#78909C", + FONT_R, + (80, 545), + ), + ( + "FCN = PIERWSZA sieć end-to-end do segmentacji per-piksel!", + 18, + "white", + FONT_R, + (80, 590), + ), + ( + "Mnemonik: FC → Conv 1x1 = otwieramy bramkę dla DOWOLNEGO rozmiaru", + 16, + "#FFE082", + FONT_R, + (80, 640), + ), + ] + slides = [_compose_slide(fcn_clip, labels, dur)] + + # Slide 2: FCN skip connections step by step + skip_lines = [ + ("FCN: Skip Connections — krok po kroku", 26, "#FFE082", FONT_B, (80, 30)), + ( + "1. Encoder zmniejsza: 224→112→56→28→14 (pooling)", + 18, + "#64B5F6", + FONT_R, + (100, 100), + ), + ( + " Każdy pooling traci detale przestrzenne (dokładne krawędzie)", + 15, + "#78909C", + FONT_R, + (100, 135), + ), + ( + "2. Decoder powiększa: 14→28→56→112→224 (upsample/deconv)", + 18, + "#A5D6A7", + FONT_R, + (100, 190), + ), + ( + " Upsample ODGADUJE piksele — rozmyty wynik!", + 15, + "#78909C", + FONT_R, + (100, 225), + ), + ( + "3. Skip connections: dodaj cechy z encodera do decodera", + 18, + "#FFE082", + FONT_R, + (100, 280), + ), + ( + " Wczesne cechy = GDZIE (precyzyjne krawędzie)", + 15, + "#64B5F6", + FONT_R, + (100, 315), + ), + ( + " Późne cechy = CO (abstrakcyjne koncepty)", + 15, + "#A5D6A7", + FONT_R, + (100, 345), + ), + ( + " Skip = daje decoderowi OBA → ostry wynik!", + 15, + "#FFE082", + FONT_R, + (100, 375), + ), + ( + "Warianty: FCN-32s (brak skip, rozmyty) → FCN-16s → FCN-8s (najlepszy)", + 16, + "#B0BEC5", + FONT_R, + (80, 440), + ), + ( + "FCN-32s: upsample 32x naraz → ROZMYTE granice", + 15, + "#EF9A9A", + FONT_R, + (100, 485), + ), + ( + "FCN-16s: skip z pool4 + upsample 16x → lepiej", + 15, + "#FFE082", + FONT_R, + (100, 520), + ), + ( + "FCN-8s: skip z pool3+pool4 + upsample 8x → OSTRE granice!", + 15, + "#A5D6A7", + FONT_R, + (100, 555), + ), + ( + "Im więcej skip connections → tym więcej " + "detali z encodera → ostrzejszy wynik", + 17, + "white", + FONT_R, + (80, 620), + ), + ] + slides.append(_text_slide(skip_lines, duration=STEP_DUR + 1)) + + return slides diff --git a/python_pkg/praca_magisterska_video/_q24_classical.py b/python_pkg/praca_magisterska_video/_q24_classical.py new file mode 100644 index 0000000..7c08a82 --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q24_classical.py @@ -0,0 +1,332 @@ +"""Classical detection methods: detection concept, HOG+SVM, Viola-Jones.""" + +from __future__ import annotations + +from _q24_common import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _tc, +) +from moviepy import CompositeVideoClip, VideoClip +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + + +# ── Detection concept ──────────────────────────────────────────── +def _detection_concept() -> list[CompositeVideoClip]: + """Show what detection is: bounding box + class + confidence.""" + slides = [] + + def make_det_frame(_t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + # Draw a "scene" with colored rectangles representing objects + # Sky background area + frame[140:500, 100:700] = (40, 50, 70) + + # "Car" object + frame[350:430, 150:320] = (180, 60, 60) + # "Person" object + frame[280:440, 450:520] = (60, 120, 180) + # "Tree" object + frame[200:400, 580:650] = (40, 130, 50) + + # Bounding boxes (with labels drawn as colored borders) + # Car bbox + for thickness in range(3): + t = thickness + frame[348 - t : 432 + t, 148 - t : 148 - t + 2] = (255, 80, 80) + frame[348 - t : 432 + t, 322 + t - 2 : 322 + t] = (255, 80, 80) + frame[348 - t : 348 - t + 2, 148 - t : 322 + t] = (255, 80, 80) + frame[432 + t - 2 : 432 + t, 148 - t : 322 + t] = (255, 80, 80) + + # Person bbox + for thickness in range(3): + t = thickness + frame[278 - t : 442 + t, 448 - t : 448 - t + 2] = (80, 180, 255) + frame[278 - t : 442 + t, 522 + t - 2 : 522 + t] = (80, 180, 255) + frame[278 - t : 278 - t + 2, 448 - t : 522 + t] = (80, 180, 255) + frame[442 + t - 2 : 442 + t, 448 - t : 522 + t] = (80, 180, 255) + + # Tree bbox + for thickness in range(3): + t = thickness + frame[198 - t : 402 + t, 578 - t : 578 - t + 2] = (80, 220, 100) + frame[198 - t : 402 + t, 652 + t - 2 : 652 + t] = (80, 220, 100) + frame[198 - t : 198 - t + 2, 578 - t : 652 + t] = (80, 220, 100) + frame[402 + t - 2 : 402 + t, 578 - t : 652 + t] = (80, 220, 100) + + # Comparison boxes on right side + # Classification + frame[180:260, 800:1150] = (35, 45, 65) + # Detection + frame[290:370, 800:1150] = (35, 45, 65) + # Segmentation + frame[400:480, 800:1150] = (35, 45, 65) + + return frame + + det_clip = VideoClip(make_det_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [det_clip] + labels = [ + ("Detekcja obiektów — co to jest?", 28, "#FFE082", FONT_B, (100, 20)), + ("Wynik: (klasa, bounding box, pewność)", 20, "#B0BEC5", FONT_R, (100, 65)), + ("samochód 95%", 14, "#EF9A9A", FONT_B, (150, 340)), + ("osoba 88%", 14, "#64B5F6", FONT_B, (450, 268)), + ("drzewo 72%", 14, "#A5D6A7", FONT_B, (580, 188)), + ("Klasyfikacja: cały obraz → 1 etykieta", 15, "#78909C", FONT_R, (810, 210)), + ("Detekcja: bbox + klasa + pewność", 15, "#FFE082", FONT_R, (810, 320)), + ("Segmentacja: maska per piksel", 15, "#78909C", FONT_R, (810, 430)), + ("← granulacja rośnie →", 14, "#90CAF9", FONT_R, (810, 520)), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides + + +# ── HOG + SVM pipeline ─────────────────────────────────────────── +def _hog_svm_demo() -> list[CompositeVideoClip]: + """Animate HOG feature computation and SVM classification.""" + slides = [] + + def make_hog_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + progress = min(t / (STEP_DUR * 0.8), 1.0) + + # Pipeline stages as boxes with arrows + stages = [ + ("Gradient", (80, 250), (130, 80), (100, 160, 220)), + ("Orientacja", (260, 250), (130, 80), (80, 180, 140)), + ("Komórki 8x8", (440, 250), (130, 80), (200, 160, 80)), + ("Bloki 2x2", (620, 250), (130, 80), (200, 120, 60)), + ("Normalizacja", (800, 250), (130, 80), (180, 100, 80)), + ("SVM", (980, 250), (130, 80), (220, 80, 80)), + ] + + n_active = int(progress * len(stages)) + 1 + + for i, (_label, (sx, sy), (sw, sh), color) in enumerate(stages): + if i < n_active: + frame[sy : sy + sh, sx : sx + sw] = color + # Border + frame[sy : sy + 2, sx : sx + sw] = tuple( + min(c + 60, 255) for c in color + ) + frame[sy + sh - 2 : sy + sh, sx : sx + sw] = tuple( + min(c + 60, 255) for c in color + ) + + # Arrow to next + if i < len(stages) - 1: + ax = sx + sw + 5 + ay = sy + sh // 2 + frame[ay - 1 : ay + 2, ax : ax + 20] = (150, 150, 170) + + # Show gradient computation example at bottom + gradient_phase = 0.2 + if progress > gradient_phase: + # Mini pixel grid showing gradient computation + gx, gy = 100, 430 + pixels = [50, 50, 200] + for idx, val in enumerate(pixels): + x = gx + idx * 50 + frame[gy : gy + 40, x : x + 40] = (val, val, val) + + return frame + + hog_clip = VideoClip(make_hog_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [hog_clip] + labels = [ + ("HOG + SVM — pipeline detekcji pieszych", 28, "#FFE082", FONT_B, (80, 20)), + ( + "Mnemonik: GOKBN = Gradienty→Orientacja→Komórki→Bloki→Normalizacja", + 16, + "#A5D6A7", + FONT_R, + (80, 65), + ), + ("Gradient: siła i kierunek zmiany jasności", 14, "#64B5F6", FONT_R, (80, 95)), + ( + "Histogram: 9 binów (0°-180°, co 20°) per komórka 8x8", + 14, + "#78909C", + FONT_R, + (80, 120), + ), + ( + "[50][50][200] → Gx = 200-50 = 150 = silna krawędź!", + 16, + "#EF9A9A", + FONT_R, + (80, 490), + ), + ( + "Wektor HOG (3780 cech) → SVM: pieszy (+1) / tło (-1)", + 16, + "white", + FONT_R, + (80, 540), + ), + ( + "Sliding window 64x128 przesuwa się po obrazie → NMS → wynik", + 16, + "#90CAF9", + FONT_R, + (80, 580), + ), + ( + "SVM = LINIA MAKSYMALNEGO ODDECHU (max margines, support vectors)", + 16, + "#FFE082", + FONT_R, + (80, 620), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides + + +# ── Viola-Jones ─────────────────────────────────────────────────── +def _viola_jones_demo() -> list[CompositeVideoClip]: + """Animate Viola-Jones cascade concept.""" + slides = [] + + def make_cascade_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + progress = min(t / (STEP_DUR * 0.8), 1.0) + + # Draw cascade "funnel" — stages filtering out non-faces + stages = 5 + start_width = 1000 + start_count = 10000 + x_center = W // 2 + + for i in range(stages): + stage_progress = min(progress * stages - i, 1.0) + if stage_progress <= 0: + break + + width = int(start_width * (1 - i * 0.18)) + int(start_count * (0.3**i)) + y = 150 + i * 100 + h_box = 60 + + # Stage box + x1 = x_center - width // 2 + frame[y : y + h_box, x1 : x1 + width] = ( + 50 + i * 10, + 60 + i * 10, + 80 + i * 10, + ) + # Border + frame[y : y + 2, x1 : x1 + width] = (100 + i * 20, 130 + i * 15, 200) + frame[y + h_box - 2 : y + h_box, x1 : x1 + width] = ( + 100 + i * 20, + 130 + i * 15, + 200, + ) + + # Arrow down to next + if i < stages - 1: + frame[y + h_box + 5 : y + h_box + 25, x_center - 1 : x_center + 2] = ( + 150, + 150, + 170, + ) + + # Red "rejected" arrows on sides + if i > 0: + # Left reject arrow + rx = x1 - 30 + ry = y + h_box // 2 + frame[ry - 1 : ry + 2, rx : rx + 25] = (200, 80, 80) + + return frame + + cascade_clip = VideoClip(make_cascade_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [cascade_clip] + labels = [ + ( + "Viola-Jones — kaskada klasyfikatorów (2001)", + 28, + "#FFE082", + FONT_B, + (80, 20), + ), + ( + "3 innowacje: HIC = Haar + Integral Image + Cascade", + 20, + "#B0BEC5", + FONT_R, + (80, 65), + ), + ("Etap 1: 2 cechy Haar", 14, "#64B5F6", FONT_R, (170, 170)), + ("Etap 2: 10 cech", 14, "#64B5F6", FONT_R, (210, 270)), + ("Etap 3: 25 cech", 14, "#64B5F6", FONT_R, (240, 370)), + ("Etap 4: 50 cech", 14, "#64B5F6", FONT_R, (260, 470)), + ("→ TWARZ!", 16, "#A5D6A7", FONT_B, (590, 560)), + ( + "SITO: 99% okien odpada w pierwszych 3 etapach → REAL-TIME!", + 16, + "#EF9A9A", + FONT_R, + (80, 620), + ), + ( + "Haar: kontrast jasna/ciemna | Integral Image: " + "suma prostokąta O(1) = 4 odczyty", + 14, + "#78909C", + FONT_R, + (80, 655), + ), + ("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 275)), + ("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 375)), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides diff --git a/python_pkg/praca_magisterska_video/_q24_common.py b/python_pkg/praca_magisterska_video/_q24_common.py new file mode 100644 index 0000000..1f6d0ac --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q24_common.py @@ -0,0 +1,115 @@ +"""Shared constants and helpers for Q24 object detection visualization.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +import numpy as np + +os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg" + +from moviepy import ( + ColorClip, + CompositeVideoClip, + TextClip, + VideoClip, +) +from moviepy.video.fx import FadeIn, FadeOut + +# ── Constants ───────────────────────────────────────────────────── +W, H = 1280, 720 +FPS = 24 +STEP_DUR = 7.0 +HEADER_DUR = 4.0 +FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf" +FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf" +OUTPUT_DIR = Path(__file__).resolve().parent / "videos" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) +OUTPUT = str(OUTPUT_DIR / "q24_object_detection.mp4") + +BG_COLOR = (15, 20, 35) + +_logger = logging.getLogger(__name__) + +# Re-export numpy for sub-modules that need it alongside constants. +__all__ = [ + "BG_COLOR", + "FONT_B", + "FONT_R", + "FPS", + "HEADER_DUR", + "OUTPUT", + "OUTPUT_DIR", + "STEP_DUR", + "H", + "W", + "_logger", + "_make_header", + "_tc", + "_text_slide", + "np", +] + + +def _tc(**kwargs: object) -> TextClip: + """TextClip wrapper that adds enough bottom margin to prevent clipping.""" + fs = kwargs.get("font_size", 24) + m = int(fs) // 3 + 2 + kwargs["margin"] = (0, m) + return TextClip(**kwargs) + + +def _make_header( + title: str, subtitle: str, duration: float = HEADER_DUR +) -> CompositeVideoClip: + """Create a title/subtitle header slide.""" + bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) + t = ( + _tc( + text=title, + font_size=48, + color="white", + font=FONT_B, + ) + .with_duration(duration) + .with_position(("center", 260)) + ) + s = ( + _tc( + text=subtitle, + font_size=24, + color="#90CAF9", + font=FONT_R, + ) + .with_duration(duration) + .with_position(("center", 340)) + ) + return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects( + [FadeIn(0.5), FadeOut(0.5)] + ) + + +def _text_slide( + lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]], + duration: float = STEP_DUR, +) -> CompositeVideoClip: + """Create a text-only slide from a list of (text, size, color, font, pos).""" + bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) + clips: list[VideoClip] = [bg] + for text, font_size, color, font, pos in lines: + tc = ( + _tc( + text=text, + font_size=font_size, + color=color, + font=font, + ) + .with_duration(duration) + .with_position(pos) + ) + clips.append(tc) + return CompositeVideoClip(clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) diff --git a/python_pkg/praca_magisterska_video/_q24_nms_final.py b/python_pkg/praca_magisterska_video/_q24_nms_final.py new file mode 100644 index 0000000..5987566 --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q24_nms_final.py @@ -0,0 +1,239 @@ +"""NMS/IoU, detector-from-classifier, and methods comparison.""" + +from __future__ import annotations + +from _q24_common import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _tc, + _text_slide, +) +from moviepy import ColorClip, CompositeVideoClip, VideoClip +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + + +# ── NMS + IoU ───────────────────────────────────────────────────── +def _nms_iou_demo() -> list[CompositeVideoClip]: + """Animate NMS and IoU concepts.""" + slides = [] + + def make_nms_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + progress = min(t / (STEP_DUR * 0.7), 1.0) + + # Draw overlapping bounding boxes + ox, oy = 100, 200 + obj_w, obj_h = 150, 120 + + # Multiple overlapping detections for same object + boxes = [ + (ox, oy, obj_w, obj_h, 0.95, (255, 80, 80)), # best + (ox + 15, oy - 10, obj_w + 10, obj_h + 5, 0.90, (200, 60, 60)), + (ox - 10, oy + 5, obj_w - 5, obj_h + 10, 0.85, (160, 50, 50)), + ] + # Different object far away + boxes.append((ox + 350, oy + 50, 100, 100, 0.40, (80, 180, 255))) + + for i, (bx, by, bw, bh, _conf, color) in enumerate(boxes): + dc = color + nms_phase = 0.4 + nms_limit = 3 + if progress > nms_phase and i > 0 and i < nms_limit: + # After NMS, these get removed (shown as faded/crossed) + dc = (60, 40, 40) + + for tt in range(2): + frame[by - tt : by + bh + tt, bx - tt : bx - tt + 2] = dc + frame[by - tt : by + bh + tt, bx + bw + tt - 2 : bx + bw + tt] = dc + frame[by - tt : by - tt + 2, bx - tt : bx + bw + tt] = dc + frame[by + bh + tt - 2 : by + bh + tt, bx - tt : bx + bw + tt] = dc + + # IoU visualization on right side + iou_x, iou_y = 700, 200 + # Box A + frame[iou_y : iou_y + 100, iou_x : iou_x + 100] = (80, 80, 200) + # Box B (overlapping) + frame[iou_y + 40 : iou_y + 140, iou_x + 40 : iou_x + 140] = (200, 80, 80) + # Intersection highlighted + frame[iou_y + 40 : iou_y + 100, iou_x + 40 : iou_x + 100] = (200, 150, 200) + + return frame + + nms_clip = VideoClip(make_nms_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [nms_clip] + labels = [ + ("NMS (Non-Maximum Suppression) + IoU", 28, "#FFE082", FONT_B, (80, 20)), + ( + "NMS = Najlepszy Ma Się dobrze — zachowaj najlepszą, usuń duplikaty", + 18, + "#B0BEC5", + FONT_R, + (80, 65), + ), + ("conf=0.95 ✓", 14, "#A5D6A7", FONT_B, (100, 340)), + ("0.90 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 365)), + ("0.85 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 390)), + ("0.40 ✓ INNY obiekt", 13, "#64B5F6", FONT_R, (100, 420)), + ("IoU = Intersection over Union", 18, "#FFE082", FONT_B, (700, 160)), + ("IoU = pole(∩) / pole(AUB)", 16, "white", FONT_R, (700, 380)), + ("Fioletowy = intersection", 14, "#CE93D8", FONT_R, (700, 410)), + ("IoU > 0.5 → TEN SAM obiekt → usuń", 14, "#EF9A9A", FONT_R, (700, 440)), + ("IoU < 0.5 → INNY obiekt → zachowaj", 14, "#A5D6A7", FONT_R, (700, 470)), + ( + "DETR: jedyny detektor BEZ NMS (Hungarian matching zamiast tego)", + 14, + "#78909C", + FONT_R, + (80, 620), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides + + +# ── Detector from Classifier ───────────────────────────────────── +def _detector_from_classifier() -> list[CompositeVideoClip]: + """Show 3 approaches to building a detector from a classifier.""" + slides = [] + + approaches = [ + ( + "Podejście 1: Sliding Window (NAJWOLNIEJSZE)", + [ + ("Okno przesuwa się po obrazie w wielu skalach", "#B0BEC5"), + ("Każde okno → klasyfikator (np. ResNet) → klasa + pewność", "#B0BEC5"), + ("~18 000 okien x 10ms = ~3 minuty na obraz!", "#EF9A9A"), + ("Mnemonik: WYCINAJ i PYTAJ — jak wycinanie ciasteczek", "#FFE082"), + ], + "SRF", + ), + ( + "Podejście 2: Region Proposals (= R-CNN)", + [ + ("Selective Search → ~2000 inteligentnych regionów", "#B0BEC5"), + ("Każdy region → CNN → wektor cech → SVM klasyfikuje", "#B0BEC5"), + ("~2000 x 10ms = ~20 sec — 9x szybciej!", "#64B5F6"), + ( + "Mnemonik: INTELIGENTNE CIĘCIE — wytnij tylko tam gdzie wiśnie", + "#FFE082", + ), + ], + "SRF", + ), + ( + "Podejście 3: Fine-tune backbone (NAJLEPSZE)", + [ + ( + "Pretrained backbone (ResNet) → odetnij FC → dodaj detection head", + "#B0BEC5", + ), + ( + "Detection head = głowica klasyfikacji + głowica regresji bbox", + "#B0BEC5", + ), + ("~0.2 sec/obraz, najlepsza jakość (mAP ~42%)", "#A5D6A7"), + ("Mnemonik: PRZESZCZEP GŁOWY — ten sam silnik, nowa głowa", "#FFE082"), + ], + "SRF", + ), + ] + + for title, points, _mnem in approaches: + lines = [ + (title, 24, "#FFE082", FONT_B, (80, 140)), + ] + for i, (text, color) in enumerate(points): + lines.append((f"• {text}", 18, color, FONT_R, (100, 220 + i * 50))) + + lines.append( + ( + "Detektor z klasyfikatora: SRF = Sliding → Region → Fine-tune", + 16, + "#78909C", + FONT_R, + (80, 520), + ) + ) + lines.append( + ( + "= Szukaj Ręcznie, Finalnie optymalizuj!", + 16, + "#90CAF9", + FONT_R, + (80, 550), + ) + ) + + slides.append(_text_slide(lines, duration=STEP_DUR)) + + return slides + + +# ── Methods comparison ──────────────────────────────────────────── +def _methods_comparison() -> CompositeVideoClip: + """Create a comparison table of all detection methods.""" + bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0) + title = ( + _tc( + text="Porównanie detektorów", + font_size=36, + color="white", + font=FONT_B, + ) + .with_duration(10.0) + .with_position(("center", 20)) + ) + + rows = [ + ("Model", "Rok", "Typ", "Szybkość", "Kluczowe"), + ("HOG+SVM", "2005", "Klasyczny", "~1 fps", "Gradient histogramy"), + ("Viola-Jones", "2001", "Klasyczny", "30+ fps", "Haar+Cascade"), + ("R-CNN", "2014", "Two-stage", "50 sec!", "CNN per region"), + ("Fast R-CNN", "2015", "Two-stage", "2 sec", "ROI Pooling"), + ("Faster R-CNN", "2015", "Two-stage", "5 fps", "RPN w sieci"), + ("YOLO", "2016", "One-stage", "45+ fps", "Siatka SxS"), + ("DETR", "2020", "Transformer", "~40 fps", "Bez NMS!"), + ] + + clips: list[VideoClip] = [bg, title] + for i, row in enumerate(rows): + y_pos = 75 + i * 72 + col_x = [40, 200, 280, 400, 530] + for j, cell in enumerate(row): + fs = 16 if i > 0 else 18 + color = "#64B5F6" if i == 0 else "#E0E0E0" + tc = ( + _tc( + text=cell, + font_size=fs, + color=color, + font=FONT_B if i == 0 else FONT_R, + ) + .with_duration(10.0) + .with_position((col_x[j], y_pos)) + ) + clips.append(tc) + + return CompositeVideoClip(clips, size=(W, H)).with_effects( + [FadeIn(0.5), FadeOut(0.5)] + ) diff --git a/python_pkg/praca_magisterska_video/_q24_rcnn.py b/python_pkg/praca_magisterska_video/_q24_rcnn.py new file mode 100644 index 0000000..34730c2 --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q24_rcnn.py @@ -0,0 +1,405 @@ +"""R-CNN family: evolution, detailed pipeline, ROI pooling.""" + +from __future__ import annotations + +from _q24_common import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _tc, +) +from moviepy import CompositeVideoClip, VideoClip +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + + +# ── R-CNN Evolution ─────────────────────────────────────────────── +def _rcnn_evolution() -> list[CompositeVideoClip]: + """Animate R-CNN → Fast R-CNN → Faster R-CNN evolution.""" + slides = [] + + def make_evolution_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + progress = min(t / (STEP_DUR * 0.8), 1.0) + + # Three rows: R-CNN, Fast R-CNN, Faster R-CNN + models = [ + ( + "R-CNN (2014)", + 50, + [ + ("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)), + ("2000x\nCNN", (350, 150), (80, 50), (180, 60, 60)), + ("2000x\nSVM", (480, 150), (80, 50), (180, 60, 60)), + ("NMS", (610, 150), (60, 50), (100, 140, 100)), + ], + "50 sec/obraz!", + ), + ( + "Fast R-CNN (2015)", + 300, + [ + ("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)), + ("1x CNN\n(cały obraz)", (350, 150), (100, 50), (80, 140, 200)), + ("ROI Pool\n(2000)", (500, 150), (90, 50), (200, 160, 80)), + ("FC", (640, 150), (50, 50), (100, 140, 100)), + ], + "2 sec/obraz", + ), + ( + "Faster R-CNN (2015)", + 300, + [ + ("CNN\nbackbone", (200, 150), (90, 50), (80, 140, 200)), + ("RPN\n(~300)", (340, 150), (80, 50), (200, 120, 60)), + ("ROI Pool", (470, 150), (80, 50), (200, 160, 80)), + ("FC", (600, 150), (50, 50), (100, 140, 100)), + ], + "0.2 sec → 5 fps!", + ), + ] + + n_models = int(progress * 3) + 1 + + for mi, (_name, base_y, stages, _speed) in enumerate(models): + if mi >= n_models: + break + for _label, (bx, by_off), (bw, bh), color in stages: + by = base_y + by_off - 150 + frame[by : by + bh, bx : bx + bw] = color + frame[by : by + 2, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + + # Arrows between stages + for si in range(len(stages) - 1): + sx = stages[si][1][0] + stages[si][2][0] + ex = stages[si + 1][1][0] + ay = base_y + 25 + frame[ay - 1 : ay + 2, sx + 3 : ex - 3] = (150, 150, 170) + + return frame + + evo_clip = VideoClip(make_evolution_frame, duration=STEP_DUR + 1).with_fps(FPS) + text_clips: list[VideoClip] = [evo_clip] + labels = [ + ("Ewolucja R-CNN — CORAZ MNIEJ MARNOWANIA", 28, "#FFE082", FONT_B, (80, 20)), + ("R-CNN (2014)", 20, "#EF9A9A", FONT_B, (50, 80)), + ("50 sec/obraz (2000x forward pass!)", 14, "#EF9A9A", FONT_R, (720, 100)), + ("Fast R-CNN (2015)", 20, "#64B5F6", FONT_B, (50, 330)), + ("2 sec/obraz (CNN raz + ROI Pool)", 14, "#64B5F6", FONT_R, (720, 350)), + ("Faster R-CNN (2015)", 20, "#A5D6A7", FONT_B, (50, 580)), + ("0.2 sec → 5 fps (RPN w sieci!)", 14, "#A5D6A7", FONT_R, (720, 600)), + ( + "Kluczowe innowacje: ROI Pooling → stały rozmiar " + "| RPN → propozycje w sieci", + 14, + "#78909C", + FONT_R, + (80, 660), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR + 1) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides + + +# ── R-CNN Detailed Pipeline ────────────────────────────────────── +def _rcnn_detailed() -> list[CompositeVideoClip]: + """Animate R-CNN step-by-step pipeline in detail.""" + slides = [] + + # Slide 1: R-CNN pipeline step by step + def make_rcnn_pipeline(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.8), 1.0) + + # Step boxes arranged vertically with arrows + steps = [ + ((80, 130), (200, 55), (120, 100, 60), "1. Selective Search"), + ((80, 230), (200, 55), (180, 60, 60), "2. Wytnij 2000 regionów"), + ((80, 330), (200, 55), (70, 130, 200), "3. CNN per region"), + ((80, 430), (200, 55), (200, 100, 80), "4. SVM klasyfikuje"), + ((80, 530), (200, 55), (100, 180, 100), "5. Bbox regresja + NMS"), + ] + n_steps = min(int(progress * 5) + 1, 5) + for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(steps): + if i < n_steps: + frame[by : by + bh, bx : bx + bw] = color + frame[by : by + 2, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + # Arrow down + arrow_limit = 4 + if i < arrow_limit: + ax = bx + bw // 2 + ay = by + bh + 5 + frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170) + + # Illustration: many overlapping regions from Selective Search + overlay_phase = 0.2 + if progress > overlay_phase: + rng_local = np.random.default_rng(42) + n_boxes = min(int((progress - 0.2) * 15), 8) + for i in range(n_boxes): + rx = 500 + rng_local.integers(-30, 100) + ry = 200 + rng_local.integers(-20, 120) + rw = 60 + rng_local.integers(0, 80) + rh = 50 + rng_local.integers(0, 70) + c = (80 + i * 15, 100 + i * 10, 60 + i * 20) + for tt in range(2): + frame[ry - tt : ry + rh + tt, rx - tt : rx - tt + 2] = c + frame[ry - tt : ry + rh + tt, rx + rw + tt - 2 : rx + rw + tt] = c + frame[ry - tt : ry - tt + 2, rx - tt : rx + rw + tt] = c + frame[ry + rh + tt - 2 : ry + rh + tt, rx - tt : rx + rw + tt] = c + + return frame + + rcnn_clip = VideoClip(make_rcnn_pipeline, duration=STEP_DUR + 1).with_fps(FPS) + dur = STEP_DUR + 1 + labels = [ + ("R-CNN: krok po kroku (2014, Girshick)", 26, "#FFE082", FONT_B, (80, 20)), + ("Pipeline detekcji two-stage", 16, "#B0BEC5", FONT_R, (80, 60)), + ("Selective Search", 11, "white", FONT_R, (105, 145)), + ("2000 regionów", 11, "white", FONT_R, (105, 245)), + ("CNN per region", 11, "white", FONT_R, (105, 345)), + ("SVM klasyfikuje", 11, "white", FONT_R, (105, 445)), + ("Regresja + NMS", 11, "white", FONT_R, (105, 545)), + ("~2000 propozycji regionów", 14, "#78909C", FONT_R, (500, 155)), + ("(inteligentne łączenie", 13, "#78909C", FONT_R, (500, 180)), + ("podobnych fragmentów)", 13, "#78909C", FONT_R, (500, 200)), + ("Problem: 2000 x CNN forward pass", 16, "#EF9A9A", FONT_R, (400, 400)), + ("= 50 SEKUND na obraz!", 18, "#EF9A9A", FONT_B, (400, 430)), + ("CNN liczy cechy per region OSOBNO", 14, "#EF9A9A", FONT_R, (400, 470)), + ( + "→ regiony się nakładają → obliczenia się powtarzają!", + 14, + "#EF9A9A", + FONT_R, + (400, 495), + ), + ( + "Rozwiązanie: CNN raz na cały obraz → Fast R-CNN →", + 16, + "#A5D6A7", + FONT_R, + (80, 620), + ), + ] + text_clips: list[VideoClip] = [rcnn_clip] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(dur) + .with_position(pos) + ) + text_clips.append(tc) + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + + return slides + + +# ── ROI Pooling ────────────────────────────────────────────────── + + +def _draw_roi_pool_grid(frame: np.ndarray) -> None: + """Draw the 3x3 ROI pool grid with max-pooled feature values.""" + out_x, out_y = 400, 220 + out_cell = 50 + out_n = 3 + roi_r1, roi_c1 = 2, 1 + roi_r2, roi_c2 = 6, 5 + roi_h = roi_r2 - roi_r1 + roi_w = roi_c2 - roi_c1 + for r in range(out_n): + for c in range(out_n): + x = out_x + c * out_cell + y = out_y + r * out_cell + + # Compute the max from corresponding region + src_r1 = roi_r1 + r * roi_h // out_n + src_r2 = roi_r1 + (r + 1) * roi_h // out_n + src_c1 = roi_c1 + c * roi_w // out_n + src_c2 = roi_c1 + (c + 1) * roi_w // out_n + max_val = 0 + for sr in range(src_r1, src_r2): + for sc in range(src_c1, src_c2): + v = 30 + ((sr * 7 + sc * 13 + 42) % 40) + max_val = max(max_val, v) + + frame[y : y + out_cell - 2, x : x + out_cell - 2] = ( + max_val, + max_val + 20, + max_val + 40, + ) + frame[y : y + 2, x : x + out_cell - 2] = (80, 200, 120) + frame[y + out_cell - 4 : y + out_cell - 2, x : x + out_cell - 2] = ( + 80, + 200, + 120, + ) + + +def _make_roi_frame(t: float) -> np.ndarray: + """Render a single frame for the ROI pooling animation.""" + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + # Left: feature map with ROI highlighted + fm_x, fm_y = 60, 180 + fm_cell = 30 + fm_grid = 8 + for r in range(fm_grid): + for c in range(fm_grid): + x = fm_x + c * fm_cell + y = fm_y + r * fm_cell + # Random-looking feature values + val = 30 + ((r * 7 + c * 13 + 42) % 40) + frame[y : y + fm_cell - 1, x : x + fm_cell - 1] = ( + val, + val + 10, + val + 20, + ) + + # ROI region highlighted + roi_r1, roi_c1 = 2, 1 + roi_r2, roi_c2 = 6, 5 + for tt in range(3): + ry1 = fm_y + roi_r1 * fm_cell - tt + ry2 = fm_y + roi_r2 * fm_cell + tt + rx1 = fm_x + roi_c1 * fm_cell - tt + rx2 = fm_x + roi_c2 * fm_cell + tt + frame[ry1:ry2, rx1 : rx1 + 2] = (255, 200, 50) + frame[ry1:ry2, rx2 - 2 : rx2] = (255, 200, 50) + frame[ry1 : ry1 + 2, rx1:rx2] = (255, 200, 50) + frame[ry2 - 2 : ry2, rx1:rx2] = (255, 200, 50) + + # Arrow + arrow_phase = 0.3 + if progress > arrow_phase: + frame[300:303, 310:380] = (150, 150, 170) + + # Middle: ROI divided into 3x3 grid (output_size) + grid_phase = 0.3 + if progress > grid_phase: + _draw_roi_pool_grid(frame) + + # Arrow to FC + fc_phase = 0.6 + if progress > fc_phase: + frame[300:303, 560:630] = (150, 150, 170) + # FC box + frame[270:340, 650:730] = (200, 100, 80) + frame[270:272, 650:730] = (240, 140, 120) + frame[338:340, 650:730] = (240, 140, 120) + + return frame + + +def _roi_pooling_demo() -> list[CompositeVideoClip]: + """Animate ROI Pooling: key Fast R-CNN innovation.""" + slides = [] + + roi_clip = VideoClip(_make_roi_frame, duration=STEP_DUR + 1).with_fps(FPS) + dur = STEP_DUR + 1 + labels = [ + ("ROI Pooling: kluczowa innowacja Fast R-CNN", 26, "#FFE082", FONT_B, (80, 20)), + ( + "KROK 1: CNN raz na CAŁY obraz → feature mapa", + 17, + "#64B5F6", + FONT_R, + (80, 60), + ), + ( + "KROK 2: Wytnij ROI z feature mapy (nie z obrazu!)", + 17, + "#FFE082", + FONT_R, + (80, 90), + ), + ( + "KROK 3: Siatkuj ROI na 3x3 → max pool per komórka → stały rozmiar", + 17, + "#A5D6A7", + FONT_R, + (80, 120), + ), + ("Feature mapa", 14, "#64B5F6", FONT_B, (60, 160)), + ("ROI (żółta ramka)", 13, "#FFE082", FONT_R, (60, 440)), + ("ROI Pool 3x3", 14, "#A5D6A7", FONT_B, (400, 195)), + ("(max z komórki)", 13, "#78909C", FONT_R, (400, 380)), + ("FC", 14, "white", FONT_B, (670, 280)), + ( + "Problem: ROI mają RÓŻNE rozmiary, FC wymaga STAŁEGO", + 15, + "#B0BEC5", + FONT_R, + (80, 500), + ), + ( + "ROI Pooling: dzieli ROI na siatkę, max pool → STAŁY rozmiar!", + 16, + "white", + FONT_R, + (80, 535), + ), + ( + "Fast R-CNN: CNN raz → 1 feature mapa → " + "ROI Pool 2000 regionów → 25x szybciej!", + 16, + "#A5D6A7", + FONT_R, + (80, 580), + ), + ( + "(R-CNN: 2000x CNN = 50s | Fast R-CNN: 1xCNN + ROI Pool = 2s)", + 15, + "#EF9A9A", + FONT_R, + (80, 620), + ), + ] + text_clips: list[VideoClip] = [roi_clip] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(dur) + .with_position(pos) + ) + text_clips.append(tc) + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides diff --git a/python_pkg/praca_magisterska_video/_q24_rpn_yolo.py b/python_pkg/praca_magisterska_video/_q24_rpn_yolo.py new file mode 100644 index 0000000..f7cf98f --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q24_rpn_yolo.py @@ -0,0 +1,383 @@ +"""RPN anchor boxes and YOLO grid detection.""" + +from __future__ import annotations + +from _q24_common import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _tc, + _text_slide, +) +from moviepy import CompositeVideoClip, VideoClip +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + + +# ── RPN + Anchor Boxes ─────────────────────────────────────────── +def _rpn_anchors_demo() -> list[CompositeVideoClip]: + """Animate RPN and anchor boxes: Faster R-CNN innovation.""" + slides = [] + + # Slide 1: Anchor boxes concept + def make_anchors_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + # Draw feature map grid point with multiple anchors + cx, cy = 350, 360 # center point on feature map + + # Draw a "feature map" grid background + cell = 60 + for r in range(-3, 4): + for c in range(-3, 4): + x = cx + c * cell - cell // 2 + y = cy + r * cell - cell // 2 + frame[y : y + cell - 1, x : x + cell - 1] = (30, 35, 48) + + # Center point highlighted + frame[cy - 5 : cy + 5, cx - 5 : cx + 5] = (255, 200, 50) + + # Draw anchors around center: 3 sizes x 3 ratios = 9 + anchor_specs = [ + (30, 30, (200, 80, 80)), # small 1:1 + (20, 40, (200, 60, 60)), # small 1:2 + (40, 20, (180, 60, 60)), # small 2:1 + (60, 60, (80, 200, 80)), # medium 1:1 + (40, 80, (60, 180, 60)), # medium 1:2 + (80, 40, (60, 160, 60)), # medium 2:1 + (90, 90, (80, 80, 200)), # large 1:1 + (60, 120, (60, 60, 180)), # large 1:2 + (120, 60, (60, 60, 160)), # large 2:1 + ] + n_anchors = min(int(progress * 9) + 1, 9) + for i in range(n_anchors): + hw, hh, color = anchor_specs[i] + x1 = max(0, cx - hw) + y1 = max(0, cy - hh) + x2 = min(W - 1, cx + hw) + y2 = min(H - 1, cy + hh) + for tt in range(2): + frame[y1 - tt : y2 + tt, x1 - tt : x1 - tt + 2] = color + frame[y1 - tt : y2 + tt, x2 + tt - 2 : x2 + tt] = color + frame[y1 - tt : y1 - tt + 2, x1 - tt : x2 + tt] = color + frame[y2 + tt - 2 : y2 + tt, x1 - tt : x2 + tt] = color + + return frame + + anch_clip = VideoClip(make_anchors_frame, duration=STEP_DUR + 1).with_fps(FPS) + dur = STEP_DUR + 1 + labels = [ + ("Anchor Boxes + RPN (Faster R-CNN)", 26, "#FFE082", FONT_B, (80, 20)), + ( + "KROK 1: Anchory = predefiniowane kształty w każdej pozycji", + 17, + "#A5D6A7", + FONT_R, + (80, 60), + ), + ( + "3 rozmiary x 3 proporcje = 9 anchorów per punkt", + 16, + "#B0BEC5", + FONT_R, + (80, 90), + ), + ("Małe (1:1, 1:2, 2:1)", 14, "#EF9A9A", FONT_R, (750, 170)), + ("Średnie (1:1, 1:2, 2:1)", 14, "#A5D6A7", FONT_R, (750, 210)), + ("Duże (1:1, 1:2, 2:1)", 14, "#64B5F6", FONT_R, (750, 250)), + ("Żółty punkt = pozycja", 14, "#FFE082", FONT_R, (750, 310)), + ("na feature mapie", 14, "#FFE082", FONT_R, (750, 335)), + ("Sieć NIE predykuje bbox od zera!", 16, "white", FONT_R, (80, 530)), + ( + "Predykuje OFFSET od najbliższego anchora: (Δx, Δy, Δw, Δh)", + 16, + "#FFE082", + FONT_R, + (80, 565), + ), + ( + "+ P(obiekt) = 'czy w tym anchorze jest coś?'", + 16, + "#A5D6A7", + FONT_R, + (80, 600), + ), + ( + "Mnemonik: Anchor = KOTWICA — sieć dopasowuje bbox do kotwicy", + 15, + "#78909C", + FONT_R, + (80, 645), + ), + ] + text_clips: list[VideoClip] = [anch_clip] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(dur) + .with_position(pos) + ) + text_clips.append(tc) + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + + # Slide 2: RPN step by step + rpn_lines = [ + ( + "RPN: Region Proposal Network — krok po kroku", + 24, + "#FFE082", + FONT_B, + (80, 30), + ), + ( + "Zastępuje Selective Search SIECIĄ NEURONOWĄ (end-to-end!)", + 17, + "#B0BEC5", + FONT_R, + (80, 85), + ), + ("", 10, "white", FONT_R, (80, 110)), + ( + "1. Backbone (ResNet) przetwarza obraz → feature mapa [40x60x256]", + 16, + "#64B5F6", + FONT_R, + (100, 140), + ), + ( + "2. Filtr 3x3 przesuwa się po feature mapie", + 16, + "#A5D6A7", + FONT_R, + (100, 180), + ), + ( + "3. W KAŻDEJ pozycji (x,y) rozważ k=9 anchorów:", + 16, + "#FFE082", + FONT_R, + (100, 220), + ), + (" → P(obiekt) — 'czy tu jest coś?'", 15, "white", FONT_R, (120, 255)), + (" → (Δx, Δy, Δw, Δh) — poprawka pozycji", 15, "white", FONT_R, (120, 285)), + ( + "4. 40x60 pozycji x 9 anchorów = 21 600 kandydatów!", + 16, + "#EF9A9A", + FONT_R, + (100, 325), + ), + ( + "5. Weź ~300 z najwyższym P(obiekt) → ROI Pool → FC", + 16, + "#A5D6A7", + FONT_R, + (100, 365), + ), + ("", 10, "white", FONT_R, (100, 395)), + ("Porównanie generowania propozycji:", 17, "white", FONT_B, (80, 420)), + ( + " Selective Search: ~2000 regionów, osobny algorytm, ~2 sec", + 15, + "#EF9A9A", + FONT_R, + (100, 460), + ), + ( + " RPN: ~300 regionów, W SIECI, ~10 ms → 200x szybciej!", + 15, + "#A5D6A7", + FONT_R, + (100, 495), + ), + ("", 10, "white", FONT_R, (100, 520)), + ( + "Faster R-CNN = Backbone + RPN + ROI Pool + FC — WSZYSTKO end-to-end", + 17, + "#FFE082", + FONT_R, + (80, 545), + ), + ( + "→ 5 fps (0.2 sec/obraz) vs R-CNN 50 sec = 250x szybciej!", + 17, + "#A5D6A7", + FONT_R, + (80, 585), + ), + ( + "Wciąż two-stage: (1) RPN generuje propozycje, (2) FC klasyfikuje", + 15, + "#78909C", + FONT_R, + (80, 630), + ), + ] + slides.append(_text_slide(rpn_lines, duration=STEP_DUR + 1)) + + return slides + + +# ── YOLO ────────────────────────────────────────────────────────── +def _yolo_demo() -> list[CompositeVideoClip]: + """Animate YOLO grid detection concept.""" + slides = [] + + def make_yolo_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + + progress = min(t / (STEP_DUR * 0.7), 1.0) + + # Draw image with grid overlay + img_x, img_y = 100, 140 + img_size = 420 + grid_n = 7 + + # Background "image" + frame[img_y : img_y + img_size, img_x : img_x + img_size] = (50, 55, 70) + + # Objects in the image + frame[img_y + 80 : img_y + 200, img_x + 50 : img_x + 180] = ( + 180, + 60, + 60, + ) # "car" + frame[img_y + 150 : img_y + 350, img_x + 250 : img_x + 330] = ( + 60, + 120, + 180, + ) # "person" + + # Grid lines + cell = img_size // grid_n + for i in range(grid_n + 1): + # Vertical + x = img_x + i * cell + frame[img_y : img_y + img_size, x : x + 1] = (100, 100, 120) + # Horizontal + y = img_y + i * cell + frame[y : y + 1, img_x : img_x + img_size] = (100, 100, 120) + + # Highlight cells containing object centers + car_phase = 0.3 + if progress > car_phase: + # Car center ~ cell (1, 1) + cx, cy = 1, 2 + hx = img_x + cx * cell + hy = img_y + cy * cell + frame[hy : hy + cell, hx : hx + cell] = np.clip( + frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255 + ).astype(np.uint8) + + person_phase = 0.5 + if progress > person_phase: + # Person center ~ cell (4, 4) + cx, cy = 4, 4 + hx = img_x + cx * cell + hy = img_y + cy * cell + frame[hy : hy + cell, hx : hx + cell] = np.clip( + frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255 + ).astype(np.uint8) + + # Bounding boxes predictions from cells + bbox_phase = 0.6 + if progress > bbox_phase: + # Car bbox + for tt in range(2): + frame[ + img_y + 78 - tt : img_y + 202 + tt, + img_x + 48 - tt : img_x + 48 - tt + 2, + ] = (255, 80, 80) + frame[ + img_y + 78 - tt : img_y + 202 + tt, + img_x + 182 + tt - 2 : img_x + 182 + tt, + ] = (255, 80, 80) + frame[ + img_y + 78 - tt : img_y + 78 - tt + 2, + img_x + 48 - tt : img_x + 182 + tt, + ] = (255, 80, 80) + frame[ + img_y + 202 + tt - 2 : img_y + 202 + tt, + img_x + 48 - tt : img_x + 182 + tt, + ] = (255, 80, 80) + + # Person bbox + for tt in range(2): + frame[ + img_y + 148 - tt : img_y + 352 + tt, + img_x + 248 - tt : img_x + 248 - tt + 2, + ] = (80, 180, 255) + frame[ + img_y + 148 - tt : img_y + 352 + tt, + img_x + 332 + tt - 2 : img_x + 332 + tt, + ] = (80, 180, 255) + frame[ + img_y + 148 - tt : img_y + 148 - tt + 2, + img_x + 248 - tt : img_x + 332 + tt, + ] = (80, 180, 255) + frame[ + img_y + 352 + tt - 2 : img_y + 352 + tt, + img_x + 248 - tt : img_x + 332 + tt, + ] = (80, 180, 255) + + return frame + + yolo_clip = VideoClip(make_yolo_frame, duration=STEP_DUR).with_fps(FPS) + text_clips: list[VideoClip] = [yolo_clip] + labels = [ + ("YOLO — You Only Look Once", 28, "#FFE082", FONT_B, (80, 20)), + ( + "Jednoetapowy detektor: siatka SxS → wszystkie detekcje naraz!", + 18, + "#B0BEC5", + FONT_R, + (80, 65), + ), + ("Siatka 7x7 = 49 komórek", 16, "#64B5F6", FONT_R, (600, 180)), + ("Każda komórka predykuje:", 16, "white", FONT_R, (600, 220)), + (" • B bbox (x, y, w, h, conf)", 14, "#B0BEC5", FONT_R, (600, 255)), + (" • C klas (prawdopodobieństwa)", 14, "#B0BEC5", FONT_R, (600, 285)), + ("Komórka odpowiada za obiekt", 14, "#A5D6A7", FONT_R, (600, 325)), + ("którego ŚRODEK w niej wpada", 14, "#A5D6A7", FONT_R, (600, 350)), + ("45-155 fps! (vs 5 fps Faster R-CNN)", 18, "#EF9A9A", FONT_B, (600, 400)), + ( + "Jedno przejście przez sieć → WSZYSTKIE detekcje naraz → NMS → wynik", + 14, + "#78909C", + FONT_R, + (80, 620), + ), + ( + "Two-stage (R-CNN): propozycje+klasyfikacja " + "| One-stage (YOLO): bez propozycji!", + 14, + "#90CAF9", + FONT_R, + (80, 655), + ), + ] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(STEP_DUR) + .with_position(pos) + ) + text_clips.append(tc) + + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + return slides diff --git a/python_pkg/praca_magisterska_video/_q24_yolo_arch_detr.py b/python_pkg/praca_magisterska_video/_q24_yolo_arch_detr.py new file mode 100644 index 0000000..193bd2b --- /dev/null +++ b/python_pkg/praca_magisterska_video/_q24_yolo_arch_detr.py @@ -0,0 +1,459 @@ +"""YOLO architecture detail and DETR transformer detection.""" + +from __future__ import annotations + +from _q24_common import ( + BG_COLOR, + FONT_B, + FONT_R, + FPS, + STEP_DUR, + H, + W, + _tc, + _text_slide, +) +from moviepy import CompositeVideoClip, VideoClip +from moviepy.video.fx import FadeIn, FadeOut +import numpy as np + + +# ── YOLO Architecture Detail ────────────────────────────────────── +def _yolo_architecture() -> list[CompositeVideoClip]: + """Show YOLO architecture: backbone → head, output tensor.""" + slides = [] + + # Slide 1: YOLO architecture breakdown + def make_yolo_arch(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + # Pipeline: Image → Backbone → Neck → Head → SxSx(B*5+C) tensor + blocks = [ + ((60, 280), (100, 80), (50, 70, 90), "Obraz"), + ((200, 280), (100, 80), (70, 130, 200), "Backbone"), + ((340, 280), (100, 80), (200, 160, 80), "Neck"), + ((480, 280), (100, 80), (200, 100, 60), "Head"), + ((620, 280), (160, 80), (80, 200, 120), "SxSx(B*5+C)"), + ] + n_blocks = min(int(progress * 5) + 1, 5) + for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(blocks): + if i < n_blocks: + frame[by : by + bh, bx : bx + bw] = color + frame[by : by + 2, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + arrow_limit = 4 + if i < arrow_limit: + ax = bx + bw + 5 + ay = by + bh // 2 + frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170) + + # Output tensor breakdown (right side) + tensor_phase = 0.6 + if progress > tensor_phase: + # Show SxS grid + gx, gy = 850, 180 + gs = 120 + gn = 4 # simplified from 7 + gc = gs // gn + for r in range(gn): + for c in range(gn): + x = gx + c * gc + y = gy + r * gc + frame[y : y + gc - 1, x : x + gc - 1] = (40, 50, 65) + # Highlight one cell + frame[gy + gc : gy + 2 * gc - 1, gx + gc : gx + 2 * gc - 1] = ( + 80, + 200, + 120, + ) + + return frame + + arch_clip = VideoClip(make_yolo_arch, duration=STEP_DUR + 1).with_fps(FPS) + dur = STEP_DUR + 1 + labels = [ + ("YOLO: Architektura — krok po kroku", 26, "#FFE082", FONT_B, (80, 20)), + ( + "One-stage: JEDEN forward pass → WSZYSTKIE detekcje naraz", + 17, + "#B0BEC5", + FONT_R, + (80, 60), + ), + ("Obraz", 13, "white", FONT_R, (85, 295)), + ("Backbone", 13, "white", FONT_R, (215, 295)), + ("(ResNet/", 11, "#78909C", FONT_R, (210, 370)), + ("Darknet)", 11, "#78909C", FONT_R, (210, 390)), + ("Neck", 13, "white", FONT_R, (365, 295)), + ("(FPN/", 11, "#78909C", FONT_R, (360, 370)), + ("PANet)", 11, "#78909C", FONT_R, (360, 390)), + ("Head", 13, "white", FONT_R, (505, 295)), + ("(conv)", 11, "#78909C", FONT_R, (500, 370)), + ("Tensor wyjścia", 13, "#A5D6A7", FONT_R, (640, 295)), + ("Każda komórka SxS predykuje:", 15, "#FFE082", FONT_R, (830, 320)), + (" B bbox x (x,y,w,h,conf)", 13, "#B0BEC5", FONT_R, (830, 350)), + (" + C klas (prob.)", 13, "#B0BEC5", FONT_R, (830, 375)), + ("= SxSx(Bx5+C) tensor", 13, "#A5D6A7", FONT_R, (830, 400)), + ("Np. 7x7x(2x5+20) = 7x7x30", 13, "#78909C", FONT_R, (830, 430)), + ( + "Two-stage (R-CNN): (1) propozycje → (2) klasyfikacja = 2 przejścia", + 15, + "#EF9A9A", + FONT_R, + (80, 470), + ), + ( + "One-stage (YOLO): siatka → predykcja all-in-one = 1 przejście!", + 15, + "#A5D6A7", + FONT_R, + (80, 505), + ), + ( + "Ewolucja YOLO: v1(2016)→v3→v5→v8(2023, anchor-free, SOTA)", + 16, + "#FFE082", + FONT_R, + (80, 555), + ), + ( + "SSD (2016): multi-scale feature maps → lepsza detekcja małych obiektów", + 15, + "#64B5F6", + FONT_R, + (80, 595), + ), + ( + "FPN: łączy wczesne warstwy (małe obiekty) + późne (duże obiekty)", + 15, + "#78909C", + FONT_R, + (80, 630), + ), + ] + text_clips: list[VideoClip] = [arch_clip] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(dur) + .with_position(pos) + ) + text_clips.append(tc) + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + + return slides + + +# ── DETR ────────────────────────────────────────────────────────── +def _detr_demo() -> list[CompositeVideoClip]: + """Animate DETR: transformer detection, object queries, no NMS.""" + slides = [] + + # Slide 1: DETR pipeline + def make_detr_frame(t: float) -> np.ndarray: + frame = np.zeros((H, W, 3), dtype=np.uint8) + frame[:] = BG_COLOR + progress = min(t / (STEP_DUR * 0.7), 1.0) + + # DETR pipeline: Image → Backbone → Encoder → Decoder → N predictions + blocks = [ + ((50, 260), (80, 60), (50, 70, 90)), + ((170, 260), (90, 60), (70, 130, 200)), + ((300, 260), (110, 60), (200, 120, 60)), + ((450, 260), (110, 60), (200, 80, 160)), + ((600, 260), (120, 60), (80, 200, 120)), + ] + n_blocks = min(int(progress * 5) + 1, 5) + for i, ((bx, by), (bw, bh), color) in enumerate(blocks): + if i < n_blocks: + frame[by : by + bh, bx : bx + bw] = color + frame[by : by + 2, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( + min(c + 50, 255) for c in color + ) + arrow_limit = 4 + if i < arrow_limit: + ax = bx + bw + 5 + ay = by + bh // 2 + frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170) + + # Object queries illustration (right side) + query_phase = 0.5 + if progress > query_phase: + qx, qy = 800, 140 + for i in range(6): + y = qy + i * 50 + w = 130 + active_limit = 3 + active = i < active_limit + color = (80, 180, 120) if active else (60, 50, 50) + frame[y : y + 35, qx : qx + w] = color + frame[y : y + 1, qx : qx + w] = tuple(min(c + 40, 255) for c in color) + + # Arrow from decoder to queries + frame[285:288, 723:798] = (150, 150, 170) + + return frame + + detr_clip = VideoClip(make_detr_frame, duration=STEP_DUR + 1).with_fps(FPS) + dur = STEP_DUR + 1 + labels = [ + ("DETR: DEtection TRansformer (2020)", 26, "#FFE082", FONT_B, (80, 20)), + ( + "Radykalnie prostszy pipeline: BEZ anchorów, BEZ NMS!", + 17, + "#B0BEC5", + FONT_R, + (80, 60), + ), + ("Obraz", 12, "white", FONT_R, (65, 275)), + ("Backbone", 12, "white", FONT_R, (185, 275)), + ("Transformer", 12, "white", FONT_R, (310, 275)), + ("Encoder", 12, "white", FONT_R, (325, 295)), + ("Transformer", 12, "white", FONT_R, (460, 275)), + ("Decoder", 12, "white", FONT_R, (478, 295)), + ("N predykcji", 12, "white", FONT_R, (615, 275)), + ("Object Queries:", 14, "#FFE082", FONT_B, (800, 115)), + ("samochód 95%", 11, "white", FONT_R, (810, 148)), + ("pies 88%", 11, "white", FONT_R, (810, 198)), + ("rower 72%", 11, "white", FONT_R, (810, 248)), + ("brak", 11, "#78909C", FONT_R, (810, 298)), + ("brak", 11, "#78909C", FONT_R, (810, 348)), + ("brak", 11, "#78909C", FONT_R, (810, 398)), + ("100 wyuczonych queries", 13, "#FFE082", FONT_R, (800, 440)), + ("→ każdy 'szuka' obiektu", 13, "#FFE082", FONT_R, (800, 465)), + ] + text_clips: list[VideoClip] = [detr_clip] + for text, fs, color, font, pos in labels: + tc = ( + _tc(text=text, font_size=fs, color=color, font=font) + .with_duration(dur) + .with_position(pos) + ) + text_clips.append(tc) + slides.append( + CompositeVideoClip(text_clips, size=(W, H)).with_effects( + [FadeIn(0.3), FadeOut(0.3)] + ) + ) + + # Slide 2: Why no NMS + Hungarian matching + detr_details = [ + ("DETR: Dlaczego bez NMS? — krok po kroku", 24, "#FFE082", FONT_B, (80, 30)), + ( + "Problem NMS: duplikaty detekcji → ręcznie usuwaj post-hoc", + 16, + "#EF9A9A", + FONT_R, + (80, 90), + ), + ( + "DETR rozwiązanie: Hungarian matching (dopasowanie węgierskie)", + 17, + "#A5D6A7", + FONT_R, + (80, 130), + ), + ("", 10, "white", FONT_R, (80, 155)), + ("Jak to działa podczas TRENINGU:", 17, "white", FONT_B, (80, 180)), + ( + " 1. Sieć daje N=100 predykcji (queries)", + 15, + "#64B5F6", + FONT_R, + (100, 220), + ), + ( + " 2. Na obrazie jest np. 5 obiektów (ground truth)", + 15, + "#64B5F6", + FONT_R, + (100, 255), + ), + ( + " 3. Hungarian matching: optymalne dopasowanie 1:1", + 15, + "#FFE082", + FONT_R, + (100, 290), + ), + ( + " → query_1 ↔ gt_samochód (najlepsze dopasowanie)", + 14, + "#A5D6A7", + FONT_R, + (120, 325), + ), + (" → query_7 ↔ gt_pies", 14, "#A5D6A7", FONT_R, (120, 355)), + (" → query_3 ↔ gt_rower", 14, "#A5D6A7", FONT_R, (120, 385)), + ( + " → pozostałe 97 queries ↔ klasa 'brak obiektu'", + 14, + "#78909C", + FONT_R, + (120, 415), + ), + ( + " 4. Każdy obiekt ma DOKŁADNIE 1 predykcję → BRAK duplikatów!", + 15, + "#A5D6A7", + FONT_R, + (100, 455), + ), + ("", 10, "white", FONT_R, (100, 475)), + ( + "Self-attention w encoderze: cechy obrazu 'rozmawiają' ze sobą", + 15, + "#64B5F6", + FONT_R, + (80, 500), + ), + ( + "Cross-attention w decoderze: queries 'pytają' cechy obrazu", + 15, + "#CE93D8", + FONT_R, + (80, 535), + ), + ( + "→ query 'rozumie' który fragment obrazu to 'jego' obiekt", + 15, + "#FFE082", + FONT_R, + (80, 570), + ), + ( + "DETR = Detekcja Eliminująca Trikowe Redundancje (NMS, anchory)", + 16, + "#FFE082", + FONT_R, + (80, 620), + ), + ( + "Wada: wolniejszy trening (O(n²) attention) | Zaleta: prostszy pipeline!", + 15, + "#78909C", + FONT_R, + (80, 660), + ), + ] + slides.append(_text_slide(detr_details, duration=STEP_DUR + 1)) + + # Slide 3: Two-stage vs One-stage vs Transformer summary + summary_lines = [ + ( + "Podsumowanie: Two-stage vs One-stage vs Transformer", + 22, + "#FFE082", + FONT_B, + (80, 30), + ), + ("", 10, "white", FONT_R, (80, 55)), + ("TWO-STAGE (R-CNN family):", 18, "#EF9A9A", FONT_B, (80, 90)), + ( + " (1) Generuj propozycje → (2) Klasyfikuj per region", + 15, + "white", + FONT_R, + (100, 125), + ), + ( + " + Wysoka precyzja | - Wolniejsze (2 przejścia)", + 15, + "#78909C", + FONT_R, + (100, 155), + ), + ( + " R-CNN → Fast R-CNN → Faster R-CNN (0.2s)", + 15, + "#B0BEC5", + FONT_R, + (100, 185), + ), + ("", 10, "white", FONT_R, (80, 210)), + ("ONE-STAGE (YOLO, SSD):", 18, "#A5D6A7", FONT_B, (80, 240)), + ( + " Siatka → predykcja all-in-one (1 przejście)", + 15, + "white", + FONT_R, + (100, 275), + ), + ( + " + Bardzo szybkie (45-155 fps) | - Historycznie mniej precyzyjne", + 15, + "#78909C", + FONT_R, + (100, 305), + ), + ( + " YOLOv8 (2023): anchor-free, dorównuje two-stage!", + 15, + "#B0BEC5", + FONT_R, + (100, 335), + ), + ("", 10, "white", FONT_R, (80, 360)), + ("TRANSFORMER (DETR):", 18, "#CE93D8", FONT_B, (80, 390)), + ( + " Object queries + self-attention (globalny kontekst)", + 15, + "white", + FONT_R, + (100, 425), + ), + ( + " + Brak NMS/anchorów | - Wolniejszy trening (O(n²))", + 15, + "#78909C", + FONT_R, + (100, 455), + ), + ( + " Hungarian matching → 1:1 obiekt↔predykcja → brak duplikatów", + 15, + "#B0BEC5", + FONT_R, + (100, 485), + ), + ("", 10, "white", FONT_R, (80, 510)), + ( + "Trend: coraz prostsze pipeline, mniej ręcznych komponentów", + 17, + "white", + FONT_R, + (80, 540), + ), + ( + " R-CNN (SS+CNN+SVM+NMS) → YOLO " + "(backbone+head+NMS) → DETR (backbone+transformer)", + 14, + "#90CAF9", + FONT_R, + (80, 580), + ), + ( + "Metryki: mAP@0.5 (standard), mAP@0.5:0.95 (surowsza), " + "IoU do dopasowania", + 15, + "#78909C", + FONT_R, + (80, 630), + ), + ] + slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1)) + + return slides diff --git a/python_pkg/praca_magisterska_video/generate_images/_pubsub_common.py b/python_pkg/praca_magisterska_video/generate_images/_pubsub_common.py new file mode 100644 index 0000000..caf4a7a --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_pubsub_common.py @@ -0,0 +1,235 @@ +"""Common drawing primitives and constants for Pub/Sub diagrams.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib as mpl + +mpl.use("Agg") + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +DPI = 300 +BG = "white" +LN = "black" +FS = 9 +FS_TITLE = 13 +FIG_W = 8.27 # A4 width in inches +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +GRAY1 = "#E8E8E8" +GRAY2 = "#D0D0D0" +GRAY3 = "#B8B8B8" +GRAY4 = "#F5F5F5" +GRAY5 = "#C0C0C0" + + +@dataclass(frozen=True) +class BoxStyle: + """Optional styling for boxes.""" + + fill: str = "white" + lw: float = 1.2 + fontsize: float = FS + fontweight: str = "normal" + ha: str = "center" + va: str = "center" + rounded: bool = True + + +@dataclass(frozen=True) +class ArrowCfg: + """Config for arrows.""" + + lw: float = 1.2 + style: str = "->" + color: str = LN + label: str = "" + label_offset: float = 0.15 + label_fs: float = 8 + + +@dataclass(frozen=True) +class DashedCfg: + """Config for dashed arrows.""" + + lw: float = 1.0 + color: str = LN + label: str = "" + label_offset: float = 0.15 + label_fs: float = 8 + + +def draw_box( + ax: Axes, + pos: tuple[float, float], + size: tuple[float, float], + text: str, + style: BoxStyle | None = None, +) -> None: + """Draw box.""" + s = style or BoxStyle() + x, y = pos + w, h = size + if s.rounded: + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.05", + lw=s.lw, + edgecolor=LN, + facecolor=s.fill, + ) + else: + rect = mpatches.Rectangle( + (x, y), + w, + h, + lw=s.lw, + edgecolor=LN, + facecolor=s.fill, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h / 2, + text, + ha=s.ha, + va=s.va, + fontsize=s.fontsize, + fontweight=s.fontweight, + wrap=True, + ) + + +def draw_arrow( + ax: Axes, + start: tuple[float, float], + end: tuple[float, float], + cfg: ArrowCfg | None = None, +) -> None: + """Draw arrow.""" + c = cfg or ArrowCfg() + ax.annotate( + "", + xy=end, + xytext=start, + arrowprops={ + "arrowstyle": c.style, + "color": c.color, + "lw": c.lw, + }, + ) + if c.label: + mx = (start[0] + end[0]) / 2 + my = (start[1] + end[1]) / 2 + c.label_offset + ax.text( + mx, + my, + c.label, + ha="center", + va="bottom", + fontsize=c.label_fs, + color=c.color, + ) + + +def draw_dashed_arrow( + ax: Axes, + start: tuple[float, float], + end: tuple[float, float], + cfg: DashedCfg | None = None, +) -> None: + """Draw dashed arrow.""" + c = cfg or DashedCfg() + ax.annotate( + "", + xy=end, + xytext=start, + arrowprops={ + "arrowstyle": "->", + "color": c.color, + "lw": c.lw, + "linestyle": "dashed", + }, + ) + if c.label: + mx = (start[0] + end[0]) / 2 + my = (start[1] + end[1]) / 2 + c.label_offset + ax.text( + mx, + my, + c.label, + ha="center", + va="bottom", + fontsize=c.label_fs, + color=c.color, + ) + + +def draw_cross( + ax: Axes, + pos: tuple[float, float], + size: float = 0.15, + lw: float = 2.5, + color: str = "black", +) -> None: + """Draw cross.""" + x, y = pos + ax.plot( + [x - size, x + size], + [y - size, y + size], + color=color, + lw=lw, + ) + ax.plot( + [x - size, x + size], + [y + size, y - size], + color=color, + lw=lw, + ) + + +def draw_check( + ax: Axes, + pos: tuple[float, float], + size: float = 0.15, + lw: float = 2.5, + color: str = "black", +) -> None: + """Draw check.""" + x, y = pos + ax.plot( + [x - size, x - size * 0.2], + [y, y - size * 0.7], + color=color, + lw=lw, + ) + ax.plot( + [x - size * 0.2, x + size], + [y - size * 0.7, y + size * 0.5], + color=color, + lw=lw, + ) + + +def save(fig: plt.Figure, name: str) -> None: + """Save.""" + plt.tight_layout() + fig.savefig( + str(Path(OUTPUT_DIR) / name), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close(fig) diff --git a/python_pkg/praca_magisterska_video/generate_images/_pubsub_qos.py b/python_pkg/praca_magisterska_video/generate_images/_pubsub_qos.py new file mode 100644 index 0000000..a68178f --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_pubsub_qos.py @@ -0,0 +1,430 @@ +"""QoS delivery guarantee diagrams: at-most-once, at-least-once, exactly-once.""" + +from __future__ import annotations + +from _pubsub_common import ( + FIG_W, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + LN, + ArrowCfg, + BoxStyle, + draw_arrow, + draw_box, + draw_check, + draw_cross, + draw_dashed_arrow, + save, +) +import matplotlib.pyplot as plt + + +# ============================================================ +# 5. At-most-once (QoS 0) +# ============================================================ +def draw_qos_at_most_once() -> None: + """Draw qos at most once.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 6) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "QoS: At-most-once" + " \u2014 \u201ewy\u015blij i zapomnij\u201d" + " (0 lub 1 dostarczenie)", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + px, bx, sx = 1.0, 4.8, 8.5 + pw, bw, sw = 2.0, 2.2, 2.0 + bh = 0.8 + bold10_g1 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold") + bold10_g2 = BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold") + draw_box(ax, (px, 5.0), (pw, bh), "Publisher", bold10_g1) + draw_box(ax, (bx, 5.0), (bw, bh), "Broker", bold10_g2) + draw_box( + ax, + (sx, 5.0), + (sw, bh), + "Subscriber", + bold10_g1, + ) + + for xc in [px + pw / 2, bx + bw / 2, sx + sw / 2]: + ax.plot( + [xc, xc], + [5.0, 1.2], + color=GRAY3, + lw=1, + linestyle=":", + ) + + # Scenario A: success + y = 4.3 + ax.text( + 0.2, + y + 0.15, + "Scenariusz A:", + fontsize=8.5, + fontweight="bold", + ) + msg9 = ArrowCfg(label="MSG", label_fs=9) + draw_arrow(ax, (px + pw / 2, y), (bx + bw / 2, y), msg9) + draw_arrow( + ax, + (bx + bw / 2, y - 0.6), + (sx + sw / 2, y - 0.6), + msg9, + ) + draw_check(ax, (sx + sw / 2 + 0.4, y - 0.6), size=0.18) + ax.text( + sx + sw / 2 + 0.7, + y - 0.6, + "OK", + fontsize=9, + fontweight="bold", + ) + + # Scenario B: lost + y = 2.6 + ax.text( + 0.2, + y + 0.15, + "Scenariusz B:", + fontsize=8.5, + fontweight="bold", + ) + draw_arrow(ax, (px + pw / 2, y), (bx + bw / 2, y), msg9) + draw_dashed_arrow(ax, (bx + bw / 2, y - 0.6), (7.5, y - 0.6)) + draw_cross(ax, (7.8, y - 0.6), size=0.2) + ax.text( + 8.2, + y - 0.55, + "UTRACONA", + fontsize=9, + fontweight="bold", + ) + ax.text( + 8.2, + y - 1.0, + "(brak retransmisji)", + fontsize=8, + style="italic", + ) + + ax.text( + 6.0, + 0.5, + "Brak ACK, brak retransmisji." + " Najszybszy. Use case:" + " logi, metryki, telemetria.", + ha="center", + va="center", + fontsize=9, + bbox={ + "boxstyle": "round,pad=0.4", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + save(fig, "pubsub_qos_at_most_once.png") + + +# ============================================================ +# 6. At-least-once (QoS 1) +# ============================================================ +def draw_qos_at_least_once() -> None: + """Draw qos at least once.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 6.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "QoS: At-least-once" + " \u2014 \u201epowtarzaj a\u017c potwierdz\u0105\u201d" + " (\u22651 dostarczenie)", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + bx, bw = 3.5, 2.2 + sx, sw = 8.0, 2.2 + bh = 0.8 + draw_box( + ax, + (bx, 5.5), + (bw, bh), + "Broker", + BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"), + ) + draw_box( + ax, + (sx, 5.5), + (sw, bh), + "Subscriber", + BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold"), + ) + + for xc in [bx + bw / 2, sx + sw / 2]: + ax.plot( + [xc, xc], + [5.5, 0.8], + color=GRAY3, + lw=1, + linestyle=":", + ) + + # Step 1: send MSG + y1 = 4.8 + draw_arrow( + ax, + (bx + bw / 2, y1), + (sx + sw / 2, y1), + ArrowCfg(label="MSG #1", label_fs=9), + ) + draw_check(ax, (sx + sw + 0.2, y1), size=0.15) + ax.text(sx + sw + 0.5, y1, "odebrano", fontsize=8) + + # Step 2: ACK lost + y2 = 3.9 + draw_dashed_arrow( + ax, + (sx + sw / 2, y2), + (bx + bw + 1.2, y2), + ) + ax.text( + (bx + bw / 2 + sx + sw / 2) / 2, + y2 + 0.18, + "ACK", + fontsize=9, + ) + draw_cross(ax, (bx + bw + 0.8, y2), size=0.18) + ax.text( + bx + 0.3, + y2 - 0.35, + "ACK utracony!", + fontsize=8.5, + style="italic", + ) + + # Step 3: timeout -> retry + y3 = 2.9 + ax.text( + bx + bw / 2, + y3 + 0.45, + "timeout...", + fontsize=8.5, + style="italic", + ha="center", + ) + draw_arrow( + ax, + (bx + bw / 2, y3), + (sx + sw / 2, y3), + ArrowCfg(label="MSG #1 (retry)", label_fs=9), + ) + draw_check(ax, (sx + sw + 0.2, y3), size=0.15) + ax.text( + sx + sw + 0.5, + y3, + "odebrano\n(ponownie!)", + fontsize=8, + ) + + # Step 4: ACK ok + y4 = 2.0 + draw_arrow( + ax, + (sx + sw / 2, y4), + (bx + bw / 2, y4), + ArrowCfg(label="ACK", label_fs=9), + ) + draw_check(ax, (bx + bw / 2 - 0.5, y4), size=0.18) + + # Duplicate bracket + ax.annotate( + "", + xy=(sx + sw + 1.3, y1), + xytext=(sx + sw + 1.3, y3), + arrowprops={ + "arrowstyle": "<->", + "color": "black", + "lw": 1.2, + }, + ) + ax.text( + sx + sw + 1.6, + (y1 + y3) / 2, + "DUPLIKAT!\nSubscriber\notrzyma\u0142 2x", + fontsize=9, + ha="left", + va="center", + fontweight="bold", + bbox={ + "boxstyle": "round,pad=0.25", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + ax.text( + 6.0, + 0.5, + "Broker czeka na ACK, retransmituje" + " po timeout. Mog\u0105 by\u0107 duplikaty!\n" + "Use case: zam\u00f3wienia, p\u0142atno\u015bci" + " (subscriber musi by\u0107 idempotentny).", + ha="center", + va="center", + fontsize=9, + bbox={ + "boxstyle": "round,pad=0.4", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + save(fig, "pubsub_qos_at_least_once.png") + + +# ============================================================ +# 7. Exactly-once (QoS 2) +# ============================================================ +def draw_qos_exactly_once() -> None: + """Draw qos exactly once.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "QoS: Exactly-once \u2014 4-krokowy" + " handshake (dok\u0142adnie 1 dostarczenie)", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + bx, bw = 2.5, 2.2 + sx, sw = 7.5, 2.2 + bh = 0.8 + draw_box( + ax, + (bx, 6.0), + (bw, bh), + "Broker", + BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"), + ) + draw_box( + ax, + (sx, 6.0), + (sw, bh), + "Subscriber", + BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold"), + ) + + for xc in [bx + bw / 2, sx + sw / 2]: + ax.plot( + [xc, xc], + [6.0, 1.0], + color=GRAY3, + lw=1, + linestyle=":", + ) + + steps = [ + ( + 5.2, + "right", + "PUBLISH (msg_id=42)", + "Broker wysy\u0142a wiadomo\u015b\u0107", + ), + ( + 4.2, + "left", + "PUBREC (otrzyma\u0142em id=42)", + "Sub potwierdza odbi\u00f3r," " zapisuje id", + ), + ( + 3.2, + "right", + "PUBREL (mo\u017cesz przetworzy\u0107)", + "Broker zwalnia wiadomo\u015b\u0107", + ), + ( + 2.2, + "left", + "PUBCOMP (zako\u0144czone)", + "Sub potwierdza przetworzenie", + ), + ] + + for i, (y, direction, label, desc) in enumerate(steps): + ax.text( + bx + bw / 2 - 0.7, + y, + f"{i + 1}", + fontsize=9, + fontweight="bold", + ha="center", + va="center", + bbox={ + "boxstyle": "circle,pad=0.18", + "facecolor": GRAY3, + "edgecolor": LN, + }, + ) + + if direction == "right": + draw_arrow( + ax, + (bx + bw / 2, y), + (sx + sw / 2, y), + ArrowCfg(label=label, label_fs=9), + ) + else: + draw_arrow( + ax, + (sx + sw / 2, y), + (bx + bw / 2, y), + ArrowCfg(label=label, label_fs=9), + ) + + ax.text( + sx + sw + 0.3, + y, + desc, + fontsize=8, + ha="left", + va="center", + style="italic", + ) + + ax.text( + 6.0, + 0.6, + "Deduplikacja po msg_id." + " Sub nie przetwarza przed PUBREL.\n" + "Najkosztowniejszy (4 pakiety)." + " Use case: transakcje finansowe," + " krytyczne zdarzenia.", + ha="center", + va="center", + fontsize=9, + bbox={ + "boxstyle": "round,pad=0.4", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + save(fig, "pubsub_qos_exactly_once.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_pubsub_topic_content.py b/python_pkg/praca_magisterska_video/generate_images/_pubsub_topic_content.py new file mode 100644 index 0000000..07d5ab6 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_pubsub_topic_content.py @@ -0,0 +1,239 @@ +"""Subscription-type diagrams: topic-based and content-based.""" + +from __future__ import annotations + +from _pubsub_common import ( + FIG_W, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + ArrowCfg, + BoxStyle, + DashedCfg, + draw_arrow, + draw_box, + draw_dashed_arrow, + save, +) +import matplotlib.pyplot as plt + + +# ============================================================ +# 1. Topic-based subscription +# ============================================================ +def draw_sub_topic() -> None: + """Draw sub topic.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.0)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 5.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Subskrypcja topic-based" " \u2014 routing po nazwie tematu", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold") + fs85 = BoxStyle(fill=GRAY1, fontsize=8.5) + + draw_box(ax, (0.2, 3.2), (2.4, 1.1), "Publisher", bold10) + draw_box( + ax, + (0.3, 1.8), + (2.2, 0.8), + 'topic: "orders"', + BoxStyle(fill=GRAY4, fontsize=8), + ) + draw_box( + ax, + (0.3, 0.7), + (2.2, 0.8), + 'topic: "payments"', + BoxStyle(fill=GRAY4, fontsize=8), + ) + + draw_box( + ax, + (4.2, 1.5), + (2.8, 2.2), + "BROKER\n\ntopic routing", + BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"), + ) + + draw_box( + ax, + (8.5, 3.8), + (3.0, 1.0), + 'Subscriber A\nsubskrybuje: "orders"', + fs85, + ) + draw_box( + ax, + (8.5, 2.2), + (3.0, 1.0), + 'Subscriber B\nsubskrybuje: "payments"', + fs85, + ) + draw_box( + ax, + (8.5, 0.6), + (3.0, 1.0), + 'Subscriber C\nsubskrybuje: "orders"', + fs85, + ) + + fs8 = ArrowCfg(label_fs=8) + draw_arrow(ax, (2.6, 2.2), (4.2, 2.8), fs8) + draw_arrow(ax, (2.6, 1.1), (4.2, 2.2), fs8) + + draw_arrow( + ax, + (7.0, 3.4), + (8.5, 4.2), + ArrowCfg(label='"orders"', label_fs=8), + ) + draw_arrow( + ax, + (7.0, 2.6), + (8.5, 2.7), + ArrowCfg(label='"payments"', label_fs=8), + ) + draw_arrow( + ax, + (7.0, 2.2), + (8.5, 1.2), + ArrowCfg(label='"orders"', label_fs=8), + ) + + ax.text( + 6.0, + 0.1, + "Subscriber deklaruje nazw\u0119 tematu." + " Broker kieruje wiadomo\u015bci\n" + "do WSZYSTKICH subscriber\u00f3w" + " danego tematu. Najprostszy model.", + ha="center", + va="bottom", + fontsize=8.5, + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + save(fig, "pubsub_sub_topic.png") + + +# ============================================================ +# 2. Content-based subscription +# ============================================================ +def draw_sub_content() -> None: + """Draw sub content.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 6) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Subskrypcja content-based" + " \u2014 filtrowanie po tre\u015bci wiadomo\u015bci", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold") + draw_box(ax, (0.2, 3.5), (2.4, 1.1), "Publisher", bold10) + draw_box( + ax, + (0.2, 1.8), + (2.4, 1.2), + 'price: 150\ntype: "book"\ncategory: "IT"', + BoxStyle(fill=GRAY4, fontsize=8.5), + ) + + draw_box( + ax, + (4.0, 2.0), + (3.0, 2.5), + "BROKER\n\newaluuje filtry\n" "ka\u017cdego subscribera", + BoxStyle(fill=GRAY2, fontsize=9, fontweight="bold"), + ) + + fs9 = BoxStyle(fill=GRAY1, fontsize=9) + draw_box( + ax, + (8.5, 4.2), + (3.2, 1.0), + "Sub A\nfiltr: price > 100", + fs9, + ) + draw_box( + ax, + (8.5, 2.6), + (3.2, 1.0), + 'Sub B\nfiltr: type = "food"', + fs9, + ) + draw_box( + ax, + (8.5, 1.0), + (3.2, 1.0), + "Sub C\nfiltr: price < 50", + fs9, + ) + + draw_arrow(ax, (2.6, 2.4), (4.0, 3.0)) + draw_arrow( + ax, + (7.0, 4.0), + (8.5, 4.6), + ArrowCfg( + label="150 > 100 \u2713 dostarczono", + label_fs=8, + ), + ) + draw_dashed_arrow( + ax, + (7.0, 3.2), + (8.5, 3.1), + DashedCfg( + label='"book" \u2260 "food"' " \u2717 odrzucono", + label_fs=8, + ), + ) + draw_dashed_arrow( + ax, + (7.0, 2.5), + (8.5, 1.6), + DashedCfg( + label="150 < 50 \u2717 odrzucono", + label_fs=8, + ), + ) + + ax.text( + 6.0, + 0.2, + "Broker analizuje TRE\u015a\u0106 wiadomo\u015bci" + " i ewaluuje predykaty.\n" + "Bardziej elastyczny ni\u017c topic-based," + " ale wolniejszy (koszt ewaluacji).", + ha="center", + va="bottom", + fontsize=8.5, + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + save(fig, "pubsub_sub_content.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_pubsub_type_hierarchical.py b/python_pkg/praca_magisterska_video/generate_images/_pubsub_type_hierarchical.py new file mode 100644 index 0000000..d619451 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_pubsub_type_hierarchical.py @@ -0,0 +1,279 @@ +"""Subscription-type diagrams: type-based and hierarchical.""" + +from __future__ import annotations + +from _pubsub_common import ( + FIG_W, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + ArrowCfg, + BoxStyle, + draw_arrow, + draw_box, + save, +) +import matplotlib.pyplot as plt + + +# ============================================================ +# 3. Type-based subscription +# ============================================================ +def draw_sub_type() -> None: + """Draw sub type.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 6.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Subskrypcja type-based" " \u2014 routing po typie (klasie) obiektu", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold") + draw_box(ax, (0.2, 4.2), (2.4, 1.1), "Publisher", bold10) + + fs9_g4 = BoxStyle(fill=GRAY4, fontsize=9) + draw_box( + ax, + (0.1, 2.8), + (2.6, 0.9), + "new OrderEvent()", + fs9_g4, + ) + draw_box( + ax, + (0.1, 1.5), + (2.6, 0.9), + "new PaymentEvent()", + fs9_g4, + ) + + draw_box( + ax, + (4.0, 2.3), + (3.0, 2.4), + "BROKER\n\nrouting po\ntypie klasy", + BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"), + ) + + fs9 = BoxStyle(fill=GRAY1, fontsize=9) + draw_box( + ax, + (8.5, 4.8), + (3.2, 1.0), + "Sub A\n\u2192 OrderEvent", + fs9, + ) + draw_box( + ax, + (8.5, 3.2), + (3.2, 1.0), + "Sub B\n\u2192 PaymentEvent", + fs9, + ) + draw_box( + ax, + (8.5, 1.6), + (3.2, 1.0), + "Sub C\n\u2192 Event (base)", + fs9, + ) + + draw_arrow(ax, (2.7, 3.2), (4.0, 3.8)) + draw_arrow(ax, (2.7, 2.0), (4.0, 3.0)) + draw_arrow( + ax, + (7.0, 4.3), + (8.5, 5.2), + ArrowCfg(label="OrderEvent", label_fs=8), + ) + draw_arrow( + ax, + (7.0, 3.5), + (8.5, 3.7), + ArrowCfg(label="PaymentEvent", label_fs=8), + ) + draw_arrow( + ax, + (7.0, 3.0), + (8.5, 2.2), + ArrowCfg(label="oba (dziedziczenie!)", label_fs=8), + ) + + hx, hy = 0.5, 0.0 + draw_box( + ax, + (hx + 2.0, hy + 0.2), + (1.8, 0.6), + "Event", + BoxStyle(fill=GRAY3, fontsize=8, fontweight="bold"), + ) + draw_box( + ax, + (hx, hy + 0.2), + (1.8, 0.6), + "OrderEvent", + BoxStyle(fill=GRAY4, fontsize=7.5), + ) + draw_box( + ax, + (hx + 4.0, hy + 0.2), + (2.0, 0.6), + "PaymentEvent", + BoxStyle(fill=GRAY4, fontsize=7.5), + ) + draw_arrow( + ax, + (hx + 2.9, hy + 0.2), + (hx + 0.9, hy + 0.2), + ArrowCfg( + lw=1.0, + label="extends", + label_offset=-0.3, + label_fs=7, + ), + ) + draw_arrow( + ax, + (hx + 2.9, hy + 0.2), + (hx + 5.0, hy + 0.2), + ArrowCfg( + lw=1.0, + label="extends", + label_offset=-0.3, + label_fs=7, + ), + ) + + ax.text( + 9.5, + 0.5, + "Sub C subskrybuje bazowy Event\n" "\u2192 otrzymuje WSZYSTKIE podtypy", + ha="center", + va="center", + fontsize=8.5, + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + }, + ) + + save(fig, "pubsub_sub_type.png") + + +# ============================================================ +# 4. Hierarchical / Wildcards subscription +# ============================================================ +def draw_sub_hierarchical() -> None: + """Draw sub hierarchical.""" + fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Subskrypcja hierarchiczna (wildcards)" " \u2014 wzorce temat\u00f3w", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + bold10 = BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold") + draw_box(ax, (4.5, 5.8), (2.4, 0.8), "sensors/", bold10) + + fs9_g3 = BoxStyle(fill=GRAY3, fontsize=9) + draw_box( + ax, + (1.5, 4.2), + (2.4, 0.8), + "temperature/", + fs9_g3, + ) + draw_box( + ax, + (7.5, 4.2), + (2.4, 0.8), + "humidity/", + fs9_g3, + ) + + fs85_g4 = BoxStyle(fill=GRAY4, fontsize=8.5) + draw_box(ax, (0.2, 2.8), (1.8, 0.7), "room1", fs85_g4) + draw_box(ax, (2.4, 2.8), (1.8, 0.7), "room2", fs85_g4) + draw_box(ax, (6.8, 2.8), (1.8, 0.7), "room1", fs85_g4) + draw_box(ax, (9.0, 2.8), (1.8, 0.7), "room2", fs85_g4) + + thin = ArrowCfg(lw=1.0) + draw_arrow(ax, (5.7, 5.8), (2.7, 5.0), thin) + draw_arrow(ax, (5.7, 5.8), (8.7, 5.0), thin) + draw_arrow(ax, (2.2, 4.2), (1.1, 3.5), thin) + draw_arrow(ax, (3.2, 4.2), (3.3, 3.5), thin) + draw_arrow(ax, (8.2, 4.2), (7.7, 3.5), thin) + draw_arrow(ax, (9.2, 4.2), (9.9, 3.5), thin) + + ax.text( + 1.1, + 2.4, + "sensors/temperature/room1", + fontsize=7, + ha="center", + fontfamily="monospace", + style="italic", + ) + ax.text( + 3.3, + 2.4, + "sensors/temperature/room2", + fontsize=7, + ha="center", + fontfamily="monospace", + style="italic", + ) + + ax.text( + 0.3, + 1.5, + "Wzorce subskrypcji (MQTT-style):", + fontsize=10, + fontweight="bold", + ) + + patterns = [ + ( + '"sensors/temperature/room1"', + "\u2192 TYLKO room1", + "(dok\u0142adne dopasowanie)", + ), + ( + '"sensors/temperature/*"', + "\u2192 room1, room2", + "( * = jeden poziom)", + ), + ( + '"sensors/#"', + "\u2192 WSZYSTKO", + "( # = dowolna g\u0142\u0119boko\u015b\u0107)", + ), + ] + for i, (pat, result, note) in enumerate(patterns): + yy = 0.9 - i * 0.55 + ax.text( + 0.5, + yy, + pat, + fontsize=9, + fontweight="bold", + fontfamily="monospace", + ) + ax.text(7.0, yy, result, fontsize=9, fontweight="bold") + ax.text(9.5, yy, note, fontsize=8, style="italic") + + save(fig, "pubsub_sub_hierarchical.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q20_architectures.py b/python_pkg/praca_magisterska_video/generate_images/_q20_architectures.py new file mode 100644 index 0000000..2f7cb65 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q20_architectures.py @@ -0,0 +1,421 @@ +"""Spark Streaming, Lambda/Kappa architecture, and exactly-once diagrams for Q20.""" + +from __future__ import annotations + +from _q20_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + draw_table, + plt, + save_fig, +) + + +# ============================================================ +# 12. Spark Streaming architecture +# ============================================================ +def gen_spark_streaming_arch() -> None: + """Gen spark streaming arch.""" + fig, ax = plt.subplots(figsize=(9, 5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Spark Streaming — architektura (micro-batch)", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Cluster border + draw_box(ax, 0.3, 0.5, 11.4, 5.8, "", fill=GRAY4, rounded=True, lw=2.5) + ax.text( + 6.0, 6.0, "SPARK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold" + ) + + # Driver + draw_box( + ax, + 1.0, + 4.5, + 3.0, + 1.2, + "Driver\n(planuje mini-batche)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + draw_arrow(ax, 2.5, 4.5, 6.0, 4.0, lw=1.5) + + # Batches + batches = ["batch 1\n(e1,e2,e3)", "batch 2\n(e4,e5,e6)", "batch 3\n(e7,e8,e9)"] + for i, b in enumerate(batches): + y = 2.8 - i * 1.0 + draw_box( + ax, 4.5, y, 2.5, 0.8, b, fill=GRAY1, fontsize=FS_SMALL, fontweight="bold" + ) + # map → reduce + draw_arrow(ax, 7.0, y + 0.4, 7.5, y + 0.4, lw=1) + draw_box(ax, 7.5, y, 1.3, 0.8, "map→\nreduce", fill=GRAY3, fontsize=5.5) + draw_arrow(ax, 8.8, y + 0.4, 9.3, y + 0.4, lw=1) + draw_box( + ax, 9.3, y, 1.5, 0.8, f"result {i + 1}", fill="white", fontsize=FS_SMALL + ) + + # Spark ecosystem + draw_box( + ax, + 1.0, + 1.0, + 3.0, + 1.0, + "Spark SQL / MLlib\n(ten sam ekosystem!)", + fill=GRAY5, + fontsize=FS, + fontweight="bold", + ) + + ax.text( + 6.0, + 0.3, + "ZALETA: batch API | WADA: latencja ≥ batch interval (~100ms)", + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN}, + ) + + save_fig(fig, "q20_spark_streaming_arch.png") + + +# ============================================================ +# 13. Lambda vs Kappa architecture +# ============================================================ +def gen_lambda_vs_kappa() -> None: + """Gen lambda vs kappa.""" + fig, axes = plt.subplots(2, 1, figsize=(10, 7)) + fig.suptitle("Architektura Lambda vs Kappa", fontsize=FS_TITLE, fontweight="bold") + + # --- Lambda --- + ax = axes[0] + ax.set_xlim(0, 12) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "LAMBDA — 2 ścieżki (batch + speed)", fontsize=FS_LABEL, fontweight="bold" + ) + + # Source + draw_box( + ax, + 0.3, + 1.8, + 2.0, + 1.5, + "Źródło\ndanych", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Batch layer (top) + draw_box( + ax, + 3.5, + 3.3, + 3.0, + 1.2, + "Batch Layer\n(Spark)\nprzelicza co godzinę", + fill=GRAY1, + fontsize=FS_SMALL, + fontweight="bold", + ) + draw_arrow(ax, 2.3, 3.0, 3.5, 3.9, lw=1.5) + + # Speed layer (bottom) + draw_box( + ax, + 3.5, + 0.8, + 3.0, + 1.2, + "Speed Layer\n(Flink)\nreal-time", + fill=GRAY3, + fontsize=FS_SMALL, + fontweight="bold", + ) + draw_arrow(ax, 2.3, 2.2, 3.5, 1.4, lw=1.5) + + # Results + draw_box( + ax, + 7.5, + 3.3, + 2.0, + 1.2, + "Dokładne\nwyniki\n(wolne)", + fill=GRAY4, + fontsize=FS_SMALL, + ) + draw_arrow(ax, 6.5, 3.9, 7.5, 3.9, lw=1.5) + + draw_box( + ax, + 7.5, + 0.8, + 2.0, + 1.2, + "Przybliżone\nwyniki\n(szybkie)", + fill=GRAY4, + fontsize=FS_SMALL, + ) + draw_arrow(ax, 6.5, 1.4, 7.5, 1.4, lw=1.5) + + # Merge + draw_box( + ax, + 10.0, + 2.0, + 1.5, + 1.5, + "MERGE\n→ UI", + fill=GRAY5, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 9.5, 3.5, 10.0, 3.0, lw=1.5) + draw_arrow(ax, 9.5, 1.8, 10.0, 2.5, lw=1.5) + + ax.text( + 6.0, + 0.1, + "2 systemy, 2 kody — złożone ale pewne", + fontsize=FS, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + # --- Kappa --- + ax = axes[1] + ax.set_xlim(0, 12) + ax.set_ylim(0, 4) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "KAPPA — 1 ścieżka (streaming only)", fontsize=FS_LABEL, fontweight="bold" + ) + + # Source + draw_box( + ax, + 0.3, + 1.3, + 2.0, + 1.5, + "Źródło\ndanych", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Single streaming layer + draw_box( + ax, + 3.5, + 1.3, + 3.5, + 1.5, + "Streaming Layer\n(Flink)\n+ replay z Kafka log", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 2.3, 2.05, 3.5, 2.05, lw=2) + + # Output + draw_box( + ax, + 8.0, + 1.3, + 2.5, + 1.5, + "Wyniki\n→ UI", + fill=GRAY4, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 7.0, 2.05, 8.0, 2.05, lw=2) + + # Replay arrow + ax.annotate( + "", + xy=(3.5, 1.0), + xytext=(7.0, 1.0), + arrowprops={ + "arrowstyle": "<-", + "lw": 1.5, + "color": LN, + "connectionstyle": "arc3,rad=0.3", + "linestyle": "--", + }, + ) + ax.text( + 5.25, + 0.3, + "Replay z Kafka\n(przetwórz historię od nowa)", + fontsize=FS_SMALL, + ha="center", + style="italic", + ) + + ax.text( + 6.0, + 3.3, + "1 system, 1 kod — prostsze, ale replay = dużo I/O", + fontsize=FS, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + fig.tight_layout(rect=[0, 0, 1, 0.92]) + save_fig(fig, "q20_lambda_vs_kappa.png") + + +# ============================================================ +# 14. Lambda vs Kappa comparison table +# ============================================================ +def gen_lambda_kappa_table() -> None: + """Gen lambda kappa table.""" + fig, ax = plt.subplots(figsize=(8, 3.5)) + ax.set_xlim(0, 10) + ax.set_ylim(-4.5, 1) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Lambda vs Kappa — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + headers = ["Cecha", "Lambda", "Kappa"] + col_w = [2.5, 3.5, 3.5] + rows = [ + ["Ścieżki", "2 (batch + speed)", "1 (streaming)"], + ["Kod", "2 implementacje", "1 implementacja"], + ["Złożoność", "wysoka", "niska"], + ["Replay", "batch przelicza", "Kafka replay"], + ["Spójność", "merge wymagany", "natywna"], + ["Przykład", "Netflix, LinkedIn", "Uber, Confluent"], + ] + draw_table( + ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.55, fontsize=7.5 + ) + + save_fig(fig, "q20_lambda_kappa_table.png") + + +# ============================================================ +# 15. Exactly-once comparison +# ============================================================ +def gen_exactly_once() -> None: + """Gen exactly once.""" + fig, ax = plt.subplots(figsize=(9, 5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Exactly-Once — mechanizmy na 3 platformach", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Flink + draw_box(ax, 0.3, 4.3, 11.0, 2.0, "", fill=GRAY4, rounded=True, lw=1.5) + ax.text( + 1.0, + 5.9, + "Flink — Distributed Snapshots (Chandy-Lamport)", + fontsize=FS, + fontweight="bold", + ) + + flink_steps = ["source", "|B|", "map()", "|B|", "sink()"] + bx = 1.0 + for s in flink_steps: + if s == "|B|": + ax.text( + bx + 0.25, + 4.85, + s, + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN}, + ) + draw_arrow(ax, bx - 0.1, 4.85, bx + 0.05, 4.85, lw=1) + bx += 0.7 + else: + draw_box( + ax, + bx, + 4.6, + 1.5, + 0.55, + s, + fill=GRAY1, + fontsize=FS_SMALL, + fontweight="bold", + ) + bx += 1.8 + ax.text( + 8.5, + 5.0, + "barrier → save state\n→ checkpoint (HDFS/S3)", + fontsize=FS_SMALL, + style="italic", + ) + + # Kafka Streams + draw_box(ax, 0.3, 2.3, 11.0, 1.5, "", fill=GRAY1, rounded=True, lw=1.5) + ax.text( + 1.0, 3.5, "Kafka Streams — Transakcje Kafka", fontsize=FS, fontweight="bold" + ) + ax.text( + 1.5, + 2.85, + "idempotent producer + begin TX → produce → commit TX → consumer offsets w TX", + fontsize=FS_SMALL, + ) + + # Spark + draw_box(ax, 0.3, 0.5, 11.0, 1.5, "", fill=GRAY3, rounded=True, lw=1.5) + ax.text( + 1.0, + 1.7, + "Spark Streaming — Write-Ahead Log (WAL)", + fontsize=FS, + fontweight="bold", + ) + ax.text( + 1.5, + 1.05, + "WAL + checkpointing micro-batchów + idempotent sinks (np. upsert do DB)", + fontsize=FS_SMALL, + ) + + save_fig(fig, "q20_exactly_once.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q20_batch_and_windows.py b/python_pkg/praca_magisterska_video/generate_images/_q20_batch_and_windows.py new file mode 100644 index 0000000..079f9cd --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q20_batch_and_windows.py @@ -0,0 +1,449 @@ +"""Batch vs streaming concept and window type diagrams for Q20.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q20_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + plt, + save_fig, +) +import matplotlib.patches as mpatches + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +# ============================================================ +# 1. Batch vs Streaming concept +# ============================================================ +def gen_batch_vs_streaming() -> None: + """Gen batch vs streaming.""" + fig, axes = plt.subplots(2, 1, figsize=(9, 5)) + fig.suptitle( + "Batch vs Streaming — dwa modele przetwarzania", + fontsize=FS_TITLE, + fontweight="bold", + ) + + # Batch + ax = axes[0] + ax.set_xlim(0, 12) + ax.set_ylim(0, 3) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("BATCH (wsadowe)", fontsize=FS_LABEL, fontweight="bold") + + # Data collected + draw_box( + ax, + 0.5, + 0.8, + 3.0, + 1.4, + "Zbierz WSZYSTKIE\ndane\n(godziny / dni)", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 3.5, 1.5, 4.5, 1.5, lw=2) + draw_box( + ax, + 4.5, + 0.8, + 2.5, + 1.4, + "Analiza\n(batch job)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 7.0, 1.5, 8.0, 1.5, lw=2) + draw_box( + ax, + 8.0, + 0.8, + 2.5, + 1.4, + "Wynik\n(jednorazowy)", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + ax.text(11.0, 1.5, "min-h", fontsize=FS, va="center", fontweight="bold") + + # Streaming + ax = axes[1] + ax.set_xlim(0, 12) + ax.set_ylim(0, 3) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("STREAMING (strumieniowe)", fontsize=FS_LABEL, fontweight="bold") + + # Events flowing + events_x = [0.5, 1.5, 2.5, 3.5] + for i, ex in enumerate(events_x): + draw_box( + ax, + ex, + 1.0, + 0.8, + 0.8, + f"e{i + 1}", + fill=GRAY4, + fontsize=FS, + fontweight="bold", + rounded=False, + ) + if i < len(events_x) - 1: + draw_arrow(ax, ex + 0.8, 1.4, ex + 1.0, 1.4, lw=1) + + ax.text(4.8, 1.4, "...", fontsize=FS_LABEL, va="center") + draw_arrow(ax, 5.2, 1.4, 5.8, 1.4, lw=2) + + draw_box( + ax, + 5.8, + 0.8, + 2.8, + 1.4, + "Analiza\nCIĄGŁA\n(event-by-event)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 8.6, 1.5, 9.3, 1.5, lw=2) + draw_box( + ax, + 9.3, + 0.8, + 2.0, + 1.4, + "Wyniki\nciągłe", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + ax.text(11.5, 0.5, "ms-s", fontsize=FS, va="center", fontweight="bold") + + # Arrow marking infinity + ax.annotate( + "", + xy=(0.2, 1.4), + xytext=(-0.3, 1.4), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(0.0, 2.3, "∞ zdarzeń", fontsize=FS_SMALL, ha="center", style="italic") + + fig.tight_layout(rect=[0, 0, 1, 0.92]) + save_fig(fig, "q20_batch_vs_streaming.png") + + +# ============================================================ +# 2. All 4 window types (TSSG) +# ============================================================ +def _draw_tumbling_window(ax: Axes, events: list[int]) -> None: + """Draw tumbling window section.""" + ax.set_xlim(0, 14) + ax.set_ylim(0, 4) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Tumbling Window (okno przerzutne) — rozłączne, stały rozmiar", + fontsize=FS_LABEL, + fontweight="bold", + ) + + # Time axis + ax.annotate( + "", + xy=(13.5, 1.0), + xytext=(0.3, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") + + # Events + for i, e in enumerate(events): + x = 1.0 + i * 1.0 + ax.plot(x, 1.0, "ko", markersize=5) + ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center") + + # Windows + colors_w = [GRAY1, GRAY3, GRAY1, GRAY3] + for w in range(4): + x_start = 1.0 + w * 3.0 - 0.3 + rect = mpatches.FancyBboxPatch( + (x_start, 1.5), + 3.0, + 1.2, + boxstyle="round,pad=0.1", + facecolor=colors_w[w], + edgecolor=LN, + lw=1.5, + ) + ax.add_patch(rect) + ax.text( + x_start + 1.5, + 2.1, + f"Okno {w + 1}", + fontsize=FS, + ha="center", + fontweight="bold", + ) + # Braces down to events + for j in range(3): + ex = 1.0 + w * 3.0 + j * 1.0 + ax.plot([ex, ex], [1.0, 1.5], color=LN, lw=0.8, linestyle="--") + + ax.text( + 7.0, + 3.2, + "Każde zdarzenie → DOKŁADNIE 1 okno. Zero nakładania.", + fontsize=FS, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + +def _draw_sliding_window(ax: Axes, events: list[int]) -> None: + """Draw sliding window section.""" + ax.set_xlim(0, 14) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Sliding Window (okno przesuwne) — nakładające, stały rozmiar + krok", + fontsize=FS_LABEL, + fontweight="bold", + ) + + ax.annotate( + "", + xy=(13.5, 1.0), + xytext=(0.3, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") + + for i, e in enumerate(events[:8]): + x = 1.0 + i * 1.0 + ax.plot(x, 1.0, "ko", markersize=5) + ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center") + + # Sliding windows: size=4, slide=2 + slide_colors = [GRAY1, GRAY2, GRAY3] + for w in range(3): + x_start = 0.7 + w * 2.0 + y_base = 1.5 + w * 0.9 + rect = mpatches.FancyBboxPatch( + (x_start, y_base), + 4.0, + 0.7, + boxstyle="round,pad=0.08", + facecolor=slide_colors[w], + edgecolor=LN, + lw=1.5, + alpha=0.7, + ) + ax.add_patch(rect) + ax.text( + x_start + 2.0, + y_base + 0.35, + f"Okno {w + 1} (size=4)", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + ) + + ax.text( + 10.5, + 3.5, + "krok=2\nNakładanie!\ne3,e4 → w oknie 1 i 2", + fontsize=FS, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + +def _draw_session_window(ax: Axes) -> None: + """Draw session window section.""" + ax.set_xlim(0, 14) + ax.set_ylim(0, 4) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Session Window (okno sesji) — dynamiczny rozmiar, gap = przerwa", + fontsize=FS_LABEL, + fontweight="bold", + ) + + ax.annotate( + "", + xy=(13.5, 1.0), + xytext=(0.3, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") + + # Cluster 1: events close together + cluster1 = [1.0, 1.8, 2.3, 3.0] + for x in cluster1: + ax.plot(x, 1.0, "ko", markersize=5) + + # Gap + ax.annotate( + "", + xy=(7.0, 0.7), + xytext=(4.0, 0.7), + arrowprops={"arrowstyle": "<->", "lw": 1, "color": LN}, + ) + ax.text( + 5.5, + 0.3, + "GAP > timeout", + fontsize=FS, + ha="center", + fontweight="bold", + style="italic", + ) + + # Cluster 2 + cluster2 = [8.0, 8.8, 9.5] + for x in cluster2: + ax.plot(x, 1.0, "ko", markersize=5) + + # Session boxes + rect1 = mpatches.FancyBboxPatch( + (0.7, 1.4), + 2.6, + 1.0, + boxstyle="round,pad=0.1", + facecolor=GRAY1, + edgecolor=LN, + lw=1.5, + ) + ax.add_patch(rect1) + ax.text( + 2.0, 1.9, "Sesja 1\n(4 zdarzenia)", fontsize=FS, ha="center", fontweight="bold" + ) + + rect2 = mpatches.FancyBboxPatch( + (7.7, 1.4), + 2.1, + 1.0, + boxstyle="round,pad=0.1", + facecolor=GRAY3, + edgecolor=LN, + lw=1.5, + ) + ax.add_patch(rect2) + ax.text( + 8.75, 1.9, "Sesja 2\n(3 zdarzenia)", fontsize=FS, ha="center", fontweight="bold" + ) + + ax.text( + 5.5, + 3.0, + "Nowa sesja po przerwie > gap", + fontsize=FS, + ha="center", + style="italic", + ) + + +def _draw_global_window(ax: Axes) -> None: + """Draw global window section.""" + ax.set_xlim(0, 14) + ax.set_ylim(0, 4) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Global Window — jedno okno na cały strumień + trigger", + fontsize=FS_LABEL, + fontweight="bold", + ) + + ax.annotate( + "", + xy=(13.5, 1.0), + xytext=(0.3, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") + + for i in range(12): + x = 1.0 + i * 1.0 + ax.plot(x, 1.0, "ko", markersize=5) + + # One big window + rect = mpatches.FancyBboxPatch( + (0.5, 1.4), + 12.5, + 1.0, + boxstyle="round,pad=0.1", + facecolor=GRAY1, + edgecolor=LN, + lw=2, + ) + ax.add_patch(rect) + ax.text( + 6.75, + 1.9, + "GLOBAL WINDOW (cały strumień)", + fontsize=FS, + ha="center", + fontweight="bold", + ) + + # Trigger markers + for tx in [4.0, 8.0, 12.0]: + ax.plot([tx, tx], [1.4, 2.4], color=LN, lw=2, linestyle="--") + ax.text( + tx, + 2.7, + "EMIT", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY3, "edgecolor": LN}, + ) + + ax.text( + 6.75, + 3.3, + "Trigger decyduje kiedy emitować (np. co N zdarzeń)", + fontsize=FS, + ha="center", + style="italic", + ) + + +def gen_window_types() -> None: + """Gen window types.""" + fig, axes = plt.subplots(4, 1, figsize=(9, 10)) + fig.suptitle("4 typy okien — TSSG", fontsize=FS_TITLE, fontweight="bold") + + events = list(range(1, 13)) + + _draw_tumbling_window(axes[0], events) + _draw_sliding_window(axes[1], events) + _draw_session_window(axes[2]) + _draw_global_window(axes[3]) + + fig.tight_layout(rect=[0, 0, 1, 0.94]) + save_fig(fig, "q20_window_types.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q20_common.py b/python_pkg/praca_magisterska_video/generate_images/_q20_common.py new file mode 100644 index 0000000..bb11c63 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q20_common.py @@ -0,0 +1,180 @@ +"""Common utilities and constants for Q20 diagram generation. + +Monochrome, A4-printable PNGs (300 DPI). +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib as mpl + +mpl.use("Agg") + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.figure import Figure + +_logger = logging.getLogger(__name__) + +rng = np.random.default_rng(42) + +DPI = 300 +BG = "white" +LN = "black" +FS = 8 +FS_TITLE = 11 +FS_SMALL = 6.5 +FS_LABEL = 9 +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +GRAY1 = "#E8E8E8" +GRAY2 = "#D0D0D0" +GRAY3 = "#B8B8B8" +GRAY4 = "#F5F5F5" +GRAY5 = "#C0C0C0" + + +def draw_box( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + text: str, + *, + fill: str = "white", + lw: float = 1.2, + fontsize: float = FS, + fontweight: str = "normal", + ha: str = "center", + va: str = "center", + rounded: bool = True, + edgecolor: str = LN, + linestyle: str = "-", +) -> None: + """Draw box.""" + if rounded: + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.05", + lw=lw, + edgecolor=edgecolor, + facecolor=fill, + linestyle=linestyle, + ) + else: + rect = mpatches.Rectangle( + (x, y), + w, + h, + lw=lw, + edgecolor=edgecolor, + facecolor=fill, + linestyle=linestyle, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h / 2, + text, + ha=ha, + va=va, + fontsize=fontsize, + fontweight=fontweight, + wrap=True, + ) + + +def draw_arrow( + ax: Axes, + x1: float, + y1: float, + x2: float, + y2: float, + lw: float = 1.2, + style: str = "->", + color: str = LN, +) -> None: + """Draw arrow.""" + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": style, "color": color, "lw": lw}, + ) + + +def save_fig(fig: Figure, name: str) -> None: + """Save fig.""" + path = str(Path(OUTPUT_DIR) / name) + fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15) + plt.close(fig) + _logger.info(" Saved: %s", path) + + +def draw_table( + ax: Axes, + headers: list[str], + rows: list[list[str]], + x0: float, + y0: float, + col_widths: list[float], + row_h: float = 0.4, + header_fill: str = GRAY2, + row_fills: list[str] | None = None, + fontsize: float = FS, + header_fontsize: float | None = None, +) -> None: + """Draw table.""" + if header_fontsize is None: + header_fontsize = fontsize + len(headers) + # Header + cx = x0 + for j, hdr in enumerate(headers): + draw_box( + ax, + cx, + y0, + col_widths[j], + row_h, + hdr, + fill=header_fill, + fontsize=header_fontsize, + fontweight="bold", + rounded=False, + ) + cx += col_widths[j] + # Rows + for i, row in enumerate(rows): + cy = y0 - (i + 1) * row_h + cx = x0 + fill = GRAY4 if (i % 2 == 0) else "white" + if row_fills and i < len(row_fills): + fill = row_fills[i] + for j, cell in enumerate(row): + fw = "bold" if j == 0 else "normal" + draw_box( + ax, + cx, + cy, + col_widths[j], + row_h, + cell, + fill=fill, + fontsize=fontsize, + fontweight=fw, + rounded=False, + ) + cx += col_widths[j] diff --git a/python_pkg/praca_magisterska_video/generate_images/_q20_late_and_decisions.py b/python_pkg/praca_magisterska_video/generate_images/_q20_late_and_decisions.py new file mode 100644 index 0000000..55adfb3 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q20_late_and_decisions.py @@ -0,0 +1,240 @@ +"""Late data strategies and decision tree diagrams for Q20.""" + +from __future__ import annotations + +from _q20_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + draw_arrow, + draw_box, + plt, + save_fig, +) + + +# ============================================================ +# 16. Late data strategies (DRAS) +# ============================================================ +def gen_late_data_strategies() -> None: + """Gen late data strategies.""" + fig, ax = plt.subplots(figsize=(9, 5.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Late Data — 4 strategie (mnemonik DRAS)", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Setup: window closed, late event arrives + draw_box( + ax, + 0.5, + 5.5, + 4.5, + 1.0, + "Okno [14:00-14:05]\nZAMKNIĘTE o 14:05", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + draw_box( + ax, + 6.0, + 5.5, + 4.5, + 1.0, + "Spóźnione zdarzenie\nevent_time=14:00:03\narrives=14:05:30", + fill="#F8D7DA", + fontsize=FS_SMALL, + fontweight="bold", + ) + draw_arrow(ax, 10.5, 6.0, 5.0, 6.0, lw=2, color="#C62828", style="->") + ax.text( + 7.5, + 5.2, + "LATE!", + fontsize=FS_LABEL, + ha="center", + fontweight="bold", + color="#C62828", + ) + + # 4 strategies + strategies = [ + ("D — Drop", "Odrzuć spóźnione", "/dev/null", GRAY4), + ("R — Recompute", "Przelicz okno ponownie", "poprawne ale kosztowne", GRAY1), + ( + "A — Allowed lateness", + "Czekaj dodatkowy czas\n(np. +2 min)", + "kompromis pamięci", + GRAY2, + ), + ( + "S — Side output", + "Przekieruj do osobnej\nkolejki", + "elastyczne, ręczna analiza", + GRAY3, + ), + ] + for i, (name, desc, tradeoff, color) in enumerate(strategies): + y = 3.8 - i * 1.1 + draw_box(ax, 0.5, y, 2.5, 0.9, name, fill=color, fontsize=FS, fontweight="bold") + ax.text(3.3, y + 0.45, desc, fontsize=FS_SMALL, va="center") + ax.text( + 8.5, + y + 0.45, + tradeoff, + fontsize=FS_SMALL, + va="center", + style="italic", + color="#555", + ) + + save_fig(fig, "q20_late_data_strategies.png") + + +# ============================================================ +# 17. Decision tree — which platform +# ============================================================ +def gen_decision_tree() -> None: + """Gen decision tree.""" + fig, ax = plt.subplots(figsize=(10, 5.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Drzewo decyzyjne — wybór platformy", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Root question + draw_box( + ax, + 3.5, + 5.5, + 4.5, + 1.0, + "Latencja < 10ms\nwymagana?", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # TAK branch + draw_arrow(ax, 3.5, 5.7, 2.0, 5.0, lw=1.5) + ax.text(2.3, 5.3, "TAK", fontsize=FS, fontweight="bold") + + draw_box( + ax, + 0.3, + 3.5, + 3.5, + 1.0, + "Dane już w Kafce?\nProste transformacje?", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + + # TAK → Kafka Streams + draw_arrow(ax, 0.3, 3.7, -0.1, 3.0, lw=1.5) + ax.text(0.0, 3.3, "TAK", fontsize=FS_SMALL, fontweight="bold") + draw_box( + ax, + -0.3, + 1.8, + 2.5, + 1.0, + "Kafka\nStreams", + fill=GRAY5, + fontsize=FS_LABEL, + fontweight="bold", + ) + + # NIE → Flink + draw_arrow(ax, 3.8, 3.7, 4.5, 3.0, lw=1.5) + ax.text(4.0, 3.3, "NIE\n(złożona logika)", fontsize=FS_SMALL) + draw_box( + ax, + 3.0, + 1.8, + 2.5, + 1.0, + "Apache\nFlink", + fill=GRAY5, + fontsize=FS_LABEL, + fontweight="bold", + ) + + # NIE branch + draw_arrow(ax, 8.0, 5.7, 9.5, 5.0, lw=1.5) + ax.text(8.7, 5.3, "NIE", fontsize=FS, fontweight="bold") + + draw_box( + ax, + 7.5, + 3.5, + 4.2, + 1.0, + "~100ms-1s OK?\nPotrzeba ML / SQL?", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + + # TAK + ML → Spark + draw_arrow(ax, 9.5, 3.5, 9.5, 3.0, lw=1.5) + ax.text(10.0, 3.3, "TAK + ML/SQL", fontsize=FS_SMALL) + draw_box( + ax, + 8.0, + 1.8, + 2.5, + 1.0, + "Spark\nStreaming", + fill=GRAY5, + fontsize=FS_LABEL, + fontweight="bold", + ) + + # TAK + proste → Kafka Streams too + draw_arrow(ax, 7.5, 3.7, 6.5, 3.0, lw=1.5) + ax.text(6.3, 3.3, "proste + TAK", fontsize=FS_SMALL) + draw_box( + ax, + 5.8, + 1.8, + 2.0, + 1.0, + "Kafka\nStreams", + fill=GRAY5, + fontsize=FS, + fontweight="bold", + ) + + # Legend + ax.text( + 6.0, + 0.7, + "Reguła: Kafka Streams = najprostsze (library) | " + "Flink = najpotężniejszy (true streaming) | Spark = ekosystem ML", + fontsize=FS, + ha="center", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + save_fig(fig, "q20_decision_tree.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q20_platforms.py b/python_pkg/praca_magisterska_video/generate_images/_q20_platforms.py new file mode 100644 index 0000000..e7351ec --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q20_platforms.py @@ -0,0 +1,471 @@ +"""Streaming ecosystem, micro-batch, platform comparison, and engine diagrams for Q20.""" + +from __future__ import annotations + +from _q20_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + draw_table, + plt, + save_fig, +) + + +# ============================================================ +# 7. Streaming ecosystem overview +# ============================================================ +def gen_streaming_ecosystem() -> None: + """Gen streaming ecosystem.""" + fig, ax = plt.subplots(figsize=(10, 5.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Ekosystem przetwarzania strumieniowego", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Source + draw_box( + ax, + 0.3, + 2.5, + 2.0, + 3.0, + "Kafka\nTopics\n(źródło)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Engines + engines = [ + ("Kafka Streams\n(library w JVM)", GRAY1, 4.7), + ("Apache Flink\n(klaster)", GRAY3, 3.2), + ("Spark Streaming\n(klaster)", GRAY5, 1.7), + ] + for label, color, y in engines: + draw_box( + ax, 4.0, y, 3.0, 1.2, label, fill=color, fontsize=FS, fontweight="bold" + ) + draw_arrow(ax, 2.3, 4.0, 4.0, y + 0.6, lw=1.5) + + # Sinks + sinks = [ + ("Kafka topic\n/ baza danych", GRAY4, 4.7), + ("DB / Kafka\n/ S3", GRAY4, 3.2), + ("HDFS / DB\n/ dashboard", GRAY4, 1.7), + ] + for label, color, y in sinks: + draw_box(ax, 8.5, y, 2.5, 1.2, label, fill=color, fontsize=FS) + draw_arrow(ax, 7.0, y + 0.6, 8.5, y + 0.6, lw=1.5) + + # Labels + ax.text(1.3, 6.0, "ŹRÓDŁO", fontsize=FS_LABEL, ha="center", fontweight="bold") + ax.text(5.5, 6.2, "SILNIK", fontsize=FS_LABEL, ha="center", fontweight="bold") + ax.text(9.75, 6.2, "WYNIK", fontsize=FS_LABEL, ha="center", fontweight="bold") + + # Latency annotations + ax.text(5.5, 5.95, "~1-10 ms", fontsize=FS_SMALL, ha="center", style="italic") + ax.text(5.5, 4.5, "<10 ms", fontsize=FS_SMALL, ha="center", style="italic") + ax.text(5.5, 3.0, "~100 ms", fontsize=FS_SMALL, ha="center", style="italic") + + save_fig(fig, "q20_streaming_ecosystem.png") + + +# ============================================================ +# 8. True streaming vs Micro-batch +# ============================================================ +def gen_true_vs_microbatch() -> None: + """Gen true vs microbatch.""" + fig, axes = plt.subplots(2, 1, figsize=(10, 5.5)) + fig.suptitle("True Streaming vs Micro-Batch", fontsize=FS_TITLE, fontweight="bold") + + # True streaming + ax = axes[0] + ax.set_xlim(0, 12) + ax.set_ylim(0, 3.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "TRUE STREAMING (Flink, Kafka Streams) — event-by-event", + fontsize=FS_LABEL, + fontweight="bold", + ) + + for i in range(6): + x = 1.0 + i * 1.8 + # Event + draw_box( + ax, + x, + 2.0, + 0.8, + 0.7, + f"e{i + 1}", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + rounded=False, + ) + # Arrow down + draw_arrow(ax, x + 0.4, 2.0, x + 0.4, 1.4, lw=1) + # Result + draw_box( + ax, + x, + 0.5, + 0.8, + 0.7, + f"r{i + 1}", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + rounded=False, + ) + # Latency label + ax.text(x + 0.4, 1.6, "~ms", fontsize=5, ha="center", color="#555") + + ax.text( + 11.5, + 1.3, + "Latencja:\n< 10 ms", + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + # Micro-batch + ax = axes[1] + ax.set_xlim(0, 12) + ax.set_ylim(0, 3.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "MICRO-BATCH (Spark Streaming) — grupami co ~100ms", + fontsize=FS_LABEL, + fontweight="bold", + ) + + batch_colors = [GRAY1, GRAY2, GRAY3] + for b in range(3): + bx = 0.8 + b * 3.5 + # Batch boundary + draw_box(ax, bx, 1.8, 3.0, 1.0, "", fill=batch_colors[b], rounded=True, lw=1.5) + ax.text( + bx + 1.5, 2.6, f"Batch {b + 1}", fontsize=FS, ha="center", fontweight="bold" + ) + for j in range(3): + ex = bx + 0.3 + j * 0.9 + draw_box( + ax, + ex, + 2.0, + 0.7, + 0.5, + f"e{b * 3 + j + 1}", + fill="white", + fontsize=FS_SMALL, + rounded=False, + ) + + # Arrow down + draw_arrow(ax, bx + 1.5, 1.8, bx + 1.5, 1.2, lw=1.5) + # Result + draw_box( + ax, + bx + 0.5, + 0.4, + 2.0, + 0.7, + f"result {b + 1}", + fill=GRAY4, + fontsize=FS, + fontweight="bold", + ) + + ax.text( + 11.5, + 1.3, + "Latencja:\n~100ms-s", + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + fig.tight_layout(rect=[0, 0, 1, 0.92]) + save_fig(fig, "q20_true_vs_microbatch.png") + + +# ============================================================ +# 9. Platform comparison table +# ============================================================ +def gen_platform_comparison() -> None: + """Gen platform comparison.""" + fig, ax = plt.subplots(figsize=(9, 5)) + ax.set_xlim(0, 11.5) + ax.set_ylim(-6, 1) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Porównanie platform strumieniowych", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + headers = ["Cecha", "Kafka Streams", "Apache Flink", "Spark Streaming"] + col_w = [2.5, 2.8, 2.8, 2.8] + rows = [ + ["Model", "event-by-event", "event-by-event", "micro-batch (~100ms)"], + ["Deployment", "library (w JVM)", "klaster", "klaster"], + ["Latencja", "~1-10 ms", "< 10 ms", "100 ms - sekundy"], + ["Exactly-once", "Kafka TXN", "checkpointing", "WAL"], + ["State", "RocksDB local", "RocksDB + ckpt", "in-memory / ext"], + ["Okna", "T, S, Session", "wszystkie + custom", "T, S"], + ["Use case", "Kafka → Kafka", "złożona analityka", "ETL + ML / SQL"], + ] + draw_table( + ax, + headers, + rows, + x0=0.25, + y0=0.5, + col_widths=col_w, + row_h=0.6, + fontsize=7, + header_fontsize=8, + ) + + save_fig(fig, "q20_platform_comparison.png") + + +# ============================================================ +# 10. Kafka Streams architecture +# ============================================================ +def gen_kafka_streams_arch() -> None: + """Gen kafka streams arch.""" + fig, ax = plt.subplots(figsize=(9, 5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Kafka Streams — architektura (library w JVM)", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Outer box: Your Java application + draw_box(ax, 0.5, 0.5, 11.0, 5.5, "", fill=GRAY4, rounded=True, lw=2.5) + ax.text( + 6.0, + 5.7, + "Twoja aplikacja Java (JVM)", + fontsize=FS_LABEL, + ha="center", + fontweight="bold", + ) + + # Kafka Consumer + draw_box( + ax, + 1.0, + 3.0, + 2.5, + 1.5, + "Kafka\nConsumer\n(input topic)", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + + # Processing + draw_box( + ax, + 4.5, + 3.0, + 2.5, + 1.5, + "Kafka Streams\n(logika\nbiznesowa)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Kafka Producer + draw_box( + ax, + 8.0, + 3.0, + 2.5, + 1.5, + "Kafka\nProducer\n(output topic)", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + + # Arrows + draw_arrow(ax, 3.5, 3.75, 4.5, 3.75, lw=2) + draw_arrow(ax, 7.0, 3.75, 8.0, 3.75, lw=2) + + # RocksDB state store + draw_box( + ax, + 4.5, + 1.0, + 2.5, + 1.3, + "RocksDB\n(stan lokalny)", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + ax.plot([5.75, 5.75], [3.0, 2.3], color=LN, lw=1.5) + ax.text( + 7.3, + 1.6, + "okna, joiny,\nagregacje", + fontsize=FS_SMALL, + style="italic", + va="center", + ) + + # Key message + ax.text( + 6.0, + 0.2, + "NIE potrzebujesz osobnego klastra! Skalujesz = więcej instancji JVM.", + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN}, + ) + + save_fig(fig, "q20_kafka_streams_arch.png") + + +# ============================================================ +# 11. Flink architecture + checkpointing +# ============================================================ +def gen_flink_arch() -> None: + """Gen flink arch.""" + fig, ax = plt.subplots(figsize=(9, 6)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 8) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Apache Flink — architektura klastra + checkpointing", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Cluster border + draw_box(ax, 0.3, 1.0, 11.4, 6.2, "", fill=GRAY4, rounded=True, lw=2.5) + ax.text( + 6.0, 6.95, "FLINK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold" + ) + + # Job Manager + draw_box( + ax, + 1.0, + 5.5, + 3.0, + 1.2, + "Job Manager\n(koordynacja,\ncheckpointy)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Task Managers + draw_box(ax, 1.0, 3.0, 10.0, 2.0, "", fill="white", rounded=True, lw=1.5) + ax.text( + 6.0, 4.7, "Task Managers (workery)", fontsize=FS, ha="center", fontweight="bold" + ) + + slots = ["source\n& map()", "map()", "window()\n& reduce", "sink()"] + for i, s in enumerate(slots): + x = 1.5 + i * 2.4 + draw_box( + ax, + x, + 3.3, + 2.0, + 1.2, + f"Slot {i + 1}\n{s}", + fill=GRAY1, + fontsize=FS_SMALL, + fontweight="bold", + ) + + draw_arrow(ax, 2.5, 5.5, 6.0, 5.0, lw=1.5, style="->") + ax.text(5.0, 5.5, "przydziela\npodzadania", fontsize=FS_SMALL, style="italic") + + # Checkpoint storage + draw_box( + ax, + 5.5, + 1.2, + 3.5, + 1.2, + "Checkpoint Storage\n(HDFS / S3)", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + ax.plot([7.25, 7.25], [2.4, 3.3], color=LN, lw=1.5, linestyle="--") + ax.text(8.0, 2.7, "snapshoty\nstanu", fontsize=FS_SMALL, style="italic") + + # Barrier concept at bottom + ax.text(3.0, 1.6, "Barrier:", fontsize=FS, fontweight="bold") + barrier_boxes = ["source", "|B|", "map", "|B|", "sink"] + bx = 0.8 + for _i, b in enumerate(barrier_boxes): + if b == "|B|": + ax.text( + bx + 0.3, + 1.5, + b, + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN}, + ) + draw_arrow(ax, bx, 1.5, bx + 0.1, 1.5, lw=1) + bx += 0.7 + else: + draw_box( + ax, + bx, + 1.3, + 1.0, + 0.45, + b, + fill=GRAY1, + fontsize=FS_SMALL, + fontweight="bold", + ) + bx += 1.2 + + save_fig(fig, "q20_flink_arch.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q20_time_monitoring_sessions.py b/python_pkg/praca_magisterska_video/generate_images/_q20_time_monitoring_sessions.py new file mode 100644 index 0000000..0bec93a --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q20_time_monitoring_sessions.py @@ -0,0 +1,464 @@ +"""Event time, fraud detection, SLA monitoring, and session diagrams for Q20.""" + +from __future__ import annotations + +from _q20_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_box, + np, + plt, + rng, + save_fig, +) +import matplotlib.patches as mpatches + + +# ============================================================ +# 3. Event Time vs Processing Time scatter + watermark +# ============================================================ +def gen_event_vs_processing_time() -> None: + """Gen event vs processing time.""" + fig, axes = plt.subplots(1, 2, figsize=(11, 5)) + fig.suptitle( + "Event Time vs Processing Time + Watermark", + fontsize=FS_TITLE, + fontweight="bold", + ) + + # --- Panel 1: Ideal vs Real --- + ax = axes[0] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.set_aspect("equal") + ax.set_xlabel("Event Time", fontsize=FS_LABEL) + ax.set_ylabel("Processing Time", fontsize=FS_LABEL) + ax.set_title("Idealny vs Realny świat", fontsize=FS_LABEL, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Ideal line + ax.plot([0, 9], [0, 9], "k--", lw=1.5, label="ideał (brak opóźnień)") + + # Real scattered points (processing >= event, some out of order) + event_times = np.sort(rng.uniform(1, 8, 15)) + proc_times = event_times + rng.exponential(0.5, 15) + # Make some out of order + idx = [3, 7, 11] + for i in idx: + proc_times[i] += 1.5 + + ax.scatter( + event_times, proc_times, c="black", s=30, zorder=5, label="zdarzenia (realne)" + ) + + # Highlight out-of-order + for i in idx: + ax.annotate( + "out-of-order", + xy=(event_times[i], proc_times[i]), + xytext=(event_times[i] + 0.8, proc_times[i] + 0.5), + fontsize=FS_SMALL, + ha="left", + arrowprops={"arrowstyle": "->", "lw": 0.8, "color": "#555"}, + ) + + ax.legend(fontsize=FS_SMALL, loc="upper left") + ax.text( + 7, + 2, + "Opóźnienie\nsieciowe ↑", + fontsize=FS, + ha="center", + style="italic", + color="#555", + ) + + # --- Panel 2: Watermark concept --- + ax = axes[1] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.set_aspect("equal") + ax.set_xlabel("Event Time", fontsize=FS_LABEL) + ax.set_ylabel("Processing Time", fontsize=FS_LABEL) + ax.set_title("Watermark — granica postępu", fontsize=FS_LABEL, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Events + ax.scatter(event_times, proc_times, c="black", s=30, zorder=5) + + # Watermark line (below most points, tracks progress) + wm_x = np.linspace(0, 9, 50) + wm_y = wm_x + 0.3 # watermark slightly above ideal + ax.plot(wm_x, wm_y, "k-", lw=2.5, label="Watermark") + ax.fill_between(wm_x, 0, wm_y, alpha=0.15, color="gray") + + ax.text( + 2.0, + 1.0, + 'PONIŻEJ watermark:\n„na pewno dotarło"', + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + # Late event + late_x, late_y = event_times[7], proc_times[7] + ax.scatter( + [late_x], [late_y], c="white", s=80, zorder=6, edgecolors="black", linewidths=2 + ) + ax.annotate( + "LATE DATA!\n(po watermarku)", + xy=(late_x, late_y), + xytext=(late_x + 1.2, late_y + 0.8), + fontsize=FS_SMALL, + ha="left", + fontweight="bold", + arrowprops={"arrowstyle": "->", "lw": 1, "color": LN}, + ) + + ax.legend(fontsize=FS_SMALL, loc="upper left") + + fig.tight_layout(rect=[0, 0, 1, 0.92]) + save_fig(fig, "q20_event_vs_processing_time.png") + + +# ============================================================ +# 4. Tumbling window example — fraud detection +# ============================================================ +def gen_tumbling_fraud() -> None: + """Gen tumbling fraud.""" + fig, ax = plt.subplots(figsize=(9, 4)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 5.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Tumbling Window — fraud detection (okno = 1 min)", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Time axis + ax.annotate( + "", + xy=(11.5, 1.0), + xytext=(0.5, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(6.0, 0.4, "czas", fontsize=FS, ha="center") + + # Window 1: normal + draw_box(ax, 1.0, 1.5, 4.5, 3.0, "", fill=GRAY4, rounded=True, lw=2) + ax.text( + 3.25, 4.2, "[14:00 — 14:01]", fontsize=FS_LABEL, ha="center", fontweight="bold" + ) + # Transactions + txns1 = ["Sklep A: 50 zł", "Sklep B: 30 zł", "Stacja: 80 zł"] + for i, t in enumerate(txns1): + draw_box( + ax, + 1.3, + 3.3 - i * 0.55, + 4.0, + 0.45, + t, + fill=GRAY1, + fontsize=FS_SMALL, + rounded=False, + ) + ax.text( + 3.25, + 1.7, + "count = 3 → OK", + fontsize=FS, + ha="center", + fontweight="bold", + color="#2E7D32", + bbox={ + "boxstyle": "round,pad=0.15", + "facecolor": "#E8F5E9", + "edgecolor": "#2E7D32", + }, + ) + + # Window 2: fraud! + draw_box(ax, 6.0, 1.5, 4.5, 3.0, "", fill=GRAY1, rounded=True, lw=2) + ax.text( + 8.25, 4.2, "[14:01 — 14:02]", fontsize=FS_LABEL, ha="center", fontweight="bold" + ) + txns2 = ["ATM Warszawa: 500 zł", "ATM Kraków: 500 zł", "... +45 transakcji"] + for i, t in enumerate(txns2): + draw_box( + ax, + 6.3, + 3.3 - i * 0.55, + 4.0, + 0.45, + t, + fill=GRAY3, + fontsize=FS_SMALL, + rounded=False, + ) + ax.text( + 8.25, + 1.7, + "count = 47 → ALERT!", + fontsize=FS, + ha="center", + fontweight="bold", + color="#C62828", + bbox={ + "boxstyle": "round,pad=0.15", + "facecolor": "#F8D7DA", + "edgecolor": "#C62828", + }, + ) + + save_fig(fig, "q20_tumbling_fraud.png") + + +# ============================================================ +# 5. Sliding window — SLA monitoring +# ============================================================ +def gen_sliding_sla() -> None: + """Gen sliding sla.""" + fig, ax = plt.subplots(figsize=(9, 4.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 6) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Sliding Window — monitoring SLA (okno=5min, krok=1min)", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Time axis + ax.annotate( + "", + xy=(11.5, 0.5), + xytext=(0.5, 0.5), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + times = ["14:05", "14:06", "14:07", "14:08", "14:09"] + latencies = [120, 180, 340, 290, 150] + sla = 200 + + for i, (t, lat) in enumerate(zip(times, latencies, strict=False)): + x = 1.5 + i * 2.0 + ax.text(x, 0.1, t, fontsize=FS, ha="center") + + # Bar proportional to latency + bar_h = lat / 100.0 + is_breach = lat > sla + fill = "#F8D7DA" if is_breach else GRAY1 + edge = "#C62828" if is_breach else LN + draw_box( + ax, + x - 0.5, + 1.0, + 1.0, + bar_h, + "", + fill=fill, + rounded=False, + edgecolor=edge, + lw=1.5, + ) + ax.text( + x, + 1.0 + bar_h + 0.15, + f"{lat}ms", + fontsize=FS, + ha="center", + fontweight="bold", + color="#C62828" if is_breach else LN, + ) + + # Status + status = "ALERT!" if is_breach else "OK" + ax.text( + x, + 1.0 + bar_h + 0.55, + status, + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + color="#C62828" if is_breach else "#2E7D32", + ) + + # SLA line + sla_y = 1.0 + sla / 100.0 + ax.plot([0.8, 11.2], [sla_y, sla_y], "k--", lw=1.5) + ax.text(11.3, sla_y, f"SLA={sla}ms", fontsize=FS, va="center", fontweight="bold") + + # Sliding window bracket + ax.annotate( + "", + xy=(1.0, 5.3), + xytext=(5.0, 5.3), + arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN}, + ) + ax.text(3.0, 5.6, "okno = 5 min", fontsize=FS, ha="center", fontweight="bold") + + ax.annotate( + "", + xy=(3.0, 4.8), + xytext=(5.0, 4.8), + arrowprops={"arrowstyle": "<->", "lw": 1, "color": "#555"}, + ) + ax.text( + 4.0, + 4.4, + "krok = 1 min\n(nakładanie!)", + fontsize=FS_SMALL, + ha="center", + style="italic", + ) + + save_fig(fig, "q20_sliding_sla.png") + + +# ============================================================ +# 6. Session window — user sessions +# ============================================================ +def gen_session_users() -> None: + """Gen session users.""" + fig, axes = plt.subplots(2, 1, figsize=(10, 5)) + fig.suptitle( + "Session Window — sesje użytkowników (gap = 30 min)", + fontsize=FS_TITLE, + fontweight="bold", + ) + + # Anna: 2 sessions + ax = axes[0] + ax.set_xlim(0, 14) + ax.set_ylim(0, 3.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Użytkownik Anna", fontsize=FS_LABEL, fontweight="bold") + + ax.annotate( + "", + xy=(13.5, 1.0), + xytext=(0.3, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + + # Clicks cluster 1 + for x in [1.0, 1.8, 2.5, 3.2]: + ax.plot(x, 1.0, "ko", markersize=6) + # Clicks cluster 2 + for x in [9.0, 9.8, 10.5]: + ax.plot(x, 1.0, "ko", markersize=6) + + # Sessions + rect1 = mpatches.FancyBboxPatch( + (0.7, 1.5), + 2.8, + 1.2, + boxstyle="round,pad=0.1", + facecolor=GRAY1, + edgecolor=LN, + lw=1.5, + ) + ax.add_patch(rect1) + ax.text( + 2.1, + 2.1, + "Sesja 1\n4 kliknięcia, 12 min", + fontsize=FS, + ha="center", + fontweight="bold", + ) + + rect2 = mpatches.FancyBboxPatch( + (8.7, 1.5), + 2.1, + 1.2, + boxstyle="round,pad=0.1", + facecolor=GRAY3, + edgecolor=LN, + lw=1.5, + ) + ax.add_patch(rect2) + ax.text( + 9.75, + 2.1, + "Sesja 2\n3 kliknięcia, 8 min", + fontsize=FS, + ha="center", + fontweight="bold", + ) + + # Gap + ax.annotate( + "", + xy=(8.5, 0.5), + xytext=(3.8, 0.5), + arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN}, + ) + ax.text( + 6.15, + 0.1, + "cisza 45 min > gap(30)", + fontsize=FS, + ha="center", + fontweight="bold", + style="italic", + ) + + # Bob: 1 session + ax = axes[1] + ax.set_xlim(0, 14) + ax.set_ylim(0, 3.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Użytkownik Bob", fontsize=FS_LABEL, fontweight="bold") + + ax.annotate( + "", + xy=(13.5, 1.0), + xytext=(0.3, 1.0), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + + # Clicks spread evenly + bobs = [1.0, 2.5, 4.0, 5.5, 7.0, 8.5, 10.0] + for x in bobs: + ax.plot(x, 1.0, "ko", markersize=6) + + rect = mpatches.FancyBboxPatch( + (0.7, 1.5), + 9.6, + 1.2, + boxstyle="round,pad=0.1", + facecolor=GRAY1, + edgecolor=LN, + lw=2, + ) + ax.add_patch(rect) + ax.text( + 5.5, + 2.1, + "Sesja 1 (ciągła) — 7 kliknięć, każde < 30 min od poprzedniego", + fontsize=FS, + ha="center", + fontweight="bold", + ) + + fig.tight_layout(rect=[0, 0, 1, 0.92]) + save_fig(fig, "q20_session_users.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_architectures.py b/python_pkg/praca_magisterska_video/generate_images/_q23_architectures.py new file mode 100644 index 0000000..01ae8ed --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_architectures.py @@ -0,0 +1,467 @@ +"""FCN and U-Net architecture diagram generators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q23_common import ( + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TINY, + FS_TITLE, + GRAY1, + GRAY2, + GRAY5, + GREEN_ACCENT, + RED_ACCENT, + _save_figure, + plt, +) +from matplotlib.patches import FancyBboxPatch + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def generate_fcn() -> None: + """Generate fcn.""" + _fig, axes = plt.subplots(2, 1, figsize=(10, 7)) + + # --- Panel 1: FC vs Conv 1x1 --- + ax = axes[0] + ax.set_xlim(0, 20) + ax.set_ylim(0, 6) + ax.axis("off") + ax.set_title( + "FC (Fully Connected) vs Conv 1x1", fontsize=FS_TITLE, fontweight="bold" + ) + + # Classic CNN with FC + layer_info_fc = [ + (1.5, "Obraz\n224x224x3", 2.2, GRAY2), + (4.5, "Conv+Pool\n112x112x64", 1.8, GRAY2), + (7.5, "Conv+Pool\n7x7x512", 1.0, GRAY2), + (10, "Flatten\n25088", 0.5, ACCENT_LIGHT), + (12, "FC\n4096", 0.5, ACCENT_LIGHT), + (14, "FC\n1000", 0.3, ACCENT_LIGHT), + (16, '"Kot"', 0.3, "#FFCDD2"), + ] + + y_fc = 4.5 + for i, (x, label, w, color) in enumerate(layer_info_fc): + rect = FancyBboxPatch( + (x - w / 2, y_fc - 0.6), + w, + 1.2, + boxstyle="round,pad=0.05", + facecolor=color, + edgecolor=BLACK, + linewidth=0.8, + ) + ax.add_patch(rect) + ax.text(x, y_fc, label, ha="center", va="center", fontsize=FS_TINY) + if i < len(layer_info_fc) - 1: + next_x = layer_info_fc[i + 1][0] + ax.annotate( + "", + xy=(next_x - layer_info_fc[i + 1][2] / 2, y_fc), + xytext=(x + w / 2, y_fc), + arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1}, + ) + + ax.text( + 0.3, y_fc, "CNN:", fontsize=FS, fontweight="bold", color=RED_ACCENT, va="center" + ) + ax.text( + 12, + y_fc + 1, + "PROBLEM: FC wymaga\nSTAŁEGO rozmiaru\n(np. 224x224)", + ha="center", + fontsize=FS_SMALL, + color=RED_ACCENT, + fontweight="bold", + bbox={ + "boxstyle": "round", + "facecolor": "#FFCDD2", + "edgecolor": RED_ACCENT, + "alpha": 0.3, + }, + ) + + # FCN with Conv 1x1 + layer_info_fcn = [ + (1.5, "Obraz\nHxWx3", 2.2, GRAY2), + (4.5, "Conv+Pool\nH/2 x W/2\nx64", 1.8, GRAY2), + (7.5, "Conv+Pool\nH/32 x W/32\nx512", 1.0, GRAY2), + (10.5, "Conv 1x1\nH/32 x W/32\nxC", 0.8, "#C8E6C9"), + (13.5, "Upsample\nHxWxC", 1.8, "#C8E6C9"), + (16.5, "Mapa\nsegmentacji", 1.5, "#C8E6C9"), + ] + + y_fcn = 1.5 + for i, (x, label, w, color) in enumerate(layer_info_fcn): + rect = FancyBboxPatch( + (x - w / 2, y_fcn - 0.7), + w, + 1.4, + boxstyle="round,pad=0.05", + facecolor=color, + edgecolor=BLACK, + linewidth=0.8, + ) + ax.add_patch(rect) + ax.text(x, y_fcn, label, ha="center", va="center", fontsize=FS_TINY) + if i < len(layer_info_fcn) - 1: + next_x = layer_info_fcn[i + 1][0] + ax.annotate( + "", + xy=(next_x - layer_info_fcn[i + 1][2] / 2, y_fcn), + xytext=(x + w / 2, y_fcn), + arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1}, + ) + + ax.text( + 0.3, + y_fcn, + "FCN:", + fontsize=FS, + fontweight="bold", + color=GREEN_ACCENT, + va="center", + ) + ax.text( + 10.5, + y_fcn + 1.2, + "Conv 1x1:\nkażdy piksel\nosobno x wagi\n(jak FC ale\nzachowuje HxW)", + ha="center", + fontsize=FS_TINY, + color=GREEN_ACCENT, + bbox={ + "boxstyle": "round", + "facecolor": "#C8E6C9", + "edgecolor": GREEN_ACCENT, + "alpha": 0.3, + }, + ) + + # --- Panel 2: What FC and Conv do --- + ax = axes[1] + ax.set_xlim(0, 20) + ax.set_ylim(0, 6) + ax.axis("off") + ax.set_title( + "Co robi warstwa FC? Co robi konwolucja?", fontsize=FS_TITLE, fontweight="bold" + ) + + # FC explanation + rect = FancyBboxPatch( + (0.3, 3.2), + 9, + 2.5, + boxstyle="round,pad=0.15", + facecolor=ACCENT_LIGHT, + edgecolor=ACCENT, + linewidth=1, + ) + ax.add_patch(rect) + ax.text( + 4.8, 5.2, "Fully Connected (FC)", fontsize=FS, fontweight="bold", ha="center" + ) + ax.text( + 4.8, + 4.5, + "KAŻDY neuron połączony z KAŻDYM wejściem\n" + "25 088 wejść x 4 096 neuronów = ~103 MLN wag!\n" + "Traci informację GDZIE (przestrzenną)\n" + "Wymaga STAŁEGO rozmiaru wejścia", + fontsize=FS_TINY, + ha="center", + va="top", + ) + + # Conv explanation + rect = FancyBboxPatch( + (10.3, 3.2), + 9, + 2.5, + boxstyle="round,pad=0.15", + facecolor="#C8E6C9", + edgecolor=GREEN_ACCENT, + linewidth=1, + ) + ax.add_patch(rect) + ax.text(14.8, 5.2, "Konwolucja (Conv)", fontsize=FS, fontweight="bold", ha="center") + ax.text( + 14.8, + 4.5, + 'Filtr (np. 3x3) „jedzie" po obrazie\n' + "Te same wagi dla KAŻDEJ pozycji\n" + "Zachowuje informację GDZIE\n" + "Akceptuje DOWOLNY rozmiar wejścia", + fontsize=FS_TINY, + ha="center", + va="top", + ) + + # Conv 1x1 explanation + rect = FancyBboxPatch( + (3, 0.3), + 14, + 2.2, + boxstyle="round,pad=0.15", + facecolor=GRAY1, + edgecolor=BLACK, + linewidth=1, + ) + ax.add_patch(rect) + ax.text( + 10, + 2.1, + 'Conv 1x1 = „FC per piksel"', + fontsize=FS, + fontweight="bold", + ha="center", + ) + ax.text( + 10, + 1.5, + "Filtr 1x1: patrzy na JEDEN piksel, ale WSZYSTKIE kanały (512→C klas)\n" + "Działa jak FC ale zachowuje mapę HxW → każdy piksel osobno klasyfikowany\n" + "FCN: zamień FC na Conv1x1 → koniec z wymogiem stałego rozmiaru!", + fontsize=FS_TINY, + ha="center", + va="top", + ) + + _save_figure("q23_fc_vs_conv1x1.png") + + +def generate_unet() -> None: + """Generate unet.""" + _fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.set_xlim(-1, 21) + ax.set_ylim(-1, 12) + ax.axis("off") + ax.set_title( + "U-Net: architektura w kształcie litery U", + fontsize=FS_TITLE + 1, + fontweight="bold", + ) + + # Encoder layers (going DOWN-LEFT) + encoder_layers = [ + (2, 10, 2.5, 1.5, "572x572x1\n(wejście)", 64), + (2, 7.5, 2.2, 1.3, "284x284\nx64", 64), + (2, 5, 1.8, 1.1, "140x140\nx128", 128), + (2, 2.5, 1.5, 1.0, "68x68\nx256", 256), + ] + + # Bottleneck + bottleneck = (8, 0.5, 2.5, 1.2, "32x32x512\n(bottleneck)", 512) + + # Decoder layers (going UP-RIGHT) + decoder_layers = [ + (14, 2.5, 1.5, 1.0, "68x68\nx256", 256), + (14, 5, 1.8, 1.1, "140x140\nx128", 128), + (14, 7.5, 2.2, 1.3, "284x284\nx64", 64), + (14, 10, 2.5, 1.5, "572x572xC\n(mapa seg.)", "C"), + ] + + def draw_block( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + label: str, + color: str, + ) -> None: + """Draw block.""" + rect = FancyBboxPatch( + (x - w / 2, y - h / 2), + w, + h, + boxstyle="round,pad=0.05", + facecolor=color, + edgecolor=BLACK, + linewidth=1.2, + ) + ax.add_patch(rect) + ax.text(x, y, label, ha="center", va="center", fontsize=FS_TINY) + + # Draw encoder + for x, y, w, h, label, _channels in encoder_layers: + draw_block(ax, x, y, w, h, label, ACCENT_LIGHT) + + # Draw arrows down (encoder) + for i in range(len(encoder_layers) - 1): + x1, y1 = encoder_layers[i][0], encoder_layers[i][1] - encoder_layers[i][3] / 2 + x2, y2 = ( + encoder_layers[i + 1][0], + encoder_layers[i + 1][1] + encoder_layers[i + 1][3] / 2, + ) + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2}, + ) + ax.text( + x1 - 1.7, + (y1 + y2) / 2, + "MaxPool\n2x2\n↓ zmniejsz", + fontsize=FS_TINY, + ha="center", + color=ACCENT, + fontweight="bold", + ) + + # Encoder to bottleneck + x1, y1 = encoder_layers[-1][0], encoder_layers[-1][1] - encoder_layers[-1][3] / 2 + draw_block( + ax, + bottleneck[0], + bottleneck[1], + bottleneck[2], + bottleneck[3], + bottleneck[4], + GRAY2, + ) + ax.annotate( + "", + xy=(bottleneck[0] - bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2), + xytext=(x1, y1), + arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2}, + ) + + # Bottleneck to decoder + ax.annotate( + "", + xy=( + decoder_layers[0][0] - decoder_layers[0][2] / 2, + decoder_layers[0][1] - decoder_layers[0][3] / 2, + ), + xytext=(bottleneck[0] + bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2), + arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 2}, + ) + + # Draw decoder + for x, y, w, h, label, channels in decoder_layers: + color = "#C8E6C9" if channels != "C" else "#A5D6A7" + draw_block(ax, x, y, w, h, label, color) + + # Draw arrows up (decoder) + for i in range(len(decoder_layers) - 1): + x1, y1 = decoder_layers[i][0], decoder_layers[i][1] + decoder_layers[i][3] / 2 + x2, y2 = ( + decoder_layers[i + 1][0], + decoder_layers[i + 1][1] - decoder_layers[i + 1][3] / 2, + ) + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT, "lw": 2}, + ) + ax.text( + x1 + 2, + (y1 + y2) / 2, + "UpConv\n2x2\n↑ zwiększ", + fontsize=FS_TINY, + ha="center", + color=GREEN_ACCENT, + fontweight="bold", + ) + + # Skip connections (horizontal arrows) + for i in range(len(encoder_layers)): + enc = encoder_layers[i] + dec = decoder_layers[len(decoder_layers) - 1 - i] + ax.annotate( + "", + xy=(dec[0] - dec[2] / 2, dec[1]), + xytext=(enc[0] + enc[2] / 2, enc[1]), + arrowprops={ + "arrowstyle": "->", + "color": GRAY5, + "lw": 1.5, + "linestyle": "dashed", + }, + ) + mid_x = (enc[0] + enc[2] / 2 + dec[0] - dec[2] / 2) / 2 + ax.text( + mid_x, + enc[1] + 0.6, + "skip\n(concat)", + fontsize=FS_TINY, + ha="center", + color=GRAY5, + fontweight="bold", + ) + + # Labels + ax.text( + 0, + 11.5, + "ENCODER\n(↓ zmniejsza)", + fontsize=FS, + fontweight="bold", + color=ACCENT, + ha="center", + ) + ax.text( + 17, + 11.5, + "DECODER\n(↑ zwiększa)", + fontsize=FS, + fontweight="bold", + color=GREEN_ACCENT, + ha="center", + ) + ax.text( + 8, + -0.8, + 'Kształt litery „U": encoder schodzi ↓ → bottleneck na dnie → decoder wraca ↑', + fontsize=FS_SMALL, + ha="center", + color=GRAY5, + fontweight="bold", + ) + + # Concatenation explanation + rect = FancyBboxPatch( + (17.5, 3), + 3, + 5, + boxstyle="round,pad=0.15", + facecolor=GRAY1, + edgecolor=GRAY5, + linewidth=1, + linestyle="--", + ) + ax.add_patch(rect) + ax.text( + 19, 7.5, "Concatenation:", fontsize=FS_SMALL, ha="center", fontweight="bold" + ) + ax.text( + 19, + 6.5, + "Encoder: 64 kanały\nDecoder: 64 kanały\n→ concat → 128 kanałów\n\n" + "Jak sklejenie\ndwóch stosów\nkart:", + fontsize=FS_TINY, + ha="center", + ) + ax.text( + 19, + 3.7, + "[enc₁|enc₂|...|dec₁|dec₂|...]", + fontsize=FS_TINY - 1, + ha="center", + fontweight="bold", + color=ACCENT, + ) + + _save_figure("q23_unet_arch.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_common.py b/python_pkg/praca_magisterska_video/generate_images/_q23_common.py new file mode 100644 index 0000000..b425d71 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_common.py @@ -0,0 +1,96 @@ +"""Common utilities and constants for Q23 diagram generation. + +A4-compatible, monochrome-friendly (grays + one accent), 300 DPI. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib as mpl + +mpl.use("Agg") + +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +_logger = logging.getLogger(__name__) + +rng = np.random.default_rng(42) + +DPI = 300 +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +# Color palette — monochrome-friendly +BLACK = "#000000" +WHITE = "#FFFFFF" +GRAY1 = "#F5F5F5" +GRAY2 = "#E0E0E0" +GRAY3 = "#BDBDBD" +GRAY4 = "#9E9E9E" +GRAY5 = "#757575" +GRAY6 = "#424242" +ACCENT = "#4A90D9" # single blue accent for highlights +ACCENT_LIGHT = "#B3D4FC" +RED_ACCENT = "#D32F2F" +GREEN_ACCENT = "#388E3C" + +FS = 9 +FS_TITLE = 11 +FS_SMALL = 7 +FS_TINY = 6 + +_RIDGE_X = 5 +_VALLEY2_END = 9 +_DARK_PIXEL_THRESHOLD = 100 +_GRID_LAST_IDX = 3 +_HIGHLIGHT_START = 3 +_HIGHLIGHT_END = 5 +_BRIGHT_THRESHOLD = 170 +_OTSU_THRESHOLD = 128 + + +def _save_figure(name: str) -> None: + """Save current figure and log.""" + plt.tight_layout() + plt.savefig( + str(Path(OUTPUT_DIR) / name), + dpi=DPI, + bbox_inches="tight", + facecolor="white", + ) + plt.close() + _logger.info(" ✓ %s", name) + + +def _render_text_lines( + ax: Axes, + lines: list[tuple[str, int, str, str]], + *, + x_pos: float = 0.5, + start_y: float, + y_step: float = 0.5, + y_empty_step: float = 0.2, +) -> None: + """Render a list of styled text lines on an axis.""" + y = start_y + for txt, size, color, weight in lines: + if txt == "": + y -= y_empty_step + continue + ax.text( + x_pos, + y, + txt, + fontsize=size, + color=color, + fontweight=weight, + va="top", + ) + y -= y_step diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_diy_unet.py b/python_pkg/praca_magisterska_video/generate_images/_q23_diy_unet.py new file mode 100644 index 0000000..f6c69d0 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_diy_unet.py @@ -0,0 +1,251 @@ +"""DIY U-Net step-by-step diagram generator.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q23_common import ( + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TINY, + GRAY1, + GRAY3, + GRAY5, + GREEN_ACCENT, + WHITE, + _save_figure, + np, + plt, + rng, +) +from matplotlib.patches import FancyBboxPatch + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def _draw_unet_layer_stack( + ax: Axes, + layer_sizes: list[tuple[int, int]], + *, + face_color: str, + edge_color: str, + arrow_color: str, + arrow_label: str, + add_skip: bool = False, +) -> None: + """Draw encoder or decoder layer stack for DIY U-Net.""" + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + + y_pos = 8.5 + for i, (s, c) in enumerate(layer_sizes): + w = s / 64 * 4 + h = 0.8 + rect = FancyBboxPatch( + (5 - w / 2, y_pos), + w, + h, + boxstyle="round,pad=0.05", + facecolor=face_color, + edgecolor=edge_color, + linewidth=1, + ) + ax.add_patch(rect) + label = f"{s}x{s}x{c}" + if add_skip and i < len(layer_sizes) - 1: + label += " + skip!" + ax.text( + 5, + y_pos + h / 2, + label, + ha="center", + va="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + if i < len(layer_sizes) - 1: + ax.annotate( + "", + xy=(5, y_pos - 0.3), + xytext=(5, y_pos), + arrowprops={ + "arrowstyle": "->", + "color": arrow_color, + "lw": 1.5, + }, + ) + ax.text( + 7, + y_pos - 0.15, + arrow_label, + fontsize=FS_TINY, + color=arrow_color, + ) + y_pos -= 2.2 + + +def _draw_unet_pseudocode(ax: Axes) -> None: + """Draw panel 6: U-Net pseudocode.""" + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Pseudokod U-Net", fontsize=FS, fontweight="bold") + + code_lines = [ + "# ENCODER", + "e1 = conv_block(input, 64) # 64x64", + "e2 = conv_block(pool(e1), 128) # 32x32", + "e3 = conv_block(pool(e2), 256) # 16x16", + "", + "# BOTTLENECK", + "b = conv_block(pool(e3), 512) # 8x8", + "", + "# DECODER + SKIP", + "d3 = conv_block(concat(", + " upconv(b), e3), 256) # 16x16", + "d2 = conv_block(concat(", + " upconv(d3), e2), 128) # 32x32", + "d1 = conv_block(concat(", + " upconv(d2), e1), 64) # 64x64", + "", + "output = conv_1x1(d1, n_classes)", + ] + for i, line in enumerate(code_lines): + txt_color = ( + ACCENT + if "concat" in line + else (GREEN_ACCENT if "output" in line else BLACK) + ) + ax.text( + 0.3, + 9.5 - i * 0.55, + line, + fontsize=FS_TINY, + fontfamily="monospace", + color=txt_color, + ) + + +def generate_diy_unet() -> None: + """Generate diy unet.""" + fig, axes = plt.subplots(2, 3, figsize=(11, 7)) + + size = 64 + + # Create synthetic image with two regions + img = np.ones((size, size, 3), dtype=np.uint8) * 200 # bright bg + # Dark region (object 1) + yy, xx = np.mgrid[:size, :size] + mask1 = ((xx - 20) ** 2 + (yy - 30) ** 2) < 12**2 + img[mask1] = [60, 60, 60] + # Medium region (object 2) + mask2 = ((xx - 45) ** 2 + (yy - 25) ** 2) < 8**2 + img[mask2] = [120, 120, 120] + + gt = np.zeros((size, size), dtype=np.uint8) + gt[mask1] = 1 # class 1 + gt[mask2] = 2 # class 2 + + # --- Panel 1: Input image --- + ax = axes[0, 0] + ax.imshow(img) + ax.set_title("Krok 1: obraz RGB\n64x64x3", fontsize=FS, fontweight="bold") + ax.axis("off") + + # --- Panel 2: Encoder shrinks --- + ax = axes[0, 1] + ax.set_title("Krok 2: Encoder ZMNIEJSZA", fontsize=FS, fontweight="bold") + _draw_unet_layer_stack( + ax, + [(64, 3), (32, 64), (16, 128), (8, 256)], + face_color=ACCENT_LIGHT, + edge_color=ACCENT, + arrow_color=ACCENT, + arrow_label="Conv+Pool", + ) + ax.text( + 5, + 0.3, + "Wyciąga cechy:\nkrawędzie → tekstury → obiekty", + ha="center", + fontsize=FS_TINY, + color=GRAY5, + ) + + # --- Panel 3: Bottleneck --- + ax = axes[0, 2] + # Show feature maps at bottleneck (abstract) + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title( + "Krok 3: Bottleneck\n(najbardziej abstrakcyjne cechy)", + fontsize=FS, + fontweight="bold", + ) + + # Show small abstract feature maps + for k in range(4): + small = rng.random((4, 4)) + ax_inset = fig.add_axes( + [0.68 + (k % 2) * 0.08, 0.72 - (k // 2) * 0.1, 0.06, 0.06] + ) + ax_inset.imshow(small, cmap="gray") + ax_inset.axis("off") + + ax.text( + 5, + 5, + '8x8x256\n\nMałe mapy, ale DUŻO kanałów\nKażdy kanał = jedna „cecha"\n' + '(np. kanał 42 = „wykrył koło"\n kanał 78 = „wykrył krawędź")\n\n' + "Wie CO jest na obrazie\nale nie wie GDZIE dokładnie", + ha="center", + va="center", + fontsize=FS_SMALL, + bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, + ) + + # --- Panel 4: Decoder enlarges --- + ax = axes[1, 0] + ax.set_title( + "Krok 4: Decoder ZWIĘKSZA\n(+ skip connections!)", + fontsize=FS, + fontweight="bold", + ) + _draw_unet_layer_stack( + ax, + [(8, 256), (16, 128), (32, 64), (64, 3)], + face_color="#C8E6C9", + edge_color=GREEN_ACCENT, + arrow_color=GREEN_ACCENT, + arrow_label="UpConv+Concat", + add_skip=True, + ) + ax.text( + 5, + 0.3, + "Odtwarza rozdzielczość:\nskip → przywraca krawędzie", + ha="center", + fontsize=FS_TINY, + color=GRAY5, + ) + + # --- Panel 5: Output segmentation map --- + ax = axes[1, 1] + cmap = plt.cm.colors.ListedColormap([WHITE, ACCENT_LIGHT, "#FFCDD2"]) + ax.imshow(gt, cmap=cmap, interpolation="nearest") + ax.set_title( + "Krok 5: mapa segmentacji\n64x64 (3 klasy)", fontsize=FS, fontweight="bold" + ) + ax.axis("off") + ax.text(20, -3, "Tło=0, obiekt A=1, obiekt B=2", fontsize=FS_TINY, ha="center") + + # --- Panel 6: Summary pseudocode --- + _draw_unet_pseudocode(axes[1, 2]) + + _save_figure("q23_diy_unet.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_mean_shift_ncuts.py b/python_pkg/praca_magisterska_video/generate_images/_q23_mean_shift_ncuts.py new file mode 100644 index 0000000..912113c --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_mean_shift_ncuts.py @@ -0,0 +1,380 @@ +"""Mean shift and normalized cuts diagram generators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q23_common import ( + _DARK_PIXEL_THRESHOLD, + _GRID_LAST_IDX, + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TINY, + FS_TITLE, + GRAY1, + GRAY3, + GRAY4, + GRAY5, + GRAY6, + GREEN_ACCENT, + RED_ACCENT, + _render_text_lines, + _save_figure, + np, + plt, + rng, +) +from matplotlib import patches + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def generate_mean_shift() -> None: + """Generate mean shift.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) + + # --- Panel 1: Feature space concept --- + ax = axes[0] + # Three clusters in 2D feature space (brightness, x-position) + c1x = rng.normal(2, 0.5, 40) + c1y = rng.normal(2, 0.5, 40) + c2x = rng.normal(6, 0.6, 35) + c2y = rng.normal(7, 0.5, 35) + c3x = rng.normal(8, 0.4, 25) + c3y = rng.normal(3, 0.6, 25) + + ax.scatter(c1x, c1y, c=GRAY4, s=15, alpha=0.7, zorder=3) + ax.scatter(c2x, c2y, c=GRAY4, s=15, alpha=0.7, zorder=3) + ax.scatter(c3x, c3y, c=GRAY4, s=15, alpha=0.7, zorder=3) + + # Label peaks + ax.scatter([2], [2], c=RED_ACCENT, s=80, marker="*", zorder=5, label="Max gęstości") + ax.scatter([6], [7], c=RED_ACCENT, s=80, marker="*", zorder=5) + ax.scatter([8], [3], c=RED_ACCENT, s=80, marker="*", zorder=5) + + ax.set_xlabel("Cecha 1: jasność", fontsize=FS) + ax.set_ylabel("Cecha 2: pozycja x", fontsize=FS) + ax.set_title("Przestrzeń cech", fontsize=FS_TITLE, fontweight="bold") + for lx, ly, ltxt in [ + (2, 0.3, "Klaster 1\n(ciemne, lewo)"), + (6, 5.3, "Klaster 2\n(jasne, prawo)"), + (8, 1.3, "Klaster 3\n(jasne, dół)"), + ]: + ax.text(lx, ly, ltxt, ha="center", fontsize=FS_TINY, color=GRAY6) + ax.legend(fontsize=FS_SMALL, loc="upper left") + + # --- Panel 2: Kernel/window moving --- + ax = axes[1] + ax.scatter(c1x, c1y, c=ACCENT_LIGHT, s=15, alpha=0.7, zorder=3) + ax.scatter(c2x, c2y, c=GRAY3, s=15, alpha=0.7, zorder=3) + ax.scatter(c3x, c3y, c=GRAY3, s=15, alpha=0.7, zorder=3) + + # Show kernel movement + path_x = [4.5, 3.8, 3.0, 2.3, 2.05] + path_y = [4.0, 3.3, 2.7, 2.2, 2.03] + + for i, (px, py) in enumerate(zip(path_x, path_y, strict=False)): + alpha = 0.3 + 0.15 * i + circle = plt.Circle( + (px, py), + 1.2, + fill=False, + edgecolor=ACCENT, + linewidth=1.5, + linestyle="--" if i < len(path_x) - 1 else "-", + alpha=alpha, + ) + ax.add_patch(circle) + if i < len(path_x) - 1: + ax.annotate( + "", + xy=(path_x[i + 1], path_y[i + 1]), + xytext=(px, py), + arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 1.5}, + ) + + ax.scatter([path_x[0]], [path_y[0]], c=ACCENT, s=50, marker="o", zorder=5) + ax.scatter([path_x[-1]], [path_y[-1]], c=RED_ACCENT, s=80, marker="*", zorder=5) + + ax.text( + 4.5, 5.2, "Start: losowy\npiksel", fontsize=FS_SMALL, ha="center", color=ACCENT + ) + ax.text( + 2.05, + 0.5, + "Koniec: max\ngęstości", + fontsize=FS_SMALL, + ha="center", + color=RED_ACCENT, + fontweight="bold", + ) + ax.text( + 7, + 8, + "Okno (jądro)\nprzesuwa się\ndo skupiska", + fontsize=FS_SMALL, + ha="center", + color=GRAY6, + bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, + ) + + ax.set_xlabel("Cecha 1", fontsize=FS) + ax.set_ylabel("Cecha 2", fontsize=FS) + ax.set_title("Jądro → max gęstości", fontsize=FS_TITLE, fontweight="bold") + ax.set_xlim(0, 10) + ax.set_ylim(0, 9) + + # --- Panel 3: Why no K parameter --- + ax = axes[2] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Dlaczego bez K?", fontsize=FS_TITLE, fontweight="bold") + + lines = [ + ("K-means wymaga:", FS, RED_ACCENT, "bold"), + (' „Podaj K=3 klastry"', FS_SMALL, "black", "normal"), + (" Problem: skąd wiesz ile klastrów?", FS_SMALL, GRAY5, "normal"), + ("", 0, "", ""), + ("Mean Shift NIE wymaga K:", FS, GREEN_ACCENT, "bold"), + (" Każdy piksel startuje → toczy się", FS_SMALL, "black", "normal"), + (" → trafia do najbliższego szczytu", FS_SMALL, "black", "normal"), + (" → ile szczytów = tyle segmentów", FS_SMALL, "black", "normal"), + (" → automatycznie!", FS_SMALL, GREEN_ACCENT, "bold"), + ("", 0, "", ""), + ("Parametr: bandwidth (szerokość okna)", FS, "black", "bold"), + (" Duże okno → mało segmentów", FS_SMALL, "black", "normal"), + (" Małe okno → dużo segmentów", FS_SMALL, "black", "normal"), + ("", 0, "", ""), + ("Okno = jądro (kernel):", FS, "black", "bold"), + (" Koło o promieniu h wokół punktu.", FS_SMALL, "black", "normal"), + (" Oblicz średnią pikseli W oknie.", FS_SMALL, "black", "normal"), + (" Przesuń okno na tę średnią.", FS_SMALL, "black", "normal"), + (" Powtórz aż się zatrzyma.", FS_SMALL, "black", "normal"), + ] + _render_text_lines(ax, lines, start_y=9.0) + + _save_figure("q23_mean_shift.png") + + +def _draw_ncuts_pixel_grid( + ax: Axes, + pixel_vals: np.ndarray, +) -> None: + """Draw 4x4 pixel grid with value labels and edge weights.""" + for i in range(4): + for j in range(4): + v = pixel_vals[i, j] + gray_val = v / 255.0 + str(gray_val) + rect = patches.Rectangle( + (j - 0.4, 3 - i - 0.4), + 0.8, + 0.8, + facecolor=(gray_val, gray_val, gray_val), + edgecolor=BLACK, + linewidth=0.8, + ) + ax.add_patch(rect) + text_color = "white" if v < _DARK_PIXEL_THRESHOLD else "black" + ax.text( + j, + 3 - i, + str(v), + ha="center", + va="center", + fontsize=FS_SMALL, + color=text_color, + fontweight="bold", + ) + + +def _draw_ncuts_edges( + ax: Axes, + pixel_vals: np.ndarray, +) -> None: + """Draw weighted edges between adjacent pixels.""" + for i in range(4): + for j in range(4): + if j < _GRID_LAST_IDX: + similarity = max( + 0, + 1 - abs(pixel_vals[i, j] - pixel_vals[i, j + 1]) / 255, + ) + lw = similarity * 2.5 + 0.3 + alpha = similarity * 0.8 + 0.2 + ax.plot( + [j + 0.4, j + 0.6], + [3 - i, 3 - i], + color=GRAY5, + linewidth=lw, + alpha=alpha, + ) + if i < _GRID_LAST_IDX: + similarity = max( + 0, + 1 - abs(pixel_vals[i, j] - pixel_vals[i + 1, j]) / 255, + ) + lw = similarity * 2.5 + 0.3 + alpha = similarity * 0.8 + 0.2 + ax.plot( + [j, j], + [3 - i - 0.4, 3 - i - 0.6], + color=GRAY5, + linewidth=lw, + alpha=alpha, + ) + + +def generate_normalized_cuts() -> None: + """Generate normalized cuts.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) + + # --- Panel 1: Image as graph --- + ax = axes[0] + ax.set_xlim(-0.5, 4.5) + ax.set_ylim(-0.5, 4.5) + ax.set_aspect("equal") + ax.set_title("Obraz → graf", fontsize=FS_TITLE, fontweight="bold") + + pixel_vals = np.array( + [ + [30, 35, 180, 190], + [40, 30, 185, 200], + [170, 180, 40, 35], + [190, 175, 30, 45], + ] + ) + _draw_ncuts_pixel_grid(ax, pixel_vals) + _draw_ncuts_edges(ax, pixel_vals) + + ax.text( + 2, + -0.8, + "Grube linie = duże podobieństwo\n(silna krawędź grafu)", + ha="center", + fontsize=FS_TINY, + color=GRAY5, + ) + ax.axis("off") + + # --- Panel 2: Cut concept --- + ax = axes[1] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Cięcie grafu (graph cut)", fontsize=FS_TITLE, fontweight="bold") + + # Draw two groups of nodes + # Group A (dark pixels) + positions_a = [(2, 7), (3, 8), (2, 5), (3, 6)] + positions_b = [(7, 7), (8, 8), (7, 5), (8, 6)] + + # Intra-group edges (thick = similar) + for i, (x1, y1) in enumerate(positions_a): + for x2, y2 in positions_a[i + 1 :]: + ax.plot([x1, x2], [y1, y2], color=ACCENT, linewidth=2, alpha=0.5) + for i, (x1, y1) in enumerate(positions_b): + for x2, y2 in positions_b[i + 1 :]: + ax.plot([x1, x2], [y1, y2], color=RED_ACCENT, linewidth=2, alpha=0.5) + + # Inter-group edges (thin = dissimilar) — these get cut + cut_edges = [((3, 8), (7, 7)), ((3, 6), (7, 5)), ((2, 5), (7, 5))] + for (x1, y1), (x2, y2) in cut_edges: + ax.plot([x1, x2], [y1, y2], color=GRAY4, linewidth=0.8, linestyle="--") + + # Draw nodes + for x, y in positions_a: + ax.scatter(x, y, c=ACCENT, s=120, zorder=5, edgecolors=BLACK, linewidth=0.8) + for x, y in positions_b: + ax.scatter(x, y, c="#FFCDD2", s=120, zorder=5, edgecolors=BLACK, linewidth=0.8) + + # Cut line + ax.plot( + [5, 5], [3.5, 9.5], color=RED_ACCENT, linewidth=2.5, linestyle="-", zorder=4 + ) + ax.text( + 5, 9.8, "CIĘCIE", ha="center", fontsize=FS, fontweight="bold", color=RED_ACCENT + ) + + ax.text( + 2.5, + 3.8, + "Segment A\n(ciemne piksele)", + ha="center", + fontsize=FS_SMALL, + color=ACCENT, + ) + ax.text( + 7.5, + 3.8, + "Segment B\n(jasne piksele)", + ha="center", + fontsize=FS_SMALL, + color=RED_ACCENT, + ) + + # Formula + ax.text( + 5, + 1.8, + "Ncut(A,B) = cut(A,B)/assoc(A,V)\n + cut(A,B)/assoc(B,V)", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, + ) + ax.text( + 5, + 0.5, + "Minimalizuj Ncut → tnij SŁABE krawędzie\nzachowuj SILNE (wewnątrz grupy)", + ha="center", + fontsize=FS_TINY, + color=GRAY5, + ) + + # --- Panel 3: Algorithm summary --- + ax = axes[2] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Algorytm Normalized Cuts", fontsize=FS_TITLE, fontweight="bold") + + steps = [ + ( + "1. Zbuduj graf", + "Piksele = węzły\nKrawędzie = podobieństwo" + " sąsiadów\n(kolor, jasność, odległość)", + ), + ( + "2. Macierz podobieństwa W", + "W[i,j] = exp(-|kolori - kolorj|² / σ²)" + "\n→ im podobniejsze, tym wyższa waga", + ), + ("3. Macierz stopni D", "D[i,i] = Σ W[i,j]\n(suma wszystkich wag z węzła i)"), + ("4. Rozwiąż problem własny", "(D-W)·y = λ·D·y\n→ drugi najm. wektor własny y"), + ("5. Podziel wg y", "y[i] > 0 → segment A\ny[i] ≤ 0 → segment B"), + ] + + y = 9.5 + for title, desc in steps: + ax.text(0.5, y, title, fontsize=FS, fontweight="bold", va="top") + y -= 0.4 + ax.text(0.8, y, desc, fontsize=FS_TINY, va="top", color=GRAY6) + y -= 1.2 + + ax.text( + 5, + 0.3, + "Złożoność: O(n³) — wymaga eigen decomposition!", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + color=RED_ACCENT, + ) + + _save_figure("q23_normalized_cuts.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_mnemonics.py b/python_pkg/praca_magisterska_video/generate_images/_q23_mnemonics.py new file mode 100644 index 0000000..9f340a2 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_mnemonics.py @@ -0,0 +1,327 @@ +"""Mnemonic summary diagram generator.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q23_common import ( + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TINY, + FS_TITLE, + GRAY1, + GRAY5, + GRAY6, + GREEN_ACCENT, + RED_ACCENT, + _save_figure, + plt, +) +from matplotlib.patches import FancyBboxPatch + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def generate_mnemonics() -> None: + """Generate mnemonics.""" + _fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + ax.set_xlim(0, 20) + ax.set_ylim(0, 16) + ax.axis("off") + ax.set_title( + "Mnemoniki — segmentacja obrazu", fontsize=FS_TITLE + 2, fontweight="bold" + ) + + def draw_card( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + title: str, + mnemonic: str, + color: str, + detail: str = "", + ) -> None: + """Draw card.""" + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.15", + facecolor=color, + edgecolor=BLACK, + linewidth=1, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h - 0.3, + title, + ha="center", + va="top", + fontsize=FS, + fontweight="bold", + ) + ax.text( + x + w / 2, + y + h / 2 - 0.1, + mnemonic, + ha="center", + va="center", + fontsize=FS_SMALL, + fontstyle="italic", + color=GRAY6, + ) + if detail: + ax.text( + x + w / 2, + y + 0.4, + detail, + ha="center", + va="bottom", + fontsize=FS_TINY, + color=GRAY5, + ) + + # Title: STRATEGIE KLASYCZNE + ax.text( + 5, + 15.5, + "STRATEGIE KLASYCZNE", + fontsize=FS_TITLE, + fontweight="bold", + color=ACCENT, + ha="center", + ) + + cards_classic = [ + ( + 0.2, + 12.5, + 4.5, + 2.5, + "Thresholding", + '„PRÓG na bramce"\nPrzepuszcza > T,\nblokuje ≤ T', + ACCENT_LIGHT, + "jasne=1, ciemne=0", + ), + ( + 5, + 12.5, + 4.5, + 2.5, + "Otsu", + '„AUTO-bramkarz"\nSam dobiera próg\nmin σ² wewnątrz', + ACCENT_LIGHT, + "histogram bimodalny", + ), + ( + 0.2, + 9.5, + 4.5, + 2.5, + "Region Growing", + '„PLAMA rozlana"\nSeed → BFS po\npodobnych sąsiadach', + ACCENT_LIGHT, + "jak atrament na papierze", + ), + ( + 5, + 9.5, + 4.5, + 2.5, + "Watershed", + '„ZALEWANIE terenu"\nDoliny=obiekty\nGranie=granice', + ACCENT_LIGHT, + "woda + geography", + ), + ( + 0.2, + 6.5, + 4.5, + 2.5, + "Mean Shift", + '„KULKI toczą się"\nKażda → max gęstości\nBez K!', + ACCENT_LIGHT, + "bandwidth = okno", + ), + ( + 5, + 6.5, + 4.5, + 2.5, + "Normalized Cuts", + '„CIĘCIE sznurków"\nGraf: tnij słabe\nkrawędzie (O(n³)!)', + ACCENT_LIGHT, + "eigenvector problem", + ), + ] + + for args in cards_classic: + draw_card(ax, *args) + + # Title: SIECI NEURONOWE + ax.text( + 15, + 15.5, + "SIECI NEURONOWE", + fontsize=FS_TITLE, + fontweight="bold", + color=GREEN_ACCENT, + ha="center", + ) + + cards_nn = [ + ( + 10.5, + 12.5, + 4.5, + 2.5, + "FCN (2015)", + '„FC → Conv 1x1"\nPierwsza end-to-end\nDowolny rozmiar', + "#C8E6C9", + "skip connections", + ), + ( + 15.3, + 12.5, + 4.5, + 2.5, + "U-Net (2015)", + '„Litera U"\nEncoder↓ Decoder↑\nSkip = concat', + "#C8E6C9", + "medycyna, małe dane", + ), + ( + 10.5, + 9.5, + 4.5, + 2.5, + "DeepLab v3+", + '„DZIURY w filtrze"\nAtrous conv (rate)\nASPP multi-scale', + "#C8E6C9", + "à trous = z dziurami", + ), + ( + 15.3, + 9.5, + 4.5, + 2.5, + "Transformer", + '„WSZYSCY ze\nWSZYSTKIMI"\nSelf-attention O(n²)', + "#C8E6C9", + "SegFormer, Mask2Former", + ), + ] + + for args in cards_nn: + draw_card(ax, *args) + + # Metryki + ax.text( + 10, + 8.3, + "METRYKI I LOSS", + fontsize=FS_TITLE, + fontweight="bold", + color=RED_ACCENT, + ha="center", + ) + + cards_metrics = [ + ( + 10.5, + 6.5, + 4.5, + 1.6, + "mIoU", + '„Nakładka / Suma"\nIoU = A∩B / A\u222aB', + "#FFCDD2", + "", + ), + ( + 15.3, + 6.5, + 4.5, + 1.6, + "Dice / Focal", + '„Dice=2·nakładka"\nFocal=trudne px', + "#FFCDD2", + "", + ), + ] + + for args in cards_metrics: + draw_card(ax, *args) + + # Master mnemonic at bottom + rect = FancyBboxPatch( + (1, 0.3), + 18, + 5.5, + boxstyle="round,pad=0.2", + facecolor=GRAY1, + edgecolor=BLACK, + linewidth=1.5, + ) + ax.add_patch(rect) + ax.text( + 10, + 5.3, + "SUPER-MNEMONIK: kolejność algorytmów segmentacji", + ha="center", + fontsize=FS, + fontweight="bold", + ) + ax.text( + 10, + 4.5, + '„TORW-MN FUD-T"', + ha="center", + fontsize=FS_TITLE + 2, + fontweight="bold", + color=RED_ACCENT, + ) + ax.text( + 10, + 3.5, + "Klasyczne: Thresholding → Otsu → Region" + " growing → Watershed → Mean shift → Norm. cuts", + ha="center", + fontsize=FS_SMALL, + ) + ax.text( + 10, + 2.8, + "Neuronowe: FCN → U-Net → DeepLab → Transformer", + ha="center", + fontsize=FS_SMALL, + ) + ax.text( + 10, + 1.8, + "„Turyści Oglądają Rzekę, Wodospad," + " Morze, Nurt — Fotografują Uroczy" + ' Dwór Tajemnic"', + ha="center", + fontsize=FS_SMALL, + fontstyle="italic", + color=ACCENT, + ) + ax.text( + 10, + 1.0, + "Klasyczne: proste→auto→BFS→flood→" + "gęstość→graf | Neuronowe:" + " FC→U-skip→dilated→attention", + ha="center", + fontsize=FS_TINY, + color=GRAY5, + ) + + _save_figure("q23_mnemonics.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_nn_basics.py b/python_pkg/praca_magisterska_video/generate_images/_q23_nn_basics.py new file mode 100644 index 0000000..1d737ea --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_nn_basics.py @@ -0,0 +1,293 @@ +"""ReLU and dot product diagram generators.""" + +from __future__ import annotations + +from _q23_common import ( + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TINY, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY5, + GREEN_ACCENT, + RED_ACCENT, + _save_figure, + np, + plt, +) +from matplotlib import patches + + +def generate_relu() -> None: + """Generate relu.""" + _fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) + + # --- Panel 1: ReLU plot --- + ax = axes[0] + x = np.linspace(-5, 5, 200) + relu = np.maximum(0, x) + ax.plot(x, relu, color=ACCENT, linewidth=2.5, label="ReLU(x) = max(0, x)") + ax.axhline(y=0, color=GRAY3, linewidth=0.5) + ax.axvline(x=0, color=GRAY3, linewidth=0.5) + ax.fill_between(x[x < 0], 0, 0, color=RED_ACCENT, alpha=0.1) + ax.fill_between(x[x >= 0], 0, relu[x >= 0], color=ACCENT, alpha=0.1) + + # Annotations + ax.annotate( + 'x < 0 → output = 0\n(neuron „wyłączony")', + xy=(-3, 0), + fontsize=FS_SMALL, + ha="center", + va="bottom", + color=RED_ACCENT, + arrowprops={"arrowstyle": "->", "color": RED_ACCENT}, + xytext=(-3, 2), + ) + ax.annotate( + 'x ≥ 0 → output = x\n(neuron „włączony")', + xy=(3, 3), + fontsize=FS_SMALL, + ha="center", + va="bottom", + color=ACCENT, + arrowprops={"arrowstyle": "->", "color": ACCENT}, + xytext=(3, 4.5), + ) + ax.scatter([0], [0], c=BLACK, s=40, zorder=5) + ax.text(0.3, -0.5, "(0,0)", fontsize=FS_SMALL, color=GRAY5) + ax.set_xlabel("x (wejście neuronu)", fontsize=FS) + ax.set_ylabel("ReLU(x)", fontsize=FS) + ax.set_title("ReLU — Rectified Linear Unit", fontsize=FS_TITLE, fontweight="bold") + ax.legend(fontsize=FS_SMALL, loc="upper left") + ax.set_ylim(-1, 6) + ax.grid(visible=True, alpha=0.2) + + # --- Panel 2: Why ReLU --- + ax = axes[1] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Dlaczego ReLU?", fontsize=FS_TITLE, fontweight="bold") + + y = 9.0 + lines = [ + ("Neuron oblicza:", FS, BLACK, "bold"), + (" z = w₁·x₁ + w₂·x₂ + ... + bias", FS_SMALL, BLACK, "normal"), + (" output = ReLU(z) = max(0, z)", FS_SMALL, ACCENT, "bold"), + ("", 0, "", ""), + ("Przykład:", FS, BLACK, "bold"), + (" wagi: w₁=0.5, w₂=-0.3, bias=0.1", FS_SMALL, BLACK, "normal"), + (" wejścia: x₁=2.0, x₂=4.0", FS_SMALL, BLACK, "normal"), + (" z = 0.5·2 + (-0.3)·4 + 0.1 = -0.1", FS_SMALL, BLACK, "normal"), + (" ReLU(-0.1) = max(0, -0.1) = 0", FS_SMALL, RED_ACCENT, "bold"), + (" → neuron milczy (wejście nieistotne)", FS_SMALL, GRAY5, "normal"), + ("", 0, "", ""), + ("Gdyby z = 2.3:", FS, BLACK, "bold"), + (" ReLU(2.3) = max(0, 2.3) = 2.3", FS_SMALL, GREEN_ACCENT, "bold"), + (" → neuron aktywny! Przekazuje sygnał", FS_SMALL, GRAY5, "normal"), + ("", 0, "", ""), + ("Szybsza niż sigmoid/tanh", FS_SMALL, GRAY5, "normal"), + ("(brak exp() → szybkie obliczenia)", FS_SMALL, GRAY5, "normal"), + ] + for txt, size, color, weight in lines: + if txt == "": + y -= 0.2 + continue + ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top") + y -= 0.5 + + _save_figure("q23_relu.png") + + +def generate_dot_product() -> None: + """Generate dot product.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) + + # --- Panel 1: Concept --- + ax = axes[0] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title( + "Iloczyn skalarny\n(dot product)", fontsize=FS_TITLE, fontweight="bold" + ) + + y = 8.5 + lines = [ + ("Dwa wektory (listy liczb) → JEDNA liczba", FS, BLACK, "bold"), + ("", 0, "", ""), + ("a = [a₁, a₂, a₃] b = [b₁, b₂, b₃]", FS, ACCENT, "normal"), + ("", 0, "", ""), + ("a · b = a₁·b₁ + a₂·b₂ + a₃·b₃", FS, BLACK, "bold"), + ("", 0, "", ""), + ("Przykład:", FS, BLACK, "bold"), + ("a = [1, 3, -2] b = [4, -1, 5]", FS_SMALL, BLACK, "normal"), + ("a·b = 1·4 + 3·(-1) + (-2)·5", FS_SMALL, BLACK, "normal"), + (" = 4 + (-3) + (-10) = -9", FS_SMALL, RED_ACCENT, "bold"), + ("", 0, "", ""), + ( + 'Duży wynik → wektory „podobne" (w tym samym kierunku)', + FS_SMALL, + GREEN_ACCENT, + "normal", + ), + ('Mały/ujemny → wektory „różne"', FS_SMALL, RED_ACCENT, "normal"), + ] + for txt, size, color, weight in lines: + if txt == "": + y -= 0.25 + continue + ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top") + y -= 0.55 + + # --- Panel 2: Convolution as dot product --- + ax = axes[1] + ax.set_xlim(-0.5, 5.5) + ax.set_ylim(-0.5, 5.5) + ax.set_aspect("equal") + ax.set_title( + "Konwolucja = iloczyn skalarny\nfiltra x fragment obrazu", + fontsize=FS_TITLE, + fontweight="bold", + ) + + # Filter 3x3 + filter_vals = [[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]] + for i in range(3): + for j in range(3): + rect = patches.Rectangle( + (j - 0.4, 4 - i - 0.4), + 0.8, + 0.8, + facecolor=ACCENT_LIGHT, + edgecolor=BLACK, + linewidth=0.8, + ) + ax.add_patch(rect) + ax.text( + j, + 4 - i, + str(filter_vals[i][j]), + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + + ax.text(1, 1.5, "Filtr", ha="center", fontsize=FS, fontweight="bold", color=ACCENT) + + # Image patch + img_vals = [[50, 50, 200], [50, 50, 200], [50, 50, 200]] + for i in range(3): + for j in range(3): + rect = patches.Rectangle( + (j + 2.6, 4 - i - 0.4), + 0.8, + 0.8, + facecolor=GRAY2, + edgecolor=BLACK, + linewidth=0.8, + ) + ax.add_patch(rect) + ax.text( + j + 3, + 4 - i, + str(img_vals[i][j]), + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + + ax.text( + 4, + 1.5, + "Fragment\nobrazu", + ha="center", + fontsize=FS, + fontweight="bold", + color=GRAY5, + ) + + ax.text( + 2.5, + 0.5, + "(-1)·50 + 0·50 + 1·200 +\n" + "(-1)·50 + 0·50 + 1·200 +\n" + "(-1)·50 + 0·50 + 1·200\n= 450 (krawędź!)", + ha="center", + fontsize=FS_TINY, + fontweight="bold", + bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GREEN_ACCENT}, + ) + + ax.axis("off") + + # --- Panel 3: Vector visualization --- + ax = axes[2] + # Draw two vectors + ax.quiver( + 0, + 0, + 3, + 4, + angles="xy", + scale_units="xy", + scale=1, + color=ACCENT, + width=0.025, + label="a = [3, 4]", + ) + ax.quiver( + 0, + 0, + 4, + 1, + angles="xy", + scale_units="xy", + scale=1, + color=RED_ACCENT, + width=0.025, + label="b = [4, 1]", + ) + + # Show angle + theta = np.linspace(np.arctan2(1, 4), np.arctan2(4, 3), 30) + r = 1.5 + ax.plot(r * np.cos(theta), r * np.sin(theta), color=GREEN_ACCENT, linewidth=1.5) + ax.text(1.8, 1.3, "θ", fontsize=FS, color=GREEN_ACCENT, fontweight="bold") + + ax.text(3.2, 4.2, "a", fontsize=FS, color=ACCENT, fontweight="bold") + ax.text(4.2, 1.2, "b", fontsize=FS, color=RED_ACCENT, fontweight="bold") + + ax.text( + 2.5, + -1.0, + "a · b = |a|·|b|·cos(θ)\n= 3·4 + 4·1 = 16", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, + ) + ax.text( + 2.5, + -2.0, + 'Mały kąt θ → duży dot product\n= wektory „zgadają się"', + ha="center", + fontsize=FS_TINY, + color=GRAY5, + ) + + ax.set_xlim(-0.5, 5.5) + ax.set_ylim(-2.5, 5.5) + ax.set_aspect("equal") + ax.grid(visible=True, alpha=0.2) + ax.legend(fontsize=FS_SMALL, loc="upper left") + ax.set_title("Geometrycznie: kąt", fontsize=FS_TITLE, fontweight="bold") + + _save_figure("q23_dot_product.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_otsu_watershed.py b/python_pkg/praca_magisterska_video/generate_images/_q23_otsu_watershed.py new file mode 100644 index 0000000..ac44fd1 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_otsu_watershed.py @@ -0,0 +1,408 @@ +"""Otsu thresholding and watershed diagram generators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q23_common import ( + _RIDGE_X, + _VALLEY2_END, + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + GREEN_ACCENT, + RED_ACCENT, + _render_text_lines, + _save_figure, + np, + plt, + rng, +) +from matplotlib.patches import FancyBboxPatch + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def _draw_otsu_variance_panel(ax: Axes) -> None: + """Draw panel 2: within-class variance explanation.""" + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Wariancja wewnątrzklasowa", fontsize=FS_TITLE, fontweight="bold") + + texts = [ + ( + "Wariancja = jak bardzo wartości\nróżnią się od średniej", + FS, + "black", + "normal", + ), + ("", 0, "black", "normal"), + ("Klasa 0 (piksele ≤ T):", FS, ACCENT, "bold"), + (" wartości: 30, 50, 45, 60, 55", FS_SMALL, "black", "normal"), + (" średnia μ₀ = 48", FS_SMALL, "black", "normal"), + (" σ₀² = ((30-48)²+(50-48)²+...)/5 = 108", FS_SMALL, "black", "normal"), + ("", 0, "black", "normal"), + ("Klasa 1 (piksele > T):", FS, RED_ACCENT, "bold"), + (" wartości: 180, 200, 190, 210, 195", FS_SMALL, "black", "normal"), + (" średnia μ₁ = 195", FS_SMALL, "black", "normal"), + (" σ₁² = ((180-195)²+...)/5 = 100", FS_SMALL, "black", "normal"), + ("", 0, "black", "normal"), + ("σ²_wewnątrz = w₀·σ₀² + w₁·σ₁²", FS, BLACK, "bold"), + ("= 0.6·108 + 0.4·100 = 104.8", FS_SMALL, "black", "normal"), + ("", 0, "black", "normal"), + ("Otsu próbuje KAŻDE T: 0,1,...,255", FS_SMALL, GREEN_ACCENT, "bold"), + ("Wybiera T dające MINIMUM σ²_wewnątrz", FS_SMALL, GREEN_ACCENT, "bold"), + ] + _render_text_lines( + ax, + texts, + x_pos=0.3, + start_y=9.2, + y_step=0.55, + y_empty_step=0.25, + ) + + +def generate_otsu_bimodal() -> None: + """Generate otsu bimodal.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) + + # --- Panel 1: Bimodal histogram --- + ax = axes[0] + dark = rng.normal(60, 20, 3000).clip(0, 255) + bright = rng.normal(190, 25, 2000).clip(0, 255) + all_pixels = np.concatenate([dark, bright]) + + counts, _bins, _bars = ax.hist( + all_pixels, bins=64, color=GRAY3, edgecolor=GRAY5, linewidth=0.5 + ) + ax.axvline( + x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="Próg Otsu T=128" + ) + ax.fill_betweenx([0, max(counts) * 1.1], 0, 128, alpha=0.12, color=ACCENT) + ax.fill_betweenx([0, max(counts) * 1.1], 128, 255, alpha=0.12, color=RED_ACCENT) + ax.text( + 45, + max(counts) * 0.85, + "Klasa 0\n(tło)", + ha="center", + fontsize=FS, + fontweight="bold", + color=ACCENT, + ) + ax.text( + 195, + max(counts) * 0.85, + "Klasa 1\n(obiekt)", + ha="center", + fontsize=FS, + fontweight="bold", + color=RED_ACCENT, + ) + ax.annotate( + "Garb 1", + xy=(60, max(counts) * 0.6), + fontsize=FS_SMALL, + ha="center", + arrowprops={"arrowstyle": "->", "color": GRAY5}, + xytext=(30, max(counts) * 0.45), + ) + ax.annotate( + "Garb 2", + xy=(190, max(counts) * 0.5), + fontsize=FS_SMALL, + ha="center", + arrowprops={"arrowstyle": "->", "color": GRAY5}, + xytext=(220, max(counts) * 0.35), + ) + ax.set_xlabel("Jasność piksela (0-255)", fontsize=FS) + ax.set_ylabel("Liczba pikseli", fontsize=FS) + ax.set_title("Histogram bimodalny", fontsize=FS_TITLE, fontweight="bold") + ax.legend(fontsize=FS_SMALL, loc="upper right") + ax.set_xlim(0, 255) + + # --- Panel 2: Within-class variance explanation --- + _draw_otsu_variance_panel(axes[1]) + + # --- Panel 3: Jednorodność explanation --- + ax = axes[2] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title('"Jednorodne" = małe σ²', fontsize=FS_TITLE, fontweight="bold") + + # Draw two clusters + # Good separation + c0 = rng.normal(2, 0.4, 15) + c1 = rng.normal(7, 0.4, 15) + y_pos_0 = rng.uniform(6, 8, 15) + y_pos_1 = rng.uniform(6, 8, 15) + ax.scatter(c0, y_pos_0, c=ACCENT, s=30, zorder=5, label="Klasa 0") + ax.scatter(c1, y_pos_1, c=RED_ACCENT, s=30, zorder=5, label="Klasa 1") + ax.axvline(x=4.5, color=GREEN_ACCENT, linewidth=2, linestyle="--") + ax.text( + 4.5, + 8.8, + "T optymalny", + ha="center", + fontsize=FS_SMALL, + color=GREEN_ACCENT, + fontweight="bold", + ) + ax.text( + 2, 5.3, "σ₀² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=ACCENT + ) + ax.text( + 7, 5.3, "σ₁² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=RED_ACCENT + ) + ax.text( + 5, + 4, + "→ σ²_wewnątrz MINIMALNA\n→ klasy JEDNORODNE\n→ dobra segmentacja!", + ha="center", + fontsize=FS, + fontweight="bold", + color=GREEN_ACCENT, + ) + + # Bad separation + c0b = rng.normal(3.5, 1.5, 15) + c1b = rng.normal(6, 1.5, 15) + y_pos_0b = rng.uniform(1, 3, 15) + y_pos_1b = rng.uniform(1, 3, 15) + ax.scatter(c0b, y_pos_0b, c=ACCENT, s=30, marker="x", zorder=5) + ax.scatter(c1b, y_pos_1b, c=RED_ACCENT, s=30, marker="x", zorder=5) + ax.axvline(x=4.5, color=GRAY4, linewidth=1, linestyle=":", ymin=0, ymax=0.35) + ax.text( + 5, + 0.3, + "σ²_wewnątrz DUŻA → klasy mieszają się → zły próg", + ha="center", + fontsize=FS_SMALL, + color=GRAY5, + ) + + ax.legend(fontsize=FS_SMALL, loc="upper left") + + _save_figure("q23_otsu_bimodal.png") + + +def _draw_watershed_result_panel(ax: Axes) -> None: + """Draw panel 3: watershed result with over-segmentation problem.""" + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Krok 3: wynik", fontsize=FS_TITLE, fontweight="bold") + + rect1 = FancyBboxPatch( + (0.5, 6), + 3.5, + 3.2, + boxstyle="round,pad=0.1", + facecolor=ACCENT_LIGHT, + edgecolor=BLACK, + linewidth=1, + ) + ax.add_patch(rect1) + ax.text(2.25, 8.8, "Ideał: 2 segmenty", fontsize=FS, ha="center", fontweight="bold") + ax.text(2.25, 7.5, "Segment A Segment B", fontsize=FS_SMALL, ha="center") + ax.text( + 2.25, + 6.7, + "(po marker-controlled)", + fontsize=FS_SMALL, + ha="center", + color=GREEN_ACCENT, + ) + + rect2 = FancyBboxPatch( + (5.5, 6), + 4, + 3.2, + boxstyle="round,pad=0.1", + facecolor="#FFCDD2", + edgecolor=BLACK, + linewidth=1, + ) + ax.add_patch(rect2) + ax.text( + 7.5, + 8.8, + "Problem: over-segmentation", + fontsize=FS, + ha="center", + fontweight="bold", + color=RED_ACCENT, + ) + ax.text( + 7.5, + 7.8, + "47 regionów zamiast 2!", + fontsize=FS_SMALL, + ha="center", + color=RED_ACCENT, + ) + ax.text(7.5, 7.1, "Każde mini-minimum", fontsize=FS_SMALL, ha="center") + ax.text(7.5, 6.5, '→ osobna „dolina"', fontsize=FS_SMALL, ha="center") + + # Apply marker-controlled solution + rect3 = FancyBboxPatch( + (1, 0.5), + 8, + 4.5, + boxstyle="round,pad=0.15", + facecolor=GRAY1, + edgecolor=GREEN_ACCENT, + linewidth=1.5, + ) + ax.add_patch(rect3) + ax.text( + 5, + 4.3, + "Rozwiązanie: Marker-controlled watershed", + fontsize=FS, + ha="center", + fontweight="bold", + color=GREEN_ACCENT, + ) + ax.text( + 5, + 3.4, + '1. Zaznacz ręcznie „seeds" (markery) w każdym obiekcie', + fontsize=FS_SMALL, + ha="center", + ) + ax.text( + 5, + 2.7, + "2. Zalewaj TYLKO od tych markerów (nie od wszystkich minimów)", + fontsize=FS_SMALL, + ha="center", + ) + ax.text( + 5, + 2.0, + "3. Eliminuje fałszywe doliny z szumu", + fontsize=FS_SMALL, + ha="center", + ) + ax.text( + 5, + 1.2, + "Wynik: tyle segmentów, ile podano markerów", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + ) + + +def generate_watershed() -> None: + """Generate watershed.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 3.8)) + + # --- Panel 1: Image as topographic surface --- + ax = axes[0] + x = np.linspace(0, 10, 200) + # Create a surface with two valleys and a ridge + surface = ( + 3 * np.exp(-((x - 3) ** 2) / 1.5) + + 4 * np.exp(-((x - 7) ** 2) / 1.2) + + 0.5 * np.sin(x * 2) + + 1 + ) + # Invert: valleys at objects (dark), peaks at boundaries (bright) + surface_inv = 6 - surface + 1 + + ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7) + ax.plot(x, surface_inv, color=BLACK, linewidth=1.5) + + # Mark valleys + ax.annotate( + "Dolina 1\n(obiekt A)", + xy=(3, surface_inv[60]), + fontsize=FS_SMALL, + ha="center", + va="bottom", + arrowprops={"arrowstyle": "->", "color": ACCENT}, + xytext=(1.5, 5.5), + ) + ax.annotate( + "Dolina 2\n(obiekt B)", + xy=(7, surface_inv[140]), + fontsize=FS_SMALL, + ha="center", + va="bottom", + arrowprops={"arrowstyle": "->", "color": RED_ACCENT}, + xytext=(8.5, 5.5), + ) + # Mark ridge + ax.annotate( + "Grań\n(granica)", + xy=(5, surface_inv[100]), + fontsize=FS_SMALL, + ha="center", + va="bottom", + arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT}, + xytext=(5, 6.5), + ) + + ax.set_xlabel("Pozycja piksela", fontsize=FS) + ax.set_ylabel("Jasność (= wysokość)", fontsize=FS) + ax.set_title("Krok 1: obraz → teren", fontsize=FS_TITLE, fontweight="bold") + ax.set_ylim(0, 7) + + # --- Panel 2: Flooding --- + ax = axes[1] + ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7) + ax.plot(x, surface_inv, color=BLACK, linewidth=1.5) + + # Water level + water_level = 3.2 + + # Fill water in valley 1 + x_v1 = x[(x > 1) & (x < _RIDGE_X)] + s_v1 = surface_inv[(x > 1) & (x < _RIDGE_X)] + ax.fill_between( + x_v1, s_v1, water_level, where=s_v1 < water_level, color=ACCENT_LIGHT, alpha=0.6 + ) + # Fill water in valley 2 + x_v2 = x[(x > _RIDGE_X) & (x < _VALLEY2_END)] + s_v2 = surface_inv[(x > _RIDGE_X) & (x < _VALLEY2_END)] + ax.fill_between( + x_v2, s_v2, water_level, where=s_v2 < water_level, color="#FFCDD2", alpha=0.6 + ) + + ax.axhline(y=water_level, color=ACCENT, linewidth=1, linestyle="--", alpha=0.5) + ax.text(3, 2.5, "Woda A", fontsize=FS, ha="center", color=ACCENT, fontweight="bold") + ax.text( + 7, 2.2, "Woda B", fontsize=FS, ha="center", color=RED_ACCENT, fontweight="bold" + ) + ax.annotate( + "Tu się spotkają!\n→ GRANICA", + xy=(5, surface_inv[100]), + fontsize=FS_SMALL, + ha="center", + color=GREEN_ACCENT, + fontweight="bold", + arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT}, + xytext=(5, 6.2), + ) + + ax.set_xlabel("Pozycja piksela", fontsize=FS) + ax.set_title("Krok 2: zalewanie", fontsize=FS_TITLE, fontweight="bold") + ax.set_ylim(0, 7) + + # --- Panel 3: Result with problem --- + _draw_watershed_result_panel(axes[2]) + + _save_figure("q23_watershed.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_receptive_transformer.py b/python_pkg/praca_magisterska_video/generate_images/_q23_receptive_transformer.py new file mode 100644 index 0000000..28491a3 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_receptive_transformer.py @@ -0,0 +1,286 @@ +"""Receptive field and transformer diagram generators.""" + +from __future__ import annotations + +from _q23_common import ( + _HIGHLIGHT_END, + _HIGHLIGHT_START, + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TITLE, + GRAY3, + GRAY5, + GREEN_ACCENT, + RED_ACCENT, + WHITE, + _save_figure, + plt, +) +from matplotlib import patches + + +def generate_receptive_field() -> None: + """Generate receptive field.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) + + def draw_grid( + ax: patches.Axes, + size: int, + highlight_cells: list[tuple[int, int]], + highlight_color: str, + title: str, + grid_offset: tuple[int, int] = (0, 0), + ) -> None: + """Draw grid.""" + ox, oy = grid_offset + for i in range(size): + for j in range(size): + color = WHITE + if (i, j) in highlight_cells: + color = highlight_color + rect = patches.Rectangle( + (ox + j, oy + size - 1 - i), + 1, + 1, + facecolor=color, + edgecolor=GRAY3, # Use GRAY3 instead of GRAY4 since unused + linewidth=0.5, + ) + ax.add_patch(rect) + ax.set_title(title, fontsize=FS_TITLE, fontweight="bold") + + # --- Panel 1: Standard 3x3 conv receptive field --- + ax = axes[0] + ax.set_xlim(-0.5, 7.5) + ax.set_ylim(-1, 8) + ax.set_aspect("equal") + ax.axis("off") + + # 7x7 input grid + highlight_3x3 = [ + (2, 2), + (2, 3), + (2, 4), + (3, 2), + (3, 3), + (3, 4), + (4, 2), + (4, 3), + (4, 4), + ] + draw_grid(ax, 7, highlight_3x3, ACCENT_LIGHT, "Zwykła conv 3x3") + ax.text( + 3.5, + -0.5, + "RF = 3x3 pikseli", + fontsize=FS, + ha="center", + fontweight="bold", + color=ACCENT, + ) + + # --- Panel 2: Dilated conv (rate=2) --- + ax = axes[1] + ax.set_xlim(-0.5, 7.5) + ax.set_ylim(-1, 8) + ax.set_aspect("equal") + ax.axis("off") + + # 7x7 input grid with dilated highlights + highlight_dilated = [ + (1, 1), + (1, 3), + (1, 5), + (3, 1), + (3, 3), + (3, 5), + (5, 1), + (5, 3), + (5, 5), + ] + draw_grid(ax, 7, highlight_dilated, "#FFCDD2", "Dilated conv 3x3\n(rate=2)") + ax.text( + 3.5, + -0.5, + "RF = 5x5, ale 9 parametrów!", + fontsize=FS, + ha="center", + fontweight="bold", + color=RED_ACCENT, + ) + + # Connect dots to show pattern + dots_x = [1.5, 3.5, 5.5, 1.5, 3.5, 5.5, 1.5, 3.5, 5.5] + dots_y = [5.5, 5.5, 5.5, 3.5, 3.5, 3.5, 1.5, 1.5, 1.5] + ax.scatter(dots_x, dots_y, c=RED_ACCENT, s=30, zorder=5) + + # --- Panel 3: Comparison --- + ax = axes[2] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title( + "Receptive Field\n(pole widzenia neuronu)", fontsize=FS_TITLE, fontweight="bold" + ) + + y = 8.5 + lines = [ + ("RF = ile pikseli WEJŚCIOWYCH", FS, BLACK, "bold"), + ("wpływa na JEDEN piksel wyjścia", FS, BLACK, "bold"), + ("", 0, "", ""), + ("Rate (współczynnik dylatacji):", FS, BLACK, "bold"), + (' rate=1: filtr „dotyka" sąsiadów', FS_SMALL, BLACK, "normal"), + (" rate=2: co drugi piksel → RF = 5x5", FS_SMALL, BLACK, "normal"), + (" rate=3: co trzeci → RF = 7x7", FS_SMALL, BLACK, "normal"), + (" WIĘCEJ kontekstu, TE SAME wagi!", FS_SMALL, GREEN_ACCENT, "bold"), + ("", 0, "", ""), + ("Dlaczego ważne w segmentacji?", FS, BLACK, "bold"), + (" Piksel sam nie wie czym jest.", FS_SMALL, BLACK, "normal"), + (" Potrzebuje KONTEKSTU (otoczenia).", FS_SMALL, BLACK, "normal"), + (" Większe RF → widzi obok budynki", FS_SMALL, BLACK, "normal"), + (' → wie, że TEN piksel to „droga"', FS_SMALL, GREEN_ACCENT, "bold"), + ("", 0, "", ""), + ("Global Average Pooling:", FS, BLACK, "bold"), + (" Mapa HxWxC → 1x1xC", FS_SMALL, BLACK, "normal"), + (" Średnia z CAŁEGO feature map", FS_SMALL, BLACK, "normal"), + (" RF = nieskończone (cały obraz)", FS_SMALL, GREEN_ACCENT, "bold"), + ] + for txt, size, color, weight in lines: + if txt == "": + y -= 0.2 + continue + ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top") + y -= 0.45 + + _save_figure("q23_receptive_field.png") + + +def generate_transformer() -> None: + """Generate transformer.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) + + # --- Panel 1: CNN local vs Transformer global --- + ax = axes[0] + ax.set_xlim(-0.5, 8.5) + ax.set_ylim(-1.5, 8.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("CNN: widzi LOKALNIE", fontsize=FS_TITLE, fontweight="bold") + + # Draw 8x8 grid + for i in range(8): + for j in range(8): + color = WHITE + if ( + _HIGHLIGHT_START <= i <= _HIGHLIGHT_END + and _HIGHLIGHT_START <= j <= _HIGHLIGHT_END + ): + color = ACCENT_LIGHT + rect = patches.Rectangle( + (j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3 + ) + ax.add_patch(rect) + + # Highlight center + rect = patches.Rectangle( + (4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.7 + ) + ax.add_patch(rect) + ax.text( + 4.5, + 4.5, + "?", + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + color=WHITE, + ) + ax.text( + 4.5, + -0.8, + "Filtr 3x3 widzi tylko\n9 sąsiednich pikseli", + fontsize=FS_SMALL, + ha="center", + color=ACCENT, + ) + + # --- Panel 2: Transformer global --- + ax = axes[1] + ax.set_xlim(-0.5, 8.5) + ax.set_ylim(-1.5, 8.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Transformer: widzi GLOBALNIE", fontsize=FS_TITLE, fontweight="bold") + + # Draw 8x8 grid all highlighted + for i in range(8): + for j in range(8): + color = "#FFCDD2" + rect = patches.Rectangle( + (j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3 + ) + ax.add_patch(rect) + + rect = patches.Rectangle( + (4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.9 + ) + ax.add_patch(rect) + ax.text( + 4.5, + 4.5, + "?", + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + color=WHITE, + ) + ax.text( + 4.5, + -0.8, + 'Self-attention „pyta"\nALL 64 piksele naraz', + fontsize=FS_SMALL, + ha="center", + color=RED_ACCENT, + ) + + # --- Panel 3: SOTA + Transformer explanation --- + ax = axes[2] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Transformer & SOTA", fontsize=FS_TITLE, fontweight="bold") + + y = 9.2 + lines = [ + ("Transformer:", FS, BLACK, "bold"), + (" Architektura z 2017 (Vaswani et al.)", FS_SMALL, BLACK, "normal"), + (" Oryginalnie do NLP (tłumaczenie)", FS_SMALL, BLACK, "normal"), + (" Kluczowy mechanizm: SELF-ATTENTION", FS_SMALL, ACCENT, "bold"), + ("", 0, "", ""), + ("Self-attention w skrócie:", FS, BLACK, "bold"), + (" Każdy piksel tworzy trzy wektory:", FS_SMALL, BLACK, "normal"), + (' Q (Query — „czego szukam?")', FS_SMALL, ACCENT, "normal"), + (' K (Key — „co oferuję innych")', FS_SMALL, RED_ACCENT, "normal"), + (' V (Value — „moja wartość")', FS_SMALL, GREEN_ACCENT, "normal"), + (" Attention = softmax(Q·Kᵀ/√d)·V", FS_SMALL, BLACK, "bold"), + (" Koszt: O(n²) — n=liczba pikseli", FS_SMALL, RED_ACCENT, "normal"), + ("", 0, "", ""), + ("SOTA = State Of The Art:", FS, BLACK, "bold"), + (" Najlepszy znany wynik na benchmarku", FS_SMALL, BLACK, "normal"), + (' Np. „mIoU 85.1% na ADE20K = SOTA"', FS_SMALL, BLACK, "normal"), + (" Ciągle się zmienia (nowy paper", FS_SMALL, GRAY5, "normal"), + (" → nowy SOTA)", FS_SMALL, GRAY5, "normal"), + ] + for txt, size, color, weight in lines: + if txt == "": + y -= 0.15 + continue + ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top") + y -= 0.45 + + _save_figure("q23_transformer_attention.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q23_region_diy.py b/python_pkg/praca_magisterska_video/generate_images/_q23_region_diy.py new file mode 100644 index 0000000..fd917e9 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q23_region_diy.py @@ -0,0 +1,408 @@ +"""Region growing and DIY thresholding diagram generators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q23_common import ( + _BRIGHT_THRESHOLD, + _OTSU_THRESHOLD, + ACCENT, + ACCENT_LIGHT, + BLACK, + FS, + FS_SMALL, + FS_TINY, + FS_TITLE, + GRAY3, + GRAY4, + GRAY5, + GREEN_ACCENT, + RED_ACCENT, + WHITE, + _save_figure, + np, + plt, + rng, +) +from matplotlib import patches + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def _draw_region_growing_grid(ax: Axes) -> None: + """Draw panel 2: region growing step-by-step grid.""" + ax.set_xlim(-0.5, 6.5) + ax.set_ylim(-1.5, 7.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Region Growing: krok po kroku", + fontsize=FS_TITLE, + fontweight="bold", + ) + + pixel_grid = np.array( + [ + [150, 153, 148, 200, 210, 205], + [147, 155, 152, 195, 208, 200], + [145, 148, 160, 190, 195, 210], + [200, 195, 190, 155, 148, 150], + [210, 205, 200, 150, 152, 145], + [215, 208, 195, 148, 147, 155], + ] + ) + region_mask = np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ] + ) + + for i in range(6): + for j in range(6): + v = pixel_grid[i, j] + if region_mask[i, j] == 1 and v < _BRIGHT_THRESHOLD: + cell_color = ACCENT_LIGHT + elif region_mask[i, j] == 1: + cell_color = "#E0E0E0" + else: + cell_color = WHITE + if i == 1 and j == 1: + cell_color = "#FFD54F" + rect = patches.Rectangle( + (j, 5 - i), + 1, + 1, + facecolor=cell_color, + edgecolor=GRAY4, + linewidth=0.5, + ) + ax.add_patch(rect) + ax.text( + j + 0.5, + 5 - i + 0.5, + str(v), + ha="center", + va="center", + fontsize=FS_TINY, + fontweight="bold", + ) + + ax.annotate( + "SEED\n(155)", + xy=(1.5, 4.5), + fontsize=FS_SMALL, + ha="center", + color=RED_ACCENT, + fontweight="bold", + arrowprops={"arrowstyle": "->", "color": RED_ACCENT}, + xytext=(-0.5, 7), + ) + ax.text( + 3, + -0.8, + "Próg = 20\nNiebieski = region (|val - seed| < 20)", + fontsize=FS_TINY, + ha="center", + color=ACCENT, + ) + + +def _draw_bfs_expansion(ax: Axes) -> None: + """Draw panel 3: BFS expansion visualization.""" + ax.set_xlim(-0.5, 6.5) + ax.set_ylim(-1.5, 7.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Rosnący region (BFS)", + fontsize=FS_TITLE, + fontweight="bold", + ) + + wave_colors = ["#FFD54F", "#FFF176", "#FFF9C4", ACCENT_LIGHT, "#B3D4FC"] + wave_labels = ["Seed", "Fala 1", "Fala 2", "Fala 3", "Fala 4"] + waves = [ + [(1, 1)], + [(0, 1), (1, 0), (1, 2), (2, 1)], + [(0, 0), (0, 2), (2, 0), (2, 2)], + ] + + for i in range(6): + for j in range(6): + cell_color = WHITE + for w_idx, wave in enumerate(waves): + if (i, j) in wave: + cell_color = wave_colors[w_idx] + rect = patches.Rectangle( + (j, 5 - i), + 1, + 1, + facecolor=cell_color, + edgecolor=GRAY4, + linewidth=0.5, + ) + ax.add_patch(rect) + + seed_x, seed_y = 1.5, 4.5 + for dx, dy, _label in [ + (0, 1, ""), + (0, -1, ""), + (1, 0, ""), + (-1, 0, ""), + ]: + ax.annotate( + "", + xy=(seed_x + dx * 0.7, seed_y + dy * 0.7), + xytext=(seed_x, seed_y), + arrowprops={ + "arrowstyle": "->", + "color": RED_ACCENT, + "lw": 1.2, + }, + ) + + ax.text( + 3, + -0.5, + "BFS: sprawdzaj sąsiadów,\ndodawaj podobne do kolejki", + fontsize=FS_TINY, + ha="center", + color=GRAY5, + ) + + for w_idx, (wave_color, label) in enumerate( + zip(wave_colors[:3], wave_labels[:3], strict=False) + ): + rect = patches.Rectangle( + (4, 6.5 - w_idx * 0.7), + 0.5, + 0.5, + facecolor=wave_color, + edgecolor=GRAY4, + linewidth=0.5, + ) + ax.add_patch(rect) + ax.text( + 4.8, + 6.75 - w_idx * 0.7, + label, + fontsize=FS_TINY, + va="center", + ) + + +def generate_region_growing() -> None: + """Generate region growing.""" + _fig, axes = plt.subplots(1, 3, figsize=(11, 4.2)) + + # --- Panel 1: Manual vs automatic seed --- + ax = axes[0] + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + ax.axis("off") + ax.set_title("Seed: ręcznie vs automatycznie", fontsize=FS_TITLE, fontweight="bold") + + y = 9.2 + lines = [ + ("Ręczny seed:", FS, ACCENT, "bold"), + (" Użytkownik klika na obraz", FS_SMALL, BLACK, "normal"), + (' → „tu jest obiekt, od tego zacznij"', FS_SMALL, BLACK, "normal"), + (" Użycie: segmentacja interaktywna", FS_SMALL, GRAY5, "normal"), + (" (np. Photoshop — magic wand tool)", FS_SMALL, GRAY5, "normal"), + ("", 0, "", ""), + ("Automatyczny seed:", FS, RED_ACCENT, "bold"), + (" 1. Histogram → lokalne maxima", FS_SMALL, BLACK, "normal"), + (" (najczęstsza jasność → seed)", FS_SMALL, GRAY5, "normal"), + (" 2. Grid: siatka co N pikseli", FS_SMALL, BLACK, "normal"), + (" (np. seed co 50 px → 100 seedów)", FS_SMALL, GRAY5, "normal"), + (" 3. Losowe próbkowanie", FS_SMALL, BLACK, "normal"), + (" 4. Ekstrema lokalne gradientu", FS_SMALL, BLACK, "normal"), + ("", 0, "", ""), + ("Dlaczego OR?", FS, GREEN_ACCENT, "bold"), + (" Ręczny → precyzyjny, ale wolny", FS_SMALL, BLACK, "normal"), + (" Auto → szybki, ale over-segmentation", FS_SMALL, BLACK, "normal"), + ] + for txt, size, color, weight in lines: + if txt == "": + y -= 0.15 + continue + ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top") + y -= 0.45 + + # --- Panel 2: Region growing step by step --- + _draw_region_growing_grid(axes[1]) + + # --- Panel 3: BFS expansion --- + _draw_bfs_expansion(axes[2]) + + _save_figure("q23_region_growing.png") + + +def _draw_otsu_variance_and_pseudocode( + ax_var: Axes, + ax_code: Axes, + img: np.ndarray, +) -> int: + """Draw panels 4 and 5: Otsu variance plot and pseudocode.""" + thresholds = range(10, 245) + variances = [] + for t in thresholds: + c0 = img[img <= t].ravel() + c1 = img[img > t].ravel() + if len(c0) == 0 or len(c1) == 0: + variances.append(np.nan) + continue + w0 = len(c0) / len(img.ravel()) + w1 = len(c1) / len(img.ravel()) + var = w0 * np.var(c0) + w1 * np.var(c1) + variances.append(var) + + ax_var.plot(list(thresholds), variances, color=ACCENT, linewidth=1.5) + best_t = list(thresholds)[np.nanargmin(variances)] + ax_var.axvline( + x=best_t, + color=RED_ACCENT, + linewidth=1.5, + linestyle="--", + label=f"Otsu T={best_t}", + ) + ax_var.scatter( + [best_t], + [np.nanmin(variances)], + c=RED_ACCENT, + s=60, + zorder=5, + ) + ax_var.set_xlabel("Próg T", fontsize=FS_SMALL) + ax_var.set_ylabel("σ² wewnątrzklasowa", fontsize=FS_SMALL) + ax_var.set_title( + "Krok 4: Otsu szuka min σ²", + fontsize=FS, + fontweight="bold", + ) + ax_var.legend(fontsize=FS_TINY) + + ax_code.set_xlim(0, 10) + ax_code.set_ylim(0, 10) + ax_code.axis("off") + ax_code.set_title("Pseudokod Otsu", fontsize=FS, fontweight="bold") + + code_lines = [ + "best_T = 0", + "min_var = ∞", + "", + "for T in 0..255:", + " c0 = piksele z jasność ≤ T", + " c1 = piksele z jasność > T", + " w0 = len(c0) / len(all)", + " w1 = len(c1) / len(all)", + " var = w0·var(c0) + w1·var(c1)", + " if var < min_var:", + " min_var = var", + " best_T = T", + "", + "return best_T # optymalny próg", + ] + for i, line in enumerate(code_lines): + txt_color = ACCENT if "best_T = T" in line or "return" in line else BLACK + ax_code.text( + 0.5, + 9.5 - i * 0.65, + line, + fontsize=FS_TINY, + fontfamily="monospace", + color=txt_color, + fontweight="bold" if txt_color == ACCENT else "normal", + ) + return int(best_t) + + +def generate_diy_thresholding() -> None: + """Generate diy thresholding.""" + _fig, axes = plt.subplots(2, 3, figsize=(11, 7)) + + # Create a simple synthetic image: dark circle on bright background + size = 64 + img = np.ones((size, size)) * 200 # bright background + yy, xx = np.mgrid[:size, :size] + mask = ((xx - 32) ** 2 + (yy - 32) ** 2) < 15**2 + img[mask] = 60 # dark circle + # Add some noise + img += rng.normal(0, 10, img.shape) + img = np.clip(img, 0, 255) + + # --- Panel 1: Original image --- + ax = axes[0, 0] + ax.imshow(img, cmap="gray", vmin=0, vmax=255) + ax.set_title("Krok 1: obraz wejściowy", fontsize=FS, fontweight="bold") + ax.axis("off") + ax.text(32, -3, "64x64 pikseli, szare", fontsize=FS_TINY, ha="center") + + # --- Panel 2: Histogram --- + ax = axes[0, 1] + counts, _bins, _ = ax.hist( + img.ravel(), bins=50, color=GRAY3, edgecolor=GRAY5, linewidth=0.5 + ) + ax.axvline( + x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="T=128 (Otsu)" + ) + ax.set_xlabel("Jasność", fontsize=FS_SMALL) + ax.set_ylabel("Piksele", fontsize=FS_SMALL) + ax.set_title("Krok 2: histogram\n(bimodalny!)", fontsize=FS, fontweight="bold") + ax.legend(fontsize=FS_TINY) + ax.annotate( + "Garb 1\n(obiekt)", + xy=(60, max(counts) * 0.5), + fontsize=FS_TINY, + ha="center", + color=ACCENT, + fontweight="bold", + ) + ax.annotate( + "Garb 2\n(tło)", + xy=(200, max(counts) * 0.5), + fontsize=FS_TINY, + ha="center", + color=RED_ACCENT, + fontweight="bold", + ) + + # --- Panel 3: Thresholding result --- + ax = axes[0, 2] + binary = (img > _OTSU_THRESHOLD).astype(float) + ax.imshow(binary, cmap="gray", vmin=0, vmax=1) + ax.set_title("Krok 3: progowanie T=128", fontsize=FS, fontweight="bold") + ax.axis("off") + ax.text(32, -3, "Biały = tło, Czarny = obiekt", fontsize=FS_TINY, ha="center") + + # --- Panels 4+5: Otsu variance plot + pseudocode --- + best_t = _draw_otsu_variance_and_pseudocode( + axes[1, 0], + axes[1, 1], + img, + ) + + # --- Panel 6: Final result with Otsu --- + ax = axes[1, 2] + binary_otsu = (img > best_t).astype(float) + ax.imshow(binary_otsu, cmap="gray", vmin=0, vmax=1) + ax.set_title(f"Krok 5: wynik Otsu (T={best_t})", fontsize=FS, fontweight="bold") + ax.axis("off") + ax.text( + 32, + -3, + "Automatyczny próg!", + fontsize=FS_TINY, + ha="center", + color=GREEN_ACCENT, + fontweight="bold", + ) + + _save_figure("q23_diy_thresholding.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_common.py b/python_pkg/praca_magisterska_video/generate_images/_q24_common.py new file mode 100644 index 0000000..bf30e5c --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_common.py @@ -0,0 +1,186 @@ +"""Common utilities and constants for Q24 diagram generation. + +Monochrome, A4-printable PNGs (300 DPI). +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib as mpl + +mpl.use("Agg") + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.figure import Figure + +_logger = logging.getLogger(__name__) + +rng = np.random.default_rng(42) + +DPI = 300 +BG = "white" +LN = "black" +FS = 8 +FS_TITLE = 11 +FS_SMALL = 6.5 +FS_LABEL = 9 +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +GRAY1 = "#E8E8E8" +GRAY2 = "#D0D0D0" +GRAY3 = "#B8B8B8" +GRAY4 = "#F5F5F5" +GRAY5 = "#C0C0C0" + +_PIXEL_BRIGHT_THRESH = 127 +_GRADIENT_BRIGHT_THRESH = 100 +_DATA_BRIGHT_THRESH = 5 +_II_BRIGHT_THRESH = 25 +_DOTS_STAGE_IDX = 2 + + +def draw_box( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + text: str, + *, + fill: str = "white", + lw: float = 1.2, + fontsize: float = FS, + fontweight: str = "normal", + ha: str = "center", + va: str = "center", + rounded: bool = True, + edgecolor: str = LN, + linestyle: str = "-", +) -> None: + """Draw box.""" + if rounded: + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.05", + lw=lw, + edgecolor=edgecolor, + facecolor=fill, + linestyle=linestyle, + ) + else: + rect = mpatches.Rectangle( + (x, y), + w, + h, + lw=lw, + edgecolor=edgecolor, + facecolor=fill, + linestyle=linestyle, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h / 2, + text, + ha=ha, + va=va, + fontsize=fontsize, + fontweight=fontweight, + wrap=True, + ) + + +def draw_arrow( + ax: Axes, + x1: float, + y1: float, + x2: float, + y2: float, + *, + lw: float = 1.2, + style: str = "->", + color: str = LN, +) -> None: + """Draw arrow.""" + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": style, "color": color, "lw": lw}, + ) + + +def save_fig(fig: Figure, name: str) -> None: + """Save fig.""" + path = str(Path(OUTPUT_DIR) / name) + fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15) + plt.close(fig) + _logger.info(" Saved: %s", path) + + +def draw_table( + ax: Axes, + headers: list[str], + rows: list[list[str]], + x0: float, + y0: float, + col_widths: list[float], + *, + row_h: float = 0.4, + header_fill: str = GRAY2, + row_fills: list[str] | None = None, + fontsize: float = FS, + header_fontsize: float | None = None, +) -> None: + """Draw table.""" + if header_fontsize is None: + header_fontsize = fontsize + len(headers) + cx = x0 + for j, hdr in enumerate(headers): + draw_box( + ax, + cx, + y0, + col_widths[j], + row_h, + hdr, + fill=header_fill, + fontsize=header_fontsize, + fontweight="bold", + rounded=False, + ) + cx += col_widths[j] + for i, row in enumerate(rows): + cy = y0 - (i + 1) * row_h + cx = x0 + fill = GRAY4 if (i % 2 == 0) else "white" + if row_fills and i < len(row_fills): + fill = row_fills[i] + for j, cell in enumerate(row): + fw = "bold" if j == 0 else "normal" + draw_box( + ax, + cx, + cy, + col_widths[j], + row_h, + cell, + fill=fill, + fontsize=fontsize, + fontweight=fw, + rounded=False, + ) + cx += col_widths[j] diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_fpn_tasks_cnn.py b/python_pkg/praca_magisterska_video/generate_images/_q24_fpn_tasks_cnn.py new file mode 100644 index 0000000..563e5b9 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_fpn_tasks_cnn.py @@ -0,0 +1,412 @@ +"""FPN, anchor boxes, detection tasks, and CNN architecture diagrams.""" + +from __future__ import annotations + +from _q24_common import ( + FS, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + LN, + draw_arrow, + draw_box, + np, + plt, + save_fig, +) +import matplotlib.patches as mpatches + + +# ============================================================ +# 16. FPN (Feature Pyramid Network) +# ============================================================ +def draw_fpn() -> None: + """Draw fpn.""" + fig, ax = plt.subplots(figsize=(9, 5)) + ax.set_xlim(-0.5, 9.5) + ax.set_ylim(-0.5, 5.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "FPN (Feature Pyramid Network) — detekcja obiektów wszystkich rozmiarów", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + levels = [ + (0, 0, 2.0, 2.0, "C2\n56x56", "duże\ndetale"), + (0, 2.2, 1.5, 1.5, "C3\n28x28", ""), + (0, 3.9, 1.0, 1.0, "C4\n14x14", ""), + (0, 5.1, 0.6, 0.6, "C5\n7x7", "kontekst"), + ] + + for x, y, w, h, label, note in levels: + ax.add_patch( + mpatches.Rectangle((x, y - h), w, h, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + ax.text( + x + w / 2, + y - h / 2, + label, + ha="center", + va="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + if note: + ax.text( + x + w + 0.15, + y - h / 2, + note, + ha="left", + va="center", + fontsize=5, + style="italic", + ) + + ax.text( + 1.0, -0.3, "Bottom-up\n(backbone)", ha="center", fontsize=FS, fontweight="bold" + ) + + # Top-down + lateral + td_levels = [ + (4.5, 5.1, 0.6, 0.6, "P5"), + (4.5, 3.9, 1.0, 1.0, "P4"), + (4.5, 2.2, 1.5, 1.5, "P3"), + (4.5, 0, 2.0, 2.0, "P2"), + ] + + for x, y, w, h, label in td_levels: + ax.add_patch( + mpatches.Rectangle( + (x, y - h + h), w, h, facecolor=GRAY2, edgecolor=LN, lw=1.5 + ) + ) + ax.text( + x + w / 2, + y - h / 2 + h, + label, + ha="center", + va="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + + # Lateral connections + for (_, y1, w1, h1, _, _), (x2, y2, _w2, h2, _) in zip( + levels, td_levels, strict=False + ): + draw_arrow(ax, w1 + 0.2, y1 - h1 / 2, x2 - 0.1, y2 + h2 / 2, lw=1, style="->") + + # Top-down arrows + for i in range(len(td_levels) - 1): + x2, y2, w2, h2, _ = td_levels[i] + x3, y3, w3, h3, _ = td_levels[i + 1] + draw_arrow( + ax, + x2 + w2 / 2, + y2, + x3 + w3 / 2, + y3 + h3 + 0.1, + lw=1.2, + style="->", + color=GRAY3, + ) + + ax.text( + 5.5, + -0.3, + "Top-down + lateral\n(FPN)", + ha="center", + fontsize=FS, + fontweight="bold", + ) + + # Detection outputs + det_labels = ["małe obj.", "średnie", "duże", "b. duże"] + for i, (x, y, w, h, _label) in enumerate(td_levels): + draw_arrow(ax, x + w + 0.1, y + h / 2, 7.5, y + h / 2, lw=0.8) + ax.text( + 7.7, + y + h / 2, + f"detekcja:\n{det_labels[3 - i]}", + fontsize=FS_SMALL, + va="center", + ) + + save_fig(fig, "q24_fpn.png") + + +# ============================================================ +# 17. Anchor boxes +# ============================================================ +def draw_anchor_boxes() -> None: + """Draw anchor boxes.""" + fig, ax = plt.subplots(figsize=(7, 5)) + ax.set_title( + "Anchor boxes — predefiniowane kształty", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1)) + + # Center point + cx, cy = 3, 2.5 + ax.plot(cx, cy, "ko", markersize=8, zorder=5) + ax.text(cx + 0.15, cy + 0.15, "(x, y)", fontsize=FS, fontweight="bold") + + # 9 anchors: 3 sizes x 3 ratios + anchors = [ + (0.8, 0.8, "-", "1:1 small"), + (1.6, 1.6, "-", "1:1 medium"), + (2.4, 2.4, "-", "1:1 large"), + (0.6, 1.2, "--", "1:2 small"), + (1.2, 2.4, "--", "1:2 medium"), + (1.8, 3.6, "--", "1:2 large"), + (1.2, 0.6, ":", "2:1 small"), + (2.4, 1.2, ":", "2:1 medium"), + (3.6, 1.8, ":", "2:1 large"), + ] + + for w, h, ls, _label in anchors: + rect = mpatches.Rectangle( + (cx - w / 2, cy - h / 2), + w, + h, + facecolor="none", + edgecolor=LN, + lw=1.2, + linestyle=ls, + ) + ax.add_patch(rect) + + # Legend-style labels + ax.text( + 3, + -0.5, + "9 anchorów = 3 rozmiary x 3 proporcje (1:1, 1:2, 2:1)\n" + "Sieć predykuje PRZESUNIĘCIE od najbliższego anchora", + ha="center", + fontsize=FS, + style="italic", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.set_xlim(-0.5, 6.5) + ax.set_ylim(-1.2, 5.5) + ax.set_aspect("equal") + ax.axis("off") + + save_fig(fig, "q24_anchor_boxes.png") + + +# ============================================================ +# 18. Detection task comparison +# ============================================================ +def draw_detection_tasks() -> None: + """Draw detection tasks.""" + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + fig.suptitle( + "Klasyfikacja vs Detekcja vs Segmentacja", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Classification + ax = axes[0] + ax.add_patch( + mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + # Simple cat silhouette + ax.add_patch(mpatches.Ellipse((2, 2), 2, 1.5, facecolor=GRAY3, edgecolor=LN, lw=1)) + ax.add_patch(mpatches.Ellipse((2, 3), 1, 0.8, facecolor=GRAY3, edgecolor=LN, lw=1)) + # Ears + ax.plot([1.6, 1.5, 1.8], [3.3, 3.8, 3.4], color=LN, lw=1.5) + ax.plot([2.2, 2.5, 2.4], [3.3, 3.8, 3.4], color=LN, lw=1.5) + ax.text( + 2, -0.4, '→ "KOT" (jedna etykieta)', ha="center", fontsize=FS, fontweight="bold" + ) + ax.set_xlim(-0.5, 4.5) + ax.set_ylim(-0.8, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Klasyfikacja\n(co?)", fontsize=FS, fontweight="bold") + + # Detection + ax = axes[1] + ax.add_patch( + mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + # Cat + ax.add_patch( + mpatches.Ellipse((1.2, 2), 1.2, 1, facecolor=GRAY3, edgecolor=LN, lw=1) + ) + ax.add_patch( + mpatches.Ellipse((1.2, 2.8), 0.7, 0.5, facecolor=GRAY3, edgecolor=LN, lw=1) + ) + # Dog + ax.add_patch( + mpatches.Ellipse((3, 1.5), 1.2, 1, facecolor=GRAY2, edgecolor=LN, lw=1) + ) + ax.add_patch( + mpatches.Ellipse((3, 2.3), 0.7, 0.5, facecolor=GRAY2, edgecolor=LN, lw=1) + ) + # Bounding boxes + ax.add_patch( + mpatches.Rectangle((0.3, 1.2), 1.8, 2.2, facecolor="none", edgecolor=LN, lw=2.5) + ) + ax.text(1.2, 3.5, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.add_patch( + mpatches.Rectangle((2.1, 0.8), 1.7, 2.0, facecolor="none", edgecolor=LN, lw=2.5) + ) + ax.text(3.0, 2.9, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.text( + 2, + -0.4, + "→ bbox + klasa (N obiektów)", + ha="center", + fontsize=FS, + fontweight="bold", + ) + ax.set_xlim(-0.5, 4.5) + ax.set_ylim(-0.8, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Detekcja\n(co? + gdzie?)", fontsize=FS, fontweight="bold") + + # Segmentation + ax = axes[2] + ax.add_patch( + mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + # Cat mask (detailed) + theta = np.linspace(0, 2 * np.pi, 30) + cat_x = 1.2 + 0.6 * np.cos(theta) + 0.1 * np.sin(3 * theta) + cat_y = 2 + 0.5 * np.sin(theta) + 0.1 * np.cos(2 * theta) + ax.fill(cat_x, cat_y, facecolor=GRAY3, edgecolor=LN, lw=1.5) + # Dog mask + dog_x = 3.0 + 0.6 * np.cos(theta) + 0.05 * np.sin(4 * theta) + dog_y = 1.5 + 0.5 * np.sin(theta) + 0.08 * np.cos(3 * theta) + ax.fill(dog_x, dog_y, facecolor=GRAY2, edgecolor=LN, lw=1.5) + ax.text(1.2, 2, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.text(3.0, 1.5, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.text( + 2, + -0.4, + "→ maska pikseli (per piksel)", + ha="center", + fontsize=FS, + fontweight="bold", + ) + ax.set_xlim(-0.5, 4.5) + ax.set_ylim(-0.8, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Segmentacja\n(dokładna maska)", fontsize=FS, fontweight="bold") + + fig.tight_layout() + save_fig(fig, "q24_detection_tasks.png") + + +# ============================================================ +# 19. CNN Architecture overview +# ============================================================ +def draw_cnn_architecture() -> None: + """Draw cnn architecture.""" + fig, ax = plt.subplots(figsize=(12, 4)) + ax.set_xlim(-0.5, 12.5) + ax.set_ylim(-1, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "CNN — od obrazu do predykcji (architektura)", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + # Input image + draw_box(ax, 0, 0.5, 1.5, 3, "Obraz\n224x224x3", fill=GRAY1, fontsize=FS) + + # Conv1 + draw_arrow(ax, 1.6, 2.0, 2.1, 2.0, lw=1.2) + draw_box( + ax, 2.2, 0.8, 1.2, 2.4, "Conv1\n+ReLU\n55x55x96", fill=GRAY4, fontsize=FS_SMALL + ) + + # Pool1 + draw_arrow(ax, 3.5, 2.0, 3.9, 2.0, lw=1.2) + draw_box(ax, 4.0, 1.0, 1.0, 2.0, "Pool\n27x27\nx96", fill=GRAY2, fontsize=FS_SMALL) + + # Conv2 + draw_arrow(ax, 5.1, 2.0, 5.5, 2.0, lw=1.2) + draw_box( + ax, + 5.6, + 0.8, + 1.2, + 2.4, + "Conv2\n+ReLU\n27x27\nx256", + fill=GRAY4, + fontsize=FS_SMALL, + ) + + # Pool2 + draw_arrow(ax, 6.9, 2.0, 7.3, 2.0, lw=1.2) + draw_box(ax, 7.4, 1.2, 0.8, 1.6, "Pool\n13x13\nx256", fill=GRAY2, fontsize=FS_SMALL) + + # More conv... + draw_arrow(ax, 8.3, 2.0, 8.7, 2.0, lw=1.2) + ax.text(9.0, 2.0, "...", fontsize=14, ha="center", va="center") + draw_arrow(ax, 9.3, 2.0, 9.7, 2.0, lw=1.2) + + # FC + draw_box(ax, 9.8, 1.2, 1.0, 1.6, "FC\n4096", fill=GRAY3, fontsize=FS) + + draw_arrow(ax, 10.9, 2.0, 11.3, 2.0, lw=1.2) + + # Output + draw_box( + ax, 11.4, 1.5, 1.0, 1.0, "Softmax\n1000 klas", fill=GRAY1, fontsize=FS_SMALL + ) + + # Annotations below + ax.text( + 3.0, + 0.0, + "rozmiar maleje\n224→55→27→13→6", + ha="center", + fontsize=FS_SMALL, + style="italic", + ) + ax.text( + 6.0, + 0.0, + "kanały rosną\n3→96→256→384", + ha="center", + fontsize=FS_SMALL, + style="italic", + ) + ax.text( + 10.0, 0.0, "decyzja\nkońcowa", ha="center", fontsize=FS_SMALL, style="italic" + ) + + # hierarchy + ax.text( + 6.0, + 4.0, + "Hierarchia: krawędzie → rogi → fragmenty → obiekty (K-R-F-O)", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q24_cnn_architecture.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_haar_integral_svm.py b/python_pkg/praca_magisterska_video/generate_images/_q24_haar_integral_svm.py new file mode 100644 index 0000000..1b35caf --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_haar_integral_svm.py @@ -0,0 +1,342 @@ +"""Haar features, integral image, and SVM hyperplane diagrams.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q24_common import ( + _DATA_BRIGHT_THRESH, + _II_BRIGHT_THRESH, + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY3, + GRAY4, + LN, + np, + plt, + rng, + save_fig, +) +import matplotlib.patches as mpatches + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +# ============================================================ +# 4. Haar Features +# ============================================================ +def draw_haar_features() -> None: + """Draw haar features.""" + fig, axes = plt.subplots(1, 4, figsize=(11, 3)) + fig.suptitle( + "Cechy Haar — typy i zastosowanie na twarzy", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Feature 1: Vertical edge + ax = axes[0] + ax.add_patch( + mpatches.Rectangle((0, 0), 1, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + ax.add_patch( + mpatches.Rectangle((1, 0), 1, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5) + ) + ax.text( + 0.5, 1, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" + ) + ax.text( + 1.5, 1, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" + ) + ax.set_xlim(-0.2, 2.2) + ax.set_ylim(-0.5, 2.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Krawędź pionowa\nwartość = Σ₁ - Σ₂", fontsize=FS) + + # Feature 2: Horizontal edge + ax = axes[1] + ax.add_patch( + mpatches.Rectangle((0, 1), 2, 1, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + ax.add_patch( + mpatches.Rectangle((0, 0), 2, 1, facecolor=GRAY3, edgecolor=LN, lw=1.5) + ) + ax.text( + 1, 1.5, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" + ) + ax.text( + 1, 0.5, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" + ) + ax.set_xlim(-0.2, 2.2) + ax.set_ylim(-0.5, 2.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Krawędź pozioma\n(oczy vs czoło)", fontsize=FS) + + # Feature 3: Three-rectangle (line) + ax = axes[2] + ax.add_patch( + mpatches.Rectangle((0, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5) + ) + ax.add_patch( + mpatches.Rectangle((0.7, 0), 0.7, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + ax.add_patch( + mpatches.Rectangle((1.4, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5) + ) + ax.text( + 0.35, 1, "-Σ₁", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold" + ) + ax.text( + 1.05, 1, "+Σ₂", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold" + ) + ax.text( + 1.75, 1, "-Σ₃", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold" + ) + ax.set_xlim(-0.2, 2.3) + ax.set_ylim(-0.5, 2.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Linia (3 prostokąty)\n(nos vs policzki)", fontsize=FS) + + _draw_haar_face_panel(axes[3]) + + fig.tight_layout() + save_fig(fig, "q24_haar_features.png") + + +def _draw_haar_face_panel(ax: Axes) -> None: + """Draw Haar feature application on face schematic.""" + face = mpatches.Ellipse( + (1.2, 1.2), + 2.0, + 2.4, + facecolor=GRAY4, + edgecolor=LN, + lw=1.5, + ) + ax.add_patch(face) + ax.add_patch( + mpatches.Ellipse((0.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1) + ) + ax.add_patch( + mpatches.Ellipse((1.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1) + ) + ax.plot([1.2, 1.1, 1.3], [1.3, 0.9, 0.9], color=LN, lw=1) + ax.plot([0.8, 1.0, 1.2, 1.4, 1.6], [0.55, 0.5, 0.55, 0.5, 0.55], color=LN, lw=1) + ax.add_patch( + mpatches.Rectangle( + (0.3, 1.4), + 1.8, + 0.4, + facecolor="none", + edgecolor=LN, + lw=2, + linestyle="--", + ) + ) + ax.annotate( + "cechy Haar\n(oczy ciemne\nvs czoło jasne)", + xy=(1.2, 1.85), + xytext=(2.2, 2.3), + fontsize=FS_SMALL, + ha="center", + arrowprops={"arrowstyle": "->", "color": LN, "lw": 1}, + ) + ax.set_xlim(-0.2, 3.0) + ax.set_ylim(-0.2, 2.8) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Zastosowanie na twarzy", fontsize=FS) + + +# ============================================================ +# 5. Integral Image +# ============================================================ +def draw_integral_image() -> None: + """Draw integral image.""" + fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) + fig.suptitle( + "Integral Image — suma prostokąta w O(1)", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Original image + ax = axes[0] + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + ax.imshow(data, cmap="gray", vmin=0, vmax=10) + for i in range(3): + for j in range(3): + ax.text( + j, + i, + str(data[i, j]), + ha="center", + va="center", + fontsize=12, + fontweight="bold", + color="white" if data[i, j] > _DATA_BRIGHT_THRESH else "black", + ) + ax.set_title("① Obraz oryginalny", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Integral image + ax = axes[1] + ii = np.array([[1, 3, 6], [5, 12, 21], [12, 27, 45]]) + ax.imshow(ii, cmap="gray", vmin=0, vmax=50) + for i in range(3): + for j in range(3): + ax.text( + j, + i, + str(ii[i, j]), + ha="center", + va="center", + fontsize=12, + fontweight="bold", + color="white" if ii[i, j] > _II_BRIGHT_THRESH else "black", + ) + ax.set_title("② Integral Image\n(sumy kumulatywne)", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Formula illustration + ax = axes[2] + ax.axis("off") + ax.set_xlim(0, 4) + ax.set_ylim(0, 4) + # Draw rectangle + ax.add_patch( + mpatches.Rectangle((0.5, 0.5), 3, 3, facecolor="white", edgecolor=LN, lw=1) + ) + ax.add_patch( + mpatches.Rectangle((1.5, 0.5), 2, 2, facecolor=GRAY3, edgecolor=LN, lw=2) + ) + # Labels + ax.text(0.3, 3.7, "A", fontsize=12, fontweight="bold") + ax.text(3.6, 3.7, "B", fontsize=12, fontweight="bold") + ax.text(0.3, 0.3, "C", fontsize=12, fontweight="bold") + ax.text(3.6, 0.3, "D", fontsize=12, fontweight="bold") + ax.text( + 2.5, + 1.5, + "SZUKANA\nSUMA", + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + ax.text( + 2.0, + -0.3, + "Suma = D - B - C + A\n= 4 odczyty → O(1) ZAWSZE!", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + ax.set_title( + "③ Formuła: 4 odczyty\n= O(1) niezależnie od rozmiaru", + fontsize=FS, + fontweight="bold", + ) + + fig.tight_layout() + save_fig(fig, "q24_integral_image.png") + + +# ============================================================ +# 11. SVM Hyperplane +# ============================================================ +def draw_svm_hyperplane() -> None: + """Draw svm hyperplane.""" + fig, ax = plt.subplots(figsize=(6, 5)) + ax.set_title( + "SVM — hiperpłaszczyzna i margines", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + x_pos = rng.standard_normal(15) * 0.5 + 3 + y_pos = rng.standard_normal(15) * 0.5 + 3 + ax.scatter( + x_pos, + y_pos, + marker="o", + s=50, + facecolors="white", + edgecolors=LN, + linewidths=1.5, + label="klasa +1 (pieszy)", + zorder=3, + ) + + x_neg = rng.standard_normal(15) * 0.5 + 1 + y_neg = rng.standard_normal(15) * 0.5 + 1 + ax.scatter( + x_neg, + y_neg, + marker="x", + s=50, + c=LN, + linewidths=1.5, + label="klasa -1 (tło)", + zorder=3, + ) + + # Hyperplane (decision boundary) + x_line = np.linspace(-0.5, 5, 100) + y_line = -x_line + 4.0 + ax.plot(x_line, y_line, "k-", lw=2, label="hiperpłaszczyzna") + + # Margin lines + ax.plot(x_line, y_line + 0.7, "k--", lw=1, alpha=0.5) + ax.plot(x_line, y_line - 0.7, "k--", lw=1, alpha=0.5) + + # Margin annotation + ax.annotate( + "", + xy=(2.5, 1.5 + 0.7), + xytext=(2.5, 1.5 - 0.7), + arrowprops={"arrowstyle": "<->", "color": LN, "lw": 1.5}, + ) + ax.text(2.8, 1.5, "margines\n(MAX!)", fontsize=FS, fontweight="bold") + + # Support vectors (highlight closest points) + ax.scatter( + [2.5], + [2.2], + marker="o", + s=120, + facecolors="none", + edgecolors=LN, + linewidths=2.5, + zorder=4, + ) + ax.scatter([1.5], [1.8], marker="x", s=120, c=LN, linewidths=2.5, zorder=4) + ax.annotate( + "support\nvectors", + xy=(1.5, 1.8), + xytext=(0.2, 3.0), + fontsize=FS, + fontweight="bold", + arrowprops={"arrowstyle": "->", "color": LN, "lw": 1}, + ) + + ax.set_xlim(-0.5, 5) + ax.set_ylim(-0.5, 5) + ax.set_xlabel("cecha 1 (np. gradient pionowy)", fontsize=FS) + ax.set_ylabel("cecha 2 (np. gradient poziomy)", fontsize=FS) + ax.legend(fontsize=FS_SMALL, loc="lower right") + ax.set_aspect("equal") + + save_fig(fig, "q24_svm_hyperplane.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_hog_classical.py b/python_pkg/praca_magisterska_video/generate_images/_q24_hog_classical.py new file mode 100644 index 0000000..0d53d6a --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_hog_classical.py @@ -0,0 +1,380 @@ +"""HOG + SVM pipeline, HOG gradient steps, Viola-Jones cascade.""" + +from __future__ import annotations + +from _q24_common import ( + _DOTS_STAGE_IDX, + _GRADIENT_BRIGHT_THRESH, + _PIXEL_BRIGHT_THRESH, + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + np, + plt, + save_fig, +) +import matplotlib.patches as mpatches + + +# ============================================================ +# 1. HOG + SVM Pipeline +# ============================================================ +def draw_hog_svm_pipeline() -> None: + """Draw hog svm pipeline.""" + fig, ax = plt.subplots(figsize=(10, 4.5)) + ax.set_xlim(-0.5, 10.5) + ax.set_ylim(-1, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "HOG + SVM — pipeline detekcji pieszych", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + # Step 1: Image with sliding window + ax.add_patch( + mpatches.Rectangle((0, 1.5), 2, 2, lw=1.5, edgecolor=LN, facecolor=GRAY1) + ) + ax.text(1, 2.5, "Obraz\nwejściowy", ha="center", va="center", fontsize=FS) + # sliding window overlay + ax.add_patch( + mpatches.Rectangle( + (0.3, 1.8), + 0.8, + 1.2, + lw=1.5, + edgecolor="black", + facecolor="none", + linestyle="--", + ) + ) + ax.text( + 0.7, + 1.35, + "okno 64x128", + ha="center", + va="center", + fontsize=FS_SMALL, + style="italic", + ) + + draw_arrow(ax, 2.1, 2.5, 2.8, 2.5, lw=1.5) + ax.text(2.45, 2.75, "①", ha="center", fontsize=FS_LABEL, fontweight="bold") + + # Step 2: Gradient computation + draw_box( + ax, 2.9, 1.8, 1.6, 1.4, "Oblicz\ngradienty\nGx, Gy", fill=GRAY4, fontsize=FS + ) + ax.text( + 3.7, 1.55, "kierunek + siła", ha="center", fontsize=FS_SMALL, style="italic" + ) + + draw_arrow(ax, 4.6, 2.5, 5.2, 2.5, lw=1.5) + ax.text(4.9, 2.75, "②", ha="center", fontsize=FS_LABEL, fontweight="bold") + + # Step 3: HOG histogram + draw_box( + ax, + 5.3, + 1.8, + 1.6, + 1.4, + "Histogramy\nkierunkowe\n9 binów/cel", + fill=GRAY4, + fontsize=FS, + ) + ax.text(6.1, 1.55, "komórki 8x8 px", ha="center", fontsize=FS_SMALL, style="italic") + + draw_arrow(ax, 7.0, 2.5, 7.6, 2.5, lw=1.5) + ax.text(7.3, 2.75, "③", ha="center", fontsize=FS_LABEL, fontweight="bold") + + # Step 4: SVM + draw_box( + ax, + 7.7, + 1.8, + 1.4, + 1.4, + "SVM\nklasyfikator\npieszy/tło", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + + draw_arrow(ax, 9.2, 2.5, 9.7, 2.5, lw=1.5) + ax.text(9.45, 2.75, "④", ha="center", fontsize=FS_LABEL, fontweight="bold") + + # Step 5: NMS + output + draw_box(ax, 9.3, 2.0, 1.0, 1.0, "NMS\n→ wynik", fill=GRAY1, fontsize=FS) + + # Bottom: HOG feature vector illustration + ax.text( + 5.0, + 0.7, + "Wektor HOG: 3780 cech = 105 bloków x 4 komórki x 9 binów", + ha="center", + fontsize=FS, + style="italic", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # Show small histogram bars + bar_x = 3.2 + bar_y = 0.0 + angles = [0, 20, 40, 60, 80, 100, 120, 140, 160] + values = [0.3, 0.1, 0.5, 0.8, 0.2, 0.6, 0.15, 0.4, 0.25] + for i, (_a, v) in enumerate(zip(angles, values, strict=False)): + ax.add_patch( + mpatches.Rectangle( + (bar_x + i * 0.18, bar_y), + 0.15, + v * 0.6, + facecolor=GRAY3, + edgecolor=LN, + lw=0.5, + ) + ) + ax.text(bar_x + 0.8, -0.2, "9 binów (0°-160°)", ha="center", fontsize=FS_SMALL) + + save_fig(fig, "q24_hog_svm_pipeline.png") + + +# ============================================================ +# 2. HOG Gradient Step-by-Step +# ============================================================ +def draw_hog_gradient_steps() -> None: + """Draw hog gradient steps.""" + fig, axes = plt.subplots(1, 4, figsize=(12, 3.5)) + fig.suptitle( + "HOG — kroki obliczania cech", fontsize=FS_TITLE, fontweight="bold", y=1.02 + ) + + # Step 1: Original patch + ax = axes[0] + patch = np.array([[50, 50, 200], [50, 50, 200], [50, 50, 200]]) + ax.imshow(patch, cmap="gray", vmin=0, vmax=255) + for i in range(3): + for j in range(3): + ax.text( + j, + i, + str(patch[i, j]), + ha="center", + va="center", + fontsize=FS_LABEL, + fontweight="bold", + color="white" if patch[i, j] > _PIXEL_BRIGHT_THRESH else "black", + ) + ax.set_title("① Fragment obrazu\n(jasność pikseli)", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Step 2: Gradient magnitude + ax = axes[1] + gx = np.array([[0, 150, 0], [0, 150, 0], [0, 150, 0]]) + ax.imshow(gx, cmap="gray", vmin=0, vmax=255) + for i in range(3): + for j in range(3): + ax.text( + j, + i, + str(gx[i, j]), + ha="center", + va="center", + fontsize=FS_LABEL, + fontweight="bold", + color="white" if gx[i, j] > _GRADIENT_BRIGHT_THRESH else "black", + ) + ax.set_title("② Gradient Gx\n(krawędź pionowa!)", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Step 3: Cell histogram + ax = axes[2] + angles = ["0°", "20°", "40°", "60°", "80°", "100°", "120°", "140°", "160°"] + values = [150, 0, 0, 0, 0, 0, 0, 0, 0] + bars = ax.bar(range(9), values, color=GRAY3, edgecolor=LN, linewidth=0.5) + bars[0].set_facecolor(GRAY5) + ax.set_xticks(range(9)) + ax.set_xticklabels(angles, fontsize=5, rotation=45) + ax.set_title( + "③ Histogram komórki\n(bin 0° = krawędź pionowa)", + fontsize=FS, + fontweight="bold", + ) + ax.set_ylabel("siła", fontsize=FS_SMALL) + + # Step 4: Block normalization + ax = axes[3] + # 2x2 block of cells + for i in range(2): + for j in range(2): + rect = mpatches.Rectangle( + (j * 1.2, (1 - i) * 1.2), + 1.0, + 1.0, + lw=1.2, + edgecolor=LN, + facecolor=GRAY4, + ) + ax.add_patch(rect) + ax.text( + j * 1.2 + 0.5, + (1 - i) * 1.2 + 0.5, + f"hist\n{i * 2 + j + 1}", + ha="center", + va="center", + fontsize=FS_SMALL, + ) + ax.add_patch( + mpatches.Rectangle( + (-0.1, -0.1), 2.6, 2.6, lw=2, edgecolor=LN, facecolor="none", linestyle="--" + ) + ) + ax.text( + 1.2, + -0.4, + "blok 2x2 → L2-norm", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + ax.set_xlim(-0.3, 2.8) + ax.set_ylim(-0.7, 2.8) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "④ Normalizacja bloków\n(odporność na oświetlenie)", + fontsize=FS, + fontweight="bold", + ) + + fig.tight_layout() + save_fig(fig, "q24_hog_gradient_steps.png") + + +# ============================================================ +# 3. Viola-Jones Cascade +# ============================================================ +def draw_viola_jones_cascade() -> None: + """Draw viola jones cascade.""" + fig, ax = plt.subplots(figsize=(10, 5)) + ax.set_xlim(-0.5, 10.5) + ax.set_ylim(-1.5, 5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Viola-Jones — kaskada klasyfikatorów (SITO)", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + # Input + draw_box( + ax, + -0.3, + 2.5, + 1.5, + 1.2, + "500 000\nokien", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + + stages = [ + ("Etap 1\n2 cechy", "50%\nodrzucone", "250 000", GRAY4), + ("Etap 2\n10 cech", "80%\nodrzucone", "50 000", GRAY4), + ("Etap 3\n25 cech", "90%\nodrzucone", "5 000", GRAY4), + ("Etap 25\n200 cech", "99%\nodrzucone", "50", GRAY3), + ] + + x_pos = 1.6 + for i, (label, reject, remain, col) in enumerate(stages): + # Stage box + draw_box( + ax, x_pos, 2.5, 1.6, 1.2, label, fill=col, fontsize=FS, fontweight="bold" + ) + + # Arrow from previous + draw_arrow(ax, x_pos - 0.3, 3.1, x_pos - 0.05, 3.1, lw=1.5) + + # Reject arrow down + draw_arrow(ax, x_pos + 0.8, 2.45, x_pos + 0.8, 1.6, lw=1.2) + ax.text( + x_pos + 0.8, + 1.3, + reject, + ha="center", + fontsize=FS_SMALL, + color="black", + style="italic", + ) + ax.text( + x_pos + 0.8, + 0.8, + "✗ NIE-TWARZ", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + + # Remaining count above + if i < len(stages) - 1: + ax.text( + x_pos + 2.0, + 3.9, + f"→ {remain}", + ha="center", + fontsize=FS_SMALL, + style="italic", + ) + + # Dots between stage 3 and stage 25 + if i == _DOTS_STAGE_IDX: + ax.text( + x_pos + 2.0, 3.1, "· · ·", ha="center", fontsize=12, fontweight="bold" + ) + x_pos += 2.5 + else: + x_pos += 2.1 + + # Final output + draw_arrow(ax, x_pos + 0.3, 3.1, x_pos + 0.9, 3.1, lw=1.5) + draw_box( + ax, + x_pos + 0.5, + 2.5, + 1.3, + 1.2, + "~50\nTWARZE\n✓", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Timing info + ax.text( + 5.0, + -0.5, + "Czas: 99% okien odrzucone w etapach 1-3 (~5 μs każde)\n" + "Tylko 0.01% dochodzi do etapu 25 → cały obraz w ~30 ms = 30+ fps", + ha="center", + fontsize=FS, + style="italic", + bbox={"boxstyle": "round,pad=0.4", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q24_viola_jones_cascade.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_iou_nms_detector.py b/python_pkg/praca_magisterska_video/generate_images/_q24_iou_nms_detector.py new file mode 100644 index 0000000..4871d9f --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_iou_nms_detector.py @@ -0,0 +1,413 @@ +"""IoU diagram, NMS steps, and detector-from-classifier diagrams.""" + +from __future__ import annotations + +from _q24_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + LN, + draw_arrow, + draw_box, + plt, + save_fig, +) +import matplotlib.patches as mpatches + + +# ============================================================ +# 8. IoU Diagram +# ============================================================ +def draw_iou_diagram() -> None: + """Draw iou diagram.""" + fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) + fig.suptitle( + "IoU (Intersection over Union) — miara nakładania bboxów", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Low IoU + ax = axes[0] + ax.add_patch( + mpatches.Rectangle( + (0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5, label="A" + ) + ) + ax.add_patch( + mpatches.Rectangle( + (2.5, 2.5), + 3, + 3, + facecolor=GRAY2, + edgecolor=LN, + lw=1.5, + alpha=0.7, + label="B", + ) + ) + # Intersection + ax.add_patch( + mpatches.Rectangle((2.5, 2.5), 0.5, 0.5, facecolor=GRAY3, edgecolor=LN, lw=2) + ) + ax.text(1.5, 1.5, "A", ha="center", va="center", fontsize=12, fontweight="bold") + ax.text(4, 4, "B", ha="center", va="center", fontsize=12, fontweight="bold") + ax.set_xlim(-0.5, 6) + ax.set_ylim(-0.5, 6) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "IoU ≈ 0.04\n(prawie się nie nakładają)", fontsize=FS, fontweight="bold" + ) + + # Medium IoU + ax = axes[1] + ax.add_patch( + mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + ax.add_patch( + mpatches.Rectangle( + (1.5, 1.5), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7 + ) + ) + ax.add_patch( + mpatches.Rectangle((1.5, 1.5), 1.5, 1.5, facecolor=GRAY3, edgecolor=LN, lw=2) + ) + ax.text(0.7, 0.7, "A", ha="center", va="center", fontsize=12, fontweight="bold") + ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold") + ax.text(2.25, 2.25, "∩", ha="center", va="center", fontsize=14, fontweight="bold") + ax.set_xlim(-0.5, 5) + ax.set_ylim(-0.5, 5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("IoU ≈ 0.14\n(częściowe nakładanie)", fontsize=FS, fontweight="bold") + + # High IoU + ax = axes[2] + ax.add_patch( + mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + ax.add_patch( + mpatches.Rectangle( + (0.3, 0.3), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7 + ) + ) + ax.add_patch( + mpatches.Rectangle((0.3, 0.3), 2.7, 2.7, facecolor=GRAY3, edgecolor=LN, lw=2) + ) + ax.text(-0.3, -0.3, "A", ha="center", va="center", fontsize=12, fontweight="bold") + ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold") + ax.text(1.65, 1.65, "∩", ha="center", va="center", fontsize=14, fontweight="bold") + ax.set_xlim(-0.8, 4) + ax.set_ylim(-0.8, 4) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "IoU ≈ 0.74\n(duże nakładanie → duplikat!)", fontsize=FS, fontweight="bold" + ) + + fig.tight_layout() + save_fig(fig, "q24_iou_diagram.png") + + +# ============================================================ +# 9. NMS Step-by-Step +# ============================================================ +def draw_nms_steps() -> None: + """Draw nms steps.""" + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + fig.suptitle( + "NMS (Non-Maximum Suppression) — usuwanie duplikatów", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Before NMS + ax = axes[0] + ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1)) + # Multiple overlapping boxes for same object + ax.add_patch( + mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2) + ) + ax.text(2.25, 4.2, "conf=0.95", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.add_patch( + mpatches.Rectangle( + (1.2, 1.3), 2.3, 2.8, facecolor="none", edgecolor=LN, lw=1.5, linestyle="--" + ) + ) + ax.text(2.35, 1.1, "conf=0.90", ha="center", fontsize=FS_SMALL) + ax.add_patch( + mpatches.Rectangle( + (0.8, 0.8), 2.7, 3.2, facecolor="none", edgecolor=LN, lw=1, linestyle=":" + ) + ) + ax.text(2.15, 0.6, "conf=0.85", ha="center", fontsize=FS_SMALL) + # Different object + ax.add_patch( + mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=1.5) + ) + ax.text(4.75, 3.7, "conf=0.80", ha="center", fontsize=FS_SMALL) + ax.text( + 2, + 0.2, + "⚠ 4 detekcje (3 duplikaty!)", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + ax.set_xlim(-0.3, 6.3) + ax.set_ylim(-0.3, 5.3) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "① Przed NMS\n(wiele nakładających się)", fontsize=FS, fontweight="bold" + ) + + # NMS process + ax = axes[1] + ax.axis("off") + ax.set_xlim(0, 6) + ax.set_ylim(0, 5) + + steps = [ + ("1. Sortuj: [0.95, 0.90, 0.85, 0.80]", 4.5), + ("2. Weź najlepszą (0.95) → ZACHOWAJ", 3.7), + ("3. IoU(0.95, 0.90)=0.82 > 0.5 → USUŃ", 2.9), + ("4. IoU(0.95, 0.85)=0.75 > 0.5 → USUŃ", 2.1), + ("5. IoU(0.95, 0.80)=0.10 < 0.5 → ZACHOWAJ", 1.3), + ] + colors = [GRAY4, GRAY2, GRAY4, GRAY4, GRAY2] + for (text, yp), c in zip(steps, colors, strict=False): + ax.text( + 3.0, + yp, + text, + ha="center", + fontsize=FS, + bbox={"boxstyle": "round,pad=0.2", "facecolor": c, "edgecolor": GRAY3}, + ) + + ax.set_title("② Algorytm NMS\n(próg IoU = 0.5)", fontsize=FS, fontweight="bold") + + # After NMS + ax = axes[2] + ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1)) + # Only best box for each object + ax.add_patch( + mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2.5) + ) + ax.text(2.25, 4.2, "conf=0.95 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.add_patch( + mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5) + ) + ax.text(4.75, 3.7, "conf=0.80 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.text( + 3, + 0.2, + "✓ 2 unikalne obiekty", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + ax.set_xlim(-0.3, 6.3) + ax.set_ylim(-0.3, 5.3) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("③ Po NMS\n(1 bbox na obiekt)", fontsize=FS, fontweight="bold") + + fig.tight_layout() + save_fig(fig, "q24_nms_steps.png") + + +# ============================================================ +# 10. Detector from Classifier — 3 approaches +# ============================================================ +def draw_detector_from_classifier() -> None: + """Draw detector from classifier.""" + fig, ax = plt.subplots(figsize=(11, 9)) + ax.set_xlim(-0.5, 11) + ax.set_ylim(-1, 9.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Jak zbudować detektor z klasyfikatora? — 3 podejścia", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + # ---- Approach 1: Sliding Window ---- + y = 7.0 + ax.text( + 0, + y + 1.5, + "① Sliding Window (NAJWOLNIEJSZE)", + fontsize=FS_LABEL, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # Image with sliding window + ax.add_patch( + mpatches.Rectangle( + (0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5 + ) + ) + ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL) + # Sliding windows + for dx, dy in [(0.1, 0.1), (0.4, 0.1), (0.7, 0.1), (0.1, 0.5), (0.4, 0.5)]: + ax.add_patch( + mpatches.Rectangle( + (dx, y - 0.5 + dy), + 0.5, + 0.5, + facecolor="none", + edgecolor=LN, + lw=0.8, + linestyle="--", + ) + ) + + draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2) + ax.text(2.35, y + 0.6, "xmiliony", fontsize=FS_SMALL, style="italic") + + draw_box( + ax, + 2.8, + y - 0.3, + 1.8, + 1.2, + 'Klasyfikator\n(ResNet)\n"kot? pies? tło?"', + fill=GRAY4, + fontsize=FS, + ) + draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2) + draw_box(ax, 5.4, y - 0.3, 1.2, 1.2, "NMS", fill=GRAY1, fontsize=FS) + draw_arrow(ax, 6.7, y + 0.3, 7.3, y + 0.3, lw=1.2) + ax.text( + 8.5, + y + 0.3, + "~3.3h / obraz!\n⚠ NIEPRAKTYCZNE", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # ---- Approach 2: Region Proposals ---- + y = 3.8 + ax.text( + 0, + y + 1.5, + "② Region Proposals + Klasyfikator (= R-CNN)", + fontsize=FS_LABEL, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.add_patch( + mpatches.Rectangle( + (0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5 + ) + ) + ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL) + # A few smart regions + ax.add_patch( + mpatches.Rectangle( + (0.1, y - 0.4), 0.7, 0.9, facecolor="none", edgecolor=LN, lw=1.5 + ) + ) + ax.add_patch( + mpatches.Rectangle( + (0.9, y + 0.0), 0.7, 0.6, facecolor="none", edgecolor=LN, lw=1.5 + ) + ) + + draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2) + draw_box( + ax, + 2.8, + y - 0.3, + 1.6, + 1.2, + "Selective\nSearch\n~2000 regionów", + fill=GRAY2, + fontsize=FS, + ) + draw_arrow(ax, 4.5, y + 0.3, 5.1, y + 0.3, lw=1.2) + ax.text(4.8, y + 0.6, "x2000", fontsize=FS_SMALL, style="italic") + draw_box(ax, 5.2, y - 0.3, 1.5, 1.2, "Klasyfikator\n(CNN)", fill=GRAY4, fontsize=FS) + draw_arrow(ax, 6.8, y + 0.3, 7.4, y + 0.3, lw=1.2) + draw_box(ax, 7.5, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS) + draw_arrow(ax, 8.6, y + 0.3, 9.0, y + 0.3, lw=1.2) + ax.text( + 10.0, + y + 0.3, + "~20-50 s/obraz\n(250x szybciej)", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # ---- Approach 3: Fine-tune backbone ---- + y = 0.5 + ax.text( + 0, + y + 1.5, + "③ Fine-tune backbone + detection head (NAJLEPSZE)", + fontsize=FS_LABEL, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY2, "edgecolor": GRAY3}, + ) + + ax.add_patch( + mpatches.Rectangle( + (0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5 + ) + ) + ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL) + + draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2) + draw_box( + ax, + 2.8, + y - 0.3, + 1.8, + 1.2, + "Pretrained\nbackbone\n(ResNet)", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2) + + # Two heads from feature map + draw_box(ax, 5.4, y + 0.3, 1.6, 0.6, "cls head\nP(klasa)", fill=GRAY4, fontsize=FS) + draw_box( + ax, 5.4, y - 0.5, 1.6, 0.6, "bbox head\nΔx,Δy,Δw,Δh", fill=GRAY4, fontsize=FS + ) + + draw_arrow(ax, 7.1, y + 0.6, 7.7, y + 0.6, lw=1.0) + draw_arrow(ax, 7.1, y - 0.2, 7.7, y - 0.2, lw=1.0) + draw_box(ax, 7.8, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS) + + draw_arrow(ax, 8.9, y + 0.3, 9.3, y + 0.3, lw=1.2) + ax.text( + 10.2, + y + 0.3, + "5-155 fps!\n✓ NAJLEPSZE", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q24_detector_from_classifier.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_modern_pipelines.py b/python_pkg/praca_magisterska_video/generate_images/_q24_modern_pipelines.py new file mode 100644 index 0000000..4a1584e --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_modern_pipelines.py @@ -0,0 +1,365 @@ +"""Two-stage vs one-stage table, ROI pooling, DETR, and sliding window.""" + +from __future__ import annotations + +from _q24_common import ( + _DATA_BRIGHT_THRESH, + FS, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + draw_table, + np, + plt, + rng, + save_fig, +) +import matplotlib.patches as mpatches + + +# ============================================================ +# 12. Two-stage vs One-stage comparison table +# ============================================================ +def draw_two_vs_one_stage() -> None: + """Draw two vs one stage.""" + fig, ax = plt.subplots(figsize=(10, 3.5)) + ax.set_xlim(0, 10) + ax.set_ylim(-0.5, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Two-stage vs One-stage — porównanie", + fontsize=FS_TITLE, + fontweight="bold", + pad=8, + ) + + headers = ["Cecha", "Two-stage\n(Faster R-CNN)", "One-stage\n(YOLO)"] + rows = [ + ["Szybkość", "~5 fps", "45-155 fps"], + ["Dokładność (mAP)", "wyższa (historycznie)", "dorównuje (YOLOv8)"], + ["Małe obiekty", "lepszy", "gorszy (SSD/FPN pomaga)"], + ["Architektura", "2 etapy + NMS", "1 etap + NMS"], + ["Real-time?", "NIE", "TAK"], + ] + col_widths = [2.5, 3.5, 3.5] + draw_table( + ax, + headers, + rows, + 0.2, + 3.8, + col_widths, + row_h=0.65, + fontsize=FS, + header_fontsize=FS, + ) + + save_fig(fig, "q24_two_vs_one_stage.png") + + +# ============================================================ +# 13. ROI Pooling illustration +# ============================================================ +def draw_roi_pooling() -> None: + """Draw roi pooling.""" + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + fig.suptitle( + "ROI Pooling — dowolny rozmiar → stały rozmiar", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Feature map with ROI + ax = axes[0] + # Draw feature map grid + fm = rng.integers(0, 10, (8, 8)) + ax.imshow(fm, cmap="gray", vmin=0, vmax=10, alpha=0.3) + for i in range(9): + ax.axhline(y=i - 0.5, color=LN, lw=0.3) + ax.axvline(x=i - 0.5, color=LN, lw=0.3) + # ROI rectangle + ax.add_patch( + mpatches.Rectangle( + (1.5, 1.5), 4, 4, facecolor="none", edgecolor=LN, lw=3, linestyle="-" + ) + ) + ax.text(3.5, 0.8, "ROI", ha="center", fontsize=FS, fontweight="bold") + ax.set_xlim(-0.5, 7.5) + ax.set_ylim(7.5, -0.5) + ax.set_title("① Feature map\nz zaznaczonym ROI", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # ROI divided into grid + ax = axes[1] + roi_data = np.array( + [ + [1, 3, 2, 1], + [0, 5, 1, 6], + [0, 4, 1, 0], + [7, 2, 9, 1], + ] + ) + ax.imshow(roi_data, cmap="gray", vmin=0, vmax=10) + for i in range(5): + ax.axhline(y=i - 0.5, color=LN, lw=1) + ax.axvline(x=i - 0.5, color=LN, lw=1) + # Grid lines for 2x2 pooling + ax.axhline(y=0.5, color=LN, lw=3, linestyle="--") + ax.axvline(x=0.5, color=LN, lw=3, linestyle="--") + for i in range(4): + for j in range(4): + ax.text( + j, + i, + str(roi_data[i, j]), + ha="center", + va="center", + fontsize=10, + fontweight="bold", + color="white" if roi_data[i, j] > _DATA_BRIGHT_THRESH else "black", + ) + ax.set_title("② ROI podzielony\nna siatkę 2x2", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + # Output after pooling + ax = axes[2] + out = np.array([[5, 6], [7, 9]]) + ax.imshow(out, cmap="gray", vmin=0, vmax=10) + for i in range(3): + ax.axhline(y=i - 0.5, color=LN, lw=1.5) + ax.axvline(x=i - 0.5, color=LN, lw=1.5) + for i in range(2): + for j in range(2): + ax.text( + j, + i, + str(out[i, j]), + ha="center", + va="center", + fontsize=14, + fontweight="bold", + color="white" if out[i, j] > _DATA_BRIGHT_THRESH else "black", + ) + ax.set_title( + "③ Po ROI Pool 2x2\n(max z każdej komórki)", fontsize=FS, fontweight="bold" + ) + ax.set_xticks([]) + ax.set_yticks([]) + + fig.tight_layout() + save_fig(fig, "q24_roi_pooling.png") + + +# ============================================================ +# 14. DETR Pipeline +# ============================================================ +def draw_detr_pipeline() -> None: + """Draw detr pipeline.""" + fig, ax = plt.subplots(figsize=(11, 4.5)) + ax.set_xlim(-0.5, 11.5) + ax.set_ylim(-1, 4.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "DETR — Transformer do detekcji (bez NMS, bez anchorów)", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + # Pipeline + draw_box(ax, 0, 1.5, 1.5, 1.5, "Obraz\nwejściowy", fill=GRAY1, fontsize=FS) + draw_arrow(ax, 1.6, 2.25, 2.1, 2.25, lw=1.5) + + draw_box( + ax, + 2.2, + 1.5, + 1.5, + 1.5, + "CNN\nBackbone\n(ResNet)", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 3.8, 2.25, 4.3, 2.25, lw=1.5) + + draw_box( + ax, + 4.4, + 1.5, + 1.8, + 1.5, + "Transformer\nEncoder\n(self-attention)", + fill=GRAY2, + fontsize=FS, + ) + draw_arrow(ax, 6.3, 2.25, 6.8, 2.25, lw=1.5) + + draw_box( + ax, + 6.9, + 1.5, + 1.8, + 1.5, + "Transformer\nDecoder\n(N=100 queries)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + # Output branches + draw_arrow(ax, 8.8, 2.5, 9.5, 3.0, lw=1.2) + draw_box(ax, 9.6, 2.7, 1.5, 0.7, "klasa₁...klasa₁₀₀", fill=GRAY4, fontsize=FS_SMALL) + + draw_arrow(ax, 8.8, 2.0, 9.5, 1.5, lw=1.2) + draw_box(ax, 9.6, 1.2, 1.5, 0.7, "bbox₁...bbox₁₀₀", fill=GRAY4, fontsize=FS_SMALL) + + # Annotations + ax.text( + 7.8, + 0.5, + '100 object queries → 5 obiektów + 95x "brak"', + ha="center", + fontsize=FS, + style="italic", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.text( + 5.5, + 0.0, + "Hungarian matching (trening): optymalne dopasowanie predykcji do GT", + ha="center", + fontsize=FS_SMALL, + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, + ) + + # Big benefit box + ax.text( + 5.5, + 4.0, + "BEZ anchorów • BEZ NMS • end-to-end • prosty pipeline", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q24_detr_pipeline.png") + + +# ============================================================ +# 15. Sliding Window illustration +# ============================================================ +def draw_sliding_window() -> None: + """Draw sliding window.""" + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + fig.suptitle( + "Sliding Window — najprostsze podejście do detekcji", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Multi-position + ax = axes[0] + ax.add_patch( + mpatches.Rectangle((0, 0), 8, 6, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + # Grid of sliding windows + for i in range(4): + for j in range(3): + ax.add_patch( + mpatches.Rectangle( + (i * 1.8 + 0.2, j * 1.8 + 0.2), + 1.5, + 1.5, + facecolor="none", + edgecolor=LN, + lw=0.6, + linestyle="--", + ) + ) + # Highlight current window + ax.add_patch( + mpatches.Rectangle((2.0, 2.0), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5) + ) + ax.set_xlim(-0.5, 8.5) + ax.set_ylim(-0.5, 6.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("① Wiele pozycji\n(krok co 8 px)", fontsize=FS, fontweight="bold") + + # Multi-scale + ax = axes[1] + ax.add_patch( + mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1.5) + ) + sizes = [(0.8, 0.8), (1.5, 1.5), (2.5, 2.5), (3.5, 3.5)] + for i, (w, h) in enumerate(sizes): + ax.add_patch( + mpatches.Rectangle( + (0.3 + i * 0.3, 0.3 + i * 0.3), + w, + h, + facecolor="none", + edgecolor=LN, + lw=1 + i * 0.3, + linestyle=[":", "--", "-.", "-"][i], + ) + ) + ax.text(3, 0, "4+ skal", ha="center", fontsize=FS_SMALL, fontweight="bold") + ax.set_xlim(-0.5, 6.5) + ax.set_ylim(-0.5, 5.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "② Wiele skal\n(obiekty mają różne rozmiary)", fontsize=FS, fontweight="bold" + ) + + # Count + ax = axes[2] + ax.axis("off") + ax.set_xlim(0, 6) + ax.set_ylim(0, 5) + + lines = [ + ("Obraz: 640 x 480 px", 4.5), + ("Okno: 64 x 64 px, krok 8 px", 3.8), + ("Pozycje: ~72 x 52 = 3 744", 3.1), + ("x 5 skal = 18 720 okien", 2.4), + ("x klasyfikacja = WOLNE!", 1.7), + ("→ ~3h na jeden obraz", 0.8), + ] + for text, yp in lines: + fw = "bold" if "~3h" in text or "WOLNE" in text else "normal" + col = GRAY2 if "WOLNE" in text or "~3h" in text else GRAY4 + ax.text( + 3.0, + yp, + text, + ha="center", + fontsize=FS, + fontweight=fw, + bbox={"boxstyle": "round,pad=0.2", "facecolor": col, "edgecolor": GRAY3}, + ) + + ax.set_title( + "③ Dlaczego wolne?\n(miliony klasyfikacji)", fontsize=FS, fontweight="bold" + ) + + fig.tight_layout() + save_fig(fig, "q24_sliding_window.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q24_rcnn_yolo.py b/python_pkg/praca_magisterska_video/generate_images/_q24_rcnn_yolo.py new file mode 100644 index 0000000..26f134e --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q24_rcnn_yolo.py @@ -0,0 +1,344 @@ +"""R-CNN evolution and YOLO grid diagrams.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from _q24_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + LN, + draw_arrow, + draw_box, + plt, + save_fig, +) +import matplotlib.patches as mpatches + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def _draw_yolo_cell_prediction(ax: Axes) -> None: + """Draw YOLO cell prediction vector panel.""" + ax.axis("off") + ax.set_xlim(0, 6) + ax.set_ylim(-1, 5) + + labels = [ + "x", + "y", + "w", + "h", + "conf", + "x", + "y", + "w", + "h", + "conf", + "P(c₁)", + "...", + "P(c₂₀)", + ] + colors_vec = [GRAY4] * 5 + [GRAY2] * 5 + [GRAY1] * 3 + + bw = 0.42 + for i, (label, col) in enumerate( + zip(labels, colors_vec, strict=False), + ): + x_pos = 0.3 + i * bw + ax.add_patch( + mpatches.Rectangle( + (x_pos, 2.5), + bw - 0.02, + 0.6, + facecolor=col, + edgecolor=LN, + lw=0.8, + ) + ) + ax.text( + x_pos + bw / 2, + 2.8, + label, + ha="center", + va="center", + fontsize=5, + fontweight="bold", + ) + + ax.annotate( + "", + xy=(0.3, 2.4), + xytext=(2.4, 2.4), + arrowprops={"arrowstyle": "-", "lw": 1}, + ) + ax.text(1.35, 2.15, "bbox 1 (5 wartości)", ha="center", fontsize=FS_SMALL) + + ax.annotate( + "", + xy=(2.4, 2.4), + xytext=(4.5, 2.4), + arrowprops={"arrowstyle": "-", "lw": 1}, + ) + ax.text(3.45, 2.15, "bbox 2 (5 wartości)", ha="center", fontsize=FS_SMALL) + + ax.annotate( + "", + xy=(4.5, 2.4), + xytext=(5.8, 2.4), + arrowprops={"arrowstyle": "-", "lw": 1}, + ) + ax.text(5.15, 2.15, "20 klas", ha="center", fontsize=FS_SMALL) + + ax.text( + 3.0, + 3.5, + "Każda komórka → 30 wartości\n= 2x(x,y,w,h,conf) + 20 klas", + ha="center", + fontsize=FS, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.set_title( + "② Predykcja jednej komórki\n(S=7, B=2, C=20)", + fontsize=FS, + fontweight="bold", + ) + + +# ============================================================ +# 6. R-CNN Evolution +# ============================================================ +def draw_rcnn_evolution() -> None: + """Draw rcnn evolution.""" + fig, ax = plt.subplots(figsize=(11, 7)) + ax.set_xlim(-0.5, 11) + ax.set_ylim(-0.5, 7.5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Ewolucja R-CNN: od 50s do 0.2s na obraz", + fontsize=FS_TITLE, + fontweight="bold", + pad=12, + ) + + y_positions = [5.5, 3.0, 0.5] + labels = [ + "R-CNN (2014) — 50 s/obraz", + "Fast R-CNN (2015) — 2 s/obraz", + "Faster R-CNN (2015) — 0.2 s/obraz", + ] + + # R-CNN + y = y_positions[0] + ax.text(0, y + 1.3, labels[0], fontsize=FS_LABEL, fontweight="bold") + draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS) + draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45) + ax.text(2.3, y + 0.8, "~2000", ha="center", fontsize=FS_SMALL, style="italic") + draw_box(ax, 2.6, y, 1.5, 0.9, "Resize\n224x224", fill=GRAY4, fontsize=FS) + draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45) + draw_box( + ax, 4.7, y, 1.5, 0.9, "CNN\nx2000!", fill=GRAY3, fontsize=FS, fontweight="bold" + ) + draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45) + draw_box(ax, 6.8, y, 1.3, 0.9, "SVM\nklasyf.", fill=GRAY4, fontsize=FS) + draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45) + draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS) + # Problem annotation + ax.text( + 5.5, + y - 0.4, + "⚠ CNN uruchamiane 2000x → 50 sek!", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # Fast R-CNN + y = y_positions[1] + ax.text(0, y + 1.3, labels[1], fontsize=FS_LABEL, fontweight="bold") + draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS) + draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45) + draw_box( + ax, + 2.6, + y, + 1.5, + 0.9, + "CNN\nx1 (RAZ!)", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45) + draw_box( + ax, 4.7, y, 1.5, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS, fontweight="bold" + ) + draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45) + draw_box(ax, 6.8, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS) + draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45) + draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS) + ax.text( + 3.8, + y - 0.4, + "✓ CNN RAZ na cały obraz → 25x szybciej", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # Faster R-CNN + y = y_positions[2] + ax.text(0, y + 1.3, labels[2], fontsize=FS_LABEL, fontweight="bold") + draw_box( + ax, + 0.5, + y, + 1.5, + 0.9, + "CNN\nBackbone", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45) + draw_box(ax, 2.6, y, 1.5, 0.9, "Feature\nMap", fill=GRAY1, fontsize=FS) + draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45) + draw_box( + ax, + 4.7, + y, + 1.3, + 0.9, + "RPN\n(w sieci!)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 6.1, y + 0.45, 6.5, y + 0.45) + draw_box(ax, 6.6, y, 1.3, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS) + draw_arrow(ax, 8.0, y + 0.45, 8.4, y + 0.45) + draw_box(ax, 8.5, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS) + ax.text( + 5.0, + y - 0.4, + "✓ RPN zastępuje Selective Search → end-to-end", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q24_rcnn_evolution.png") + + +# ============================================================ +# 7. YOLO Grid +# ============================================================ +def draw_yolo_grid() -> None: + """Draw yolo grid.""" + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + fig.suptitle( + "YOLO — detekcja jednoetapowa (siatka SxS)", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + # Grid on image + ax = axes[0] + grid_size = 7 + ax.set_xlim(0, grid_size) + ax.set_ylim(0, grid_size) + for i in range(grid_size + 1): + ax.axhline(y=i, color=LN, lw=0.5, alpha=0.5) + ax.axvline(x=i, color=LN, lw=0.5, alpha=0.5) + ax.add_patch( + mpatches.Rectangle( + (0, 0), + grid_size, + grid_size, + facecolor=GRAY4, + edgecolor=LN, + lw=1.5, + ) + ) + # Highlight one cell + ax.add_patch(mpatches.Rectangle((3, 3), 1, 1, facecolor=GRAY2, edgecolor=LN, lw=2)) + # Object center dot + ax.plot(3.5, 3.5, "ko", markersize=8) + # Bounding box from that cell + ax.add_patch( + mpatches.Rectangle( + (2.0, 2.2), 3.0, 2.6, facecolor="none", edgecolor=LN, lw=2, linestyle="--" + ) + ) + ax.text( + 3.5, + 1.8, + "bbox z komórki (3,3)", + ha="center", + fontsize=FS_SMALL, + fontweight="bold", + ) + ax.set_aspect("equal") + ax.invert_yaxis() + ax.set_title("① Siatka 7x7\nna obrazie", fontsize=FS, fontweight="bold") + ax.set_xticks([]) + ax.set_yticks([]) + + _draw_yolo_cell_prediction(axes[1]) + + # Speed comparison + ax = axes[2] + ax.axis("off") + ax.set_xlim(0, 5) + ax.set_ylim(0, 5) + + methods = ["R-CNN", "Fast R-CNN", "Faster R-CNN", "YOLO", "YOLOv8"] + fps_vals = [0.02, 0.5, 5, 45, 100] + bar_colors = [GRAY3, GRAY3, GRAY3, GRAY2, GRAY1] + + for i, (m, f, c) in enumerate(zip(methods, fps_vals, bar_colors, strict=False)): + bar_w = f / 100 * 4.0 + y_pos = 4.0 - i * 0.8 + ax.add_patch( + mpatches.Rectangle( + (0.5, y_pos), max(bar_w, 0.1), 0.5, facecolor=c, edgecolor=LN, lw=0.8 + ) + ) + ax.text( + 0.4, + y_pos + 0.25, + m, + ha="right", + va="center", + fontsize=FS, + fontweight="bold", + ) + ax.text( + max(0.7, 0.5 + bar_w + 0.1), + y_pos + 0.25, + f"{f} fps", + ha="left", + va="center", + fontsize=FS, + ) + + ax.set_title( + "③ Porównanie szybkości\n(fps = klatki/sek)", fontsize=FS, fontweight="bold" + ) + + fig.tight_layout() + save_fig(fig, "q24_yolo_grid.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q31_common.py b/python_pkg/praca_magisterska_video/generate_images/_q31_common.py new file mode 100644 index 0000000..7654d1d --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q31_common.py @@ -0,0 +1,102 @@ +"""Common constants and utilities for Q31 diagrams.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +_logger = logging.getLogger(__name__) + +DPI = 300 +BG = "white" +LN = "black" +FS = 8 +FS_TITLE = 11 +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +GRAY1 = "#E8E8E8" +GRAY2 = "#D0D0D0" +GRAY3 = "#B8B8B8" +GRAY4 = "#F5F5F5" +GRAY5 = "#C0C0C0" + +# Number of regret table header columns before the max-regret column +_REGRET_HEADER_COLS = 4 +# Number of data state columns +_DATA_STATE_COLS = 3 +# Expected-value for the winning alternative +_WINNING_EV = 95 + + +def draw_box( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + text: str, + *, + fill: str = "white", + lw: float = 1.2, + fontsize: float = FS, + fontweight: str = "normal", + ha: str = "center", + va: str = "center", + rounded: bool = True, +) -> None: + """Draw a labeled box on the axes.""" + if rounded: + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.05", + lw=lw, + edgecolor=LN, + facecolor=fill, + ) + else: + rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h / 2, + text, + ha=ha, + va=va, + fontsize=fontsize, + fontweight=fontweight, + wrap=True, + ) + + +def draw_arrow( + ax: Axes, + x1: float, + y1: float, + x2: float, + y2: float, + *, + lw: float = 1.2, + style: str = "->", + color: str = LN, +) -> None: + """Draw an arrow between two points.""" + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={ + "arrowstyle": style, + "color": color, + "lw": lw, + }, + ) diff --git a/python_pkg/praca_magisterska_video/generate_images/_q31_criteria_comparison.py b/python_pkg/praca_magisterska_video/generate_images/_q31_criteria_comparison.py new file mode 100644 index 0000000..9855e30 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q31_criteria_comparison.py @@ -0,0 +1,256 @@ +"""Q31 Diagram 1: Payoff matrix + all criteria bar chart.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np + +from python_pkg.praca_magisterska_video.generate_images._q31_common import ( + _DATA_STATE_COLS, + BG, + DPI, + FS, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + OUTPUT_DIR, + _logger, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def _draw_payoff_table(ax: Axes) -> None: + """Draw the payoff matrix table on the left panel.""" + ax.axis("off") + ax.set_xlim(0, 6) + ax.set_ylim(0, 6) + ax.set_title( + "Macierz wypłat (tys. zł)", + fontsize=FS_TITLE, + fontweight="bold", + pad=8, + ) + + headers_col = ["", "S₁\n(dobra)", "S₂\n(średnia)", "S₃\n(zła)"] + rows = [ + ["A₁ (fabryka)", "200", "50", "-100"], + ["A₂ (sklep)", "80", "70", "40"], + ["A₃ (obligacje)", "30", "30", "30"], + ] + + col_w = [1.8, 1.2, 1.2, 1.2] + row_h = 0.7 + start_y = 4.5 + start_x = 0.2 + + # Draw header row + x = start_x + for j, h in enumerate(headers_col): + fill = GRAY2 if j > 0 else GRAY3 + rect = mpatches.Rectangle( + (x, start_y), + col_w[j], + row_h, + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + col_w[j] / 2, + start_y + row_h / 2, + h, + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + x += col_w[j] + + # Draw data rows + for i, row in enumerate(rows): + x = start_x + y = start_y - (i + 1) * row_h + for j, val in enumerate(row): + fill = GRAY4 if j == 0 else ("white" if i % 2 == 0 else GRAY1) + if val.startswith("-"): + fill = "#D8D8D8" + rect = mpatches.Rectangle( + (x, y), + col_w[j], + row_h, + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + fw = "bold" if j == 0 else "normal" + ax.text( + x + col_w[j] / 2, + y + row_h / 2, + val, + ha="center", + va="center", + fontsize=FS, + fontweight=fw, + ) + x += col_w[j] + + # Probability row for EV + x = start_x + y = start_y - 4 * row_h + probs = ["p (dla E[X]):", "0.5", "0.3", "0.2"] + for j, val in enumerate(probs): + fill = GRAY5 if j > 0 else GRAY3 + rect = mpatches.Rectangle( + (x, y), + col_w[j], + row_h * 0.7, + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + col_w[j] / 2, + y + row_h * 0.35, + val, + ha="center", + va="center", + fontsize=7, + fontweight="bold", + style="italic", + ) + x += col_w[j] + + +def _draw_criteria_bars(ax2: Axes) -> None: + """Draw the criteria comparison bar chart on the right panel.""" + criteria = [ + "E[X]", + "Laplace", + "Maximax", + "Maximin", + "Hurwicz\n\u03b1=0.6", + "Savage", + ] + + ev = [95, 69, 30] + laplace = [50, 63.3, 30] + maximax = [200, 80, 30] + maximin = [-100, 40, 30] + hurwicz = [80, 64, 30] + savage_maxregret = [140, 120, 170] + + winners = [0, 1, 0, 1, 0, 1] + + x_pos = np.arange(len(criteria)) + width = 0.22 + hatches = ["///", "...", "xxx"] + labels = ["A₁ (fabryka)", "A₂ (sklep)", "A₃ (obligacje)"] + + all_vals = [ + [ + ev[0], + laplace[0], + maximax[0], + maximin[0], + hurwicz[0], + savage_maxregret[0], + ], + [ + ev[1], + laplace[1], + maximax[1], + maximin[1], + hurwicz[1], + savage_maxregret[1], + ], + [ + ev[2], + laplace[2], + maximax[2], + maximin[2], + hurwicz[2], + savage_maxregret[2], + ], + ] + + for i in range(_DATA_STATE_COLS): + ax2.bar( + x_pos + (i - 1) * width, + all_vals[i], + width, + label=labels[i], + color="white", + edgecolor=LN, + hatch=hatches[i], + lw=0.8, + ) + + for c_idx in range(len(criteria)): + w = winners[c_idx] + val = all_vals[w][c_idx] + ax2.text( + x_pos[c_idx] + (w - 1) * width, + val + 5, + "★", + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + ) + + ax2.set_xticks(x_pos) + ax2.set_xticklabels(criteria, fontsize=7) + ax2.set_ylabel("Wartość kryterium", fontsize=8) + ax2.set_title( + "Porównanie kryteriów", + fontsize=FS_TITLE, + fontweight="bold", + pad=8, + ) + ax2.legend(fontsize=7, loc="upper right") + ax2.axhline(y=0, color=LN, lw=0.5, ls="-") + ax2.spines["top"].set_visible(False) + ax2.spines["right"].set_visible(False) + ax2.tick_params(labelsize=7) + + ax2.text( + 5, + -30, + "(Savage: niżej\n= lepiej)", + fontsize=6, + ha="center", + va="top", + style="italic", + ) + + +def draw_criteria_comparison() -> None: + """Draw payoff matrix and criteria comparison chart.""" + fig, axes = plt.subplots( + 1, + 2, + figsize=(8.27, 4.5), + gridspec_kw={"width_ratios": [1.2, 1]}, + ) + + _draw_payoff_table(axes[0]) + _draw_criteria_bars(axes[1]) + + plt.tight_layout() + outpath = str(Path(OUTPUT_DIR) / "q31_criteria_comparison.png") + fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) + plt.close(fig) + _logger.info(" Saved: %s", outpath) diff --git a/python_pkg/praca_magisterska_video/generate_images/_q31_ev_spectrum.py b/python_pkg/praca_magisterska_video/generate_images/_q31_ev_spectrum.py new file mode 100644 index 0000000..71a1a8f --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q31_ev_spectrum.py @@ -0,0 +1,289 @@ +"""Q31 Diagrams 5 & 6: Expected value + decision conditions spectrum.""" + +from __future__ import annotations + +from pathlib import Path + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt + +from python_pkg.praca_magisterska_video.generate_images._q31_common import ( + _WINNING_EV, + BG, + DPI, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + LN, + OUTPUT_DIR, + _logger, +) + + +def draw_expected_value() -> None: + """Draw expected value criterion with probability-weighted bars.""" + fig, axes = plt.subplots(1, 3, figsize=(8.27, 3.5), sharey=True) + fig.suptitle( + "Kryterium wartości oczekiwanej E[X]" " \u2014 rozkład wyników per alternatywa", + fontsize=FS_TITLE, + fontweight="bold", + y=1.02, + ) + + probs = [0.5, 0.3, 0.2] + alts = [ + ("A₁ (fabryka)", [200, 50, -100], 95), + ("A₂ (sklep)", [80, 70, 40], 69), + ("A₃ (obligacje)", [30, 30, 30], 30), + ] + + hatches = ["///", "...", "xxx"] + + for _idx, (ax, (name, vals, ev)) in enumerate(zip(axes, alts, strict=False)): + x_positions = [0, 0.6, 1.0] + widths = [p * 0.9 for p in probs] + + for i, (v, p, h) in enumerate(zip(vals, probs, hatches, strict=False)): + color = "white" if v >= 0 else GRAY2 + ax.bar( + x_positions[i], + v, + width=widths[i], + color=color, + edgecolor=LN, + hatch=h, + lw=0.8, + align="edge", + ) + offset = 8 if v >= 0 else -12 + ax.text( + x_positions[i] + widths[i] / 2, + v + offset, + f"{v}", + ha="center", + va="center", + fontsize=8, + fontweight="bold", + ) + contrib = v * p + ax.text( + x_positions[i] + widths[i] / 2, + v / 2, + f"{v}x{p}\n={contrib:.0f}", + ha="center", + va="center", + fontsize=6, + style="italic", + ) + + # Expected value line + ax.axhline(y=ev, color=LN, lw=2, ls="--") + ax.text( + 1.35, + ev, + f"E[X]={ev}", + fontsize=8, + fontweight="bold", + va="center", + ha="left", + bbox={ + "boxstyle": "round,pad=0.15", + "facecolor": GRAY1, + "edgecolor": LN, + }, + ) + + ax.set_title(name, fontsize=9, fontweight="bold") + ax.set_xticks([0.225, 0.735, 1.09]) + ax.set_xticklabels(["S₁", "S₂", "S₃"], fontsize=7) + ax.axhline(y=0, color=LN, lw=0.5) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.tick_params(labelsize=7) + + # Star on winner + if ev == _WINNING_EV: + ax.text( + 0.7, + ev + 20, + "★ MAX", + fontsize=9, + fontweight="bold", + ha="center", + va="bottom", + ) + + axes[0].set_ylabel("Wypłata (tys. zł)", fontsize=8) + + plt.tight_layout() + outpath = str(Path(OUTPUT_DIR) / "q31_expected_value.png") + fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) + plt.close(fig) + _logger.info(" Saved: %s", outpath) + + +def draw_conditions_spectrum() -> None: + """Draw decision conditions spectrum diagram.""" + fig, ax = plt.subplots(1, 1, figsize=(8.27, 3.5)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Warunki decyzyjne" " \u2014 spektrum wiedzy decydenta", + fontsize=FS_TITLE + 1, + fontweight="bold", + pad=10, + ) + + # Three zones + zones = [ + ( + 0.3, + 1.5, + 2.8, + 2.5, + "PEWNOŚĆ", + "white", + [ + "Znamy dokładny wynik", + "Przykład: lokata 5%", + "Metoda: po prostu wybierz", + "najlepszy wynik", + ], + ), + ( + 3.5, + 1.5, + 2.8, + 2.5, + "RYZYKO", + GRAY1, + [ + "Znamy wyniki I prawdop.", + "Przykład: gra w kości", + "Metoda: wartość", + "oczekiwana E[X]", + ], + ), + ( + 6.7, + 1.5, + 2.8, + 2.5, + "NIEPEWNOŚĆ", + GRAY3, + [ + "Znamy wyniki, ale", + "NIE znamy prawdop.", + "Metody: Laplace, maximax,", + "maximin, Hurwicz, Savage", + ], + ), + ] + + for x, y, w, h, title, fill, lines in zones: + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.1", + lw=2, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h - 0.3, + title, + ha="center", + va="center", + fontsize=11, + fontweight="bold", + ) + for i, line in enumerate(lines): + ax.text( + x + w / 2, + y + h - 0.7 - i * 0.4, + line, + ha="center", + va="center", + fontsize=7, + ) + + # Arrows between zones + ax.annotate( + "", + xy=(3.4, 2.75), + xytext=(3.15, 2.75), + arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, + ) + ax.annotate( + "", + xy=(6.6, 2.75), + xytext=(6.35, 2.75), + arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, + ) + + # Bottom: knowledge gradient bar + gradient_y = 0.5 + gradient_h = 0.5 + n_steps = 50 + for i in range(n_steps): + x = 0.3 + i * (9.2 / n_steps) + w = 9.2 / n_steps + 0.01 + gray_val = 1 - (i / n_steps) * 0.7 + rect = mpatches.Rectangle( + (x, gradient_y), + w, + gradient_h, + lw=0, + facecolor=str(gray_val), + ) + ax.add_patch(rect) + + rect = mpatches.Rectangle( + (0.3, gradient_y), + 9.2, + gradient_h, + lw=1.5, + edgecolor=LN, + facecolor="none", + ) + ax.add_patch(rect) + + ax.text( + 0.3, + gradient_y - 0.15, + "Dużo wiedzy", + fontsize=7, + ha="left", + va="top", + ) + ax.text( + 9.5, + gradient_y - 0.15, + "Mało wiedzy", + fontsize=7, + ha="right", + va="top", + ) + ax.text( + 4.95, + gradient_y + gradient_h / 2, + "POZIOM WIEDZY DECYDENTA", + fontsize=8, + fontweight="bold", + ha="center", + va="center", + color="white", + ) + + plt.tight_layout() + outpath = str(Path(OUTPUT_DIR) / "q31_conditions_spectrum.png") + fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) + plt.close(fig) + _logger.info(" Saved: %s", outpath) diff --git a/python_pkg/praca_magisterska_video/generate_images/_q31_hurwicz_mnemonic.py b/python_pkg/praca_magisterska_video/generate_images/_q31_hurwicz_mnemonic.py new file mode 100644 index 0000000..eeb2751 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q31_hurwicz_mnemonic.py @@ -0,0 +1,344 @@ +"""Q31 Diagrams 3 & 4: Hurwicz interpolation + criteria mnemonic map.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt +import numpy as np + +from python_pkg.praca_magisterska_video.generate_images._q31_common import ( + BG, + DPI, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + LN, + OUTPUT_DIR, + _logger, + draw_box, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def draw_hurwicz_interpolation() -> None: + """Draw Hurwicz alpha interpolation diagram.""" + fig, ax = plt.subplots(1, 1, figsize=(8.27, 4)) + ax.set_title( + "Kryterium Hurwicza" " \u2014 wpływ \u03b1 na wybór alternatywy", + fontsize=FS_TITLE + 1, + fontweight="bold", + pad=10, + ) + + alphas = np.linspace(0, 1, 200) + + v1 = alphas * 200 + (1 - alphas) * (-100) + v2 = alphas * 80 + (1 - alphas) * 40 + v3 = alphas * 30 + (1 - alphas) * 30 + + ax.plot( + alphas, + v1, + "k-", + lw=2, + label="A₁ (fabryka): V = 300\u03b1 - 100", + ) + ax.plot( + alphas, + v2, + "k--", + lw=2, + label="A₂ (sklep): V = 40\u03b1 + 40", + ) + ax.plot( + alphas, + v3, + "k:", + lw=2, + label="A₃ (obligacje): V = 30", + ) + + # Crossover A2=A1 + alpha_cross_12 = 140 / 260 + v_cross_12 = 40 * alpha_cross_12 + 40 + + ax.plot(alpha_cross_12, v_cross_12, "ko", markersize=8, zorder=5) + ax.annotate( + f"\u03b1 ≈ {alpha_cross_12:.2f}\nA₁ = A₂", + xy=(alpha_cross_12, v_cross_12), + xytext=(alpha_cross_12 + 0.12, v_cross_12 - 30), + fontsize=8, + fontweight="bold", + arrowprops={ + "arrowstyle": "->", + "color": LN, + "lw": 1, + }, + ) + + # Shade winning regions + ax.axvspan(0, alpha_cross_12, alpha=0.08, color="black") + ax.axvspan(alpha_cross_12, 1, alpha=0.15, color="black") + + ax.text( + alpha_cross_12 / 2, + -60, + "A₂ wygrywa\n(pesymistycznie)", + fontsize=8, + ha="center", + va="center", + bbox={ + "boxstyle": "round", + "facecolor": "white", + "edgecolor": LN, + }, + ) + ax.text( + (alpha_cross_12 + 1) / 2, + 160, + "A₁ wygrywa\n(optymistycznie)", + fontsize=8, + ha="center", + va="center", + bbox={ + "boxstyle": "round", + "facecolor": "white", + "edgecolor": LN, + }, + ) + + # Special alpha values + ax.axvline(x=0, color=LN, lw=0.5, ls=":") + ax.axvline(x=1, color=LN, lw=0.5, ls=":") + ax.text( + 0, + -115, + "\u03b1=0\nmaximin", + fontsize=7, + ha="center", + va="top", + fontweight="bold", + ) + ax.text( + 1, + -115, + "\u03b1=1\nmaximax", + fontsize=7, + ha="center", + va="top", + fontweight="bold", + ) + + ax.set_xlabel("Współczynnik optymizmu \u03b1", fontsize=9) + ax.set_ylabel("V(Aᵢ) = \u03b1·max + (1-\u03b1)·min", fontsize=9) + ax.legend(fontsize=8, loc="upper left") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.set_xlim(-0.05, 1.05) + ax.axhline(y=0, color=LN, lw=0.3, ls="-") + ax.tick_params(labelsize=8) + + plt.tight_layout() + outpath = str(Path(OUTPUT_DIR) / "q31_hurwicz_alpha.png") + fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) + plt.close(fig) + _logger.info(" Saved: %s", outpath) + + +def _draw_mnemonic_criteria_boxes(ax: Axes) -> None: + """Draw the 6 criteria boxes around the center.""" + criteria = [ + ( + 0, + 6.5, + 3, + 1.2, + "WARTOŚĆ OCZEKIWANA", + "\u201eMam prawdopodobieństwa\u201d", + "E[Aᵢ] = Σ pⱼ·aᵢⱼ", + ), + ( + 3.5, + 6.5, + 3, + 1.2, + "LAPLACE", + "\u201eWszystko po równo\u201d", + "V = Σaᵢⱼ / n", + ), + ( + 7, + 6.5, + 3, + 1.2, + "MAXIMAX", + "\u201eOptymista: max z max\u201d", + "max maxⱼ aᵢⱼ", + ), + ( + 0, + 0.5, + 3, + 1.2, + "MAXIMIN (Wald)", + "\u201ePesymista: max z min\u201d", + "max minⱼ aᵢⱼ", + ), + ( + 3.5, + 0.5, + 3, + 1.2, + "HURWICZ", + "\u201e\u03b1 pomiędzy\u201d", + "\u03b1·max + (1-\u03b1)·min", + ), + ( + 7, + 0.5, + 3, + 1.2, + "SAVAGE", + "\u201eMin max żalu\u201d", + "min maxⱼ rᵢⱼ", + ), + ] + + fills = [GRAY3, GRAY1, "white", "white", GRAY1, GRAY3] + + for i, (x, y, w, h, title, mnem, formula) in enumerate(criteria): + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.08", + lw=1.5, + edgecolor=LN, + facecolor=fills[i], + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h * 0.78, + title, + ha="center", + va="center", + fontsize=8, + fontweight="bold", + ) + ax.text( + x + w / 2, + y + h * 0.45, + mnem, + ha="center", + va="center", + fontsize=7, + style="italic", + ) + ax.text( + x + w / 2, + y + h * 0.15, + formula, + ha="center", + va="center", + fontsize=7, + fontweight="bold", + family="monospace", + ) + + # Arrows from center to each box + cx, cy = 5, 4 + bx = x + w / 2 + by_center = y + h / 2 + if by_center > cy: + ax.annotate( + "", + xy=(bx, y), + xytext=(cx, 4.5), + arrowprops={ + "arrowstyle": "->", + "color": LN, + "lw": 1, + "connectionstyle": "arc3,rad=0", + }, + ) + else: + ax.annotate( + "", + xy=(bx, y + h), + xytext=(cx, 3.5), + arrowprops={ + "arrowstyle": "->", + "color": LN, + "lw": 1, + "connectionstyle": "arc3,rad=0", + }, + ) + + # Labels on arrows + arrow_labels = [ + (1.2, 5.6, "znane p"), + (5, 5.6, "p = 1/n"), + (8.7, 5.6, "max ↑"), + (1.2, 2.5, "min ↑"), + (5, 2.5, "podaj \u03b1"), + (8.7, 2.5, "macierz\nżalu"), + ] + for lx, ly, ltext in arrow_labels: + ax.text( + lx, + ly, + ltext, + fontsize=7, + ha="center", + va="center", + bbox={ + "boxstyle": "round,pad=0.15", + "facecolor": "white", + "edgecolor": GRAY3, + "lw": 0.5, + }, + ) + + +def draw_criteria_mnemonic() -> None: + """Draw decision criteria mnemonic map diagram.""" + fig, ax = plt.subplots(1, 1, figsize=(8.27, 6)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 8) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Mapa mnemoniczna \u2014 6 kryteriów decyzyjnych", + fontsize=FS_TITLE + 2, + fontweight="bold", + pad=10, + ) + + # Central node + draw_box( + ax, + 3.5, + 3.5, + 3, + 1, + "MACIERZ\nWYPŁAT", + fill=GRAY2, + lw=2, + fontsize=11, + fontweight="bold", + ) + + _draw_mnemonic_criteria_boxes(ax) + + plt.tight_layout() + outpath = str(Path(OUTPUT_DIR) / "q31_criteria_mnemonic.png") + fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) + plt.close(fig) + _logger.info(" Saved: %s", outpath) diff --git a/python_pkg/praca_magisterska_video/generate_images/_q31_regret_matrix.py b/python_pkg/praca_magisterska_video/generate_images/_q31_regret_matrix.py new file mode 100644 index 0000000..427b559 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q31_regret_matrix.py @@ -0,0 +1,322 @@ +"""Q31 Diagram 2: Regret matrix construction step-by-step.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt + +from python_pkg.praca_magisterska_video.generate_images._q31_common import ( + _DATA_STATE_COLS, + _REGRET_HEADER_COLS, + BG, + DPI, + FS, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + LN, + OUTPUT_DIR, + _logger, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +def _draw_original_payoff( + ax: Axes, + start_y: float, + row_h: float, +) -> None: + """Draw the original payoff matrix (left side of regret fig).""" + ax.text( + 2.2, + 6.3, + "Krok 1: Macierz wypłat", + fontsize=9, + fontweight="bold", + ha="center", + va="center", + ) + + col_w = 1.0 + headers = ["", "S₁", "S₂", "S₃"] + data = [ + ["A₁", "200", "50", "-100"], + ["A₂", "80", "70", "40"], + ["A₃", "30", "30", "30"], + ] + start_x = 0.3 + + for j, h in enumerate(headers): + w = 0.7 if j == 0 else col_w + x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w) + rect = mpatches.Rectangle( + (x, start_y), + w, + row_h, + lw=1, + edgecolor=LN, + facecolor=GRAY2, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + start_y + row_h / 2, + h, + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + + for i, row in enumerate(data): + y = start_y - (i + 1) * row_h + for j, val in enumerate(row): + w = 0.7 if j == 0 else col_w + x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w) + fill = GRAY4 if j == 0 else "white" + rect = mpatches.Rectangle( + (x, y), + w, + row_h, + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + row_h / 2, + val, + ha="center", + va="center", + fontsize=FS, + ) + + # Max per column annotation + max_y = start_y - _DATA_STATE_COLS * row_h - 0.1 + col_maxes = ["max=200", "max=70", "max=40"] + for idx, label in enumerate(col_maxes): + ax.text( + start_x + 0.7 + (idx + 0.5) * col_w, + max_y, + label, + fontsize=7, + ha="center", + va="top", + fontweight="bold", + color="#333", + ) + + # Arrow + ax.annotate( + "", + xy=(5.0, 4.8), + xytext=(4.2, 4.8), + arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, + ) + ax.text( + 4.6, + 5.0, + "rᵢⱼ = max - aᵢⱼ", + fontsize=8, + ha="center", + va="bottom", + fontweight="bold", + ) + + +def _draw_regret_table( + ax: Axes, + start_y: float, + row_h: float, +) -> None: + """Draw the regret matrix (right side of regret fig).""" + ax.text( + 7.5, + 6.3, + "Krok 2: Macierz żalu", + fontsize=9, + fontweight="bold", + ha="center", + va="center", + ) + + regret_data = [ + ["A₁", "0", "20", "140"], + ["A₂", "120", "0", "0"], + ["A₃", "170", "40", "10"], + ] + headers2 = ["", "S₁", "S₂", "S₃", "max rᵢ"] + start_x2 = 5.3 + + for j, h in enumerate(headers2): + w = 0.7 if j == 0 else (0.9 if j < _REGRET_HEADER_COLS else 1.0) + x = start_x2 + if j == 0: + x = start_x2 + elif j <= _DATA_STATE_COLS: + x = start_x2 + 0.7 + (j - 1) * 0.9 + else: + x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9 + rect = mpatches.Rectangle( + (x, start_y), + w, + row_h, + lw=1, + edgecolor=LN, + facecolor=(GRAY2 if j < _REGRET_HEADER_COLS else GRAY3), + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + start_y + row_h / 2, + h, + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + + max_regrets = [140, 120, 170] + for i, row in enumerate(regret_data): + y = start_y - (i + 1) * row_h + for j, val in enumerate(row): + w = 0.7 if j == 0 else 0.9 + x = start_x2 + (0 if j == 0 else 0.7 + (j - 1) * 0.9) + fill = GRAY4 if j == 0 else "white" + if j > 0 and int(val) == max_regrets[i]: + fill = GRAY2 + rect = mpatches.Rectangle( + (x, y), + w, + row_h, + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + fw = "bold" if (j > 0 and int(val) == max_regrets[i]) else "normal" + ax.text( + x + w / 2, + y + row_h / 2, + val, + ha="center", + va="center", + fontsize=FS, + fontweight=fw, + ) + + # Max regret column + x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9 + w = 1.0 + is_winner = max_regrets[i] == min(max_regrets) + fill = "#C8C8C8" if is_winner else GRAY1 + rect = mpatches.Rectangle( + (x, y), + w, + row_h, + lw=1.5 if is_winner else 1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + marker = " ★" if is_winner else "" + ax.text( + x + w / 2, + y + row_h / 2, + f"{max_regrets[i]}{marker}", + ha="center", + va="center", + fontsize=FS, + fontweight="bold", + ) + + +def draw_regret_matrix() -> None: + """Draw the regret matrix construction diagram.""" + fig, ax = plt.subplots(1, 1, figsize=(8.27, 5)) + ax.axis("off") + ax.set_xlim(0, 10) + ax.set_ylim(0, 7) + ax.set_title( + "Kryterium Savage'a \u2014 budowa macierzy żalu", + fontsize=FS_TITLE + 1, + fontweight="bold", + pad=10, + ) + + start_y = 5.5 + row_h = 0.55 + + _draw_original_payoff(ax, start_y, row_h) + _draw_regret_table(ax, start_y, row_h) + + # Bottom conclusion + ax.text( + 5.0, + 2.8, + "Krok 3: Wybierz min z max żalu" " → A₂ (max żal = 120)", + fontsize=10, + ha="center", + va="center", + fontweight="bold", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY1, + "edgecolor": LN, + "lw": 1.5, + }, + ) + + # Interpretation examples + ax.text( + 5.0, + 2.0, + "Interpretacja żalu: r₁₃ = 140 oznacza:\n" + "\u201eGdyby nastąpił S₃ (zła koniunktura)," + " a wybrałbym A₁,\n" + "żałowałbym, bo najlepszą opcją byłoby" + " A₂ z wynikiem 40 \u2014 traciłbym 140\u201d", + fontsize=7.5, + ha="center", + va="center", + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + "lw": 0.8, + }, + ) + + # Mnemonic + ax.text( + 5.0, + 0.8, + "Mnemonik: Savage = \u201eŻal jak nóż\u201d\n" + "Maksymalny żal to nóż " + "\u2014 wybierz opcję z NAJMNIEJSZYM nożem", + fontsize=8, + ha="center", + va="center", + fontweight="bold", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": "white", + "edgecolor": LN, + "lw": 1, + }, + ) + + plt.tight_layout() + outpath = str(Path(OUTPUT_DIR) / "q31_regret_matrix.png") + fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) + plt.close(fig) + _logger.info(" Saved: %s", outpath) diff --git a/python_pkg/praca_magisterska_video/generate_images/_q9_basics.py b/python_pkg/praca_magisterska_video/generate_images/_q9_basics.py new file mode 100644 index 0000000..2a45053 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q9_basics.py @@ -0,0 +1,448 @@ +"""Q9 diagrams 1-6: process/thread basics, memory, states, PCB, speed.""" + +from __future__ import annotations + +from _q9_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + draw_table, + save_fig, +) +import matplotlib.pyplot as plt + + +# ============================================================ +# 1. Process vs Thread comparison table +# ============================================================ +def gen_process_vs_thread() -> None: + """Gen process vs thread.""" + fig, ax = plt.subplots(figsize=(7.5, 4.5)) + ax.set_xlim(0, 10) + ax.set_ylim(-4.5, 1.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Proces vs Wątek — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + headers = ["Cecha", "Proces", "Wątek"] + col_w = [2.5, 3.5, 3.5] + rows = [ + ["Pamięć", "Własna, izolowana", "Współdzielona (heap)"], + ["Tworzenie", "~1-10 ms", "~10-100 μs (100x szybciej)"], + ["Przełączanie", "~1-5 μs (TLB flush)", "~0.1-0.5 μs (10x)"], + ["Komunikacja", "IPC (pipe, socket, shm)", "Bezpośrednia (wspólna pam.)"], + ["Izolacja", "Pełna — awaria izolowana", "Brak — może zabić proces"], + ["Zastosowanie", "Bezpieczeństwo, izolacja", "Wydajność, współdzielenie"], + ] + draw_table( + ax, + headers, + rows, + x0=0.25, + y0=0.8, + col_widths=col_w, + row_h=0.55, + fontsize=7.5, + header_fontsize=FS_LABEL, + ) + + # Analogy at bottom + ax.text( + 5.0, + -4.2, + "Analogia: Proces = mieszkanie (własny adres) " + "Wątek = pokój w mieszkaniu (wspólna kuchnia = heap)", + ha="center", + fontsize=FS, + style="italic", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q9_process_vs_thread.png") + + +# ============================================================ +# 2. Memory segments layout +# ============================================================ +def gen_memory_layout() -> None: + """Gen memory layout.""" + fig, ax = plt.subplots(figsize=(6, 5)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 8) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Segmenty pamięci procesu", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + segments = [ + ("STACK ↓", "zmienne lokalne, adresy\npowrotu (każdy wątek WŁASNY)", GRAY1), + ("...", "(wolna przestrzeń)", "white"), + ("HEAP ↑", "malloc/new — dynamiczna\nalokacja (współdzielony)", GRAY4), + ("BSS", "zmienne globalne\nniezainicjalizowane (zerowane)", GRAY2), + ("DATA", "zmienne globalne\nzainicjalizowane", GRAY3), + ("TEXT", "kod maszynowy\n(read-only, współdzielony)", GRAY5), + ] + bx, bw = 2.0, 2.5 + seg_h = 0.9 + gap = 0.05 + top_y = 7.0 + + for i, (name, desc, color) in enumerate(segments): + y = top_y - i * (seg_h + gap) + draw_box( + ax, + bx, + y, + bw, + seg_h, + name, + fill=color, + fontsize=FS_LABEL, + fontweight="bold", + rounded=False, + ) + ax.text(bx + bw + 0.3, y + seg_h / 2, desc, fontsize=7.5, va="center") + + # Address labels + ax.text( + bx - 0.2, + top_y + seg_h / 2, + "wysoki\nadres", + fontsize=FS_SMALL, + va="center", + ha="right", + style="italic", + ) + bottom_y = top_y - 5 * (seg_h + gap) + ax.text( + bx - 0.2, + bottom_y + seg_h / 2, + "niski\nadres", + fontsize=FS_SMALL, + va="center", + ha="right", + style="italic", + ) + + # Arrows for growth + ax.annotate( + "", + xy=(bx - 0.5, top_y - 0.1), + xytext=(bx - 0.5, top_y + seg_h + 0.1), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(bx - 0.9, top_y + 0.4, "rośnie\nw dół", fontsize=FS_SMALL, ha="center") + + heap_y = top_y - 2 * (seg_h + gap) + ax.annotate( + "", + xy=(bx - 0.5, heap_y + seg_h + 0.1), + xytext=(bx - 0.5, heap_y - 0.1), + arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, + ) + ax.text(bx - 0.9, heap_y + 0.5, "rośnie\nw górę", fontsize=FS_SMALL, ha="center") + + save_fig(fig, "q9_memory_layout.png") + + +# ============================================================ +# 3. Process states diagram +# ============================================================ +def gen_process_states() -> None: + """Gen process states.""" + fig, ax = plt.subplots(figsize=(7, 3.5)) + ax.set_xlim(0, 12) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Stany procesu — diagram przejść", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + states = { + "NEW": (1.0, 2.5), + "READY": (3.5, 2.5), + "RUNNING": (6.5, 2.5), + "BLOCKED": (6.5, 0.5), + "TERMINATED": (10.0, 2.5), + } + fills = { + "NEW": GRAY4, + "READY": GRAY1, + "RUNNING": GRAY3, + "BLOCKED": GRAY2, + "TERMINATED": GRAY5, + } + bw, bh = 1.8, 0.9 + for name, (x, y) in states.items(): + draw_box( + ax, + x, + y, + bw, + bh, + name, + fill=fills[name], + fontsize=FS_LABEL, + fontweight="bold", + ) + + # Transitions + transitions = [ + ("NEW", "READY", "admit"), + ("READY", "RUNNING", "dispatch\n(scheduler)"), + ("RUNNING", "TERMINATED", "exit"), + ("RUNNING", "BLOCKED", "I/O wait"), + ] + for src, dst, label in transitions: + sx, sy = states[src] + dx, dy = states[dst] + if sy == dy: # horizontal + draw_arrow(ax, sx + bw, sy + bh / 2, dx, dy + bh / 2, lw=1.5) + mx = (sx + bw + dx) / 2 + ax.text( + mx, + sy + bh / 2 + 0.25, + label, + fontsize=FS_SMALL, + ha="center", + va="bottom", + ) + else: # vertical + draw_arrow(ax, sx + bw / 2, sy, dx + bw / 2, dy + bh, lw=1.5) + ax.text( + sx + bw + 0.2, + (sy + dy + bh) / 2, + label, + fontsize=FS_SMALL, + ha="left", + va="center", + ) + + # BLOCKED → READY + bx, by = states["BLOCKED"] + rx, ry = states["READY"] + ax.annotate( + "", + xy=(rx + bw / 2, ry), + xytext=(bx - 0.3, by + bh / 2), + arrowprops={ + "arrowstyle": "->", + "lw": 1.5, + "color": LN, + "connectionstyle": "arc3,rad=0.3", + }, + ) + ax.text(3.5, 0.7, "I/O done", fontsize=FS_SMALL, ha="center") + + # RUNNING → READY (preemption) + rux, ruy = states["RUNNING"] + draw_arrow(ax, rux, ruy + bh, rx + bw, ry + bh, lw=1.2) + ax.text(5.0, 3.7, "preempt /\ntimeout", fontsize=FS_SMALL, ha="center") + + save_fig(fig, "q9_process_states.png") + + +# ============================================================ +# 4. Thread structure within process +# ============================================================ +def gen_thread_structure() -> None: + """Gen thread structure.""" + fig, ax = plt.subplots(figsize=(8, 4.5)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 6) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Wątki wewnątrz procesu (PID=42)", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + # Shared memory region + draw_box(ax, 0.5, 3.5, 9.0, 1.8, "", fill=GRAY1, rounded=False, lw=2) + ax.text(5.0, 5.0, "WSPÓŁDZIELONE", fontsize=FS, fontweight="bold", ha="center") + + labels_shared = ["TEXT", "DATA", "BSS", "HEAP", "pliki", "PID"] + for i, lab in enumerate(labels_shared): + x = 1.0 + i * 1.4 + draw_box( + ax, + x, + 3.8, + 1.1, + 0.6, + lab, + fill=GRAY3, + fontsize=FS, + fontweight="bold", + rounded=False, + ) + ax.text(x + 0.55, 4.6, lab, fontsize=FS_SMALL, ha="center", color="#555555") + + # Per-thread regions + draw_box( + ax, 0.5, 0.5, 9.0, 2.7, "", fill="white", rounded=False, lw=2, linestyle="--" + ) + ax.text( + 5.0, 2.95, "PRYWATNE (każdy wątek)", fontsize=FS, fontweight="bold", ha="center" + ) + + for i in range(3): + x = 1.0 + i * 3.0 + tid = i + 1 + draw_box(ax, x, 0.7, 2.3, 2.0, "", fill=GRAY4, rounded=False) + ax.text( + x + 1.15, + 2.4, + f"Wątek {tid}", + fontsize=FS_LABEL, + fontweight="bold", + ha="center", + ) + items = [f"stos_{tid}", f"rejestry_{tid}", f"PC_{tid}", f"TID={40 + tid}"] + for j, item in enumerate(items): + ax.text( + x + 1.15, + 2.0 - j * 0.35, + item, + fontsize=FS_SMALL, + ha="center", + family="monospace", + ) + + save_fig(fig, "q9_thread_structure.png") + + +# ============================================================ +# 5. PCB structure +# ============================================================ +def gen_pcb_structure() -> None: + """Gen pcb structure.""" + fig, ax = plt.subplots(figsize=(5, 3.5)) + ax.set_xlim(0, 8) + ax.set_ylim(0, 5.5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "PCB (Process Control Block)", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + fields = [ + ("PID", "42"), + ("Stan", "READY / RUNNING / BLOCKED"), + ("Rejestry CPU", "EAX, EBX, ESP, EIP ..."), + ("Tablice stron", "mapowanie wirtualne → fizyczne"), + ("Otwarte pliki", "fd[0], fd[1], fd[2] ..."), + ("Priorytety", "nice value, scheduling class"), + ("Statystyki", "CPU time, I/O count"), + ] + + top_y = 4.8 + for i, (field, value) in enumerate(fields): + y = top_y - i * 0.55 + draw_box( + ax, + 0.5, + y, + 2.2, + 0.45, + field, + fill=GRAY2, + fontsize=FS, + fontweight="bold", + rounded=False, + ) + draw_box(ax, 2.7, y, 4.5, 0.45, value, fill=GRAY4, fontsize=FS, rounded=False) + + ax.text( + 4.0, + 0.3, + "Context switch = zapisz PCB starego → wczytaj PCB nowego", + fontsize=FS_SMALL, + ha="center", + style="italic", + ) + + save_fig(fig, "q9_pcb_structure.png") + + +# ============================================================ +# 6. Speed comparison +# ============================================================ +def gen_speed_comparison() -> None: + """Gen speed comparison.""" + fig, axes = plt.subplots(1, 2, figsize=(9, 3.5)) + fig.suptitle( + "Szybkość — procesy vs wątki (benchmarki Linux)", + fontsize=FS_TITLE, + fontweight="bold", + ) + + # Creation time panel + ax = axes[0] + ops = ["fork()\n(nowy proces)", "pthread_create()\n(nowy wątek)"] + times = [3.0, 0.05] # ms + colors = [GRAY3, GRAY1] + bars = ax.barh(ops, times, color=colors, edgecolor=LN, height=0.5, linewidth=1.2) + ax.set_xlabel("Czas [ms]", fontsize=FS) + ax.set_title("Tworzenie", fontsize=FS_LABEL, fontweight="bold") + ax.set_xlim(0, 4.5) + for bar, t in zip(bars, times, strict=False): + ax.text( + bar.get_width() + 0.1, + bar.get_y() + bar.get_height() / 2, + f"{t} ms", + va="center", + fontsize=FS, + ) + ax.text( + 2.5, + -0.6, + "~100x szybciej", + fontsize=FS, + ha="center", + fontweight="bold", + transform=ax.get_xaxis_transform(), + ) + ax.tick_params(labelsize=FS) + + # Right: context switch + ax = axes[1] + ops2 = ["Proces→Proces\n(TLB flush)", "Wątek→Wątek\n(TLB warm)"] + times2 = [3000, 300] # ns + bars2 = ax.barh(ops2, times2, color=colors, edgecolor=LN, height=0.5, linewidth=1.2) + ax.set_xlabel("Czas [ns]", fontsize=FS) + ax.set_title("Przełączanie kontekstu", fontsize=FS_LABEL, fontweight="bold") + ax.set_xlim(0, 4500) + for bar, t in zip(bars2, times2, strict=False): + ax.text( + bar.get_width() + 50, + bar.get_y() + bar.get_height() / 2, + f"{t} ns", + va="center", + fontsize=FS, + ) + ax.text( + 2500, + -0.6, + "~10x szybciej", + fontsize=FS, + ha="center", + fontweight="bold", + transform=ax.get_xaxis_transform(), + ) + ax.tick_params(labelsize=FS) + + fig.tight_layout(rect=[0, 0.05, 1, 0.92]) + save_fig(fig, "q9_speed_comparison.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q9_classic_sync.py b/python_pkg/praca_magisterska_video/generate_images/_q9_classic_sync.py new file mode 100644 index 0000000..4932505 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q9_classic_sync.py @@ -0,0 +1,420 @@ +"""Q9 diagrams 14-16: classic sync problems, mechanism comparison, semaphore.""" + +from __future__ import annotations + +from _q9_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + OCCUPIED_SLOTS, + draw_arrow, + draw_box, + draw_table, + save_fig, +) +import matplotlib.pyplot as plt +import numpy as np + + +# ============================================================ +# 14. Bounded buffer + readers-writers + philosophers +# ============================================================ +def _draw_bounded_buffer_panel(ax: plt.Axes) -> None: + """Draw the bounded-buffer (producer-consumer) panel.""" + ax.set_xlim(0, 8) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Producent-Konsument\n(Bounded Buffer, N=4)", fontsize=FS, fontweight="bold" + ) + + draw_box( + ax, + 0.2, + 4.0, + 2.0, + 1.2, + "Producent\nP(empty)\nP(mutex)\nwstaw()\nV(mutex)\nV(full)", + fill=GRAY1, + fontsize=5.5, + ) + items = ["A", "B", "", ""] + for i, item in enumerate(items): + x = 2.8 + i * 0.9 + fill = GRAY3 if item else "white" + draw_box( + ax, + x, + 4.3, + 0.9, + 0.7, + item, + fill=fill, + fontsize=FS, + fontweight="bold", + rounded=False, + ) + ax.text(4.6, 5.2, "Bufor (N=4)", fontsize=FS_SMALL, ha="center", fontweight="bold") + draw_box( + ax, + 6.0, + 4.0, + 2.0, + 1.2, + "Konsument\nP(full)\nP(mutex)\npobierz()\nV(mutex)\nV(empty)", + fill=GRAY4, + fontsize=5.5, + ) + draw_arrow(ax, 2.2, 4.6, 2.8, 4.65, lw=1.2) + draw_arrow(ax, 6.4, 4.65, 6.0, 4.6, lw=1.2) + + sems = [("mutex = 1", GRAY2), ("empty = N", GRAY1), ("full = 0", GRAY3)] + for i, (s, c) in enumerate(sems): + draw_box( + ax, + 2.0, + 2.5 - i * 0.6, + 4.0, + 0.45, + s, + fill=c, + fontsize=FS_SMALL, + fontweight="bold", + ) + + ax.text( + 4.0, + 0.5, + "KOLEJNOŚĆ: P(empty/full)\nPRZED P(mutex)!\nOdwrotnie = DEADLOCK", + fontsize=5.5, + ha="center", + fontweight="bold", + color="#C62828", + bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, + ) + + +def _draw_readers_writers_panel(ax: plt.Axes) -> None: + """Draw the readers-writers panel.""" + ax.set_xlim(0, 8) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Czytelnicy-Pisarze\n(Readers-Writers)", fontsize=FS, fontweight="bold" + ) + + draw_box( + ax, + 2.5, + 3.5, + 3.0, + 1.5, + "Dane\n(współdzielone)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + + for i in range(3): + x = 0.3 + i * 1.0 + draw_box( + ax, + x, + 5.5, + 0.8, + 0.7, + f"R{i + 1}", + fill=GRAY4, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, x + 0.4, 5.5, 3.0 + i * 0.5, 5.0, lw=1) + + ax.text( + 1.5, + 6.5, + "Czytelnicy (wielu naraz)", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + ) + + draw_box( + ax, 5.5, 5.5, 1.5, 0.7, "Pisarz", fill=GRAY5, fontsize=FS, fontweight="bold" + ) + draw_arrow(ax, 6.25, 5.5, 5.0, 5.0, lw=1.5) + ax.text( + 6.25, + 6.5, + "WYŁĄCZNY", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + color="#C62828", + ) + + rules = [ + "Wielu czytelników = OK", + "Jeden pisarz = wyłączny", + "Czytelnik + Pisarz = ✗", + "Problem: pisarze głodują", + ] + for i, r in enumerate(rules): + ax.text(4.0, 2.5 - i * 0.45, r, fontsize=FS_SMALL, ha="center") + + ax.text( + 4.0, + 0.5, + "Rozwiązanie:\nrw_mutex + count_mutex\n+ zmienna readers", + fontsize=5.5, + ha="center", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + +def _draw_philosophers_panel(ax: plt.Axes) -> None: + """Draw the dining-philosophers panel.""" + ax.set_xlim(0, 8) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Ucztujący filozofowie\n(Dining Philosophers)", fontsize=FS, fontweight="bold" + ) + + cx, cy, r = 4.0, 3.8, 1.8 + table = plt.Circle((cx, cy), 0.8, fill=True, facecolor=GRAY2, edgecolor=LN, lw=1.5) + ax.add_patch(table) + ax.text(cx, cy, "Stół", fontsize=FS, ha="center", fontweight="bold") + + for i in range(5): + angle = np.pi / 2 + i * 2 * np.pi / 5 + px = cx + r * np.cos(angle) + py = cy + r * np.sin(angle) + circle = plt.Circle( + (px, py), 0.35, fill=True, facecolor=GRAY1, edgecolor=LN, lw=1.2 + ) + ax.add_patch(circle) + ax.text( + px, py, f"F{i}", ha="center", va="center", fontsize=FS, fontweight="bold" + ) + + fork_angle = np.pi / 2 + (i + 0.5) * 2 * np.pi / 5 + fx = cx + (r * 0.6) * np.cos(fork_angle) + fy = cy + (r * 0.6) * np.sin(fork_angle) + ax.plot( + [fx - 0.1, fx + 0.1], + [fy - 0.15, fy + 0.15], + color=LN, + lw=2.5, + solid_capstyle="round", + ) + ax.text(fx + 0.2, fy + 0.15, f"w{i}", fontsize=5, color="#555555") + + rules = [ + "Jedzenie = 2 widelce", + "Naiwne → DEADLOCK", + "Fix: F4 bierze odwrotnie", + "Alt: semafor(4)", + ] + for i, r in enumerate(rules): + ax.text(4.0, 1.2 - i * 0.35, r, fontsize=FS_SMALL, ha="center") + + +def gen_classic_problems() -> None: + """Gen classic problems.""" + fig, axes = plt.subplots(1, 3, figsize=(12, 5)) + fig.suptitle( + "Klasyczne problemy synchronizacji", fontsize=FS_TITLE, fontweight="bold" + ) + + _draw_bounded_buffer_panel(axes[0]) + _draw_readers_writers_panel(axes[1]) + _draw_philosophers_panel(axes[2]) + + fig.tight_layout(rect=[0, 0, 1, 0.88]) + save_fig(fig, "q9_classic_problems.png") + + +# ============================================================ +# 15. Sync mechanisms comparison + mutex/sem/spinlock +# ============================================================ +def gen_sync_comparison() -> None: + """Gen sync comparison.""" + fig, axes = plt.subplots(2, 1, figsize=(9, 7)) + + # Top: comparison table + ax = axes[0] + ax.set_xlim(0, 11.5) + ax.set_ylim(-5, 1) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Mechanizmy synchronizacji — porównanie", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + headers = ["Mechanizm", "Opis", "Kiedy używać"] + col_w = [2.5, 4.5, 4.0] + rows = [ + ["Mutex", "Zamek: 1 wątek w sekcji", "Sekcja krytyczna"], + ["Semafor(n)", "Licznik: max n wątków", "Ograniczone zasoby (n miejsc)"], + ["Monitor", "Obiekt z wbudowanym mutex", "Java synchronized"], + ["Cond. Variable", "wait()/signal() na warunek", "Producent-konsument"], + ["Spinlock", "Aktywne czekanie (busy-wait)", "Bardzo krótkie sekcje (<1 μs)"], + ["RW Lock", "Wielu czytelników LUB 1 pisarz", "Bazy danych, cache"], + ["Barrier", "Czekaj aż wszyscy dotrą", "Obliczenia równoległe"], + ] + draw_table( + ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7 + ) + + # Bottom: mutex vs semafor vs spinlock + ax = axes[1] + ax.set_xlim(0, 12) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Mutex vs Semafor vs Spinlock", fontsize=FS_TITLE, fontweight="bold", pad=5 + ) + + # Mutex + draw_box(ax, 0.3, 2.5, 3.5, 2.0, "", fill=GRAY4) + ax.text(2.05, 4.2, "MUTEX", fontsize=FS_LABEL, fontweight="bold", ha="center") + ax.text(2.05, 3.6, "= klucz do łazienki\n(1 osoba)", fontsize=FS, ha="center") + ax.text( + 2.05, + 2.8, + "Wątek ZASYPIA gdy czeka\nOS go obudzi (~μs)", + fontsize=FS_SMALL, + ha="center", + style="italic", + ) + + # Semafor + draw_box(ax, 4.3, 2.5, 3.5, 2.0, "", fill=GRAY1) + ax.text(6.05, 4.2, "SEMAFOR(n)", fontsize=FS_LABEL, fontweight="bold", ha="center") + ax.text( + 6.05, 3.6, "= parking na n miejsc\n(n wątków naraz)", fontsize=FS, ha="center" + ) + ax.text( + 6.05, + 2.8, + "Semafor(1) = mutex\nP() = zmniejsz, V() = zwiększ", + fontsize=FS_SMALL, + ha="center", + style="italic", + ) + + # Spinlock + draw_box(ax, 8.3, 2.5, 3.5, 2.0, "", fill=GRAY2) + ax.text(10.05, 4.2, "SPINLOCK", fontsize=FS_LABEL, fontweight="bold", ha="center") + ax.text( + 10.05, 3.6, "= obrotowe drzwi\n(kręcisz się w kółko)", fontsize=FS, ha="center" + ) + ax.text( + 10.05, + 2.8, + "Wątek KRĘCI się w pętli\nLepszy gdy sekcja < 1 μs", + fontsize=FS_SMALL, + ha="center", + style="italic", + ) + + # Dividing rule + ax.text( + 6.0, + 1.5, + "Reguła kciuka: sekcja > 1 μs → MUTEX | " + "sekcja < 1 μs → SPINLOCK | n jednocześnie → SEMAFOR(n)", + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + fig.tight_layout() + save_fig(fig, "q9_sync_comparison.png") + + +# ============================================================ +# 16. Semaphore concept diagram +# ============================================================ +def gen_semaphore_concept() -> None: + """Gen semaphore concept.""" + fig, ax = plt.subplots(figsize=(6, 3)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Semafor — koncepcja (parking na 3 miejsca)", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Parking slots + for i in range(3): + x = 2.0 + i * 2.0 + occupied = i < OCCUPIED_SLOTS + fill = GRAY3 if occupied else "white" + label = f"Wątek {i + 1}" if occupied else "(wolne)" + draw_box( + ax, + x, + 2.5, + 1.5, + 1.2, + label, + fill=fill, + fontsize=FS, + fontweight="bold" if occupied else "normal", + rounded=False, + ) + + ax.text( + 5.0, + 4.2, + "semafor(3): counter = 1 (jedno wolne miejsce)", + fontsize=FS, + ha="center", + fontweight="bold", + ) + + # Waiting thread + draw_box( + ax, + 0.2, + 0.5, + 1.5, + 0.8, + "Wątek 4\nP() → czeka", + fill="#F8D7DA", + fontsize=FS_SMALL, + ) + draw_arrow(ax, 1.7, 0.9, 2.0, 2.5, lw=1.2, color="#C62828") + + ax.text( + 5.0, + 0.6, + "P() = counter-- (jeśli 0 → czekaj)\nV() = counter++ (obudź czekającego)", + fontsize=FS, + ha="center", + family="monospace", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q9_semaphore_concept.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q9_common.py b/python_pkg/praca_magisterska_video/generate_images/_q9_common.py new file mode 100644 index 0000000..6a714ce --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q9_common.py @@ -0,0 +1,200 @@ +"""Common utilities and constants for Q9 diagram generation. + +Monochrome, A4-printable PNGs (300 DPI). +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib as mpl + +mpl.use("Agg") + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.figure import Figure + +_logger = logging.getLogger(__name__) + +DPI = 300 +BG = "white" +LN = "black" +FS = 8 +FS_TITLE = 11 +FS_SMALL = 6.5 +FS_LABEL = 9 +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +GRAY1 = "#E8E8E8" +GRAY2 = "#D0D0D0" +GRAY3 = "#B8B8B8" +GRAY4 = "#F5F5F5" +GRAY5 = "#C0C0C0" +OCCUPIED_SLOTS = 2 + + +def draw_box( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + text: str, + fill: str = "white", + lw: float = 1.2, + fontsize: float = FS, + fontweight: str = "normal", + ha: str = "center", + va: str = "center", + *, + rounded: bool = True, + edgecolor: str = LN, + linestyle: str = "-", +) -> None: + """Draw box.""" + if rounded: + rect = FancyBboxPatch( + (x, y), + w, + h, + boxstyle="round,pad=0.05", + lw=lw, + edgecolor=edgecolor, + facecolor=fill, + linestyle=linestyle, + ) + else: + rect = mpatches.Rectangle( + (x, y), + w, + h, + lw=lw, + edgecolor=edgecolor, + facecolor=fill, + linestyle=linestyle, + ) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h / 2, + text, + ha=ha, + va=va, + fontsize=fontsize, + fontweight=fontweight, + wrap=True, + ) + + +def draw_arrow( + ax: Axes, + x1: float, + y1: float, + x2: float, + y2: float, + lw: float = 1.2, + style: str = "->", + color: str = LN, +) -> None: + """Draw arrow.""" + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": style, "color": color, "lw": lw}, + ) + + +def draw_double_arrow( + ax: Axes, + x1: float, + y1: float, + x2: float, + y2: float, + lw: float = 1.2, + color: str = LN, +) -> None: + """Draw double arrow.""" + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": "<->", "color": color, "lw": lw}, + ) + + +def save_fig(fig: Figure, name: str) -> None: + """Save fig.""" + path = str(Path(OUTPUT_DIR) / name) + fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15) + plt.close(fig) + _logger.info(" Saved: %s", path) + + +def draw_table( + ax: Axes, + headers: list[str], + rows: list[list[str]], + x0: float, + y0: float, + col_widths: list[float], + row_h: float = 0.4, + header_fill: str = GRAY2, + row_fills: list[str] | None = None, + fontsize: float = FS, + header_fontsize: float | None = None, +) -> None: + """Draw a clean table on axes.""" + if header_fontsize is None: + header_fontsize = fontsize + len(headers) + len(rows) + sum(col_widths) + + # Header + cx = x0 + for j, hdr in enumerate(headers): + draw_box( + ax, + cx, + y0, + col_widths[j], + row_h, + hdr, + fill=header_fill, + fontsize=header_fontsize, + fontweight="bold", + rounded=False, + ) + cx += col_widths[j] + + # Rows + for i, row in enumerate(rows): + cy = y0 - (i + 1) * row_h + cx = x0 + fill = GRAY4 if (i % 2 == 0) else "white" + if row_fills and i < len(row_fills): + fill = row_fills[i] + for j, cell in enumerate(row): + fw = "bold" if j == 0 else "normal" + draw_box( + ax, + cx, + cy, + col_widths[j], + row_h, + cell, + fill=fill, + fontsize=fontsize, + fontweight=fw, + rounded=False, + ) + cx += col_widths[j] diff --git a/python_pkg/praca_magisterska_video/generate_images/_q9_ipc.py b/python_pkg/praca_magisterska_video/generate_images/_q9_ipc.py new file mode 100644 index 0000000..825372e --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q9_ipc.py @@ -0,0 +1,212 @@ +"""Q9 diagrams 7-9: IPC mechanisms and scenario tables.""" + +from __future__ import annotations + +from _q9_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + draw_arrow, + draw_box, + draw_double_arrow, + draw_table, + save_fig, +) +import matplotlib.pyplot as plt + + +# ============================================================ +# 7. Scenario table (when to use process vs thread) +# ============================================================ +def gen_scenario_table() -> None: + """Gen scenario table.""" + fig, ax = plt.subplots(figsize=(8.5, 4.5)) + ax.set_xlim(0, 11) + ax.set_ylim(-5.5, 1) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Kiedy proces, kiedy wątek? — typowe scenariusze", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + headers = ["Scenariusz", "Wybór", "Dlaczego?"] + col_w = [3.5, 2.5, 4.5] + rows = [ + ["Serwer WWW (Apache)", "Proces", "izolacja klientów"], + ["Serwer WWW (nginx)", "Wątek / async", "szybkość, cooperacja"], + ["Przeglądarka (karty)", "Proces", "crash isolation"], + ["Przeglądarka (JS+render)", "Wątek", "współdzielony DOM"], + ["Gra (fizyka+rendering)", "Wątek", "współdzielony świat gry"], + ["Kompilacja (make -j8)", "Proces", "izolacja, prostota"], + ["Baza danych (zapytania)", "Wątek", "współdzielony cache"], + ["Microservices", "Proces (kontener)", "izolacja, deployment"], + ] + draw_table( + ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7 + ) + + save_fig(fig, "q9_scenario_table.png") + + +# ============================================================ +# 8. IPC details: pipe, shared memory, socket (3-panel) +# ============================================================ +def gen_ipc_details() -> None: + """Gen ipc details.""" + fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) + fig.suptitle("Mechanizmy IPC — szczegóły", fontsize=FS_TITLE, fontweight="bold") + + # Panel 1: Pipe + ax = axes[0] + ax.set_xlim(0, 8) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Pipe (potok)", fontsize=FS_LABEL, fontweight="bold") + + draw_box(ax, 0.2, 2.0, 1.8, 1.2, "Proces A\n(ls)\nstdout", fill=GRAY1, fontsize=FS) + draw_box( + ax, + 3.0, + 2.0, + 1.8, + 1.2, + "Bufor\njądra\n(4 KB)", + fill=GRAY2, + fontsize=FS, + fontweight="bold", + ) + draw_box(ax, 5.8, 2.0, 1.8, 1.2, "Proces B\n(grep)\nstdin", fill=GRAY1, fontsize=FS) + draw_arrow(ax, 2.0, 2.6, 3.0, 2.6, lw=1.5) + ax.text(2.5, 3.0, "write()\nfd[1]", fontsize=FS_SMALL, ha="center") + draw_arrow(ax, 4.8, 2.6, 5.8, 2.6, lw=1.5) + ax.text(5.3, 3.0, "read()\nfd[0]", fontsize=FS_SMALL, ha="center") + ax.text( + 4.0, + 0.8, + "Jednokierunkowy\nBufor pełny → write() blokuje", + fontsize=FS_SMALL, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # Panel 2: Shared Memory + ax = axes[1] + ax.set_xlim(0, 8) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Shared Memory", fontsize=FS_LABEL, fontweight="bold") + + draw_box(ax, 0.3, 3.0, 2.2, 1.2, "Proces A\nstrona 7", fill=GRAY1, fontsize=FS) + draw_box(ax, 5.5, 3.0, 2.2, 1.2, "Proces B\nstrona 3", fill=GRAY1, fontsize=FS) + draw_box( + ax, + 2.8, + 1.0, + 2.4, + 1.2, + "RAM\nramka 42", + fill=GRAY3, + fontsize=FS, + fontweight="bold", + ) + draw_arrow(ax, 2.0, 3.0, 3.5, 2.2, lw=1.5) + draw_arrow(ax, 6.0, 3.0, 4.5, 2.2, lw=1.5) + ax.text( + 4.0, + 0.3, + "Zero kopiowania!\nA pisze → B widzi od razu\nWymaga synchronizacji (semafor)", + fontsize=FS_SMALL, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + # Panel 3: Socket + ax = axes[2] + ax.set_xlim(0, 8) + ax.set_ylim(0, 5) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Socket", fontsize=FS_LABEL, fontweight="bold") + + # Network socket + draw_box( + ax, 0.3, 3.2, 1.8, 0.9, "Klient", fill=GRAY1, fontsize=FS, fontweight="bold" + ) + draw_box( + ax, 5.5, 3.2, 1.8, 0.9, "Serwer", fill=GRAY1, fontsize=FS, fontweight="bold" + ) + draw_double_arrow(ax, 2.1, 3.65, 5.5, 3.65, lw=1.5) + ax.text(3.8, 4.3, "TCP/IP (sieciowy)", fontsize=FS, ha="center", fontweight="bold") + + # Unix socket + draw_box( + ax, 0.3, 1.3, 1.8, 0.9, "Proces A", fill=GRAY4, fontsize=FS, fontweight="bold" + ) + draw_box( + ax, 5.5, 1.3, 1.8, 0.9, "Proces B", fill=GRAY4, fontsize=FS, fontweight="bold" + ) + draw_double_arrow(ax, 2.1, 1.75, 5.5, 1.75, lw=1.5) + ax.text( + 3.8, + 2.4, + "Unix domain socket\n(/tmp/app.sock)", + fontsize=FS, + ha="center", + fontweight="bold", + ) + + ax.text( + 3.8, + 0.5, + "Dwukierunkowy\nNajbardziej uniwersalny IPC", + fontsize=FS_SMALL, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + fig.tight_layout(rect=[0, 0, 1, 0.9]) + save_fig(fig, "q9_ipc_details.png") + + +# ============================================================ +# 9. IPC comparison table +# ============================================================ +def gen_ipc_table() -> None: + """Gen ipc table.""" + fig, ax = plt.subplots(figsize=(8.5, 3.5)) + ax.set_xlim(0, 11) + ax.set_ylim(-4.5, 1) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Porównanie mechanizmów IPC", fontsize=FS_TITLE, fontweight="bold", pad=10 + ) + + headers = ["Mechanizm", "Kierunek", "Szybkość", "Zastosowanie"] + col_w = [2.5, 2.0, 2.5, 3.5] + rows = [ + ["Pipe", "jednokierunkowy", "średnia", "ls | grep"], + ["Named Pipe", "jednokierunkowy", "średnia", "demon → klient"], + ["Shared Memory", "dwukierunkowy", "NAJSZYBSZA", "video, bazy danych"], + ["Message Queue", "dwukierunkowy", "średnia", "wieloproducentowe"], + ["Socket", "dwukierunkowy", "wolna (sieć)", "klient-serwer"], + ["Signal", "jednokierunkowy", "natychmiastowa", "powiadomienia (nr)"], + ] + draw_table( + ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7.5 + ) + + save_fig(fig, "q9_ipc_table.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_q9_race_deadlock.py b/python_pkg/praca_magisterska_video/generate_images/_q9_race_deadlock.py new file mode 100644 index 0000000..7166f07 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_q9_race_deadlock.py @@ -0,0 +1,404 @@ +"""Q9 diagrams 10-13: race conditions, deadlock, Coffman, starvation.""" + +from __future__ import annotations + +from _q9_common import ( + FS, + FS_LABEL, + FS_SMALL, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + draw_arrow, + draw_box, + draw_table, + save_fig, +) +import matplotlib.pyplot as plt + + +# ============================================================ +# 10. Race condition (simple x + bank timeline) +# ============================================================ +def gen_race_condition() -> None: + """Gen race condition.""" + fig, axes = plt.subplots(1, 2, figsize=(11, 5)) + fig.suptitle( + "Wyścig (Race Condition) — przykłady", fontsize=FS_TITLE, fontweight="bold" + ) + + # Panel 1: simple x increment + ax = axes[0] + ax.set_xlim(0, 8) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Prosty wyścig: x = x + 1", fontsize=FS_LABEL, fontweight="bold") + + # Timeline + steps_a = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"] + steps_b = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"] + ax.text(2.0, 6.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold") + ax.text(6.0, 6.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold") + ax.plot([2, 2], [0.8, 6.0], color=LN, lw=1) + ax.plot([6, 6], [0.8, 6.0], color=LN, lw=1) + + for i, (sa, sb) in enumerate(zip(steps_a, steps_b, strict=False)): + y = 5.3 - i * 1.2 + draw_box(ax, 0.5, y, 3.0, 0.6, sa, fill=GRAY4, fontsize=FS) + draw_box(ax, 4.5, y - 0.3, 3.0, 0.6, sb, fill=GRAY1, fontsize=FS) + + ax.text( + 4.0, + 0.4, + "Wynik: x = 1 (powinno 2!)", + fontsize=FS, + ha="center", + fontweight="bold", + color="#C62828", + bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, + ) + + # Panel 2: bank account + ax = axes[1] + ax.set_xlim(0, 10) + ax.set_ylim(0, 7) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Konto bankowe: saldo = 1000 zł", fontsize=FS_LABEL, fontweight="bold") + + ax.text(2.5, 6.3, "Wątek A (+500)", fontsize=FS, ha="center", fontweight="bold") + ax.text(7.5, 6.3, "Wątek B (-200)", fontsize=FS, ha="center", fontweight="bold") + ax.plot([2.5, 2.5], [0.8, 6.0], color=LN, lw=1) + ax.plot([7.5, 7.5], [0.8, 6.0], color=LN, lw=1) + + events = [ + ("t1", "czytaj → 1000", "", 5.3), + ("t2", "", "czytaj → 1000", 4.6), + ("t3", "1000+500=1500", "", 3.9), + ("t4", "", "1000-200=800", 3.2), + ("t5", "zapisz 1500", "", 2.5), + ("t6", "", "zapisz 800 ✗", 1.8), + ] + for t, a, b, y in events: + ax.text(0.3, y + 0.15, t, fontsize=FS_SMALL, fontweight="bold", va="center") + if a: + draw_box(ax, 1.0, y, 3.0, 0.45, a, fill=GRAY4, fontsize=FS_SMALL) + if b: + fill = "#F8D7DA" if "✗" in b else GRAY1 + draw_box(ax, 6.0, y, 3.0, 0.45, b, fill=fill, fontsize=FS_SMALL) + + ax.text( + 5.0, + 0.4, + "Wynik: 800 zł (powinno 1300!)", + fontsize=FS, + ha="center", + fontweight="bold", + color="#C62828", + bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, + ) + + fig.tight_layout(rect=[0, 0, 1, 0.9]) + save_fig(fig, "q9_race_condition.png") + + +# ============================================================ +# 11. Deadlock scenario + cycle +# ============================================================ +def gen_deadlock_scenario() -> None: + """Gen deadlock scenario.""" + fig, axes = plt.subplots(1, 2, figsize=(11, 4.5)) + fig.suptitle("Zakleszczenie (Deadlock)", fontsize=FS_TITLE, fontweight="bold") + + # Panel 1: timeline + ax = axes[0] + ax.set_xlim(0, 8) + ax.set_ylim(0, 6) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Scenariusz z 2 mutexami", fontsize=FS_LABEL, fontweight="bold") + + ax.text(2.5, 5.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold") + ax.text(6.0, 5.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold") + + steps = [ + ("lock(mutex1) OK", "", "trzyma", False, 4.5), + ("", "lock(mutex2) OK", "trzyma", False, 3.7), + ("lock(mutex2) ...WAIT", "", "CZEKA!", True, 2.9), + ("", "lock(mutex1) ...WAIT", "CZEKA!", True, 2.1), + ] + for a_text, b_text, _note, is_wait, y in steps: + if a_text: + fill = "#F8D7DA" if is_wait else GRAY4 + draw_box(ax, 0.5, y, 3.3, 0.55, a_text, fill=fill, fontsize=FS_SMALL) + if b_text: + fill = "#F8D7DA" if is_wait else GRAY4 + draw_box(ax, 4.3, y, 3.3, 0.55, b_text, fill=fill, fontsize=FS_SMALL) + + ax.text( + 4.0, + 1.2, + "DEADLOCK!\nŻaden nie odpuści", + fontsize=FS, + ha="center", + fontweight="bold", + color="#C62828", + bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, + ) + + # Panel 2: cycle diagram + ax = axes[1] + ax.set_xlim(0, 8) + ax.set_ylim(0, 6) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Cykl oczekiwania", fontsize=FS_LABEL, fontweight="bold") + + # Thread boxes + draw_box( + ax, + 0.5, + 3.5, + 2.2, + 1.2, + "Wątek A\ntrzyma Mutex 1", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + draw_box( + ax, + 5.3, + 3.5, + 2.2, + 1.2, + "Wątek B\ntrzyma Mutex 2", + fill=GRAY1, + fontsize=FS, + fontweight="bold", + ) + + # Mutex boxes + draw_box( + ax, 0.5, 1.0, 2.2, 1.0, "Mutex 1", fill=GRAY3, fontsize=FS, fontweight="bold" + ) + draw_box( + ax, 5.3, 1.0, 2.2, 1.0, "Mutex 2", fill=GRAY3, fontsize=FS, fontweight="bold" + ) + + # holds arrows (down) + draw_arrow(ax, 1.6, 3.5, 1.6, 2.0, lw=2) + ax.text(0.9, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center") + draw_arrow(ax, 6.4, 3.5, 6.4, 2.0, lw=2) + ax.text(7.0, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center") + + # waits-for arrows (across, red) + draw_arrow(ax, 2.7, 4.3, 5.3, 4.3, lw=2.5, color="#C62828") + ax.text( + 4.0, + 4.7, + "czeka na Mutex 2", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + color="#C62828", + ) + draw_arrow(ax, 5.3, 3.7, 2.7, 3.7, lw=2.5, color="#C62828") + ax.text( + 4.0, + 3.1, + "czeka na Mutex 1", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + color="#C62828", + ) + + fig.tight_layout(rect=[0, 0, 1, 0.9]) + save_fig(fig, "q9_deadlock_scenario.png") + + +# ============================================================ +# 12. Coffman conditions + prevention strategies +# ============================================================ +def gen_coffman_strategies() -> None: + """Gen coffman strategies.""" + fig, ax = plt.subplots(figsize=(9, 4)) + ax.set_xlim(0, 11.5) + ax.set_ylim(-3.5, 1) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title( + "Warunki Coffmana — zapobieganie deadlockowi", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + headers = ["Warunek", "Opis", "Jak złamać", "Przykład"] + col_w = [2.5, 2.5, 3.0, 3.0] + rows = [ + [ + "1. Mutual Exclusion", + "zasób wyłączny", + "współdzielony zasób", + "Read-write lock", + ], + [ + "2. Hold and Wait", + "trzymaj + czekaj", + "bierz WSZYSTKIE naraz", + "lock(m1,m2) atomowo", + ], + [ + "3. No Preemption", + "nie zabierzesz siłą", + "timeout / trylock", + "pthread_mutex_trylock()", + ], + [ + "4. Circular Wait", + "cykliczne oczekiw.", + "porządek liniowy", + "zawsze m1 przed m2", + ], + ] + draw_table( + ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.6, fontsize=7 + ) + + ax.text( + 5.75, + -3.1, + "▸ Najczęstsza strategia: PORZĄDEK LINIOWY — " + "numeruj mutexy, zawsze blokuj rosnąco", + fontsize=FS, + ha="center", + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + save_fig(fig, "q9_coffman_strategies.png") + + +# ============================================================ +# 13. Starvation + Priority Inversion (2-panel) +# ============================================================ +def gen_starvation_priority() -> None: + """Gen starvation priority.""" + fig, axes = plt.subplots(1, 2, figsize=(11, 4.5)) + fig.suptitle( + "Zagłodzenie i Inwersja priorytetów", fontsize=FS_TITLE, fontweight="bold" + ) + + # Panel 1: Starvation + aging + ax = axes[0] + ax.set_xlim(0, 8) + ax.set_ylim(0, 6) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Zagłodzenie (Starvation)", fontsize=FS_LABEL, fontweight="bold") + + threads = [ + ("Wątek HIGH", "prio=10", GRAY5, 3.0), + ("Wątek HIGH", "prio=9", GRAY3, 2.2), + ("Wątek MED", "prio=5", GRAY2, 1.4), + ("Wątek LOW", "prio=1 → głoduje!", "#F8D7DA", 0.6), + ] + for name, prio, color, y in threads: + draw_box( + ax, 0.5, y, 2.0, 0.6, name, fill=color, fontsize=FS_SMALL, fontweight="bold" + ) + ax.text(2.8, y + 0.3, prio, fontsize=FS_SMALL, va="center") + + ax.text( + 1.5, + 4.2, + "CPU zawsze\ndostaje HIGH!", + fontsize=FS, + ha="center", + fontweight="bold", + ) + draw_arrow(ax, 1.5, 3.9, 1.5, 3.65, lw=1.5) + + # Aging solution + draw_box(ax, 4.5, 1.5, 3.2, 2.5, "", fill=GRAY4, rounded=True) + ax.text(6.1, 3.7, "Rozwiązanie: AGING", fontsize=FS, fontweight="bold", ha="center") + aging = [ + "t=0: prio=1", + "t=100ms: prio=2", + "t=200ms: prio=3", + "...", + "w końcu → CPU!", + ] + for i, line in enumerate(aging): + ax.text( + 6.1, 3.2 - i * 0.4, line, fontsize=FS_SMALL, ha="center", family="monospace" + ) + + # Panel 2: Priority Inversion + ax = axes[1] + ax.set_xlim(0, 10) + ax.set_ylim(0, 6) + ax.set_aspect("auto") + ax.axis("off") + ax.set_title("Inwersja priorytetów", fontsize=FS_LABEL, fontweight="bold") + + # Timeline + labels = ["H (wysoki)", "M (średni)", "L (niski)"] + ys = [4.2, 2.8, 1.4] + for label, y in zip(labels, ys, strict=False): + ax.text(0.3, y + 0.2, label, fontsize=FS, fontweight="bold", va="center") + + # L runs and locks mutex + draw_box(ax, 2.0, ys[2], 1.2, 0.5, "lock(m)", fill=GRAY1, fontsize=FS_SMALL) + + # M preempts L + draw_box(ax, 3.5, ys[1], 3.0, 0.5, "M pracuje...", fill=GRAY3, fontsize=FS_SMALL) + + # H waits for mutex + draw_box( + ax, + 3.5, + ys[0], + 3.0, + 0.5, + "CZEKA na mutex!", + fill="#F8D7DA", + fontsize=FS_SMALL, + fontweight="bold", + ) + + # M finishes, L continues, unlocks + draw_box(ax, 6.8, ys[2], 1.5, 0.5, "unlock(m)", fill=GRAY1, fontsize=FS_SMALL) + draw_box(ax, 8.5, ys[0], 1.2, 0.5, "H runs", fill=GRAY4, fontsize=FS_SMALL) + + # Explanation + ax.text( + 5.0, + 0.5, + "H czeka na M (mimo H > M)!\n" + "Rozwiązanie: Priority Inheritance\n" + "L dziedziczy priorytet H → M nie wypycha L", + fontsize=FS_SMALL, + ha="center", + style="italic", + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.text( + 5.0, + 0.0, + "Mars Pathfinder (1997) — klasyczny bug!", + fontsize=FS_SMALL, + ha="center", + fontweight="bold", + ) + + fig.tight_layout(rect=[0, 0, 1, 0.9]) + save_fig(fig, "q9_starvation_priority.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_sched_common.py b/python_pkg/praca_magisterska_video/generate_images/_sched_common.py new file mode 100644 index 0000000..d6c0af9 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_sched_common.py @@ -0,0 +1,87 @@ +"""Common constants and utilities for scheduling diagrams.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +_logger = logging.getLogger(__name__) + +DPI = 300 +BG = "white" +LN = "black" +FS = 8 +FS_TITLE = 11 +OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") +Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + +GRAY1 = "#E8E8E8" +GRAY2 = "#D0D0D0" +GRAY3 = "#B8B8B8" +GRAY4 = "#F5F5F5" +GRAY5 = "#C0C0C0" + +MIN_COLUMN_INDEX = 3 +FONTWEIGHT_THRESHOLD = 3 + + +def draw_box( + ax: Axes, + x: float, + y: float, + w: float, + h: float, + text: str, + fill: str = "white", + lw: float = 1.2, + fontsize: float = FS, + fontweight: str = "normal", + ha: str = "center", + va: str = "center", + *, + rounded: bool = True, +) -> None: + """Draw box.""" + if rounded: + rect = FancyBboxPatch( + (x, y), w, h, boxstyle="round,pad=0.05", lw=lw, edgecolor=LN, facecolor=fill + ) + else: + rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill) + ax.add_patch(rect) + ax.text( + x + w / 2, + y + h / 2, + text, + ha=ha, + va=va, + fontsize=fontsize, + fontweight=fontweight, + wrap=True, + ) + + +def draw_arrow( + ax: Axes, + x1: float, + y1: float, + x2: float, + y2: float, + lw: float = 1.2, + style: str = "->", + color: str = LN, +) -> None: + """Draw arrow.""" + ax.annotate( + "", + xy=(x2, y2), + xytext=(x1, y1), + arrowprops={"arrowstyle": style, "color": color, "lw": lw}, + ) diff --git a/python_pkg/praca_magisterska_video/generate_images/_sched_complexity_edd.py b/python_pkg/praca_magisterska_video/generate_images/_sched_complexity_edd.py new file mode 100644 index 0000000..ebd41ee --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_sched_complexity_edd.py @@ -0,0 +1,309 @@ +"""Scheduling complexity landscape and EDD example diagrams.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt + +from python_pkg.praca_magisterska_video.generate_images._sched_common import ( + BG, + DPI, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + OUTPUT_DIR, + draw_arrow, +) + +_logger = logging.getLogger(__name__) + + +# ============================================================ +# SCHEDULING COMPLEXITY LANDSCAPE +# ============================================================ +def draw_complexity_map() -> None: + """Draw complexity map.""" + _fig, ax = plt.subplots(1, 1, figsize=(8.27, 5)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 7) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Złożoność problemów szeregowania — od łatwych do NP-trudnych", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Gradient arrow at the top + ax.annotate( + "", + xy=(9.5, 6.2), + xytext=(0.5, 6.2), + arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, + ) + ax.text(5, 6.5, "Rosnąca złożoność", ha="center", fontsize=9, fontweight="bold") + + # Easy (polynomial) region + easy_rect = FancyBboxPatch( + (0.3, 2.8), + 4.0, + 3.0, + boxstyle="round,pad=0.15", + lw=1.5, + edgecolor="#666666", + facecolor=GRAY4, + linestyle="-", + ) + ax.add_patch(easy_rect) + ax.text( + 2.3, + 5.5, + "WIELOMIANOWE O(n log n)", + ha="center", + fontsize=9, + fontweight="bold", + color="#444444", + ) + + easy_problems = [ + ("1 || ΣCⱼ", "SPT", GRAY1, 4.8), + ("1 || Lmax", "EDD", GRAY2, 4.0), + ("F2 || Cmax", "Johnson", GRAY1, 3.2), + ] + for prob, method, fill, y in easy_problems: + rect = FancyBboxPatch( + (0.6, y), + 3.5, + 0.6, + boxstyle="round,pad=0.05", + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + 1.2, + y + 0.3, + prob, + ha="center", + va="center", + fontsize=8, + fontweight="bold", + fontfamily="monospace", + ) + ax.text(3.0, y + 0.3, f"→ {method}", ha="center", va="center", fontsize=8) + + # Hard (NP) region + hard_rect = FancyBboxPatch( + (5.3, 2.8), + 4.3, + 3.0, + boxstyle="round,pad=0.15", + lw=1.5, + edgecolor="#444444", + facecolor=GRAY3, + linestyle="-", + ) + ax.add_patch(hard_rect) + ax.text( + 7.45, + 5.5, + "NP-TRUDNE", + ha="center", + fontsize=9, + fontweight="bold", + color="#333333", + ) + + hard_problems = [ + ("Pm || Cmax\n(m≥2)", "LPT heuryst.", GRAY2, 4.5), + ("1 || ΣTⱼ", "branch&bound", GRAY4, 3.7), + ("Jm || Cmax\n(m≥3)", "metaheuryst.", GRAY5, 2.9), + ] + for prob, method, fill, y in hard_problems: + rect = FancyBboxPatch( + (5.6, y), + 3.7, + 0.7, + boxstyle="round,pad=0.05", + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + 6.5, + y + 0.35, + prob, + ha="center", + va="center", + fontsize=7, + fontweight="bold", + fontfamily="monospace", + ) + ax.text(8.2, y + 0.35, f"→ {method}", ha="center", va="center", fontsize=7) + + # Arrow connecting + draw_arrow(ax, 4.4, 4.0, 5.2, 4.0, lw=2, color="#888888") + ax.text(4.8, 4.25, "+1\nmaszyna", ha="center", fontsize=6, color="#888888") + + # Bottom: key insight + ax.text( + 5.0, + 1.8, + "„Dodanie jednej maszyny lub jednego ograniczenia\n" + 'może zmienić problem z łatwego na NP-trudny!"', + ha="center", + fontsize=8, + fontweight="bold", + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + "lw": 1, + }, + ) + + # Bottom examples + ax.text( + 5.0, + 0.8, + "1 maszyna → łatwe (sortuj) | ≥2 maszyny równoległe → NP-trudne\n" + "Flow shop 2 maszyny → Johnson O(n log n) | Flow shop 3 maszyny → NP-trudne", + ha="center", + fontsize=7, + color="#555555", + ) + + plt.tight_layout() + plt.savefig( + str(Path(OUTPUT_DIR) / "scheduling_complexity_map.png"), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close() + _logger.info(" ✓ scheduling_complexity_map.png") + + +# ============================================================ +# EDD EXAMPLE (1 || Lmax) +# ============================================================ +def draw_edd_example() -> None: + """Draw edd example.""" + _fig, ax = plt.subplots(1, 1, figsize=(8.27, 4)) + ax.set_xlim(-2, 28) + ax.set_ylim(-2, 4) + ax.axis("off") + ax.set_title( + "EDD (Earliest Due Date) — 1 || Lmax — Przykład", + fontsize=FS_TITLE, + fontweight="bold", + pad=8, + ) + + # Tasks: name, processing time, due date + tasks = [("J1", 4, 10), ("J2", 2, 6), ("J3", 6, 15), ("J4", 3, 8), ("J5", 5, 18)] + # EDD: sort by due date + edd_order = sorted(tasks, key=lambda x: x[2]) + + bar_y = 1.5 + bar_h = 0.8 + t = 0 + fills_edd = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5] + + lateness_vals = [] + for i, (name, p, d) in enumerate(edd_order): + rect = mpatches.Rectangle( + (t, bar_y), p, bar_h, lw=1.2, edgecolor=LN, facecolor=fills_edd[i] + ) + ax.add_patch(rect) + ax.text( + t + p / 2, + bar_y + bar_h / 2, + f"{name}\np={p}, d={d}", + ha="center", + va="center", + fontsize=6.5, + fontweight="bold", + ) + t += p + lateness = t - d + lateness_vals.append(lateness) + + # Due date marker + ax.plot( + [d, d], [bar_y - 0.4, bar_y - 0.1], color="#888888", lw=0.8, linestyle="--" + ) + ax.text( + d, + bar_y - 0.5, + f"d={d}", + ha="center", + va="top", + fontsize=5.5, + color="#888888", + ) + + # Completion + lateness + ax.plot([t, t], [bar_y + bar_h, bar_y + bar_h + 0.15], color=LN, lw=0.8) + ax.text( + t, + bar_y + bar_h + 0.2, + f"C={t}\nL={lateness}", + ha="center", + va="bottom", + fontsize=5.5, + ) + + # Time axis + ax.plot([0, 22], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5) + + lmax = max(lateness_vals) + ax.text( + 22, + bar_y + bar_h / 2, + f"Lmax = {lmax}", + ha="left", + va="center", + fontsize=10, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY1, "edgecolor": LN}, + ) + + # Bottom mnemonic + ax.text( + 10, + -1.3, + '„Early Due Date Does it first" — najpilniejszy deadline idzie pierwszy', + ha="center", + fontsize=8, + fontweight="bold", + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + "lw": 0.8, + }, + ) + + plt.tight_layout() + plt.savefig( + str(Path(OUTPUT_DIR) / "scheduling_edd_example.png"), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close() + _logger.info(" ✓ scheduling_edd_example.png") diff --git a/python_pkg/praca_magisterska_video/generate_images/_sched_graham.py b/python_pkg/praca_magisterska_video/generate_images/_sched_graham.py new file mode 100644 index 0000000..9c0e2fc --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_sched_graham.py @@ -0,0 +1,484 @@ +"""Graham notation α|β|γ visual mnemonic map diagram.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from matplotlib.patches import FancyBboxPatch +import matplotlib.pyplot as plt + +from python_pkg.praca_magisterska_video.generate_images._sched_common import ( + BG, + DPI, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + OUTPUT_DIR, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +_logger = logging.getLogger(__name__) + + +def draw_graham_notation() -> None: + """Draw graham notation.""" + _fig, ax = plt.subplots(1, 1, figsize=(8.27, 10)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 14) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Notacja Grahama \u03b1 | β | \u03b3 — Mapa mnemoniczna", + fontsize=FS_TITLE + 1, + fontweight="bold", + pad=12, + ) + + _draw_graham_formula_bar(ax) + _draw_graham_alpha_beta(ax) + _draw_graham_lower(ax) + + plt.tight_layout() + plt.savefig( + str(Path(OUTPUT_DIR) / "scheduling_graham_notation.png"), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close() + _logger.info(" ✓ scheduling_graham_notation.png") + + +def _draw_graham_formula_bar(ax: Axes) -> None: + """Draw the top alpha|beta|gamma formula bar.""" + bar_y = 12.5 + bar_h = 1.0 + # alpha box + rect = FancyBboxPatch( + (0.5, bar_y), + 2.5, + bar_h, + boxstyle="round,pad=0.08", + lw=2, + edgecolor=LN, + facecolor=GRAY1, + ) + ax.add_patch(rect) + ax.text( + 1.75, + bar_y + bar_h / 2, + "\u03b1", + fontsize=20, + fontweight="bold", + ha="center", + va="center", + ) + ax.text( + 1.75, + bar_y - 0.25, + "MASZYNY", + fontsize=8, + fontweight="bold", + ha="center", + va="top", + color="#444444", + ) + + # separator | + ax.text( + 3.3, + bar_y + bar_h / 2, + "|", + fontsize=24, + fontweight="bold", + ha="center", + va="center", + ) + + # β box + rect = FancyBboxPatch( + (3.7, bar_y), + 2.5, + bar_h, + boxstyle="round,pad=0.08", + lw=2, + edgecolor=LN, + facecolor=GRAY2, + ) + ax.add_patch(rect) + ax.text( + 4.95, + bar_y + bar_h / 2, + "β", + fontsize=20, + fontweight="bold", + ha="center", + va="center", + ) + ax.text( + 4.95, + bar_y - 0.25, + "OGRANICZENIA", + fontsize=8, + fontweight="bold", + ha="center", + va="top", + color="#444444", + ) + + # separator | + ax.text( + 6.5, + bar_y + bar_h / 2, + "|", + fontsize=24, + fontweight="bold", + ha="center", + va="center", + ) + + # gamma box + rect = FancyBboxPatch( + (6.9, bar_y), + 2.5, + bar_h, + boxstyle="round,pad=0.08", + lw=2, + edgecolor=LN, + facecolor=GRAY3, + ) + ax.add_patch(rect) + ax.text( + 8.15, + bar_y + bar_h / 2, + "\u03b3", + fontsize=20, + fontweight="bold", + ha="center", + va="center", + ) + ax.text( + 8.15, + bar_y - 0.25, + "CEL", + fontsize=8, + fontweight="bold", + ha="center", + va="top", + color="#444444", + ) + + +def _draw_graham_alpha_beta(ax: Axes) -> None: + """Draw alpha (machines) and beta (constraints) sections.""" + start_x = 0.3 + col_w = 1.28 + + # === SECTION alpha: MACHINES === + sec_y = 11.5 + ax.text( + 0.3, + sec_y, + '\u03b1 — „1 Prawdziwy Quasi-Rycerz Forsuje Jaskinię Orków"', + fontsize=8, + fontweight="bold", + va="top", + style="italic", + color="#333333", + ) + + alpha_items = [ + ("1", "jedna maszyna", "●", GRAY4), + ("Pm", "identyczne Parallel", "●●●", GRAY1), + ("Qm", "Quasi-uniform\n(różne prędkości)", "●●◐", GRAY4), + ("Rm", "Random unrelated\n(czasy per para)", "●◆▲", GRAY1), + ("Fm", "Flow shop\n(ta sama kolejność)", "→→→", GRAY2), + ("Jm", "Job shop\n(indyw. trasy)", "↗↙↘", GRAY4), + ("Om", "Open shop\n(dowolna kolej.)", "?→?", GRAY1), + ] + + col_w = 1.28 + box_h_a = 1.1 + start_x = 0.3 + start_y = 9.6 + + for i, (symbol, desc, icon, fill) in enumerate(alpha_items): + x = start_x + i * col_w + y = start_y + rect = FancyBboxPatch( + (x, y), + col_w - 0.1, + box_h_a, + boxstyle="round,pad=0.04", + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + (col_w - 0.1) / 2, + y + box_h_a - 0.15, + symbol, + ha="center", + va="top", + fontsize=9, + fontweight="bold", + ) + ax.text( + x + (col_w - 0.1) / 2, + y + box_h_a / 2 - 0.1, + desc, + ha="center", + va="center", + fontsize=5.5, + ) + ax.text( + x + (col_w - 0.1) / 2, y + 0.12, icon, ha="center", va="bottom", fontsize=7 + ) + + # Complexity arrow under alpha + arr_y = 9.35 + ax.annotate( + "", + xy=(9.0, arr_y), + xytext=(0.5, arr_y), + arrowprops={"arrowstyle": "->", "color": "#666666", "lw": 1.5}, + ) + ax.text( + 4.8, + arr_y - 0.18, + "rosnąca złożoność →", + ha="center", + fontsize=6, + color="#666666", + ) + + # === SECTION β: CONSTRAINTS === + sec_y2 = 8.9 + ax.text( + 0.3, + sec_y2, + "β — „Robak Daje Deadline: Przerwy Poprzedzają Pojedyncze Setup'y\"", + fontsize=8, + fontweight="bold", + va="top", + style="italic", + color="#333333", + ) + + beta_items = [ + ("rⱼ", "release\ndates", "Robak\ndostępne\nod czasu rⱼ", GRAY1), + ("dⱼ", "due\ndates", "Daje\ntermin soft\n(kara za spóźn.)", GRAY4), + ("d̄ⱼ", "dead-\nlines", "Deadline\ntermin hard\n(musi dotrzymać)", GRAY1), + ("pmtn", "preemp-\ntion", "Przerwy\nmożna\nprzerwać", GRAY2), + ("prec", "prece-\ndencje", "Poprzedzają\nA->B (DAG)", GRAY4), + ("pⱼ=1", "unit\ntime", "Pojedyncze\nwszystkie = 1", GRAY1), + ("sⱼₖ", "setup\ntimes", "Setup'y\nprzezbrojenie\nmiędzy j->k", GRAY4), + ] + + start_y2 = 7.0 + box_h_b = 1.4 + + for i, (symbol, _label, desc, fill) in enumerate(beta_items): + x = start_x + i * col_w + y = start_y2 + rect = FancyBboxPatch( + (x, y), + col_w - 0.1, + box_h_b, + boxstyle="round,pad=0.04", + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + (col_w - 0.1) / 2, + y + box_h_b - 0.12, + symbol, + ha="center", + va="top", + fontsize=9, + fontweight="bold", + ) + ax.text( + x + (col_w - 0.1) / 2, + y + box_h_b / 2 - 0.05, + desc, + ha="center", + va="center", + fontsize=5, + ) + + +def _draw_graham_lower(ax: Axes) -> None: + """Draw gamma criteria, examples, and footer sections.""" + start_x = 0.3 + + # === SECTION gamma: CRITERIA === + sec_y3 = 6.5 + ax.text( + 0.3, + sec_y3, + '\u03b3 — „Ciężki Sum Ważony Lata, Tardiness Uderza"', + fontsize=8, + fontweight="bold", + va="top", + style="italic", + color="#333333", + ) + + gamma_items = [ + ("Cmax", "makespan\nmax(Cⱼ)", "Jak długo\ntrwa WSZYSTKO?", GRAY2), + ("ΣCⱼ", "suma\nukończeń", "Średni czas\noczekiwania?", GRAY4), + ("ΣwⱼCⱼ", "ważona\nsuma", "Priorytety\nzadań?", GRAY1), + ("Lmax", "max\nopóźnienie", "Najgorsze\nspóźnienie?", GRAY2), + ("ΣTⱼ", "suma\nspóźnień", "Łączne\nspóźnienia?", GRAY4), + ("ΣUⱼ", "liczba\nspóźnionych", "Ile spóźnionych\nzadań?", GRAY1), + ] + + start_y3 = 4.5 + box_h_g = 1.4 + col_w_g = 1.5 + + for i, (symbol, label, question, fill) in enumerate(gamma_items): + x = start_x + i * col_w_g + y = start_y3 + rect = FancyBboxPatch( + (x, y), + col_w_g - 0.1, + box_h_g, + boxstyle="round,pad=0.04", + lw=1, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + (col_w_g - 0.1) / 2, + y + box_h_g - 0.1, + symbol, + ha="center", + va="top", + fontsize=9, + fontweight="bold", + ) + ax.text( + x + (col_w_g - 0.1) / 2, + y + box_h_g / 2 - 0.05, + label, + ha="center", + va="center", + fontsize=6, + ) + ax.text( + x + (col_w_g - 0.1) / 2, + y + 0.15, + f'„{question}"', + ha="center", + va="bottom", + fontsize=5, + style="italic", + ) + + # === BOTTOM: Example + Optimal methods === + ex_y = 3.5 + ax.text( + 0.3, + ex_y, + "Przykłady zapisu i optymalne metody:", + fontsize=8, + fontweight="bold", + va="top", + ) + + examples = [ + ("1 || ΣCⱼ", "SPT (najkrótsze\nnajpierw)", "O(n log n)", GRAY1), + ("1 || Lmax", "EDD (najwcześniejszy\ntermin)", "O(n log n)", GRAY4), + ("F2 || Cmax", "Algorytm\nJohnsona", "O(n log n)", GRAY2), + ("Pm || Cmax", "LPT heurystyka\n(NP-trudny!)", "NP-hard", GRAY3), + ("Jm || Cmax", "Branch & Bound\n(NP-trudny!)", "NP-hard", GRAY5), + ] + + ex_start_y = 1.8 + ex_box_w = 1.72 + ex_box_h = 1.4 + + for i, (notation, method, complexity, fill) in enumerate(examples): + x = start_x + i * (ex_box_w + 0.1) + y = ex_start_y + rect = FancyBboxPatch( + (x, y), + ex_box_w, + ex_box_h, + boxstyle="round,pad=0.04", + lw=1.2, + edgecolor=LN, + facecolor=fill, + ) + ax.add_patch(rect) + ax.text( + x + ex_box_w / 2, + y + ex_box_h - 0.12, + notation, + ha="center", + va="top", + fontsize=8, + fontweight="bold", + fontfamily="monospace", + ) + ax.text( + x + ex_box_w / 2, + y + ex_box_h / 2 - 0.05, + method, + ha="center", + va="center", + fontsize=6, + ) + ax.text( + x + ex_box_w / 2, + y + 0.12, + complexity, + ha="center", + va="bottom", + fontsize=6.5, + fontweight="bold", + color="#555555", + ) + + # Footer mnemonic summary + ax.text( + 5.0, + 0.8, + '„\u03b1|β|\u03b3 = Maszyny | Ograniczenia | Cel"', + ha="center", + fontsize=9, + fontweight="bold", + style="italic", + color="#333333", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + "lw": 1, + }, + ) + + ax.text( + 5.0, + 0.2, + "\u03b1: ILE maszyn i JAKIE? " + "β: JAKIE ograniczenia zadań? " + "\u03b3: CO minimalizujemy?", + ha="center", + fontsize=7, + color="#555555", + ) diff --git a/python_pkg/praca_magisterska_video/generate_images/_sched_johnson.py b/python_pkg/praca_magisterska_video/generate_images/_sched_johnson.py new file mode 100644 index 0000000..39cd829 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_sched_johnson.py @@ -0,0 +1,318 @@ +"""Johnson's algorithm Gantt chart diagram (F2||Cmax).""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt + +from python_pkg.praca_magisterska_video.generate_images._sched_common import ( + BG, + DPI, + FONTWEIGHT_THRESHOLD, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + MIN_COLUMN_INDEX, + OUTPUT_DIR, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +_logger = logging.getLogger(__name__) + + +def draw_johnson_gantt() -> None: + """Draw johnson gantt.""" + _fig, axes = plt.subplots( + 2, 1, figsize=(8.27, 7), gridspec_kw={"height_ratios": [1, 1.8]} + ) + + _draw_johnson_decision_table(axes[0]) + _draw_johnson_gantt_chart(axes[1]) + + plt.tight_layout() + plt.savefig( + str(Path(OUTPUT_DIR) / "scheduling_johnson_gantt.png"), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close() + _logger.info(" ✓ scheduling_johnson_gantt.png") + + +def _draw_johnson_decision_table(ax: Axes) -> None: + """Draw the Johnson algorithm decision table.""" + ax.set_xlim(0, 10) + ax.set_ylim(0, 5) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title( + "Algorytm Johnsona (F2 || Cmax) — Decyzja + Diagram Gantta", + fontsize=FS_TITLE, + fontweight="bold", + pad=10, + ) + + # Task table + tasks = ["J1", "J2", "J3", "J4", "J5"] + a_times = [4, 2, 6, 1, 3] + b_times = [5, 3, 2, 7, 4] + min_vals = [min(a, b) for a, b in zip(a_times, b_times, strict=False)] + min_on = ["M1" if a <= b else "M2" for a, b in zip(a_times, b_times, strict=False)] + assign = ["POCZątek" if m == "M1" else "KONIEC" for m in min_on] + + # Draw table + col_w_t = 1.3 + row_h = 0.55 + headers = ["Zadanie", "aⱼ (M1)", "bⱼ (M2)", "min", "min na", "Przydziel"] + table_x = 0.8 + table_y = 3.8 + + for j, hdr in enumerate(headers): + x = table_x + j * col_w_t + rect = mpatches.Rectangle( + (x, table_y), col_w_t, row_h, lw=1, edgecolor=LN, facecolor=GRAY2 + ) + ax.add_patch(rect) + ax.text( + x + col_w_t / 2, + table_y + row_h / 2, + hdr, + ha="center", + va="center", + fontsize=6.5, + fontweight="bold", + ) + + for i in range(5): + row_data = [ + tasks[i], + str(a_times[i]), + str(b_times[i]), + str(min_vals[i]), + min_on[i], + assign[i], + ] + for j, val in enumerate(row_data): + x = table_x + j * col_w_t + y = table_y - (i + 1) * row_h + fill_c = GRAY1 if min_on[i] == "M1" else GRAY4 + if j == MIN_COLUMN_INDEX: # min column - highlight + fill_c = GRAY3 + rect = mpatches.Rectangle( + (x, y), col_w_t, row_h, lw=0.8, edgecolor=LN, facecolor=fill_c + ) + ax.add_patch(rect) + fw = "bold" if j >= FONTWEIGHT_THRESHOLD else "normal" + ax.text( + x + col_w_t / 2, + y + row_h / 2, + val, + ha="center", + va="center", + fontsize=6.5, + fontweight=fw, + ) + + # Sorting result + result_y = 0.7 + ax.text( + 5.0, + result_y + 0.4, + "Sortuj → POCZĄTEK ↑aⱼ: J4(1), J2(2), J5(3), J1(4) | KONIEC ↓bⱼ: J3(2)", + ha="center", + fontsize=7, + color="#333333", + ) + ax.text( + 5.0, + result_y, + "Optymalna kolejność: J4 → J2 → J5 → J1 → J3", + ha="center", + fontsize=9, + fontweight="bold", + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": GRAY1, + "edgecolor": LN, + "lw": 1.2, + }, + ) + + +def _draw_johnson_gantt_chart(ax2: Axes) -> None: + """Draw the Johnson algorithm Gantt chart.""" + ax2.set_xlim(-1, 24) + ax2.set_ylim(-1, 4) + ax2.axis("off") + + # Machines labels + m1_y = 2.5 + m2_y = 0.8 + bar_h = 0.9 + + ax2.text( + -0.8, + m1_y + bar_h / 2, + "M1", + ha="center", + va="center", + fontsize=11, + fontweight="bold", + ) + ax2.text( + -0.8, + m2_y + bar_h / 2, + "M2", + ha="center", + va="center", + fontsize=11, + fontweight="bold", + ) + + # Schedule: J4 → J2 → J5 → J1 → J3 + order = ["J4", "J2", "J5", "J1", "J3"] + a_ord = [1, 2, 3, 4, 6] # M1 times in order + b_ord = [7, 3, 4, 5, 2] # M2 times in order + fills = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5] + hatches = ["", "///", "", "\\\\\\", "xxx"] + + # M1 schedule + m1_starts = [] + t = 0 + for a in a_ord: + m1_starts.append(t) + t += a + m1_ends = [s + a for s, a in zip(m1_starts, a_ord, strict=False)] + + # M2 schedule (must wait for M1 finish AND previous M2 finish) + m2_starts = [] + m2_ends = [] + prev_m2_end = 0 + for i, b in enumerate(b_ord): + start = max(m1_ends[i], prev_m2_end) + m2_starts.append(start) + m2_ends.append(start + b) + prev_m2_end = start + b + + # Draw M1 bars + for i in range(5): + rect = mpatches.Rectangle( + (m1_starts[i], m1_y), + a_ord[i], + bar_h, + lw=1.2, + edgecolor=LN, + facecolor=fills[i], + hatch=hatches[i], + ) + ax2.add_patch(rect) + ax2.text( + m1_starts[i] + a_ord[i] / 2, + m1_y + bar_h / 2, + f"{order[i]}\n({a_ord[i]})", + ha="center", + va="center", + fontsize=7, + fontweight="bold", + ) + + # Draw M2 bars + for i in range(5): + rect = mpatches.Rectangle( + (m2_starts[i], m2_y), + b_ord[i], + bar_h, + lw=1.2, + edgecolor=LN, + facecolor=fills[i], + hatch=hatches[i], + ) + ax2.add_patch(rect) + ax2.text( + m2_starts[i] + b_ord[i] / 2, + m2_y + bar_h / 2, + f"{order[i]}\n({b_ord[i]})", + ha="center", + va="center", + fontsize=7, + fontweight="bold", + ) + + # Draw idle regions on M2 + idle_starts = [0] + idle_ends = [m2_starts[0]] + for i in range(1, 5): + if m2_starts[i] > m2_ends[i - 1]: + idle_starts.append(m2_ends[i - 1]) + idle_ends.append(m2_starts[i]) + + for s, e in zip(idle_starts, idle_ends, strict=False): + if e > s: + rect = mpatches.Rectangle( + (s, m2_y), + e - s, + bar_h, + lw=0.5, + edgecolor="#AAAAAA", + facecolor="white", + linestyle="--", + ) + ax2.add_patch(rect) + ax2.text( + s + (e - s) / 2, + m2_y + bar_h / 2, + "idle", + ha="center", + va="center", + fontsize=5, + color="#999999", + ) + + # Time axis + ax_y = m2_y - 0.15 + ax2.plot([0, 23], [ax_y, ax_y], color=LN, lw=0.8) + for t in range(0, 24, 2): + ax2.plot([t, t], [ax_y - 0.08, ax_y + 0.08], color=LN, lw=0.8) + ax2.text(t, ax_y - 0.25, str(t), ha="center", va="top", fontsize=6) + ax2.text(11.5, ax_y - 0.55, "czas", ha="center", fontsize=7) + + # Cmax annotation + ax2.annotate( + f"Cmax = {m2_ends[-1]}", + xy=(m2_ends[-1], m2_y + bar_h), + xytext=(m2_ends[-1] + 0.5, m2_y + bar_h + 0.6), + fontsize=10, + fontweight="bold", + color="#333333", + arrowprops={"arrowstyle": "->", "color": "#333333", "lw": 1.5}, + ) + + # Mnemonic at bottom + ax2.text( + 11, + -0.7, + "„Krótki na M1 → START (szybko karmi M2)" + " Krótki na M2 → KONIEC" + ' (szybko kończy)"', + ha="center", + fontsize=7.5, + fontweight="bold", + style="italic", + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": GRAY4, + "edgecolor": GRAY3, + "lw": 0.8, + }, + ) diff --git a/python_pkg/praca_magisterska_video/generate_images/_sched_spt_flow_job.py b/python_pkg/praca_magisterska_video/generate_images/_sched_spt_flow_job.py new file mode 100644 index 0000000..75c6706 --- /dev/null +++ b/python_pkg/praca_magisterska_video/generate_images/_sched_spt_flow_job.py @@ -0,0 +1,352 @@ +"""SPT vs LPT comparison and Flow Shop vs Job Shop diagrams.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt + +from python_pkg.praca_magisterska_video.generate_images._sched_common import ( + BG, + DPI, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + OUTPUT_DIR, + draw_arrow, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + +_logger = logging.getLogger(__name__) + + +# ============================================================ +# SPT vs LPT COMPARISON (1 || ΣCⱼ) +# ============================================================ +def draw_spt_comparison() -> None: + """Draw spt comparison.""" + fig, axes = plt.subplots(2, 1, figsize=(8.27, 5.5)) + + tasks_orig = [("J1", 5), ("J2", 3), ("J3", 8), ("J4", 2), ("J5", 6)] + + spt_order = sorted(tasks_orig, key=lambda x: x[1]) + lpt_order = sorted(tasks_orig, key=lambda x: -x[1]) + + fills_map = {"J1": GRAY1, "J2": GRAY2, "J3": GRAY3, "J4": GRAY4, "J5": GRAY5} + hatch_map = {"J1": "", "J2": "///", "J3": "xxx", "J4": "", "J5": "\\\\\\"} + + for _idx, (ax, order_list, title, is_optimal) in enumerate( + [ + (axes[0], spt_order, "SPT (Shortest Processing Time) — OPTYMALNE", True), + (axes[1], lpt_order, "LPT (Longest Processing Time) — gorsze!", False), + ] + ): + ax.set_xlim(-2, 26) + ax.set_ylim(-0.5, 2.5) + ax.axis("off") + color = "#222222" if is_optimal else "#666666" + marker = "✓" if is_optimal else "✗" + ax.set_title( + f"{marker} {title}", + fontsize=9, + fontweight="bold", + loc="left", + color=color, + pad=5, + ) + + bar_y = 1.0 + bar_h = 0.8 + t = 0 + completions = [] + + for name, duration in order_list: + rect = mpatches.Rectangle( + (t, bar_y), + duration, + bar_h, + lw=1.2, + edgecolor=LN, + facecolor=fills_map[name], + hatch=hatch_map[name], + ) + ax.add_patch(rect) + ax.text( + t + duration / 2, + bar_y + bar_h / 2, + f"{name}\n({duration})", + ha="center", + va="center", + fontsize=7, + fontweight="bold", + ) + t += duration + completions.append(t) + + # Completion time marker + ax.plot([t, t], [bar_y - 0.15, bar_y], color=LN, lw=0.8) + ax.text( + t, + bar_y - 0.25, + f"C={t}", + ha="center", + va="top", + fontsize=6, + color="#555555", + ) + + total = sum(completions) + # Time axis + ax.plot([0, 25], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5) + + # Sum annotation + comp_str = " + ".join(str(c) for c in completions) + ax.text( + 25, + bar_y + bar_h / 2, + f"ΣCⱼ = {comp_str}\n = {total}", + ha="left", + va="center", + fontsize=7, + fontweight="bold" if is_optimal else "normal", + color=color, + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": GRAY1 if is_optimal else "white", + "edgecolor": color, + "lw": 1, + }, + ) + + # Bottom annotation + fig.text( + 0.5, + 0.02, + '„Short People To the front"' + " — krótkie najpierw," + " jak niskie osoby w zdjęciu klasowym", + ha="center", + fontsize=8, + fontweight="bold", + style="italic", + color="#444444", + ) + + plt.tight_layout(rect=[0, 0.05, 1, 1]) + plt.savefig( + str(Path(OUTPUT_DIR) / "scheduling_spt_comparison.png"), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close() + _logger.info(" ✓ scheduling_spt_comparison.png") + + +# ============================================================ +# FLOW SHOP vs JOB SHOP +# ============================================================ +def draw_flow_vs_job() -> None: + """Draw flow vs job.""" + _fig, axes = plt.subplots(1, 2, figsize=(8.27, 4.5)) + + _draw_flow_shop(axes[0]) + _draw_job_shop(axes[1]) + + plt.tight_layout() + plt.savefig( + str(Path(OUTPUT_DIR) / "scheduling_flow_vs_job.png"), + dpi=DPI, + bbox_inches="tight", + facecolor=BG, + ) + plt.close() + _logger.info(" ✓ scheduling_flow_vs_job.png") + + +def _draw_flow_shop(ax: Axes) -> None: + """Draw the Flow Shop diagram.""" + ax.set_xlim(0, 6) + ax.set_ylim(0, 6) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Flow Shop (Fm)", fontsize=10, fontweight="bold", pad=8) + + # Machines in a row + machines_x = [1, 3, 5] + machines_y = 3 + mach_r = 0.4 + + for i, mx in enumerate(machines_x): + circle = plt.Circle( + (mx, machines_y), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5 + ) + ax.add_patch(circle) + ax.text( + mx, + machines_y, + f"M{i + 1}", + ha="center", + va="center", + fontsize=9, + fontweight="bold", + ) + + # Arrows between machines + for i in range(len(machines_x) - 1): + draw_arrow( + ax, + machines_x[i] + mach_r + 0.05, + machines_y, + machines_x[i + 1] - mach_r - 0.05, + machines_y, + lw=2, + ) + + # Jobs all flowing the same way + jobs_flow = ["J1", "J2", "J3"] + for _j, (job, y_off) in enumerate(zip(jobs_flow, [0.8, 0, -0.8], strict=False)): + ax.text( + 0.2, + machines_y + y_off, + job, + ha="center", + va="center", + fontsize=7, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.15", "facecolor": GRAY1, "edgecolor": LN}, + ) + # Dashed flow line + ax.annotate( + "", + xy=(5.5, machines_y + y_off * 0.3), + xytext=(0.5, machines_y + y_off), + arrowprops={ + "arrowstyle": "->", + "color": "#888888", + "lw": 0.8, + "linestyle": "dashed", + }, + ) + + ax.text( + 3, + 1.2, + "Wszystkie zadania:\nM1 → M2 → M3", + ha="center", + va="center", + fontsize=8, + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.text( + 3, + 0.4, + "Jak taśma montażowa", + ha="center", + fontsize=7, + style="italic", + color="#666666", + ) + + +def _draw_job_shop(ax: Axes) -> None: + """Draw the Job Shop diagram.""" + mach_r = 0.4 + ax.set_xlim(0, 6) + ax.set_ylim(0, 6) + ax.set_aspect("equal") + ax.axis("off") + ax.set_title("Job Shop (Jm)", fontsize=10, fontweight="bold", pad=8) + + # Machines scattered + m_positions = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)] + + for i, (mx, my) in enumerate(m_positions): + circle = plt.Circle((mx, my), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5) + ax.add_patch(circle) + ax.text( + mx, my, f"M{i + 1}", ha="center", va="center", fontsize=9, fontweight="bold" + ) + + # J1: M1 → M2 → M3 (solid) + route1 = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)] + for i in range(len(route1) - 1): + x1, y1 = route1[i] + x2, y2 = route1[i + 1] + dx = x2 - x1 + dy = y2 - y1 + d = (dx**2 + dy**2) ** 0.5 + draw_arrow( + ax, + x1 + mach_r * dx / d + 0.05, + y1 + mach_r * dy / d, + x2 - mach_r * dx / d - 0.05, + y2 - mach_r * dy / d, + lw=1.5, + ) + ax.text( + 0.3, + 4.8, + "J1: M1→M2→M3", + fontsize=7, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY1, "edgecolor": LN}, + ) + + # J2: M2 → M3 → M1 (dashed) + route2 = [(4.5, 4.2), (3, 2.5), (1.5, 4.2)] + for i in range(len(route2) - 1): + x1, y1 = route2[i] + x2, y2 = route2[i + 1] + dx = x2 - x1 + dy = y2 - y1 + d = (dx**2 + dy**2) ** 0.5 + off = 0.15 # offset to avoid overlap + ax.annotate( + "", + xy=(x2 - mach_r * dx / d - 0.05, y2 - mach_r * dy / d + off), + xytext=(x1 + mach_r * dx / d + 0.05, y1 + mach_r * dy / d + off), + arrowprops={ + "arrowstyle": "->", + "color": "#555555", + "lw": 1.5, + "linestyle": "dashed", + }, + ) + ax.text( + 3.8, + 5.2, + "J2: M2→M3→M1", + fontsize=7, + fontweight="bold", + bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY4, "edgecolor": LN}, + ) + + ax.text( + 3, + 1.2, + "Każde zadanie:\nwłasna trasa!", + ha="center", + va="center", + fontsize=8, + bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, + ) + + ax.text( + 3, + 0.4, + "NP-trudny już dla 3 maszyn", + ha="center", + fontsize=7, + style="italic", + color="#666666", + ) diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_pubsub_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_pubsub_diagrams.py index e7c80e4..7134b40 100755 --- a/python_pkg/praca_magisterska_video/generate_images/generate_pubsub_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_pubsub_diagrams.py @@ -17,1208 +17,29 @@ One diagram per image -- no cramming. from __future__ import annotations -from dataclasses import dataclass import logging -from pathlib import Path -from typing import TYPE_CHECKING -import matplotlib as mpl - -mpl.use("Agg") - -import matplotlib.patches as mpatches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt - -if TYPE_CHECKING: - from matplotlib.axes import Axes +from _pubsub_qos import ( + draw_qos_at_least_once, + draw_qos_at_most_once, + draw_qos_exactly_once, +) +from _pubsub_topic_content import ( + draw_sub_content, + draw_sub_topic, +) +from _pubsub_type_hierarchical import ( + draw_sub_hierarchical, + draw_sub_type, +) logger = logging.getLogger(__name__) -DPI = 300 -BG = "white" -LN = "black" -FS = 9 -FS_TITLE = 13 -FIG_W = 8.27 # A4 width in inches -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -GRAY1 = "#E8E8E8" -GRAY2 = "#D0D0D0" -GRAY3 = "#B8B8B8" -GRAY4 = "#F5F5F5" -GRAY5 = "#C0C0C0" - - -@dataclass(frozen=True) -class BoxStyle: - """Optional styling for boxes.""" - - fill: str = "white" - lw: float = 1.2 - fontsize: float = FS - fontweight: str = "normal" - ha: str = "center" - va: str = "center" - rounded: bool = True - - -@dataclass(frozen=True) -class ArrowCfg: - """Config for arrows.""" - - lw: float = 1.2 - style: str = "->" - color: str = LN - label: str = "" - label_offset: float = 0.15 - label_fs: float = 8 - - -@dataclass(frozen=True) -class DashedCfg: - """Config for dashed arrows.""" - - lw: float = 1.0 - color: str = LN - label: str = "" - label_offset: float = 0.15 - label_fs: float = 8 - - -def draw_box( - ax: Axes, - pos: tuple[float, float], - size: tuple[float, float], - text: str, - style: BoxStyle | None = None, -) -> None: - """Draw box.""" - s = style or BoxStyle() - x, y = pos - w, h = size - if s.rounded: - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.05", - lw=s.lw, - edgecolor=LN, - facecolor=s.fill, - ) - else: - rect = mpatches.Rectangle( - (x, y), - w, - h, - lw=s.lw, - edgecolor=LN, - facecolor=s.fill, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h / 2, - text, - ha=s.ha, - va=s.va, - fontsize=s.fontsize, - fontweight=s.fontweight, - wrap=True, - ) - - -def draw_arrow( - ax: Axes, - start: tuple[float, float], - end: tuple[float, float], - cfg: ArrowCfg | None = None, -) -> None: - """Draw arrow.""" - c = cfg or ArrowCfg() - ax.annotate( - "", - xy=end, - xytext=start, - arrowprops={ - "arrowstyle": c.style, - "color": c.color, - "lw": c.lw, - }, - ) - if c.label: - mx = (start[0] + end[0]) / 2 - my = (start[1] + end[1]) / 2 + c.label_offset - ax.text( - mx, - my, - c.label, - ha="center", - va="bottom", - fontsize=c.label_fs, - color=c.color, - ) - - -def draw_dashed_arrow( - ax: Axes, - start: tuple[float, float], - end: tuple[float, float], - cfg: DashedCfg | None = None, -) -> None: - """Draw dashed arrow.""" - c = cfg or DashedCfg() - ax.annotate( - "", - xy=end, - xytext=start, - arrowprops={ - "arrowstyle": "->", - "color": c.color, - "lw": c.lw, - "linestyle": "dashed", - }, - ) - if c.label: - mx = (start[0] + end[0]) / 2 - my = (start[1] + end[1]) / 2 + c.label_offset - ax.text( - mx, - my, - c.label, - ha="center", - va="bottom", - fontsize=c.label_fs, - color=c.color, - ) - - -def draw_cross( - ax: Axes, - pos: tuple[float, float], - size: float = 0.15, - lw: float = 2.5, - color: str = "black", -) -> None: - """Draw cross.""" - x, y = pos - ax.plot( - [x - size, x + size], - [y - size, y + size], - color=color, - lw=lw, - ) - ax.plot( - [x - size, x + size], - [y + size, y - size], - color=color, - lw=lw, - ) - - -def draw_check( - ax: Axes, - pos: tuple[float, float], - size: float = 0.15, - lw: float = 2.5, - color: str = "black", -) -> None: - """Draw check.""" - x, y = pos - ax.plot( - [x - size, x - size * 0.2], - [y, y - size * 0.7], - color=color, - lw=lw, - ) - ax.plot( - [x - size * 0.2, x + size], - [y - size * 0.7, y + size * 0.5], - color=color, - lw=lw, - ) - - -def save(fig: plt.Figure, name: str) -> None: - """Save.""" - plt.tight_layout() - fig.savefig( - str(Path(OUTPUT_DIR) / name), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close(fig) - logger.info(" \u2713 %s", name) - - -# ============================================================ -# 1. Topic-based subscription -# ============================================================ -def draw_sub_topic() -> None: - """Draw sub topic.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.0)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 5.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Subskrypcja topic-based" - " \u2014 routing po nazwie tematu", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - bold10 = BoxStyle( - fill=GRAY1, fontsize=10, fontweight="bold" - ) - fs85 = BoxStyle(fill=GRAY1, fontsize=8.5) - - draw_box( - ax, (0.2, 3.2), (2.4, 1.1), "Publisher", bold10 - ) - draw_box( - ax, - (0.3, 1.8), - (2.2, 0.8), - 'topic: "orders"', - BoxStyle(fill=GRAY4, fontsize=8), - ) - draw_box( - ax, - (0.3, 0.7), - (2.2, 0.8), - 'topic: "payments"', - BoxStyle(fill=GRAY4, fontsize=8), - ) - - draw_box( - ax, - (4.2, 1.5), - (2.8, 2.2), - "BROKER\n\ntopic routing", - BoxStyle( - fill=GRAY2, fontsize=10, fontweight="bold" - ), - ) - - draw_box( - ax, - (8.5, 3.8), - (3.0, 1.0), - 'Subscriber A\nsubskrybuje: "orders"', - fs85, - ) - draw_box( - ax, - (8.5, 2.2), - (3.0, 1.0), - 'Subscriber B\nsubskrybuje: "payments"', - fs85, - ) - draw_box( - ax, - (8.5, 0.6), - (3.0, 1.0), - 'Subscriber C\nsubskrybuje: "orders"', - fs85, - ) - - fs8 = ArrowCfg(label_fs=8) - draw_arrow(ax, (2.6, 2.2), (4.2, 2.8), fs8) - draw_arrow(ax, (2.6, 1.1), (4.2, 2.2), fs8) - - draw_arrow( - ax, - (7.0, 3.4), - (8.5, 4.2), - ArrowCfg(label='"orders"', label_fs=8), - ) - draw_arrow( - ax, - (7.0, 2.6), - (8.5, 2.7), - ArrowCfg(label='"payments"', label_fs=8), - ) - draw_arrow( - ax, - (7.0, 2.2), - (8.5, 1.2), - ArrowCfg(label='"orders"', label_fs=8), - ) - - ax.text( - 6.0, - 0.1, - "Subscriber deklaruje nazw\u0119 tematu." - " Broker kieruje wiadomo\u015bci\n" - "do WSZYSTKICH subscriber\u00f3w" - " danego tematu. Najprostszy model.", - ha="center", - va="bottom", - fontsize=8.5, - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - save(fig, "pubsub_sub_topic.png") - - -# ============================================================ -# 2. Content-based subscription -# ============================================================ -def draw_sub_content() -> None: - """Draw sub content.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 6) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Subskrypcja content-based" - " \u2014 filtrowanie po tre\u015bci wiadomo\u015bci", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - bold10 = BoxStyle( - fill=GRAY1, fontsize=10, fontweight="bold" - ) - draw_box( - ax, (0.2, 3.5), (2.4, 1.1), "Publisher", bold10 - ) - draw_box( - ax, - (0.2, 1.8), - (2.4, 1.2), - 'price: 150\ntype: "book"\ncategory: "IT"', - BoxStyle(fill=GRAY4, fontsize=8.5), - ) - - draw_box( - ax, - (4.0, 2.0), - (3.0, 2.5), - "BROKER\n\newaluuje filtry\n" - "ka\u017cdego subscribera", - BoxStyle( - fill=GRAY2, fontsize=9, fontweight="bold" - ), - ) - - fs9 = BoxStyle(fill=GRAY1, fontsize=9) - draw_box( - ax, - (8.5, 4.2), - (3.2, 1.0), - "Sub A\nfiltr: price > 100", - fs9, - ) - draw_box( - ax, - (8.5, 2.6), - (3.2, 1.0), - 'Sub B\nfiltr: type = "food"', - fs9, - ) - draw_box( - ax, - (8.5, 1.0), - (3.2, 1.0), - "Sub C\nfiltr: price < 50", - fs9, - ) - - draw_arrow(ax, (2.6, 2.4), (4.0, 3.0)) - draw_arrow( - ax, - (7.0, 4.0), - (8.5, 4.6), - ArrowCfg( - label="150 > 100 \u2713 dostarczono", - label_fs=8, - ), - ) - draw_dashed_arrow( - ax, - (7.0, 3.2), - (8.5, 3.1), - DashedCfg( - label='"book" \u2260 "food"' - " \u2717 odrzucono", - label_fs=8, - ), - ) - draw_dashed_arrow( - ax, - (7.0, 2.5), - (8.5, 1.6), - DashedCfg( - label="150 < 50 \u2717 odrzucono", - label_fs=8, - ), - ) - - ax.text( - 6.0, - 0.2, - "Broker analizuje TRE\u015a\u0106 wiadomo\u015bci" - " i ewaluuje predykaty.\n" - "Bardziej elastyczny ni\u017c topic-based," - " ale wolniejszy (koszt ewaluacji).", - ha="center", - va="bottom", - fontsize=8.5, - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - save(fig, "pubsub_sub_content.png") - - -# ============================================================ -# 3. Type-based subscription -# ============================================================ -def draw_sub_type() -> None: - """Draw sub type.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 6.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Subskrypcja type-based" - " \u2014 routing po typie (klasie) obiektu", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - bold10 = BoxStyle( - fill=GRAY1, fontsize=10, fontweight="bold" - ) - draw_box( - ax, (0.2, 4.2), (2.4, 1.1), "Publisher", bold10 - ) - - fs9_g4 = BoxStyle(fill=GRAY4, fontsize=9) - draw_box( - ax, - (0.1, 2.8), - (2.6, 0.9), - "new OrderEvent()", - fs9_g4, - ) - draw_box( - ax, - (0.1, 1.5), - (2.6, 0.9), - "new PaymentEvent()", - fs9_g4, - ) - - draw_box( - ax, - (4.0, 2.3), - (3.0, 2.4), - "BROKER\n\nrouting po\ntypie klasy", - BoxStyle( - fill=GRAY2, fontsize=10, fontweight="bold" - ), - ) - - fs9 = BoxStyle(fill=GRAY1, fontsize=9) - draw_box( - ax, - (8.5, 4.8), - (3.2, 1.0), - "Sub A\n\u2192 OrderEvent", - fs9, - ) - draw_box( - ax, - (8.5, 3.2), - (3.2, 1.0), - "Sub B\n\u2192 PaymentEvent", - fs9, - ) - draw_box( - ax, - (8.5, 1.6), - (3.2, 1.0), - "Sub C\n\u2192 Event (base)", - fs9, - ) - - draw_arrow(ax, (2.7, 3.2), (4.0, 3.8)) - draw_arrow(ax, (2.7, 2.0), (4.0, 3.0)) - draw_arrow( - ax, - (7.0, 4.3), - (8.5, 5.2), - ArrowCfg(label="OrderEvent", label_fs=8), - ) - draw_arrow( - ax, - (7.0, 3.5), - (8.5, 3.7), - ArrowCfg(label="PaymentEvent", label_fs=8), - ) - draw_arrow( - ax, - (7.0, 3.0), - (8.5, 2.2), - ArrowCfg( - label="oba (dziedziczenie!)", label_fs=8 - ), - ) - - hx, hy = 0.5, 0.0 - draw_box( - ax, - (hx + 2.0, hy + 0.2), - (1.8, 0.6), - "Event", - BoxStyle( - fill=GRAY3, fontsize=8, fontweight="bold" - ), - ) - draw_box( - ax, - (hx, hy + 0.2), - (1.8, 0.6), - "OrderEvent", - BoxStyle(fill=GRAY4, fontsize=7.5), - ) - draw_box( - ax, - (hx + 4.0, hy + 0.2), - (2.0, 0.6), - "PaymentEvent", - BoxStyle(fill=GRAY4, fontsize=7.5), - ) - draw_arrow( - ax, - (hx + 2.9, hy + 0.2), - (hx + 0.9, hy + 0.2), - ArrowCfg( - lw=1.0, - label="extends", - label_offset=-0.3, - label_fs=7, - ), - ) - draw_arrow( - ax, - (hx + 2.9, hy + 0.2), - (hx + 5.0, hy + 0.2), - ArrowCfg( - lw=1.0, - label="extends", - label_offset=-0.3, - label_fs=7, - ), - ) - - ax.text( - 9.5, - 0.5, - "Sub C subskrybuje bazowy Event\n" - "\u2192 otrzymuje WSZYSTKIE podtypy", - ha="center", - va="center", - fontsize=8.5, - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - save(fig, "pubsub_sub_type.png") - - -# ============================================================ -# 4. Hierarchical / Wildcards subscription -# ============================================================ -def draw_sub_hierarchical() -> None: - """Draw sub hierarchical.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Subskrypcja hierarchiczna (wildcards)" - " \u2014 wzorce temat\u00f3w", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - bold10 = BoxStyle( - fill=GRAY2, fontsize=10, fontweight="bold" - ) - draw_box( - ax, (4.5, 5.8), (2.4, 0.8), "sensors/", bold10 - ) - - fs9_g3 = BoxStyle(fill=GRAY3, fontsize=9) - draw_box( - ax, - (1.5, 4.2), - (2.4, 0.8), - "temperature/", - fs9_g3, - ) - draw_box( - ax, - (7.5, 4.2), - (2.4, 0.8), - "humidity/", - fs9_g3, - ) - - fs85_g4 = BoxStyle(fill=GRAY4, fontsize=8.5) - draw_box( - ax, (0.2, 2.8), (1.8, 0.7), "room1", fs85_g4 - ) - draw_box( - ax, (2.4, 2.8), (1.8, 0.7), "room2", fs85_g4 - ) - draw_box( - ax, (6.8, 2.8), (1.8, 0.7), "room1", fs85_g4 - ) - draw_box( - ax, (9.0, 2.8), (1.8, 0.7), "room2", fs85_g4 - ) - - thin = ArrowCfg(lw=1.0) - draw_arrow(ax, (5.7, 5.8), (2.7, 5.0), thin) - draw_arrow(ax, (5.7, 5.8), (8.7, 5.0), thin) - draw_arrow(ax, (2.2, 4.2), (1.1, 3.5), thin) - draw_arrow(ax, (3.2, 4.2), (3.3, 3.5), thin) - draw_arrow(ax, (8.2, 4.2), (7.7, 3.5), thin) - draw_arrow(ax, (9.2, 4.2), (9.9, 3.5), thin) - - ax.text( - 1.1, - 2.4, - "sensors/temperature/room1", - fontsize=7, - ha="center", - fontfamily="monospace", - style="italic", - ) - ax.text( - 3.3, - 2.4, - "sensors/temperature/room2", - fontsize=7, - ha="center", - fontfamily="monospace", - style="italic", - ) - - ax.text( - 0.3, - 1.5, - "Wzorce subskrypcji (MQTT-style):", - fontsize=10, - fontweight="bold", - ) - - patterns = [ - ( - '"sensors/temperature/room1"', - "\u2192 TYLKO room1", - "(dok\u0142adne dopasowanie)", - ), - ( - '"sensors/temperature/*"', - "\u2192 room1, room2", - "( * = jeden poziom)", - ), - ( - '"sensors/#"', - "\u2192 WSZYSTKO", - "( # = dowolna g\u0142\u0119boko\u015b\u0107)", - ), - ] - for i, (pat, result, note) in enumerate(patterns): - yy = 0.9 - i * 0.55 - ax.text( - 0.5, - yy, - pat, - fontsize=9, - fontweight="bold", - fontfamily="monospace", - ) - ax.text( - 7.0, yy, result, fontsize=9, fontweight="bold" - ) - ax.text( - 9.5, yy, note, fontsize=8, style="italic" - ) - - save(fig, "pubsub_sub_hierarchical.png") - - -# ============================================================ -# 5. At-most-once (QoS 0) -# ============================================================ -def draw_qos_at_most_once() -> None: - """Draw qos at most once.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 6) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "QoS: At-most-once" - " \u2014 \u201ewy\u015blij i zapomnij\u201d" - " (0 lub 1 dostarczenie)", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - px, bx, sx = 1.0, 4.8, 8.5 - pw, bw, sw = 2.0, 2.2, 2.0 - bh = 0.8 - bold10_g1 = BoxStyle( - fill=GRAY1, fontsize=10, fontweight="bold" - ) - bold10_g2 = BoxStyle( - fill=GRAY2, fontsize=10, fontweight="bold" - ) - draw_box( - ax, (px, 5.0), (pw, bh), "Publisher", bold10_g1 - ) - draw_box( - ax, (bx, 5.0), (bw, bh), "Broker", bold10_g2 - ) - draw_box( - ax, - (sx, 5.0), - (sw, bh), - "Subscriber", - bold10_g1, - ) - - for xc in [px + pw / 2, bx + bw / 2, sx + sw / 2]: - ax.plot( - [xc, xc], - [5.0, 1.2], - color=GRAY3, - lw=1, - linestyle=":", - ) - - # Scenario A: success - y = 4.3 - ax.text( - 0.2, - y + 0.15, - "Scenariusz A:", - fontsize=8.5, - fontweight="bold", - ) - msg9 = ArrowCfg(label="MSG", label_fs=9) - draw_arrow( - ax, (px + pw / 2, y), (bx + bw / 2, y), msg9 - ) - draw_arrow( - ax, - (bx + bw / 2, y - 0.6), - (sx + sw / 2, y - 0.6), - msg9, - ) - draw_check(ax, (sx + sw / 2 + 0.4, y - 0.6), size=0.18) - ax.text( - sx + sw / 2 + 0.7, - y - 0.6, - "OK", - fontsize=9, - fontweight="bold", - ) - - # Scenario B: lost - y = 2.6 - ax.text( - 0.2, - y + 0.15, - "Scenariusz B:", - fontsize=8.5, - fontweight="bold", - ) - draw_arrow( - ax, (px + pw / 2, y), (bx + bw / 2, y), msg9 - ) - draw_dashed_arrow( - ax, (bx + bw / 2, y - 0.6), (7.5, y - 0.6) - ) - draw_cross(ax, (7.8, y - 0.6), size=0.2) - ax.text( - 8.2, - y - 0.55, - "UTRACONA", - fontsize=9, - fontweight="bold", - ) - ax.text( - 8.2, - y - 1.0, - "(brak retransmisji)", - fontsize=8, - style="italic", - ) - - ax.text( - 6.0, - 0.5, - "Brak ACK, brak retransmisji." - " Najszybszy. Use case:" - " logi, metryki, telemetria.", - ha="center", - va="center", - fontsize=9, - bbox={ - "boxstyle": "round,pad=0.4", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - save(fig, "pubsub_qos_at_most_once.png") - - -# ============================================================ -# 6. At-least-once (QoS 1) -# ============================================================ -def draw_qos_at_least_once() -> None: - """Draw qos at least once.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 6.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "QoS: At-least-once" - " \u2014 \u201epowtarzaj a\u017c potwierdz\u0105\u201d" - " (\u22651 dostarczenie)", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - bx, bw = 3.5, 2.2 - sx, sw = 8.0, 2.2 - bh = 0.8 - draw_box( - ax, - (bx, 5.5), - (bw, bh), - "Broker", - BoxStyle( - fill=GRAY2, fontsize=10, fontweight="bold" - ), - ) - draw_box( - ax, - (sx, 5.5), - (sw, bh), - "Subscriber", - BoxStyle( - fill=GRAY1, fontsize=10, fontweight="bold" - ), - ) - - for xc in [bx + bw / 2, sx + sw / 2]: - ax.plot( - [xc, xc], - [5.5, 0.8], - color=GRAY3, - lw=1, - linestyle=":", - ) - - # Step 1: send MSG - y1 = 4.8 - draw_arrow( - ax, - (bx + bw / 2, y1), - (sx + sw / 2, y1), - ArrowCfg(label="MSG #1", label_fs=9), - ) - draw_check(ax, (sx + sw + 0.2, y1), size=0.15) - ax.text(sx + sw + 0.5, y1, "odebrano", fontsize=8) - - # Step 2: ACK lost - y2 = 3.9 - draw_dashed_arrow( - ax, - (sx + sw / 2, y2), - (bx + bw + 1.2, y2), - ) - ax.text( - (bx + bw / 2 + sx + sw / 2) / 2, - y2 + 0.18, - "ACK", - fontsize=9, - ) - draw_cross(ax, (bx + bw + 0.8, y2), size=0.18) - ax.text( - bx + 0.3, - y2 - 0.35, - "ACK utracony!", - fontsize=8.5, - style="italic", - ) - - # Step 3: timeout -> retry - y3 = 2.9 - ax.text( - bx + bw / 2, - y3 + 0.45, - "timeout...", - fontsize=8.5, - style="italic", - ha="center", - ) - draw_arrow( - ax, - (bx + bw / 2, y3), - (sx + sw / 2, y3), - ArrowCfg(label="MSG #1 (retry)", label_fs=9), - ) - draw_check(ax, (sx + sw + 0.2, y3), size=0.15) - ax.text( - sx + sw + 0.5, - y3, - "odebrano\n(ponownie!)", - fontsize=8, - ) - - # Step 4: ACK ok - y4 = 2.0 - draw_arrow( - ax, - (sx + sw / 2, y4), - (bx + bw / 2, y4), - ArrowCfg(label="ACK", label_fs=9), - ) - draw_check(ax, (bx + bw / 2 - 0.5, y4), size=0.18) - - # Duplicate bracket - ax.annotate( - "", - xy=(sx + sw + 1.3, y1), - xytext=(sx + sw + 1.3, y3), - arrowprops={ - "arrowstyle": "<->", - "color": "black", - "lw": 1.2, - }, - ) - ax.text( - sx + sw + 1.6, - (y1 + y3) / 2, - "DUPLIKAT!\nSubscriber\notrzyma\u0142 2x", - fontsize=9, - ha="left", - va="center", - fontweight="bold", - bbox={ - "boxstyle": "round,pad=0.25", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - ax.text( - 6.0, - 0.5, - "Broker czeka na ACK, retransmituje" - " po timeout. Mog\u0105 by\u0107 duplikaty!\n" - "Use case: zam\u00f3wienia, p\u0142atno\u015bci" - " (subscriber musi by\u0107 idempotentny).", - ha="center", - va="center", - fontsize=9, - bbox={ - "boxstyle": "round,pad=0.4", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - save(fig, "pubsub_qos_at_least_once.png") - - -# ============================================================ -# 7. Exactly-once (QoS 2) -# ============================================================ -def draw_qos_exactly_once() -> None: - """Draw qos exactly once.""" - fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "QoS: Exactly-once \u2014 4-krokowy" - " handshake (dok\u0142adnie 1 dostarczenie)", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - bx, bw = 2.5, 2.2 - sx, sw = 7.5, 2.2 - bh = 0.8 - draw_box( - ax, - (bx, 6.0), - (bw, bh), - "Broker", - BoxStyle( - fill=GRAY2, fontsize=10, fontweight="bold" - ), - ) - draw_box( - ax, - (sx, 6.0), - (sw, bh), - "Subscriber", - BoxStyle( - fill=GRAY1, fontsize=10, fontweight="bold" - ), - ) - - for xc in [bx + bw / 2, sx + sw / 2]: - ax.plot( - [xc, xc], - [6.0, 1.0], - color=GRAY3, - lw=1, - linestyle=":", - ) - - steps = [ - ( - 5.2, - "right", - "PUBLISH (msg_id=42)", - "Broker wysy\u0142a wiadomo\u015b\u0107", - ), - ( - 4.2, - "left", - "PUBREC (otrzyma\u0142em id=42)", - "Sub potwierdza odbi\u00f3r," - " zapisuje id", - ), - ( - 3.2, - "right", - "PUBREL (mo\u017cesz przetworzy\u0107)", - "Broker zwalnia wiadomo\u015b\u0107", - ), - ( - 2.2, - "left", - "PUBCOMP (zako\u0144czone)", - "Sub potwierdza przetworzenie", - ), - ] - - for i, (y, direction, label, desc) in enumerate( - steps - ): - ax.text( - bx + bw / 2 - 0.7, - y, - f"{i + 1}", - fontsize=9, - fontweight="bold", - ha="center", - va="center", - bbox={ - "boxstyle": "circle,pad=0.18", - "facecolor": GRAY3, - "edgecolor": LN, - }, - ) - - if direction == "right": - draw_arrow( - ax, - (bx + bw / 2, y), - (sx + sw / 2, y), - ArrowCfg(label=label, label_fs=9), - ) - else: - draw_arrow( - ax, - (sx + sw / 2, y), - (bx + bw / 2, y), - ArrowCfg(label=label, label_fs=9), - ) - - ax.text( - sx + sw + 0.3, - y, - desc, - fontsize=8, - ha="left", - va="center", - style="italic", - ) - - ax.text( - 6.0, - 0.6, - "Deduplikacja po msg_id." - " Sub nie przetwarza przed PUBREL.\n" - "Najkosztowniejszy (4 pakiety)." - " Use case: transakcje finansowe," - " krytyczne zdarzenia.", - ha="center", - va="center", - fontsize=9, - bbox={ - "boxstyle": "round,pad=0.4", - "facecolor": GRAY4, - "edgecolor": GRAY3, - }, - ) - - save(fig, "pubsub_qos_exactly_once.png") - - # ============================================================ # Main # ============================================================ if __name__ == "__main__": - logger.info( - "Generating Pub/Sub diagrams" - " (7 separate images)..." - ) + logger.info("Generating Pub/Sub diagrams" " (7 separate images)...") draw_sub_topic() draw_sub_content() draw_sub_type() diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py old mode 100755 new mode 100644 index 4893ad2..26902d2 --- a/python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py @@ -2,2125 +2,62 @@ """Generate ALL diagrams for PYTANIE 20: Analityka danych strumieniowych. Monochrome, A4-printable PNGs (300 DPI). +Re-exports all diagram generators from submodules. """ from __future__ import annotations import logging -from typing import TYPE_CHECKING - -import matplotlib as mpl - -mpl.use("Agg") from pathlib import Path +import sys -import matplotlib.patches as mpatches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt -import numpy as np +# Ensure sibling modules are importable when run as a script. +sys.path.insert(0, str(Path(__file__).resolve().parent)) -if TYPE_CHECKING: - from matplotlib.axes import Axes - from matplotlib.figure import Figure +from _q20_architectures import ( + gen_exactly_once, + gen_lambda_kappa_table, + gen_lambda_vs_kappa, + gen_spark_streaming_arch, +) +from _q20_batch_and_windows import gen_batch_vs_streaming, gen_window_types +from _q20_late_and_decisions import gen_decision_tree, gen_late_data_strategies +from _q20_platforms import ( + gen_flink_arch, + gen_kafka_streams_arch, + gen_platform_comparison, + gen_streaming_ecosystem, + gen_true_vs_microbatch, +) +from _q20_time_monitoring_sessions import ( + gen_event_vs_processing_time, + gen_session_users, + gen_sliding_sla, + gen_tumbling_fraud, +) + +__all__ = [ + "gen_batch_vs_streaming", + "gen_decision_tree", + "gen_event_vs_processing_time", + "gen_exactly_once", + "gen_flink_arch", + "gen_kafka_streams_arch", + "gen_lambda_kappa_table", + "gen_lambda_vs_kappa", + "gen_late_data_strategies", + "gen_platform_comparison", + "gen_session_users", + "gen_sliding_sla", + "gen_spark_streaming_arch", + "gen_streaming_ecosystem", + "gen_true_vs_microbatch", + "gen_tumbling_fraud", + "gen_window_types", +] _logger = logging.getLogger(__name__) -rng = np.random.default_rng(42) - -DPI = 300 -BG = "white" -LN = "black" -FS = 8 -FS_TITLE = 11 -FS_SMALL = 6.5 -FS_LABEL = 9 -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -GRAY1 = "#E8E8E8" -GRAY2 = "#D0D0D0" -GRAY3 = "#B8B8B8" -GRAY4 = "#F5F5F5" -GRAY5 = "#C0C0C0" - - -def draw_box( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - text: str, - *, - fill: str = "white", - lw: float = 1.2, - fontsize: float = FS, - fontweight: str = "normal", - ha: str = "center", - va: str = "center", - rounded: bool = True, - edgecolor: str = LN, - linestyle: str = "-", -) -> None: - """Draw box.""" - if rounded: - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.05", - lw=lw, - edgecolor=edgecolor, - facecolor=fill, - linestyle=linestyle, - ) - else: - rect = mpatches.Rectangle( - (x, y), - w, - h, - lw=lw, - edgecolor=edgecolor, - facecolor=fill, - linestyle=linestyle, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h / 2, - text, - ha=ha, - va=va, - fontsize=fontsize, - fontweight=fontweight, - wrap=True, - ) - - -def draw_arrow( - ax: Axes, - x1: float, - y1: float, - x2: float, - y2: float, - lw: float = 1.2, - style: str = "->", - color: str = LN, -) -> None: - """Draw arrow.""" - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": style, "color": color, "lw": lw}, - ) - - -def save_fig(fig: Figure, name: str) -> None: - """Save fig.""" - path = str(Path(OUTPUT_DIR) / name) - fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15) - plt.close(fig) - _logger.info(" Saved: %s", path) - - -def draw_table( - ax: Axes, - headers: list[str], - rows: list[list[str]], - x0: float, - y0: float, - col_widths: list[float], - row_h: float = 0.4, - header_fill: str = GRAY2, - row_fills: list[str] | None = None, - fontsize: float = FS, - header_fontsize: float | None = None, -) -> None: - """Draw table.""" - if header_fontsize is None: - header_fontsize = fontsize - len(headers) - # Header - cx = x0 - for j, hdr in enumerate(headers): - draw_box( - ax, - cx, - y0, - col_widths[j], - row_h, - hdr, - fill=header_fill, - fontsize=header_fontsize, - fontweight="bold", - rounded=False, - ) - cx += col_widths[j] - # Rows - for i, row in enumerate(rows): - cy = y0 - (i + 1) * row_h - cx = x0 - fill = GRAY4 if (i % 2 == 0) else "white" - if row_fills and i < len(row_fills): - fill = row_fills[i] - for j, cell in enumerate(row): - fw = "bold" if j == 0 else "normal" - draw_box( - ax, - cx, - cy, - col_widths[j], - row_h, - cell, - fill=fill, - fontsize=fontsize, - fontweight=fw, - rounded=False, - ) - cx += col_widths[j] - - -# ============================================================ -# 1. Batch vs Streaming concept -# ============================================================ -def gen_batch_vs_streaming() -> None: - """Gen batch vs streaming.""" - fig, axes = plt.subplots(2, 1, figsize=(9, 5)) - fig.suptitle( - "Batch vs Streaming — dwa modele przetwarzania", - fontsize=FS_TITLE, - fontweight="bold", - ) - - # Batch - ax = axes[0] - ax.set_xlim(0, 12) - ax.set_ylim(0, 3) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("BATCH (wsadowe)", fontsize=FS_LABEL, fontweight="bold") - - # Data collected - draw_box( - ax, - 0.5, - 0.8, - 3.0, - 1.4, - "Zbierz WSZYSTKIE\ndane\n(godziny / dni)", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 3.5, 1.5, 4.5, 1.5, lw=2) - draw_box( - ax, - 4.5, - 0.8, - 2.5, - 1.4, - "Analiza\n(batch job)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 7.0, 1.5, 8.0, 1.5, lw=2) - draw_box( - ax, - 8.0, - 0.8, - 2.5, - 1.4, - "Wynik\n(jednorazowy)", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - ax.text(11.0, 1.5, "min-h", fontsize=FS, va="center", fontweight="bold") - - # Streaming - ax = axes[1] - ax.set_xlim(0, 12) - ax.set_ylim(0, 3) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("STREAMING (strumieniowe)", fontsize=FS_LABEL, fontweight="bold") - - # Events flowing - events_x = [0.5, 1.5, 2.5, 3.5] - for i, ex in enumerate(events_x): - draw_box( - ax, - ex, - 1.0, - 0.8, - 0.8, - f"e{i + 1}", - fill=GRAY4, - fontsize=FS, - fontweight="bold", - rounded=False, - ) - if i < len(events_x) - 1: - draw_arrow(ax, ex + 0.8, 1.4, ex + 1.0, 1.4, lw=1) - - ax.text(4.8, 1.4, "...", fontsize=FS_LABEL, va="center") - draw_arrow(ax, 5.2, 1.4, 5.8, 1.4, lw=2) - - draw_box( - ax, - 5.8, - 0.8, - 2.8, - 1.4, - "Analiza\nCIĄGŁA\n(event-by-event)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 8.6, 1.5, 9.3, 1.5, lw=2) - draw_box( - ax, - 9.3, - 0.8, - 2.0, - 1.4, - "Wyniki\nciągłe", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - ax.text(11.5, 0.5, "ms-s", fontsize=FS, va="center", fontweight="bold") - - # Arrow marking infinity - ax.annotate( - "", - xy=(0.2, 1.4), - xytext=(-0.3, 1.4), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(0.0, 2.3, "∞ zdarzeń", fontsize=FS_SMALL, ha="center", style="italic") - - fig.tight_layout(rect=[0, 0, 1, 0.92]) - save_fig(fig, "q20_batch_vs_streaming.png") - - -# ============================================================ -# 2. All 4 window types (TSSG) -# ============================================================ -def _draw_tumbling_window(ax: Axes, events: list[int]) -> None: - """Draw tumbling window section.""" - ax.set_xlim(0, 14) - ax.set_ylim(0, 4) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Tumbling Window (okno przerzutne) — rozłączne, stały rozmiar", - fontsize=FS_LABEL, - fontweight="bold", - ) - - # Time axis - ax.annotate( - "", - xy=(13.5, 1.0), - xytext=(0.3, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") - - # Events - for i, e in enumerate(events): - x = 1.0 + i * 1.0 - ax.plot(x, 1.0, "ko", markersize=5) - ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center") - - # Windows - colors_w = [GRAY1, GRAY3, GRAY1, GRAY3] - for w in range(4): - x_start = 1.0 + w * 3.0 - 0.3 - rect = mpatches.FancyBboxPatch( - (x_start, 1.5), - 3.0, - 1.2, - boxstyle="round,pad=0.1", - facecolor=colors_w[w], - edgecolor=LN, - lw=1.5, - ) - ax.add_patch(rect) - ax.text( - x_start + 1.5, - 2.1, - f"Okno {w + 1}", - fontsize=FS, - ha="center", - fontweight="bold", - ) - # Braces down to events - for j in range(3): - ex = 1.0 + w * 3.0 + j * 1.0 - ax.plot([ex, ex], [1.0, 1.5], color=LN, lw=0.8, linestyle="--") - - ax.text( - 7.0, - 3.2, - "Każde zdarzenie → DOKŁADNIE 1 okno. Zero nakładania.", - fontsize=FS, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - - -def _draw_sliding_window(ax: Axes, events: list[int]) -> None: - """Draw sliding window section.""" - ax.set_xlim(0, 14) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Sliding Window (okno przesuwne) — nakładające, stały rozmiar + krok", - fontsize=FS_LABEL, - fontweight="bold", - ) - - ax.annotate( - "", - xy=(13.5, 1.0), - xytext=(0.3, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") - - for i, e in enumerate(events[:8]): - x = 1.0 + i * 1.0 - ax.plot(x, 1.0, "ko", markersize=5) - ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center") - - # Sliding windows: size=4, slide=2 - slide_colors = [GRAY1, GRAY2, GRAY3] - for w in range(3): - x_start = 0.7 + w * 2.0 - y_base = 1.5 + w * 0.9 - rect = mpatches.FancyBboxPatch( - (x_start, y_base), - 4.0, - 0.7, - boxstyle="round,pad=0.08", - facecolor=slide_colors[w], - edgecolor=LN, - lw=1.5, - alpha=0.7, - ) - ax.add_patch(rect) - ax.text( - x_start + 2.0, - y_base + 0.35, - f"Okno {w + 1} (size=4)", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - ) - - ax.text( - 10.5, - 3.5, - "krok=2\nNakładanie!\ne3,e4 → w oknie 1 i 2", - fontsize=FS, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - - -def _draw_session_window(ax: Axes) -> None: - """Draw session window section.""" - ax.set_xlim(0, 14) - ax.set_ylim(0, 4) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Session Window (okno sesji) — dynamiczny rozmiar, gap = przerwa", - fontsize=FS_LABEL, - fontweight="bold", - ) - - ax.annotate( - "", - xy=(13.5, 1.0), - xytext=(0.3, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") - - # Cluster 1: events close together - cluster1 = [1.0, 1.8, 2.3, 3.0] - for x in cluster1: - ax.plot(x, 1.0, "ko", markersize=5) - - # Gap - ax.annotate( - "", - xy=(7.0, 0.7), - xytext=(4.0, 0.7), - arrowprops={"arrowstyle": "<->", "lw": 1, "color": LN}, - ) - ax.text( - 5.5, - 0.3, - "GAP > timeout", - fontsize=FS, - ha="center", - fontweight="bold", - style="italic", - ) - - # Cluster 2 - cluster2 = [8.0, 8.8, 9.5] - for x in cluster2: - ax.plot(x, 1.0, "ko", markersize=5) - - # Session boxes - rect1 = mpatches.FancyBboxPatch( - (0.7, 1.4), - 2.6, - 1.0, - boxstyle="round,pad=0.1", - facecolor=GRAY1, - edgecolor=LN, - lw=1.5, - ) - ax.add_patch(rect1) - ax.text( - 2.0, 1.9, "Sesja 1\n(4 zdarzenia)", fontsize=FS, ha="center", fontweight="bold" - ) - - rect2 = mpatches.FancyBboxPatch( - (7.7, 1.4), - 2.1, - 1.0, - boxstyle="round,pad=0.1", - facecolor=GRAY3, - edgecolor=LN, - lw=1.5, - ) - ax.add_patch(rect2) - ax.text( - 8.75, 1.9, "Sesja 2\n(3 zdarzenia)", fontsize=FS, ha="center", fontweight="bold" - ) - - ax.text( - 5.5, - 3.0, - "Nowa sesja po przerwie > gap", - fontsize=FS, - ha="center", - style="italic", - ) - - - -def _draw_global_window(ax: Axes) -> None: - """Draw global window section.""" - ax.set_xlim(0, 14) - ax.set_ylim(0, 4) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Global Window — jedno okno na cały strumień + trigger", - fontsize=FS_LABEL, - fontweight="bold", - ) - - ax.annotate( - "", - xy=(13.5, 1.0), - xytext=(0.3, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center") - - for i in range(12): - x = 1.0 + i * 1.0 - ax.plot(x, 1.0, "ko", markersize=5) - - # One big window - rect = mpatches.FancyBboxPatch( - (0.5, 1.4), - 12.5, - 1.0, - boxstyle="round,pad=0.1", - facecolor=GRAY1, - edgecolor=LN, - lw=2, - ) - ax.add_patch(rect) - ax.text( - 6.75, - 1.9, - "GLOBAL WINDOW (cały strumień)", - fontsize=FS, - ha="center", - fontweight="bold", - ) - - # Trigger markers - for tx in [4.0, 8.0, 12.0]: - ax.plot([tx, tx], [1.4, 2.4], color=LN, lw=2, linestyle="--") - ax.text( - tx, - 2.7, - "EMIT", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY3, "edgecolor": LN}, - ) - - ax.text( - 6.75, - 3.3, - "Trigger decyduje kiedy emitować (np. co N zdarzeń)", - fontsize=FS, - ha="center", - style="italic", - ) - - - -def gen_window_types() -> None: - """Gen window types.""" - fig, axes = plt.subplots(4, 1, figsize=(9, 10)) - fig.suptitle("4 typy okien — TSSG", fontsize=FS_TITLE, fontweight="bold") - - events = list(range(1, 13)) - - _draw_tumbling_window(axes[0], events) - _draw_sliding_window(axes[1], events) - _draw_session_window(axes[2]) - _draw_global_window(axes[3]) - - fig.tight_layout(rect=[0, 0, 1, 0.94]) - save_fig(fig, "q20_window_types.png") - -# ============================================================ -# 3. Event Time vs Processing Time scatter + watermark -# ============================================================ -def gen_event_vs_processing_time() -> None: - """Gen event vs processing time.""" - fig, axes = plt.subplots(1, 2, figsize=(11, 5)) - fig.suptitle( - "Event Time vs Processing Time + Watermark", - fontsize=FS_TITLE, - fontweight="bold", - ) - - # --- Panel 1: Ideal vs Real --- - ax = axes[0] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.set_aspect("equal") - ax.set_xlabel("Event Time", fontsize=FS_LABEL) - ax.set_ylabel("Processing Time", fontsize=FS_LABEL) - ax.set_title("Idealny vs Realny świat", fontsize=FS_LABEL, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Ideal line - ax.plot([0, 9], [0, 9], "k--", lw=1.5, label="ideał (brak opóźnień)") - - # Real scattered points (processing >= event, some out of order) - event_times = np.sort(rng.uniform(1, 8, 15)) - proc_times = event_times + rng.exponential(0.5, 15) - # Make some out of order - idx = [3, 7, 11] - for i in idx: - proc_times[i] += 1.5 - - ax.scatter( - event_times, proc_times, c="black", s=30, zorder=5, label="zdarzenia (realne)" - ) - - # Highlight out-of-order - for i in idx: - ax.annotate( - "out-of-order", - xy=(event_times[i], proc_times[i]), - xytext=(event_times[i] + 0.8, proc_times[i] + 0.5), - fontsize=FS_SMALL, - ha="left", - arrowprops={"arrowstyle": "->", "lw": 0.8, "color": "#555"}, - ) - - ax.legend(fontsize=FS_SMALL, loc="upper left") - ax.text( - 7, - 2, - "Opóźnienie\nsieciowe ↑", - fontsize=FS, - ha="center", - style="italic", - color="#555", - ) - - # --- Panel 2: Watermark concept --- - ax = axes[1] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.set_aspect("equal") - ax.set_xlabel("Event Time", fontsize=FS_LABEL) - ax.set_ylabel("Processing Time", fontsize=FS_LABEL) - ax.set_title("Watermark — granica postępu", fontsize=FS_LABEL, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Events - ax.scatter(event_times, proc_times, c="black", s=30, zorder=5) - - # Watermark line (below most points, tracks progress) - wm_x = np.linspace(0, 9, 50) - wm_y = wm_x + 0.3 # watermark slightly above ideal - ax.plot(wm_x, wm_y, "k-", lw=2.5, label="Watermark") - ax.fill_between(wm_x, 0, wm_y, alpha=0.15, color="gray") - - ax.text( - 2.0, - 1.0, - 'PONIŻEJ watermark:\n„na pewno dotarło"', - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - # Late event - late_x, late_y = event_times[7], proc_times[7] - ax.scatter( - [late_x], [late_y], c="white", s=80, zorder=6, edgecolors="black", linewidths=2 - ) - ax.annotate( - "LATE DATA!\n(po watermarku)", - xy=(late_x, late_y), - xytext=(late_x + 1.2, late_y + 0.8), - fontsize=FS_SMALL, - ha="left", - fontweight="bold", - arrowprops={"arrowstyle": "->", "lw": 1, "color": LN}, - ) - - ax.legend(fontsize=FS_SMALL, loc="upper left") - - fig.tight_layout(rect=[0, 0, 1, 0.92]) - save_fig(fig, "q20_event_vs_processing_time.png") - - -# ============================================================ -# 4. Tumbling window example — fraud detection -# ============================================================ -def gen_tumbling_fraud() -> None: - """Gen tumbling fraud.""" - fig, ax = plt.subplots(figsize=(9, 4)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 5.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Tumbling Window — fraud detection (okno = 1 min)", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Time axis - ax.annotate( - "", - xy=(11.5, 1.0), - xytext=(0.5, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(6.0, 0.4, "czas", fontsize=FS, ha="center") - - # Window 1: normal - draw_box(ax, 1.0, 1.5, 4.5, 3.0, "", fill=GRAY4, rounded=True, lw=2) - ax.text( - 3.25, 4.2, "[14:00 — 14:01]", fontsize=FS_LABEL, ha="center", fontweight="bold" - ) - # Transactions - txns1 = ["Sklep A: 50 zł", "Sklep B: 30 zł", "Stacja: 80 zł"] - for i, t in enumerate(txns1): - draw_box( - ax, - 1.3, - 3.3 - i * 0.55, - 4.0, - 0.45, - t, - fill=GRAY1, - fontsize=FS_SMALL, - rounded=False, - ) - ax.text( - 3.25, - 1.7, - "count = 3 → OK", - fontsize=FS, - ha="center", - fontweight="bold", - color="#2E7D32", - bbox={ - "boxstyle": "round,pad=0.15", - "facecolor": "#E8F5E9", - "edgecolor": "#2E7D32", - }, - ) - - # Window 2: fraud! - draw_box(ax, 6.0, 1.5, 4.5, 3.0, "", fill=GRAY1, rounded=True, lw=2) - ax.text( - 8.25, 4.2, "[14:01 — 14:02]", fontsize=FS_LABEL, ha="center", fontweight="bold" - ) - txns2 = ["ATM Warszawa: 500 zł", "ATM Kraków: 500 zł", "... +45 transakcji"] - for i, t in enumerate(txns2): - draw_box( - ax, - 6.3, - 3.3 - i * 0.55, - 4.0, - 0.45, - t, - fill=GRAY3, - fontsize=FS_SMALL, - rounded=False, - ) - ax.text( - 8.25, - 1.7, - "count = 47 → ALERT!", - fontsize=FS, - ha="center", - fontweight="bold", - color="#C62828", - bbox={ - "boxstyle": "round,pad=0.15", - "facecolor": "#F8D7DA", - "edgecolor": "#C62828", - }, - ) - - save_fig(fig, "q20_tumbling_fraud.png") - - -# ============================================================ -# 5. Sliding window — SLA monitoring -# ============================================================ -def gen_sliding_sla() -> None: - """Gen sliding sla.""" - fig, ax = plt.subplots(figsize=(9, 4.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 6) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Sliding Window — monitoring SLA (okno=5min, krok=1min)", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Time axis - ax.annotate( - "", - xy=(11.5, 0.5), - xytext=(0.5, 0.5), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - times = ["14:05", "14:06", "14:07", "14:08", "14:09"] - latencies = [120, 180, 340, 290, 150] - sla = 200 - - for i, (t, lat) in enumerate(zip(times, latencies, strict=False)): - x = 1.5 + i * 2.0 - ax.text(x, 0.1, t, fontsize=FS, ha="center") - - # Bar proportional to latency - bar_h = lat / 100.0 - is_breach = lat > sla - fill = "#F8D7DA" if is_breach else GRAY1 - edge = "#C62828" if is_breach else LN - draw_box( - ax, - x - 0.5, - 1.0, - 1.0, - bar_h, - "", - fill=fill, - rounded=False, - edgecolor=edge, - lw=1.5, - ) - ax.text( - x, - 1.0 + bar_h + 0.15, - f"{lat}ms", - fontsize=FS, - ha="center", - fontweight="bold", - color="#C62828" if is_breach else LN, - ) - - # Status - status = "ALERT!" if is_breach else "OK" - ax.text( - x, - 1.0 + bar_h + 0.55, - status, - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - color="#C62828" if is_breach else "#2E7D32", - ) - - # SLA line - sla_y = 1.0 + sla / 100.0 - ax.plot([0.8, 11.2], [sla_y, sla_y], "k--", lw=1.5) - ax.text(11.3, sla_y, f"SLA={sla}ms", fontsize=FS, va="center", fontweight="bold") - - # Sliding window bracket - ax.annotate( - "", - xy=(1.0, 5.3), - xytext=(5.0, 5.3), - arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN}, - ) - ax.text(3.0, 5.6, "okno = 5 min", fontsize=FS, ha="center", fontweight="bold") - - ax.annotate( - "", - xy=(3.0, 4.8), - xytext=(5.0, 4.8), - arrowprops={"arrowstyle": "<->", "lw": 1, "color": "#555"}, - ) - ax.text( - 4.0, - 4.4, - "krok = 1 min\n(nakładanie!)", - fontsize=FS_SMALL, - ha="center", - style="italic", - ) - - save_fig(fig, "q20_sliding_sla.png") - - -# ============================================================ -# 6. Session window — user sessions -# ============================================================ -def gen_session_users() -> None: - """Gen session users.""" - fig, axes = plt.subplots(2, 1, figsize=(10, 5)) - fig.suptitle( - "Session Window — sesje użytkowników (gap = 30 min)", - fontsize=FS_TITLE, - fontweight="bold", - ) - - # Anna: 2 sessions - ax = axes[0] - ax.set_xlim(0, 14) - ax.set_ylim(0, 3.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Użytkownik Anna", fontsize=FS_LABEL, fontweight="bold") - - ax.annotate( - "", - xy=(13.5, 1.0), - xytext=(0.3, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - - # Clicks cluster 1 - for x in [1.0, 1.8, 2.5, 3.2]: - ax.plot(x, 1.0, "ko", markersize=6) - # Clicks cluster 2 - for x in [9.0, 9.8, 10.5]: - ax.plot(x, 1.0, "ko", markersize=6) - - # Sessions - rect1 = mpatches.FancyBboxPatch( - (0.7, 1.5), - 2.8, - 1.2, - boxstyle="round,pad=0.1", - facecolor=GRAY1, - edgecolor=LN, - lw=1.5, - ) - ax.add_patch(rect1) - ax.text( - 2.1, - 2.1, - "Sesja 1\n4 kliknięcia, 12 min", - fontsize=FS, - ha="center", - fontweight="bold", - ) - - rect2 = mpatches.FancyBboxPatch( - (8.7, 1.5), - 2.1, - 1.2, - boxstyle="round,pad=0.1", - facecolor=GRAY3, - edgecolor=LN, - lw=1.5, - ) - ax.add_patch(rect2) - ax.text( - 9.75, - 2.1, - "Sesja 2\n3 kliknięcia, 8 min", - fontsize=FS, - ha="center", - fontweight="bold", - ) - - # Gap - ax.annotate( - "", - xy=(8.5, 0.5), - xytext=(3.8, 0.5), - arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN}, - ) - ax.text( - 6.15, - 0.1, - "cisza 45 min > gap(30)", - fontsize=FS, - ha="center", - fontweight="bold", - style="italic", - ) - - # Bob: 1 session - ax = axes[1] - ax.set_xlim(0, 14) - ax.set_ylim(0, 3.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Użytkownik Bob", fontsize=FS_LABEL, fontweight="bold") - - ax.annotate( - "", - xy=(13.5, 1.0), - xytext=(0.3, 1.0), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - - # Clicks spread evenly - bobs = [1.0, 2.5, 4.0, 5.5, 7.0, 8.5, 10.0] - for x in bobs: - ax.plot(x, 1.0, "ko", markersize=6) - - rect = mpatches.FancyBboxPatch( - (0.7, 1.5), - 9.6, - 1.2, - boxstyle="round,pad=0.1", - facecolor=GRAY1, - edgecolor=LN, - lw=2, - ) - ax.add_patch(rect) - ax.text( - 5.5, - 2.1, - "Sesja 1 (ciągła) — 7 kliknięć, każde < 30 min od poprzedniego", - fontsize=FS, - ha="center", - fontweight="bold", - ) - - fig.tight_layout(rect=[0, 0, 1, 0.92]) - save_fig(fig, "q20_session_users.png") - - -# ============================================================ -# 7. Streaming ecosystem overview -# ============================================================ -def gen_streaming_ecosystem() -> None: - """Gen streaming ecosystem.""" - fig, ax = plt.subplots(figsize=(10, 5.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Ekosystem przetwarzania strumieniowego", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Source - draw_box( - ax, - 0.3, - 2.5, - 2.0, - 3.0, - "Kafka\nTopics\n(źródło)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Engines - engines = [ - ("Kafka Streams\n(library w JVM)", GRAY1, 4.7), - ("Apache Flink\n(klaster)", GRAY3, 3.2), - ("Spark Streaming\n(klaster)", GRAY5, 1.7), - ] - for label, color, y in engines: - draw_box( - ax, 4.0, y, 3.0, 1.2, label, fill=color, fontsize=FS, fontweight="bold" - ) - draw_arrow(ax, 2.3, 4.0, 4.0, y + 0.6, lw=1.5) - - # Sinks - sinks = [ - ("Kafka topic\n/ baza danych", GRAY4, 4.7), - ("DB / Kafka\n/ S3", GRAY4, 3.2), - ("HDFS / DB\n/ dashboard", GRAY4, 1.7), - ] - for label, color, y in sinks: - draw_box(ax, 8.5, y, 2.5, 1.2, label, fill=color, fontsize=FS) - draw_arrow(ax, 7.0, y + 0.6, 8.5, y + 0.6, lw=1.5) - - # Labels - ax.text(1.3, 6.0, "ŹRÓDŁO", fontsize=FS_LABEL, ha="center", fontweight="bold") - ax.text(5.5, 6.2, "SILNIK", fontsize=FS_LABEL, ha="center", fontweight="bold") - ax.text(9.75, 6.2, "WYNIK", fontsize=FS_LABEL, ha="center", fontweight="bold") - - # Latency annotations - ax.text(5.5, 5.95, "~1-10 ms", fontsize=FS_SMALL, ha="center", style="italic") - ax.text(5.5, 4.5, "<10 ms", fontsize=FS_SMALL, ha="center", style="italic") - ax.text(5.5, 3.0, "~100 ms", fontsize=FS_SMALL, ha="center", style="italic") - - save_fig(fig, "q20_streaming_ecosystem.png") - - -# ============================================================ -# 8. True streaming vs Micro-batch -# ============================================================ -def gen_true_vs_microbatch() -> None: - """Gen true vs microbatch.""" - fig, axes = plt.subplots(2, 1, figsize=(10, 5.5)) - fig.suptitle("True Streaming vs Micro-Batch", fontsize=FS_TITLE, fontweight="bold") - - # True streaming - ax = axes[0] - ax.set_xlim(0, 12) - ax.set_ylim(0, 3.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "TRUE STREAMING (Flink, Kafka Streams) — event-by-event", - fontsize=FS_LABEL, - fontweight="bold", - ) - - for i in range(6): - x = 1.0 + i * 1.8 - # Event - draw_box( - ax, - x, - 2.0, - 0.8, - 0.7, - f"e{i + 1}", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - rounded=False, - ) - # Arrow down - draw_arrow(ax, x + 0.4, 2.0, x + 0.4, 1.4, lw=1) - # Result - draw_box( - ax, - x, - 0.5, - 0.8, - 0.7, - f"r{i + 1}", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - rounded=False, - ) - # Latency label - ax.text(x + 0.4, 1.6, "~ms", fontsize=5, ha="center", color="#555") - - ax.text( - 11.5, - 1.3, - "Latencja:\n< 10 ms", - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - # Micro-batch - ax = axes[1] - ax.set_xlim(0, 12) - ax.set_ylim(0, 3.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "MICRO-BATCH (Spark Streaming) — grupami co ~100ms", - fontsize=FS_LABEL, - fontweight="bold", - ) - - batch_colors = [GRAY1, GRAY2, GRAY3] - for b in range(3): - bx = 0.8 + b * 3.5 - # Batch boundary - draw_box(ax, bx, 1.8, 3.0, 1.0, "", fill=batch_colors[b], rounded=True, lw=1.5) - ax.text( - bx + 1.5, 2.6, f"Batch {b + 1}", fontsize=FS, ha="center", fontweight="bold" - ) - for j in range(3): - ex = bx + 0.3 + j * 0.9 - draw_box( - ax, - ex, - 2.0, - 0.7, - 0.5, - f"e{b * 3 + j + 1}", - fill="white", - fontsize=FS_SMALL, - rounded=False, - ) - - # Arrow down - draw_arrow(ax, bx + 1.5, 1.8, bx + 1.5, 1.2, lw=1.5) - # Result - draw_box( - ax, - bx + 0.5, - 0.4, - 2.0, - 0.7, - f"result {b + 1}", - fill=GRAY4, - fontsize=FS, - fontweight="bold", - ) - - ax.text( - 11.5, - 1.3, - "Latencja:\n~100ms-s", - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - fig.tight_layout(rect=[0, 0, 1, 0.92]) - save_fig(fig, "q20_true_vs_microbatch.png") - - -# ============================================================ -# 9. Platform comparison table -# ============================================================ -def gen_platform_comparison() -> None: - """Gen platform comparison.""" - fig, ax = plt.subplots(figsize=(9, 5)) - ax.set_xlim(0, 11.5) - ax.set_ylim(-6, 1) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Porównanie platform strumieniowych", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - headers = ["Cecha", "Kafka Streams", "Apache Flink", "Spark Streaming"] - col_w = [2.5, 2.8, 2.8, 2.8] - rows = [ - ["Model", "event-by-event", "event-by-event", "micro-batch (~100ms)"], - ["Deployment", "library (w JVM)", "klaster", "klaster"], - ["Latencja", "~1-10 ms", "< 10 ms", "100 ms - sekundy"], - ["Exactly-once", "Kafka TXN", "checkpointing", "WAL"], - ["State", "RocksDB local", "RocksDB + ckpt", "in-memory / ext"], - ["Okna", "T, S, Session", "wszystkie + custom", "T, S"], - ["Use case", "Kafka → Kafka", "złożona analityka", "ETL + ML / SQL"], - ] - draw_table( - ax, - headers, - rows, - x0=0.25, - y0=0.5, - col_widths=col_w, - row_h=0.6, - fontsize=7, - header_fontsize=8, - ) - - save_fig(fig, "q20_platform_comparison.png") - - -# ============================================================ -# 10. Kafka Streams architecture -# ============================================================ -def gen_kafka_streams_arch() -> None: - """Gen kafka streams arch.""" - fig, ax = plt.subplots(figsize=(9, 5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Kafka Streams — architektura (library w JVM)", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Outer box: Your Java application - draw_box(ax, 0.5, 0.5, 11.0, 5.5, "", fill=GRAY4, rounded=True, lw=2.5) - ax.text( - 6.0, - 5.7, - "Twoja aplikacja Java (JVM)", - fontsize=FS_LABEL, - ha="center", - fontweight="bold", - ) - - # Kafka Consumer - draw_box( - ax, - 1.0, - 3.0, - 2.5, - 1.5, - "Kafka\nConsumer\n(input topic)", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - - # Processing - draw_box( - ax, - 4.5, - 3.0, - 2.5, - 1.5, - "Kafka Streams\n(logika\nbiznesowa)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Kafka Producer - draw_box( - ax, - 8.0, - 3.0, - 2.5, - 1.5, - "Kafka\nProducer\n(output topic)", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - - # Arrows - draw_arrow(ax, 3.5, 3.75, 4.5, 3.75, lw=2) - draw_arrow(ax, 7.0, 3.75, 8.0, 3.75, lw=2) - - # RocksDB state store - draw_box( - ax, - 4.5, - 1.0, - 2.5, - 1.3, - "RocksDB\n(stan lokalny)", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - ax.plot([5.75, 5.75], [3.0, 2.3], color=LN, lw=1.5) - ax.text( - 7.3, - 1.6, - "okna, joiny,\nagregacje", - fontsize=FS_SMALL, - style="italic", - va="center", - ) - - # Key message - ax.text( - 6.0, - 0.2, - "NIE potrzebujesz osobnego klastra! Skalujesz = więcej instancji JVM.", - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN}, - ) - - save_fig(fig, "q20_kafka_streams_arch.png") - - -# ============================================================ -# 11. Flink architecture + checkpointing -# ============================================================ -def gen_flink_arch() -> None: - """Gen flink arch.""" - fig, ax = plt.subplots(figsize=(9, 6)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 8) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Apache Flink — architektura klastra + checkpointing", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Cluster border - draw_box(ax, 0.3, 1.0, 11.4, 6.2, "", fill=GRAY4, rounded=True, lw=2.5) - ax.text( - 6.0, 6.95, "FLINK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold" - ) - - # Job Manager - draw_box( - ax, - 1.0, - 5.5, - 3.0, - 1.2, - "Job Manager\n(koordynacja,\ncheckpointy)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Task Managers - draw_box(ax, 1.0, 3.0, 10.0, 2.0, "", fill="white", rounded=True, lw=1.5) - ax.text( - 6.0, 4.7, "Task Managers (workery)", fontsize=FS, ha="center", fontweight="bold" - ) - - slots = ["source\n& map()", "map()", "window()\n& reduce", "sink()"] - for i, s in enumerate(slots): - x = 1.5 + i * 2.4 - draw_box( - ax, - x, - 3.3, - 2.0, - 1.2, - f"Slot {i + 1}\n{s}", - fill=GRAY1, - fontsize=FS_SMALL, - fontweight="bold", - ) - - draw_arrow(ax, 2.5, 5.5, 6.0, 5.0, lw=1.5, style="->") - ax.text(5.0, 5.5, "przydziela\npodzadania", fontsize=FS_SMALL, style="italic") - - # Checkpoint storage - draw_box( - ax, - 5.5, - 1.2, - 3.5, - 1.2, - "Checkpoint Storage\n(HDFS / S3)", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - ax.plot([7.25, 7.25], [2.4, 3.3], color=LN, lw=1.5, linestyle="--") - ax.text(8.0, 2.7, "snapshoty\nstanu", fontsize=FS_SMALL, style="italic") - - # Barrier concept at bottom - ax.text(3.0, 1.6, "Barrier:", fontsize=FS, fontweight="bold") - barrier_boxes = ["source", "|B|", "map", "|B|", "sink"] - bx = 0.8 - for _i, b in enumerate(barrier_boxes): - if b == "|B|": - ax.text( - bx + 0.3, - 1.5, - b, - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN}, - ) - draw_arrow(ax, bx, 1.5, bx + 0.1, 1.5, lw=1) - bx += 0.7 - else: - draw_box( - ax, - bx, - 1.3, - 1.0, - 0.45, - b, - fill=GRAY1, - fontsize=FS_SMALL, - fontweight="bold", - ) - bx += 1.2 - - save_fig(fig, "q20_flink_arch.png") - - -# ============================================================ -# 12. Spark Streaming architecture -# ============================================================ -def gen_spark_streaming_arch() -> None: - """Gen spark streaming arch.""" - fig, ax = plt.subplots(figsize=(9, 5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Spark Streaming — architektura (micro-batch)", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Cluster border - draw_box(ax, 0.3, 0.5, 11.4, 5.8, "", fill=GRAY4, rounded=True, lw=2.5) - ax.text( - 6.0, 6.0, "SPARK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold" - ) - - # Driver - draw_box( - ax, - 1.0, - 4.5, - 3.0, - 1.2, - "Driver\n(planuje mini-batche)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - draw_arrow(ax, 2.5, 4.5, 6.0, 4.0, lw=1.5) - - # Batches - batches = ["batch 1\n(e1,e2,e3)", "batch 2\n(e4,e5,e6)", "batch 3\n(e7,e8,e9)"] - for i, b in enumerate(batches): - y = 2.8 - i * 1.0 - draw_box( - ax, 4.5, y, 2.5, 0.8, b, fill=GRAY1, fontsize=FS_SMALL, fontweight="bold" - ) - # map → reduce - draw_arrow(ax, 7.0, y + 0.4, 7.5, y + 0.4, lw=1) - draw_box(ax, 7.5, y, 1.3, 0.8, "map→\nreduce", fill=GRAY3, fontsize=5.5) - draw_arrow(ax, 8.8, y + 0.4, 9.3, y + 0.4, lw=1) - draw_box( - ax, 9.3, y, 1.5, 0.8, f"result {i + 1}", fill="white", fontsize=FS_SMALL - ) - - # Spark ecosystem - draw_box( - ax, - 1.0, - 1.0, - 3.0, - 1.0, - "Spark SQL / MLlib\n(ten sam ekosystem!)", - fill=GRAY5, - fontsize=FS, - fontweight="bold", - ) - - ax.text( - 6.0, - 0.3, - "ZALETA: batch API | WADA: latencja ≥ batch interval (~100ms)", - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN}, - ) - - save_fig(fig, "q20_spark_streaming_arch.png") - - -# ============================================================ -# 13. Lambda vs Kappa architecture -# ============================================================ -def gen_lambda_vs_kappa() -> None: - """Gen lambda vs kappa.""" - fig, axes = plt.subplots(2, 1, figsize=(10, 7)) - fig.suptitle("Architektura Lambda vs Kappa", fontsize=FS_TITLE, fontweight="bold") - - # --- Lambda --- - ax = axes[0] - ax.set_xlim(0, 12) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "LAMBDA — 2 ścieżki (batch + speed)", fontsize=FS_LABEL, fontweight="bold" - ) - - # Source - draw_box( - ax, - 0.3, - 1.8, - 2.0, - 1.5, - "Źródło\ndanych", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Batch layer (top) - draw_box( - ax, - 3.5, - 3.3, - 3.0, - 1.2, - "Batch Layer\n(Spark)\nprzelicza co godzinę", - fill=GRAY1, - fontsize=FS_SMALL, - fontweight="bold", - ) - draw_arrow(ax, 2.3, 3.0, 3.5, 3.9, lw=1.5) - - # Speed layer (bottom) - draw_box( - ax, - 3.5, - 0.8, - 3.0, - 1.2, - "Speed Layer\n(Flink)\nreal-time", - fill=GRAY3, - fontsize=FS_SMALL, - fontweight="bold", - ) - draw_arrow(ax, 2.3, 2.2, 3.5, 1.4, lw=1.5) - - # Results - draw_box( - ax, - 7.5, - 3.3, - 2.0, - 1.2, - "Dokładne\nwyniki\n(wolne)", - fill=GRAY4, - fontsize=FS_SMALL, - ) - draw_arrow(ax, 6.5, 3.9, 7.5, 3.9, lw=1.5) - - draw_box( - ax, - 7.5, - 0.8, - 2.0, - 1.2, - "Przybliżone\nwyniki\n(szybkie)", - fill=GRAY4, - fontsize=FS_SMALL, - ) - draw_arrow(ax, 6.5, 1.4, 7.5, 1.4, lw=1.5) - - # Merge - draw_box( - ax, - 10.0, - 2.0, - 1.5, - 1.5, - "MERGE\n→ UI", - fill=GRAY5, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 9.5, 3.5, 10.0, 3.0, lw=1.5) - draw_arrow(ax, 9.5, 1.8, 10.0, 2.5, lw=1.5) - - ax.text( - 6.0, - 0.1, - "2 systemy, 2 kody — złożone ale pewne", - fontsize=FS, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - # --- Kappa --- - ax = axes[1] - ax.set_xlim(0, 12) - ax.set_ylim(0, 4) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "KAPPA — 1 ścieżka (streaming only)", fontsize=FS_LABEL, fontweight="bold" - ) - - # Source - draw_box( - ax, - 0.3, - 1.3, - 2.0, - 1.5, - "Źródło\ndanych", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Single streaming layer - draw_box( - ax, - 3.5, - 1.3, - 3.5, - 1.5, - "Streaming Layer\n(Flink)\n+ replay z Kafka log", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 2.3, 2.05, 3.5, 2.05, lw=2) - - # Output - draw_box( - ax, - 8.0, - 1.3, - 2.5, - 1.5, - "Wyniki\n→ UI", - fill=GRAY4, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 7.0, 2.05, 8.0, 2.05, lw=2) - - # Replay arrow - ax.annotate( - "", - xy=(3.5, 1.0), - xytext=(7.0, 1.0), - arrowprops={ - "arrowstyle": "<-", - "lw": 1.5, - "color": LN, - "connectionstyle": "arc3,rad=0.3", - "linestyle": "--", - }, - ) - ax.text( - 5.25, - 0.3, - "Replay z Kafka\n(przetwórz historię od nowa)", - fontsize=FS_SMALL, - ha="center", - style="italic", - ) - - ax.text( - 6.0, - 3.3, - "1 system, 1 kod — prostsze, ale replay = dużo I/O", - fontsize=FS, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - fig.tight_layout(rect=[0, 0, 1, 0.92]) - save_fig(fig, "q20_lambda_vs_kappa.png") - - -# ============================================================ -# 14. Lambda vs Kappa comparison table -# ============================================================ -def gen_lambda_kappa_table() -> None: - """Gen lambda kappa table.""" - fig, ax = plt.subplots(figsize=(8, 3.5)) - ax.set_xlim(0, 10) - ax.set_ylim(-4.5, 1) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Lambda vs Kappa — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - headers = ["Cecha", "Lambda", "Kappa"] - col_w = [2.5, 3.5, 3.5] - rows = [ - ["Ścieżki", "2 (batch + speed)", "1 (streaming)"], - ["Kod", "2 implementacje", "1 implementacja"], - ["Złożoność", "wysoka", "niska"], - ["Replay", "batch przelicza", "Kafka replay"], - ["Spójność", "merge wymagany", "natywna"], - ["Przykład", "Netflix, LinkedIn", "Uber, Confluent"], - ] - draw_table( - ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.55, fontsize=7.5 - ) - - save_fig(fig, "q20_lambda_kappa_table.png") - - -# ============================================================ -# 15. Exactly-once comparison -# ============================================================ -def gen_exactly_once() -> None: - """Gen exactly once.""" - fig, ax = plt.subplots(figsize=(9, 5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Exactly-Once — mechanizmy na 3 platformach", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Flink - draw_box(ax, 0.3, 4.3, 11.0, 2.0, "", fill=GRAY4, rounded=True, lw=1.5) - ax.text( - 1.0, - 5.9, - "Flink — Distributed Snapshots (Chandy-Lamport)", - fontsize=FS, - fontweight="bold", - ) - - flink_steps = ["source", "|B|", "map()", "|B|", "sink()"] - bx = 1.0 - for s in flink_steps: - if s == "|B|": - ax.text( - bx + 0.25, - 4.85, - s, - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN}, - ) - draw_arrow(ax, bx - 0.1, 4.85, bx + 0.05, 4.85, lw=1) - bx += 0.7 - else: - draw_box( - ax, - bx, - 4.6, - 1.5, - 0.55, - s, - fill=GRAY1, - fontsize=FS_SMALL, - fontweight="bold", - ) - bx += 1.8 - ax.text( - 8.5, - 5.0, - "barrier → save state\n→ checkpoint (HDFS/S3)", - fontsize=FS_SMALL, - style="italic", - ) - - # Kafka Streams - draw_box(ax, 0.3, 2.3, 11.0, 1.5, "", fill=GRAY1, rounded=True, lw=1.5) - ax.text( - 1.0, 3.5, "Kafka Streams — Transakcje Kafka", fontsize=FS, fontweight="bold" - ) - ax.text( - 1.5, - 2.85, - "idempotent producer + begin TX → produce → commit TX → consumer offsets w TX", - fontsize=FS_SMALL, - ) - - # Spark - draw_box(ax, 0.3, 0.5, 11.0, 1.5, "", fill=GRAY3, rounded=True, lw=1.5) - ax.text( - 1.0, - 1.7, - "Spark Streaming — Write-Ahead Log (WAL)", - fontsize=FS, - fontweight="bold", - ) - ax.text( - 1.5, - 1.05, - "WAL + checkpointing micro-batchów + idempotent sinks (np. upsert do DB)", - fontsize=FS_SMALL, - ) - - save_fig(fig, "q20_exactly_once.png") - - -# ============================================================ -# 16. Late data strategies (DRAS) -# ============================================================ -def gen_late_data_strategies() -> None: - """Gen late data strategies.""" - fig, ax = plt.subplots(figsize=(9, 5.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Late Data — 4 strategie (mnemonik DRAS)", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Setup: window closed, late event arrives - draw_box( - ax, - 0.5, - 5.5, - 4.5, - 1.0, - "Okno [14:00-14:05]\nZAMKNIĘTE o 14:05", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - draw_box( - ax, - 6.0, - 5.5, - 4.5, - 1.0, - "Spóźnione zdarzenie\nevent_time=14:00:03\narrives=14:05:30", - fill="#F8D7DA", - fontsize=FS_SMALL, - fontweight="bold", - ) - draw_arrow(ax, 10.5, 6.0, 5.0, 6.0, lw=2, color="#C62828", style="->") - ax.text( - 7.5, - 5.2, - "LATE!", - fontsize=FS_LABEL, - ha="center", - fontweight="bold", - color="#C62828", - ) - - # 4 strategies - strategies = [ - ("D — Drop", "Odrzuć spóźnione", "/dev/null", GRAY4), - ("R — Recompute", "Przelicz okno ponownie", "poprawne ale kosztowne", GRAY1), - ( - "A — Allowed lateness", - "Czekaj dodatkowy czas\n(np. +2 min)", - "kompromis pamięci", - GRAY2, - ), - ( - "S — Side output", - "Przekieruj do osobnej\nkolejki", - "elastyczne, ręczna analiza", - GRAY3, - ), - ] - for i, (name, desc, tradeoff, color) in enumerate(strategies): - y = 3.8 - i * 1.1 - draw_box(ax, 0.5, y, 2.5, 0.9, name, fill=color, fontsize=FS, fontweight="bold") - ax.text(3.3, y + 0.45, desc, fontsize=FS_SMALL, va="center") - ax.text( - 8.5, - y + 0.45, - tradeoff, - fontsize=FS_SMALL, - va="center", - style="italic", - color="#555", - ) - - save_fig(fig, "q20_late_data_strategies.png") - - -# ============================================================ -# 17. Decision tree — which platform -# ============================================================ -def gen_decision_tree() -> None: - """Gen decision tree.""" - fig, ax = plt.subplots(figsize=(10, 5.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Drzewo decyzyjne — wybór platformy", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Root question - draw_box( - ax, - 3.5, - 5.5, - 4.5, - 1.0, - "Latencja < 10ms\nwymagana?", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # TAK branch - draw_arrow(ax, 3.5, 5.7, 2.0, 5.0, lw=1.5) - ax.text(2.3, 5.3, "TAK", fontsize=FS, fontweight="bold") - - draw_box( - ax, - 0.3, - 3.5, - 3.5, - 1.0, - "Dane już w Kafce?\nProste transformacje?", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - - # TAK → Kafka Streams - draw_arrow(ax, 0.3, 3.7, -0.1, 3.0, lw=1.5) - ax.text(0.0, 3.3, "TAK", fontsize=FS_SMALL, fontweight="bold") - draw_box( - ax, - -0.3, - 1.8, - 2.5, - 1.0, - "Kafka\nStreams", - fill=GRAY5, - fontsize=FS_LABEL, - fontweight="bold", - ) - - # NIE → Flink - draw_arrow(ax, 3.8, 3.7, 4.5, 3.0, lw=1.5) - ax.text(4.0, 3.3, "NIE\n(złożona logika)", fontsize=FS_SMALL) - draw_box( - ax, - 3.0, - 1.8, - 2.5, - 1.0, - "Apache\nFlink", - fill=GRAY5, - fontsize=FS_LABEL, - fontweight="bold", - ) - - # NIE branch - draw_arrow(ax, 8.0, 5.7, 9.5, 5.0, lw=1.5) - ax.text(8.7, 5.3, "NIE", fontsize=FS, fontweight="bold") - - draw_box( - ax, - 7.5, - 3.5, - 4.2, - 1.0, - "~100ms-1s OK?\nPotrzeba ML / SQL?", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - - # TAK + ML → Spark - draw_arrow(ax, 9.5, 3.5, 9.5, 3.0, lw=1.5) - ax.text(10.0, 3.3, "TAK + ML/SQL", fontsize=FS_SMALL) - draw_box( - ax, - 8.0, - 1.8, - 2.5, - 1.0, - "Spark\nStreaming", - fill=GRAY5, - fontsize=FS_LABEL, - fontweight="bold", - ) - - # TAK + proste → Kafka Streams too - draw_arrow(ax, 7.5, 3.7, 6.5, 3.0, lw=1.5) - ax.text(6.3, 3.3, "proste + TAK", fontsize=FS_SMALL) - draw_box( - ax, - 5.8, - 1.8, - 2.0, - 1.0, - "Kafka\nStreams", - fill=GRAY5, - fontsize=FS, - fontweight="bold", - ) - - # Legend - ax.text( - 6.0, - 0.7, - "Reguła: Kafka Streams = najprostsze (library) | " - "Flink = najpotężniejszy (true streaming) | Spark = ekosystem ML", - fontsize=FS, - ha="center", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - save_fig(fig, "q20_decision_tree.png") - - # ============================================================ # MAIN # ============================================================ diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_q23_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_q23_diagrams.py index a0fa9c2..e553f10 100755 --- a/python_pkg/praca_magisterska_video/generate_images/generate_q23_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_q23_diagrams.py @@ -2,2606 +2,47 @@ """Generate all diagrams for PYTANIE 23: Segmentacja obrazu. A4-compatible, monochrome-friendly (grays + one accent), 300 DPI. +Re-exports all diagram generators from submodules. """ from __future__ import annotations import logging from pathlib import Path -from typing import TYPE_CHECKING +import sys -import matplotlib as mpl +# Ensure sibling modules are importable when run as a script. +sys.path.insert(0, str(Path(__file__).resolve().parent)) -mpl.use("Agg") +from _q23_architectures import generate_fcn, generate_unet +from _q23_common import OUTPUT_DIR +from _q23_diy_unet import generate_diy_unet +from _q23_mean_shift_ncuts import generate_mean_shift, generate_normalized_cuts +from _q23_mnemonics import generate_mnemonics +from _q23_nn_basics import generate_dot_product, generate_relu +from _q23_otsu_watershed import generate_otsu_bimodal, generate_watershed +from _q23_receptive_transformer import generate_receptive_field, generate_transformer +from _q23_region_diy import generate_diy_thresholding, generate_region_growing -from matplotlib import patches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt -import numpy as np - -if TYPE_CHECKING: - from matplotlib.axes import Axes +__all__ = [ + "generate_diy_thresholding", + "generate_diy_unet", + "generate_dot_product", + "generate_fcn", + "generate_mean_shift", + "generate_mnemonics", + "generate_normalized_cuts", + "generate_otsu_bimodal", + "generate_receptive_field", + "generate_region_growing", + "generate_relu", + "generate_transformer", + "generate_unet", + "generate_watershed", +] _logger = logging.getLogger(__name__) -rng = np.random.default_rng(42) - -DPI = 300 -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -# Color palette — monochrome-friendly -BLACK = "#000000" -WHITE = "#FFFFFF" -GRAY1 = "#F5F5F5" -GRAY2 = "#E0E0E0" -GRAY3 = "#BDBDBD" -GRAY4 = "#9E9E9E" -GRAY5 = "#757575" -GRAY6 = "#424242" -ACCENT = "#4A90D9" # single blue accent for highlights -ACCENT_LIGHT = "#B3D4FC" -RED_ACCENT = "#D32F2F" -GREEN_ACCENT = "#388E3C" - -FS = 9 -FS_TITLE = 11 -FS_SMALL = 7 -FS_TINY = 6 - -_RIDGE_X = 5 -_VALLEY2_END = 9 -_DARK_PIXEL_THRESHOLD = 100 -_GRID_LAST_IDX = 3 -_HIGHLIGHT_START = 3 -_HIGHLIGHT_END = 5 -_BRIGHT_THRESHOLD = 170 -_OTSU_THRESHOLD = 128 - - -def _save_figure(name: str) -> None: - """Save current figure and log.""" - plt.tight_layout() - plt.savefig( - str(Path(OUTPUT_DIR) / name), - dpi=DPI, - bbox_inches="tight", - facecolor="white", - ) - plt.close() - _logger.info(" ✓ %s", name) - - -def _render_text_lines( - ax: Axes, - lines: list[tuple[str, int, str, str]], - *, - x_pos: float = 0.5, - start_y: float, - y_step: float = 0.5, - y_empty_step: float = 0.2, -) -> None: - """Render a list of styled text lines on an axis.""" - y = start_y - for txt, size, color, weight in lines: - if txt == "": - y -= y_empty_step - continue - ax.text( - x_pos, y, txt, - fontsize=size, color=color, fontweight=weight, va="top", - ) - y -= y_step - - -def _draw_otsu_variance_panel(ax: Axes) -> None: - """Draw panel 2: within-class variance explanation.""" - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Wariancja wewnątrzklasowa", fontsize=FS_TITLE, fontweight="bold") - - texts = [ - ( - "Wariancja = jak bardzo wartości\nróżnią się od średniej", - FS, - "black", - "normal", - ), - ("", 0, "black", "normal"), - ("Klasa 0 (piksele ≤ T):", FS, ACCENT, "bold"), - (" wartości: 30, 50, 45, 60, 55", FS_SMALL, "black", "normal"), - (" średnia μ₀ = 48", FS_SMALL, "black", "normal"), - (" σ₀² = ((30-48)²+(50-48)²+...)/5 = 108", FS_SMALL, "black", "normal"), - ("", 0, "black", "normal"), - ("Klasa 1 (piksele > T):", FS, RED_ACCENT, "bold"), - (" wartości: 180, 200, 190, 210, 195", FS_SMALL, "black", "normal"), - (" średnia μ₁ = 195", FS_SMALL, "black", "normal"), - (" σ₁² = ((180-195)²+...)/5 = 100", FS_SMALL, "black", "normal"), - ("", 0, "black", "normal"), - ("σ²_wewnątrz = w₀·σ₀² + w₁·σ₁²", FS, BLACK, "bold"), - ("= 0.6·108 + 0.4·100 = 104.8", FS_SMALL, "black", "normal"), - ("", 0, "black", "normal"), - ("Otsu próbuje KAŻDE T: 0,1,...,255", FS_SMALL, GREEN_ACCENT, "bold"), - ("Wybiera T dające MINIMUM σ²_wewnątrz", FS_SMALL, GREEN_ACCENT, "bold"), - ] - _render_text_lines( - ax, texts, - x_pos=0.3, start_y=9.2, y_step=0.55, y_empty_step=0.25, - ) - - -# ============================================================ -# 1. OTSU — Bimodal histogram + within-class variance -# ============================================================ -def generate_otsu_bimodal() -> None: - """Generate otsu bimodal.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) - - # --- Panel 1: Bimodal histogram --- - ax = axes[0] - dark = rng.normal(60, 20, 3000).clip(0, 255) - bright = rng.normal(190, 25, 2000).clip(0, 255) - all_pixels = np.concatenate([dark, bright]) - - counts, _bins, _bars = ax.hist( - all_pixels, bins=64, color=GRAY3, edgecolor=GRAY5, linewidth=0.5 - ) - ax.axvline( - x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="Próg Otsu T=128" - ) - ax.fill_betweenx([0, max(counts) * 1.1], 0, 128, alpha=0.12, color=ACCENT) - ax.fill_betweenx([0, max(counts) * 1.1], 128, 255, alpha=0.12, color=RED_ACCENT) - ax.text( - 45, - max(counts) * 0.85, - "Klasa 0\n(tło)", - ha="center", - fontsize=FS, - fontweight="bold", - color=ACCENT, - ) - ax.text( - 195, - max(counts) * 0.85, - "Klasa 1\n(obiekt)", - ha="center", - fontsize=FS, - fontweight="bold", - color=RED_ACCENT, - ) - ax.annotate( - "Garb 1", - xy=(60, max(counts) * 0.6), - fontsize=FS_SMALL, - ha="center", - arrowprops={"arrowstyle": "->", "color": GRAY5}, - xytext=(30, max(counts) * 0.45), - ) - ax.annotate( - "Garb 2", - xy=(190, max(counts) * 0.5), - fontsize=FS_SMALL, - ha="center", - arrowprops={"arrowstyle": "->", "color": GRAY5}, - xytext=(220, max(counts) * 0.35), - ) - ax.set_xlabel("Jasność piksela (0-255)", fontsize=FS) - ax.set_ylabel("Liczba pikseli", fontsize=FS) - ax.set_title("Histogram bimodalny", fontsize=FS_TITLE, fontweight="bold") - ax.legend(fontsize=FS_SMALL, loc="upper right") - ax.set_xlim(0, 255) - - # --- Panel 2: Within-class variance explanation --- - _draw_otsu_variance_panel(axes[1]) - - # --- Panel 3: Jednorodność explanation --- - ax = axes[2] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title('"Jednorodne" = małe σ²', fontsize=FS_TITLE, fontweight="bold") - - # Draw two clusters - # Good separation - c0 = rng.normal(2, 0.4, 15) - c1 = rng.normal(7, 0.4, 15) - y_pos_0 = rng.uniform(6, 8, 15) - y_pos_1 = rng.uniform(6, 8, 15) - ax.scatter(c0, y_pos_0, c=ACCENT, s=30, zorder=5, label="Klasa 0") - ax.scatter(c1, y_pos_1, c=RED_ACCENT, s=30, zorder=5, label="Klasa 1") - ax.axvline(x=4.5, color=GREEN_ACCENT, linewidth=2, linestyle="--") - ax.text( - 4.5, - 8.8, - "T optymalny", - ha="center", - fontsize=FS_SMALL, - color=GREEN_ACCENT, - fontweight="bold", - ) - ax.text( - 2, 5.3, "σ₀² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=ACCENT - ) - ax.text( - 7, 5.3, "σ₁² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=RED_ACCENT - ) - ax.text( - 5, - 4, - "→ σ²_wewnątrz MINIMALNA\n→ klasy JEDNORODNE\n→ dobra segmentacja!", - ha="center", - fontsize=FS, - fontweight="bold", - color=GREEN_ACCENT, - ) - - # Bad separation - c0b = rng.normal(3.5, 1.5, 15) - c1b = rng.normal(6, 1.5, 15) - y_pos_0b = rng.uniform(1, 3, 15) - y_pos_1b = rng.uniform(1, 3, 15) - ax.scatter(c0b, y_pos_0b, c=ACCENT, s=30, marker="x", zorder=5) - ax.scatter(c1b, y_pos_1b, c=RED_ACCENT, s=30, marker="x", zorder=5) - ax.axvline(x=4.5, color=GRAY4, linewidth=1, linestyle=":", ymin=0, ymax=0.35) - ax.text( - 5, - 0.3, - "σ²_wewnątrz DUŻA → klasy mieszają się → zły próg", - ha="center", - fontsize=FS_SMALL, - color=GRAY5, - ) - - ax.legend(fontsize=FS_SMALL, loc="upper left") - - _save_figure("q23_otsu_bimodal.png") - - -def _draw_watershed_result_panel(ax: Axes) -> None: - """Draw panel 3: watershed result with over-segmentation problem.""" - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Krok 3: wynik", fontsize=FS_TITLE, fontweight="bold") - - rect1 = FancyBboxPatch( - (0.5, 6), 3.5, 3.2, - boxstyle="round,pad=0.1", facecolor=ACCENT_LIGHT, - edgecolor=BLACK, linewidth=1, - ) - ax.add_patch(rect1) - ax.text(2.25, 8.8, "Ideał: 2 segmenty", fontsize=FS, ha="center", fontweight="bold") - ax.text(2.25, 7.5, "Segment A Segment B", fontsize=FS_SMALL, ha="center") - ax.text( - 2.25, 6.7, "(po marker-controlled)", - fontsize=FS_SMALL, ha="center", color=GREEN_ACCENT, - ) - - rect2 = FancyBboxPatch( - (5.5, 6), 4, 3.2, - boxstyle="round,pad=0.1", facecolor="#FFCDD2", - edgecolor=BLACK, linewidth=1, - ) - ax.add_patch(rect2) - ax.text( - 7.5, 8.8, "Problem: over-segmentation", - fontsize=FS, ha="center", fontweight="bold", color=RED_ACCENT, - ) - ax.text( - 7.5, 7.8, "47 regionów zamiast 2!", - fontsize=FS_SMALL, ha="center", color=RED_ACCENT, - ) - ax.text(7.5, 7.1, "Każde mini-minimum", fontsize=FS_SMALL, ha="center") - ax.text(7.5, 6.5, '→ osobna „dolina"', fontsize=FS_SMALL, ha="center") - - # Apply marker-controlled solution - rect3 = FancyBboxPatch( - (1, 0.5), 8, 4.5, - boxstyle="round,pad=0.15", facecolor=GRAY1, - edgecolor=GREEN_ACCENT, linewidth=1.5, - ) - ax.add_patch(rect3) - ax.text( - 5, 4.3, "Rozwiązanie: Marker-controlled watershed", - fontsize=FS, ha="center", fontweight="bold", color=GREEN_ACCENT, - ) - ax.text( - 5, 3.4, - '1. Zaznacz ręcznie „seeds" (markery) w każdym obiekcie', - fontsize=FS_SMALL, ha="center", - ) - ax.text( - 5, 2.7, - "2. Zalewaj TYLKO od tych markerów (nie od wszystkich minimów)", - fontsize=FS_SMALL, ha="center", - ) - ax.text( - 5, 2.0, "3. Eliminuje fałszywe doliny z szumu", - fontsize=FS_SMALL, ha="center", - ) - ax.text( - 5, 1.2, "Wynik: tyle segmentów, ile podano markerów", - fontsize=FS_SMALL, ha="center", fontweight="bold", - ) - - -# ============================================================ -# 2. WATERSHED — Topographic flooding (not ASCII!) -# ============================================================ -def generate_watershed() -> None: - """Generate watershed.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 3.8)) - - # --- Panel 1: Image as topographic surface --- - ax = axes[0] - x = np.linspace(0, 10, 200) - # Create a surface with two valleys and a ridge - surface = ( - 3 * np.exp(-((x - 3) ** 2) / 1.5) - + 4 * np.exp(-((x - 7) ** 2) / 1.2) - + 0.5 * np.sin(x * 2) - + 1 - ) - # Invert: valleys at objects (dark), peaks at boundaries (bright) - surface_inv = 6 - surface + 1 - - ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7) - ax.plot(x, surface_inv, color=BLACK, linewidth=1.5) - - # Mark valleys - ax.annotate( - "Dolina 1\n(obiekt A)", - xy=(3, surface_inv[60]), - fontsize=FS_SMALL, - ha="center", - va="bottom", - arrowprops={"arrowstyle": "->", "color": ACCENT}, - xytext=(1.5, 5.5), - ) - ax.annotate( - "Dolina 2\n(obiekt B)", - xy=(7, surface_inv[140]), - fontsize=FS_SMALL, - ha="center", - va="bottom", - arrowprops={"arrowstyle": "->", "color": RED_ACCENT}, - xytext=(8.5, 5.5), - ) - # Mark ridge - ax.annotate( - "Grań\n(granica)", - xy=(5, surface_inv[100]), - fontsize=FS_SMALL, - ha="center", - va="bottom", - arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT}, - xytext=(5, 6.5), - ) - - ax.set_xlabel("Pozycja piksela", fontsize=FS) - ax.set_ylabel("Jasność (= wysokość)", fontsize=FS) - ax.set_title("Krok 1: obraz → teren", fontsize=FS_TITLE, fontweight="bold") - ax.set_ylim(0, 7) - - # --- Panel 2: Flooding --- - ax = axes[1] - ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7) - ax.plot(x, surface_inv, color=BLACK, linewidth=1.5) - - # Water level - water_level = 3.2 - - # Fill water in valley 1 - x_v1 = x[(x > 1) & (x < _RIDGE_X)] - s_v1 = surface_inv[(x > 1) & (x < _RIDGE_X)] - ax.fill_between( - x_v1, s_v1, water_level, where=s_v1 < water_level, color=ACCENT_LIGHT, alpha=0.6 - ) - # Fill water in valley 2 - x_v2 = x[(x > _RIDGE_X) & (x < _VALLEY2_END)] - s_v2 = surface_inv[(x > _RIDGE_X) & (x < _VALLEY2_END)] - ax.fill_between( - x_v2, s_v2, water_level, where=s_v2 < water_level, color="#FFCDD2", alpha=0.6 - ) - - ax.axhline(y=water_level, color=ACCENT, linewidth=1, linestyle="--", alpha=0.5) - ax.text(3, 2.5, "Woda A", fontsize=FS, ha="center", color=ACCENT, fontweight="bold") - ax.text( - 7, 2.2, "Woda B", fontsize=FS, ha="center", color=RED_ACCENT, fontweight="bold" - ) - ax.annotate( - "Tu się spotkają!\n→ GRANICA", - xy=(5, surface_inv[100]), - fontsize=FS_SMALL, - ha="center", - color=GREEN_ACCENT, - fontweight="bold", - arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT}, - xytext=(5, 6.2), - ) - - ax.set_xlabel("Pozycja piksela", fontsize=FS) - ax.set_title("Krok 2: zalewanie", fontsize=FS_TITLE, fontweight="bold") - ax.set_ylim(0, 7) - - # --- Panel 3: Result with problem --- - _draw_watershed_result_panel(axes[2]) - - _save_figure("q23_watershed.png") - - -# ============================================================ -# 3. MEAN SHIFT — Kernel, density, feature space -# ============================================================ -def generate_mean_shift() -> None: - """Generate mean shift.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) - - # --- Panel 1: Feature space concept --- - ax = axes[0] - # Three clusters in 2D feature space (brightness, x-position) - c1x = rng.normal(2, 0.5, 40) - c1y = rng.normal(2, 0.5, 40) - c2x = rng.normal(6, 0.6, 35) - c2y = rng.normal(7, 0.5, 35) - c3x = rng.normal(8, 0.4, 25) - c3y = rng.normal(3, 0.6, 25) - - ax.scatter(c1x, c1y, c=GRAY4, s=15, alpha=0.7, zorder=3) - ax.scatter(c2x, c2y, c=GRAY4, s=15, alpha=0.7, zorder=3) - ax.scatter(c3x, c3y, c=GRAY4, s=15, alpha=0.7, zorder=3) - - # Label peaks - ax.scatter([2], [2], c=RED_ACCENT, s=80, marker="*", zorder=5, label="Max gęstości") - ax.scatter([6], [7], c=RED_ACCENT, s=80, marker="*", zorder=5) - ax.scatter([8], [3], c=RED_ACCENT, s=80, marker="*", zorder=5) - - ax.set_xlabel("Cecha 1: jasność", fontsize=FS) - ax.set_ylabel("Cecha 2: pozycja x", fontsize=FS) - ax.set_title("Przestrzeń cech", fontsize=FS_TITLE, fontweight="bold") - for lx, ly, ltxt in [ - (2, 0.3, "Klaster 1\n(ciemne, lewo)"), - (6, 5.3, "Klaster 2\n(jasne, prawo)"), - (8, 1.3, "Klaster 3\n(jasne, dół)"), - ]: - ax.text(lx, ly, ltxt, ha="center", fontsize=FS_TINY, color=GRAY6) - ax.legend(fontsize=FS_SMALL, loc="upper left") - - # --- Panel 2: Kernel/window moving --- - ax = axes[1] - ax.scatter(c1x, c1y, c=ACCENT_LIGHT, s=15, alpha=0.7, zorder=3) - ax.scatter(c2x, c2y, c=GRAY3, s=15, alpha=0.7, zorder=3) - ax.scatter(c3x, c3y, c=GRAY3, s=15, alpha=0.7, zorder=3) - - # Show kernel movement - path_x = [4.5, 3.8, 3.0, 2.3, 2.05] - path_y = [4.0, 3.3, 2.7, 2.2, 2.03] - - for i, (px, py) in enumerate(zip(path_x, path_y, strict=False)): - alpha = 0.3 + 0.15 * i - circle = plt.Circle( - (px, py), - 1.2, - fill=False, - edgecolor=ACCENT, - linewidth=1.5, - linestyle="--" if i < len(path_x) - 1 else "-", - alpha=alpha, - ) - ax.add_patch(circle) - if i < len(path_x) - 1: - ax.annotate( - "", - xy=(path_x[i + 1], path_y[i + 1]), - xytext=(px, py), - arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 1.5}, - ) - - ax.scatter([path_x[0]], [path_y[0]], c=ACCENT, s=50, marker="o", zorder=5) - ax.scatter([path_x[-1]], [path_y[-1]], c=RED_ACCENT, s=80, marker="*", zorder=5) - - ax.text( - 4.5, 5.2, "Start: losowy\npiksel", fontsize=FS_SMALL, ha="center", color=ACCENT - ) - ax.text( - 2.05, - 0.5, - "Koniec: max\ngęstości", - fontsize=FS_SMALL, - ha="center", - color=RED_ACCENT, - fontweight="bold", - ) - ax.text( - 7, - 8, - "Okno (jądro)\nprzesuwa się\ndo skupiska", - fontsize=FS_SMALL, - ha="center", - color=GRAY6, - bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, - ) - - ax.set_xlabel("Cecha 1", fontsize=FS) - ax.set_ylabel("Cecha 2", fontsize=FS) - ax.set_title("Jądro → max gęstości", fontsize=FS_TITLE, fontweight="bold") - ax.set_xlim(0, 10) - ax.set_ylim(0, 9) - - # --- Panel 3: Why no K parameter --- - ax = axes[2] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Dlaczego bez K?", fontsize=FS_TITLE, fontweight="bold") - - lines = [ - ("K-means wymaga:", FS, RED_ACCENT, "bold"), - (' „Podaj K=3 klastry"', FS_SMALL, "black", "normal"), - (" Problem: skąd wiesz ile klastrów?", FS_SMALL, GRAY5, "normal"), - ("", 0, "", ""), - ("Mean Shift NIE wymaga K:", FS, GREEN_ACCENT, "bold"), - (" Każdy piksel startuje → toczy się", FS_SMALL, "black", "normal"), - (" → trafia do najbliższego szczytu", FS_SMALL, "black", "normal"), - (" → ile szczytów = tyle segmentów", FS_SMALL, "black", "normal"), - (" → automatycznie!", FS_SMALL, GREEN_ACCENT, "bold"), - ("", 0, "", ""), - ("Parametr: bandwidth (szerokość okna)", FS, "black", "bold"), - (" Duże okno → mało segmentów", FS_SMALL, "black", "normal"), - (" Małe okno → dużo segmentów", FS_SMALL, "black", "normal"), - ("", 0, "", ""), - ("Okno = jądro (kernel):", FS, "black", "bold"), - (" Koło o promieniu h wokół punktu.", FS_SMALL, "black", "normal"), - (" Oblicz średnią pikseli W oknie.", FS_SMALL, "black", "normal"), - (" Przesuń okno na tę średnią.", FS_SMALL, "black", "normal"), - (" Powtórz aż się zatrzyma.", FS_SMALL, "black", "normal"), - ] - _render_text_lines(ax, lines, start_y=9.0) - - _save_figure("q23_mean_shift.png") - - -def _draw_ncuts_pixel_grid( - ax: Axes, pixel_vals: np.ndarray, -) -> None: - """Draw 4x4 pixel grid with value labels and edge weights.""" - for i in range(4): - for j in range(4): - v = pixel_vals[i, j] - gray_val = v / 255.0 - str(gray_val) - rect = patches.Rectangle( - (j - 0.4, 3 - i - 0.4), - 0.8, 0.8, - facecolor=(gray_val, gray_val, gray_val), - edgecolor=BLACK, linewidth=0.8, - ) - ax.add_patch(rect) - text_color = ( - "white" if v < _DARK_PIXEL_THRESHOLD else "black" - ) - ax.text( - j, 3 - i, str(v), - ha="center", va="center", fontsize=FS_SMALL, - color=text_color, fontweight="bold", - ) - - -def _draw_ncuts_edges( - ax: Axes, pixel_vals: np.ndarray, -) -> None: - """Draw weighted edges between adjacent pixels.""" - for i in range(4): - for j in range(4): - if j < _GRID_LAST_IDX: - similarity = max( - 0, - 1 - abs(pixel_vals[i, j] - pixel_vals[i, j + 1]) - / 255, - ) - lw = similarity * 2.5 + 0.3 - alpha = similarity * 0.8 + 0.2 - ax.plot( - [j + 0.4, j + 0.6], [3 - i, 3 - i], - color=GRAY5, linewidth=lw, alpha=alpha, - ) - if i < _GRID_LAST_IDX: - similarity = max( - 0, - 1 - abs(pixel_vals[i, j] - pixel_vals[i + 1, j]) - / 255, - ) - lw = similarity * 2.5 + 0.3 - alpha = similarity * 0.8 + 0.2 - ax.plot( - [j, j], [3 - i - 0.4, 3 - i - 0.6], - color=GRAY5, linewidth=lw, alpha=alpha, - ) - - -# ============================================================ -# 4. NORMALIZED CUTS — Graph cut visualization -# ============================================================ -def generate_normalized_cuts() -> None: - """Generate normalized cuts.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) - - # --- Panel 1: Image as graph --- - ax = axes[0] - ax.set_xlim(-0.5, 4.5) - ax.set_ylim(-0.5, 4.5) - ax.set_aspect("equal") - ax.set_title("Obraz → graf", fontsize=FS_TITLE, fontweight="bold") - - pixel_vals = np.array( - [ - [30, 35, 180, 190], - [40, 30, 185, 200], - [170, 180, 40, 35], - [190, 175, 30, 45], - ] - ) - _draw_ncuts_pixel_grid(ax, pixel_vals) - _draw_ncuts_edges(ax, pixel_vals) - - ax.text( - 2, - -0.8, - "Grube linie = duże podobieństwo\n(silna krawędź grafu)", - ha="center", - fontsize=FS_TINY, - color=GRAY5, - ) - ax.axis("off") - - # --- Panel 2: Cut concept --- - ax = axes[1] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Cięcie grafu (graph cut)", fontsize=FS_TITLE, fontweight="bold") - - # Draw two groups of nodes - # Group A (dark pixels) - positions_a = [(2, 7), (3, 8), (2, 5), (3, 6)] - positions_b = [(7, 7), (8, 8), (7, 5), (8, 6)] - - # Intra-group edges (thick = similar) - for i, (x1, y1) in enumerate(positions_a): - for x2, y2 in positions_a[i + 1 :]: - ax.plot([x1, x2], [y1, y2], color=ACCENT, linewidth=2, alpha=0.5) - for i, (x1, y1) in enumerate(positions_b): - for x2, y2 in positions_b[i + 1 :]: - ax.plot([x1, x2], [y1, y2], color=RED_ACCENT, linewidth=2, alpha=0.5) - - # Inter-group edges (thin = dissimilar) — these get cut - cut_edges = [((3, 8), (7, 7)), ((3, 6), (7, 5)), ((2, 5), (7, 5))] - for (x1, y1), (x2, y2) in cut_edges: - ax.plot([x1, x2], [y1, y2], color=GRAY4, linewidth=0.8, linestyle="--") - - # Draw nodes - for x, y in positions_a: - ax.scatter(x, y, c=ACCENT, s=120, zorder=5, edgecolors=BLACK, linewidth=0.8) - for x, y in positions_b: - ax.scatter(x, y, c="#FFCDD2", s=120, zorder=5, edgecolors=BLACK, linewidth=0.8) - - # Cut line - ax.plot( - [5, 5], [3.5, 9.5], color=RED_ACCENT, linewidth=2.5, linestyle="-", zorder=4 - ) - ax.text( - 5, 9.8, "CIĘCIE", ha="center", fontsize=FS, fontweight="bold", color=RED_ACCENT - ) - - ax.text( - 2.5, - 3.8, - "Segment A\n(ciemne piksele)", - ha="center", - fontsize=FS_SMALL, - color=ACCENT, - ) - ax.text( - 7.5, - 3.8, - "Segment B\n(jasne piksele)", - ha="center", - fontsize=FS_SMALL, - color=RED_ACCENT, - ) - - # Formula - ax.text( - 5, - 1.8, - "Ncut(A,B) = cut(A,B)/assoc(A,V)\n + cut(A,B)/assoc(B,V)", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, - ) - ax.text( - 5, - 0.5, - "Minimalizuj Ncut → tnij SŁABE krawędzie\nzachowuj SILNE (wewnątrz grupy)", - ha="center", - fontsize=FS_TINY, - color=GRAY5, - ) - - # --- Panel 3: Algorithm summary --- - ax = axes[2] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Algorytm Normalized Cuts", fontsize=FS_TITLE, fontweight="bold") - - steps = [ - ( - "1. Zbuduj graf", - "Piksele = węzły\nKrawędzie = podobieństwo" - " sąsiadów\n(kolor, jasność, odległość)", - ), - ( - "2. Macierz podobieństwa W", - "W[i,j] = exp(-|kolori - kolorj|² / σ²)" - "\n→ im podobniejsze, tym wyższa waga", - ), - ("3. Macierz stopni D", "D[i,i] = Σ W[i,j]\n(suma wszystkich wag z węzła i)"), - ("4. Rozwiąż problem własny", "(D-W)·y = λ·D·y\n→ drugi najm. wektor własny y"), - ("5. Podziel wg y", "y[i] > 0 → segment A\ny[i] ≤ 0 → segment B"), - ] - - y = 9.5 - for title, desc in steps: - ax.text(0.5, y, title, fontsize=FS, fontweight="bold", va="top") - y -= 0.4 - ax.text(0.8, y, desc, fontsize=FS_TINY, va="top", color=GRAY6) - y -= 1.2 - - ax.text( - 5, - 0.3, - "Złożoność: O(n³) — wymaga eigen decomposition!", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - color=RED_ACCENT, - ) - - _save_figure("q23_normalized_cuts.png") - - -# ============================================================ -# 5. RELU — Function plot -# ============================================================ -def generate_relu() -> None: - """Generate relu.""" - _fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) - - # --- Panel 1: ReLU plot --- - ax = axes[0] - x = np.linspace(-5, 5, 200) - relu = np.maximum(0, x) - ax.plot(x, relu, color=ACCENT, linewidth=2.5, label="ReLU(x) = max(0, x)") - ax.axhline(y=0, color=GRAY3, linewidth=0.5) - ax.axvline(x=0, color=GRAY3, linewidth=0.5) - ax.fill_between(x[x < 0], 0, 0, color=RED_ACCENT, alpha=0.1) - ax.fill_between(x[x >= 0], 0, relu[x >= 0], color=ACCENT, alpha=0.1) - - # Annotations - ax.annotate( - 'x < 0 → output = 0\n(neuron „wyłączony")', - xy=(-3, 0), - fontsize=FS_SMALL, - ha="center", - va="bottom", - color=RED_ACCENT, - arrowprops={"arrowstyle": "->", "color": RED_ACCENT}, - xytext=(-3, 2), - ) - ax.annotate( - 'x ≥ 0 → output = x\n(neuron „włączony")', - xy=(3, 3), - fontsize=FS_SMALL, - ha="center", - va="bottom", - color=ACCENT, - arrowprops={"arrowstyle": "->", "color": ACCENT}, - xytext=(3, 4.5), - ) - ax.scatter([0], [0], c=BLACK, s=40, zorder=5) - ax.text(0.3, -0.5, "(0,0)", fontsize=FS_SMALL, color=GRAY5) - ax.set_xlabel("x (wejście neuronu)", fontsize=FS) - ax.set_ylabel("ReLU(x)", fontsize=FS) - ax.set_title("ReLU — Rectified Linear Unit", fontsize=FS_TITLE, fontweight="bold") - ax.legend(fontsize=FS_SMALL, loc="upper left") - ax.set_ylim(-1, 6) - ax.grid(visible=True, alpha=0.2) - - # --- Panel 2: Why ReLU --- - ax = axes[1] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Dlaczego ReLU?", fontsize=FS_TITLE, fontweight="bold") - - y = 9.0 - lines = [ - ("Neuron oblicza:", FS, BLACK, "bold"), - (" z = w₁·x₁ + w₂·x₂ + ... + bias", FS_SMALL, BLACK, "normal"), - (" output = ReLU(z) = max(0, z)", FS_SMALL, ACCENT, "bold"), - ("", 0, "", ""), - ("Przykład:", FS, BLACK, "bold"), - (" wagi: w₁=0.5, w₂=-0.3, bias=0.1", FS_SMALL, BLACK, "normal"), - (" wejścia: x₁=2.0, x₂=4.0", FS_SMALL, BLACK, "normal"), - (" z = 0.5·2 + (-0.3)·4 + 0.1 = -0.1", FS_SMALL, BLACK, "normal"), - (" ReLU(-0.1) = max(0, -0.1) = 0", FS_SMALL, RED_ACCENT, "bold"), - (" → neuron milczy (wejście nieistotne)", FS_SMALL, GRAY5, "normal"), - ("", 0, "", ""), - ("Gdyby z = 2.3:", FS, BLACK, "bold"), - (" ReLU(2.3) = max(0, 2.3) = 2.3", FS_SMALL, GREEN_ACCENT, "bold"), - (" → neuron aktywny! Przekazuje sygnał", FS_SMALL, GRAY5, "normal"), - ("", 0, "", ""), - ("Szybsza niż sigmoid/tanh", FS_SMALL, GRAY5, "normal"), - ("(brak exp() → szybkie obliczenia)", FS_SMALL, GRAY5, "normal"), - ] - for txt, size, color, weight in lines: - if txt == "": - y -= 0.2 - continue - ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top") - y -= 0.5 - - _save_figure("q23_relu.png") - - -# ============================================================ -# 6. DOT PRODUCT — Iloczyn skalarny visual -# ============================================================ -def generate_dot_product() -> None: - """Generate dot product.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) - - # --- Panel 1: Concept --- - ax = axes[0] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title( - "Iloczyn skalarny\n(dot product)", fontsize=FS_TITLE, fontweight="bold" - ) - - y = 8.5 - lines = [ - ("Dwa wektory (listy liczb) → JEDNA liczba", FS, BLACK, "bold"), - ("", 0, "", ""), - ("a = [a₁, a₂, a₃] b = [b₁, b₂, b₃]", FS, ACCENT, "normal"), - ("", 0, "", ""), - ("a · b = a₁·b₁ + a₂·b₂ + a₃·b₃", FS, BLACK, "bold"), - ("", 0, "", ""), - ("Przykład:", FS, BLACK, "bold"), - ("a = [1, 3, -2] b = [4, -1, 5]", FS_SMALL, BLACK, "normal"), - ("a·b = 1·4 + 3·(-1) + (-2)·5", FS_SMALL, BLACK, "normal"), - (" = 4 + (-3) + (-10) = -9", FS_SMALL, RED_ACCENT, "bold"), - ("", 0, "", ""), - ( - 'Duży wynik → wektory „podobne" (w tym samym kierunku)', - FS_SMALL, - GREEN_ACCENT, - "normal", - ), - ('Mały/ujemny → wektory „różne"', FS_SMALL, RED_ACCENT, "normal"), - ] - for txt, size, color, weight in lines: - if txt == "": - y -= 0.25 - continue - ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top") - y -= 0.55 - - # --- Panel 2: Convolution as dot product --- - ax = axes[1] - ax.set_xlim(-0.5, 5.5) - ax.set_ylim(-0.5, 5.5) - ax.set_aspect("equal") - ax.set_title( - "Konwolucja = iloczyn skalarny\nfiltra x fragment obrazu", - fontsize=FS_TITLE, - fontweight="bold", - ) - - # Filter 3x3 - filter_vals = [[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]] - for i in range(3): - for j in range(3): - rect = patches.Rectangle( - (j - 0.4, 4 - i - 0.4), - 0.8, - 0.8, - facecolor=ACCENT_LIGHT, - edgecolor=BLACK, - linewidth=0.8, - ) - ax.add_patch(rect) - ax.text( - j, - 4 - i, - str(filter_vals[i][j]), - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - - ax.text(1, 1.5, "Filtr", ha="center", fontsize=FS, fontweight="bold", color=ACCENT) - - # Image patch - img_vals = [[50, 50, 200], [50, 50, 200], [50, 50, 200]] - for i in range(3): - for j in range(3): - rect = patches.Rectangle( - (j + 2.6, 4 - i - 0.4), - 0.8, - 0.8, - facecolor=GRAY2, - edgecolor=BLACK, - linewidth=0.8, - ) - ax.add_patch(rect) - ax.text( - j + 3, - 4 - i, - str(img_vals[i][j]), - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - - ax.text( - 4, - 1.5, - "Fragment\nobrazu", - ha="center", - fontsize=FS, - fontweight="bold", - color=GRAY5, - ) - - ax.text( - 2.5, - 0.5, - "(-1)·50 + 0·50 + 1·200 +\n" - "(-1)·50 + 0·50 + 1·200 +\n" - "(-1)·50 + 0·50 + 1·200\n= 450 (krawędź!)", - ha="center", - fontsize=FS_TINY, - fontweight="bold", - bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GREEN_ACCENT}, - ) - - ax.axis("off") - - # --- Panel 3: Vector visualization --- - ax = axes[2] - # Draw two vectors - ax.quiver( - 0, - 0, - 3, - 4, - angles="xy", - scale_units="xy", - scale=1, - color=ACCENT, - width=0.025, - label="a = [3, 4]", - ) - ax.quiver( - 0, - 0, - 4, - 1, - angles="xy", - scale_units="xy", - scale=1, - color=RED_ACCENT, - width=0.025, - label="b = [4, 1]", - ) - - # Show angle - theta = np.linspace(np.arctan2(1, 4), np.arctan2(4, 3), 30) - r = 1.5 - ax.plot(r * np.cos(theta), r * np.sin(theta), color=GREEN_ACCENT, linewidth=1.5) - ax.text(1.8, 1.3, "θ", fontsize=FS, color=GREEN_ACCENT, fontweight="bold") - - ax.text(3.2, 4.2, "a", fontsize=FS, color=ACCENT, fontweight="bold") - ax.text(4.2, 1.2, "b", fontsize=FS, color=RED_ACCENT, fontweight="bold") - - ax.text( - 2.5, - -1.0, - "a · b = |a|·|b|·cos(θ)\n= 3·4 + 4·1 = 16", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, - ) - ax.text( - 2.5, - -2.0, - 'Mały kąt θ → duży dot product\n= wektory „zgadają się"', - ha="center", - fontsize=FS_TINY, - color=GRAY5, - ) - - ax.set_xlim(-0.5, 5.5) - ax.set_ylim(-2.5, 5.5) - ax.set_aspect("equal") - ax.grid(visible=True, alpha=0.2) - ax.legend(fontsize=FS_SMALL, loc="upper left") - ax.set_title("Geometrycznie: kąt", fontsize=FS_TITLE, fontweight="bold") - - _save_figure("q23_dot_product.png") - - -# ============================================================ -# 7. FCN — FC vs Conv 1x1, skip connections -# ============================================================ -def generate_fcn() -> None: - """Generate fcn.""" - _fig, axes = plt.subplots(2, 1, figsize=(10, 7)) - - # --- Panel 1: FC vs Conv 1x1 --- - ax = axes[0] - ax.set_xlim(0, 20) - ax.set_ylim(0, 6) - ax.axis("off") - ax.set_title( - "FC (Fully Connected) vs Conv 1x1", fontsize=FS_TITLE, fontweight="bold" - ) - - # Classic CNN with FC - layer_info_fc = [ - (1.5, "Obraz\n224x224x3", 2.2, GRAY2), - (4.5, "Conv+Pool\n112x112x64", 1.8, GRAY2), - (7.5, "Conv+Pool\n7x7x512", 1.0, GRAY2), - (10, "Flatten\n25088", 0.5, ACCENT_LIGHT), - (12, "FC\n4096", 0.5, ACCENT_LIGHT), - (14, "FC\n1000", 0.3, ACCENT_LIGHT), - (16, '"Kot"', 0.3, "#FFCDD2"), - ] - - y_fc = 4.5 - for i, (x, label, w, color) in enumerate(layer_info_fc): - rect = FancyBboxPatch( - (x - w / 2, y_fc - 0.6), - w, - 1.2, - boxstyle="round,pad=0.05", - facecolor=color, - edgecolor=BLACK, - linewidth=0.8, - ) - ax.add_patch(rect) - ax.text(x, y_fc, label, ha="center", va="center", fontsize=FS_TINY) - if i < len(layer_info_fc) - 1: - next_x = layer_info_fc[i + 1][0] - ax.annotate( - "", - xy=(next_x - layer_info_fc[i + 1][2] / 2, y_fc), - xytext=(x + w / 2, y_fc), - arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1}, - ) - - ax.text( - 0.3, y_fc, "CNN:", fontsize=FS, fontweight="bold", color=RED_ACCENT, va="center" - ) - ax.text( - 12, - y_fc + 1, - "PROBLEM: FC wymaga\nSTAŁEGO rozmiaru\n(np. 224x224)", - ha="center", - fontsize=FS_SMALL, - color=RED_ACCENT, - fontweight="bold", - bbox={ - "boxstyle": "round", - "facecolor": "#FFCDD2", - "edgecolor": RED_ACCENT, - "alpha": 0.3, - }, - ) - - # FCN with Conv 1x1 - layer_info_fcn = [ - (1.5, "Obraz\nHxWx3", 2.2, GRAY2), - (4.5, "Conv+Pool\nH/2 x W/2\nx64", 1.8, GRAY2), - (7.5, "Conv+Pool\nH/32 x W/32\nx512", 1.0, GRAY2), - (10.5, "Conv 1x1\nH/32 x W/32\nxC", 0.8, "#C8E6C9"), - (13.5, "Upsample\nHxWxC", 1.8, "#C8E6C9"), - (16.5, "Mapa\nsegmentacji", 1.5, "#C8E6C9"), - ] - - y_fcn = 1.5 - for i, (x, label, w, color) in enumerate(layer_info_fcn): - rect = FancyBboxPatch( - (x - w / 2, y_fcn - 0.7), - w, - 1.4, - boxstyle="round,pad=0.05", - facecolor=color, - edgecolor=BLACK, - linewidth=0.8, - ) - ax.add_patch(rect) - ax.text(x, y_fcn, label, ha="center", va="center", fontsize=FS_TINY) - if i < len(layer_info_fcn) - 1: - next_x = layer_info_fcn[i + 1][0] - ax.annotate( - "", - xy=(next_x - layer_info_fcn[i + 1][2] / 2, y_fcn), - xytext=(x + w / 2, y_fcn), - arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1}, - ) - - ax.text( - 0.3, - y_fcn, - "FCN:", - fontsize=FS, - fontweight="bold", - color=GREEN_ACCENT, - va="center", - ) - ax.text( - 10.5, - y_fcn + 1.2, - "Conv 1x1:\nkażdy piksel\nosobno x wagi\n(jak FC ale\nzachowuje HxW)", - ha="center", - fontsize=FS_TINY, - color=GREEN_ACCENT, - bbox={ - "boxstyle": "round", - "facecolor": "#C8E6C9", - "edgecolor": GREEN_ACCENT, - "alpha": 0.3, - }, - ) - - # --- Panel 2: What FC and Conv do --- - ax = axes[1] - ax.set_xlim(0, 20) - ax.set_ylim(0, 6) - ax.axis("off") - ax.set_title( - "Co robi warstwa FC? Co robi konwolucja?", fontsize=FS_TITLE, fontweight="bold" - ) - - # FC explanation - rect = FancyBboxPatch( - (0.3, 3.2), - 9, - 2.5, - boxstyle="round,pad=0.15", - facecolor=ACCENT_LIGHT, - edgecolor=ACCENT, - linewidth=1, - ) - ax.add_patch(rect) - ax.text( - 4.8, 5.2, "Fully Connected (FC)", fontsize=FS, fontweight="bold", ha="center" - ) - ax.text( - 4.8, - 4.5, - "KAŻDY neuron połączony z KAŻDYM wejściem\n" - "25 088 wejść x 4 096 neuronów = ~103 MLN wag!\n" - "Traci informację GDZIE (przestrzenną)\n" - "Wymaga STAŁEGO rozmiaru wejścia", - fontsize=FS_TINY, - ha="center", - va="top", - ) - - # Conv explanation - rect = FancyBboxPatch( - (10.3, 3.2), - 9, - 2.5, - boxstyle="round,pad=0.15", - facecolor="#C8E6C9", - edgecolor=GREEN_ACCENT, - linewidth=1, - ) - ax.add_patch(rect) - ax.text(14.8, 5.2, "Konwolucja (Conv)", fontsize=FS, fontweight="bold", ha="center") - ax.text( - 14.8, - 4.5, - 'Filtr (np. 3x3) „jedzie" po obrazie\n' - "Te same wagi dla KAŻDEJ pozycji\n" - "Zachowuje informację GDZIE\n" - "Akceptuje DOWOLNY rozmiar wejścia", - fontsize=FS_TINY, - ha="center", - va="top", - ) - - # Conv 1x1 explanation - rect = FancyBboxPatch( - (3, 0.3), - 14, - 2.2, - boxstyle="round,pad=0.15", - facecolor=GRAY1, - edgecolor=BLACK, - linewidth=1, - ) - ax.add_patch(rect) - ax.text( - 10, - 2.1, - 'Conv 1x1 = „FC per piksel"', - fontsize=FS, - fontweight="bold", - ha="center", - ) - ax.text( - 10, - 1.5, - "Filtr 1x1: patrzy na JEDEN piksel, ale WSZYSTKIE kanały (512→C klas)\n" - "Działa jak FC ale zachowuje mapę HxW → każdy piksel osobno klasyfikowany\n" - "FCN: zamień FC na Conv1x1 → koniec z wymogiem stałego rozmiaru!", - fontsize=FS_TINY, - ha="center", - va="top", - ) - - _save_figure("q23_fc_vs_conv1x1.png") - - -# ============================================================ -# 8. U-NET ARCHITECTURE — Proper U-shaped diagram -# ============================================================ -def generate_unet() -> None: - """Generate unet.""" - _fig, ax = plt.subplots(1, 1, figsize=(10, 6)) - ax.set_xlim(-1, 21) - ax.set_ylim(-1, 12) - ax.axis("off") - ax.set_title( - "U-Net: architektura w kształcie litery U", - fontsize=FS_TITLE + 1, - fontweight="bold", - ) - - # Encoder layers (going DOWN-LEFT) - encoder_layers = [ - (2, 10, 2.5, 1.5, "572x572x1\n(wejście)", 64), - (2, 7.5, 2.2, 1.3, "284x284\nx64", 64), - (2, 5, 1.8, 1.1, "140x140\nx128", 128), - (2, 2.5, 1.5, 1.0, "68x68\nx256", 256), - ] - - # Bottleneck - bottleneck = (8, 0.5, 2.5, 1.2, "32x32x512\n(bottleneck)", 512) - - # Decoder layers (going UP-RIGHT) - decoder_layers = [ - (14, 2.5, 1.5, 1.0, "68x68\nx256", 256), - (14, 5, 1.8, 1.1, "140x140\nx128", 128), - (14, 7.5, 2.2, 1.3, "284x284\nx64", 64), - (14, 10, 2.5, 1.5, "572x572xC\n(mapa seg.)", "C"), - ] - - def draw_block( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - label: str, - color: str, - ) -> None: - """Draw block.""" - rect = FancyBboxPatch( - (x - w / 2, y - h / 2), - w, - h, - boxstyle="round,pad=0.05", - facecolor=color, - edgecolor=BLACK, - linewidth=1.2, - ) - ax.add_patch(rect) - ax.text(x, y, label, ha="center", va="center", fontsize=FS_TINY) - - # Draw encoder - for x, y, w, h, label, _channels in encoder_layers: - draw_block(ax, x, y, w, h, label, ACCENT_LIGHT) - - # Draw arrows down (encoder) - for i in range(len(encoder_layers) - 1): - x1, y1 = encoder_layers[i][0], encoder_layers[i][1] - encoder_layers[i][3] / 2 - x2, y2 = ( - encoder_layers[i + 1][0], - encoder_layers[i + 1][1] + encoder_layers[i + 1][3] / 2, - ) - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2}, - ) - ax.text( - x1 - 1.7, - (y1 + y2) / 2, - "MaxPool\n2x2\n↓ zmniejsz", - fontsize=FS_TINY, - ha="center", - color=ACCENT, - fontweight="bold", - ) - - # Encoder to bottleneck - x1, y1 = encoder_layers[-1][0], encoder_layers[-1][1] - encoder_layers[-1][3] / 2 - draw_block( - ax, - bottleneck[0], - bottleneck[1], - bottleneck[2], - bottleneck[3], - bottleneck[4], - GRAY2, - ) - ax.annotate( - "", - xy=(bottleneck[0] - bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2), - xytext=(x1, y1), - arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2}, - ) - - # Bottleneck to decoder - ax.annotate( - "", - xy=( - decoder_layers[0][0] - decoder_layers[0][2] / 2, - decoder_layers[0][1] - decoder_layers[0][3] / 2, - ), - xytext=(bottleneck[0] + bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2), - arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 2}, - ) - - # Draw decoder - for x, y, w, h, label, channels in decoder_layers: - color = "#C8E6C9" if channels != "C" else "#A5D6A7" - draw_block(ax, x, y, w, h, label, color) - - # Draw arrows up (decoder) - for i in range(len(decoder_layers) - 1): - x1, y1 = decoder_layers[i][0], decoder_layers[i][1] + decoder_layers[i][3] / 2 - x2, y2 = ( - decoder_layers[i + 1][0], - decoder_layers[i + 1][1] - decoder_layers[i + 1][3] / 2, - ) - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT, "lw": 2}, - ) - ax.text( - x1 + 2, - (y1 + y2) / 2, - "UpConv\n2x2\n↑ zwiększ", - fontsize=FS_TINY, - ha="center", - color=GREEN_ACCENT, - fontweight="bold", - ) - - # Skip connections (horizontal arrows) - for i in range(len(encoder_layers)): - enc = encoder_layers[i] - dec = decoder_layers[len(decoder_layers) - 1 - i] - ax.annotate( - "", - xy=(dec[0] - dec[2] / 2, dec[1]), - xytext=(enc[0] + enc[2] / 2, enc[1]), - arrowprops={ - "arrowstyle": "->", - "color": GRAY5, - "lw": 1.5, - "linestyle": "dashed", - }, - ) - mid_x = (enc[0] + enc[2] / 2 + dec[0] - dec[2] / 2) / 2 - ax.text( - mid_x, - enc[1] + 0.6, - "skip\n(concat)", - fontsize=FS_TINY, - ha="center", - color=GRAY5, - fontweight="bold", - ) - - # Labels - ax.text( - 0, - 11.5, - "ENCODER\n(↓ zmniejsza)", - fontsize=FS, - fontweight="bold", - color=ACCENT, - ha="center", - ) - ax.text( - 17, - 11.5, - "DECODER\n(↑ zwiększa)", - fontsize=FS, - fontweight="bold", - color=GREEN_ACCENT, - ha="center", - ) - ax.text( - 8, - -0.8, - 'Kształt litery „U": encoder schodzi ↓ → bottleneck na dnie → decoder wraca ↑', - fontsize=FS_SMALL, - ha="center", - color=GRAY5, - fontweight="bold", - ) - - # Concatenation explanation - rect = FancyBboxPatch( - (17.5, 3), - 3, - 5, - boxstyle="round,pad=0.15", - facecolor=GRAY1, - edgecolor=GRAY5, - linewidth=1, - linestyle="--", - ) - ax.add_patch(rect) - ax.text( - 19, 7.5, "Concatenation:", fontsize=FS_SMALL, ha="center", fontweight="bold" - ) - ax.text( - 19, - 6.5, - "Encoder: 64 kanały\nDecoder: 64 kanały\n→ concat → 128 kanałów\n\n" - "Jak sklejenie\ndwóch stosów\nkart:", - fontsize=FS_TINY, - ha="center", - ) - ax.text( - 19, - 3.7, - "[enc₁|enc₂|...|dec₁|dec₂|...]", - fontsize=FS_TINY - 1, - ha="center", - fontweight="bold", - color=ACCENT, - ) - - _save_figure("q23_unet_arch.png") - - -# ============================================================ -# 9. RECEPTIVE FIELD — with dilation -# ============================================================ -def generate_receptive_field() -> None: - """Generate receptive field.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) - - def draw_grid( - ax: Axes, - size: int, - highlight_cells: list[tuple[int, int]], - highlight_color: str, - title: str, - grid_offset: tuple[int, int] = (0, 0), - ) -> None: - """Draw grid.""" - ox, oy = grid_offset - for i in range(size): - for j in range(size): - color = WHITE - if (i, j) in highlight_cells: - color = highlight_color - rect = patches.Rectangle( - (ox + j, oy + size - 1 - i), - 1, - 1, - facecolor=color, - edgecolor=GRAY4, - linewidth=0.5, - ) - ax.add_patch(rect) - ax.set_title(title, fontsize=FS_TITLE, fontweight="bold") - - # --- Panel 1: Standard 3x3 conv receptive field --- - ax = axes[0] - ax.set_xlim(-0.5, 7.5) - ax.set_ylim(-1, 8) - ax.set_aspect("equal") - ax.axis("off") - - # 7x7 input grid - highlight_3x3 = [ - (2, 2), - (2, 3), - (2, 4), - (3, 2), - (3, 3), - (3, 4), - (4, 2), - (4, 3), - (4, 4), - ] - draw_grid(ax, 7, highlight_3x3, ACCENT_LIGHT, "Zwykła conv 3x3") - ax.text( - 3.5, - -0.5, - "RF = 3x3 pikseli", - fontsize=FS, - ha="center", - fontweight="bold", - color=ACCENT, - ) - - # --- Panel 2: Dilated conv (rate=2) --- - ax = axes[1] - ax.set_xlim(-0.5, 7.5) - ax.set_ylim(-1, 8) - ax.set_aspect("equal") - ax.axis("off") - - # 7x7 input grid with dilated highlights - highlight_dilated = [ - (1, 1), - (1, 3), - (1, 5), - (3, 1), - (3, 3), - (3, 5), - (5, 1), - (5, 3), - (5, 5), - ] - draw_grid(ax, 7, highlight_dilated, "#FFCDD2", "Dilated conv 3x3\n(rate=2)") - ax.text( - 3.5, - -0.5, - "RF = 5x5, ale 9 parametrów!", - fontsize=FS, - ha="center", - fontweight="bold", - color=RED_ACCENT, - ) - - # Connect dots to show pattern - dots_x = [1.5, 3.5, 5.5, 1.5, 3.5, 5.5, 1.5, 3.5, 5.5] - dots_y = [5.5, 5.5, 5.5, 3.5, 3.5, 3.5, 1.5, 1.5, 1.5] - ax.scatter(dots_x, dots_y, c=RED_ACCENT, s=30, zorder=5) - - # --- Panel 3: Comparison --- - ax = axes[2] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title( - "Receptive Field\n(pole widzenia neuronu)", fontsize=FS_TITLE, fontweight="bold" - ) - - y = 8.5 - lines = [ - ("RF = ile pikseli WEJŚCIOWYCH", FS, BLACK, "bold"), - ("wpływa na JEDEN piksel wyjścia", FS, BLACK, "bold"), - ("", 0, "", ""), - ("Rate (współczynnik dylatacji):", FS, BLACK, "bold"), - (' rate=1: filtr „dotyka" sąsiadów', FS_SMALL, BLACK, "normal"), - (" rate=2: co drugi piksel → RF = 5x5", FS_SMALL, BLACK, "normal"), - (" rate=3: co trzeci → RF = 7x7", FS_SMALL, BLACK, "normal"), - (" WIĘCEJ kontekstu, TE SAME wagi!", FS_SMALL, GREEN_ACCENT, "bold"), - ("", 0, "", ""), - ("Dlaczego ważne w segmentacji?", FS, BLACK, "bold"), - (" Piksel sam nie wie czym jest.", FS_SMALL, BLACK, "normal"), - (" Potrzebuje KONTEKSTU (otoczenia).", FS_SMALL, BLACK, "normal"), - (" Większe RF → widzi obok budynki", FS_SMALL, BLACK, "normal"), - (' → wie, że TEN piksel to „droga"', FS_SMALL, GREEN_ACCENT, "bold"), - ("", 0, "", ""), - ("Global Average Pooling:", FS, BLACK, "bold"), - (" Mapa HxWxC → 1x1xC", FS_SMALL, BLACK, "normal"), - (" Średnia z CAŁEGO feature map", FS_SMALL, BLACK, "normal"), - (" RF = nieskończone (cały obraz)", FS_SMALL, GREEN_ACCENT, "bold"), - ] - for txt, size, color, weight in lines: - if txt == "": - y -= 0.2 - continue - ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top") - y -= 0.45 - - _save_figure("q23_receptive_field.png") - - -# ============================================================ -# 10. TRANSFORMER / Self-attention / SOTA -# ============================================================ -def generate_transformer() -> None: - """Generate transformer.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 4)) - - # --- Panel 1: CNN local vs Transformer global --- - ax = axes[0] - ax.set_xlim(-0.5, 8.5) - ax.set_ylim(-1.5, 8.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("CNN: widzi LOKALNIE", fontsize=FS_TITLE, fontweight="bold") - - # Draw 8x8 grid - for i in range(8): - for j in range(8): - color = WHITE - if (_HIGHLIGHT_START <= i <= _HIGHLIGHT_END - and _HIGHLIGHT_START <= j <= _HIGHLIGHT_END): - color = ACCENT_LIGHT - rect = patches.Rectangle( - (j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3 - ) - ax.add_patch(rect) - - # Highlight center - rect = patches.Rectangle( - (4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.7 - ) - ax.add_patch(rect) - ax.text( - 4.5, - 4.5, - "?", - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - color=WHITE, - ) - ax.text( - 4.5, - -0.8, - "Filtr 3x3 widzi tylko\n9 sąsiednich pikseli", - fontsize=FS_SMALL, - ha="center", - color=ACCENT, - ) - - # --- Panel 2: Transformer global --- - ax = axes[1] - ax.set_xlim(-0.5, 8.5) - ax.set_ylim(-1.5, 8.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Transformer: widzi GLOBALNIE", fontsize=FS_TITLE, fontweight="bold") - - # Draw 8x8 grid all highlighted - for i in range(8): - for j in range(8): - color = "#FFCDD2" - rect = patches.Rectangle( - (j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3 - ) - ax.add_patch(rect) - - rect = patches.Rectangle( - (4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.9 - ) - ax.add_patch(rect) - ax.text( - 4.5, - 4.5, - "?", - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - color=WHITE, - ) - ax.text( - 4.5, - -0.8, - 'Self-attention „pyta"\nALL 64 piksele naraz', - fontsize=FS_SMALL, - ha="center", - color=RED_ACCENT, - ) - - # --- Panel 3: SOTA + Transformer explanation --- - ax = axes[2] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Transformer & SOTA", fontsize=FS_TITLE, fontweight="bold") - - y = 9.2 - lines = [ - ("Transformer:", FS, BLACK, "bold"), - (" Architektura z 2017 (Vaswani et al.)", FS_SMALL, BLACK, "normal"), - (" Oryginalnie do NLP (tłumaczenie)", FS_SMALL, BLACK, "normal"), - (" Kluczowy mechanizm: SELF-ATTENTION", FS_SMALL, ACCENT, "bold"), - ("", 0, "", ""), - ("Self-attention w skrócie:", FS, BLACK, "bold"), - (" Każdy piksel tworzy trzy wektory:", FS_SMALL, BLACK, "normal"), - (' Q (Query — „czego szukam?")', FS_SMALL, ACCENT, "normal"), - (' K (Key — „co oferuję innych")', FS_SMALL, RED_ACCENT, "normal"), - (' V (Value — „moja wartość")', FS_SMALL, GREEN_ACCENT, "normal"), - (" Attention = softmax(Q·Kᵀ/√d)·V", FS_SMALL, BLACK, "bold"), - (" Koszt: O(n²) — n=liczba pikseli", FS_SMALL, RED_ACCENT, "normal"), - ("", 0, "", ""), - ("SOTA = State Of The Art:", FS, BLACK, "bold"), - (" Najlepszy znany wynik na benchmarku", FS_SMALL, BLACK, "normal"), - (' Np. „mIoU 85.1% na ADE20K = SOTA"', FS_SMALL, BLACK, "normal"), - (" Ciągle się zmienia (nowy paper", FS_SMALL, GRAY5, "normal"), - (" → nowy SOTA)", FS_SMALL, GRAY5, "normal"), - ] - for txt, size, color, weight in lines: - if txt == "": - y -= 0.15 - continue - ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top") - y -= 0.45 - - _save_figure("q23_transformer_attention.png") - - -def _draw_region_growing_grid(ax: Axes) -> None: - """Draw panel 2: region growing step-by-step grid.""" - ax.set_xlim(-0.5, 6.5) - ax.set_ylim(-1.5, 7.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Region Growing: krok po kroku", - fontsize=FS_TITLE, fontweight="bold", - ) - - pixel_grid = np.array([ - [150, 153, 148, 200, 210, 205], - [147, 155, 152, 195, 208, 200], - [145, 148, 160, 190, 195, 210], - [200, 195, 190, 155, 148, 150], - [210, 205, 200, 150, 152, 145], - [215, 208, 195, 148, 147, 155], - ]) - region_mask = np.array([ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - ]) - - for i in range(6): - for j in range(6): - v = pixel_grid[i, j] - if region_mask[i, j] == 1 and v < _BRIGHT_THRESHOLD: - cell_color = ACCENT_LIGHT - elif region_mask[i, j] == 1: - cell_color = GRAY2 - else: - cell_color = WHITE - if i == 1 and j == 1: - cell_color = "#FFD54F" - rect = patches.Rectangle( - (j, 5 - i), 1, 1, - facecolor=cell_color, edgecolor=GRAY4, - linewidth=0.5, - ) - ax.add_patch(rect) - ax.text( - j + 0.5, 5 - i + 0.5, str(v), - ha="center", va="center", - fontsize=FS_TINY, fontweight="bold", - ) - - ax.annotate( - "SEED\n(155)", xy=(1.5, 4.5), - fontsize=FS_SMALL, ha="center", - color=RED_ACCENT, fontweight="bold", - arrowprops={"arrowstyle": "->", "color": RED_ACCENT}, - xytext=(-0.5, 7), - ) - ax.text( - 3, -0.8, - "Próg = 20\nNiebieski = region (|val - seed| < 20)", - fontsize=FS_TINY, ha="center", color=ACCENT, - ) - - -def _draw_bfs_expansion(ax: Axes) -> None: - """Draw panel 3: BFS expansion visualization.""" - ax.set_xlim(-0.5, 6.5) - ax.set_ylim(-1.5, 7.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Rosnący region (BFS)", - fontsize=FS_TITLE, fontweight="bold", - ) - - wave_colors = ["#FFD54F", "#FFF176", "#FFF9C4", ACCENT_LIGHT, "#B3D4FC"] - wave_labels = ["Seed", "Fala 1", "Fala 2", "Fala 3", "Fala 4"] - waves = [ - [(1, 1)], - [(0, 1), (1, 0), (1, 2), (2, 1)], - [(0, 0), (0, 2), (2, 0), (2, 2)], - ] - - for i in range(6): - for j in range(6): - cell_color = WHITE - for w_idx, wave in enumerate(waves): - if (i, j) in wave: - cell_color = wave_colors[w_idx] - rect = patches.Rectangle( - (j, 5 - i), 1, 1, - facecolor=cell_color, edgecolor=GRAY4, - linewidth=0.5, - ) - ax.add_patch(rect) - - seed_x, seed_y = 1.5, 4.5 - for dx, dy, _label in [ - (0, 1, ""), (0, -1, ""), (1, 0, ""), (-1, 0, ""), - ]: - ax.annotate( - "", - xy=(seed_x + dx * 0.7, seed_y + dy * 0.7), - xytext=(seed_x, seed_y), - arrowprops={ - "arrowstyle": "->", "color": RED_ACCENT, "lw": 1.2, - }, - ) - - ax.text( - 3, -0.5, - "BFS: sprawdzaj sąsiadów,\ndodawaj podobne do kolejki", - fontsize=FS_TINY, ha="center", color=GRAY5, - ) - - for w_idx, (wave_color, label) in enumerate( - zip(wave_colors[:3], wave_labels[:3], strict=False) - ): - rect = patches.Rectangle( - (4, 6.5 - w_idx * 0.7), 0.5, 0.5, - facecolor=wave_color, edgecolor=GRAY4, linewidth=0.5, - ) - ax.add_patch(rect) - ax.text( - 4.8, 6.75 - w_idx * 0.7, label, - fontsize=FS_TINY, va="center", - ) - - -# ============================================================ -# 11. REGION GROWING — seed selection + BFS -# ============================================================ -def generate_region_growing() -> None: - """Generate region growing.""" - _fig, axes = plt.subplots(1, 3, figsize=(11, 4.2)) - - # --- Panel 1: Manual vs automatic seed --- - ax = axes[0] - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Seed: ręcznie vs automatycznie", fontsize=FS_TITLE, fontweight="bold") - - y = 9.2 - lines = [ - ("Ręczny seed:", FS, ACCENT, "bold"), - (" Użytkownik klika na obraz", FS_SMALL, BLACK, "normal"), - (' → „tu jest obiekt, od tego zacznij"', FS_SMALL, BLACK, "normal"), - (" Użycie: segmentacja interaktywna", FS_SMALL, GRAY5, "normal"), - (" (np. Photoshop — magic wand tool)", FS_SMALL, GRAY5, "normal"), - ("", 0, "", ""), - ("Automatyczny seed:", FS, RED_ACCENT, "bold"), - (" 1. Histogram → lokalne maxima", FS_SMALL, BLACK, "normal"), - (" (najczęstsza jasność → seed)", FS_SMALL, GRAY5, "normal"), - (" 2. Grid: siatka co N pikseli", FS_SMALL, BLACK, "normal"), - (" (np. seed co 50 px → 100 seedów)", FS_SMALL, GRAY5, "normal"), - (" 3. Losowe próbkowanie", FS_SMALL, BLACK, "normal"), - (" 4. Ekstrema lokalne gradientu", FS_SMALL, BLACK, "normal"), - ("", 0, "", ""), - ("Dlaczego OR?", FS, GREEN_ACCENT, "bold"), - (" Ręczny → precyzyjny, ale wolny", FS_SMALL, BLACK, "normal"), - (" Auto → szybki, ale over-segmentation", FS_SMALL, BLACK, "normal"), - ] - for txt, size, color, weight in lines: - if txt == "": - y -= 0.15 - continue - ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top") - y -= 0.45 - - # --- Panel 2: Region growing step by step --- - _draw_region_growing_grid(axes[1]) - - # --- Panel 3: BFS expansion --- - _draw_bfs_expansion(axes[2]) - - _save_figure("q23_region_growing.png") - - -def _draw_otsu_variance_and_pseudocode( - ax_var: Axes, - ax_code: Axes, - img: np.ndarray, -) -> int: - """Draw panels 4 and 5: Otsu variance plot and pseudocode.""" - thresholds = range(10, 245) - variances = [] - for t in thresholds: - c0 = img[img <= t].ravel() - c1 = img[img > t].ravel() - if len(c0) == 0 or len(c1) == 0: - variances.append(np.nan) - continue - w0 = len(c0) / len(img.ravel()) - w1 = len(c1) / len(img.ravel()) - var = w0 * np.var(c0) + w1 * np.var(c1) - variances.append(var) - - ax_var.plot(list(thresholds), variances, color=ACCENT, linewidth=1.5) - best_t = list(thresholds)[np.nanargmin(variances)] - ax_var.axvline( - x=best_t, color=RED_ACCENT, - linewidth=1.5, linestyle="--", - label=f"Otsu T={best_t}", - ) - ax_var.scatter( - [best_t], [np.nanmin(variances)], - c=RED_ACCENT, s=60, zorder=5, - ) - ax_var.set_xlabel("Próg T", fontsize=FS_SMALL) - ax_var.set_ylabel("σ² wewnątrzklasowa", fontsize=FS_SMALL) - ax_var.set_title( - "Krok 4: Otsu szuka min σ²", - fontsize=FS, fontweight="bold", - ) - ax_var.legend(fontsize=FS_TINY) - - ax_code.set_xlim(0, 10) - ax_code.set_ylim(0, 10) - ax_code.axis("off") - ax_code.set_title("Pseudokod Otsu", fontsize=FS, fontweight="bold") - - code_lines = [ - "best_T = 0", "min_var = ∞", "", - "for T in 0..255:", - " c0 = piksele z jasność ≤ T", - " c1 = piksele z jasność > T", - " w0 = len(c0) / len(all)", - " w1 = len(c1) / len(all)", - " var = w0·var(c0) + w1·var(c1)", - " if var < min_var:", - " min_var = var", " best_T = T", "", - "return best_T # optymalny próg", - ] - for i, line in enumerate(code_lines): - txt_color = ( - ACCENT - if "best_T = T" in line or "return" in line - else BLACK - ) - ax_code.text( - 0.5, 9.5 - i * 0.65, line, - fontsize=FS_TINY, fontfamily="monospace", - color=txt_color, - fontweight="bold" if txt_color == ACCENT else "normal", - ) - return int(best_t) - - -# ============================================================ -# 12. DIY THRESHOLDING — Step-by-step example -# ============================================================ -def generate_diy_thresholding() -> None: - """Generate diy thresholding.""" - _fig, axes = plt.subplots(2, 3, figsize=(11, 7)) - - # Create a simple synthetic image: dark circle on bright background - size = 64 - img = np.ones((size, size)) * 200 # bright background - yy, xx = np.mgrid[:size, :size] - mask = ((xx - 32) ** 2 + (yy - 32) ** 2) < 15**2 - img[mask] = 60 # dark circle - # Add some noise - img += rng.normal(0, 10, img.shape) - img = np.clip(img, 0, 255) - - # --- Panel 1: Original image --- - ax = axes[0, 0] - ax.imshow(img, cmap="gray", vmin=0, vmax=255) - ax.set_title("Krok 1: obraz wejściowy", fontsize=FS, fontweight="bold") - ax.axis("off") - ax.text(32, -3, "64x64 pikseli, szare", fontsize=FS_TINY, ha="center") - - # --- Panel 2: Histogram --- - ax = axes[0, 1] - counts, _bins, _ = ax.hist( - img.ravel(), bins=50, color=GRAY3, edgecolor=GRAY5, linewidth=0.5 - ) - ax.axvline( - x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="T=128 (Otsu)" - ) - ax.set_xlabel("Jasność", fontsize=FS_SMALL) - ax.set_ylabel("Piksele", fontsize=FS_SMALL) - ax.set_title("Krok 2: histogram\n(bimodalny!)", fontsize=FS, fontweight="bold") - ax.legend(fontsize=FS_TINY) - ax.annotate( - "Garb 1\n(obiekt)", - xy=(60, max(counts) * 0.5), - fontsize=FS_TINY, - ha="center", - color=ACCENT, - fontweight="bold", - ) - ax.annotate( - "Garb 2\n(tło)", - xy=(200, max(counts) * 0.5), - fontsize=FS_TINY, - ha="center", - color=RED_ACCENT, - fontweight="bold", - ) - - # --- Panel 3: Thresholding result --- - ax = axes[0, 2] - binary = (img > _OTSU_THRESHOLD).astype(float) - ax.imshow(binary, cmap="gray", vmin=0, vmax=1) - ax.set_title("Krok 3: progowanie T=128", fontsize=FS, fontweight="bold") - ax.axis("off") - ax.text(32, -3, "Biały = tło, Czarny = obiekt", fontsize=FS_TINY, ha="center") - - # --- Panels 4+5: Otsu variance plot + pseudocode --- - best_t = _draw_otsu_variance_and_pseudocode( - axes[1, 0], axes[1, 1], img, - ) - - # --- Panel 6: Final result with Otsu --- - ax = axes[1, 2] - binary_otsu = (img > best_t).astype(float) - ax.imshow(binary_otsu, cmap="gray", vmin=0, vmax=1) - ax.set_title(f"Krok 5: wynik Otsu (T={best_t})", fontsize=FS, fontweight="bold") - ax.axis("off") - ax.text( - 32, - -3, - "Automatyczny próg!", - fontsize=FS_TINY, - ha="center", - color=GREEN_ACCENT, - fontweight="bold", - ) - - _save_figure("q23_diy_thresholding.png") - - -def _draw_unet_layer_stack( - ax: Axes, - layer_sizes: list[tuple[int, int]], - *, - face_color: str, - edge_color: str, - arrow_color: str, - arrow_label: str, - add_skip: bool = False, -) -> None: - """Draw encoder or decoder layer stack for DIY U-Net.""" - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - - y_pos = 8.5 - for i, (s, c) in enumerate(layer_sizes): - w = s / 64 * 4 - h = 0.8 - rect = FancyBboxPatch( - (5 - w / 2, y_pos), w, h, - boxstyle="round,pad=0.05", - facecolor=face_color, edgecolor=edge_color, - linewidth=1, - ) - ax.add_patch(rect) - label = f"{s}x{s}x{c}" - if add_skip and i < len(layer_sizes) - 1: - label += " + skip!" - ax.text( - 5, y_pos + h / 2, label, - ha="center", va="center", - fontsize=FS_SMALL, fontweight="bold", - ) - if i < len(layer_sizes) - 1: - ax.annotate( - "", xy=(5, y_pos - 0.3), xytext=(5, y_pos), - arrowprops={ - "arrowstyle": "->", "color": arrow_color, - "lw": 1.5, - }, - ) - ax.text( - 7, y_pos - 0.15, arrow_label, - fontsize=FS_TINY, color=arrow_color, - ) - y_pos -= 2.2 - - -def _draw_unet_pseudocode(ax: Axes) -> None: - """Draw panel 6: U-Net pseudocode.""" - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title("Pseudokod U-Net", fontsize=FS, fontweight="bold") - - code_lines = [ - "# ENCODER", - "e1 = conv_block(input, 64) # 64x64", - "e2 = conv_block(pool(e1), 128) # 32x32", - "e3 = conv_block(pool(e2), 256) # 16x16", - "", - "# BOTTLENECK", - "b = conv_block(pool(e3), 512) # 8x8", - "", - "# DECODER + SKIP", - "d3 = conv_block(concat(", - " upconv(b), e3), 256) # 16x16", - "d2 = conv_block(concat(", - " upconv(d3), e2), 128) # 32x32", - "d1 = conv_block(concat(", - " upconv(d2), e1), 64) # 64x64", - "", - "output = conv_1x1(d1, n_classes)", - ] - for i, line in enumerate(code_lines): - txt_color = ( - ACCENT - if "concat" in line - else (GREEN_ACCENT if "output" in line else BLACK) - ) - ax.text( - 0.3, 9.5 - i * 0.55, line, - fontsize=FS_TINY, fontfamily="monospace", - color=txt_color, - ) - - -# ============================================================ -# 13. DIY U-NET — Simplified step-by-step -# ============================================================ -def generate_diy_unet() -> None: - """Generate diy unet.""" - fig, axes = plt.subplots(2, 3, figsize=(11, 7)) - - size = 64 - - # Create synthetic image with two regions - img = np.ones((size, size, 3), dtype=np.uint8) * 200 # bright bg - # Dark region (object 1) - yy, xx = np.mgrid[:size, :size] - mask1 = ((xx - 20) ** 2 + (yy - 30) ** 2) < 12**2 - img[mask1] = [60, 60, 60] - # Medium region (object 2) - mask2 = ((xx - 45) ** 2 + (yy - 25) ** 2) < 8**2 - img[mask2] = [120, 120, 120] - - gt = np.zeros((size, size), dtype=np.uint8) - gt[mask1] = 1 # class 1 - gt[mask2] = 2 # class 2 - - # --- Panel 1: Input image --- - ax = axes[0, 0] - ax.imshow(img) - ax.set_title("Krok 1: obraz RGB\n64x64x3", fontsize=FS, fontweight="bold") - ax.axis("off") - - # --- Panel 2: Encoder shrinks --- - ax = axes[0, 1] - ax.set_title("Krok 2: Encoder ZMNIEJSZA", fontsize=FS, fontweight="bold") - _draw_unet_layer_stack( - ax, [(64, 3), (32, 64), (16, 128), (8, 256)], - face_color=ACCENT_LIGHT, edge_color=ACCENT, - arrow_color=ACCENT, arrow_label="Conv+Pool", - ) - ax.text( - 5, 0.3, - "Wyciąga cechy:\nkrawędzie → tekstury → obiekty", - ha="center", fontsize=FS_TINY, color=GRAY5, - ) - - # --- Panel 3: Bottleneck --- - ax = axes[0, 2] - # Show feature maps at bottleneck (abstract) - ax.set_xlim(0, 10) - ax.set_ylim(0, 10) - ax.axis("off") - ax.set_title( - "Krok 3: Bottleneck\n(najbardziej abstrakcyjne cechy)", - fontsize=FS, - fontweight="bold", - ) - - # Show small abstract feature maps - for k in range(4): - small = rng.random((4, 4)) - ax_inset = fig.add_axes( - [0.68 + (k % 2) * 0.08, 0.72 - (k // 2) * 0.1, 0.06, 0.06] - ) - ax_inset.imshow(small, cmap="gray") - ax_inset.axis("off") - - ax.text( - 5, - 5, - '8x8x256\n\nMałe mapy, ale DUŻO kanałów\nKażdy kanał = jedna „cecha"\n' - '(np. kanał 42 = „wykrył koło"\n kanał 78 = „wykrył krawędź")\n\n' - "Wie CO jest na obrazie\nale nie wie GDZIE dokładnie", - ha="center", - va="center", - fontsize=FS_SMALL, - bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3}, - ) - - # --- Panel 4: Decoder enlarges --- - ax = axes[1, 0] - ax.set_title( - "Krok 4: Decoder ZWIĘKSZA\n(+ skip connections!)", - fontsize=FS, - fontweight="bold", - ) - _draw_unet_layer_stack( - ax, [(8, 256), (16, 128), (32, 64), (64, 3)], - face_color="#C8E6C9", edge_color=GREEN_ACCENT, - arrow_color=GREEN_ACCENT, arrow_label="UpConv+Concat", - add_skip=True, - ) - ax.text( - 5, 0.3, - "Odtwarza rozdzielczość:\nskip → przywraca krawędzie", - ha="center", fontsize=FS_TINY, color=GRAY5, - ) - - # --- Panel 5: Output segmentation map --- - ax = axes[1, 1] - cmap = plt.cm.colors.ListedColormap([WHITE, ACCENT_LIGHT, "#FFCDD2"]) - ax.imshow(gt, cmap=cmap, interpolation="nearest") - ax.set_title( - "Krok 5: mapa segmentacji\n64x64 (3 klasy)", fontsize=FS, fontweight="bold" - ) - ax.axis("off") - ax.text(20, -3, "Tło=0, obiekt A=1, obiekt B=2", fontsize=FS_TINY, ha="center") - - # --- Panel 6: Summary pseudocode --- - _draw_unet_pseudocode(axes[1, 2]) - - _save_figure("q23_diy_unet.png") - - -# ============================================================ -# 14. MNEMONICS — Visual mnemonic summary -# ============================================================ -def generate_mnemonics() -> None: - """Generate mnemonics.""" - _fig, ax = plt.subplots(1, 1, figsize=(10, 8)) - ax.set_xlim(0, 20) - ax.set_ylim(0, 16) - ax.axis("off") - ax.set_title( - "Mnemoniki — segmentacja obrazu", fontsize=FS_TITLE + 2, fontweight="bold" - ) - - def draw_card( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - title: str, - mnemonic: str, - color: str, - detail: str = "", - ) -> None: - """Draw card.""" - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.15", - facecolor=color, - edgecolor=BLACK, - linewidth=1, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h - 0.3, - title, - ha="center", - va="top", - fontsize=FS, - fontweight="bold", - ) - ax.text( - x + w / 2, - y + h / 2 - 0.1, - mnemonic, - ha="center", - va="center", - fontsize=FS_SMALL, - fontstyle="italic", - color=GRAY6, - ) - if detail: - ax.text( - x + w / 2, - y + 0.4, - detail, - ha="center", - va="bottom", - fontsize=FS_TINY, - color=GRAY5, - ) - - # Title: STRATEGIE KLASYCZNE - ax.text( - 5, - 15.5, - "STRATEGIE KLASYCZNE", - fontsize=FS_TITLE, - fontweight="bold", - color=ACCENT, - ha="center", - ) - - cards_classic = [ - ( - 0.2, - 12.5, - 4.5, - 2.5, - "Thresholding", - '„PRÓG na bramce"\nPrzepuszcza > T,\nblokuje ≤ T', - ACCENT_LIGHT, - "jasne=1, ciemne=0", - ), - ( - 5, - 12.5, - 4.5, - 2.5, - "Otsu", - '„AUTO-bramkarz"\nSam dobiera próg\nmin σ² wewnątrz', - ACCENT_LIGHT, - "histogram bimodalny", - ), - ( - 0.2, - 9.5, - 4.5, - 2.5, - "Region Growing", - '„PLAMA rozlana"\nSeed → BFS po\npodobnych sąsiadach', - ACCENT_LIGHT, - "jak atrament na papierze", - ), - ( - 5, - 9.5, - 4.5, - 2.5, - "Watershed", - '„ZALEWANIE terenu"\nDoliny=obiekty\nGranie=granice', - ACCENT_LIGHT, - "woda + geography", - ), - ( - 0.2, - 6.5, - 4.5, - 2.5, - "Mean Shift", - '„KULKI toczą się"\nKażda → max gęstości\nBez K!', - ACCENT_LIGHT, - "bandwidth = okno", - ), - ( - 5, - 6.5, - 4.5, - 2.5, - "Normalized Cuts", - '„CIĘCIE sznurków"\nGraf: tnij słabe\nkrawędzie (O(n³)!)', - ACCENT_LIGHT, - "eigenvector problem", - ), - ] - - for args in cards_classic: - draw_card(ax, *args) - - # Title: SIECI NEURONOWE - ax.text( - 15, - 15.5, - "SIECI NEURONOWE", - fontsize=FS_TITLE, - fontweight="bold", - color=GREEN_ACCENT, - ha="center", - ) - - cards_nn = [ - ( - 10.5, - 12.5, - 4.5, - 2.5, - "FCN (2015)", - '„FC → Conv 1x1"\nPierwsza end-to-end\nDowolny rozmiar', - "#C8E6C9", - "skip connections", - ), - ( - 15.3, - 12.5, - 4.5, - 2.5, - "U-Net (2015)", - '„Litera U"\nEncoder↓ Decoder↑\nSkip = concat', - "#C8E6C9", - "medycyna, małe dane", - ), - ( - 10.5, - 9.5, - 4.5, - 2.5, - "DeepLab v3+", - '„DZIURY w filtrze"\nAtrous conv (rate)\nASPP multi-scale', - "#C8E6C9", - "à trous = z dziurami", - ), - ( - 15.3, - 9.5, - 4.5, - 2.5, - "Transformer", - '„WSZYSCY ze\nWSZYSTKIMI"\nSelf-attention O(n²)', - "#C8E6C9", - "SegFormer, Mask2Former", - ), - ] - - for args in cards_nn: - draw_card(ax, *args) - - # Metryki - ax.text( - 10, - 8.3, - "METRYKI I LOSS", - fontsize=FS_TITLE, - fontweight="bold", - color=RED_ACCENT, - ha="center", - ) - - cards_metrics = [ - ( - 10.5, - 6.5, - 4.5, - 1.6, - "mIoU", - '„Nakładka / Suma"\nIoU = A∩B / A\u222aB', - "#FFCDD2", - "", - ), - ( - 15.3, - 6.5, - 4.5, - 1.6, - "Dice / Focal", - '„Dice=2·nakładka"\nFocal=trudne px', - "#FFCDD2", - "", - ), - ] - - for args in cards_metrics: - draw_card(ax, *args) - - # Master mnemonic at bottom - rect = FancyBboxPatch( - (1, 0.3), - 18, - 5.5, - boxstyle="round,pad=0.2", - facecolor=GRAY1, - edgecolor=BLACK, - linewidth=1.5, - ) - ax.add_patch(rect) - ax.text( - 10, - 5.3, - "SUPER-MNEMONIK: kolejność algorytmów segmentacji", - ha="center", - fontsize=FS, - fontweight="bold", - ) - ax.text( - 10, - 4.5, - '„TORW-MN FUD-T"', - ha="center", - fontsize=FS_TITLE + 2, - fontweight="bold", - color=RED_ACCENT, - ) - ax.text( - 10, - 3.5, - "Klasyczne: Thresholding → Otsu → Region" - " growing → Watershed → Mean shift → Norm. cuts", - ha="center", - fontsize=FS_SMALL, - ) - ax.text( - 10, - 2.8, - "Neuronowe: FCN → U-Net → DeepLab → Transformer", - ha="center", - fontsize=FS_SMALL, - ) - ax.text( - 10, - 1.8, - '„Turyści Oglądają Rzekę, Wodospad,' - ' Morze, Nurt — Fotografują Uroczy' - ' Dwór Tajemnic"', - ha="center", - fontsize=FS_SMALL, - fontstyle="italic", - color=ACCENT, - ) - ax.text( - 10, - 1.0, - "Klasyczne: proste→auto→BFS→flood→" - "gęstość→graf | Neuronowe:" - " FC→U-skip→dilated→attention", - ha="center", - fontsize=FS_TINY, - color=GRAY5, - ) - - _save_figure("q23_mnemonics.png") - - # ============================================================ # MAIN # ============================================================ diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py old mode 100755 new mode 100644 index 78dded6..f57b4aa --- a/python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py @@ -2,2274 +2,72 @@ """Generate ALL diagrams for PYTANIE 24: Detekcja obiektów. Monochrome, A4-printable PNGs (300 DPI). +Re-exports all diagram generators from submodules. """ from __future__ import annotations import logging from pathlib import Path -from typing import TYPE_CHECKING +import sys -import matplotlib as mpl +# Ensure sibling modules are importable when run as a script. +sys.path.insert(0, str(Path(__file__).resolve().parent)) -mpl.use("Agg") +from _q24_common import OUTPUT_DIR +from _q24_fpn_tasks_cnn import ( + draw_anchor_boxes, + draw_cnn_architecture, + draw_detection_tasks, + draw_fpn, +) +from _q24_haar_integral_svm import ( + draw_haar_features, + draw_integral_image, + draw_svm_hyperplane, +) +from _q24_hog_classical import ( + draw_hog_gradient_steps, + draw_hog_svm_pipeline, + draw_viola_jones_cascade, +) +from _q24_iou_nms_detector import ( + draw_detector_from_classifier, + draw_iou_diagram, + draw_nms_steps, +) +from _q24_modern_pipelines import ( + draw_detr_pipeline, + draw_roi_pooling, + draw_sliding_window, + draw_two_vs_one_stage, +) +from _q24_rcnn_yolo import draw_rcnn_evolution, draw_yolo_grid -import matplotlib.patches as mpatches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt -import numpy as np - -if TYPE_CHECKING: - from matplotlib.axes import Axes - from matplotlib.figure import Figure +__all__ = [ + "draw_anchor_boxes", + "draw_cnn_architecture", + "draw_detection_tasks", + "draw_detector_from_classifier", + "draw_detr_pipeline", + "draw_fpn", + "draw_haar_features", + "draw_hog_gradient_steps", + "draw_hog_svm_pipeline", + "draw_integral_image", + "draw_iou_diagram", + "draw_nms_steps", + "draw_rcnn_evolution", + "draw_roi_pooling", + "draw_sliding_window", + "draw_svm_hyperplane", + "draw_two_vs_one_stage", + "draw_viola_jones_cascade", + "draw_yolo_grid", +] _logger = logging.getLogger(__name__) -rng = np.random.default_rng(42) - -DPI = 300 -BG = "white" -LN = "black" -FS = 8 -FS_TITLE = 11 -FS_SMALL = 6.5 -FS_LABEL = 9 -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -GRAY1 = "#E8E8E8" -GRAY2 = "#D0D0D0" -GRAY3 = "#B8B8B8" -GRAY4 = "#F5F5F5" -GRAY5 = "#C0C0C0" - -_PIXEL_BRIGHT_THRESH = 127 -_GRADIENT_BRIGHT_THRESH = 100 -_DATA_BRIGHT_THRESH = 5 -_II_BRIGHT_THRESH = 25 -_DOTS_STAGE_IDX = 2 - - -def draw_box( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - text: str, - *, - fill: str = "white", - lw: float = 1.2, - fontsize: float = FS, - fontweight: str = "normal", - ha: str = "center", - va: str = "center", - rounded: bool = True, - edgecolor: str = LN, - linestyle: str = "-", -) -> None: - """Draw box.""" - if rounded: - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.05", - lw=lw, - edgecolor=edgecolor, - facecolor=fill, - linestyle=linestyle, - ) - else: - rect = mpatches.Rectangle( - (x, y), - w, - h, - lw=lw, - edgecolor=edgecolor, - facecolor=fill, - linestyle=linestyle, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h / 2, - text, - ha=ha, - va=va, - fontsize=fontsize, - fontweight=fontweight, - wrap=True, - ) - - -def draw_arrow( - ax: Axes, - x1: float, - y1: float, - x2: float, - y2: float, - *, - lw: float = 1.2, - style: str = "->", - color: str = LN, -) -> None: - """Draw arrow.""" - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": style, "color": color, "lw": lw}, - ) - - -def save_fig(fig: Figure, name: str) -> None: - """Save fig.""" - path = str(Path(OUTPUT_DIR) / name) - fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15) - plt.close(fig) - _logger.info(" Saved: %s", path) - - -def draw_table( - ax: Axes, - headers: list[str], - rows: list[list[str]], - x0: float, - y0: float, - col_widths: list[float], - *, - row_h: float = 0.4, - header_fill: str = GRAY2, - row_fills: list[str] | None = None, - fontsize: float = FS, - header_fontsize: float | None = None, -) -> None: - """Draw table.""" - if header_fontsize is None: - header_fontsize = fontsize - len(headers) - cx = x0 - for j, hdr in enumerate(headers): - draw_box( - ax, - cx, - y0, - col_widths[j], - row_h, - hdr, - fill=header_fill, - fontsize=header_fontsize, - fontweight="bold", - rounded=False, - ) - cx += col_widths[j] - for i, row in enumerate(rows): - cy = y0 - (i + 1) * row_h - cx = x0 - fill = GRAY4 if (i % 2 == 0) else "white" - if row_fills and i < len(row_fills): - fill = row_fills[i] - for j, cell in enumerate(row): - fw = "bold" if j == 0 else "normal" - draw_box( - ax, - cx, - cy, - col_widths[j], - row_h, - cell, - fill=fill, - fontsize=fontsize, - fontweight=fw, - rounded=False, - ) - cx += col_widths[j] - - -# ============================================================ -# 1. HOG + SVM Pipeline -# ============================================================ -def draw_hog_svm_pipeline() -> None: - """Draw hog svm pipeline.""" - fig, ax = plt.subplots(figsize=(10, 4.5)) - ax.set_xlim(-0.5, 10.5) - ax.set_ylim(-1, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "HOG + SVM — pipeline detekcji pieszych", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - # Step 1: Image with sliding window - ax.add_patch( - mpatches.Rectangle((0, 1.5), 2, 2, lw=1.5, edgecolor=LN, facecolor=GRAY1) - ) - ax.text(1, 2.5, "Obraz\nwejściowy", ha="center", va="center", fontsize=FS) - # sliding window overlay - ax.add_patch( - mpatches.Rectangle( - (0.3, 1.8), - 0.8, - 1.2, - lw=1.5, - edgecolor="black", - facecolor="none", - linestyle="--", - ) - ) - ax.text( - 0.7, - 1.35, - "okno 64x128", - ha="center", - va="center", - fontsize=FS_SMALL, - style="italic", - ) - - draw_arrow(ax, 2.1, 2.5, 2.8, 2.5, lw=1.5) - ax.text(2.45, 2.75, "①", ha="center", fontsize=FS_LABEL, fontweight="bold") - - # Step 2: Gradient computation - draw_box( - ax, 2.9, 1.8, 1.6, 1.4, "Oblicz\ngradienty\nGx, Gy", fill=GRAY4, fontsize=FS - ) - ax.text( - 3.7, 1.55, "kierunek + siła", ha="center", fontsize=FS_SMALL, style="italic" - ) - - draw_arrow(ax, 4.6, 2.5, 5.2, 2.5, lw=1.5) - ax.text(4.9, 2.75, "②", ha="center", fontsize=FS_LABEL, fontweight="bold") - - # Step 3: HOG histogram - draw_box( - ax, - 5.3, - 1.8, - 1.6, - 1.4, - "Histogramy\nkierunkowe\n9 binów/cel", - fill=GRAY4, - fontsize=FS, - ) - ax.text(6.1, 1.55, "komórki 8x8 px", ha="center", fontsize=FS_SMALL, style="italic") - - draw_arrow(ax, 7.0, 2.5, 7.6, 2.5, lw=1.5) - ax.text(7.3, 2.75, "③", ha="center", fontsize=FS_LABEL, fontweight="bold") - - # Step 4: SVM - draw_box( - ax, - 7.7, - 1.8, - 1.4, - 1.4, - "SVM\nklasyfikator\npieszy/tło", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - - draw_arrow(ax, 9.2, 2.5, 9.7, 2.5, lw=1.5) - ax.text(9.45, 2.75, "④", ha="center", fontsize=FS_LABEL, fontweight="bold") - - # Step 5: NMS + output - draw_box(ax, 9.3, 2.0, 1.0, 1.0, "NMS\n→ wynik", fill=GRAY1, fontsize=FS) - - # Bottom: HOG feature vector illustration - ax.text( - 5.0, - 0.7, - "Wektor HOG: 3780 cech = 105 bloków x 4 komórki x 9 binów", - ha="center", - fontsize=FS, - style="italic", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # Show small histogram bars - bar_x = 3.2 - bar_y = 0.0 - angles = [0, 20, 40, 60, 80, 100, 120, 140, 160] - values = [0.3, 0.1, 0.5, 0.8, 0.2, 0.6, 0.15, 0.4, 0.25] - for i, (_a, v) in enumerate(zip(angles, values, strict=False)): - ax.add_patch( - mpatches.Rectangle( - (bar_x + i * 0.18, bar_y), - 0.15, - v * 0.6, - facecolor=GRAY3, - edgecolor=LN, - lw=0.5, - ) - ) - ax.text(bar_x + 0.8, -0.2, "9 binów (0°-160°)", ha="center", fontsize=FS_SMALL) - - save_fig(fig, "q24_hog_svm_pipeline.png") - - -# ============================================================ -# 2. HOG Gradient Step-by-Step -# ============================================================ -def draw_hog_gradient_steps() -> None: - """Draw hog gradient steps.""" - fig, axes = plt.subplots(1, 4, figsize=(12, 3.5)) - fig.suptitle( - "HOG — kroki obliczania cech", fontsize=FS_TITLE, fontweight="bold", y=1.02 - ) - - # Step 1: Original patch - ax = axes[0] - patch = np.array([[50, 50, 200], [50, 50, 200], [50, 50, 200]]) - ax.imshow(patch, cmap="gray", vmin=0, vmax=255) - for i in range(3): - for j in range(3): - ax.text( - j, - i, - str(patch[i, j]), - ha="center", - va="center", - fontsize=FS_LABEL, - fontweight="bold", - color="white" if patch[i, j] > _PIXEL_BRIGHT_THRESH else "black", - ) - ax.set_title("① Fragment obrazu\n(jasność pikseli)", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Step 2: Gradient magnitude - ax = axes[1] - gx = np.array([[0, 150, 0], [0, 150, 0], [0, 150, 0]]) - ax.imshow(gx, cmap="gray", vmin=0, vmax=255) - for i in range(3): - for j in range(3): - ax.text( - j, - i, - str(gx[i, j]), - ha="center", - va="center", - fontsize=FS_LABEL, - fontweight="bold", - color="white" if gx[i, j] > _GRADIENT_BRIGHT_THRESH else "black", - ) - ax.set_title("② Gradient Gx\n(krawędź pionowa!)", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Step 3: Cell histogram - ax = axes[2] - angles = ["0°", "20°", "40°", "60°", "80°", "100°", "120°", "140°", "160°"] - values = [150, 0, 0, 0, 0, 0, 0, 0, 0] - bars = ax.bar(range(9), values, color=GRAY3, edgecolor=LN, linewidth=0.5) - bars[0].set_facecolor(GRAY5) - ax.set_xticks(range(9)) - ax.set_xticklabels(angles, fontsize=5, rotation=45) - ax.set_title( - "③ Histogram komórki\n(bin 0° = krawędź pionowa)", - fontsize=FS, - fontweight="bold", - ) - ax.set_ylabel("siła", fontsize=FS_SMALL) - - # Step 4: Block normalization - ax = axes[3] - # 2x2 block of cells - for i in range(2): - for j in range(2): - rect = mpatches.Rectangle( - (j * 1.2, (1 - i) * 1.2), - 1.0, - 1.0, - lw=1.2, - edgecolor=LN, - facecolor=GRAY4, - ) - ax.add_patch(rect) - ax.text( - j * 1.2 + 0.5, - (1 - i) * 1.2 + 0.5, - f"hist\n{i * 2 + j + 1}", - ha="center", - va="center", - fontsize=FS_SMALL, - ) - ax.add_patch( - mpatches.Rectangle( - (-0.1, -0.1), 2.6, 2.6, lw=2, edgecolor=LN, facecolor="none", linestyle="--" - ) - ) - ax.text( - 1.2, - -0.4, - "blok 2x2 → L2-norm", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - ax.set_xlim(-0.3, 2.8) - ax.set_ylim(-0.7, 2.8) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "④ Normalizacja bloków\n(odporność na oświetlenie)", - fontsize=FS, - fontweight="bold", - ) - - fig.tight_layout() - save_fig(fig, "q24_hog_gradient_steps.png") - - -# ============================================================ -# 3. Viola-Jones Cascade -# ============================================================ -def draw_viola_jones_cascade() -> None: - """Draw viola jones cascade.""" - fig, ax = plt.subplots(figsize=(10, 5)) - ax.set_xlim(-0.5, 10.5) - ax.set_ylim(-1.5, 5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Viola-Jones — kaskada klasyfikatorów (SITO)", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - # Input - draw_box( - ax, - -0.3, - 2.5, - 1.5, - 1.2, - "500 000\nokien", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - - stages = [ - ("Etap 1\n2 cechy", "50%\nodrzucone", "250 000", GRAY4), - ("Etap 2\n10 cech", "80%\nodrzucone", "50 000", GRAY4), - ("Etap 3\n25 cech", "90%\nodrzucone", "5 000", GRAY4), - ("Etap 25\n200 cech", "99%\nodrzucone", "50", GRAY3), - ] - - x_pos = 1.6 - for i, (label, reject, remain, col) in enumerate(stages): - # Stage box - draw_box( - ax, x_pos, 2.5, 1.6, 1.2, label, fill=col, fontsize=FS, fontweight="bold" - ) - - # Arrow from previous - draw_arrow(ax, x_pos - 0.3, 3.1, x_pos - 0.05, 3.1, lw=1.5) - - # Reject arrow down - draw_arrow(ax, x_pos + 0.8, 2.45, x_pos + 0.8, 1.6, lw=1.2) - ax.text( - x_pos + 0.8, - 1.3, - reject, - ha="center", - fontsize=FS_SMALL, - color="black", - style="italic", - ) - ax.text( - x_pos + 0.8, - 0.8, - "✗ NIE-TWARZ", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - - # Remaining count above - if i < len(stages) - 1: - ax.text( - x_pos + 2.0, - 3.9, - f"→ {remain}", - ha="center", - fontsize=FS_SMALL, - style="italic", - ) - - # Dots between stage 3 and stage 25 - if i == _DOTS_STAGE_IDX: - ax.text( - x_pos + 2.0, 3.1, "· · ·", ha="center", fontsize=12, fontweight="bold" - ) - x_pos += 2.5 - else: - x_pos += 2.1 - - # Final output - draw_arrow(ax, x_pos + 0.3, 3.1, x_pos + 0.9, 3.1, lw=1.5) - draw_box( - ax, - x_pos + 0.5, - 2.5, - 1.3, - 1.2, - "~50\nTWARZE\n✓", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Timing info - ax.text( - 5.0, - -0.5, - "Czas: 99% okien odrzucone w etapach 1-3 (~5 μs każde)\n" - "Tylko 0.01% dochodzi do etapu 25 → cały obraz w ~30 ms = 30+ fps", - ha="center", - fontsize=FS, - style="italic", - bbox={"boxstyle": "round,pad=0.4", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q24_viola_jones_cascade.png") - - -# ============================================================ -# 4. Haar Features -# ============================================================ -def draw_haar_features() -> None: - """Draw haar features.""" - fig, axes = plt.subplots(1, 4, figsize=(11, 3)) - fig.suptitle( - "Cechy Haar — typy i zastosowanie na twarzy", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Feature 1: Vertical edge - ax = axes[0] - ax.add_patch( - mpatches.Rectangle((0, 0), 1, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - ax.add_patch( - mpatches.Rectangle((1, 0), 1, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5) - ) - ax.text( - 0.5, 1, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" - ) - ax.text( - 1.5, 1, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" - ) - ax.set_xlim(-0.2, 2.2) - ax.set_ylim(-0.5, 2.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Krawędź pionowa\nwartość = Σ₁ - Σ₂", fontsize=FS) - - # Feature 2: Horizontal edge - ax = axes[1] - ax.add_patch( - mpatches.Rectangle((0, 1), 2, 1, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - ax.add_patch( - mpatches.Rectangle((0, 0), 2, 1, facecolor=GRAY3, edgecolor=LN, lw=1.5) - ) - ax.text( - 1, 1.5, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" - ) - ax.text( - 1, 0.5, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold" - ) - ax.set_xlim(-0.2, 2.2) - ax.set_ylim(-0.5, 2.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Krawędź pozioma\n(oczy vs czoło)", fontsize=FS) - - # Feature 3: Three-rectangle (line) - ax = axes[2] - ax.add_patch( - mpatches.Rectangle((0, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5) - ) - ax.add_patch( - mpatches.Rectangle((0.7, 0), 0.7, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - ax.add_patch( - mpatches.Rectangle((1.4, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5) - ) - ax.text( - 0.35, 1, "-Σ₁", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold" - ) - ax.text( - 1.05, 1, "+Σ₂", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold" - ) - ax.text( - 1.75, 1, "-Σ₃", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold" - ) - ax.set_xlim(-0.2, 2.3) - ax.set_ylim(-0.5, 2.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Linia (3 prostokąty)\n(nos vs policzki)", fontsize=FS) - - _draw_haar_face_panel(axes[3]) - - fig.tight_layout() - save_fig(fig, "q24_haar_features.png") - - -def _draw_haar_face_panel(ax: Axes) -> None: - """Draw Haar feature application on face schematic.""" - face = mpatches.Ellipse( - (1.2, 1.2), 2.0, 2.4, facecolor=GRAY4, edgecolor=LN, lw=1.5, - ) - ax.add_patch(face) - ax.add_patch( - mpatches.Ellipse((0.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1) - ) - ax.add_patch( - mpatches.Ellipse((1.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1) - ) - ax.plot([1.2, 1.1, 1.3], [1.3, 0.9, 0.9], color=LN, lw=1) - ax.plot([0.8, 1.0, 1.2, 1.4, 1.6], [0.55, 0.5, 0.55, 0.5, 0.55], color=LN, lw=1) - ax.add_patch( - mpatches.Rectangle( - (0.3, 1.4), 1.8, 0.4, facecolor="none", edgecolor=LN, lw=2, linestyle="--", - ) - ) - ax.annotate( - "cechy Haar\n(oczy ciemne\nvs czoło jasne)", - xy=(1.2, 1.85), - xytext=(2.2, 2.3), - fontsize=FS_SMALL, - ha="center", - arrowprops={"arrowstyle": "->", "color": LN, "lw": 1}, - ) - ax.set_xlim(-0.2, 3.0) - ax.set_ylim(-0.2, 2.8) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Zastosowanie na twarzy", fontsize=FS) - - -# ============================================================ -# 5. Integral Image -# ============================================================ -def draw_integral_image() -> None: - """Draw integral image.""" - fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) - fig.suptitle( - "Integral Image — suma prostokąta w O(1)", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Original image - ax = axes[0] - data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - ax.imshow(data, cmap="gray", vmin=0, vmax=10) - for i in range(3): - for j in range(3): - ax.text( - j, - i, - str(data[i, j]), - ha="center", - va="center", - fontsize=12, - fontweight="bold", - color="white" if data[i, j] > _DATA_BRIGHT_THRESH else "black", - ) - ax.set_title("① Obraz oryginalny", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Integral image - ax = axes[1] - ii = np.array([[1, 3, 6], [5, 12, 21], [12, 27, 45]]) - ax.imshow(ii, cmap="gray", vmin=0, vmax=50) - for i in range(3): - for j in range(3): - ax.text( - j, - i, - str(ii[i, j]), - ha="center", - va="center", - fontsize=12, - fontweight="bold", - color="white" if ii[i, j] > _II_BRIGHT_THRESH else "black", - ) - ax.set_title("② Integral Image\n(sumy kumulatywne)", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Formula illustration - ax = axes[2] - ax.axis("off") - ax.set_xlim(0, 4) - ax.set_ylim(0, 4) - # Draw rectangle - ax.add_patch( - mpatches.Rectangle((0.5, 0.5), 3, 3, facecolor="white", edgecolor=LN, lw=1) - ) - ax.add_patch( - mpatches.Rectangle((1.5, 0.5), 2, 2, facecolor=GRAY3, edgecolor=LN, lw=2) - ) - # Labels - ax.text(0.3, 3.7, "A", fontsize=12, fontweight="bold") - ax.text(3.6, 3.7, "B", fontsize=12, fontweight="bold") - ax.text(0.3, 0.3, "C", fontsize=12, fontweight="bold") - ax.text(3.6, 0.3, "D", fontsize=12, fontweight="bold") - ax.text( - 2.5, - 1.5, - "SZUKANA\nSUMA", - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - ax.text( - 2.0, - -0.3, - "Suma = D - B - C + A\n= 4 odczyty → O(1) ZAWSZE!", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - ax.set_title( - "③ Formuła: 4 odczyty\n= O(1) niezależnie od rozmiaru", - fontsize=FS, - fontweight="bold", - ) - - fig.tight_layout() - save_fig(fig, "q24_integral_image.png") - - -# ============================================================ -# 6. R-CNN Evolution -# ============================================================ -def draw_rcnn_evolution() -> None: - """Draw rcnn evolution.""" - fig, ax = plt.subplots(figsize=(11, 7)) - ax.set_xlim(-0.5, 11) - ax.set_ylim(-0.5, 7.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Ewolucja R-CNN: od 50s do 0.2s na obraz", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - y_positions = [5.5, 3.0, 0.5] - labels = [ - "R-CNN (2014) — 50 s/obraz", - "Fast R-CNN (2015) — 2 s/obraz", - "Faster R-CNN (2015) — 0.2 s/obraz", - ] - - # R-CNN - y = y_positions[0] - ax.text(0, y + 1.3, labels[0], fontsize=FS_LABEL, fontweight="bold") - draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS) - draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45) - ax.text(2.3, y + 0.8, "~2000", ha="center", fontsize=FS_SMALL, style="italic") - draw_box(ax, 2.6, y, 1.5, 0.9, "Resize\n224x224", fill=GRAY4, fontsize=FS) - draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45) - draw_box( - ax, 4.7, y, 1.5, 0.9, "CNN\nx2000!", fill=GRAY3, fontsize=FS, fontweight="bold" - ) - draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45) - draw_box(ax, 6.8, y, 1.3, 0.9, "SVM\nklasyf.", fill=GRAY4, fontsize=FS) - draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45) - draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS) - # Problem annotation - ax.text( - 5.5, - y - 0.4, - "⚠ CNN uruchamiane 2000x → 50 sek!", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # Fast R-CNN - y = y_positions[1] - ax.text(0, y + 1.3, labels[1], fontsize=FS_LABEL, fontweight="bold") - draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS) - draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45) - draw_box( - ax, - 2.6, - y, - 1.5, - 0.9, - "CNN\nx1 (RAZ!)", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45) - draw_box( - ax, 4.7, y, 1.5, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS, fontweight="bold" - ) - draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45) - draw_box(ax, 6.8, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS) - draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45) - draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS) - ax.text( - 3.8, - y - 0.4, - "✓ CNN RAZ na cały obraz → 25x szybciej", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # Faster R-CNN - y = y_positions[2] - ax.text(0, y + 1.3, labels[2], fontsize=FS_LABEL, fontweight="bold") - draw_box( - ax, - 0.5, - y, - 1.5, - 0.9, - "CNN\nBackbone", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45) - draw_box(ax, 2.6, y, 1.5, 0.9, "Feature\nMap", fill=GRAY1, fontsize=FS) - draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45) - draw_box( - ax, - 4.7, - y, - 1.3, - 0.9, - "RPN\n(w sieci!)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 6.1, y + 0.45, 6.5, y + 0.45) - draw_box(ax, 6.6, y, 1.3, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS) - draw_arrow(ax, 8.0, y + 0.45, 8.4, y + 0.45) - draw_box(ax, 8.5, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS) - ax.text( - 5.0, - y - 0.4, - "✓ RPN zastępuje Selective Search → end-to-end", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q24_rcnn_evolution.png") - - -def _draw_yolo_cell_prediction(ax: Axes) -> None: - """Draw YOLO cell prediction vector panel.""" - ax.axis("off") - ax.set_xlim(0, 6) - ax.set_ylim(-1, 5) - - labels = [ - "x", "y", "w", "h", "conf", - "x", "y", "w", "h", "conf", - "P(c₁)", "...", "P(c₂₀)", - ] - colors_vec = [GRAY4] * 5 + [GRAY2] * 5 + [GRAY1] * 3 - - bw = 0.42 - for i, (label, col) in enumerate( - zip(labels, colors_vec, strict=False), - ): - x_pos = 0.3 + i * bw - ax.add_patch( - mpatches.Rectangle( - (x_pos, 2.5), bw - 0.02, 0.6, - facecolor=col, edgecolor=LN, lw=0.8, - ) - ) - ax.text( - x_pos + bw / 2, - 2.8, - label, - ha="center", - va="center", - fontsize=5, - fontweight="bold", - ) - - ax.annotate( - "", xy=(0.3, 2.4), xytext=(2.4, 2.4), - arrowprops={"arrowstyle": "-", "lw": 1}, - ) - ax.text(1.35, 2.15, "bbox 1 (5 wartości)", ha="center", fontsize=FS_SMALL) - - ax.annotate( - "", xy=(2.4, 2.4), xytext=(4.5, 2.4), - arrowprops={"arrowstyle": "-", "lw": 1}, - ) - ax.text(3.45, 2.15, "bbox 2 (5 wartości)", ha="center", fontsize=FS_SMALL) - - ax.annotate( - "", xy=(4.5, 2.4), xytext=(5.8, 2.4), - arrowprops={"arrowstyle": "-", "lw": 1}, - ) - ax.text(5.15, 2.15, "20 klas", ha="center", fontsize=FS_SMALL) - - ax.text( - 3.0, - 3.5, - "Każda komórka → 30 wartości\n= 2x(x,y,w,h,conf) + 20 klas", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.set_title( - "② Predykcja jednej komórki\n(S=7, B=2, C=20)", - fontsize=FS, - fontweight="bold", - ) - - -# ============================================================ -# 7. YOLO Grid -# ============================================================ -def draw_yolo_grid() -> None: - """Draw yolo grid.""" - fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - fig.suptitle( - "YOLO — detekcja jednoetapowa (siatka SxS)", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Grid on image - ax = axes[0] - grid_size = 7 - ax.set_xlim(0, grid_size) - ax.set_ylim(0, grid_size) - for i in range(grid_size + 1): - ax.axhline(y=i, color=LN, lw=0.5, alpha=0.5) - ax.axvline(x=i, color=LN, lw=0.5, alpha=0.5) - ax.add_patch( - mpatches.Rectangle( - (0, 0), grid_size, grid_size, - facecolor=GRAY4, edgecolor=LN, lw=1.5, - ) - ) - # Highlight one cell - ax.add_patch(mpatches.Rectangle((3, 3), 1, 1, facecolor=GRAY2, edgecolor=LN, lw=2)) - # Object center dot - ax.plot(3.5, 3.5, "ko", markersize=8) - # Bounding box from that cell - ax.add_patch( - mpatches.Rectangle( - (2.0, 2.2), 3.0, 2.6, facecolor="none", edgecolor=LN, lw=2, linestyle="--" - ) - ) - ax.text( - 3.5, - 1.8, - "bbox z komórki (3,3)", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - ax.set_aspect("equal") - ax.invert_yaxis() - ax.set_title("① Siatka 7x7\nna obrazie", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - _draw_yolo_cell_prediction(axes[1]) - - # Speed comparison - ax = axes[2] - ax.axis("off") - ax.set_xlim(0, 5) - ax.set_ylim(0, 5) - - methods = ["R-CNN", "Fast R-CNN", "Faster R-CNN", "YOLO", "YOLOv8"] - fps_vals = [0.02, 0.5, 5, 45, 100] - bar_colors = [GRAY3, GRAY3, GRAY3, GRAY2, GRAY1] - - for i, (m, f, c) in enumerate(zip(methods, fps_vals, bar_colors, strict=False)): - bar_w = f / 100 * 4.0 - y_pos = 4.0 - i * 0.8 - ax.add_patch( - mpatches.Rectangle( - (0.5, y_pos), max(bar_w, 0.1), 0.5, facecolor=c, edgecolor=LN, lw=0.8 - ) - ) - ax.text( - 0.4, - y_pos + 0.25, - m, - ha="right", - va="center", - fontsize=FS, - fontweight="bold", - ) - ax.text( - max(0.7, 0.5 + bar_w + 0.1), - y_pos + 0.25, - f"{f} fps", - ha="left", - va="center", - fontsize=FS, - ) - - ax.set_title( - "③ Porównanie szybkości\n(fps = klatki/sek)", fontsize=FS, fontweight="bold" - ) - - fig.tight_layout() - save_fig(fig, "q24_yolo_grid.png") - - -# ============================================================ -# 8. IoU Diagram -# ============================================================ -def draw_iou_diagram() -> None: - """Draw iou diagram.""" - fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) - fig.suptitle( - "IoU (Intersection over Union) — miara nakładania bboxów", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Low IoU - ax = axes[0] - ax.add_patch( - mpatches.Rectangle( - (0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5, label="A" - ) - ) - ax.add_patch( - mpatches.Rectangle( - (2.5, 2.5), - 3, - 3, - facecolor=GRAY2, - edgecolor=LN, - lw=1.5, - alpha=0.7, - label="B", - ) - ) - # Intersection - ax.add_patch( - mpatches.Rectangle((2.5, 2.5), 0.5, 0.5, facecolor=GRAY3, edgecolor=LN, lw=2) - ) - ax.text(1.5, 1.5, "A", ha="center", va="center", fontsize=12, fontweight="bold") - ax.text(4, 4, "B", ha="center", va="center", fontsize=12, fontweight="bold") - ax.set_xlim(-0.5, 6) - ax.set_ylim(-0.5, 6) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "IoU ≈ 0.04\n(prawie się nie nakładają)", fontsize=FS, fontweight="bold" - ) - - # Medium IoU - ax = axes[1] - ax.add_patch( - mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - ax.add_patch( - mpatches.Rectangle( - (1.5, 1.5), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7 - ) - ) - ax.add_patch( - mpatches.Rectangle((1.5, 1.5), 1.5, 1.5, facecolor=GRAY3, edgecolor=LN, lw=2) - ) - ax.text(0.7, 0.7, "A", ha="center", va="center", fontsize=12, fontweight="bold") - ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold") - ax.text(2.25, 2.25, "∩", ha="center", va="center", fontsize=14, fontweight="bold") - ax.set_xlim(-0.5, 5) - ax.set_ylim(-0.5, 5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("IoU ≈ 0.14\n(częściowe nakładanie)", fontsize=FS, fontweight="bold") - - # High IoU - ax = axes[2] - ax.add_patch( - mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - ax.add_patch( - mpatches.Rectangle( - (0.3, 0.3), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7 - ) - ) - ax.add_patch( - mpatches.Rectangle((0.3, 0.3), 2.7, 2.7, facecolor=GRAY3, edgecolor=LN, lw=2) - ) - ax.text(-0.3, -0.3, "A", ha="center", va="center", fontsize=12, fontweight="bold") - ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold") - ax.text(1.65, 1.65, "∩", ha="center", va="center", fontsize=14, fontweight="bold") - ax.set_xlim(-0.8, 4) - ax.set_ylim(-0.8, 4) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "IoU ≈ 0.74\n(duże nakładanie → duplikat!)", fontsize=FS, fontweight="bold" - ) - - fig.tight_layout() - save_fig(fig, "q24_iou_diagram.png") - - -# ============================================================ -# 9. NMS Step-by-Step -# ============================================================ -def draw_nms_steps() -> None: - """Draw nms steps.""" - fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - fig.suptitle( - "NMS (Non-Maximum Suppression) — usuwanie duplikatów", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Before NMS - ax = axes[0] - ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1)) - # Multiple overlapping boxes for same object - ax.add_patch( - mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2) - ) - ax.text(2.25, 4.2, "conf=0.95", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.add_patch( - mpatches.Rectangle( - (1.2, 1.3), 2.3, 2.8, facecolor="none", edgecolor=LN, lw=1.5, linestyle="--" - ) - ) - ax.text(2.35, 1.1, "conf=0.90", ha="center", fontsize=FS_SMALL) - ax.add_patch( - mpatches.Rectangle( - (0.8, 0.8), 2.7, 3.2, facecolor="none", edgecolor=LN, lw=1, linestyle=":" - ) - ) - ax.text(2.15, 0.6, "conf=0.85", ha="center", fontsize=FS_SMALL) - # Different object - ax.add_patch( - mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=1.5) - ) - ax.text(4.75, 3.7, "conf=0.80", ha="center", fontsize=FS_SMALL) - ax.text( - 2, - 0.2, - "⚠ 4 detekcje (3 duplikaty!)", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - ax.set_xlim(-0.3, 6.3) - ax.set_ylim(-0.3, 5.3) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "① Przed NMS\n(wiele nakładających się)", fontsize=FS, fontweight="bold" - ) - - # NMS process - ax = axes[1] - ax.axis("off") - ax.set_xlim(0, 6) - ax.set_ylim(0, 5) - - steps = [ - ("1. Sortuj: [0.95, 0.90, 0.85, 0.80]", 4.5), - ("2. Weź najlepszą (0.95) → ZACHOWAJ", 3.7), - ("3. IoU(0.95, 0.90)=0.82 > 0.5 → USUŃ", 2.9), - ("4. IoU(0.95, 0.85)=0.75 > 0.5 → USUŃ", 2.1), - ("5. IoU(0.95, 0.80)=0.10 < 0.5 → ZACHOWAJ", 1.3), - ] - colors = [GRAY4, GRAY2, GRAY4, GRAY4, GRAY2] - for (text, yp), c in zip(steps, colors, strict=False): - ax.text( - 3.0, - yp, - text, - ha="center", - fontsize=FS, - bbox={"boxstyle": "round,pad=0.2", "facecolor": c, "edgecolor": GRAY3}, - ) - - ax.set_title("② Algorytm NMS\n(próg IoU = 0.5)", fontsize=FS, fontweight="bold") - - # After NMS - ax = axes[2] - ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1)) - # Only best box for each object - ax.add_patch( - mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2.5) - ) - ax.text(2.25, 4.2, "conf=0.95 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.add_patch( - mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5) - ) - ax.text(4.75, 3.7, "conf=0.80 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.text( - 3, - 0.2, - "✓ 2 unikalne obiekty", - ha="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - ax.set_xlim(-0.3, 6.3) - ax.set_ylim(-0.3, 5.3) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("③ Po NMS\n(1 bbox na obiekt)", fontsize=FS, fontweight="bold") - - fig.tight_layout() - save_fig(fig, "q24_nms_steps.png") - - -# ============================================================ -# 10. Detector from Classifier — 3 approaches -# ============================================================ -def draw_detector_from_classifier() -> None: - """Draw detector from classifier.""" - fig, ax = plt.subplots(figsize=(11, 9)) - ax.set_xlim(-0.5, 11) - ax.set_ylim(-1, 9.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Jak zbudować detektor z klasyfikatora? — 3 podejścia", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - # ---- Approach 1: Sliding Window ---- - y = 7.0 - ax.text( - 0, - y + 1.5, - "① Sliding Window (NAJWOLNIEJSZE)", - fontsize=FS_LABEL, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # Image with sliding window - ax.add_patch( - mpatches.Rectangle( - (0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5 - ) - ) - ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL) - # Sliding windows - for dx, dy in [(0.1, 0.1), (0.4, 0.1), (0.7, 0.1), (0.1, 0.5), (0.4, 0.5)]: - ax.add_patch( - mpatches.Rectangle( - (dx, y - 0.5 + dy), - 0.5, - 0.5, - facecolor="none", - edgecolor=LN, - lw=0.8, - linestyle="--", - ) - ) - - draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2) - ax.text(2.35, y + 0.6, "xmiliony", fontsize=FS_SMALL, style="italic") - - draw_box( - ax, - 2.8, - y - 0.3, - 1.8, - 1.2, - 'Klasyfikator\n(ResNet)\n"kot? pies? tło?"', - fill=GRAY4, - fontsize=FS, - ) - draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2) - draw_box(ax, 5.4, y - 0.3, 1.2, 1.2, "NMS", fill=GRAY1, fontsize=FS) - draw_arrow(ax, 6.7, y + 0.3, 7.3, y + 0.3, lw=1.2) - ax.text( - 8.5, - y + 0.3, - "~3.3h / obraz!\n⚠ NIEPRAKTYCZNE", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # ---- Approach 2: Region Proposals ---- - y = 3.8 - ax.text( - 0, - y + 1.5, - "② Region Proposals + Klasyfikator (= R-CNN)", - fontsize=FS_LABEL, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.add_patch( - mpatches.Rectangle( - (0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5 - ) - ) - ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL) - # A few smart regions - ax.add_patch( - mpatches.Rectangle( - (0.1, y - 0.4), 0.7, 0.9, facecolor="none", edgecolor=LN, lw=1.5 - ) - ) - ax.add_patch( - mpatches.Rectangle( - (0.9, y + 0.0), 0.7, 0.6, facecolor="none", edgecolor=LN, lw=1.5 - ) - ) - - draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2) - draw_box( - ax, - 2.8, - y - 0.3, - 1.6, - 1.2, - "Selective\nSearch\n~2000 regionów", - fill=GRAY2, - fontsize=FS, - ) - draw_arrow(ax, 4.5, y + 0.3, 5.1, y + 0.3, lw=1.2) - ax.text(4.8, y + 0.6, "x2000", fontsize=FS_SMALL, style="italic") - draw_box(ax, 5.2, y - 0.3, 1.5, 1.2, "Klasyfikator\n(CNN)", fill=GRAY4, fontsize=FS) - draw_arrow(ax, 6.8, y + 0.3, 7.4, y + 0.3, lw=1.2) - draw_box(ax, 7.5, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS) - draw_arrow(ax, 8.6, y + 0.3, 9.0, y + 0.3, lw=1.2) - ax.text( - 10.0, - y + 0.3, - "~20-50 s/obraz\n(250x szybciej)", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # ---- Approach 3: Fine-tune backbone ---- - y = 0.5 - ax.text( - 0, - y + 1.5, - "③ Fine-tune backbone + detection head (NAJLEPSZE)", - fontsize=FS_LABEL, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY2, "edgecolor": GRAY3}, - ) - - ax.add_patch( - mpatches.Rectangle( - (0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5 - ) - ) - ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL) - - draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2) - draw_box( - ax, - 2.8, - y - 0.3, - 1.8, - 1.2, - "Pretrained\nbackbone\n(ResNet)", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2) - - # Two heads from feature map - draw_box(ax, 5.4, y + 0.3, 1.6, 0.6, "cls head\nP(klasa)", fill=GRAY4, fontsize=FS) - draw_box( - ax, 5.4, y - 0.5, 1.6, 0.6, "bbox head\nΔx,Δy,Δw,Δh", fill=GRAY4, fontsize=FS - ) - - draw_arrow(ax, 7.1, y + 0.6, 7.7, y + 0.6, lw=1.0) - draw_arrow(ax, 7.1, y - 0.2, 7.7, y - 0.2, lw=1.0) - draw_box(ax, 7.8, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS) - - draw_arrow(ax, 8.9, y + 0.3, 9.3, y + 0.3, lw=1.2) - ax.text( - 10.2, - y + 0.3, - "5-155 fps!\n✓ NAJLEPSZE", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q24_detector_from_classifier.png") - - -# ============================================================ -# 11. SVM Hyperplane -# ============================================================ -def draw_svm_hyperplane() -> None: - """Draw svm hyperplane.""" - fig, ax = plt.subplots(figsize=(6, 5)) - ax.set_title( - "SVM — hiperpłaszczyzna i margines", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - x_pos = rng.standard_normal(15) * 0.5 + 3 - y_pos = rng.standard_normal(15) * 0.5 + 3 - ax.scatter( - x_pos, - y_pos, - marker="o", - s=50, - facecolors="white", - edgecolors=LN, - linewidths=1.5, - label="klasa +1 (pieszy)", - zorder=3, - ) - - x_neg = rng.standard_normal(15) * 0.5 + 1 - y_neg = rng.standard_normal(15) * 0.5 + 1 - ax.scatter( - x_neg, - y_neg, - marker="x", - s=50, - c=LN, - linewidths=1.5, - label="klasa -1 (tło)", - zorder=3, - ) - - # Hyperplane (decision boundary) - x_line = np.linspace(-0.5, 5, 100) - y_line = -x_line + 4.0 - ax.plot(x_line, y_line, "k-", lw=2, label="hiperpłaszczyzna") - - # Margin lines - ax.plot(x_line, y_line + 0.7, "k--", lw=1, alpha=0.5) - ax.plot(x_line, y_line - 0.7, "k--", lw=1, alpha=0.5) - - # Margin annotation - ax.annotate( - "", - xy=(2.5, 1.5 + 0.7), - xytext=(2.5, 1.5 - 0.7), - arrowprops={"arrowstyle": "<->", "color": LN, "lw": 1.5}, - ) - ax.text(2.8, 1.5, "margines\n(MAX!)", fontsize=FS, fontweight="bold") - - # Support vectors (highlight closest points) - # Find points closest to the line - ax.scatter( - [2.5], - [2.2], - marker="o", - s=120, - facecolors="none", - edgecolors=LN, - linewidths=2.5, - zorder=4, - ) - ax.scatter([1.5], [1.8], marker="x", s=120, c=LN, linewidths=2.5, zorder=4) - ax.annotate( - "support\nvectors", - xy=(1.5, 1.8), - xytext=(0.2, 3.0), - fontsize=FS, - fontweight="bold", - arrowprops={"arrowstyle": "->", "color": LN, "lw": 1}, - ) - - ax.set_xlim(-0.5, 5) - ax.set_ylim(-0.5, 5) - ax.set_xlabel("cecha 1 (np. gradient pionowy)", fontsize=FS) - ax.set_ylabel("cecha 2 (np. gradient poziomy)", fontsize=FS) - ax.legend(fontsize=FS_SMALL, loc="lower right") - ax.set_aspect("equal") - - save_fig(fig, "q24_svm_hyperplane.png") - - -# ============================================================ -# 12. Two-stage vs One-stage comparison table -# ============================================================ -def draw_two_vs_one_stage() -> None: - """Draw two vs one stage.""" - fig, ax = plt.subplots(figsize=(10, 3.5)) - ax.set_xlim(0, 10) - ax.set_ylim(-0.5, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Two-stage vs One-stage — porównanie", - fontsize=FS_TITLE, - fontweight="bold", - pad=8, - ) - - headers = ["Cecha", "Two-stage\n(Faster R-CNN)", "One-stage\n(YOLO)"] - rows = [ - ["Szybkość", "~5 fps", "45-155 fps"], - ["Dokładność (mAP)", "wyższa (historycznie)", "dorównuje (YOLOv8)"], - ["Małe obiekty", "lepszy", "gorszy (SSD/FPN pomaga)"], - ["Architektura", "2 etapy + NMS", "1 etap + NMS"], - ["Real-time?", "NIE", "TAK"], - ] - col_widths = [2.5, 3.5, 3.5] - draw_table( - ax, - headers, - rows, - 0.2, - 3.8, - col_widths, - row_h=0.65, - fontsize=FS, - header_fontsize=FS, - ) - - save_fig(fig, "q24_two_vs_one_stage.png") - - -# ============================================================ -# 13. ROI Pooling illustration -# ============================================================ -def draw_roi_pooling() -> None: - """Draw roi pooling.""" - fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - fig.suptitle( - "ROI Pooling — dowolny rozmiar → stały rozmiar", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Feature map with ROI - ax = axes[0] - # Draw feature map grid - fm = rng.integers(0, 10, (8, 8)) - ax.imshow(fm, cmap="gray", vmin=0, vmax=10, alpha=0.3) - for i in range(9): - ax.axhline(y=i - 0.5, color=LN, lw=0.3) - ax.axvline(x=i - 0.5, color=LN, lw=0.3) - # ROI rectangle - ax.add_patch( - mpatches.Rectangle( - (1.5, 1.5), 4, 4, facecolor="none", edgecolor=LN, lw=3, linestyle="-" - ) - ) - ax.text(3.5, 0.8, "ROI", ha="center", fontsize=FS_LABEL, fontweight="bold") - ax.set_xlim(-0.5, 7.5) - ax.set_ylim(7.5, -0.5) - ax.set_title("① Feature map\nz zaznaczonym ROI", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # ROI divided into grid - ax = axes[1] - roi_data = np.array( - [ - [1, 3, 2, 1], - [0, 5, 1, 6], - [0, 4, 1, 0], - [7, 2, 9, 1], - ] - ) - ax.imshow(roi_data, cmap="gray", vmin=0, vmax=10) - for i in range(5): - ax.axhline(y=i - 0.5, color=LN, lw=1) - ax.axvline(x=i - 0.5, color=LN, lw=1) - # Grid lines for 2x2 pooling - ax.axhline(y=0.5, color=LN, lw=3, linestyle="--") - ax.axvline(x=0.5, color=LN, lw=3, linestyle="--") - for i in range(4): - for j in range(4): - ax.text( - j, - i, - str(roi_data[i, j]), - ha="center", - va="center", - fontsize=10, - fontweight="bold", - color="white" if roi_data[i, j] > _DATA_BRIGHT_THRESH else "black", - ) - ax.set_title("② ROI podzielony\nna siatkę 2x2", fontsize=FS, fontweight="bold") - ax.set_xticks([]) - ax.set_yticks([]) - - # Output after pooling - ax = axes[2] - out = np.array([[5, 6], [7, 9]]) - ax.imshow(out, cmap="gray", vmin=0, vmax=10) - for i in range(3): - ax.axhline(y=i - 0.5, color=LN, lw=1.5) - ax.axvline(x=i - 0.5, color=LN, lw=1.5) - for i in range(2): - for j in range(2): - ax.text( - j, - i, - str(out[i, j]), - ha="center", - va="center", - fontsize=14, - fontweight="bold", - color="white" if out[i, j] > _DATA_BRIGHT_THRESH else "black", - ) - ax.set_title( - "③ Po ROI Pool 2x2\n(max z każdej komórki)", fontsize=FS, fontweight="bold" - ) - ax.set_xticks([]) - ax.set_yticks([]) - - fig.tight_layout() - save_fig(fig, "q24_roi_pooling.png") - - -# ============================================================ -# 14. DETR Pipeline -# ============================================================ -def draw_detr_pipeline() -> None: - """Draw detr pipeline.""" - fig, ax = plt.subplots(figsize=(11, 4.5)) - ax.set_xlim(-0.5, 11.5) - ax.set_ylim(-1, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "DETR — Transformer do detekcji (bez NMS, bez anchorów)", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - # Pipeline - draw_box(ax, 0, 1.5, 1.5, 1.5, "Obraz\nwejściowy", fill=GRAY1, fontsize=FS) - draw_arrow(ax, 1.6, 2.25, 2.1, 2.25, lw=1.5) - - draw_box( - ax, - 2.2, - 1.5, - 1.5, - 1.5, - "CNN\nBackbone\n(ResNet)", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 3.8, 2.25, 4.3, 2.25, lw=1.5) - - draw_box( - ax, - 4.4, - 1.5, - 1.8, - 1.5, - "Transformer\nEncoder\n(self-attention)", - fill=GRAY2, - fontsize=FS, - ) - draw_arrow(ax, 6.3, 2.25, 6.8, 2.25, lw=1.5) - - draw_box( - ax, - 6.9, - 1.5, - 1.8, - 1.5, - "Transformer\nDecoder\n(N=100 queries)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - # Output branches - draw_arrow(ax, 8.8, 2.5, 9.5, 3.0, lw=1.2) - draw_box(ax, 9.6, 2.7, 1.5, 0.7, "klasa₁...klasa₁₀₀", fill=GRAY4, fontsize=FS_SMALL) - - draw_arrow(ax, 8.8, 2.0, 9.5, 1.5, lw=1.2) - draw_box(ax, 9.6, 1.2, 1.5, 0.7, "bbox₁...bbox₁₀₀", fill=GRAY4, fontsize=FS_SMALL) - - # Annotations - ax.text( - 7.8, - 0.5, - '100 object queries → 5 obiektów + 95x "brak"', - ha="center", - fontsize=FS, - style="italic", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.text( - 5.5, - 0.0, - "Hungarian matching (trening): optymalne dopasowanie predykcji do GT", - ha="center", - fontsize=FS_SMALL, - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5}, - ) - - # Big benefit box - ax.text( - 5.5, - 4.0, - "BEZ anchorów • BEZ NMS • end-to-end • prosty pipeline", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q24_detr_pipeline.png") - - -# ============================================================ -# 15. Sliding Window illustration -# ============================================================ -def draw_sliding_window() -> None: - """Draw sliding window.""" - fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - fig.suptitle( - "Sliding Window — najprostsze podejście do detekcji", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Multi-position - ax = axes[0] - ax.add_patch( - mpatches.Rectangle((0, 0), 8, 6, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - # Grid of sliding windows - for i in range(4): - for j in range(3): - ax.add_patch( - mpatches.Rectangle( - (i * 1.8 + 0.2, j * 1.8 + 0.2), - 1.5, - 1.5, - facecolor="none", - edgecolor=LN, - lw=0.6, - linestyle="--", - ) - ) - # Highlight current window - ax.add_patch( - mpatches.Rectangle((2.0, 2.0), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5) - ) - ax.set_xlim(-0.5, 8.5) - ax.set_ylim(-0.5, 6.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("① Wiele pozycji\n(krok co 8 px)", fontsize=FS, fontweight="bold") - - # Multi-scale - ax = axes[1] - ax.add_patch( - mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - sizes = [(0.8, 0.8), (1.5, 1.5), (2.5, 2.5), (3.5, 3.5)] - for i, (w, h) in enumerate(sizes): - ax.add_patch( - mpatches.Rectangle( - (0.3 + i * 0.3, 0.3 + i * 0.3), - w, - h, - facecolor="none", - edgecolor=LN, - lw=1 + i * 0.3, - linestyle=[":", "--", "-.", "-"][i], - ) - ) - ax.text(3, 0, "4+ skal", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.set_xlim(-0.5, 6.5) - ax.set_ylim(-0.5, 5.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "② Wiele skal\n(obiekty mają różne rozmiary)", fontsize=FS, fontweight="bold" - ) - - # Count - ax = axes[2] - ax.axis("off") - ax.set_xlim(0, 6) - ax.set_ylim(0, 5) - - lines = [ - ("Obraz: 640 x 480 px", 4.5), - ("Okno: 64 x 64 px, krok 8 px", 3.8), - ("Pozycje: ~72 x 52 = 3 744", 3.1), - ("x 5 skal = 18 720 okien", 2.4), - ("x klasyfikacja = WOLNE!", 1.7), - ("→ ~3h na jeden obraz", 0.8), - ] - for text, yp in lines: - fw = "bold" if "~3h" in text or "WOLNE" in text else "normal" - col = GRAY2 if "WOLNE" in text or "~3h" in text else GRAY4 - ax.text( - 3.0, - yp, - text, - ha="center", - fontsize=FS, - fontweight=fw, - bbox={"boxstyle": "round,pad=0.2", "facecolor": col, "edgecolor": GRAY3}, - ) - - ax.set_title( - "③ Dlaczego wolne?\n(miliony klasyfikacji)", fontsize=FS, fontweight="bold" - ) - - fig.tight_layout() - save_fig(fig, "q24_sliding_window.png") - - -# ============================================================ -# 16. FPN (Feature Pyramid Network) -# ============================================================ -def draw_fpn() -> None: - """Draw fpn.""" - fig, ax = plt.subplots(figsize=(9, 5)) - ax.set_xlim(-0.5, 9.5) - ax.set_ylim(-0.5, 5.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "FPN (Feature Pyramid Network) — detekcja obiektów wszystkich rozmiarów", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - levels = [ - (0, 0, 2.0, 2.0, "C2\n56x56", "duże\ndetale"), - (0, 2.2, 1.5, 1.5, "C3\n28x28", ""), - (0, 3.9, 1.0, 1.0, "C4\n14x14", ""), - (0, 5.1, 0.6, 0.6, "C5\n7x7", "kontekst"), - ] - - for x, y, w, h, label, note in levels: - ax.add_patch( - mpatches.Rectangle((x, y - h), w, h, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - ax.text( - x + w / 2, - y - h / 2, - label, - ha="center", - va="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - if note: - ax.text( - x + w + 0.15, - y - h / 2, - note, - ha="left", - va="center", - fontsize=5, - style="italic", - ) - - ax.text( - 1.0, -0.3, "Bottom-up\n(backbone)", ha="center", fontsize=FS, fontweight="bold" - ) - - # Top-down + lateral - td_levels = [ - (4.5, 5.1, 0.6, 0.6, "P5"), - (4.5, 3.9, 1.0, 1.0, "P4"), - (4.5, 2.2, 1.5, 1.5, "P3"), - (4.5, 0, 2.0, 2.0, "P2"), - ] - - for x, y, w, h, label in td_levels: - ax.add_patch( - mpatches.Rectangle( - (x, y - h + h), w, h, facecolor=GRAY2, edgecolor=LN, lw=1.5 - ) - ) - ax.text( - x + w / 2, - y - h / 2 + h, - label, - ha="center", - va="center", - fontsize=FS_SMALL, - fontweight="bold", - ) - - # Lateral connections - for (_, y1, w1, h1, _, _), (x2, y2, _w2, h2, _) in zip( - levels, td_levels, strict=False - ): - draw_arrow(ax, w1 + 0.2, y1 - h1 / 2, x2 - 0.1, y2 + h2 / 2, lw=1, style="->") - - # Top-down arrows - for i in range(len(td_levels) - 1): - x2, y2, w2, h2, _ = td_levels[i] - x3, y3, w3, h3, _ = td_levels[i + 1] - draw_arrow( - ax, - x2 + w2 / 2, - y2, - x3 + w3 / 2, - y3 + h3 + 0.1, - lw=1.2, - style="->", - color=GRAY3, - ) - - ax.text( - 5.5, - -0.3, - "Top-down + lateral\n(FPN)", - ha="center", - fontsize=FS, - fontweight="bold", - ) - - # Detection outputs - det_labels = ["małe obj.", "średnie", "duże", "b. duże"] - for i, (x, y, w, h, _label) in enumerate(td_levels): - draw_arrow(ax, x + w + 0.1, y + h / 2, 7.5, y + h / 2, lw=0.8) - ax.text( - 7.7, - y + h / 2, - f"detekcja:\n{det_labels[3 - i]}", - fontsize=FS_SMALL, - va="center", - ) - - save_fig(fig, "q24_fpn.png") - - -# ============================================================ -# 17. Anchor boxes -# ============================================================ -def draw_anchor_boxes() -> None: - """Draw anchor boxes.""" - fig, ax = plt.subplots(figsize=(7, 5)) - ax.set_title( - "Anchor boxes — predefiniowane kształty", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1)) - - # Center point - cx, cy = 3, 2.5 - ax.plot(cx, cy, "ko", markersize=8, zorder=5) - ax.text(cx + 0.15, cy + 0.15, "(x, y)", fontsize=FS, fontweight="bold") - - # 9 anchors: 3 sizes x 3 ratios - anchors = [ - (0.8, 0.8, "-", "1:1 small"), - (1.6, 1.6, "-", "1:1 medium"), - (2.4, 2.4, "-", "1:1 large"), - (0.6, 1.2, "--", "1:2 small"), - (1.2, 2.4, "--", "1:2 medium"), - (1.8, 3.6, "--", "1:2 large"), - (1.2, 0.6, ":", "2:1 small"), - (2.4, 1.2, ":", "2:1 medium"), - (3.6, 1.8, ":", "2:1 large"), - ] - - for w, h, ls, _label in anchors: - rect = mpatches.Rectangle( - (cx - w / 2, cy - h / 2), - w, - h, - facecolor="none", - edgecolor=LN, - lw=1.2, - linestyle=ls, - ) - ax.add_patch(rect) - - # Legend-style labels - ax.text( - 3, - -0.5, - "9 anchorów = 3 rozmiary x 3 proporcje (1:1, 1:2, 2:1)\n" - "Sieć predykuje PRZESUNIĘCIE od najbliższego anchora", - ha="center", - fontsize=FS, - style="italic", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.set_xlim(-0.5, 6.5) - ax.set_ylim(-1.2, 5.5) - ax.set_aspect("equal") - ax.axis("off") - - save_fig(fig, "q24_anchor_boxes.png") - - -# ============================================================ -# 18. Detection task comparison -# ============================================================ -def draw_detection_tasks() -> None: - """Draw detection tasks.""" - fig, axes = plt.subplots(1, 3, figsize=(12, 4)) - fig.suptitle( - "Klasyfikacja vs Detekcja vs Segmentacja", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - # Classification - ax = axes[0] - ax.add_patch( - mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - # Simple cat silhouette - ax.add_patch(mpatches.Ellipse((2, 2), 2, 1.5, facecolor=GRAY3, edgecolor=LN, lw=1)) - ax.add_patch(mpatches.Ellipse((2, 3), 1, 0.8, facecolor=GRAY3, edgecolor=LN, lw=1)) - # Ears - ax.plot([1.6, 1.5, 1.8], [3.3, 3.8, 3.4], color=LN, lw=1.5) - ax.plot([2.2, 2.5, 2.4], [3.3, 3.8, 3.4], color=LN, lw=1.5) - ax.text( - 2, -0.4, '→ "KOT" (jedna etykieta)', ha="center", fontsize=FS, fontweight="bold" - ) - ax.set_xlim(-0.5, 4.5) - ax.set_ylim(-0.8, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Klasyfikacja\n(co?)", fontsize=FS, fontweight="bold") - - # Detection - ax = axes[1] - ax.add_patch( - mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - # Cat - ax.add_patch( - mpatches.Ellipse((1.2, 2), 1.2, 1, facecolor=GRAY3, edgecolor=LN, lw=1) - ) - ax.add_patch( - mpatches.Ellipse((1.2, 2.8), 0.7, 0.5, facecolor=GRAY3, edgecolor=LN, lw=1) - ) - # Dog - ax.add_patch( - mpatches.Ellipse((3, 1.5), 1.2, 1, facecolor=GRAY2, edgecolor=LN, lw=1) - ) - ax.add_patch( - mpatches.Ellipse((3, 2.3), 0.7, 0.5, facecolor=GRAY2, edgecolor=LN, lw=1) - ) - # Bounding boxes - ax.add_patch( - mpatches.Rectangle((0.3, 1.2), 1.8, 2.2, facecolor="none", edgecolor=LN, lw=2.5) - ) - ax.text(1.2, 3.5, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.add_patch( - mpatches.Rectangle((2.1, 0.8), 1.7, 2.0, facecolor="none", edgecolor=LN, lw=2.5) - ) - ax.text(3.0, 2.9, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.text( - 2, - -0.4, - "→ bbox + klasa (N obiektów)", - ha="center", - fontsize=FS, - fontweight="bold", - ) - ax.set_xlim(-0.5, 4.5) - ax.set_ylim(-0.8, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Detekcja\n(co? + gdzie?)", fontsize=FS, fontweight="bold") - - # Segmentation - ax = axes[2] - ax.add_patch( - mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5) - ) - # Cat mask (detailed) - theta = np.linspace(0, 2 * np.pi, 30) - cat_x = 1.2 + 0.6 * np.cos(theta) + 0.1 * np.sin(3 * theta) - cat_y = 2 + 0.5 * np.sin(theta) + 0.1 * np.cos(2 * theta) - ax.fill(cat_x, cat_y, facecolor=GRAY3, edgecolor=LN, lw=1.5) - # Dog mask - dog_x = 3.0 + 0.6 * np.cos(theta) + 0.05 * np.sin(4 * theta) - dog_y = 1.5 + 0.5 * np.sin(theta) + 0.08 * np.cos(3 * theta) - ax.fill(dog_x, dog_y, facecolor=GRAY2, edgecolor=LN, lw=1.5) - ax.text(1.2, 2, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.text(3.0, 1.5, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold") - ax.text( - 2, - -0.4, - "→ maska pikseli (per piksel)", - ha="center", - fontsize=FS, - fontweight="bold", - ) - ax.set_xlim(-0.5, 4.5) - ax.set_ylim(-0.8, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Segmentacja\n(dokładna maska)", fontsize=FS, fontweight="bold") - - fig.tight_layout() - save_fig(fig, "q24_detection_tasks.png") - - -# ============================================================ -# 19. CNN Architecture overview -# ============================================================ -def draw_cnn_architecture() -> None: - """Draw cnn architecture.""" - fig, ax = plt.subplots(figsize=(12, 4)) - ax.set_xlim(-0.5, 12.5) - ax.set_ylim(-1, 4.5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "CNN — od obrazu do predykcji (architektura)", - fontsize=FS_TITLE, - fontweight="bold", - pad=12, - ) - - # Input image - draw_box(ax, 0, 0.5, 1.5, 3, "Obraz\n224x224x3", fill=GRAY1, fontsize=FS) - - # Conv1 - draw_arrow(ax, 1.6, 2.0, 2.1, 2.0, lw=1.2) - draw_box( - ax, 2.2, 0.8, 1.2, 2.4, "Conv1\n+ReLU\n55x55x96", fill=GRAY4, fontsize=FS_SMALL - ) - - # Pool1 - draw_arrow(ax, 3.5, 2.0, 3.9, 2.0, lw=1.2) - draw_box(ax, 4.0, 1.0, 1.0, 2.0, "Pool\n27x27\nx96", fill=GRAY2, fontsize=FS_SMALL) - - # Conv2 - draw_arrow(ax, 5.1, 2.0, 5.5, 2.0, lw=1.2) - draw_box( - ax, - 5.6, - 0.8, - 1.2, - 2.4, - "Conv2\n+ReLU\n27x27\nx256", - fill=GRAY4, - fontsize=FS_SMALL, - ) - - # Pool2 - draw_arrow(ax, 6.9, 2.0, 7.3, 2.0, lw=1.2) - draw_box(ax, 7.4, 1.2, 0.8, 1.6, "Pool\n13x13\nx256", fill=GRAY2, fontsize=FS_SMALL) - - # More conv... - draw_arrow(ax, 8.3, 2.0, 8.7, 2.0, lw=1.2) - ax.text(9.0, 2.0, "...", fontsize=14, ha="center", va="center") - draw_arrow(ax, 9.3, 2.0, 9.7, 2.0, lw=1.2) - - # FC - draw_box(ax, 9.8, 1.2, 1.0, 1.6, "FC\n4096", fill=GRAY3, fontsize=FS) - - draw_arrow(ax, 10.9, 2.0, 11.3, 2.0, lw=1.2) - - # Output - draw_box( - ax, 11.4, 1.5, 1.0, 1.0, "Softmax\n1000 klas", fill=GRAY1, fontsize=FS_SMALL - ) - - # Annotations below - ax.text( - 3.0, - 0.0, - "rozmiar maleje\n224→55→27→13→6", - ha="center", - fontsize=FS_SMALL, - style="italic", - ) - ax.text( - 6.0, - 0.0, - "kanały rosną\n3→96→256→384", - ha="center", - fontsize=FS_SMALL, - style="italic", - ) - ax.text( - 10.0, 0.0, "decyzja\nkońcowa", ha="center", fontsize=FS_SMALL, style="italic" - ) - - # hierarchy - ax.text( - 6.0, - 4.0, - "Hierarchia: krawędzie → rogi → fragmenty → obiekty (K-R-F-O)", - ha="center", - fontsize=FS, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q24_cnn_architecture.png") - - # ============================================================ # MAIN # ============================================================ @@ -2294,4 +92,4 @@ if __name__ == "__main__": draw_anchor_boxes() draw_detection_tasks() draw_cnn_architecture() - _logger.info("All PYTANIE 24 diagrams generated!") + _logger.info("All PYTANIE 24 diagrams saved to: %s", OUTPUT_DIR) diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_q31_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_q31_diagrams.py index bbf0133..fb006c7 100755 --- a/python_pkg/praca_magisterska_video/generate_images/generate_q31_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_q31_diagrams.py @@ -17,1201 +17,39 @@ All: A4-compatible, B&W, 300 DPI, laser-printer-friendly. from __future__ import annotations import logging -from typing import TYPE_CHECKING import matplotlib as mpl mpl.use("Agg") -from pathlib import Path -import matplotlib.patches as mpatches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt -import numpy as np +from python_pkg.praca_magisterska_video.generate_images._q31_common import ( + OUTPUT_DIR, + _logger, +) +from python_pkg.praca_magisterska_video.generate_images._q31_criteria_comparison import ( + draw_criteria_comparison, +) +from python_pkg.praca_magisterska_video.generate_images._q31_ev_spectrum import ( + draw_conditions_spectrum, + draw_expected_value, +) +from python_pkg.praca_magisterska_video.generate_images._q31_hurwicz_mnemonic import ( + draw_criteria_mnemonic, + draw_hurwicz_interpolation, +) +from python_pkg.praca_magisterska_video.generate_images._q31_regret_matrix import ( + draw_regret_matrix, +) + +__all__ = [ + "draw_conditions_spectrum", + "draw_criteria_comparison", + "draw_criteria_mnemonic", + "draw_expected_value", + "draw_hurwicz_interpolation", + "draw_regret_matrix", +] -if TYPE_CHECKING: - from matplotlib.axes import Axes - -_logger = logging.getLogger(__name__) - -DPI = 300 -BG = "white" -LN = "black" -FS = 8 -FS_TITLE = 11 -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -GRAY1 = "#E8E8E8" -GRAY2 = "#D0D0D0" -GRAY3 = "#B8B8B8" -GRAY4 = "#F5F5F5" -GRAY5 = "#C0C0C0" - -# Number of regret table header columns before the max-regret column -_REGRET_HEADER_COLS = 4 -# Number of data state columns -_DATA_STATE_COLS = 3 -# Expected-value for the winning alternative -_WINNING_EV = 95 - - -def draw_box( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - text: str, - *, - fill: str = "white", - lw: float = 1.2, - fontsize: float = FS, - fontweight: str = "normal", - ha: str = "center", - va: str = "center", - rounded: bool = True, -) -> None: - """Draw a labeled box on the axes.""" - if rounded: - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.05", - lw=lw, - edgecolor=LN, - facecolor=fill, - ) - else: - rect = mpatches.Rectangle( - (x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h / 2, - text, - ha=ha, - va=va, - fontsize=fontsize, - fontweight=fontweight, - wrap=True, - ) - - -def draw_arrow( - ax: Axes, - x1: float, - y1: float, - x2: float, - y2: float, - *, - lw: float = 1.2, - style: str = "->", - color: str = LN, -) -> None: - """Draw an arrow between two points.""" - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={ - "arrowstyle": style, - "color": color, - "lw": lw, - }, - ) - - -# ============================================================ -# 1. PAYOFF MATRIX + ALL CRITERIA BAR CHART -# ============================================================ -def _draw_payoff_table(ax: Axes) -> None: - """Draw the payoff matrix table on the left panel.""" - ax.axis("off") - ax.set_xlim(0, 6) - ax.set_ylim(0, 6) - ax.set_title( - "Macierz wypłat (tys. zł)", - fontsize=FS_TITLE, - fontweight="bold", - pad=8, - ) - - headers_col = ["", "S₁\n(dobra)", "S₂\n(średnia)", "S₃\n(zła)"] - rows = [ - ["A₁ (fabryka)", "200", "50", "-100"], - ["A₂ (sklep)", "80", "70", "40"], - ["A₃ (obligacje)", "30", "30", "30"], - ] - - col_w = [1.8, 1.2, 1.2, 1.2] - row_h = 0.7 - start_y = 4.5 - start_x = 0.2 - - # Draw header row - x = start_x - for j, h in enumerate(headers_col): - fill = GRAY2 if j > 0 else GRAY3 - rect = mpatches.Rectangle( - (x, start_y), - col_w[j], - row_h, - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + col_w[j] / 2, - start_y + row_h / 2, - h, - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - x += col_w[j] - - # Draw data rows - for i, row in enumerate(rows): - x = start_x - y = start_y - (i + 1) * row_h - for j, val in enumerate(row): - fill = GRAY4 if j == 0 else ( - "white" if i % 2 == 0 else GRAY1 - ) - if val.startswith("-"): - fill = "#D8D8D8" - rect = mpatches.Rectangle( - (x, y), - col_w[j], - row_h, - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - fw = "bold" if j == 0 else "normal" - ax.text( - x + col_w[j] / 2, - y + row_h / 2, - val, - ha="center", - va="center", - fontsize=FS, - fontweight=fw, - ) - x += col_w[j] - - # Probability row for EV - x = start_x - y = start_y - 4 * row_h - probs = ["p (dla E[X]):", "0.5", "0.3", "0.2"] - for j, val in enumerate(probs): - fill = GRAY5 if j > 0 else GRAY3 - rect = mpatches.Rectangle( - (x, y), - col_w[j], - row_h * 0.7, - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + col_w[j] / 2, - y + row_h * 0.35, - val, - ha="center", - va="center", - fontsize=7, - fontweight="bold", - style="italic", - ) - x += col_w[j] - - -def _draw_criteria_bars(ax2: Axes) -> None: - """Draw the criteria comparison bar chart on the right panel.""" - criteria = [ - "E[X]", - "Laplace", - "Maximax", - "Maximin", - "Hurwicz\n\u03b1=0.6", - "Savage", - ] - - ev = [95, 69, 30] - laplace = [50, 63.3, 30] - maximax = [200, 80, 30] - maximin = [-100, 40, 30] - hurwicz = [80, 64, 30] - savage_maxregret = [140, 120, 170] - - winners = [0, 1, 0, 1, 0, 1] - - x_pos = np.arange(len(criteria)) - width = 0.22 - hatches = ["///", "...", "xxx"] - labels = ["A₁ (fabryka)", "A₂ (sklep)", "A₃ (obligacje)"] - - all_vals = [ - [ - ev[0], laplace[0], maximax[0], - maximin[0], hurwicz[0], savage_maxregret[0], - ], - [ - ev[1], laplace[1], maximax[1], - maximin[1], hurwicz[1], savage_maxregret[1], - ], - [ - ev[2], laplace[2], maximax[2], - maximin[2], hurwicz[2], savage_maxregret[2], - ], - ] - - for i in range(_DATA_STATE_COLS): - ax2.bar( - x_pos + (i - 1) * width, - all_vals[i], - width, - label=labels[i], - color="white", - edgecolor=LN, - hatch=hatches[i], - lw=0.8, - ) - - for c_idx in range(len(criteria)): - w = winners[c_idx] - val = all_vals[w][c_idx] - ax2.text( - x_pos[c_idx] + (w - 1) * width, - val + 5, - "★", - ha="center", - va="bottom", - fontsize=10, - fontweight="bold", - ) - - ax2.set_xticks(x_pos) - ax2.set_xticklabels(criteria, fontsize=7) - ax2.set_ylabel("Wartość kryterium", fontsize=8) - ax2.set_title( - "Porównanie kryteriów", - fontsize=FS_TITLE, - fontweight="bold", - pad=8, - ) - ax2.legend(fontsize=7, loc="upper right") - ax2.axhline(y=0, color=LN, lw=0.5, ls="-") - ax2.spines["top"].set_visible(False) - ax2.spines["right"].set_visible(False) - ax2.tick_params(labelsize=7) - - ax2.text( - 5, - -30, - "(Savage: niżej\n= lepiej)", - fontsize=6, - ha="center", - va="top", - style="italic", - ) - - -def draw_criteria_comparison() -> None: - """Draw payoff matrix and criteria comparison chart.""" - fig, axes = plt.subplots( - 1, - 2, - figsize=(8.27, 4.5), - gridspec_kw={"width_ratios": [1.2, 1]}, - ) - - _draw_payoff_table(axes[0]) - _draw_criteria_bars(axes[1]) - - plt.tight_layout() - outpath = str(Path(OUTPUT_DIR) / "q31_criteria_comparison.png") - fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) - plt.close(fig) - _logger.info(" Saved: %s", outpath) - - -# ============================================================ -# 2. REGRET MATRIX CONSTRUCTION -# ============================================================ -def _draw_original_payoff( - ax: Axes, - start_y: float, - row_h: float, -) -> None: - """Draw the original payoff matrix (left side of regret fig).""" - ax.text( - 2.2, - 6.3, - "Krok 1: Macierz wypłat", - fontsize=9, - fontweight="bold", - ha="center", - va="center", - ) - - col_w = 1.0 - headers = ["", "S₁", "S₂", "S₃"] - data = [ - ["A₁", "200", "50", "-100"], - ["A₂", "80", "70", "40"], - ["A₃", "30", "30", "30"], - ] - start_x = 0.3 - - for j, h in enumerate(headers): - w = 0.7 if j == 0 else col_w - x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w) - rect = mpatches.Rectangle( - (x, start_y), - w, - row_h, - lw=1, - edgecolor=LN, - facecolor=GRAY2, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - start_y + row_h / 2, - h, - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - - for i, row in enumerate(data): - y = start_y - (i + 1) * row_h - for j, val in enumerate(row): - w = 0.7 if j == 0 else col_w - x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w) - fill = GRAY4 if j == 0 else "white" - rect = mpatches.Rectangle( - (x, y), - w, - row_h, - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + row_h / 2, - val, - ha="center", - va="center", - fontsize=FS, - ) - - # Max per column annotation - max_y = start_y - _DATA_STATE_COLS * row_h - 0.1 - col_maxes = ["max=200", "max=70", "max=40"] - for idx, label in enumerate(col_maxes): - ax.text( - start_x + 0.7 + (idx + 0.5) * col_w, - max_y, - label, - fontsize=7, - ha="center", - va="top", - fontweight="bold", - color="#333", - ) - - # Arrow - ax.annotate( - "", - xy=(5.0, 4.8), - xytext=(4.2, 4.8), - arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, - ) - ax.text( - 4.6, - 5.0, - "rᵢⱼ = max - aᵢⱼ", - fontsize=8, - ha="center", - va="bottom", - fontweight="bold", - ) - - -def _draw_regret_table( - ax: Axes, - start_y: float, - row_h: float, -) -> None: - """Draw the regret matrix (right side of regret fig).""" - ax.text( - 7.5, - 6.3, - "Krok 2: Macierz żalu", - fontsize=9, - fontweight="bold", - ha="center", - va="center", - ) - - regret_data = [ - ["A₁", "0", "20", "140"], - ["A₂", "120", "0", "0"], - ["A₃", "170", "40", "10"], - ] - headers2 = ["", "S₁", "S₂", "S₃", "max rᵢ"] - start_x2 = 5.3 - - for j, h in enumerate(headers2): - w = 0.7 if j == 0 else ( - 0.9 if j < _REGRET_HEADER_COLS else 1.0 - ) - x = start_x2 - if j == 0: - x = start_x2 - elif j <= _DATA_STATE_COLS: - x = start_x2 + 0.7 + (j - 1) * 0.9 - else: - x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9 - rect = mpatches.Rectangle( - (x, start_y), - w, - row_h, - lw=1, - edgecolor=LN, - facecolor=( - GRAY2 if j < _REGRET_HEADER_COLS else GRAY3 - ), - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - start_y + row_h / 2, - h, - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - - max_regrets = [140, 120, 170] - for i, row in enumerate(regret_data): - y = start_y - (i + 1) * row_h - for j, val in enumerate(row): - w = 0.7 if j == 0 else 0.9 - x = start_x2 + ( - 0 if j == 0 else 0.7 + (j - 1) * 0.9 - ) - fill = GRAY4 if j == 0 else "white" - if j > 0 and int(val) == max_regrets[i]: - fill = GRAY2 - rect = mpatches.Rectangle( - (x, y), - w, - row_h, - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - fw = ( - "bold" - if (j > 0 and int(val) == max_regrets[i]) - else "normal" - ) - ax.text( - x + w / 2, - y + row_h / 2, - val, - ha="center", - va="center", - fontsize=FS, - fontweight=fw, - ) - - # Max regret column - x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9 - w = 1.0 - is_winner = max_regrets[i] == min(max_regrets) - fill = "#C8C8C8" if is_winner else GRAY1 - rect = mpatches.Rectangle( - (x, y), - w, - row_h, - lw=1.5 if is_winner else 1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - marker = " ★" if is_winner else "" - ax.text( - x + w / 2, - y + row_h / 2, - f"{max_regrets[i]}{marker}", - ha="center", - va="center", - fontsize=FS, - fontweight="bold", - ) - - -def draw_regret_matrix() -> None: - """Draw the regret matrix construction diagram.""" - fig, ax = plt.subplots(1, 1, figsize=(8.27, 5)) - ax.axis("off") - ax.set_xlim(0, 10) - ax.set_ylim(0, 7) - ax.set_title( - "Kryterium Savage'a \u2014 budowa macierzy żalu", - fontsize=FS_TITLE + 1, - fontweight="bold", - pad=10, - ) - - start_y = 5.5 - row_h = 0.55 - - _draw_original_payoff(ax, start_y, row_h) - _draw_regret_table(ax, start_y, row_h) - - # Bottom conclusion - ax.text( - 5.0, - 2.8, - "Krok 3: Wybierz min z max żalu" - " → A₂ (max żal = 120)", - fontsize=10, - ha="center", - va="center", - fontweight="bold", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY1, - "edgecolor": LN, - "lw": 1.5, - }, - ) - - # Interpretation examples - ax.text( - 5.0, - 2.0, - "Interpretacja żalu: r₁₃ = 140 oznacza:\n" - "\u201eGdyby nastąpił S₃ (zła koniunktura)," - " a wybrałbym A₁,\n" - "żałowałbym, bo najlepszą opcją byłoby" - " A₂ z wynikiem 40 \u2014 traciłbym 140\u201d", - fontsize=7.5, - ha="center", - va="center", - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - "lw": 0.8, - }, - ) - - # Mnemonic - ax.text( - 5.0, - 0.8, - "Mnemonik: Savage = \u201eŻal jak nóż\u201d\n" - "Maksymalny żal to nóż " - "\u2014 wybierz opcję z NAJMNIEJSZYM nożem", - fontsize=8, - ha="center", - va="center", - fontweight="bold", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": "white", - "edgecolor": LN, - "lw": 1, - }, - ) - - plt.tight_layout() - outpath = str(Path(OUTPUT_DIR) / "q31_regret_matrix.png") - fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) - plt.close(fig) - _logger.info(" Saved: %s", outpath) - - -# ============================================================ -# 3. HURWICZ alpha INTERPOLATION -# ============================================================ -def draw_hurwicz_interpolation() -> None: - """Draw Hurwicz alpha interpolation diagram.""" - fig, ax = plt.subplots(1, 1, figsize=(8.27, 4)) - ax.set_title( - "Kryterium Hurwicza" - " \u2014 wpływ \u03b1 na wybór alternatywy", - fontsize=FS_TITLE + 1, - fontweight="bold", - pad=10, - ) - - alphas = np.linspace(0, 1, 200) - - v1 = alphas * 200 + (1 - alphas) * (-100) - v2 = alphas * 80 + (1 - alphas) * 40 - v3 = alphas * 30 + (1 - alphas) * 30 - - ax.plot( - alphas, v1, "k-", lw=2, - label="A₁ (fabryka): V = 300\u03b1 - 100", - ) - ax.plot( - alphas, v2, "k--", lw=2, - label="A₂ (sklep): V = 40\u03b1 + 40", - ) - ax.plot( - alphas, v3, "k:", lw=2, - label="A₃ (obligacje): V = 30", - ) - - # Crossover A2=A1 - alpha_cross_12 = 140 / 260 - v_cross_12 = 40 * alpha_cross_12 + 40 - - ax.plot(alpha_cross_12, v_cross_12, "ko", markersize=8, zorder=5) - ax.annotate( - f"\u03b1 ≈ {alpha_cross_12:.2f}\nA₁ = A₂", - xy=(alpha_cross_12, v_cross_12), - xytext=(alpha_cross_12 + 0.12, v_cross_12 - 30), - fontsize=8, - fontweight="bold", - arrowprops={ - "arrowstyle": "->", - "color": LN, - "lw": 1, - }, - ) - - # Shade winning regions - ax.axvspan(0, alpha_cross_12, alpha=0.08, color="black") - ax.axvspan(alpha_cross_12, 1, alpha=0.15, color="black") - - ax.text( - alpha_cross_12 / 2, - -60, - "A₂ wygrywa\n(pesymistycznie)", - fontsize=8, - ha="center", - va="center", - bbox={ - "boxstyle": "round", - "facecolor": "white", - "edgecolor": LN, - }, - ) - ax.text( - (alpha_cross_12 + 1) / 2, - 160, - "A₁ wygrywa\n(optymistycznie)", - fontsize=8, - ha="center", - va="center", - bbox={ - "boxstyle": "round", - "facecolor": "white", - "edgecolor": LN, - }, - ) - - # Special alpha values - ax.axvline(x=0, color=LN, lw=0.5, ls=":") - ax.axvline(x=1, color=LN, lw=0.5, ls=":") - ax.text( - 0, - -115, - "\u03b1=0\nmaximin", - fontsize=7, - ha="center", - va="top", - fontweight="bold", - ) - ax.text( - 1, - -115, - "\u03b1=1\nmaximax", - fontsize=7, - ha="center", - va="top", - fontweight="bold", - ) - - ax.set_xlabel("Współczynnik optymizmu \u03b1", fontsize=9) - ax.set_ylabel( - "V(Aᵢ) = \u03b1·max + (1-\u03b1)·min", fontsize=9 - ) - ax.legend(fontsize=8, loc="upper left") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xlim(-0.05, 1.05) - ax.axhline(y=0, color=LN, lw=0.3, ls="-") - ax.tick_params(labelsize=8) - - plt.tight_layout() - outpath = str(Path(OUTPUT_DIR) / "q31_hurwicz_alpha.png") - fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) - plt.close(fig) - _logger.info(" Saved: %s", outpath) - - -# ============================================================ -# 4. DECISION CRITERIA MNEMONIC MAP -# ============================================================ -def draw_criteria_mnemonic() -> None: - """Draw decision criteria mnemonic map diagram.""" - fig, ax = plt.subplots(1, 1, figsize=(8.27, 6)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 8) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Mapa mnemoniczna \u2014 6 kryteriów decyzyjnych", - fontsize=FS_TITLE + 2, - fontweight="bold", - pad=10, - ) - - # Central node - draw_box( - ax, - 3.5, - 3.5, - 3, - 1, - "MACIERZ\nWYPŁAT", - fill=GRAY2, - lw=2, - fontsize=11, - fontweight="bold", - ) - - # Criteria boxes around the center - criteria = [ - ( - 0, 6.5, 3, 1.2, - "WARTOŚĆ OCZEKIWANA", - "\u201eMam prawdopodobieństwa\u201d", - "E[Aᵢ] = Σ pⱼ·aᵢⱼ", - ), - ( - 3.5, 6.5, 3, 1.2, - "LAPLACE", - "\u201eWszystko po równo\u201d", - "V = Σaᵢⱼ / n", - ), - ( - 7, 6.5, 3, 1.2, - "MAXIMAX", - "\u201eOptymista: max z max\u201d", - "max maxⱼ aᵢⱼ", - ), - ( - 0, 0.5, 3, 1.2, - "MAXIMIN (Wald)", - "\u201ePesymista: max z min\u201d", - "max minⱼ aᵢⱼ", - ), - ( - 3.5, 0.5, 3, 1.2, - "HURWICZ", - "\u201e\u03b1 pomiędzy\u201d", - "\u03b1·max + (1-\u03b1)·min", - ), - ( - 7, 0.5, 3, 1.2, - "SAVAGE", - "\u201eMin max żalu\u201d", - "min maxⱼ rᵢⱼ", - ), - ] - - fills = [GRAY3, GRAY1, "white", "white", GRAY1, GRAY3] - - for i, (x, y, w, h, title, mnem, formula) in enumerate( - criteria - ): - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.08", - lw=1.5, - edgecolor=LN, - facecolor=fills[i], - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h * 0.78, - title, - ha="center", - va="center", - fontsize=8, - fontweight="bold", - ) - ax.text( - x + w / 2, - y + h * 0.45, - mnem, - ha="center", - va="center", - fontsize=7, - style="italic", - ) - ax.text( - x + w / 2, - y + h * 0.15, - formula, - ha="center", - va="center", - fontsize=7, - fontweight="bold", - family="monospace", - ) - - # Arrows from center to each box - cx, cy = 5, 4 - bx = x + w / 2 - by_center = y + h / 2 - if by_center > cy: - ax.annotate( - "", - xy=(bx, y), - xytext=(cx, 4.5), - arrowprops={ - "arrowstyle": "->", - "color": LN, - "lw": 1, - "connectionstyle": "arc3,rad=0", - }, - ) - else: - ax.annotate( - "", - xy=(bx, y + h), - xytext=(cx, 3.5), - arrowprops={ - "arrowstyle": "->", - "color": LN, - "lw": 1, - "connectionstyle": "arc3,rad=0", - }, - ) - - # Labels on arrows - arrow_labels = [ - (1.2, 5.6, "znane p"), - (5, 5.6, "p = 1/n"), - (8.7, 5.6, "max ↑"), - (1.2, 2.5, "min ↑"), - (5, 2.5, "podaj \u03b1"), - (8.7, 2.5, "macierz\nżalu"), - ] - for lx, ly, ltext in arrow_labels: - ax.text( - lx, - ly, - ltext, - fontsize=7, - ha="center", - va="center", - bbox={ - "boxstyle": "round,pad=0.15", - "facecolor": "white", - "edgecolor": GRAY3, - "lw": 0.5, - }, - ) - - plt.tight_layout() - outpath = str(Path(OUTPUT_DIR) / "q31_criteria_mnemonic.png") - fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) - plt.close(fig) - _logger.info(" Saved: %s", outpath) - - -# ============================================================ -# 5. EXPECTED VALUE CRITERION WITH PROBABILITY BARS -# ============================================================ -def draw_expected_value() -> None: - """Draw expected value criterion with probability-weighted bars.""" - fig, axes = plt.subplots(1, 3, figsize=(8.27, 3.5), sharey=True) - fig.suptitle( - "Kryterium wartości oczekiwanej E[X]" - " \u2014 rozkład wyników per alternatywa", - fontsize=FS_TITLE, - fontweight="bold", - y=1.02, - ) - - probs = [0.5, 0.3, 0.2] - alts = [ - ("A₁ (fabryka)", [200, 50, -100], 95), - ("A₂ (sklep)", [80, 70, 40], 69), - ("A₃ (obligacje)", [30, 30, 30], 30), - ] - - hatches = ["///", "...", "xxx"] - - for _idx, (ax, (name, vals, ev)) in enumerate( - zip(axes, alts, strict=False) - ): - x_positions = [0, 0.6, 1.0] - widths = [p * 0.9 for p in probs] - - for i, (v, p, h) in enumerate( - zip(vals, probs, hatches, strict=False) - ): - color = "white" if v >= 0 else GRAY2 - ax.bar( - x_positions[i], - v, - width=widths[i], - color=color, - edgecolor=LN, - hatch=h, - lw=0.8, - align="edge", - ) - offset = 8 if v >= 0 else -12 - ax.text( - x_positions[i] + widths[i] / 2, - v + offset, - f"{v}", - ha="center", - va="center", - fontsize=8, - fontweight="bold", - ) - contrib = v * p - ax.text( - x_positions[i] + widths[i] / 2, - v / 2, - f"{v}x{p}\n={contrib:.0f}", - ha="center", - va="center", - fontsize=6, - style="italic", - ) - - # Expected value line - ax.axhline(y=ev, color=LN, lw=2, ls="--") - ax.text( - 1.35, - ev, - f"E[X]={ev}", - fontsize=8, - fontweight="bold", - va="center", - ha="left", - bbox={ - "boxstyle": "round,pad=0.15", - "facecolor": GRAY1, - "edgecolor": LN, - }, - ) - - ax.set_title(name, fontsize=9, fontweight="bold") - ax.set_xticks([0.225, 0.735, 1.09]) - ax.set_xticklabels(["S₁", "S₂", "S₃"], fontsize=7) - ax.axhline(y=0, color=LN, lw=0.5) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.tick_params(labelsize=7) - - # Star on winner - if ev == _WINNING_EV: - ax.text( - 0.7, - ev + 20, - "★ MAX", - fontsize=9, - fontweight="bold", - ha="center", - va="bottom", - ) - - axes[0].set_ylabel("Wypłata (tys. zł)", fontsize=8) - - plt.tight_layout() - outpath = str(Path(OUTPUT_DIR) / "q31_expected_value.png") - fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) - plt.close(fig) - _logger.info(" Saved: %s", outpath) - - -# ============================================================ -# 6. DECISION CONDITIONS SPECTRUM -# ============================================================ -def draw_conditions_spectrum() -> None: - """Draw decision conditions spectrum diagram.""" - fig, ax = plt.subplots(1, 1, figsize=(8.27, 3.5)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Warunki decyzyjne" - " \u2014 spektrum wiedzy decydenta", - fontsize=FS_TITLE + 1, - fontweight="bold", - pad=10, - ) - - # Three zones - zones = [ - ( - 0.3, 1.5, 2.8, 2.5, - "PEWNOŚĆ", "white", - [ - "Znamy dokładny wynik", - "Przykład: lokata 5%", - "Metoda: po prostu wybierz", - "najlepszy wynik", - ], - ), - ( - 3.5, 1.5, 2.8, 2.5, - "RYZYKO", GRAY1, - [ - "Znamy wyniki I prawdop.", - "Przykład: gra w kości", - "Metoda: wartość", - "oczekiwana E[X]", - ], - ), - ( - 6.7, 1.5, 2.8, 2.5, - "NIEPEWNOŚĆ", GRAY3, - [ - "Znamy wyniki, ale", - "NIE znamy prawdop.", - "Metody: Laplace, maximax,", - "maximin, Hurwicz, Savage", - ], - ), - ] - - for x, y, w, h, title, fill, lines in zones: - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.1", - lw=2, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h - 0.3, - title, - ha="center", - va="center", - fontsize=11, - fontweight="bold", - ) - for i, line in enumerate(lines): - ax.text( - x + w / 2, - y + h - 0.7 - i * 0.4, - line, - ha="center", - va="center", - fontsize=7, - ) - - # Arrows between zones - ax.annotate( - "", - xy=(3.4, 2.75), - xytext=(3.15, 2.75), - arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, - ) - ax.annotate( - "", - xy=(6.6, 2.75), - xytext=(6.35, 2.75), - arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, - ) - - # Bottom: knowledge gradient bar - gradient_y = 0.5 - gradient_h = 0.5 - n_steps = 50 - for i in range(n_steps): - x = 0.3 + i * (9.2 / n_steps) - w = 9.2 / n_steps + 0.01 - gray_val = 1 - (i / n_steps) * 0.7 - rect = mpatches.Rectangle( - (x, gradient_y), - w, - gradient_h, - lw=0, - facecolor=str(gray_val), - ) - ax.add_patch(rect) - - rect = mpatches.Rectangle( - (0.3, gradient_y), - 9.2, - gradient_h, - lw=1.5, - edgecolor=LN, - facecolor="none", - ) - ax.add_patch(rect) - - ax.text( - 0.3, gradient_y - 0.15, - "Dużo wiedzy", fontsize=7, ha="left", va="top", - ) - ax.text( - 9.5, gradient_y - 0.15, - "Mało wiedzy", fontsize=7, ha="right", va="top", - ) - ax.text( - 4.95, - gradient_y + gradient_h / 2, - "POZIOM WIEDZY DECYDENTA", - fontsize=8, - fontweight="bold", - ha="center", - va="center", - color="white", - ) - - plt.tight_layout() - outpath = str(Path(OUTPUT_DIR) / "q31_conditions_spectrum.png") - fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG) - plt.close(fig) - _logger.info(" Saved: %s", outpath) - - -# ============================================================ -# MAIN -# ============================================================ if __name__ == "__main__": logging.basicConfig(level=logging.INFO) _logger.info("Generating PYTANIE 31 diagrams...") diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_q9_all_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_q9_all_diagrams.py index 453c52a..553e8e2 100755 --- a/python_pkg/praca_magisterska_video/generate_images/generate_q9_all_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_q9_all_diagrams.py @@ -2,1608 +2,60 @@ """Generate ALL diagrams for PYTANIE 9: Procesy i wątki (SOI). Replaces every ASCII diagram with a monochrome A4-printable PNG (300 DPI). +Re-exports all diagram generators from submodules. """ from __future__ import annotations import logging -from typing import TYPE_CHECKING - -import matplotlib as mpl - -mpl.use("Agg") from pathlib import Path +import sys -import matplotlib.patches as mpatches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt -import numpy as np +# Ensure sibling modules are importable when run as a script. +sys.path.insert(0, str(Path(__file__).resolve().parent)) -if TYPE_CHECKING: - from matplotlib.axes import Axes - from matplotlib.figure import Figure +from _q9_basics import ( + gen_memory_layout, + gen_pcb_structure, + gen_process_states, + gen_process_vs_thread, + gen_speed_comparison, + gen_thread_structure, +) +from _q9_classic_sync import ( + gen_classic_problems, + gen_semaphore_concept, + gen_sync_comparison, +) +from _q9_ipc import gen_ipc_details, gen_ipc_table, gen_scenario_table +from _q9_race_deadlock import ( + gen_coffman_strategies, + gen_deadlock_scenario, + gen_race_condition, + gen_starvation_priority, +) + +__all__ = [ + "gen_classic_problems", + "gen_coffman_strategies", + "gen_deadlock_scenario", + "gen_ipc_details", + "gen_ipc_table", + "gen_memory_layout", + "gen_pcb_structure", + "gen_process_states", + "gen_process_vs_thread", + "gen_race_condition", + "gen_scenario_table", + "gen_semaphore_concept", + "gen_speed_comparison", + "gen_starvation_priority", + "gen_sync_comparison", + "gen_thread_structure", +] _logger = logging.getLogger(__name__) -DPI = 300 -BG = "white" -LN = "black" -FS = 8 -FS_TITLE = 11 -FS_SMALL = 6.5 -FS_LABEL = 9 -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -GRAY1 = "#E8E8E8" -GRAY2 = "#D0D0D0" -GRAY3 = "#B8B8B8" -GRAY4 = "#F5F5F5" -GRAY5 = "#C0C0C0" -OCCUPIED_SLOTS = 2 - - -def draw_box( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - text: str, - fill: str = "white", - lw: float = 1.2, - fontsize: float = FS, - fontweight: str = "normal", - ha: str = "center", - va: str = "center", - *, - rounded: bool = True, - edgecolor: str = LN, - linestyle: str = "-", -) -> None: - """Draw box.""" - if rounded: - rect = FancyBboxPatch( - (x, y), - w, - h, - boxstyle="round,pad=0.05", - lw=lw, - edgecolor=edgecolor, - facecolor=fill, - linestyle=linestyle, - ) - else: - rect = mpatches.Rectangle( - (x, y), - w, - h, - lw=lw, - edgecolor=edgecolor, - facecolor=fill, - linestyle=linestyle, - ) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h / 2, - text, - ha=ha, - va=va, - fontsize=fontsize, - fontweight=fontweight, - wrap=True, - ) - - -def draw_arrow( - ax: Axes, - x1: float, - y1: float, - x2: float, - y2: float, - lw: float = 1.2, - style: str = "->", - color: str = LN, -) -> None: - """Draw arrow.""" - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": style, "color": color, "lw": lw}, - ) - - -def draw_double_arrow( - ax: Axes, - x1: float, - y1: float, - x2: float, - y2: float, - lw: float = 1.2, - color: str = LN, -) -> None: - """Draw double arrow.""" - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": "<->", "color": color, "lw": lw}, - ) - - -def save_fig(fig: Figure, name: str) -> None: - """Save fig.""" - path = str(Path(OUTPUT_DIR) / name) - fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15) - plt.close(fig) - _logger.info(" Saved: %s", path) - - -def draw_table( - ax: Axes, - headers: list[str], - rows: list[list[str]], - x0: float, - y0: float, - col_widths: list[float], - row_h: float = 0.4, - header_fill: str = GRAY2, - row_fills: list[str] | None = None, - fontsize: float = FS, - header_fontsize: float | None = None, -) -> None: - """Draw a clean table on axes.""" - if header_fontsize is None: - header_fontsize = fontsize - len(headers) - len(rows) - sum(col_widths) - - # Header - cx = x0 - for j, hdr in enumerate(headers): - draw_box( - ax, - cx, - y0, - col_widths[j], - row_h, - hdr, - fill=header_fill, - fontsize=header_fontsize, - fontweight="bold", - rounded=False, - ) - cx += col_widths[j] - - # Rows - for i, row in enumerate(rows): - cy = y0 - (i + 1) * row_h - cx = x0 - fill = GRAY4 if (i % 2 == 0) else "white" - if row_fills and i < len(row_fills): - fill = row_fills[i] - for j, cell in enumerate(row): - fw = "bold" if j == 0 else "normal" - draw_box( - ax, - cx, - cy, - col_widths[j], - row_h, - cell, - fill=fill, - fontsize=fontsize, - fontweight=fw, - rounded=False, - ) - cx += col_widths[j] - - -# ============================================================ -# 1. Process vs Thread comparison table -# ============================================================ -def gen_process_vs_thread() -> None: - """Gen process vs thread.""" - fig, ax = plt.subplots(figsize=(7.5, 4.5)) - ax.set_xlim(0, 10) - ax.set_ylim(-4.5, 1.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Proces vs Wątek — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - headers = ["Cecha", "Proces", "Wątek"] - col_w = [2.5, 3.5, 3.5] - rows = [ - ["Pamięć", "Własna, izolowana", "Współdzielona (heap)"], - ["Tworzenie", "~1-10 ms", "~10-100 μs (100x szybciej)"], - ["Przełączanie", "~1-5 μs (TLB flush)", "~0.1-0.5 μs (10x)"], - ["Komunikacja", "IPC (pipe, socket, shm)", "Bezpośrednia (wspólna pam.)"], - ["Izolacja", "Pełna — awaria izolowana", "Brak — może zabić proces"], - ["Zastosowanie", "Bezpieczeństwo, izolacja", "Wydajność, współdzielenie"], - ] - draw_table( - ax, - headers, - rows, - x0=0.25, - y0=0.8, - col_widths=col_w, - row_h=0.55, - fontsize=7.5, - header_fontsize=FS_LABEL, - ) - - # Analogy at bottom - ax.text( - 5.0, - -4.2, - "Analogia: Proces = mieszkanie (własny adres) " - "Wątek = pokój w mieszkaniu (wspólna kuchnia = heap)", - ha="center", - fontsize=FS, - style="italic", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q9_process_vs_thread.png") - - -# ============================================================ -# 2. Memory segments layout -# ============================================================ -def gen_memory_layout() -> None: - """Gen memory layout.""" - fig, ax = plt.subplots(figsize=(6, 5)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 8) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Segmenty pamięci procesu", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - segments = [ - ("STACK ↓", "zmienne lokalne, adresy\npowrotu (każdy wątek WŁASNY)", GRAY1), - ("...", "(wolna przestrzeń)", "white"), - ("HEAP ↑", "malloc/new — dynamiczna\nalokacja (współdzielony)", GRAY4), - ("BSS", "zmienne globalne\nniezainicjalizowane (zerowane)", GRAY2), - ("DATA", "zmienne globalne\nzainicjalizowane", GRAY3), - ("TEXT", "kod maszynowy\n(read-only, współdzielony)", GRAY5), - ] - bx, bw = 2.0, 2.5 - seg_h = 0.9 - gap = 0.05 - top_y = 7.0 - - for i, (name, desc, color) in enumerate(segments): - y = top_y - i * (seg_h + gap) - draw_box( - ax, - bx, - y, - bw, - seg_h, - name, - fill=color, - fontsize=FS_LABEL, - fontweight="bold", - rounded=False, - ) - ax.text(bx + bw + 0.3, y + seg_h / 2, desc, fontsize=7.5, va="center") - - # Address labels - ax.text( - bx - 0.2, - top_y + seg_h / 2, - "wysoki\nadres", - fontsize=FS_SMALL, - va="center", - ha="right", - style="italic", - ) - bottom_y = top_y - 5 * (seg_h + gap) - ax.text( - bx - 0.2, - bottom_y + seg_h / 2, - "niski\nadres", - fontsize=FS_SMALL, - va="center", - ha="right", - style="italic", - ) - - # Arrows for growth - ax.annotate( - "", - xy=(bx - 0.5, top_y - 0.1), - xytext=(bx - 0.5, top_y + seg_h + 0.1), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(bx - 0.9, top_y + 0.4, "rośnie\nw dół", fontsize=FS_SMALL, ha="center") - - heap_y = top_y - 2 * (seg_h + gap) - ax.annotate( - "", - xy=(bx - 0.5, heap_y + seg_h + 0.1), - xytext=(bx - 0.5, heap_y - 0.1), - arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN}, - ) - ax.text(bx - 0.9, heap_y + 0.5, "rośnie\nw górę", fontsize=FS_SMALL, ha="center") - - save_fig(fig, "q9_memory_layout.png") - - -# ============================================================ -# 3. Process states diagram -# ============================================================ -def gen_process_states() -> None: - """Gen process states.""" - fig, ax = plt.subplots(figsize=(7, 3.5)) - ax.set_xlim(0, 12) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Stany procesu — diagram przejść", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - states = { - "NEW": (1.0, 2.5), - "READY": (3.5, 2.5), - "RUNNING": (6.5, 2.5), - "BLOCKED": (6.5, 0.5), - "TERMINATED": (10.0, 2.5), - } - fills = { - "NEW": GRAY4, - "READY": GRAY1, - "RUNNING": GRAY3, - "BLOCKED": GRAY2, - "TERMINATED": GRAY5, - } - bw, bh = 1.8, 0.9 - for name, (x, y) in states.items(): - draw_box( - ax, - x, - y, - bw, - bh, - name, - fill=fills[name], - fontsize=FS_LABEL, - fontweight="bold", - ) - - # Transitions - transitions = [ - ("NEW", "READY", "admit"), - ("READY", "RUNNING", "dispatch\n(scheduler)"), - ("RUNNING", "TERMINATED", "exit"), - ("RUNNING", "BLOCKED", "I/O wait"), - ] - for src, dst, label in transitions: - sx, sy = states[src] - dx, dy = states[dst] - if sy == dy: # horizontal - draw_arrow(ax, sx + bw, sy + bh / 2, dx, dy + bh / 2, lw=1.5) - mx = (sx + bw + dx) / 2 - ax.text( - mx, - sy + bh / 2 + 0.25, - label, - fontsize=FS_SMALL, - ha="center", - va="bottom", - ) - else: # vertical - draw_arrow(ax, sx + bw / 2, sy, dx + bw / 2, dy + bh, lw=1.5) - ax.text( - sx + bw + 0.2, - (sy + dy + bh) / 2, - label, - fontsize=FS_SMALL, - ha="left", - va="center", - ) - - # BLOCKED → READY - bx, by = states["BLOCKED"] - rx, ry = states["READY"] - ax.annotate( - "", - xy=(rx + bw / 2, ry), - xytext=(bx - 0.3, by + bh / 2), - arrowprops={ - "arrowstyle": "->", - "lw": 1.5, - "color": LN, - "connectionstyle": "arc3,rad=0.3", - }, - ) - ax.text(3.5, 0.7, "I/O done", fontsize=FS_SMALL, ha="center") - - # RUNNING → READY (preemption) - rux, ruy = states["RUNNING"] - draw_arrow(ax, rux, ruy + bh, rx + bw, ry + bh, lw=1.2) - ax.text(5.0, 3.7, "preempt /\ntimeout", fontsize=FS_SMALL, ha="center") - - save_fig(fig, "q9_process_states.png") - - -# ============================================================ -# 4. Thread structure within process -# ============================================================ -def gen_thread_structure() -> None: - """Gen thread structure.""" - fig, ax = plt.subplots(figsize=(8, 4.5)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 6) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Wątki wewnątrz procesu (PID=42)", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - # Shared memory region - draw_box(ax, 0.5, 3.5, 9.0, 1.8, "", fill=GRAY1, rounded=False, lw=2) - ax.text(5.0, 5.0, "WSPÓŁDZIELONE", fontsize=FS, fontweight="bold", ha="center") - - labels_shared = ["TEXT", "DATA", "BSS", "HEAP", "pliki", "PID"] - for i, lab in enumerate(labels_shared): - x = 1.0 + i * 1.4 - draw_box( - ax, - x, - 3.8, - 1.1, - 0.6, - lab, - fill=GRAY3, - fontsize=FS, - fontweight="bold", - rounded=False, - ) - ax.text(x + 0.55, 4.6, lab, fontsize=FS_SMALL, ha="center", color="#555555") - - # Per-thread regions - draw_box( - ax, 0.5, 0.5, 9.0, 2.7, "", fill="white", rounded=False, lw=2, linestyle="--" - ) - ax.text( - 5.0, 2.95, "PRYWATNE (każdy wątek)", fontsize=FS, fontweight="bold", ha="center" - ) - - for i in range(3): - x = 1.0 + i * 3.0 - tid = i + 1 - draw_box(ax, x, 0.7, 2.3, 2.0, "", fill=GRAY4, rounded=False) - ax.text( - x + 1.15, - 2.4, - f"Wątek {tid}", - fontsize=FS_LABEL, - fontweight="bold", - ha="center", - ) - items = [f"stos_{tid}", f"rejestry_{tid}", f"PC_{tid}", f"TID={40 + tid}"] - for j, item in enumerate(items): - ax.text( - x + 1.15, - 2.0 - j * 0.35, - item, - fontsize=FS_SMALL, - ha="center", - family="monospace", - ) - - save_fig(fig, "q9_thread_structure.png") - - -# ============================================================ -# 5. PCB structure -# ============================================================ -def gen_pcb_structure() -> None: - """Gen pcb structure.""" - fig, ax = plt.subplots(figsize=(5, 3.5)) - ax.set_xlim(0, 8) - ax.set_ylim(0, 5.5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "PCB (Process Control Block)", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - fields = [ - ("PID", "42"), - ("Stan", "READY / RUNNING / BLOCKED"), - ("Rejestry CPU", "EAX, EBX, ESP, EIP ..."), - ("Tablice stron", "mapowanie wirtualne → fizyczne"), - ("Otwarte pliki", "fd[0], fd[1], fd[2] ..."), - ("Priorytety", "nice value, scheduling class"), - ("Statystyki", "CPU time, I/O count"), - ] - - top_y = 4.8 - for i, (field, value) in enumerate(fields): - y = top_y - i * 0.55 - draw_box( - ax, - 0.5, - y, - 2.2, - 0.45, - field, - fill=GRAY2, - fontsize=FS, - fontweight="bold", - rounded=False, - ) - draw_box(ax, 2.7, y, 4.5, 0.45, value, fill=GRAY4, fontsize=FS, rounded=False) - - ax.text( - 4.0, - 0.3, - "Context switch = zapisz PCB starego → wczytaj PCB nowego", - fontsize=FS_SMALL, - ha="center", - style="italic", - ) - - save_fig(fig, "q9_pcb_structure.png") - - -# ============================================================ -# 6. Speed comparison -# ============================================================ -def gen_speed_comparison() -> None: - """Gen speed comparison.""" - fig, axes = plt.subplots(1, 2, figsize=(9, 3.5)) - fig.suptitle( - "Szybkość — procesy vs wątki (benchmarki Linux)", - fontsize=FS_TITLE, - fontweight="bold", - ) - - # Creation time panel - ax = axes[0] - ops = ["fork()\n(nowy proces)", "pthread_create()\n(nowy wątek)"] - times = [3.0, 0.05] # ms - colors = [GRAY3, GRAY1] - bars = ax.barh(ops, times, color=colors, edgecolor=LN, height=0.5, linewidth=1.2) - ax.set_xlabel("Czas [ms]", fontsize=FS) - ax.set_title("Tworzenie", fontsize=FS_LABEL, fontweight="bold") - ax.set_xlim(0, 4.5) - for bar, t in zip(bars, times, strict=False): - ax.text( - bar.get_width() + 0.1, - bar.get_y() + bar.get_height() / 2, - f"{t} ms", - va="center", - fontsize=FS, - ) - ax.text( - 2.5, - -0.6, - "~100x szybciej", - fontsize=FS, - ha="center", - fontweight="bold", - transform=ax.get_xaxis_transform(), - ) - ax.tick_params(labelsize=FS) - - # Right: context switch - ax = axes[1] - ops2 = ["Proces→Proces\n(TLB flush)", "Wątek→Wątek\n(TLB warm)"] - times2 = [3000, 300] # ns - bars2 = ax.barh(ops2, times2, color=colors, edgecolor=LN, height=0.5, linewidth=1.2) - ax.set_xlabel("Czas [ns]", fontsize=FS) - ax.set_title("Przełączanie kontekstu", fontsize=FS_LABEL, fontweight="bold") - ax.set_xlim(0, 4500) - for bar, t in zip(bars2, times2, strict=False): - ax.text( - bar.get_width() + 50, - bar.get_y() + bar.get_height() / 2, - f"{t} ns", - va="center", - fontsize=FS, - ) - ax.text( - 2500, - -0.6, - "~10x szybciej", - fontsize=FS, - ha="center", - fontweight="bold", - transform=ax.get_xaxis_transform(), - ) - ax.tick_params(labelsize=FS) - - fig.tight_layout(rect=[0, 0.05, 1, 0.92]) - save_fig(fig, "q9_speed_comparison.png") - - -# ============================================================ -# 7. Scenario table (when to use process vs thread) -# ============================================================ -def gen_scenario_table() -> None: - """Gen scenario table.""" - fig, ax = plt.subplots(figsize=(8.5, 4.5)) - ax.set_xlim(0, 11) - ax.set_ylim(-5.5, 1) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Kiedy proces, kiedy wątek? — typowe scenariusze", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - headers = ["Scenariusz", "Wybór", "Dlaczego?"] - col_w = [3.5, 2.5, 4.5] - rows = [ - ["Serwer WWW (Apache)", "Proces", "izolacja klientów"], - ["Serwer WWW (nginx)", "Wątek / async", "szybkość, cooperacja"], - ["Przeglądarka (karty)", "Proces", "crash isolation"], - ["Przeglądarka (JS+render)", "Wątek", "współdzielony DOM"], - ["Gra (fizyka+rendering)", "Wątek", "współdzielony świat gry"], - ["Kompilacja (make -j8)", "Proces", "izolacja, prostota"], - ["Baza danych (zapytania)", "Wątek", "współdzielony cache"], - ["Microservices", "Proces (kontener)", "izolacja, deployment"], - ] - draw_table( - ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7 - ) - - save_fig(fig, "q9_scenario_table.png") - - -# ============================================================ -# 8. IPC details: pipe, shared memory, socket (3-panel) -# ============================================================ -def gen_ipc_details() -> None: - """Gen ipc details.""" - fig, axes = plt.subplots(1, 3, figsize=(11, 3.5)) - fig.suptitle("Mechanizmy IPC — szczegóły", fontsize=FS_TITLE, fontweight="bold") - - # Panel 1: Pipe - ax = axes[0] - ax.set_xlim(0, 8) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Pipe (potok)", fontsize=FS_LABEL, fontweight="bold") - - draw_box(ax, 0.2, 2.0, 1.8, 1.2, "Proces A\n(ls)\nstdout", fill=GRAY1, fontsize=FS) - draw_box( - ax, - 3.0, - 2.0, - 1.8, - 1.2, - "Bufor\njądra\n(4 KB)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - draw_box(ax, 5.8, 2.0, 1.8, 1.2, "Proces B\n(grep)\nstdin", fill=GRAY1, fontsize=FS) - draw_arrow(ax, 2.0, 2.6, 3.0, 2.6, lw=1.5) - ax.text(2.5, 3.0, "write()\nfd[1]", fontsize=FS_SMALL, ha="center") - draw_arrow(ax, 4.8, 2.6, 5.8, 2.6, lw=1.5) - ax.text(5.3, 3.0, "read()\nfd[0]", fontsize=FS_SMALL, ha="center") - ax.text( - 4.0, - 0.8, - "Jednokierunkowy\nBufor pełny → write() blokuje", - fontsize=FS_SMALL, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # Panel 2: Shared Memory - ax = axes[1] - ax.set_xlim(0, 8) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Shared Memory", fontsize=FS_LABEL, fontweight="bold") - - draw_box(ax, 0.3, 3.0, 2.2, 1.2, "Proces A\nstrona 7", fill=GRAY1, fontsize=FS) - draw_box(ax, 5.5, 3.0, 2.2, 1.2, "Proces B\nstrona 3", fill=GRAY1, fontsize=FS) - draw_box( - ax, - 2.8, - 1.0, - 2.4, - 1.2, - "RAM\nramka 42", - fill=GRAY3, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, 2.0, 3.0, 3.5, 2.2, lw=1.5) - draw_arrow(ax, 6.0, 3.0, 4.5, 2.2, lw=1.5) - ax.text( - 4.0, - 0.3, - "Zero kopiowania!\nA pisze → B widzi od razu\nWymaga synchronizacji (semafor)", - fontsize=FS_SMALL, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - # Panel 3: Socket - ax = axes[2] - ax.set_xlim(0, 8) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Socket", fontsize=FS_LABEL, fontweight="bold") - - # Network socket - draw_box( - ax, 0.3, 3.2, 1.8, 0.9, "Klient", fill=GRAY1, fontsize=FS, fontweight="bold" - ) - draw_box( - ax, 5.5, 3.2, 1.8, 0.9, "Serwer", fill=GRAY1, fontsize=FS, fontweight="bold" - ) - draw_double_arrow(ax, 2.1, 3.65, 5.5, 3.65, lw=1.5) - ax.text(3.8, 4.3, "TCP/IP (sieciowy)", fontsize=FS, ha="center", fontweight="bold") - - # Unix socket - draw_box( - ax, 0.3, 1.3, 1.8, 0.9, "Proces A", fill=GRAY4, fontsize=FS, fontweight="bold" - ) - draw_box( - ax, 5.5, 1.3, 1.8, 0.9, "Proces B", fill=GRAY4, fontsize=FS, fontweight="bold" - ) - draw_double_arrow(ax, 2.1, 1.75, 5.5, 1.75, lw=1.5) - ax.text( - 3.8, - 2.4, - "Unix domain socket\n(/tmp/app.sock)", - fontsize=FS, - ha="center", - fontweight="bold", - ) - - ax.text( - 3.8, - 0.5, - "Dwukierunkowy\nNajbardziej uniwersalny IPC", - fontsize=FS_SMALL, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - fig.tight_layout(rect=[0, 0, 1, 0.9]) - save_fig(fig, "q9_ipc_details.png") - - -# ============================================================ -# 9. IPC comparison table -# ============================================================ -def gen_ipc_table() -> None: - """Gen ipc table.""" - fig, ax = plt.subplots(figsize=(8.5, 3.5)) - ax.set_xlim(0, 11) - ax.set_ylim(-4.5, 1) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Porównanie mechanizmów IPC", fontsize=FS_TITLE, fontweight="bold", pad=10 - ) - - headers = ["Mechanizm", "Kierunek", "Szybkość", "Zastosowanie"] - col_w = [2.5, 2.0, 2.5, 3.5] - rows = [ - ["Pipe", "jednokierunkowy", "średnia", "ls | grep"], - ["Named Pipe", "jednokierunkowy", "średnia", "demon → klient"], - ["Shared Memory", "dwukierunkowy", "NAJSZYBSZA", "video, bazy danych"], - ["Message Queue", "dwukierunkowy", "średnia", "wieloproducentowe"], - ["Socket", "dwukierunkowy", "wolna (sieć)", "klient-serwer"], - ["Signal", "jednokierunkowy", "natychmiastowa", "powiadomienia (nr)"], - ] - draw_table( - ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7.5 - ) - - save_fig(fig, "q9_ipc_table.png") - - -# ============================================================ -# 10. Race condition (simple x + bank timeline) -# ============================================================ -def gen_race_condition() -> None: - """Gen race condition.""" - fig, axes = plt.subplots(1, 2, figsize=(11, 5)) - fig.suptitle( - "Wyścig (Race Condition) — przykłady", fontsize=FS_TITLE, fontweight="bold" - ) - - # Panel 1: simple x increment - ax = axes[0] - ax.set_xlim(0, 8) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Prosty wyścig: x = x + 1", fontsize=FS_LABEL, fontweight="bold") - - # Timeline - steps_a = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"] - steps_b = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"] - ax.text(2.0, 6.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold") - ax.text(6.0, 6.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold") - ax.plot([2, 2], [0.8, 6.0], color=LN, lw=1) - ax.plot([6, 6], [0.8, 6.0], color=LN, lw=1) - - for i, (sa, sb) in enumerate(zip(steps_a, steps_b, strict=False)): - y = 5.3 - i * 1.2 - draw_box(ax, 0.5, y, 3.0, 0.6, sa, fill=GRAY4, fontsize=FS) - draw_box(ax, 4.5, y - 0.3, 3.0, 0.6, sb, fill=GRAY1, fontsize=FS) - - ax.text( - 4.0, - 0.4, - "Wynik: x = 1 (powinno 2!)", - fontsize=FS, - ha="center", - fontweight="bold", - color="#C62828", - bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, - ) - - # Panel 2: bank account - ax = axes[1] - ax.set_xlim(0, 10) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Konto bankowe: saldo = 1000 zł", fontsize=FS_LABEL, fontweight="bold") - - ax.text(2.5, 6.3, "Wątek A (+500)", fontsize=FS, ha="center", fontweight="bold") - ax.text(7.5, 6.3, "Wątek B (-200)", fontsize=FS, ha="center", fontweight="bold") - ax.plot([2.5, 2.5], [0.8, 6.0], color=LN, lw=1) - ax.plot([7.5, 7.5], [0.8, 6.0], color=LN, lw=1) - - events = [ - ("t1", "czytaj → 1000", "", 5.3), - ("t2", "", "czytaj → 1000", 4.6), - ("t3", "1000+500=1500", "", 3.9), - ("t4", "", "1000-200=800", 3.2), - ("t5", "zapisz 1500", "", 2.5), - ("t6", "", "zapisz 800 ✗", 1.8), - ] - for t, a, b, y in events: - ax.text(0.3, y + 0.15, t, fontsize=FS_SMALL, fontweight="bold", va="center") - if a: - draw_box(ax, 1.0, y, 3.0, 0.45, a, fill=GRAY4, fontsize=FS_SMALL) - if b: - fill = "#F8D7DA" if "✗" in b else GRAY1 - draw_box(ax, 6.0, y, 3.0, 0.45, b, fill=fill, fontsize=FS_SMALL) - - ax.text( - 5.0, - 0.4, - "Wynik: 800 zł (powinno 1300!)", - fontsize=FS, - ha="center", - fontweight="bold", - color="#C62828", - bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, - ) - - fig.tight_layout(rect=[0, 0, 1, 0.9]) - save_fig(fig, "q9_race_condition.png") - - -# ============================================================ -# 11. Deadlock scenario + cycle -# ============================================================ -def gen_deadlock_scenario() -> None: - """Gen deadlock scenario.""" - fig, axes = plt.subplots(1, 2, figsize=(11, 4.5)) - fig.suptitle("Zakleszczenie (Deadlock)", fontsize=FS_TITLE, fontweight="bold") - - # Panel 1: timeline - ax = axes[0] - ax.set_xlim(0, 8) - ax.set_ylim(0, 6) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Scenariusz z 2 mutexami", fontsize=FS_LABEL, fontweight="bold") - - ax.text(2.5, 5.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold") - ax.text(6.0, 5.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold") - - steps = [ - ("lock(mutex1) OK", "", "trzyma", False, 4.5), - ("", "lock(mutex2) OK", "trzyma", False, 3.7), - ("lock(mutex2) ...WAIT", "", "CZEKA!", True, 2.9), - ("", "lock(mutex1) ...WAIT", "CZEKA!", True, 2.1), - ] - for a_text, b_text, _note, is_wait, y in steps: - if a_text: - fill = "#F8D7DA" if is_wait else GRAY4 - draw_box(ax, 0.5, y, 3.3, 0.55, a_text, fill=fill, fontsize=FS_SMALL) - if b_text: - fill = "#F8D7DA" if is_wait else GRAY4 - draw_box(ax, 4.3, y, 3.3, 0.55, b_text, fill=fill, fontsize=FS_SMALL) - - ax.text( - 4.0, - 1.2, - "DEADLOCK!\nŻaden nie odpuści", - fontsize=FS, - ha="center", - fontweight="bold", - color="#C62828", - bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, - ) - - # Panel 2: cycle diagram - ax = axes[1] - ax.set_xlim(0, 8) - ax.set_ylim(0, 6) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Cykl oczekiwania", fontsize=FS_LABEL, fontweight="bold") - - # Thread boxes - draw_box( - ax, - 0.5, - 3.5, - 2.2, - 1.2, - "Wątek A\ntrzyma Mutex 1", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - draw_box( - ax, - 5.3, - 3.5, - 2.2, - 1.2, - "Wątek B\ntrzyma Mutex 2", - fill=GRAY1, - fontsize=FS, - fontweight="bold", - ) - - # Mutex boxes - draw_box( - ax, 0.5, 1.0, 2.2, 1.0, "Mutex 1", fill=GRAY3, fontsize=FS, fontweight="bold" - ) - draw_box( - ax, 5.3, 1.0, 2.2, 1.0, "Mutex 2", fill=GRAY3, fontsize=FS, fontweight="bold" - ) - - # holds arrows (down) - draw_arrow(ax, 1.6, 3.5, 1.6, 2.0, lw=2) - ax.text(0.9, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center") - draw_arrow(ax, 6.4, 3.5, 6.4, 2.0, lw=2) - ax.text(7.0, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center") - - # waits-for arrows (across, red) - draw_arrow(ax, 2.7, 4.3, 5.3, 4.3, lw=2.5, color="#C62828") - ax.text( - 4.0, - 4.7, - "czeka na Mutex 2", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - color="#C62828", - ) - draw_arrow(ax, 5.3, 3.7, 2.7, 3.7, lw=2.5, color="#C62828") - ax.text( - 4.0, - 3.1, - "czeka na Mutex 1", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - color="#C62828", - ) - - fig.tight_layout(rect=[0, 0, 1, 0.9]) - save_fig(fig, "q9_deadlock_scenario.png") - - -# ============================================================ -# 12. Coffman conditions + prevention strategies -# ============================================================ -def gen_coffman_strategies() -> None: - """Gen coffman strategies.""" - fig, ax = plt.subplots(figsize=(9, 4)) - ax.set_xlim(0, 11.5) - ax.set_ylim(-3.5, 1) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Warunki Coffmana — zapobieganie deadlockowi", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - headers = ["Warunek", "Opis", "Jak złamać", "Przykład"] - col_w = [2.5, 2.5, 3.0, 3.0] - rows = [ - [ - "1. Mutual Exclusion", - "zasób wyłączny", - "współdzielony zasób", - "Read-write lock", - ], - [ - "2. Hold and Wait", - "trzymaj + czekaj", - "bierz WSZYSTKIE naraz", - "lock(m1,m2) atomowo", - ], - [ - "3. No Preemption", - "nie zabierzesz siłą", - "timeout / trylock", - "pthread_mutex_trylock()", - ], - [ - "4. Circular Wait", - "cykliczne oczekiw.", - "porządek liniowy", - "zawsze m1 przed m2", - ], - ] - draw_table( - ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.6, fontsize=7 - ) - - ax.text( - 5.75, - -3.1, - "▸ Najczęstsza strategia: PORZĄDEK LINIOWY — " - "numeruj mutexy, zawsze blokuj rosnąco", - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q9_coffman_strategies.png") - - -# ============================================================ -# 13. Starvation + Priority Inversion (2-panel) -# ============================================================ -def gen_starvation_priority() -> None: - """Gen starvation priority.""" - fig, axes = plt.subplots(1, 2, figsize=(11, 4.5)) - fig.suptitle( - "Zagłodzenie i Inwersja priorytetów", fontsize=FS_TITLE, fontweight="bold" - ) - - # Panel 1: Starvation + aging - ax = axes[0] - ax.set_xlim(0, 8) - ax.set_ylim(0, 6) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Zagłodzenie (Starvation)", fontsize=FS_LABEL, fontweight="bold") - - threads = [ - ("Wątek HIGH", "prio=10", GRAY5, 3.0), - ("Wątek HIGH", "prio=9", GRAY3, 2.2), - ("Wątek MED", "prio=5", GRAY2, 1.4), - ("Wątek LOW", "prio=1 → głoduje!", "#F8D7DA", 0.6), - ] - for name, prio, color, y in threads: - draw_box( - ax, 0.5, y, 2.0, 0.6, name, fill=color, fontsize=FS_SMALL, fontweight="bold" - ) - ax.text(2.8, y + 0.3, prio, fontsize=FS_SMALL, va="center") - - ax.text( - 1.5, - 4.2, - "CPU zawsze\ndostaje HIGH!", - fontsize=FS, - ha="center", - fontweight="bold", - ) - draw_arrow(ax, 1.5, 3.9, 1.5, 3.65, lw=1.5) - - # Aging solution - draw_box(ax, 4.5, 1.5, 3.2, 2.5, "", fill=GRAY4, rounded=True) - ax.text(6.1, 3.7, "Rozwiązanie: AGING", fontsize=FS, fontweight="bold", ha="center") - aging = [ - "t=0: prio=1", - "t=100ms: prio=2", - "t=200ms: prio=3", - "...", - "w końcu → CPU!", - ] - for i, line in enumerate(aging): - ax.text( - 6.1, 3.2 - i * 0.4, line, fontsize=FS_SMALL, ha="center", family="monospace" - ) - - # Panel 2: Priority Inversion - ax = axes[1] - ax.set_xlim(0, 10) - ax.set_ylim(0, 6) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title("Inwersja priorytetów", fontsize=FS_LABEL, fontweight="bold") - - # Timeline - labels = ["H (wysoki)", "M (średni)", "L (niski)"] - ys = [4.2, 2.8, 1.4] - for label, y in zip(labels, ys, strict=False): - ax.text(0.3, y + 0.2, label, fontsize=FS, fontweight="bold", va="center") - - # L runs and locks mutex - draw_box(ax, 2.0, ys[2], 1.2, 0.5, "lock(m)", fill=GRAY1, fontsize=FS_SMALL) - - # M preempts L - draw_box(ax, 3.5, ys[1], 3.0, 0.5, "M pracuje...", fill=GRAY3, fontsize=FS_SMALL) - - # H waits for mutex - draw_box( - ax, - 3.5, - ys[0], - 3.0, - 0.5, - "CZEKA na mutex!", - fill="#F8D7DA", - fontsize=FS_SMALL, - fontweight="bold", - ) - - # M finishes, L continues, unlocks - draw_box(ax, 6.8, ys[2], 1.5, 0.5, "unlock(m)", fill=GRAY1, fontsize=FS_SMALL) - draw_box(ax, 8.5, ys[0], 1.2, 0.5, "H runs", fill=GRAY4, fontsize=FS_SMALL) - - # Explanation - ax.text( - 5.0, - 0.5, - "H czeka na M (mimo H > M)!\n" - "Rozwiązanie: Priority Inheritance\n" - "L dziedziczy priorytet H → M nie wypycha L", - fontsize=FS_SMALL, - ha="center", - style="italic", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.text( - 5.0, - 0.0, - "Mars Pathfinder (1997) — klasyczny bug!", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - ) - - fig.tight_layout(rect=[0, 0, 1, 0.9]) - save_fig(fig, "q9_starvation_priority.png") - - -# ============================================================ -# 14. Bounded buffer + readers-writers + philosophers -# ============================================================ -def _draw_bounded_buffer_panel(ax: Axes) -> None: - """Draw the bounded-buffer (producer-consumer) panel.""" - ax.set_xlim(0, 8) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Producent-Konsument\n(Bounded Buffer, N=4)", fontsize=FS, fontweight="bold" - ) - - draw_box( - ax, - 0.2, - 4.0, - 2.0, - 1.2, - "Producent\nP(empty)\nP(mutex)\nwstaw()\nV(mutex)\nV(full)", - fill=GRAY1, - fontsize=5.5, - ) - items = ["A", "B", "", ""] - for i, item in enumerate(items): - x = 2.8 + i * 0.9 - fill = GRAY3 if item else "white" - draw_box( - ax, - x, - 4.3, - 0.9, - 0.7, - item, - fill=fill, - fontsize=FS, - fontweight="bold", - rounded=False, - ) - ax.text(4.6, 5.2, "Bufor (N=4)", fontsize=FS_SMALL, ha="center", fontweight="bold") - draw_box( - ax, - 6.0, - 4.0, - 2.0, - 1.2, - "Konsument\nP(full)\nP(mutex)\npobierz()\nV(mutex)\nV(empty)", - fill=GRAY4, - fontsize=5.5, - ) - draw_arrow(ax, 2.2, 4.6, 2.8, 4.65, lw=1.2) - draw_arrow(ax, 6.4, 4.65, 6.0, 4.6, lw=1.2) - - sems = [("mutex = 1", GRAY2), ("empty = N", GRAY1), ("full = 0", GRAY3)] - for i, (s, c) in enumerate(sems): - draw_box( - ax, - 2.0, - 2.5 - i * 0.6, - 4.0, - 0.45, - s, - fill=c, - fontsize=FS_SMALL, - fontweight="bold", - ) - - ax.text( - 4.0, - 0.5, - "KOLEJNOŚĆ: P(empty/full)\nPRZED P(mutex)!\nOdwrotnie = DEADLOCK", - fontsize=5.5, - ha="center", - fontweight="bold", - color="#C62828", - bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"}, - ) - - -def _draw_readers_writers_panel(ax: Axes) -> None: - """Draw the readers-writers panel.""" - ax.set_xlim(0, 8) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Czytelnicy-Pisarze\n(Readers-Writers)", fontsize=FS, fontweight="bold" - ) - - draw_box( - ax, - 2.5, - 3.5, - 3.0, - 1.5, - "Dane\n(współdzielone)", - fill=GRAY2, - fontsize=FS, - fontweight="bold", - ) - - for i in range(3): - x = 0.3 + i * 1.0 - draw_box( - ax, - x, - 5.5, - 0.8, - 0.7, - f"R{i + 1}", - fill=GRAY4, - fontsize=FS, - fontweight="bold", - ) - draw_arrow(ax, x + 0.4, 5.5, 3.0 + i * 0.5, 5.0, lw=1) - - ax.text( - 1.5, - 6.5, - "Czytelnicy (wielu naraz)", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - ) - - draw_box( - ax, 5.5, 5.5, 1.5, 0.7, "Pisarz", fill=GRAY5, fontsize=FS, fontweight="bold" - ) - draw_arrow(ax, 6.25, 5.5, 5.0, 5.0, lw=1.5) - ax.text( - 6.25, - 6.5, - "WYŁĄCZNY", - fontsize=FS_SMALL, - ha="center", - fontweight="bold", - color="#C62828", - ) - - rules = [ - "Wielu czytelników = OK", - "Jeden pisarz = wyłączny", - "Czytelnik + Pisarz = ✗", - "Problem: pisarze głodują", - ] - for i, r in enumerate(rules): - ax.text(4.0, 2.5 - i * 0.45, r, fontsize=FS_SMALL, ha="center") - - ax.text( - 4.0, - 0.5, - "Rozwiązanie:\nrw_mutex + count_mutex\n+ zmienna readers", - fontsize=5.5, - ha="center", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - -def _draw_philosophers_panel(ax: Axes) -> None: - """Draw the dining-philosophers panel.""" - ax.set_xlim(0, 8) - ax.set_ylim(0, 7) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Ucztujący filozofowie\n(Dining Philosophers)", fontsize=FS, fontweight="bold" - ) - - cx, cy, r = 4.0, 3.8, 1.8 - table = plt.Circle((cx, cy), 0.8, fill=True, facecolor=GRAY2, edgecolor=LN, lw=1.5) - ax.add_patch(table) - ax.text(cx, cy, "Stół", fontsize=FS, ha="center", fontweight="bold") - - for i in range(5): - angle = np.pi / 2 + i * 2 * np.pi / 5 - px = cx + r * np.cos(angle) - py = cy + r * np.sin(angle) - circle = plt.Circle( - (px, py), 0.35, fill=True, facecolor=GRAY1, edgecolor=LN, lw=1.2 - ) - ax.add_patch(circle) - ax.text( - px, py, f"F{i}", ha="center", va="center", fontsize=FS, fontweight="bold" - ) - - fork_angle = np.pi / 2 + (i + 0.5) * 2 * np.pi / 5 - fx = cx + (r * 0.6) * np.cos(fork_angle) - fy = cy + (r * 0.6) * np.sin(fork_angle) - ax.plot( - [fx - 0.1, fx + 0.1], - [fy - 0.15, fy + 0.15], - color=LN, - lw=2.5, - solid_capstyle="round", - ) - ax.text(fx + 0.2, fy + 0.15, f"w{i}", fontsize=5, color="#555555") - - rules = [ - "Jedzenie = 2 widelce", - "Naiwne → DEADLOCK", - "Fix: F4 bierze odwrotnie", - "Alt: semafor(4)", - ] - for i, r in enumerate(rules): - ax.text(4.0, 1.2 - i * 0.35, r, fontsize=FS_SMALL, ha="center") - - -# ============================================================ -# 14. Bounded buffer + readers-writers + philosophers -# ============================================================ -def gen_classic_problems() -> None: - """Gen classic problems.""" - fig, axes = plt.subplots(1, 3, figsize=(12, 5)) - fig.suptitle( - "Klasyczne problemy synchronizacji", fontsize=FS_TITLE, fontweight="bold" - ) - - _draw_bounded_buffer_panel(axes[0]) - _draw_readers_writers_panel(axes[1]) - _draw_philosophers_panel(axes[2]) - - fig.tight_layout(rect=[0, 0, 1, 0.88]) - save_fig(fig, "q9_classic_problems.png") - - -# ============================================================ -# 15. Sync mechanisms comparison + mutex/sem/spinlock -# ============================================================ -def gen_sync_comparison() -> None: - """Gen sync comparison.""" - fig, axes = plt.subplots(2, 1, figsize=(9, 7)) - - # Top: comparison table - ax = axes[0] - ax.set_xlim(0, 11.5) - ax.set_ylim(-5, 1) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Mechanizmy synchronizacji — porównanie", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - headers = ["Mechanizm", "Opis", "Kiedy używać"] - col_w = [2.5, 4.5, 4.0] - rows = [ - ["Mutex", "Zamek: 1 wątek w sekcji", "Sekcja krytyczna"], - ["Semafor(n)", "Licznik: max n wątków", "Ograniczone zasoby (n miejsc)"], - ["Monitor", "Obiekt z wbudowanym mutex", "Java synchronized"], - ["Cond. Variable", "wait()/signal() na warunek", "Producent-konsument"], - ["Spinlock", "Aktywne czekanie (busy-wait)", "Bardzo krótkie sekcje (<1 μs)"], - ["RW Lock", "Wielu czytelników LUB 1 pisarz", "Bazy danych, cache"], - ["Barrier", "Czekaj aż wszyscy dotrą", "Obliczenia równoległe"], - ] - draw_table( - ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7 - ) - - # Bottom: mutex vs semafor vs spinlock - ax = axes[1] - ax.set_xlim(0, 12) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Mutex vs Semafor vs Spinlock", fontsize=FS_TITLE, fontweight="bold", pad=5 - ) - - # Mutex - draw_box(ax, 0.3, 2.5, 3.5, 2.0, "", fill=GRAY4) - ax.text(2.05, 4.2, "MUTEX", fontsize=FS_LABEL, fontweight="bold", ha="center") - ax.text(2.05, 3.6, "= klucz do łazienki\n(1 osoba)", fontsize=FS, ha="center") - ax.text( - 2.05, - 2.8, - "Wątek ZASYPIA gdy czeka\nOS go obudzi (~μs)", - fontsize=FS_SMALL, - ha="center", - style="italic", - ) - - # Semafor - draw_box(ax, 4.3, 2.5, 3.5, 2.0, "", fill=GRAY1) - ax.text(6.05, 4.2, "SEMAFOR(n)", fontsize=FS_LABEL, fontweight="bold", ha="center") - ax.text( - 6.05, 3.6, "= parking na n miejsc\n(n wątków naraz)", fontsize=FS, ha="center" - ) - ax.text( - 6.05, - 2.8, - "Semafor(1) = mutex\nP() = zmniejsz, V() = zwiększ", - fontsize=FS_SMALL, - ha="center", - style="italic", - ) - - # Spinlock - draw_box(ax, 8.3, 2.5, 3.5, 2.0, "", fill=GRAY2) - ax.text(10.05, 4.2, "SPINLOCK", fontsize=FS_LABEL, fontweight="bold", ha="center") - ax.text( - 10.05, 3.6, "= obrotowe drzwi\n(kręcisz się w kółko)", fontsize=FS, ha="center" - ) - ax.text( - 10.05, - 2.8, - "Wątek KRĘCI się w pętli\nLepszy gdy sekcja < 1 μs", - fontsize=FS_SMALL, - ha="center", - style="italic", - ) - - # Dividing rule - ax.text( - 6.0, - 1.5, - "Reguła kciuka: sekcja > 1 μs → MUTEX | " - "sekcja < 1 μs → SPINLOCK | n jednocześnie → SEMAFOR(n)", - fontsize=FS, - ha="center", - fontweight="bold", - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - fig.tight_layout() - save_fig(fig, "q9_sync_comparison.png") - - -# ============================================================ -# 16. Semaphore concept diagram -# ============================================================ -def gen_semaphore_concept() -> None: - """Gen semaphore concept.""" - fig, ax = plt.subplots(figsize=(6, 3)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 5) - ax.set_aspect("auto") - ax.axis("off") - ax.set_title( - "Semafor — koncepcja (parking na 3 miejsca)", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Parking slots - for i in range(3): - x = 2.0 + i * 2.0 - occupied = i < OCCUPIED_SLOTS - fill = GRAY3 if occupied else "white" - label = f"Wątek {i + 1}" if occupied else "(wolne)" - draw_box( - ax, - x, - 2.5, - 1.5, - 1.2, - label, - fill=fill, - fontsize=FS, - fontweight="bold" if occupied else "normal", - rounded=False, - ) - - ax.text( - 5.0, - 4.2, - "semafor(3): counter = 1 (jedno wolne miejsce)", - fontsize=FS, - ha="center", - fontweight="bold", - ) - - # Waiting thread - draw_box( - ax, - 0.2, - 0.5, - 1.5, - 0.8, - "Wątek 4\nP() → czeka", - fill="#F8D7DA", - fontsize=FS_SMALL, - ) - draw_arrow(ax, 1.7, 0.9, 2.0, 2.5, lw=1.2, color="#C62828") - - ax.text( - 5.0, - 0.6, - "P() = counter-- (jeśli 0 → czekaj)\nV() = counter++ (obudź czekającego)", - fontsize=FS, - ha="center", - family="monospace", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - save_fig(fig, "q9_semaphore_concept.png") - - # ============================================================ # MAIN — generate all # ============================================================ diff --git a/python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py b/python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py old mode 100755 new mode 100644 index ea7aeef..b68a5b1 --- a/python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py +++ b/python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py @@ -2,11 +2,12 @@ """Generate diagrams for PYTANIE 17: Szeregowanie zadań (Scheduling). Diagrams: - 1. Graham notation \u03b1|β|\u03b3 visual mnemonic map + 1. Graham notation α|β|γ visual mnemonic map 2. Johnson's algorithm Gantt chart (F2||Cmax example) 3. SPT vs LPT comparison Gantt (1||ΣCⱼ) 4. Flow shop vs Job shop visual comparison 5. Scheduling complexity landscape + 6. EDD example (1 || Lmax) All: A4-compatible, B&W, 300 DPI, laser-printer-friendly. """ @@ -14,1452 +15,46 @@ All: A4-compatible, B&W, 300 DPI, laser-printer-friendly. from __future__ import annotations import logging -from typing import TYPE_CHECKING import matplotlib as mpl mpl.use("Agg") -from pathlib import Path -import matplotlib.patches as mpatches -from matplotlib.patches import FancyBboxPatch -import matplotlib.pyplot as plt - -if TYPE_CHECKING: - from matplotlib.axes import Axes +# Re-export common utilities for backward compatibility +from python_pkg.praca_magisterska_video.generate_images._sched_common import ( # noqa: F401 + BG, + DPI, + FONTWEIGHT_THRESHOLD, + FS, + FS_TITLE, + GRAY1, + GRAY2, + GRAY3, + GRAY4, + GRAY5, + LN, + MIN_COLUMN_INDEX, + OUTPUT_DIR, + draw_arrow, + draw_box, +) +from python_pkg.praca_magisterska_video.generate_images._sched_complexity_edd import ( + draw_complexity_map, + draw_edd_example, +) +from python_pkg.praca_magisterska_video.generate_images._sched_graham import ( + draw_graham_notation, +) +from python_pkg.praca_magisterska_video.generate_images._sched_johnson import ( + draw_johnson_gantt, +) +from python_pkg.praca_magisterska_video.generate_images._sched_spt_flow_job import ( + draw_flow_vs_job, + draw_spt_comparison, +) _logger = logging.getLogger(__name__) -DPI = 300 -BG = "white" -LN = "black" -FS = 8 -FS_TITLE = 11 -OUTPUT_DIR = str(Path(__file__).resolve().parent / "img") -Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) - -GRAY1 = "#E8E8E8" -GRAY2 = "#D0D0D0" -GRAY3 = "#B8B8B8" -GRAY4 = "#F5F5F5" -GRAY5 = "#C0C0C0" - -MIN_COLUMN_INDEX = 3 -FONTWEIGHT_THRESHOLD = 3 - - -def draw_box( - ax: Axes, - x: float, - y: float, - w: float, - h: float, - text: str, - fill: str = "white", - lw: float = 1.2, - fontsize: float = FS, - fontweight: str = "normal", - ha: str = "center", - va: str = "center", - *, - rounded: bool = True, -) -> None: - """Draw box.""" - if rounded: - rect = FancyBboxPatch( - (x, y), w, h, boxstyle="round,pad=0.05", lw=lw, edgecolor=LN, facecolor=fill - ) - else: - rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill) - ax.add_patch(rect) - ax.text( - x + w / 2, - y + h / 2, - text, - ha=ha, - va=va, - fontsize=fontsize, - fontweight=fontweight, - wrap=True, - ) - - -def draw_arrow( - ax: Axes, - x1: float, - y1: float, - x2: float, - y2: float, - lw: float = 1.2, - style: str = "->", - color: str = LN, -) -> None: - """Draw arrow.""" - ax.annotate( - "", - xy=(x2, y2), - xytext=(x1, y1), - arrowprops={"arrowstyle": style, "color": color, "lw": lw}, - ) - - -# ============================================================ -# 1. GRAHAM NOTATION alpha|β|gamma — MNEMONIC MAP -# ============================================================ -def draw_graham_notation() -> None: - """Draw graham notation.""" - _fig, ax = plt.subplots(1, 1, figsize=(8.27, 10)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 14) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Notacja Grahama \u03b1 | β | \u03b3 — Mapa mnemoniczna", - fontsize=FS_TITLE + 1, - fontweight="bold", - pad=12, - ) - - _draw_graham_formula_bar(ax) - _draw_graham_alpha_beta(ax) - _draw_graham_lower(ax) - - plt.tight_layout() - plt.savefig( - str(Path(OUTPUT_DIR) / "scheduling_graham_notation.png"), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close() - _logger.info(" ✓ scheduling_graham_notation.png") - - -def _draw_graham_formula_bar(ax: Axes) -> None: - """Draw the top alpha|beta|gamma formula bar.""" - bar_y = 12.5 - bar_h = 1.0 - # alpha box - rect = FancyBboxPatch( - (0.5, bar_y), - 2.5, - bar_h, - boxstyle="round,pad=0.08", - lw=2, - edgecolor=LN, - facecolor=GRAY1, - ) - ax.add_patch(rect) - ax.text( - 1.75, - bar_y + bar_h / 2, - "\u03b1", - fontsize=20, - fontweight="bold", - ha="center", - va="center", - ) - ax.text( - 1.75, - bar_y - 0.25, - "MASZYNY", - fontsize=8, - fontweight="bold", - ha="center", - va="top", - color="#444444", - ) - - # separator | - ax.text( - 3.3, - bar_y + bar_h / 2, - "|", - fontsize=24, - fontweight="bold", - ha="center", - va="center", - ) - - # β box - rect = FancyBboxPatch( - (3.7, bar_y), - 2.5, - bar_h, - boxstyle="round,pad=0.08", - lw=2, - edgecolor=LN, - facecolor=GRAY2, - ) - ax.add_patch(rect) - ax.text( - 4.95, - bar_y + bar_h / 2, - "β", - fontsize=20, - fontweight="bold", - ha="center", - va="center", - ) - ax.text( - 4.95, - bar_y - 0.25, - "OGRANICZENIA", - fontsize=8, - fontweight="bold", - ha="center", - va="top", - color="#444444", - ) - - # separator | - ax.text( - 6.5, - bar_y + bar_h / 2, - "|", - fontsize=24, - fontweight="bold", - ha="center", - va="center", - ) - - # gamma box - rect = FancyBboxPatch( - (6.9, bar_y), - 2.5, - bar_h, - boxstyle="round,pad=0.08", - lw=2, - edgecolor=LN, - facecolor=GRAY3, - ) - ax.add_patch(rect) - ax.text( - 8.15, - bar_y + bar_h / 2, - "\u03b3", - fontsize=20, - fontweight="bold", - ha="center", - va="center", - ) - ax.text( - 8.15, - bar_y - 0.25, - "CEL", - fontsize=8, - fontweight="bold", - ha="center", - va="top", - color="#444444", - ) - - -def _draw_graham_alpha_beta(ax: Axes) -> None: - """Draw alpha (machines) and beta (constraints) sections.""" - start_x = 0.3 - col_w = 1.28 - - # === SECTION alpha: MACHINES === - sec_y = 11.5 - ax.text( - 0.3, - sec_y, - '\u03b1 — „1 Prawdziwy Quasi-Rycerz Forsuje Jaskinię Orków"', - fontsize=8, - fontweight="bold", - va="top", - style="italic", - color="#333333", - ) - - alpha_items = [ - ("1", "jedna maszyna", "●", GRAY4), - ("Pm", "identyczne Parallel", "●●●", GRAY1), - ("Qm", "Quasi-uniform\n(różne prędkości)", "●●◐", GRAY4), - ("Rm", "Random unrelated\n(czasy per para)", "●◆▲", GRAY1), - ("Fm", "Flow shop\n(ta sama kolejność)", "→→→", GRAY2), - ("Jm", "Job shop\n(indyw. trasy)", "↗↙↘", GRAY4), - ("Om", "Open shop\n(dowolna kolej.)", "?→?", GRAY1), - ] - - col_w = 1.28 - box_h_a = 1.1 - start_x = 0.3 - start_y = 9.6 - - for i, (symbol, desc, icon, fill) in enumerate(alpha_items): - x = start_x + i * col_w - y = start_y - rect = FancyBboxPatch( - (x, y), - col_w - 0.1, - box_h_a, - boxstyle="round,pad=0.04", - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + (col_w - 0.1) / 2, - y + box_h_a - 0.15, - symbol, - ha="center", - va="top", - fontsize=9, - fontweight="bold", - ) - ax.text( - x + (col_w - 0.1) / 2, - y + box_h_a / 2 - 0.1, - desc, - ha="center", - va="center", - fontsize=5.5, - ) - ax.text( - x + (col_w - 0.1) / 2, y + 0.12, icon, ha="center", va="bottom", fontsize=7 - ) - - # Complexity arrow under alpha - arr_y = 9.35 - ax.annotate( - "", - xy=(9.0, arr_y), - xytext=(0.5, arr_y), - arrowprops={"arrowstyle": "->", "color": "#666666", "lw": 1.5}, - ) - ax.text( - 4.8, - arr_y - 0.18, - "rosnąca złożoność →", - ha="center", - fontsize=6, - color="#666666", - ) - - # === SECTION β: CONSTRAINTS === - sec_y2 = 8.9 - ax.text( - 0.3, - sec_y2, - "β — „Robak Daje Deadline: Przerwy Poprzedzają Pojedyncze Setup'y\"", - fontsize=8, - fontweight="bold", - va="top", - style="italic", - color="#333333", - ) - - beta_items = [ - ("rⱼ", "release\ndates", "Robak\ndostępne\nod czasu rⱼ", GRAY1), - ("dⱼ", "due\ndates", "Daje\ntermin soft\n(kara za spóźn.)", GRAY4), - ("d̄ⱼ", "dead-\nlines", "Deadline\ntermin hard\n(musi dotrzymać)", GRAY1), - ("pmtn", "preemp-\ntion", "Przerwy\nmożna\nprzerwać", GRAY2), - ("prec", "prece-\ndencje", "Poprzedzają\nA->B (DAG)", GRAY4), - ("pⱼ=1", "unit\ntime", "Pojedyncze\nwszystkie = 1", GRAY1), - ("sⱼₖ", "setup\ntimes", "Setup'y\nprzezbrojenie\nmiędzy j->k", GRAY4), - ] - - start_y2 = 7.0 - box_h_b = 1.4 - - for i, (symbol, _label, desc, fill) in enumerate(beta_items): - x = start_x + i * col_w - y = start_y2 - rect = FancyBboxPatch( - (x, y), - col_w - 0.1, - box_h_b, - boxstyle="round,pad=0.04", - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + (col_w - 0.1) / 2, - y + box_h_b - 0.12, - symbol, - ha="center", - va="top", - fontsize=9, - fontweight="bold", - ) - ax.text( - x + (col_w - 0.1) / 2, - y + box_h_b / 2 - 0.05, - desc, - ha="center", - va="center", - fontsize=5, - ) - - -def _draw_graham_lower(ax: Axes) -> None: - """Draw gamma criteria, examples, and footer sections.""" - start_x = 0.3 - - # === SECTION gamma: CRITERIA === - sec_y3 = 6.5 - ax.text( - 0.3, - sec_y3, - '\u03b3 — „Ciężki Sum Ważony Lata, Tardiness Uderza"', - fontsize=8, - fontweight="bold", - va="top", - style="italic", - color="#333333", - ) - - gamma_items = [ - ("Cmax", "makespan\nmax(Cⱼ)", "Jak długo\ntrwa WSZYSTKO?", GRAY2), - ("ΣCⱼ", "suma\nukończeń", "Średni czas\noczekiwania?", GRAY4), - ("ΣwⱼCⱼ", "ważona\nsuma", "Priorytety\nzadań?", GRAY1), - ("Lmax", "max\nopóźnienie", "Najgorsze\nspóźnienie?", GRAY2), - ("ΣTⱼ", "suma\nspóźnień", "Łączne\nspóźnienia?", GRAY4), - ("ΣUⱼ", "liczba\nspóźnionych", "Ile spóźnionych\nzadań?", GRAY1), - ] - - start_y3 = 4.5 - box_h_g = 1.4 - col_w_g = 1.5 - - for i, (symbol, label, question, fill) in enumerate(gamma_items): - x = start_x + i * col_w_g - y = start_y3 - rect = FancyBboxPatch( - (x, y), - col_w_g - 0.1, - box_h_g, - boxstyle="round,pad=0.04", - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + (col_w_g - 0.1) / 2, - y + box_h_g - 0.1, - symbol, - ha="center", - va="top", - fontsize=9, - fontweight="bold", - ) - ax.text( - x + (col_w_g - 0.1) / 2, - y + box_h_g / 2 - 0.05, - label, - ha="center", - va="center", - fontsize=6, - ) - ax.text( - x + (col_w_g - 0.1) / 2, - y + 0.15, - f'„{question}"', - ha="center", - va="bottom", - fontsize=5, - style="italic", - ) - - # === BOTTOM: Example + Optimal methods === - ex_y = 3.5 - ax.text( - 0.3, - ex_y, - "Przykłady zapisu i optymalne metody:", - fontsize=8, - fontweight="bold", - va="top", - ) - - examples = [ - ("1 || ΣCⱼ", "SPT (najkrótsze\nnajpierw)", "O(n log n)", GRAY1), - ("1 || Lmax", "EDD (najwcześniejszy\ntermin)", "O(n log n)", GRAY4), - ("F2 || Cmax", "Algorytm\nJohnsona", "O(n log n)", GRAY2), - ("Pm || Cmax", "LPT heurystyka\n(NP-trudny!)", "NP-hard", GRAY3), - ("Jm || Cmax", "Branch & Bound\n(NP-trudny!)", "NP-hard", GRAY5), - ] - - ex_start_y = 1.8 - ex_box_w = 1.72 - ex_box_h = 1.4 - - for i, (notation, method, complexity, fill) in enumerate(examples): - x = start_x + i * (ex_box_w + 0.1) - y = ex_start_y - rect = FancyBboxPatch( - (x, y), - ex_box_w, - ex_box_h, - boxstyle="round,pad=0.04", - lw=1.2, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - x + ex_box_w / 2, - y + ex_box_h - 0.12, - notation, - ha="center", - va="top", - fontsize=8, - fontweight="bold", - fontfamily="monospace", - ) - ax.text( - x + ex_box_w / 2, - y + ex_box_h / 2 - 0.05, - method, - ha="center", - va="center", - fontsize=6, - ) - ax.text( - x + ex_box_w / 2, - y + 0.12, - complexity, - ha="center", - va="bottom", - fontsize=6.5, - fontweight="bold", - color="#555555", - ) - - # Footer mnemonic summary - ax.text( - 5.0, - 0.8, - '„\u03b1|β|\u03b3 = Maszyny | Ograniczenia | Cel"', - ha="center", - fontsize=9, - fontweight="bold", - style="italic", - color="#333333", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - "lw": 1, - }, - ) - - ax.text( - 5.0, - 0.2, - "\u03b1: ILE maszyn i JAKIE? " - "β: JAKIE ograniczenia zadań? " - "\u03b3: CO minimalizujemy?", - ha="center", - fontsize=7, - color="#555555", - ) - - -# ============================================================ -# 2. JOHNSON'S ALGORITHM GANTT CHART -# ============================================================ -def draw_johnson_gantt() -> None: - """Draw johnson gantt.""" - _fig, axes = plt.subplots( - 2, 1, figsize=(8.27, 7), gridspec_kw={"height_ratios": [1, 1.8]} - ) - - _draw_johnson_decision_table(axes[0]) - _draw_johnson_gantt_chart(axes[1]) - - plt.tight_layout() - plt.savefig( - str(Path(OUTPUT_DIR) / "scheduling_johnson_gantt.png"), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close() - _logger.info(" ✓ scheduling_johnson_gantt.png") - - -def _draw_johnson_decision_table(ax: Axes) -> None: - """Draw the Johnson algorithm decision table.""" - ax.set_xlim(0, 10) - ax.set_ylim(0, 5) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Algorytm Johnsona (F2 || Cmax) — Decyzja + Diagram Gantta", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Task table - tasks = ["J1", "J2", "J3", "J4", "J5"] - a_times = [4, 2, 6, 1, 3] - b_times = [5, 3, 2, 7, 4] - min_vals = [min(a, b) for a, b in zip(a_times, b_times, strict=False)] - min_on = ["M1" if a <= b else "M2" for a, b in zip(a_times, b_times, strict=False)] - assign = ["POCZątek" if m == "M1" else "KONIEC" for m in min_on] - - # Draw table - col_w_t = 1.3 - row_h = 0.55 - headers = ["Zadanie", "aⱼ (M1)", "bⱼ (M2)", "min", "min na", "Przydziel"] - table_x = 0.8 - table_y = 3.8 - - for j, hdr in enumerate(headers): - x = table_x + j * col_w_t - rect = mpatches.Rectangle( - (x, table_y), col_w_t, row_h, lw=1, edgecolor=LN, facecolor=GRAY2 - ) - ax.add_patch(rect) - ax.text( - x + col_w_t / 2, - table_y + row_h / 2, - hdr, - ha="center", - va="center", - fontsize=6.5, - fontweight="bold", - ) - - for i in range(5): - row_data = [ - tasks[i], - str(a_times[i]), - str(b_times[i]), - str(min_vals[i]), - min_on[i], - assign[i], - ] - for j, val in enumerate(row_data): - x = table_x + j * col_w_t - y = table_y - (i + 1) * row_h - fill_c = GRAY1 if min_on[i] == "M1" else GRAY4 - if j == MIN_COLUMN_INDEX: # min column - highlight - fill_c = GRAY3 - rect = mpatches.Rectangle( - (x, y), col_w_t, row_h, lw=0.8, edgecolor=LN, facecolor=fill_c - ) - ax.add_patch(rect) - fw = "bold" if j >= FONTWEIGHT_THRESHOLD else "normal" - ax.text( - x + col_w_t / 2, - y + row_h / 2, - val, - ha="center", - va="center", - fontsize=6.5, - fontweight=fw, - ) - - # Sorting result - result_y = 0.7 - ax.text( - 5.0, - result_y + 0.4, - "Sortuj → POCZĄTEK ↑aⱼ: J4(1), J2(2), J5(3), J1(4) | KONIEC ↓bⱼ: J3(2)", - ha="center", - fontsize=7, - color="#333333", - ) - ax.text( - 5.0, - result_y, - "Optymalna kolejność: J4 → J2 → J5 → J1 → J3", - ha="center", - fontsize=9, - fontweight="bold", - bbox={ - "boxstyle": "round,pad=0.2", - "facecolor": GRAY1, - "edgecolor": LN, - "lw": 1.2, - }, - ) - - -def _draw_johnson_gantt_chart(ax2: Axes) -> None: - """Draw the Johnson algorithm Gantt chart.""" - ax2.set_xlim(-1, 24) - ax2.set_ylim(-1, 4) - ax2.axis("off") - - # Machines labels - m1_y = 2.5 - m2_y = 0.8 - bar_h = 0.9 - - ax2.text( - -0.8, - m1_y + bar_h / 2, - "M1", - ha="center", - va="center", - fontsize=11, - fontweight="bold", - ) - ax2.text( - -0.8, - m2_y + bar_h / 2, - "M2", - ha="center", - va="center", - fontsize=11, - fontweight="bold", - ) - - # Schedule: J4 → J2 → J5 → J1 → J3 - order = ["J4", "J2", "J5", "J1", "J3"] - a_ord = [1, 2, 3, 4, 6] # M1 times in order - b_ord = [7, 3, 4, 5, 2] # M2 times in order - fills = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5] - hatches = ["", "///", "", "\\\\\\", "xxx"] - - # M1 schedule - m1_starts = [] - t = 0 - for a in a_ord: - m1_starts.append(t) - t += a - m1_ends = [s + a for s, a in zip(m1_starts, a_ord, strict=False)] - - # M2 schedule (must wait for M1 finish AND previous M2 finish) - m2_starts = [] - m2_ends = [] - prev_m2_end = 0 - for i, b in enumerate(b_ord): - start = max(m1_ends[i], prev_m2_end) - m2_starts.append(start) - m2_ends.append(start + b) - prev_m2_end = start + b - - # Draw M1 bars - for i in range(5): - rect = mpatches.Rectangle( - (m1_starts[i], m1_y), - a_ord[i], - bar_h, - lw=1.2, - edgecolor=LN, - facecolor=fills[i], - hatch=hatches[i], - ) - ax2.add_patch(rect) - ax2.text( - m1_starts[i] + a_ord[i] / 2, - m1_y + bar_h / 2, - f"{order[i]}\n({a_ord[i]})", - ha="center", - va="center", - fontsize=7, - fontweight="bold", - ) - - # Draw M2 bars - for i in range(5): - rect = mpatches.Rectangle( - (m2_starts[i], m2_y), - b_ord[i], - bar_h, - lw=1.2, - edgecolor=LN, - facecolor=fills[i], - hatch=hatches[i], - ) - ax2.add_patch(rect) - ax2.text( - m2_starts[i] + b_ord[i] / 2, - m2_y + bar_h / 2, - f"{order[i]}\n({b_ord[i]})", - ha="center", - va="center", - fontsize=7, - fontweight="bold", - ) - - # Draw idle regions on M2 - idle_starts = [0] - idle_ends = [m2_starts[0]] - for i in range(1, 5): - if m2_starts[i] > m2_ends[i - 1]: - idle_starts.append(m2_ends[i - 1]) - idle_ends.append(m2_starts[i]) - - for s, e in zip(idle_starts, idle_ends, strict=False): - if e > s: - rect = mpatches.Rectangle( - (s, m2_y), - e - s, - bar_h, - lw=0.5, - edgecolor="#AAAAAA", - facecolor="white", - linestyle="--", - ) - ax2.add_patch(rect) - ax2.text( - s + (e - s) / 2, - m2_y + bar_h / 2, - "idle", - ha="center", - va="center", - fontsize=5, - color="#999999", - ) - - # Time axis - ax_y = m2_y - 0.15 - ax2.plot([0, 23], [ax_y, ax_y], color=LN, lw=0.8) - for t in range(0, 24, 2): - ax2.plot([t, t], [ax_y - 0.08, ax_y + 0.08], color=LN, lw=0.8) - ax2.text(t, ax_y - 0.25, str(t), ha="center", va="top", fontsize=6) - ax2.text(11.5, ax_y - 0.55, "czas", ha="center", fontsize=7) - - # Cmax annotation - ax2.annotate( - f"Cmax = {m2_ends[-1]}", - xy=(m2_ends[-1], m2_y + bar_h), - xytext=(m2_ends[-1] + 0.5, m2_y + bar_h + 0.6), - fontsize=10, - fontweight="bold", - color="#333333", - arrowprops={"arrowstyle": "->", "color": "#333333", "lw": 1.5}, - ) - - # Mnemonic at bottom - ax2.text( - 11, - -0.7, - '„Krótki na M1 → START (szybko karmi M2)' - ' Krótki na M2 → KONIEC' - ' (szybko kończy)"', - ha="center", - fontsize=7.5, - fontweight="bold", - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - "lw": 0.8, - }, - ) - - -# ============================================================ -# 3. SPT vs LPT COMPARISON (1 || ΣCⱼ) -# ============================================================ -def draw_spt_comparison() -> None: - """Draw spt comparison.""" - fig, axes = plt.subplots(2, 1, figsize=(8.27, 5.5)) - - tasks_orig = [("J1", 5), ("J2", 3), ("J3", 8), ("J4", 2), ("J5", 6)] - - spt_order = sorted(tasks_orig, key=lambda x: x[1]) - lpt_order = sorted(tasks_orig, key=lambda x: -x[1]) - - fills_map = {"J1": GRAY1, "J2": GRAY2, "J3": GRAY3, "J4": GRAY4, "J5": GRAY5} - hatch_map = {"J1": "", "J2": "///", "J3": "xxx", "J4": "", "J5": "\\\\\\"} - - for _idx, (ax, order_list, title, is_optimal) in enumerate( - [ - (axes[0], spt_order, "SPT (Shortest Processing Time) — OPTYMALNE", True), - (axes[1], lpt_order, "LPT (Longest Processing Time) — gorsze!", False), - ] - ): - ax.set_xlim(-2, 26) - ax.set_ylim(-0.5, 2.5) - ax.axis("off") - color = "#222222" if is_optimal else "#666666" - marker = "✓" if is_optimal else "✗" - ax.set_title( - f"{marker} {title}", - fontsize=9, - fontweight="bold", - loc="left", - color=color, - pad=5, - ) - - bar_y = 1.0 - bar_h = 0.8 - t = 0 - completions = [] - - for name, duration in order_list: - rect = mpatches.Rectangle( - (t, bar_y), - duration, - bar_h, - lw=1.2, - edgecolor=LN, - facecolor=fills_map[name], - hatch=hatch_map[name], - ) - ax.add_patch(rect) - ax.text( - t + duration / 2, - bar_y + bar_h / 2, - f"{name}\n({duration})", - ha="center", - va="center", - fontsize=7, - fontweight="bold", - ) - t += duration - completions.append(t) - - # Completion time marker - ax.plot([t, t], [bar_y - 0.15, bar_y], color=LN, lw=0.8) - ax.text( - t, - bar_y - 0.25, - f"C={t}", - ha="center", - va="top", - fontsize=6, - color="#555555", - ) - - total = sum(completions) - # Time axis - ax.plot([0, 25], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5) - - # Sum annotation - comp_str = " + ".join(str(c) for c in completions) - ax.text( - 25, - bar_y + bar_h / 2, - f"ΣCⱼ = {comp_str}\n = {total}", - ha="left", - va="center", - fontsize=7, - fontweight="bold" if is_optimal else "normal", - color=color, - bbox={ - "boxstyle": "round,pad=0.2", - "facecolor": GRAY1 if is_optimal else "white", - "edgecolor": color, - "lw": 1, - }, - ) - - # Bottom annotation - fig.text( - 0.5, - 0.02, - '„Short People To the front"' - ' — krótkie najpierw,' - ' jak niskie osoby w zdjęciu klasowym', - ha="center", - fontsize=8, - fontweight="bold", - style="italic", - color="#444444", - ) - - plt.tight_layout(rect=[0, 0.05, 1, 1]) - plt.savefig( - str(Path(OUTPUT_DIR) / "scheduling_spt_comparison.png"), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close() - _logger.info(" ✓ scheduling_spt_comparison.png") - - -# ============================================================ -# 4. FLOW SHOP vs JOB SHOP -# ============================================================ -def draw_flow_vs_job() -> None: - """Draw flow vs job.""" - _fig, axes = plt.subplots(1, 2, figsize=(8.27, 4.5)) - - _draw_flow_shop(axes[0]) - _draw_job_shop(axes[1]) - - plt.tight_layout() - plt.savefig( - str(Path(OUTPUT_DIR) / "scheduling_flow_vs_job.png"), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close() - _logger.info(" ✓ scheduling_flow_vs_job.png") - - -def _draw_flow_shop(ax: Axes) -> None: - """Draw the Flow Shop diagram.""" - ax.set_xlim(0, 6) - ax.set_ylim(0, 6) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Flow Shop (Fm)", fontsize=10, fontweight="bold", pad=8) - - # Machines in a row - machines_x = [1, 3, 5] - machines_y = 3 - mach_r = 0.4 - - for i, mx in enumerate(machines_x): - circle = plt.Circle( - (mx, machines_y), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5 - ) - ax.add_patch(circle) - ax.text( - mx, - machines_y, - f"M{i + 1}", - ha="center", - va="center", - fontsize=9, - fontweight="bold", - ) - - # Arrows between machines - for i in range(len(machines_x) - 1): - draw_arrow( - ax, - machines_x[i] + mach_r + 0.05, - machines_y, - machines_x[i + 1] - mach_r - 0.05, - machines_y, - lw=2, - ) - - # Jobs all flowing the same way - jobs_flow = ["J1", "J2", "J3"] - for _j, (job, y_off) in enumerate(zip(jobs_flow, [0.8, 0, -0.8], strict=False)): - ax.text( - 0.2, - machines_y + y_off, - job, - ha="center", - va="center", - fontsize=7, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.15", "facecolor": GRAY1, "edgecolor": LN}, - ) - # Dashed flow line - ax.annotate( - "", - xy=(5.5, machines_y + y_off * 0.3), - xytext=(0.5, machines_y + y_off), - arrowprops={ - "arrowstyle": "->", - "color": "#888888", - "lw": 0.8, - "linestyle": "dashed", - }, - ) - - ax.text( - 3, - 1.2, - "Wszystkie zadania:\nM1 → M2 → M3", - ha="center", - va="center", - fontsize=8, - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.text( - 3, - 0.4, - "Jak taśma montażowa", - ha="center", - fontsize=7, - style="italic", - color="#666666", - ) - - -def _draw_job_shop(ax: Axes) -> None: - """Draw the Job Shop diagram.""" - mach_r = 0.4 - ax.set_xlim(0, 6) - ax.set_ylim(0, 6) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title("Job Shop (Jm)", fontsize=10, fontweight="bold", pad=8) - - # Machines scattered - m_positions = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)] - - for i, (mx, my) in enumerate(m_positions): - circle = plt.Circle((mx, my), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5) - ax.add_patch(circle) - ax.text( - mx, my, f"M{i + 1}", ha="center", va="center", fontsize=9, fontweight="bold" - ) - - # J1: M1 → M2 → M3 (solid) - route1 = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)] - for i in range(len(route1) - 1): - x1, y1 = route1[i] - x2, y2 = route1[i + 1] - dx = x2 - x1 - dy = y2 - y1 - d = (dx**2 + dy**2) ** 0.5 - draw_arrow( - ax, - x1 + mach_r * dx / d + 0.05, - y1 + mach_r * dy / d, - x2 - mach_r * dx / d - 0.05, - y2 - mach_r * dy / d, - lw=1.5, - ) - ax.text( - 0.3, - 4.8, - "J1: M1→M2→M3", - fontsize=7, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY1, "edgecolor": LN}, - ) - - # J2: M2 → M3 → M1 (dashed) - route2 = [(4.5, 4.2), (3, 2.5), (1.5, 4.2)] - for i in range(len(route2) - 1): - x1, y1 = route2[i] - x2, y2 = route2[i + 1] - dx = x2 - x1 - dy = y2 - y1 - d = (dx**2 + dy**2) ** 0.5 - off = 0.15 # offset to avoid overlap - ax.annotate( - "", - xy=(x2 - mach_r * dx / d - 0.05, y2 - mach_r * dy / d + off), - xytext=(x1 + mach_r * dx / d + 0.05, y1 + mach_r * dy / d + off), - arrowprops={ - "arrowstyle": "->", - "color": "#555555", - "lw": 1.5, - "linestyle": "dashed", - }, - ) - ax.text( - 3.8, - 5.2, - "J2: M2→M3→M1", - fontsize=7, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY4, "edgecolor": LN}, - ) - - ax.text( - 3, - 1.2, - "Każde zadanie:\nwłasna trasa!", - ha="center", - va="center", - fontsize=8, - bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3}, - ) - - ax.text( - 3, - 0.4, - "NP-trudny już dla 3 maszyn", - ha="center", - fontsize=7, - style="italic", - color="#666666", - ) - - -# ============================================================ -# 5. SCHEDULING COMPLEXITY LANDSCAPE -# ============================================================ -def draw_complexity_map() -> None: - """Draw complexity map.""" - _fig, ax = plt.subplots(1, 1, figsize=(8.27, 5)) - ax.set_xlim(0, 10) - ax.set_ylim(0, 7) - ax.set_aspect("equal") - ax.axis("off") - ax.set_title( - "Złożoność problemów szeregowania — od łatwych do NP-trudnych", - fontsize=FS_TITLE, - fontweight="bold", - pad=10, - ) - - # Gradient arrow at the top - ax.annotate( - "", - xy=(9.5, 6.2), - xytext=(0.5, 6.2), - arrowprops={"arrowstyle": "->", "color": LN, "lw": 2}, - ) - ax.text(5, 6.5, "Rosnąca złożoność", ha="center", fontsize=9, fontweight="bold") - - # Easy (polynomial) region - easy_rect = FancyBboxPatch( - (0.3, 2.8), - 4.0, - 3.0, - boxstyle="round,pad=0.15", - lw=1.5, - edgecolor="#666666", - facecolor=GRAY4, - linestyle="-", - ) - ax.add_patch(easy_rect) - ax.text( - 2.3, - 5.5, - "WIELOMIANOWE O(n log n)", - ha="center", - fontsize=9, - fontweight="bold", - color="#444444", - ) - - easy_problems = [ - ("1 || ΣCⱼ", "SPT", GRAY1, 4.8), - ("1 || Lmax", "EDD", GRAY2, 4.0), - ("F2 || Cmax", "Johnson", GRAY1, 3.2), - ] - for prob, method, fill, y in easy_problems: - rect = FancyBboxPatch( - (0.6, y), - 3.5, - 0.6, - boxstyle="round,pad=0.05", - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - 1.2, - y + 0.3, - prob, - ha="center", - va="center", - fontsize=8, - fontweight="bold", - fontfamily="monospace", - ) - ax.text(3.0, y + 0.3, f"→ {method}", ha="center", va="center", fontsize=8) - - # Hard (NP) region - hard_rect = FancyBboxPatch( - (5.3, 2.8), - 4.3, - 3.0, - boxstyle="round,pad=0.15", - lw=1.5, - edgecolor="#444444", - facecolor=GRAY3, - linestyle="-", - ) - ax.add_patch(hard_rect) - ax.text( - 7.45, - 5.5, - "NP-TRUDNE", - ha="center", - fontsize=9, - fontweight="bold", - color="#333333", - ) - - hard_problems = [ - ("Pm || Cmax\n(m≥2)", "LPT heuryst.", GRAY2, 4.5), - ("1 || ΣTⱼ", "branch&bound", GRAY4, 3.7), - ("Jm || Cmax\n(m≥3)", "metaheuryst.", GRAY5, 2.9), - ] - for prob, method, fill, y in hard_problems: - rect = FancyBboxPatch( - (5.6, y), - 3.7, - 0.7, - boxstyle="round,pad=0.05", - lw=1, - edgecolor=LN, - facecolor=fill, - ) - ax.add_patch(rect) - ax.text( - 6.5, - y + 0.35, - prob, - ha="center", - va="center", - fontsize=7, - fontweight="bold", - fontfamily="monospace", - ) - ax.text(8.2, y + 0.35, f"→ {method}", ha="center", va="center", fontsize=7) - - # Arrow connecting - draw_arrow(ax, 4.4, 4.0, 5.2, 4.0, lw=2, color="#888888") - ax.text(4.8, 4.25, "+1\nmaszyna", ha="center", fontsize=6, color="#888888") - - # Bottom: key insight - ax.text( - 5.0, - 1.8, - "„Dodanie jednej maszyny lub jednego ograniczenia\n" - 'może zmienić problem z łatwego na NP-trudny!"', - ha="center", - fontsize=8, - fontweight="bold", - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - "lw": 1, - }, - ) - - # Bottom examples - ax.text( - 5.0, - 0.8, - "1 maszyna → łatwe (sortuj) | ≥2 maszyny równoległe → NP-trudne\n" - "Flow shop 2 maszyny → Johnson O(n log n) | Flow shop 3 maszyny → NP-trudne", - ha="center", - fontsize=7, - color="#555555", - ) - - plt.tight_layout() - plt.savefig( - str(Path(OUTPUT_DIR) / "scheduling_complexity_map.png"), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close() - _logger.info(" ✓ scheduling_complexity_map.png") - - -# ============================================================ -# 6. EDD EXAMPLE (1 || Lmax) -# ============================================================ -def draw_edd_example() -> None: - """Draw edd example.""" - _fig, ax = plt.subplots(1, 1, figsize=(8.27, 4)) - ax.set_xlim(-2, 28) - ax.set_ylim(-2, 4) - ax.axis("off") - ax.set_title( - "EDD (Earliest Due Date) — 1 || Lmax — Przykład", - fontsize=FS_TITLE, - fontweight="bold", - pad=8, - ) - - # Tasks: name, processing time, due date - tasks = [("J1", 4, 10), ("J2", 2, 6), ("J3", 6, 15), ("J4", 3, 8), ("J5", 5, 18)] - # EDD: sort by due date - edd_order = sorted(tasks, key=lambda x: x[2]) - - bar_y = 1.5 - bar_h = 0.8 - t = 0 - fills_edd = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5] - - lateness_vals = [] - for i, (name, p, d) in enumerate(edd_order): - rect = mpatches.Rectangle( - (t, bar_y), p, bar_h, lw=1.2, edgecolor=LN, facecolor=fills_edd[i] - ) - ax.add_patch(rect) - ax.text( - t + p / 2, - bar_y + bar_h / 2, - f"{name}\np={p}, d={d}", - ha="center", - va="center", - fontsize=6.5, - fontweight="bold", - ) - t += p - lateness = t - d - lateness_vals.append(lateness) - - # Due date marker - ax.plot( - [d, d], [bar_y - 0.4, bar_y - 0.1], color="#888888", lw=0.8, linestyle="--" - ) - ax.text( - d, - bar_y - 0.5, - f"d={d}", - ha="center", - va="top", - fontsize=5.5, - color="#888888", - ) - - # Completion + lateness - ax.plot([t, t], [bar_y + bar_h, bar_y + bar_h + 0.15], color=LN, lw=0.8) - ax.text( - t, - bar_y + bar_h + 0.2, - f"C={t}\nL={lateness}", - ha="center", - va="bottom", - fontsize=5.5, - ) - - # Time axis - ax.plot([0, 22], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5) - - lmax = max(lateness_vals) - ax.text( - 22, - bar_y + bar_h / 2, - f"Lmax = {lmax}", - ha="left", - va="center", - fontsize=10, - fontweight="bold", - bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY1, "edgecolor": LN}, - ) - - # Bottom mnemonic - ax.text( - 10, - -1.3, - '„Early Due Date Does it first" — najpilniejszy deadline idzie pierwszy', - ha="center", - fontsize=8, - fontweight="bold", - style="italic", - bbox={ - "boxstyle": "round,pad=0.3", - "facecolor": GRAY4, - "edgecolor": GRAY3, - "lw": 0.8, - }, - ) - - plt.tight_layout() - plt.savefig( - str(Path(OUTPUT_DIR) / "scheduling_edd_example.png"), - dpi=DPI, - bbox_inches="tight", - facecolor=BG, - ) - plt.close() - _logger.info(" ✓ scheduling_edd_example.png") - - # ============================================================ # MAIN # ============================================================ diff --git a/python_pkg/praca_magisterska_video/visualize_q23.py b/python_pkg/praca_magisterska_video/visualize_q23.py index 91894e8..20cd7e7 100644 --- a/python_pkg/praca_magisterska_video/visualize_q23.py +++ b/python_pkg/praca_magisterska_video/visualize_q23.py @@ -1,1548 +1,33 @@ """MoviePy visualization for PYTANIE 23: Image Segmentation. -Creates animated video demonstrating: -- What segmentation is (pixel-level classification) -- Thresholding / Otsu (bimodal histogram) -- Region Growing (BFS flood fill) -- Watershed (topographic flooding) -- U-Net encoder-decoder architecture +Thin orchestrator that assembles sections from submodules into +the final video. """ from __future__ import annotations -import logging -import os -from pathlib import Path +from moviepy import VideoClip, concatenate_videoclips -import numpy as np - -os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg" - -from moviepy import ( - ColorClip, - CompositeVideoClip, - TextClip, - VideoClip, - concatenate_videoclips, +from python_pkg.praca_magisterska_video._q23_classical import ( + _region_growing_demo, + _segmentation_concept, + _thresholding_demo, + _watershed_demo, ) -from moviepy.video.fx import FadeIn, FadeOut +from python_pkg.praca_magisterska_video._q23_deeplab import _deeplab_demo +from python_pkg.praca_magisterska_video._q23_helpers import ( + FPS, + OUTPUT, + _logger, + _make_header, +) +from python_pkg.praca_magisterska_video._q23_transformer import ( + _methods_comparison, + _transformer_seg_demo, +) +from python_pkg.praca_magisterska_video._q23_unet_fcn import _fcn_demo, _unet_demo -# ── Constants ───────────────────────────────────────────────────── -W, H = 1280, 720 -FPS = 24 -STEP_DUR = 7.0 -HEADER_DUR = 4.0 -FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf" -FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf" -OUTPUT_DIR = Path(__file__).resolve().parent / "videos" -OUTPUT_DIR.mkdir(parents=True, exist_ok=True) -OUTPUT = str(OUTPUT_DIR / "q23_segmentation.mp4") -logging.basicConfig(level=logging.INFO) -_logger = logging.getLogger(__name__) - -BG_COLOR = (15, 20, 35) -rng = np.random.default_rng(42) - - -def _tc(**kwargs: object) -> TextClip: - """TextClip wrapper that adds enough bottom margin to prevent clipping.""" - fs = kwargs.get("font_size", 24) - m = int(fs) // 3 + 2 - kwargs["margin"] = (0, m) - return TextClip(**kwargs) - - -def _make_header( - title: str, subtitle: str, duration: float = HEADER_DUR -) -> CompositeVideoClip: - bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) - t = ( - _tc( - text=title, - font_size=48, - color="white", - font=FONT_B, - ) - .with_duration(duration) - .with_position(("center", 260)) - ) - s = ( - _tc( - text=subtitle, - font_size=24, - color="#90CAF9", - font=FONT_R, - ) - .with_duration(duration) - .with_position(("center", 340)) - ) - return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects( - [FadeIn(0.5), FadeOut(0.5)] - ) - - -def _text_slide( - lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]], - duration: float = STEP_DUR, -) -> CompositeVideoClip: - """Create a slide with multiple text elements.""" - bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) - clips: list[VideoClip] = [bg] - for text, font_size, color, font, pos in lines: - tc = ( - _tc( - text=text, - font_size=font_size, - color=color, - font=font, - ) - .with_duration(duration) - .with_position(pos) - ) - clips.append(tc) - return CompositeVideoClip(clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - - -def _compose_slide( - base_clip: VideoClip, - labels: list[tuple[str, int, str, str, tuple[int, int]]], - duration: float, -) -> CompositeVideoClip: - """Overlay text labels on an animated base clip.""" - text_clips: list[VideoClip] = [base_clip] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(duration) - .with_position(pos) - ) - text_clips.append(tc) - return CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - - -# ── Segmentation concept ───────────────────────────────────────── -def _segmentation_concept() -> list[CompositeVideoClip]: - """Show what segmentation is: pixel-level labeling.""" - slides = [] - - # Synthetic image: grid of colored pixels - def make_image_frame(_t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - # Draw a small "image" grid - grid_x, grid_y = 100, 150 - cell = 40 - # Sky (top rows) - colors_map = [ - [(135, 206, 235)] * 8, # sky - [(135, 206, 235)] * 5 + [(34, 139, 34)] * 3, # sky + tree - [(34, 139, 34)] * 3 - + [(128, 128, 128)] * 3 - + [(34, 139, 34)] * 2, # tree+road+tree - [(128, 128, 128)] * 3 - + [(200, 50, 50)] * 2 - + [(128, 128, 128)] * 3, # road+car+road - ] - labels_map = [ - ["niebo"] * 8, - ["niebo"] * 5 + ["drzewo"] * 3, - ["drzewo"] * 3 + ["droga"] * 3 + ["drzewo"] * 2, - ["droga"] * 3 + ["samochód"] * 2 + ["droga"] * 3, - ] - label_colors = { - "niebo": (100, 180, 255), - "drzewo": (50, 200, 50), - "droga": (180, 180, 180), - "samochód": (255, 80, 80), - } - - for r, row in enumerate(colors_map): - for c, col in enumerate(row): - y = grid_y + r * cell - x = grid_x + c * cell - frame[y : y + cell - 2, x : x + cell - 2] = col - - # Draw segmentation map on the right - seg_x = 600 - for r, row in enumerate(labels_map): - for c, lab in enumerate(row): - y = grid_y + r * cell - x = seg_x + c * cell - frame[y : y + cell - 2, x : x + cell - 2] = label_colors[lab] - - return frame - - image_clip = VideoClip(make_image_frame, duration=STEP_DUR).with_fps(FPS) - labels_text = [ - ("Obraz wejściowy", 22, "white", FONT_B, (170, 100)), - ("Mapa segmentacji", 22, "white", FONT_B, (660, 100)), - ("→", 50, "#FFE082", FONT_B, (450, 250)), - ("Każdy piksel → etykieta klasy", 20, "#B0BEC5", FONT_R, (100, 420)), - ("niebo | drzewo | droga | samochód", 18, "#90CAF9", FONT_R, (600, 420)), - ("Segmentacja = klasyfikacja per-piksel", 24, "#FFE082", FONT_B, (100, 500)), - ( - "Semantic: klasy bez instancji | Instance: " - "rozróżnia obiekty | Panoptic: oba", - 16, - "#78909C", - FONT_R, - (100, 560), - ), - ] - clips: list[VideoClip] = [image_clip] - for text, fs, color, font, pos in labels_text: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - clips.append(tc) - - slides.append( - CompositeVideoClip(clips, size=(W, H)).with_effects([FadeIn(0.3), FadeOut(0.3)]) - ) - return slides - - -# ── Thresholding / Otsu ─────────────────────────────────────────── -def _thresholding_demo() -> list[CompositeVideoClip]: - """Animate thresholding and Otsu concept.""" - slides = [] - - # Show histogram & threshold - def make_threshold_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - # Draw bimodal histogram bars - bar_start_x = 80 - bar_y = 500 - bar_w = 4 - - for i in range(256): - # Bimodal: peaks at 60 and 190 - h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2)) - h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2)) - bar_h = int(h1 + h2) - x = bar_start_x + i * bar_w - if x + bar_w < W: - frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (150, 150, 170) - - # Animated threshold line - threshold = int(60 + (190 - 60) * min(t / (STEP_DUR * 0.7), 1.0)) - tx = bar_start_x + threshold * bar_w - if tx < W: - frame[bar_y - 250 : bar_y + 10, tx : tx + 3] = (255, 80, 80) - - # Color the two sides - for i in range(threshold): - x = bar_start_x + i * bar_w - h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2)) - h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2)) - bar_h = int(h1 + h2) - if x + bar_w < W and bar_h > 0: - frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (70, 130, 200) - - for i in range(threshold, 256): - x = bar_start_x + i * bar_w - h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2)) - h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2)) - bar_h = int(h1 + h2) - if x + bar_w < W and bar_h > 0: - frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (200, 100, 80) - - return frame - - hist_clip = VideoClip(make_threshold_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [hist_clip] - labels = [ - ("Progowanie (Thresholding) z metodą Otsu", 28, "#FFE082", FONT_B, (80, 30)), - ( - "Histogram jasności pikseli — dwumodalny (bimodal)", - 20, - "#B0BEC5", - FONT_R, - (80, 80), - ), - ("Garb 1: piksele obiektu (ciemne ~60)", 16, "#64B5F6", FONT_R, (80, 120)), - ("Garb 2: piksele tła (jasne ~190)", 16, "#EF9A9A", FONT_R, (80, 150)), - ( - "Próg T (czerwona linia) dzieli piksele na 2 klasy", - 18, - "white", - FONT_R, - (80, 540), - ), - ( - "Otsu: automatycznie testuje T=0..255, minimalizuje σ² wewnątrzklasową", - 16, - "#A5D6A7", - FONT_R, - (80, 580), - ), - ( - "Piksel ≤ T → klasa 0 (tło) | Piksel > T → klasa 1 (obiekt)", - 16, - "#78909C", - FONT_R, - (80, 620), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── Region Growing ──────────────────────────────────────────────── -def _region_growing_demo() -> list[CompositeVideoClip]: - """Animate region growing BFS from a seed pixel.""" - slides = [] - - grid_size = 10 - cell_size = 40 - rng = np.random.default_rng(42) - # Create a simple grid: dark region (30-80) and bright region (160-220) - grid = np.zeros((grid_size, grid_size), dtype=np.uint8) - grid[:] = 60 # dark background - grid[2:7, 3:8] = 180 # bright rectangle - - # Add some noise - noise = rng.integers(-15, 15, (grid_size, grid_size)) - grid = np.clip(grid.astype(int) + noise, 0, 255).astype(np.uint8) - - # BFS steps from seed (4, 5) - seed = (4, 5) - threshold_val = 50 - visited_order: list[tuple[int, int]] = [] - queue = [seed] - visited_set = {seed} - while queue: - r, c = queue.pop(0) - visited_order.append((r, c)) - for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: - nr, nc = r + dr, c + dc - if ( - 0 <= nr < grid_size - and 0 <= nc < grid_size - and (nr, nc) not in visited_set - ) and abs(int(grid[nr, nc]) - int(grid[seed])) < threshold_val: - visited_set.add((nr, nc)) - queue.append((nr, nc)) - - def make_region_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - ox, oy = 100, 180 - - # How many cells to show as visited - progress = min(t / (STEP_DUR * 0.8), 1.0) - n_visited = int(progress * len(visited_order)) - - for r in range(grid_size): - for c in range(grid_size): - x = ox + c * cell_size - y = oy + r * cell_size - val = grid[r, c] - color = (val, val, val) - - # Highlight visited - if (r, c) in visited_order[:n_visited]: - color = (80, 200, 120) # green for region - elif (r, c) == seed: - color = (255, 200, 50) # yellow seed - - frame[y : y + cell_size - 2, x : x + cell_size - 2] = color - - # Show value - # (drawn as a simple marker since we can't render text in numpy easily) - - # Mark the seed with a bright border - sx = ox + seed[1] * cell_size - sy = ox + seed[0] * cell_size + 80 - frame[sy : sy + cell_size, sx : sx + 2] = (255, 200, 50) - frame[sy : sy + cell_size, sx + cell_size - 2 : sx + cell_size] = (255, 200, 50) - frame[sy : sy + 2, sx : sx + cell_size] = (255, 200, 50) - frame[sy + cell_size - 2 : sy + cell_size, sx : sx + cell_size] = (255, 200, 50) - - return frame - - region_clip = VideoClip(make_region_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [region_clip] - labels = [ - ("Region Growing — rozrastanie regionu", 28, "#FFE082", FONT_B, (100, 30)), - ("Seed (ziarno) → BFS do podobnych sąsiadów", 20, "#B0BEC5", FONT_R, (100, 80)), - ( - "Żółty = seed | Zielony = region | Szary = nieodwiedzone", - 16, - "#78909C", - FONT_R, - (100, 120), - ), - ( - "Sąsiad PODOBNY (|jasność - jasność_regionu| < próg) → dodaj do regionu", - 16, - "#A5D6A7", - FONT_R, - (100, 600), - ), - ( - "Algorytm zatrzymuje się gdy brak podobnych sąsiadów", - 16, - "#90CAF9", - FONT_R, - (100, 640), - ), - ( - "Mnemonik: PLAMA atramentu — rozlewa się na podobne piksele", - 18, - "#EF9A9A", - FONT_R, - (100, 670), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── Watershed ───────────────────────────────────────────────────── -def _watershed_demo() -> list[CompositeVideoClip]: - """Animate watershed flooding concept.""" - slides = [] - - def make_watershed_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - # Draw terrain profile (1D cross-section) - ox, oy = 100, 450 - terrain_w = 900 - terrain_points = 100 - - xs = np.linspace(0, 1, terrain_points) - # Two valleys with a ridge - terrain = ( - 120 * np.exp(-((xs - 0.25) ** 2) / 0.005) - + 80 * np.exp(-((xs - 0.75) ** 2) / 0.008) - + 30 - ) - terrain = 250 - terrain # invert for visual (valleys at bottom) - - # Water level rises over time - water_level = int(160 + 80 * min(t / (STEP_DUR * 0.7), 1.0)) - - for i in range(terrain_points - 1): - x1 = ox + int(xs[i] * terrain_w) - x2 = ox + int(xs[i + 1] * terrain_w) - y1 = oy - int(terrain[i]) - y2 = oy - int(terrain[i + 1]) - - # Fill terrain - for x in range(x1, min(x2 + 1, W)): - top = min(y1, y2) - 5 - frame[top:oy, x : x + 1] = (100, 80, 60) - - # Fill water - water_y = oy - water_level - for x in range(x1, min(x2 + 1, W)): - t_y = oy - int(terrain[i]) - if water_y < t_y: - # Water fills below terrain surface - fill_top = max(water_y, 0) - fill_bot = min(t_y, oy) - if fill_top < fill_bot: - frame[fill_top:fill_bot, x : x + 1] = (70, 130, 220) - - # Dam marker at ridge - ridge_x = ox + int(0.5 * terrain_w) - dam_visible_threshold = 160 - if water_level > dam_visible_threshold: - frame[oy - water_level : oy - 140, ridge_x - 2 : ridge_x + 2] = ( - 255, - 80, - 80, - ) - - return frame - - ws_clip = VideoClip(make_watershed_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [ws_clip] - labels = [ - ("Watershed — metoda zlewiska", 28, "#FFE082", FONT_B, (100, 20)), - ( - "Obraz = mapa topograficzna (jasność = wysokość)", - 20, - "#B0BEC5", - FONT_R, - (100, 65), - ), - ( - "Brązowy = teren (ciemne=doliny, jasne=szczyty)", - 16, - "#8D6E63", - FONT_R, - (100, 100), - ), - ("Niebieski = woda zalewająca od minimów", 16, "#64B5F6", FONT_R, (100, 130)), - ( - "Czerwony = TAMA (granica segmentu) — gdy woda z 2 dolin się spotka", - 16, - "#EF9A9A", - FONT_R, - (100, 160), - ), - ( - "Problem: over-segmentation " - "(za dużo regionów). " - "Rozwiązanie: marker-controlled.", - 16, - "#A5D6A7", - FONT_R, - (100, 560), - ), - ( - "Mnemonik: ZALEWANIE terenu — granie gór = granice segmentów", - 18, - "#FFE082", - FONT_R, - (100, 600), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── U-Net Architecture ─────────────────────────────────────────── -def _draw_unet_skips( - frame: np.ndarray, - enc_positions: list[tuple[int, int, int, int]], - n_blocks: int, - dec_x: int, - skip_threshold: int, -) -> None: - """Draw horizontal dashed skip-connection lines.""" - if n_blocks <= skip_threshold: - return - for i in range(min(n_blocks - 5, 4)): - ey = enc_positions[i][1] + enc_positions[i][3] // 2 - ex_end = enc_positions[i][0] + enc_positions[i][2] - for dash_x in range(ex_end + 10, dec_x - 10, 15): - frame[ey : ey + 2, dash_x : dash_x + 8] = (255, 200, 50) - - -def _make_unet_frame(t: float) -> np.ndarray: - """Render a single U-Net animation frame.""" - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - enc_sizes = [(80, 120), (60, 100), (45, 80), (30, 60)] - dec_sizes = list(reversed(enc_sizes)) - enc_x = 150 - dec_x = 850 - - progress = min(t / (STEP_DUR * 0.6), 1.0) - n_blocks = int(progress * 8) + 1 - - enc_positions: list[tuple[int, int, int, int]] = [] - y_offset = 120 - for i, (bw, bh) in enumerate(enc_sizes): - x = enc_x - y = y_offset + i * 130 - enc_positions.append((x, y, bw, bh)) - if i < n_blocks: - frame[y : y + bh, x : x + bw] = (70, 130, 200) - frame[y : y + 2, x : x + bw] = (100, 180, 255) - frame[y + bh - 2 : y + bh, x : x + bw] = (100, 180, 255) - frame[y : y + bh, x : x + 2] = (100, 180, 255) - frame[y : y + bh, x + bw - 2 : x + bw] = (100, 180, 255) - if i < len(enc_sizes) - 1: - ax = x + bw // 2 - ay = y + bh + 10 - frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170) - - bx, by = 500, y_offset + 3 * 130 + 30 - encoder_count = 4 - if n_blocks > encoder_count: - frame[by : by + 50, bx : bx + 25] = (200, 100, 80) - frame[by : by + 2, bx : bx + 25] = (255, 140, 100) - frame[by + 48 : by + 50, bx : bx + 25] = (255, 140, 100) - - for i, (bw, bh) in enumerate(dec_sizes): - x = dec_x - y = y_offset + (3 - i) * 130 - if n_blocks > 4 + i + 1: - frame[y : y + bh, x : x + bw] = (80, 200, 120) - frame[y : y + 2, x : x + bw] = (120, 230, 150) - frame[y + bh - 2 : y + bh, x : x + bw] = (120, 230, 150) - frame[y : y + bh, x : x + 2] = (120, 230, 150) - frame[y : y + bh, x + bw - 2 : x + bw] = (120, 230, 150) - if i < len(dec_sizes) - 1: - ax = x + bw // 2 - ay = y - 30 - frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170) - - skip_threshold = 5 - _draw_unet_skips(frame, enc_positions, n_blocks, dec_x, skip_threshold) - - return frame - - -def _unet_demo() -> list[CompositeVideoClip]: - """Animate U-Net encoder-decoder architecture.""" - dur = STEP_DUR + 1 - unet_clip = VideoClip(_make_unet_frame, duration=dur).with_fps(FPS) - labels = [ - ("U-Net: Encoder-Decoder + Skip Connections", 28, "#FFE082", FONT_B, (80, 20)), - ( - "Niebieski = Encoder (↓ zmniejsza rozdzielczość, wyciąga cechy)", - 16, - "#64B5F6", - FONT_R, - (80, 65), - ), - ( - "Zielony = Decoder (↑ zwiększa rozdzielczość, odtwarza mapę)", - 16, - "#A5D6A7", - FONT_R, - (80, 90), - ), - ( - "Żółte przerywane = Skip connections (przenoszą detale z encodera)", - 16, - "#FFE082", - FONT_R, - (80, 115), - ), - ( - "Czerwony = Bottleneck (najgłębsza warstwa, max abstrakcja)", - 16, - "#EF9A9A", - FONT_R, - (450, 570), - ), - ( - "Kształt U: encoder ↓ decoder ↑, mosty pośrodku", - 18, - "white", - FONT_R, - (80, 640), - ), - ( - "Concatenation: skip łączy kanały (więcej informacji niż dodawanie)", - 16, - "#78909C", - FONT_R, - (80, 670), - ), - ] - return [_compose_slide(unet_clip, labels, dur)] - - -# ── FCN Architecture ───────────────────────────────────────────── -def _draw_pipeline_blocks( - frame: np.ndarray, - blocks: list[ - tuple[tuple[int, int], tuple[int, int], tuple[int, int, int]] - ], - n_visible: int, - arrow_limit: int, -) -> None: - """Draw coloured blocks with connecting arrows.""" - for i, ((bx, by), (bw, bh), color) in enumerate(blocks): - if i < n_visible: - frame[by : by + bh, bx : bx + bw] = color - frame[by : by + 2, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - if i < arrow_limit: - ax = bx + bw + 3 - ay = by + bh // 2 - frame[ay - 1 : ay + 2, ax : ax + 12] = (150, 150, 170) - - -def _draw_red_cross( - frame: np.ndarray, - x_start: int, - width: int, - top_y: int, - height: int, -) -> None: - """Draw a red X across the given rectangle.""" - for d in range(-2, 3): - for step in range(height): - x1 = x_start + int(step * width / height) - y1 = top_y + step + d - if 0 <= y1 < H and 0 <= x1 < W: - frame[y1, x1] = (255, 80, 80) - y2 = top_y + height - step + d - if 0 <= y2 < H and 0 <= x1 < W: - frame[y2, x1] = (255, 80, 80) - - -def _make_fcn_frame(t: float) -> np.ndarray: - """Render a single FCN comparison frame.""" - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.8), 1.0) - - top_y = 140 - blocks_classic = [ - ((80, top_y), (70, 50), (70, 130, 200)), - ((170, top_y), (50, 40), (50, 100, 160)), - ((240, top_y), (60, 50), (70, 130, 200)), - ((320, top_y), (40, 35), (50, 100, 160)), - ((385, top_y), (55, 50), (160, 80, 60)), - ((465, top_y), (55, 50), (180, 60, 60)), - ((545, top_y), (80, 50), (200, 80, 80)), - ] - n_top = min(int(progress * 7) + 1, 7) - arrow_limit = 6 - _draw_pipeline_blocks(frame, blocks_classic, n_top, arrow_limit) - - cross_phase = 0.6 - if progress > cross_phase: - _draw_red_cross(frame, 385, 135, top_y, 50) - - bot_y = 380 - blocks_fcn = [ - ((80, bot_y), (70, 50), (70, 130, 200)), - ((170, bot_y), (50, 40), (50, 100, 160)), - ((240, bot_y), (60, 50), (70, 130, 200)), - ((320, bot_y), (40, 35), (50, 100, 160)), - ((385, bot_y), (70, 50), (80, 200, 120)), - ((480, bot_y), (75, 50), (200, 160, 80)), - ((580, bot_y), (80, 50), (100, 200, 100)), - ] - fcn_phase = 0.4 - if progress > fcn_phase: - n_bot = min(int((progress - fcn_phase) / 0.6 * 7) + 1, 7) - _draw_pipeline_blocks(frame, blocks_fcn, n_bot, arrow_limit) - - return frame - - -def _fcn_demo() -> list[CompositeVideoClip]: - """Animate FCN step-by-step: FC → Conv 1x1 transformation.""" - dur = STEP_DUR + 1 - fcn_clip = VideoClip(_make_fcn_frame, duration=dur).with_fps(FPS) - labels = [ - ("FCN: Fully Convolutional Network (2015)", 26, "#FFE082", FONT_B, (80, 20)), - ("KROK 1: Zamień FC → Conv 1x1", 18, "#A5D6A7", FONT_R, (80, 60)), - ("Klasyczny CNN:", 16, "#EF9A9A", FONT_B, (80, 105)), - ("Conv", 11, "white", FONT_R, (92, 148)), - ("Pool", 11, "white", FONT_R, (178, 148)), - ("Conv", 11, "white", FONT_R, (250, 148)), - ("Pool", 11, "white", FONT_R, (325, 148)), - ("Flatten", 11, "#EF9A9A", FONT_R, (390, 148)), - ("FC", 11, "#EF9A9A", FONT_R, (480, 148)), - ("1 label", 11, "#EF9A9A", FONT_R, (555, 148)), - ("FCN:", 16, "#A5D6A7", FONT_B, (80, 350)), - ("Conv", 11, "white", FONT_R, (92, 388)), - ("Pool", 11, "white", FONT_R, (178, 388)), - ("Conv", 11, "white", FONT_R, (250, 388)), - ("Pool", 11, "white", FONT_R, (325, 388)), - ("Conv1x1", 11, "#A5D6A7", FONT_R, (390, 388)), - ("Upsample", 11, "#FFE082", FONT_R, (486, 388)), - ("Mapa", 11, "#A5D6A7", FONT_R, (595, 388)), - ( - "FC: spłaszcza 3D→1D, wymusza stały rozmiar → 1 etykieta", - 16, - "#EF9A9A", - FONT_R, - (80, 250), - ), - ( - "Conv1x1: działa per piksel x kanały → DOWOLNY rozmiar → mapa klasy", - 16, - "#A5D6A7", - FONT_R, - (80, 460), - ), - ( - "KROK 2: Skip connections — łączą wczesne detale z późną abstrakcją", - 17, - "#64B5F6", - FONT_R, - (80, 510), - ), - ( - "Wczesne warstwy = krawędzie, tekstury | Późne = koncepty obiektów", - 15, - "#78909C", - FONT_R, - (80, 545), - ), - ( - "FCN = PIERWSZA sieć end-to-end do segmentacji per-piksel!", - 18, - "white", - FONT_R, - (80, 590), - ), - ( - "Mnemonik: FC → Conv 1x1 = otwieramy bramkę dla DOWOLNEGO rozmiaru", - 16, - "#FFE082", - FONT_R, - (80, 640), - ), - ] - slides = [_compose_slide(fcn_clip, labels, dur)] - - # Slide 2: FCN skip connections step by step - skip_lines = [ - ("FCN: Skip Connections — krok po kroku", 26, "#FFE082", FONT_B, (80, 30)), - ( - "1. Encoder zmniejsza: 224→112→56→28→14 (pooling)", - 18, - "#64B5F6", - FONT_R, - (100, 100), - ), - ( - " Każdy pooling traci detale przestrzenne (dokładne krawędzie)", - 15, - "#78909C", - FONT_R, - (100, 135), - ), - ( - "2. Decoder powiększa: 14→28→56→112→224 (upsample/deconv)", - 18, - "#A5D6A7", - FONT_R, - (100, 190), - ), - ( - " Upsample ODGADUJE piksele — rozmyty wynik!", - 15, - "#78909C", - FONT_R, - (100, 225), - ), - ( - "3. Skip connections: dodaj cechy z encodera do decodera", - 18, - "#FFE082", - FONT_R, - (100, 280), - ), - ( - " Wczesne cechy = GDZIE (precyzyjne krawędzie)", - 15, - "#64B5F6", - FONT_R, - (100, 315), - ), - ( - " Późne cechy = CO (abstrakcyjne koncepty)", - 15, - "#A5D6A7", - FONT_R, - (100, 345), - ), - ( - " Skip = daje decoderowi OBA → ostry wynik!", - 15, - "#FFE082", - FONT_R, - (100, 375), - ), - ( - "Warianty: FCN-32s (brak skip, rozmyty) → FCN-16s → FCN-8s (najlepszy)", - 16, - "#B0BEC5", - FONT_R, - (80, 440), - ), - ( - "FCN-32s: upsample 32x naraz → ROZMYTE granice", - 15, - "#EF9A9A", - FONT_R, - (100, 485), - ), - ( - "FCN-16s: skip z pool4 + upsample 16x → lepiej", - 15, - "#FFE082", - FONT_R, - (100, 520), - ), - ( - "FCN-8s: skip z pool3+pool4 + upsample 8x → OSTRE granice!", - 15, - "#A5D6A7", - FONT_R, - (100, 555), - ), - ( - "Im więcej skip connections → tym więcej " - "detali z encodera → ostrzejszy wynik", - 17, - "white", - FONT_R, - (80, 620), - ), - ] - slides.append(_text_slide(skip_lines, duration=STEP_DUR + 1)) - - return slides - - -# ── DeepLab Architecture ───────────────────────────────────────── -def _make_dilated_frame(t: float) -> np.ndarray: - """Render a dilated convolution comparison frame.""" - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - cell = 36 - grids = [ - ( - "rate=1", - 60, - [ - (0, 0), - (0, 1), - (0, 2), - (1, 0), - (1, 1), - (1, 2), - (2, 0), - (2, 1), - (2, 2), - ], - ), - ( - "rate=2", - 420, - [ - (0, 0), - (0, 2), - (0, 4), - (2, 0), - (2, 2), - (2, 4), - (4, 0), - (4, 2), - (4, 4), - ], - ), - ( - "rate=3", - 820, - [ - (0, 0), - (0, 3), - (0, 6), - (3, 0), - (3, 3), - (3, 6), - (6, 0), - (6, 3), - (6, 6), - ], - ), - ] - - for gi, (_label, gx, positions) in enumerate(grids): - if progress < gi * 0.3: - break - gy = 180 - grid_size = 7 - for r in range(grid_size): - for c in range(grid_size): - x = gx + c * cell - y = gy + r * cell - frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55) - for r, c in positions: - x = gx + c * cell - y = gy + r * cell - frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200) - frame[y : y + 2, x : x + cell - 2] = (120, 180, 255) - frame[y + cell - 4 : y + cell - 2, x : x + cell - 2] = (120, 180, 255) - - return frame - - -def _make_aspp_frame(t: float) -> np.ndarray: - """Render a single ASPP module animation frame.""" - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - frame[250:330, 50:130] = (70, 130, 200) - frame[250:252, 50:130] = (120, 180, 255) - frame[328:330, 50:130] = (120, 180, 255) - - branches = [ - ("1x1 conv", 250, (200, 170), (100, 40), (80, 200, 120)), - ("rate=6", 310, (200, 250), (100, 40), (200, 160, 80)), - ("rate=12", 370, (200, 330), (100, 40), (200, 120, 60)), - ("rate=18", 430, (200, 410), (100, 40), (180, 100, 80)), - ("GAP", 490, (200, 490), (100, 40), (160, 80, 160)), - ] - n_branches = min(int(progress * 5) + 1, 5) - for i, (_lbl, _h, (bx, by), (bw, bh), color) in enumerate(branches): - if i < n_branches: - frame[by : by + bh, bx : bx + bw] = color - frame[by : by + 2, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - ay = by + bh // 2 - frame[ay - 1 : ay + 2, 133:197] = (150, 150, 170) - - concat_phase = 0.6 - if progress > concat_phase: - frame[250:530, 380:420] = (50, 60, 80) - frame[250:252, 380:420] = (200, 200, 100) - frame[528:530, 380:420] = (200, 200, 100) - for i, (_lbl, _h, (bx, by), (bw, bh), _c) in enumerate(branches): - if i < n_branches: - ay = by + bh // 2 - frame[ay - 1 : ay + 2, bx + bw + 3 : 378] = (150, 150, 170) - - final_conv_phase = 0.8 - if progress > final_conv_phase: - frame[350:420, 450:550] = (100, 200, 100) - frame[350:352, 450:550] = (150, 230, 150) - frame[418:420, 450:550] = (150, 230, 150) - frame[388:391, 423:448] = (150, 150, 170) - - return frame - - -def _deeplab_demo() -> list[CompositeVideoClip]: - """Animate DeepLab: dilated convolution + ASPP step by step.""" - dur = STEP_DUR + 1 - - # Slide 1: Regular vs Dilated convolution - dil_clip = VideoClip(_make_dilated_frame, duration=dur).with_fps(FPS) - labels = [ - ("DeepLab: Atrous (Dilated) Convolution", 26, "#FFE082", FONT_B, (80, 20)), - ( - "KROK 1: Zrozum dilated convolution — filtr z DZIURAMI", - 18, - "#A5D6A7", - FONT_R, - (80, 60), - ), - ("rate=1 (zwykła)", 14, "#64B5F6", FONT_B, (60, 160)), - ("RF = 3x3", 14, "#64B5F6", FONT_R, (60, 440)), - ("9 wag, kontekst 3px", 12, "#78909C", FONT_R, (60, 470)), - ("rate=2 (dilated)", 14, "#FFE082", FONT_B, (420, 160)), - ("RF = 5x5", 14, "#FFE082", FONT_R, (420, 440)), - ("9 wag, kontekst 5px!", 12, "#78909C", FONT_R, (420, 470)), - ("rate=3 (dilated)", 14, "#A5D6A7", FONT_B, (820, 160)), - ("RF = 7x7", 14, "#A5D6A7", FONT_R, (820, 440)), - ("9 wag, kontekst 7px!", 12, "#78909C", FONT_R, (820, 470)), - ( - "Niebieski = pozycja wag filtra 3x3 | Szary = pominięte (dziury)", - 15, - "#B0BEC5", - FONT_R, - (80, 510), - ), - ( - "TE SAME 9 wag → WIĘKSZE pole widzenia " - "→ lepszy kontekst BEZ dodatkowych parametrów!", - 16, - "white", - FONT_R, - (80, 550), - ), - ( - "Mnemonik: DZIURY w filtrze — à trous = z dziurami (fr.)", - 16, - "#FFE082", - FONT_R, - (80, 600), - ), - ] - slides = [_compose_slide(dil_clip, labels, dur)] - - # Slide 2: ASPP module step by step - aspp_clip = VideoClip(_make_aspp_frame, duration=dur).with_fps(FPS) - labels2 = [ - ( - "DeepLab: ASPP (Atrous Spatial Pyramid Pooling)", - 24, - "#FFE082", - FONT_B, - (80, 20), - ), - ( - "KROK 2: Multi-scale — analizuj obraz na WIELU skalach naraz", - 17, - "#A5D6A7", - FONT_R, - (80, 60), - ), - ("Wejście", 13, "#64B5F6", FONT_B, (55, 235)), - ("Conv 1x1", 12, "white", FONT_R, (210, 178)), - ("Dilated r=6", 12, "white", FONT_R, (205, 258)), - ("Dilated r=12", 12, "white", FONT_R, (203, 338)), - ("Dilated r=18", 12, "white", FONT_R, (203, 418)), - ("GAP (global)", 12, "white", FONT_R, (205, 498)), - ("Concat", 13, "#FFE082", FONT_B, (381, 537)), - ("Conv", 13, "#A5D6A7", FONT_B, (470, 425)), - ( - "5 gałęzi RÓWNOLEGŁYCH → różne skale kontekstu:", - 16, - "#B0BEC5", - FONT_R, - (550, 170), - ), - (" 1x1: kontekst punktowy (piksel)", 14, "#A5D6A7", FONT_R, (560, 210)), - (" r=6: kontekst lokalny (~13px)", 14, "#FFE082", FONT_R, (560, 245)), - (" r=12: kontekst średni (~25px)", 14, "#FFE082", FONT_R, (560, 280)), - (" r=18: kontekst szeroki (~37px)", 14, "#FFE082", FONT_R, (560, 315)), - (" GAP: kontekst GLOBALNY (cały obraz)", 14, "#CE93D8", FONT_R, (560, 350)), - ("Concat → 1x1 conv → mapa segmentacji", 16, "#A5D6A7", FONT_R, (550, 400)), - ( - "Efekt: sieć widzi OD piksela DO całego obrazu naraz!", - 17, - "white", - FONT_R, - (80, 600), - ), - ( - "Mnemonik: ASPP = Piramida z DZIURAMI, patrzy na 5 skal jednocześnie", - 15, - "#FFE082", - FONT_R, - (80, 645), - ), - ] - slides.append(_compose_slide(aspp_clip, labels2, dur)) - - return slides - - -# ── Transformer Segmentation ──────────────────────────────────── -def _draw_base_grid( - frame: np.ndarray, gx: int, gy: int, grid_n: int, cell: int, -) -> None: - """Draw an empty grid of cells.""" - for r in range(grid_n): - for c in range(grid_n): - x = gx + c * cell - y = gy + r * cell - frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55) - - -def _draw_cnn_kernel( - frame: np.ndarray, lx: int, ly: int, cell: int, progress: float, -) -> None: - """Highlight a 3x3 CNN kernel on the grid.""" - cnn_phase = 0.2 - if progress <= cnn_phase: - return - cx, cy = 2, 2 - for dr in range(-1, 2): - for dc in range(-1, 2): - r, c = cy + dr, cx + dc - x = lx + c * cell - y = ly + r * cell - frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200) - x = lx + cx * cell - y = ly + cy * cell - frame[y : y + cell - 2, x : x + cell - 2] = (120, 180, 255) - - -def _draw_conn_line( - frame: np.ndarray, x0: int, y0: int, x1: int, y1: int, -) -> None: - """Draw a dashed connection line between two points.""" - steps = max(abs(x1 - x0), abs(y1 - y0)) - if steps <= 0: - return - for s in range(0, steps, 3): - px = x0 + int((x1 - x0) * s / steps) - py = y0 + int((y1 - y0) * s / steps) - if 0 <= px < W - 1 and 0 <= py < H - 1: - frame[py : py + 1, px : px + 1] = (200, 180, 50) - - -def _draw_attention_connections( - frame: np.ndarray, - origin: tuple[int, int], - grid_n: int, - cell: int, - progress: float, -) -> None: - """Draw transformer self-attention connections on the grid.""" - rx, ry = origin - transformer_phase = 0.4 - if progress <= transformer_phase: - return - cx_t, cy_t = 2, 2 - x0 = rx + cx_t * cell + cell // 2 - y0 = ry + cy_t * cell + cell // 2 - n_connections = int(progress * 36) - conn_idx = 0 - for r in range(grid_n): - for c in range(grid_n): - conn_idx += 1 - if conn_idx > n_connections: - break - x = rx + c * cell - y = ry + r * cell - dist = abs(r - cy_t) + abs(c - cx_t) - strength = max(30, 200 - dist * 30) - frame[y : y + cell - 2, x : x + cell - 2] = ( - strength // 3, - strength // 2, - strength, - ) - _draw_conn_line(frame, x0, y0, x + cell // 2, y + cell // 2) - else: - continue - break - x = rx + cx_t * cell - y = ry + cy_t * cell - frame[y : y + cell - 2, x : x + cell - 2] = (255, 200, 50) - - -def _make_attention_frame(t: float) -> np.ndarray: - """Render a CNN-vs-Transformer attention comparison frame.""" - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - cell = 40 - grid_n = 6 - - lx, ly = 60, 200 - _draw_base_grid(frame, lx, ly, grid_n, cell) - _draw_cnn_kernel(frame, lx, ly, cell, progress) - - rx, ry = 680, 200 - _draw_base_grid(frame, rx, ry, grid_n, cell) - _draw_attention_connections(frame, (rx, ry), grid_n, cell, progress) - - return frame - - -def _transformer_seg_demo() -> list[CompositeVideoClip]: - """Animate transformer-based segmentation: self-attention concept.""" - dur = STEP_DUR + 1 - - # Slide 1: CNN local vs Transformer global - att_clip = VideoClip(_make_attention_frame, duration=dur).with_fps(FPS) - labels = [ - ("Transformer: Self-Attention w segmentacji", 26, "#FFE082", FONT_B, (80, 20)), - ("CNN = LOKALNY kontekst", 18, "#64B5F6", FONT_B, (60, 160)), - ("Transformer = GLOBALNY kontekst", 18, "#FFE082", FONT_B, (680, 160)), - ("Filtr 3x3 widzi", 14, "#64B5F6", FONT_R, (60, 460)), - ("TYLKO 9 sąsiadów", 14, "#64B5F6", FONT_R, (60, 485)), - ("Self-attention: każdy", 14, "#FFE082", FONT_R, (680, 460)), - ("piksel widzi WSZYSTKIE!", 14, "#FFE082", FONT_R, (680, 485)), - ("vs", 28, "#B0BEC5", FONT_B, (450, 300)), - ] - slides = [_compose_slide(att_clip, labels, dur)] - - # Slide 2: Self-attention Q/K/V step by step - qkv_lines = [ - ("Self-Attention: Q / K / V krok po kroku", 26, "#FFE082", FONT_B, (80, 30)), - ("Każdy piksel (token) tworzy 3 wektory:", 18, "#B0BEC5", FONT_R, (100, 100)), - ( - " Q (Query) = 'czego szukam?' - pytanie piksela", - 17, - "#64B5F6", - FONT_R, - (120, 145), - ), - ( - " K (Key) = 'co oferuj\u0119?' - odpowied\u017a piksela", - 17, - "#A5D6A7", - FONT_R, - (120, 185), - ), - ( - " V (Value) = 'moja warto\u015b\u0107' - informacja do przekazania", - 17, - "#FFE082", - FONT_R, - (120, 225), - ), - ("Algorytm attention:", 18, "#B0BEC5", FONT_R, (100, 285)), - ( - " 1. Mnożenie Q x K\u1d40 → macierz NxN (kto ważny dla kogo)", - 16, - "white", - FONT_R, - (120, 320), - ), - ( - " 2. Skalowanie: / √d (stabilność gradientów)", - 16, - "white", - FONT_R, - (120, 355), - ), - ( - " 3. Softmax → wagi attention (sumują się do 1)", - 16, - "white", - FONT_R, - (120, 390), - ), - ( - " 4. Mnożenie wag x V → ważona suma wartości", - 16, - "white", - FONT_R, - (120, 425), - ), - ( - "Attention(Q,K,V) = softmax(Q · K\u1d40 / √d) · V", - 20, - "#FFE082", - FONT_B, - (100, 480), - ), - ( - "Złożoność: O(n²) pamięci — n = liczba pikseli/tokenów", - 16, - "#EF9A9A", - FONT_R, - (100, 535), - ), - ( - "Dlatego SegFormer używa efficient attention (liniowa złożoność)", - 15, - "#78909C", - FONT_R, - (100, 570), - ), - ( - "SegFormer (2021): lightweight + hierarchiczny encoder", - 16, - "#A5D6A7", - FONT_R, - (100, 610), - ), - ( - "Mask2Former (2022): masked attention + " - "unified (semantic+instance+panoptic)", - 16, - "#CE93D8", - FONT_R, - (100, 645), - ), - ] - slides.append(_text_slide(qkv_lines, duration=STEP_DUR + 1)) - - # Slide 3: Encoder-Decoder in DL summary - summary_lines = [ - ( - "Podsumowanie: Encoder-Decoder w segmentacji DL", - 24, - "#FFE082", - FONT_B, - (80, 30), - ), - ("Wspólna idea WSZYSTKICH sieci segmentacji:", 18, "#B0BEC5", FONT_R, (80, 90)), - ( - "Encoder: obraz → cechy (zmniejsza rozdzielczość, wyciąga CO)", - 16, - "#64B5F6", - FONT_R, - (100, 140), - ), - ( - "Decoder: cechy → mapa (zwiększa rozdzielczość, odtwarza GDZIE)", - 16, - "#A5D6A7", - FONT_R, - (100, 175), - ), - ( - "Skip: przenosi detale z encodera do decodera", - 16, - "#FFE082", - FONT_R, - (100, 210), - ), - ("", 10, "white", FONT_R, (100, 240)), - ( - "FCN (2015): Conv1x1 + skip → pierwsza end-to-end", - 16, - "#64B5F6", - FONT_R, - (100, 275), - ), - ( - "U-Net (2015): U-shape + skip concat → segmentacja medyczna", - 16, - "#A5D6A7", - FONT_R, - (100, 310), - ), - ( - "DeepLab (2018): dilated conv + ASPP → multi-scale kontekst", - 16, - "#FFE082", - FONT_R, - (100, 345), - ), - ( - "SegFormer: transformer encoder (globalny kontekst)", - 16, - "#CE93D8", - FONT_R, - (100, 380), - ), - ( - "Mask2Former: masked attention (unified, SOTA)", - 16, - "#CE93D8", - FONT_R, - (100, 415), - ), - ("", 10, "white", FONT_R, (100, 440)), - ( - "Ewolucja: więcej kontekstu + lepsze skip connections:", - 17, - "white", - FONT_R, - (80, 465), - ), - ( - " CNN lokal. → dilated (szersze RF) → transformer (global) → masked att.", - 16, - "#B0BEC5", - FONT_R, - (80, 505), - ), - ( - " addition skip → concat skip → cross-attention skip", - 16, - "#B0BEC5", - FONT_R, - (80, 540), - ), - ( - "Metryki: mIoU (standard), Dice (medycyna), Focal Loss (imbalance)", - 16, - "#90CAF9", - FONT_R, - (80, 590), - ), - ( - "Loss: Cross-Entropy per piksel + opcjonalnie Dice/Focal", - 15, - "#78909C", - FONT_R, - (80, 625), - ), - ] - slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1)) - - return slides - - -# ── Methods comparison ──────────────────────────────────────────── -def _methods_comparison() -> CompositeVideoClip: - bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0) - title = ( - _tc( - text="Porównanie metod segmentacji", - font_size=36, - color="white", - font=FONT_B, - ) - .with_duration(10.0) - .with_position(("center", 20)) - ) - - rows = [ - ("Metoda", "Typ", "Idea", "Mnemonik"), - ("Thresholding", "Klasyczna", "piksel > T → klasa 1", "PRÓG na bramce"), - ("Otsu", "Klasyczna", "auto-próg, min σ²", "AUTO-bramkarz"), - ("Region Growing", "Klasyczna", "BFS od seeda", "PLAMA atramentu"), - ("Watershed", "Klasyczna", "zalewanie minimów", "ZALEWANIE terenu"), - ("Mean Shift", "Klasyczna", "jądro → max gęstości", "KULKI do dołków"), - ("U-Net", "Deep Learning", "encoder-decoder + skip", "Litera U + mosty"), - ("DeepLab", "Deep Learning", "dilated conv + ASPP", "DZIURY w filtrze"), - ] - - clips: list[VideoClip] = [bg, title] - mnemonic_col = 3 - for i, row in enumerate(rows): - y_pos = 75 + i * 72 - col_x = [40, 210, 340, 660] - for j, cell in enumerate(row): - fs = 16 if i > 0 else 18 - color = ( - "#64B5F6" if i == 0 - else ("#E0E0E0" if j < mnemonic_col else "#FFE082") - ) - tc = ( - _tc( - text=cell, - font_size=fs, - color=color, - font=FONT_B if i == 0 else FONT_R, - ) - .with_duration(10.0) - .with_position((col_x[j], y_pos)) - ) - clips.append(tc) - - return CompositeVideoClip(clips, size=(W, H)).with_effects( - [FadeIn(0.5), FadeOut(0.5)] - ) - - -# ── Main ────────────────────────────────────────────────────────── def main() -> None: """Generate the Q23 segmentation visualization video.""" sections: list[VideoClip] = [] diff --git a/python_pkg/praca_magisterska_video/visualize_q24.py b/python_pkg/praca_magisterska_video/visualize_q24.py index 5e3f1b5..b3e5370 100644 --- a/python_pkg/praca_magisterska_video/visualize_q24.py +++ b/python_pkg/praca_magisterska_video/visualize_q24.py @@ -13,1813 +13,28 @@ from __future__ import annotations import logging import os -from pathlib import Path - -import numpy as np os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg" -from moviepy import ( - ColorClip, - CompositeVideoClip, - TextClip, - VideoClip, - concatenate_videoclips, +from _q24_classical import ( + _detection_concept, + _hog_svm_demo, + _viola_jones_demo, ) -from moviepy.video.fx import FadeIn, FadeOut - -# ── Constants ───────────────────────────────────────────────────── -W, H = 1280, 720 -FPS = 24 -STEP_DUR = 7.0 -HEADER_DUR = 4.0 -FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf" -FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf" -OUTPUT_DIR = Path(__file__).resolve().parent / "videos" -OUTPUT_DIR.mkdir(parents=True, exist_ok=True) -OUTPUT = str(OUTPUT_DIR / "q24_object_detection.mp4") - -BG_COLOR = (15, 20, 35) - -_logger = logging.getLogger(__name__) - - -def _tc(**kwargs: object) -> TextClip: - """TextClip wrapper that adds enough bottom margin to prevent clipping.""" - fs = kwargs.get("font_size", 24) - m = int(fs) // 3 + 2 - kwargs["margin"] = (0, m) - return TextClip(**kwargs) - - -def _make_header( - title: str, subtitle: str, duration: float = HEADER_DUR -) -> CompositeVideoClip: - bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) - t = ( - _tc( - text=title, - font_size=48, - color="white", - font=FONT_B, - ) - .with_duration(duration) - .with_position(("center", 260)) - ) - s = ( - _tc( - text=subtitle, - font_size=24, - color="#90CAF9", - font=FONT_R, - ) - .with_duration(duration) - .with_position(("center", 340)) - ) - return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects( - [FadeIn(0.5), FadeOut(0.5)] - ) - - -# ── Detection concept ──────────────────────────────────────────── -def _detection_concept() -> list[CompositeVideoClip]: - """Show what detection is: bounding box + class + confidence.""" - slides = [] - - def make_det_frame(_t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - # Draw a "scene" with colored rectangles representing objects - # Sky background area - frame[140:500, 100:700] = (40, 50, 70) - - # "Car" object - frame[350:430, 150:320] = (180, 60, 60) - # "Person" object - frame[280:440, 450:520] = (60, 120, 180) - # "Tree" object - frame[200:400, 580:650] = (40, 130, 50) - - # Bounding boxes (with labels drawn as colored borders) - # Car bbox - for thickness in range(3): - t = thickness - frame[348 - t : 432 + t, 148 - t : 148 - t + 2] = (255, 80, 80) - frame[348 - t : 432 + t, 322 + t - 2 : 322 + t] = (255, 80, 80) - frame[348 - t : 348 - t + 2, 148 - t : 322 + t] = (255, 80, 80) - frame[432 + t - 2 : 432 + t, 148 - t : 322 + t] = (255, 80, 80) - - # Person bbox - for thickness in range(3): - t = thickness - frame[278 - t : 442 + t, 448 - t : 448 - t + 2] = (80, 180, 255) - frame[278 - t : 442 + t, 522 + t - 2 : 522 + t] = (80, 180, 255) - frame[278 - t : 278 - t + 2, 448 - t : 522 + t] = (80, 180, 255) - frame[442 + t - 2 : 442 + t, 448 - t : 522 + t] = (80, 180, 255) - - # Tree bbox - for thickness in range(3): - t = thickness - frame[198 - t : 402 + t, 578 - t : 578 - t + 2] = (80, 220, 100) - frame[198 - t : 402 + t, 652 + t - 2 : 652 + t] = (80, 220, 100) - frame[198 - t : 198 - t + 2, 578 - t : 652 + t] = (80, 220, 100) - frame[402 + t - 2 : 402 + t, 578 - t : 652 + t] = (80, 220, 100) - - # Comparison boxes on right side - # Classification - frame[180:260, 800:1150] = (35, 45, 65) - # Detection - frame[290:370, 800:1150] = (35, 45, 65) - # Segmentation - frame[400:480, 800:1150] = (35, 45, 65) - - return frame - - det_clip = VideoClip(make_det_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [det_clip] - labels = [ - ("Detekcja obiektów — co to jest?", 28, "#FFE082", FONT_B, (100, 20)), - ("Wynik: (klasa, bounding box, pewność)", 20, "#B0BEC5", FONT_R, (100, 65)), - ("samochód 95%", 14, "#EF9A9A", FONT_B, (150, 340)), - ("osoba 88%", 14, "#64B5F6", FONT_B, (450, 268)), - ("drzewo 72%", 14, "#A5D6A7", FONT_B, (580, 188)), - ("Klasyfikacja: cały obraz → 1 etykieta", 15, "#78909C", FONT_R, (810, 210)), - ("Detekcja: bbox + klasa + pewność", 15, "#FFE082", FONT_R, (810, 320)), - ("Segmentacja: maska per piksel", 15, "#78909C", FONT_R, (810, 430)), - ("← granulacja rośnie →", 14, "#90CAF9", FONT_R, (810, 520)), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── HOG + SVM pipeline ─────────────────────────────────────────── -def _hog_svm_demo() -> list[CompositeVideoClip]: - """Animate HOG feature computation and SVM classification.""" - slides = [] - - def make_hog_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - progress = min(t / (STEP_DUR * 0.8), 1.0) - - # Pipeline stages as boxes with arrows - stages = [ - ("Gradient", (80, 250), (130, 80), (100, 160, 220)), - ("Orientacja", (260, 250), (130, 80), (80, 180, 140)), - ("Komórki 8x8", (440, 250), (130, 80), (200, 160, 80)), - ("Bloki 2x2", (620, 250), (130, 80), (200, 120, 60)), - ("Normalizacja", (800, 250), (130, 80), (180, 100, 80)), - ("SVM", (980, 250), (130, 80), (220, 80, 80)), - ] - - n_active = int(progress * len(stages)) + 1 - - for i, (_label, (sx, sy), (sw, sh), color) in enumerate(stages): - if i < n_active: - frame[sy : sy + sh, sx : sx + sw] = color - # Border - frame[sy : sy + 2, sx : sx + sw] = tuple( - min(c + 60, 255) for c in color - ) - frame[sy + sh - 2 : sy + sh, sx : sx + sw] = tuple( - min(c + 60, 255) for c in color - ) - - # Arrow to next - if i < len(stages) - 1: - ax = sx + sw + 5 - ay = sy + sh // 2 - frame[ay - 1 : ay + 2, ax : ax + 20] = (150, 150, 170) - - # Show gradient computation example at bottom - gradient_phase = 0.2 - if progress > gradient_phase: - # Mini pixel grid showing gradient computation - gx, gy = 100, 430 - pixels = [50, 50, 200] - for idx, val in enumerate(pixels): - x = gx + idx * 50 - frame[gy : gy + 40, x : x + 40] = (val, val, val) - - return frame - - hog_clip = VideoClip(make_hog_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [hog_clip] - labels = [ - ("HOG + SVM — pipeline detekcji pieszych", 28, "#FFE082", FONT_B, (80, 20)), - ( - "Mnemonik: GOKBN = Gradienty→Orientacja→Komórki→Bloki→Normalizacja", - 16, - "#A5D6A7", - FONT_R, - (80, 65), - ), - ("Gradient: siła i kierunek zmiany jasności", 14, "#64B5F6", FONT_R, (80, 95)), - ( - "Histogram: 9 binów (0°-180°, co 20°) per komórka 8x8", - 14, - "#78909C", - FONT_R, - (80, 120), - ), - ( - "[50][50][200] → Gx = 200-50 = 150 = silna krawędź!", - 16, - "#EF9A9A", - FONT_R, - (80, 490), - ), - ( - "Wektor HOG (3780 cech) → SVM: pieszy (+1) / tło (-1)", - 16, - "white", - FONT_R, - (80, 540), - ), - ( - "Sliding window 64x128 przesuwa się po obrazie → NMS → wynik", - 16, - "#90CAF9", - FONT_R, - (80, 580), - ), - ( - "SVM = LINIA MAKSYMALNEGO ODDECHU (max margines, support vectors)", - 16, - "#FFE082", - FONT_R, - (80, 620), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── Viola-Jones ─────────────────────────────────────────────────── -def _viola_jones_demo() -> list[CompositeVideoClip]: - """Animate Viola-Jones cascade concept.""" - slides = [] - - def make_cascade_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - progress = min(t / (STEP_DUR * 0.8), 1.0) - - # Draw cascade "funnel" — stages filtering out non-faces - stages = 5 - start_width = 1000 - start_count = 10000 - x_center = W // 2 - - for i in range(stages): - stage_progress = min(progress * stages - i, 1.0) - if stage_progress <= 0: - break - - width = int(start_width * (1 - i * 0.18)) - int(start_count * (0.3**i)) - y = 150 + i * 100 - h_box = 60 - - # Stage box - x1 = x_center - width // 2 - frame[y : y + h_box, x1 : x1 + width] = ( - 50 + i * 10, - 60 + i * 10, - 80 + i * 10, - ) - # Border - frame[y : y + 2, x1 : x1 + width] = (100 + i * 20, 130 + i * 15, 200) - frame[y + h_box - 2 : y + h_box, x1 : x1 + width] = ( - 100 + i * 20, - 130 + i * 15, - 200, - ) - - # Arrow down to next - if i < stages - 1: - frame[y + h_box + 5 : y + h_box + 25, x_center - 1 : x_center + 2] = ( - 150, - 150, - 170, - ) - - # Red "rejected" arrows on sides - if i > 0: - # Left reject arrow - rx = x1 - 30 - ry = y + h_box // 2 - frame[ry - 1 : ry + 2, rx : rx + 25] = (200, 80, 80) - - return frame - - cascade_clip = VideoClip(make_cascade_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [cascade_clip] - labels = [ - ( - "Viola-Jones — kaskada klasyfikatorów (2001)", - 28, - "#FFE082", - FONT_B, - (80, 20), - ), - ( - "3 innowacje: HIC = Haar + Integral Image + Cascade", - 20, - "#B0BEC5", - FONT_R, - (80, 65), - ), - ("Etap 1: 2 cechy Haar", 14, "#64B5F6", FONT_R, (170, 170)), - ("Etap 2: 10 cech", 14, "#64B5F6", FONT_R, (210, 270)), - ("Etap 3: 25 cech", 14, "#64B5F6", FONT_R, (240, 370)), - ("Etap 4: 50 cech", 14, "#64B5F6", FONT_R, (260, 470)), - ("→ TWARZ!", 16, "#A5D6A7", FONT_B, (590, 560)), - ( - "SITO: 99% okien odpada w pierwszych 3 etapach → REAL-TIME!", - 16, - "#EF9A9A", - FONT_R, - (80, 620), - ), - ( - "Haar: kontrast jasna/ciemna | Integral Image: " - "suma prostokąta O(1) = 4 odczyty", - 14, - "#78909C", - FONT_R, - (80, 655), - ), - ("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 275)), - ("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 375)), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── R-CNN Evolution ─────────────────────────────────────────────── -def _rcnn_evolution() -> list[CompositeVideoClip]: - """Animate R-CNN → Fast R-CNN → Faster R-CNN evolution.""" - slides = [] - - def make_evolution_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - progress = min(t / (STEP_DUR * 0.8), 1.0) - - # Three rows: R-CNN, Fast R-CNN, Faster R-CNN - models = [ - ( - "R-CNN (2014)", - 50, - [ - ("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)), - ("2000x\nCNN", (350, 150), (80, 50), (180, 60, 60)), - ("2000x\nSVM", (480, 150), (80, 50), (180, 60, 60)), - ("NMS", (610, 150), (60, 50), (100, 140, 100)), - ], - "50 sec/obraz!", - ), - ( - "Fast R-CNN (2015)", - 300, - [ - ("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)), - ("1x CNN\n(cały obraz)", (350, 150), (100, 50), (80, 140, 200)), - ("ROI Pool\n(2000)", (500, 150), (90, 50), (200, 160, 80)), - ("FC", (640, 150), (50, 50), (100, 140, 100)), - ], - "2 sec/obraz", - ), - ( - "Faster R-CNN (2015)", - 300, - [ - ("CNN\nbackbone", (200, 150), (90, 50), (80, 140, 200)), - ("RPN\n(~300)", (340, 150), (80, 50), (200, 120, 60)), - ("ROI Pool", (470, 150), (80, 50), (200, 160, 80)), - ("FC", (600, 150), (50, 50), (100, 140, 100)), - ], - "0.2 sec → 5 fps!", - ), - ] - - n_models = int(progress * 3) + 1 - - for mi, (_name, base_y, stages, _speed) in enumerate(models): - if mi >= n_models: - break - for _label, (bx, by_off), (bw, bh), color in stages: - by = base_y + by_off - 150 - frame[by : by + bh, bx : bx + bw] = color - frame[by : by + 2, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - - # Arrows between stages - for si in range(len(stages) - 1): - sx = stages[si][1][0] + stages[si][2][0] - ex = stages[si + 1][1][0] - ay = base_y + 25 - frame[ay - 1 : ay + 2, sx + 3 : ex - 3] = (150, 150, 170) - - return frame - - evo_clip = VideoClip(make_evolution_frame, duration=STEP_DUR + 1).with_fps(FPS) - text_clips: list[VideoClip] = [evo_clip] - labels = [ - ("Ewolucja R-CNN — CORAZ MNIEJ MARNOWANIA", 28, "#FFE082", FONT_B, (80, 20)), - ("R-CNN (2014)", 20, "#EF9A9A", FONT_B, (50, 80)), - ("50 sec/obraz (2000x forward pass!)", 14, "#EF9A9A", FONT_R, (720, 100)), - ("Fast R-CNN (2015)", 20, "#64B5F6", FONT_B, (50, 330)), - ("2 sec/obraz (CNN raz + ROI Pool)", 14, "#64B5F6", FONT_R, (720, 350)), - ("Faster R-CNN (2015)", 20, "#A5D6A7", FONT_B, (50, 580)), - ("0.2 sec → 5 fps (RPN w sieci!)", 14, "#A5D6A7", FONT_R, (720, 600)), - ( - "Kluczowe innowacje: ROI Pooling → stały rozmiar " - "| RPN → propozycje w sieci", - 14, - "#78909C", - FONT_R, - (80, 660), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR + 1) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── R-CNN Detailed Pipeline ────────────────────────────────────── -def _rcnn_detailed() -> list[CompositeVideoClip]: - """Animate R-CNN step-by-step pipeline in detail.""" - slides = [] - - # Slide 1: R-CNN pipeline step by step - def make_rcnn_pipeline(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.8), 1.0) - - # Step boxes arranged vertically with arrows - steps = [ - ((80, 130), (200, 55), (120, 100, 60), "1. Selective Search"), - ((80, 230), (200, 55), (180, 60, 60), "2. Wytnij 2000 regionów"), - ((80, 330), (200, 55), (70, 130, 200), "3. CNN per region"), - ((80, 430), (200, 55), (200, 100, 80), "4. SVM klasyfikuje"), - ((80, 530), (200, 55), (100, 180, 100), "5. Bbox regresja + NMS"), - ] - n_steps = min(int(progress * 5) + 1, 5) - for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(steps): - if i < n_steps: - frame[by : by + bh, bx : bx + bw] = color - frame[by : by + 2, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - # Arrow down - arrow_limit = 4 - if i < arrow_limit: - ax = bx + bw // 2 - ay = by + bh + 5 - frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170) - - # Illustration: many overlapping regions from Selective Search - overlay_phase = 0.2 - if progress > overlay_phase: - rng_local = np.random.default_rng(42) - n_boxes = min(int((progress - 0.2) * 15), 8) - for i in range(n_boxes): - rx = 500 + rng_local.integers(-30, 100) - ry = 200 + rng_local.integers(-20, 120) - rw = 60 + rng_local.integers(0, 80) - rh = 50 + rng_local.integers(0, 70) - c = (80 + i * 15, 100 + i * 10, 60 + i * 20) - for tt in range(2): - frame[ry - tt : ry + rh + tt, rx - tt : rx - tt + 2] = c - frame[ry - tt : ry + rh + tt, rx + rw + tt - 2 : rx + rw + tt] = c - frame[ry - tt : ry - tt + 2, rx - tt : rx + rw + tt] = c - frame[ry + rh + tt - 2 : ry + rh + tt, rx - tt : rx + rw + tt] = c - - return frame - - rcnn_clip = VideoClip(make_rcnn_pipeline, duration=STEP_DUR + 1).with_fps(FPS) - dur = STEP_DUR + 1 - labels = [ - ("R-CNN: krok po kroku (2014, Girshick)", 26, "#FFE082", FONT_B, (80, 20)), - ("Pipeline detekcji two-stage", 16, "#B0BEC5", FONT_R, (80, 60)), - ("Selective Search", 11, "white", FONT_R, (105, 145)), - ("2000 regionów", 11, "white", FONT_R, (105, 245)), - ("CNN per region", 11, "white", FONT_R, (105, 345)), - ("SVM klasyfikuje", 11, "white", FONT_R, (105, 445)), - ("Regresja + NMS", 11, "white", FONT_R, (105, 545)), - ("~2000 propozycji regionów", 14, "#78909C", FONT_R, (500, 155)), - ("(inteligentne łączenie", 13, "#78909C", FONT_R, (500, 180)), - ("podobnych fragmentów)", 13, "#78909C", FONT_R, (500, 200)), - ("Problem: 2000 x CNN forward pass", 16, "#EF9A9A", FONT_R, (400, 400)), - ("= 50 SEKUND na obraz!", 18, "#EF9A9A", FONT_B, (400, 430)), - ("CNN liczy cechy per region OSOBNO", 14, "#EF9A9A", FONT_R, (400, 470)), - ( - "→ regiony się nakładają → obliczenia się powtarzają!", - 14, - "#EF9A9A", - FONT_R, - (400, 495), - ), - ( - "Rozwiązanie: CNN raz na cały obraz → Fast R-CNN →", - 16, - "#A5D6A7", - FONT_R, - (80, 620), - ), - ] - text_clips: list[VideoClip] = [rcnn_clip] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(dur) - .with_position(pos) - ) - text_clips.append(tc) - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - - return slides - - -# ── ROI Pooling ────────────────────────────────────────────────── - - -def _draw_roi_pool_grid(frame: np.ndarray) -> None: - """Draw the 3x3 ROI pool grid with max-pooled feature values.""" - out_x, out_y = 400, 220 - out_cell = 50 - out_n = 3 - roi_r1, roi_c1 = 2, 1 - roi_r2, roi_c2 = 6, 5 - roi_h = roi_r2 - roi_r1 - roi_w = roi_c2 - roi_c1 - for r in range(out_n): - for c in range(out_n): - x = out_x + c * out_cell - y = out_y + r * out_cell - - # Compute the max from corresponding region - src_r1 = roi_r1 + r * roi_h // out_n - src_r2 = roi_r1 + (r + 1) * roi_h // out_n - src_c1 = roi_c1 + c * roi_w // out_n - src_c2 = roi_c1 + (c + 1) * roi_w // out_n - max_val = 0 - for sr in range(src_r1, src_r2): - for sc in range(src_c1, src_c2): - v = 30 + ((sr * 7 + sc * 13 + 42) % 40) - max_val = max(max_val, v) - - frame[y : y + out_cell - 2, x : x + out_cell - 2] = ( - max_val, - max_val + 20, - max_val + 40, - ) - frame[y : y + 2, x : x + out_cell - 2] = (80, 200, 120) - frame[y + out_cell - 4 : y + out_cell - 2, x : x + out_cell - 2] = ( - 80, - 200, - 120, - ) - - -def _make_roi_frame(t: float) -> np.ndarray: - """Render a single frame for the ROI pooling animation.""" - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - # Left: feature map with ROI highlighted - fm_x, fm_y = 60, 180 - fm_cell = 30 - fm_grid = 8 - for r in range(fm_grid): - for c in range(fm_grid): - x = fm_x + c * fm_cell - y = fm_y + r * fm_cell - # Random-looking feature values - val = 30 + ((r * 7 + c * 13 + 42) % 40) - frame[y : y + fm_cell - 1, x : x + fm_cell - 1] = ( - val, - val + 10, - val + 20, - ) - - # ROI region highlighted - roi_r1, roi_c1 = 2, 1 - roi_r2, roi_c2 = 6, 5 - for tt in range(3): - ry1 = fm_y + roi_r1 * fm_cell - tt - ry2 = fm_y + roi_r2 * fm_cell + tt - rx1 = fm_x + roi_c1 * fm_cell - tt - rx2 = fm_x + roi_c2 * fm_cell + tt - frame[ry1:ry2, rx1 : rx1 + 2] = (255, 200, 50) - frame[ry1:ry2, rx2 - 2 : rx2] = (255, 200, 50) - frame[ry1 : ry1 + 2, rx1:rx2] = (255, 200, 50) - frame[ry2 - 2 : ry2, rx1:rx2] = (255, 200, 50) - - # Arrow - arrow_phase = 0.3 - if progress > arrow_phase: - frame[300:303, 310:380] = (150, 150, 170) - - # Middle: ROI divided into 3x3 grid (output_size) - grid_phase = 0.3 - if progress > grid_phase: - _draw_roi_pool_grid(frame) - - # Arrow to FC - fc_phase = 0.6 - if progress > fc_phase: - frame[300:303, 560:630] = (150, 150, 170) - # FC box - frame[270:340, 650:730] = (200, 100, 80) - frame[270:272, 650:730] = (240, 140, 120) - frame[338:340, 650:730] = (240, 140, 120) - - return frame - - -def _roi_pooling_demo() -> list[CompositeVideoClip]: - """Animate ROI Pooling: key Fast R-CNN innovation.""" - slides = [] - - roi_clip = VideoClip(_make_roi_frame, duration=STEP_DUR + 1).with_fps(FPS) - dur = STEP_DUR + 1 - labels = [ - ("ROI Pooling: kluczowa innowacja Fast R-CNN", 26, "#FFE082", FONT_B, (80, 20)), - ( - "KROK 1: CNN raz na CAŁY obraz → feature mapa", - 17, - "#64B5F6", - FONT_R, - (80, 60), - ), - ( - "KROK 2: Wytnij ROI z feature mapy (nie z obrazu!)", - 17, - "#FFE082", - FONT_R, - (80, 90), - ), - ( - "KROK 3: Siatkuj ROI na 3x3 → max pool per komórka → stały rozmiar", - 17, - "#A5D6A7", - FONT_R, - (80, 120), - ), - ("Feature mapa", 14, "#64B5F6", FONT_B, (60, 160)), - ("ROI (żółta ramka)", 13, "#FFE082", FONT_R, (60, 440)), - ("ROI Pool 3x3", 14, "#A5D6A7", FONT_B, (400, 195)), - ("(max z komórki)", 13, "#78909C", FONT_R, (400, 380)), - ("FC", 14, "white", FONT_B, (670, 280)), - ( - "Problem: ROI mają RÓŻNE rozmiary, FC wymaga STAŁEGO", - 15, - "#B0BEC5", - FONT_R, - (80, 500), - ), - ( - "ROI Pooling: dzieli ROI na siatkę, max pool → STAŁY rozmiar!", - 16, - "white", - FONT_R, - (80, 535), - ), - ( - "Fast R-CNN: CNN raz → 1 feature mapa → " - "ROI Pool 2000 regionów → 25x szybciej!", - 16, - "#A5D6A7", - FONT_R, - (80, 580), - ), - ( - "(R-CNN: 2000x CNN = 50s | Fast R-CNN: 1xCNN + ROI Pool = 2s)", - 15, - "#EF9A9A", - FONT_R, - (80, 620), - ), - ] - text_clips: list[VideoClip] = [roi_clip] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(dur) - .with_position(pos) - ) - text_clips.append(tc) - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── RPN + Anchor Boxes ─────────────────────────────────────────── -def _rpn_anchors_demo() -> list[CompositeVideoClip]: - """Animate RPN and anchor boxes: Faster R-CNN innovation.""" - slides = [] - - # Slide 1: Anchor boxes concept - def make_anchors_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - # Draw feature map grid point with multiple anchors - cx, cy = 350, 360 # center point on feature map - - # Draw a "feature map" grid background - cell = 60 - for r in range(-3, 4): - for c in range(-3, 4): - x = cx + c * cell - cell // 2 - y = cy + r * cell - cell // 2 - frame[y : y + cell - 1, x : x + cell - 1] = (30, 35, 48) - - # Center point highlighted - frame[cy - 5 : cy + 5, cx - 5 : cx + 5] = (255, 200, 50) - - # Draw anchors around center: 3 sizes x 3 ratios = 9 - anchor_specs = [ - (30, 30, (200, 80, 80)), # small 1:1 - (20, 40, (200, 60, 60)), # small 1:2 - (40, 20, (180, 60, 60)), # small 2:1 - (60, 60, (80, 200, 80)), # medium 1:1 - (40, 80, (60, 180, 60)), # medium 1:2 - (80, 40, (60, 160, 60)), # medium 2:1 - (90, 90, (80, 80, 200)), # large 1:1 - (60, 120, (60, 60, 180)), # large 1:2 - (120, 60, (60, 60, 160)), # large 2:1 - ] - n_anchors = min(int(progress * 9) + 1, 9) - for i in range(n_anchors): - hw, hh, color = anchor_specs[i] - x1 = max(0, cx - hw) - y1 = max(0, cy - hh) - x2 = min(W - 1, cx + hw) - y2 = min(H - 1, cy + hh) - for tt in range(2): - frame[y1 - tt : y2 + tt, x1 - tt : x1 - tt + 2] = color - frame[y1 - tt : y2 + tt, x2 + tt - 2 : x2 + tt] = color - frame[y1 - tt : y1 - tt + 2, x1 - tt : x2 + tt] = color - frame[y2 + tt - 2 : y2 + tt, x1 - tt : x2 + tt] = color - - return frame - - anch_clip = VideoClip(make_anchors_frame, duration=STEP_DUR + 1).with_fps(FPS) - dur = STEP_DUR + 1 - labels = [ - ("Anchor Boxes + RPN (Faster R-CNN)", 26, "#FFE082", FONT_B, (80, 20)), - ( - "KROK 1: Anchory = predefiniowane kształty w każdej pozycji", - 17, - "#A5D6A7", - FONT_R, - (80, 60), - ), - ( - "3 rozmiary x 3 proporcje = 9 anchorów per punkt", - 16, - "#B0BEC5", - FONT_R, - (80, 90), - ), - ("Małe (1:1, 1:2, 2:1)", 14, "#EF9A9A", FONT_R, (750, 170)), - ("Średnie (1:1, 1:2, 2:1)", 14, "#A5D6A7", FONT_R, (750, 210)), - ("Duże (1:1, 1:2, 2:1)", 14, "#64B5F6", FONT_R, (750, 250)), - ("Żółty punkt = pozycja", 14, "#FFE082", FONT_R, (750, 310)), - ("na feature mapie", 14, "#FFE082", FONT_R, (750, 335)), - ("Sieć NIE predykuje bbox od zera!", 16, "white", FONT_R, (80, 530)), - ( - "Predykuje OFFSET od najbliższego anchora: (Δx, Δy, Δw, Δh)", - 16, - "#FFE082", - FONT_R, - (80, 565), - ), - ( - "+ P(obiekt) = 'czy w tym anchorze jest coś?'", - 16, - "#A5D6A7", - FONT_R, - (80, 600), - ), - ( - "Mnemonik: Anchor = KOTWICA — sieć dopasowuje bbox do kotwicy", - 15, - "#78909C", - FONT_R, - (80, 645), - ), - ] - text_clips: list[VideoClip] = [anch_clip] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(dur) - .with_position(pos) - ) - text_clips.append(tc) - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - - # Slide 2: RPN step by step - rpn_lines = [ - ( - "RPN: Region Proposal Network — krok po kroku", - 24, - "#FFE082", - FONT_B, - (80, 30), - ), - ( - "Zastępuje Selective Search SIECIĄ NEURONOWĄ (end-to-end!)", - 17, - "#B0BEC5", - FONT_R, - (80, 85), - ), - ("", 10, "white", FONT_R, (80, 110)), - ( - "1. Backbone (ResNet) przetwarza obraz → feature mapa [40x60x256]", - 16, - "#64B5F6", - FONT_R, - (100, 140), - ), - ( - "2. Filtr 3x3 przesuwa się po feature mapie", - 16, - "#A5D6A7", - FONT_R, - (100, 180), - ), - ( - "3. W KAŻDEJ pozycji (x,y) rozważ k=9 anchorów:", - 16, - "#FFE082", - FONT_R, - (100, 220), - ), - (" → P(obiekt) — 'czy tu jest coś?'", 15, "white", FONT_R, (120, 255)), - (" → (Δx, Δy, Δw, Δh) — poprawka pozycji", 15, "white", FONT_R, (120, 285)), - ( - "4. 40x60 pozycji x 9 anchorów = 21 600 kandydatów!", - 16, - "#EF9A9A", - FONT_R, - (100, 325), - ), - ( - "5. Weź ~300 z najwyższym P(obiekt) → ROI Pool → FC", - 16, - "#A5D6A7", - FONT_R, - (100, 365), - ), - ("", 10, "white", FONT_R, (100, 395)), - ("Porównanie generowania propozycji:", 17, "white", FONT_B, (80, 420)), - ( - " Selective Search: ~2000 regionów, osobny algorytm, ~2 sec", - 15, - "#EF9A9A", - FONT_R, - (100, 460), - ), - ( - " RPN: ~300 regionów, W SIECI, ~10 ms → 200x szybciej!", - 15, - "#A5D6A7", - FONT_R, - (100, 495), - ), - ("", 10, "white", FONT_R, (100, 520)), - ( - "Faster R-CNN = Backbone + RPN + ROI Pool + FC — WSZYSTKO end-to-end", - 17, - "#FFE082", - FONT_R, - (80, 545), - ), - ( - "→ 5 fps (0.2 sec/obraz) vs R-CNN 50 sec = 250x szybciej!", - 17, - "#A5D6A7", - FONT_R, - (80, 585), - ), - ( - "Wciąż two-stage: (1) RPN generuje propozycje, (2) FC klasyfikuje", - 15, - "#78909C", - FONT_R, - (80, 630), - ), - ] - slides.append(_text_slide(rpn_lines, duration=STEP_DUR + 1)) - - return slides - - -# ── YOLO ────────────────────────────────────────────────────────── -def _yolo_demo() -> list[CompositeVideoClip]: - """Animate YOLO grid detection concept.""" - slides = [] - - def make_yolo_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - progress = min(t / (STEP_DUR * 0.7), 1.0) - - # Draw image with grid overlay - img_x, img_y = 100, 140 - img_size = 420 - grid_n = 7 - - # Background "image" - frame[img_y : img_y + img_size, img_x : img_x + img_size] = (50, 55, 70) - - # Objects in the image - frame[img_y + 80 : img_y + 200, img_x + 50 : img_x + 180] = ( - 180, - 60, - 60, - ) # "car" - frame[img_y + 150 : img_y + 350, img_x + 250 : img_x + 330] = ( - 60, - 120, - 180, - ) # "person" - - # Grid lines - cell = img_size // grid_n - for i in range(grid_n + 1): - # Vertical - x = img_x + i * cell - frame[img_y : img_y + img_size, x : x + 1] = (100, 100, 120) - # Horizontal - y = img_y + i * cell - frame[y : y + 1, img_x : img_x + img_size] = (100, 100, 120) - - # Highlight cells containing object centers - car_phase = 0.3 - if progress > car_phase: - # Car center ~ cell (1, 1) - cx, cy = 1, 2 - hx = img_x + cx * cell - hy = img_y + cy * cell - frame[hy : hy + cell, hx : hx + cell] = np.clip( - frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255 - ).astype(np.uint8) - - person_phase = 0.5 - if progress > person_phase: - # Person center ~ cell (4, 4) - cx, cy = 4, 4 - hx = img_x + cx * cell - hy = img_y + cy * cell - frame[hy : hy + cell, hx : hx + cell] = np.clip( - frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255 - ).astype(np.uint8) - - # Bounding boxes predictions from cells - bbox_phase = 0.6 - if progress > bbox_phase: - # Car bbox - for tt in range(2): - frame[ - img_y + 78 - tt : img_y + 202 + tt, - img_x + 48 - tt : img_x + 48 - tt + 2, - ] = (255, 80, 80) - frame[ - img_y + 78 - tt : img_y + 202 + tt, - img_x + 182 + tt - 2 : img_x + 182 + tt, - ] = (255, 80, 80) - frame[ - img_y + 78 - tt : img_y + 78 - tt + 2, - img_x + 48 - tt : img_x + 182 + tt, - ] = (255, 80, 80) - frame[ - img_y + 202 + tt - 2 : img_y + 202 + tt, - img_x + 48 - tt : img_x + 182 + tt, - ] = (255, 80, 80) - - # Person bbox - for tt in range(2): - frame[ - img_y + 148 - tt : img_y + 352 + tt, - img_x + 248 - tt : img_x + 248 - tt + 2, - ] = (80, 180, 255) - frame[ - img_y + 148 - tt : img_y + 352 + tt, - img_x + 332 + tt - 2 : img_x + 332 + tt, - ] = (80, 180, 255) - frame[ - img_y + 148 - tt : img_y + 148 - tt + 2, - img_x + 248 - tt : img_x + 332 + tt, - ] = (80, 180, 255) - frame[ - img_y + 352 + tt - 2 : img_y + 352 + tt, - img_x + 248 - tt : img_x + 332 + tt, - ] = (80, 180, 255) - - return frame - - yolo_clip = VideoClip(make_yolo_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [yolo_clip] - labels = [ - ("YOLO — You Only Look Once", 28, "#FFE082", FONT_B, (80, 20)), - ( - "Jednoetapowy detektor: siatka SxS → wszystkie detekcje naraz!", - 18, - "#B0BEC5", - FONT_R, - (80, 65), - ), - ("Siatka 7x7 = 49 komórek", 16, "#64B5F6", FONT_R, (600, 180)), - ("Każda komórka predykuje:", 16, "white", FONT_R, (600, 220)), - (" • B bbox (x, y, w, h, conf)", 14, "#B0BEC5", FONT_R, (600, 255)), - (" • C klas (prawdopodobieństwa)", 14, "#B0BEC5", FONT_R, (600, 285)), - ("Komórka odpowiada za obiekt", 14, "#A5D6A7", FONT_R, (600, 325)), - ("którego ŚRODEK w niej wpada", 14, "#A5D6A7", FONT_R, (600, 350)), - ("45-155 fps! (vs 5 fps Faster R-CNN)", 18, "#EF9A9A", FONT_B, (600, 400)), - ( - "Jedno przejście przez sieć → WSZYSTKIE detekcje naraz → NMS → wynik", - 14, - "#78909C", - FONT_R, - (80, 620), - ), - ( - "Two-stage (R-CNN): propozycje+klasyfikacja " - "| One-stage (YOLO): bez propozycji!", - 14, - "#90CAF9", - FONT_R, - (80, 655), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── YOLO Architecture Detail ────────────────────────────────────── -def _yolo_architecture() -> list[CompositeVideoClip]: - """Show YOLO architecture: backbone → head, output tensor.""" - slides = [] - - # Slide 1: YOLO architecture breakdown - def make_yolo_arch(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - # Pipeline: Image → Backbone → Neck → Head → SxSx(B*5+C) tensor - blocks = [ - ((60, 280), (100, 80), (50, 70, 90), "Obraz"), - ((200, 280), (100, 80), (70, 130, 200), "Backbone"), - ((340, 280), (100, 80), (200, 160, 80), "Neck"), - ((480, 280), (100, 80), (200, 100, 60), "Head"), - ((620, 280), (160, 80), (80, 200, 120), "SxSx(B*5+C)"), - ] - n_blocks = min(int(progress * 5) + 1, 5) - for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(blocks): - if i < n_blocks: - frame[by : by + bh, bx : bx + bw] = color - frame[by : by + 2, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - arrow_limit = 4 - if i < arrow_limit: - ax = bx + bw + 5 - ay = by + bh // 2 - frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170) - - # Output tensor breakdown (right side) - tensor_phase = 0.6 - if progress > tensor_phase: - # Show SxS grid - gx, gy = 850, 180 - gs = 120 - gn = 4 # simplified from 7 - gc = gs // gn - for r in range(gn): - for c in range(gn): - x = gx + c * gc - y = gy + r * gc - frame[y : y + gc - 1, x : x + gc - 1] = (40, 50, 65) - # Highlight one cell - frame[gy + gc : gy + 2 * gc - 1, gx + gc : gx + 2 * gc - 1] = (80, 200, 120) - - return frame - - arch_clip = VideoClip(make_yolo_arch, duration=STEP_DUR + 1).with_fps(FPS) - dur = STEP_DUR + 1 - labels = [ - ("YOLO: Architektura — krok po kroku", 26, "#FFE082", FONT_B, (80, 20)), - ( - "One-stage: JEDEN forward pass → WSZYSTKIE detekcje naraz", - 17, - "#B0BEC5", - FONT_R, - (80, 60), - ), - ("Obraz", 13, "white", FONT_R, (85, 295)), - ("Backbone", 13, "white", FONT_R, (215, 295)), - ("(ResNet/", 11, "#78909C", FONT_R, (210, 370)), - ("Darknet)", 11, "#78909C", FONT_R, (210, 390)), - ("Neck", 13, "white", FONT_R, (365, 295)), - ("(FPN/", 11, "#78909C", FONT_R, (360, 370)), - ("PANet)", 11, "#78909C", FONT_R, (360, 390)), - ("Head", 13, "white", FONT_R, (505, 295)), - ("(conv)", 11, "#78909C", FONT_R, (500, 370)), - ("Tensor wyjścia", 13, "#A5D6A7", FONT_R, (640, 295)), - ("Każda komórka SxS predykuje:", 15, "#FFE082", FONT_R, (830, 320)), - (" B bbox x (x,y,w,h,conf)", 13, "#B0BEC5", FONT_R, (830, 350)), - (" + C klas (prob.)", 13, "#B0BEC5", FONT_R, (830, 375)), - ("= SxSx(Bx5+C) tensor", 13, "#A5D6A7", FONT_R, (830, 400)), - ("Np. 7x7x(2x5+20) = 7x7x30", 13, "#78909C", FONT_R, (830, 430)), - ( - "Two-stage (R-CNN): (1) propozycje → (2) klasyfikacja = 2 przejścia", - 15, - "#EF9A9A", - FONT_R, - (80, 470), - ), - ( - "One-stage (YOLO): siatka → predykcja all-in-one = 1 przejście!", - 15, - "#A5D6A7", - FONT_R, - (80, 505), - ), - ( - "Ewolucja YOLO: v1(2016)→v3→v5→v8(2023, anchor-free, SOTA)", - 16, - "#FFE082", - FONT_R, - (80, 555), - ), - ( - "SSD (2016): multi-scale feature maps → lepsza detekcja małych obiektów", - 15, - "#64B5F6", - FONT_R, - (80, 595), - ), - ( - "FPN: łączy wczesne warstwy (małe obiekty) + późne (duże obiekty)", - 15, - "#78909C", - FONT_R, - (80, 630), - ), - ] - text_clips: list[VideoClip] = [arch_clip] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(dur) - .with_position(pos) - ) - text_clips.append(tc) - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - - return slides - - -# ── DETR ────────────────────────────────────────────────────────── -def _detr_demo() -> list[CompositeVideoClip]: - """Animate DETR: transformer detection, object queries, no NMS.""" - slides = [] - - # Slide 1: DETR pipeline - def make_detr_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - progress = min(t / (STEP_DUR * 0.7), 1.0) - - # DETR pipeline: Image → Backbone → Encoder → Decoder → N predictions - blocks = [ - ((50, 260), (80, 60), (50, 70, 90)), - ((170, 260), (90, 60), (70, 130, 200)), - ((300, 260), (110, 60), (200, 120, 60)), - ((450, 260), (110, 60), (200, 80, 160)), - ((600, 260), (120, 60), (80, 200, 120)), - ] - n_blocks = min(int(progress * 5) + 1, 5) - for i, ((bx, by), (bw, bh), color) in enumerate(blocks): - if i < n_blocks: - frame[by : by + bh, bx : bx + bw] = color - frame[by : by + 2, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple( - min(c + 50, 255) for c in color - ) - arrow_limit = 4 - if i < arrow_limit: - ax = bx + bw + 5 - ay = by + bh // 2 - frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170) - - # Object queries illustration (right side) - query_phase = 0.5 - if progress > query_phase: - qx, qy = 800, 140 - for i in range(6): - y = qy + i * 50 - w = 130 - active_limit = 3 - active = i < active_limit - color = (80, 180, 120) if active else (60, 50, 50) - frame[y : y + 35, qx : qx + w] = color - frame[y : y + 1, qx : qx + w] = tuple(min(c + 40, 255) for c in color) - - # Arrow from decoder to queries - frame[285:288, 723:798] = (150, 150, 170) - - return frame - - detr_clip = VideoClip(make_detr_frame, duration=STEP_DUR + 1).with_fps(FPS) - dur = STEP_DUR + 1 - labels = [ - ("DETR: DEtection TRansformer (2020)", 26, "#FFE082", FONT_B, (80, 20)), - ( - "Radykalnie prostszy pipeline: BEZ anchorów, BEZ NMS!", - 17, - "#B0BEC5", - FONT_R, - (80, 60), - ), - ("Obraz", 12, "white", FONT_R, (65, 275)), - ("Backbone", 12, "white", FONT_R, (185, 275)), - ("Transformer", 12, "white", FONT_R, (310, 275)), - ("Encoder", 12, "white", FONT_R, (325, 295)), - ("Transformer", 12, "white", FONT_R, (460, 275)), - ("Decoder", 12, "white", FONT_R, (478, 295)), - ("N predykcji", 12, "white", FONT_R, (615, 275)), - ("Object Queries:", 14, "#FFE082", FONT_B, (800, 115)), - ("samochód 95%", 11, "white", FONT_R, (810, 148)), - ("pies 88%", 11, "white", FONT_R, (810, 198)), - ("rower 72%", 11, "white", FONT_R, (810, 248)), - ("brak", 11, "#78909C", FONT_R, (810, 298)), - ("brak", 11, "#78909C", FONT_R, (810, 348)), - ("brak", 11, "#78909C", FONT_R, (810, 398)), - ("100 wyuczonych queries", 13, "#FFE082", FONT_R, (800, 440)), - ("→ każdy 'szuka' obiektu", 13, "#FFE082", FONT_R, (800, 465)), - ] - text_clips: list[VideoClip] = [detr_clip] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(dur) - .with_position(pos) - ) - text_clips.append(tc) - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - - # Slide 2: Why no NMS + Hungarian matching - detr_details = [ - ("DETR: Dlaczego bez NMS? — krok po kroku", 24, "#FFE082", FONT_B, (80, 30)), - ( - "Problem NMS: duplikaty detekcji → ręcznie usuwaj post-hoc", - 16, - "#EF9A9A", - FONT_R, - (80, 90), - ), - ( - "DETR rozwiązanie: Hungarian matching (dopasowanie węgierskie)", - 17, - "#A5D6A7", - FONT_R, - (80, 130), - ), - ("", 10, "white", FONT_R, (80, 155)), - ("Jak to działa podczas TRENINGU:", 17, "white", FONT_B, (80, 180)), - (" 1. Sieć daje N=100 predykcji (queries)", 15, "#64B5F6", FONT_R, (100, 220)), - ( - " 2. Na obrazie jest np. 5 obiektów (ground truth)", - 15, - "#64B5F6", - FONT_R, - (100, 255), - ), - ( - " 3. Hungarian matching: optymalne dopasowanie 1:1", - 15, - "#FFE082", - FONT_R, - (100, 290), - ), - ( - " → query_1 ↔ gt_samochód (najlepsze dopasowanie)", - 14, - "#A5D6A7", - FONT_R, - (120, 325), - ), - (" → query_7 ↔ gt_pies", 14, "#A5D6A7", FONT_R, (120, 355)), - (" → query_3 ↔ gt_rower", 14, "#A5D6A7", FONT_R, (120, 385)), - ( - " → pozostałe 97 queries ↔ klasa 'brak obiektu'", - 14, - "#78909C", - FONT_R, - (120, 415), - ), - ( - " 4. Każdy obiekt ma DOKŁADNIE 1 predykcję → BRAK duplikatów!", - 15, - "#A5D6A7", - FONT_R, - (100, 455), - ), - ("", 10, "white", FONT_R, (100, 475)), - ( - "Self-attention w encoderze: cechy obrazu 'rozmawiają' ze sobą", - 15, - "#64B5F6", - FONT_R, - (80, 500), - ), - ( - "Cross-attention w decoderze: queries 'pytają' cechy obrazu", - 15, - "#CE93D8", - FONT_R, - (80, 535), - ), - ( - "→ query 'rozumie' który fragment obrazu to 'jego' obiekt", - 15, - "#FFE082", - FONT_R, - (80, 570), - ), - ( - "DETR = Detekcja Eliminująca Trikowe Redundancje (NMS, anchory)", - 16, - "#FFE082", - FONT_R, - (80, 620), - ), - ( - "Wada: wolniejszy trening (O(n²) attention) | Zaleta: prostszy pipeline!", - 15, - "#78909C", - FONT_R, - (80, 660), - ), - ] - slides.append(_text_slide(detr_details, duration=STEP_DUR + 1)) - - # Slide 3: Two-stage vs One-stage vs Transformer summary - summary_lines = [ - ( - "Podsumowanie: Two-stage vs One-stage vs Transformer", - 22, - "#FFE082", - FONT_B, - (80, 30), - ), - ("", 10, "white", FONT_R, (80, 55)), - ("TWO-STAGE (R-CNN family):", 18, "#EF9A9A", FONT_B, (80, 90)), - ( - " (1) Generuj propozycje → (2) Klasyfikuj per region", - 15, - "white", - FONT_R, - (100, 125), - ), - ( - " + Wysoka precyzja | - Wolniejsze (2 przejścia)", - 15, - "#78909C", - FONT_R, - (100, 155), - ), - ( - " R-CNN → Fast R-CNN → Faster R-CNN (0.2s)", - 15, - "#B0BEC5", - FONT_R, - (100, 185), - ), - ("", 10, "white", FONT_R, (80, 210)), - ("ONE-STAGE (YOLO, SSD):", 18, "#A5D6A7", FONT_B, (80, 240)), - ( - " Siatka → predykcja all-in-one (1 przejście)", - 15, - "white", - FONT_R, - (100, 275), - ), - ( - " + Bardzo szybkie (45-155 fps) | - Historycznie mniej precyzyjne", - 15, - "#78909C", - FONT_R, - (100, 305), - ), - ( - " YOLOv8 (2023): anchor-free, dorównuje two-stage!", - 15, - "#B0BEC5", - FONT_R, - (100, 335), - ), - ("", 10, "white", FONT_R, (80, 360)), - ("TRANSFORMER (DETR):", 18, "#CE93D8", FONT_B, (80, 390)), - ( - " Object queries + self-attention (globalny kontekst)", - 15, - "white", - FONT_R, - (100, 425), - ), - ( - " + Brak NMS/anchorów | - Wolniejszy trening (O(n²))", - 15, - "#78909C", - FONT_R, - (100, 455), - ), - ( - " Hungarian matching → 1:1 obiekt↔predykcja → brak duplikatów", - 15, - "#B0BEC5", - FONT_R, - (100, 485), - ), - ("", 10, "white", FONT_R, (80, 510)), - ( - "Trend: coraz prostsze pipeline, mniej ręcznych komponentów", - 17, - "white", - FONT_R, - (80, 540), - ), - ( - " R-CNN (SS+CNN+SVM+NMS) → YOLO " - "(backbone+head+NMS) → DETR (backbone+transformer)", - 14, - "#90CAF9", - FONT_R, - (80, 580), - ), - ( - "Metryki: mAP@0.5 (standard), mAP@0.5:0.95 (surowsza), IoU do dopasowania", - 15, - "#78909C", - FONT_R, - (80, 630), - ), - ] - slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1)) - - return slides - - -# ── NMS + IoU ───────────────────────────────────────────────────── -def _nms_iou_demo() -> list[CompositeVideoClip]: - """Animate NMS and IoU concepts.""" - slides = [] - - def make_nms_frame(t: float) -> np.ndarray: - frame = np.zeros((H, W, 3), dtype=np.uint8) - frame[:] = BG_COLOR - - progress = min(t / (STEP_DUR * 0.7), 1.0) - - # Draw overlapping bounding boxes - ox, oy = 100, 200 - obj_w, obj_h = 150, 120 - - # Multiple overlapping detections for same object - boxes = [ - (ox, oy, obj_w, obj_h, 0.95, (255, 80, 80)), # best - (ox + 15, oy - 10, obj_w + 10, obj_h + 5, 0.90, (200, 60, 60)), - (ox - 10, oy + 5, obj_w - 5, obj_h + 10, 0.85, (160, 50, 50)), - ] - # Different object far away - boxes.append((ox + 350, oy + 50, 100, 100, 0.40, (80, 180, 255))) - - for i, (bx, by, bw, bh, _conf, color) in enumerate(boxes): - dc = color - nms_phase = 0.4 - nms_limit = 3 - if progress > nms_phase and i > 0 and i < nms_limit: - # After NMS, these get removed (shown as faded/crossed) - dc = (60, 40, 40) - - for tt in range(2): - frame[by - tt : by + bh + tt, bx - tt : bx - tt + 2] = dc - frame[by - tt : by + bh + tt, bx + bw + tt - 2 : bx + bw + tt] = dc - frame[by - tt : by - tt + 2, bx - tt : bx + bw + tt] = dc - frame[by + bh + tt - 2 : by + bh + tt, bx - tt : bx + bw + tt] = dc - - # IoU visualization on right side - iou_x, iou_y = 700, 200 - # Box A - frame[iou_y : iou_y + 100, iou_x : iou_x + 100] = (80, 80, 200) - # Box B (overlapping) - frame[iou_y + 40 : iou_y + 140, iou_x + 40 : iou_x + 140] = (200, 80, 80) - # Intersection highlighted - frame[iou_y + 40 : iou_y + 100, iou_x + 40 : iou_x + 100] = (200, 150, 200) - - return frame - - nms_clip = VideoClip(make_nms_frame, duration=STEP_DUR).with_fps(FPS) - text_clips: list[VideoClip] = [nms_clip] - labels = [ - ("NMS (Non-Maximum Suppression) + IoU", 28, "#FFE082", FONT_B, (80, 20)), - ( - "NMS = Najlepszy Ma Się dobrze — zachowaj najlepszą, usuń duplikaty", - 18, - "#B0BEC5", - FONT_R, - (80, 65), - ), - ("conf=0.95 ✓", 14, "#A5D6A7", FONT_B, (100, 340)), - ("0.90 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 365)), - ("0.85 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 390)), - ("0.40 ✓ INNY obiekt", 13, "#64B5F6", FONT_R, (100, 420)), - ("IoU = Intersection over Union", 18, "#FFE082", FONT_B, (700, 160)), - ("IoU = pole(∩) / pole(AUB)", 16, "white", FONT_R, (700, 380)), - ("Fioletowy = intersection", 14, "#CE93D8", FONT_R, (700, 410)), - ("IoU > 0.5 → TEN SAM obiekt → usuń", 14, "#EF9A9A", FONT_R, (700, 440)), - ("IoU < 0.5 → INNY obiekt → zachowaj", 14, "#A5D6A7", FONT_R, (700, 470)), - ( - "DETR: jedyny detektor BEZ NMS (Hungarian matching zamiast tego)", - 14, - "#78909C", - FONT_R, - (80, 620), - ), - ] - for text, fs, color, font, pos in labels: - tc = ( - _tc(text=text, font_size=fs, color=color, font=font) - .with_duration(STEP_DUR) - .with_position(pos) - ) - text_clips.append(tc) - - slides.append( - CompositeVideoClip(text_clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - ) - return slides - - -# ── Detector from Classifier ───────────────────────────────────── -def _detector_from_classifier() -> list[CompositeVideoClip]: - """Show 3 approaches to building a detector from a classifier.""" - slides = [] - - approaches = [ - ( - "Podejście 1: Sliding Window (NAJWOLNIEJSZE)", - [ - ("Okno przesuwa się po obrazie w wielu skalach", "#B0BEC5"), - ("Każde okno → klasyfikator (np. ResNet) → klasa + pewność", "#B0BEC5"), - ("~18 000 okien x 10ms = ~3 minuty na obraz!", "#EF9A9A"), - ("Mnemonik: WYCINAJ i PYTAJ — jak wycinanie ciasteczek", "#FFE082"), - ], - "SRF", - ), - ( - "Podejście 2: Region Proposals (= R-CNN)", - [ - ("Selective Search → ~2000 inteligentnych regionów", "#B0BEC5"), - ("Każdy region → CNN → wektor cech → SVM klasyfikuje", "#B0BEC5"), - ("~2000 x 10ms = ~20 sec — 9x szybciej!", "#64B5F6"), - ( - "Mnemonik: INTELIGENTNE CIĘCIE — wytnij tylko tam gdzie wiśnie", - "#FFE082", - ), - ], - "SRF", - ), - ( - "Podejście 3: Fine-tune backbone (NAJLEPSZE)", - [ - ( - "Pretrained backbone (ResNet) → odetnij FC → dodaj detection head", - "#B0BEC5", - ), - ( - "Detection head = głowica klasyfikacji + głowica regresji bbox", - "#B0BEC5", - ), - ("~0.2 sec/obraz, najlepsza jakość (mAP ~42%)", "#A5D6A7"), - ("Mnemonik: PRZESZCZEP GŁOWY — ten sam silnik, nowa głowa", "#FFE082"), - ], - "SRF", - ), - ] - - for title, points, _mnem in approaches: - lines = [ - (title, 24, "#FFE082", FONT_B, (80, 140)), - ] - for i, (text, color) in enumerate(points): - lines.append((f"• {text}", 18, color, FONT_R, (100, 220 + i * 50))) - - lines.append( - ( - "Detektor z klasyfikatora: SRF = Sliding → Region → Fine-tune", - 16, - "#78909C", - FONT_R, - (80, 520), - ) - ) - lines.append( - ( - "= Szukaj Ręcznie, Finalnie optymalizuj!", - 16, - "#90CAF9", - FONT_R, - (80, 550), - ) - ) - - slides.append(_text_slide(lines, duration=STEP_DUR)) - - return slides - - -def _text_slide( - lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]], - duration: float = STEP_DUR, -) -> CompositeVideoClip: - bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration) - clips: list[VideoClip] = [bg] - for text, font_size, color, font, pos in lines: - tc = ( - _tc( - text=text, - font_size=font_size, - color=color, - font=font, - ) - .with_duration(duration) - .with_position(pos) - ) - clips.append(tc) - return CompositeVideoClip(clips, size=(W, H)).with_effects( - [FadeIn(0.3), FadeOut(0.3)] - ) - - -# ── Methods comparison ──────────────────────────────────────────── -def _methods_comparison() -> CompositeVideoClip: - bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0) - title = ( - _tc( - text="Porównanie detektorów", - font_size=36, - color="white", - font=FONT_B, - ) - .with_duration(10.0) - .with_position(("center", 20)) - ) - - rows = [ - ("Model", "Rok", "Typ", "Szybkość", "Kluczowe"), - ("HOG+SVM", "2005", "Klasyczny", "~1 fps", "Gradient histogramy"), - ("Viola-Jones", "2001", "Klasyczny", "30+ fps", "Haar+Cascade"), - ("R-CNN", "2014", "Two-stage", "50 sec!", "CNN per region"), - ("Fast R-CNN", "2015", "Two-stage", "2 sec", "ROI Pooling"), - ("Faster R-CNN", "2015", "Two-stage", "5 fps", "RPN w sieci"), - ("YOLO", "2016", "One-stage", "45+ fps", "Siatka SxS"), - ("DETR", "2020", "Transformer", "~40 fps", "Bez NMS!"), - ] - - clips: list[VideoClip] = [bg, title] - for i, row in enumerate(rows): - y_pos = 75 + i * 72 - col_x = [40, 200, 280, 400, 530] - for j, cell in enumerate(row): - fs = 16 if i > 0 else 18 - color = "#64B5F6" if i == 0 else "#E0E0E0" - tc = ( - _tc( - text=cell, - font_size=fs, - color=color, - font=FONT_B if i == 0 else FONT_R, - ) - .with_duration(10.0) - .with_position((col_x[j], y_pos)) - ) - clips.append(tc) - - return CompositeVideoClip(clips, size=(W, H)).with_effects( - [FadeIn(0.5), FadeOut(0.5)] - ) +from _q24_common import FPS, OUTPUT, _logger, _make_header +from _q24_nms_final import ( + _detector_from_classifier, + _methods_comparison, + _nms_iou_demo, +) +from _q24_rcnn import ( + _rcnn_detailed, + _rcnn_evolution, + _roi_pooling_demo, +) +from _q24_rpn_yolo import _rpn_anchors_demo, _yolo_demo +from _q24_yolo_arch_detr import _detr_demo, _yolo_architecture +from moviepy import VideoClip, concatenate_videoclips # ── Main ────────────────────────────────────────────────────────── @@ -1923,4 +138,5 @@ def main() -> None: if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) main() diff --git a/python_pkg/screen_locker/_constants.py b/python_pkg/screen_locker/_constants.py new file mode 100644 index 0000000..d7f22cc --- /dev/null +++ b/python_pkg/screen_locker/_constants.py @@ -0,0 +1,36 @@ +"""Constants for the screen locker module.""" + +from __future__ import annotations + +from pathlib import Path + +# Validation limits for workout data +MAX_DISTANCE_KM = 100 +MAX_TIME_MINUTES = 600 +MAX_PACE_MIN_PER_KM = 20 +MIN_EXERCISE_NAME_LEN = 3 +MAX_SETS = 20 +MAX_REPS = 100 +MAX_WEIGHT_KG = 500 +SICK_LOCKOUT_SECONDS = 120 # 2 minutes wait when sick +SUBMIT_DELAY_DEMO = 30 +SUBMIT_DELAY_PRODUCTION = 180 +PHONE_PENALTY_DELAY_DEMO = 10 +PHONE_PENALTY_DELAY_PRODUCTION = 600 +ADB_TIMEOUT = 15 +STRONGLIFTS_DB_REMOTE = ( + "/data/data/com.stronglifts.app/databases/StrongLifts-Database-3" +) +SHUTDOWN_CONFIG_FILE = Path("/etc/shutdown-schedule.conf") +# Helper script path (relative to this file) +ADJUST_SHUTDOWN_SCRIPT = Path(__file__).resolve().parent / "adjust_shutdown_schedule.sh" +# State file to track sick day usage and original config values +SICK_DAY_STATE_FILE = Path(__file__).resolve().parent / "sick_day_state.json" + +STRENGTH_FIELDS: list[tuple[str, int]] = [ + ("Exercises (comma-separated):", 50), + ("Sets per exercise (comma-separated):", 20), + ("Reps (comma-sep, + for variable: 12+11+12):", 30), + ("Weight per exercise kg (comma-separated):", 20), + ("Total weight lifted (kg):", 15), +] diff --git a/python_pkg/screen_locker/_phone_verification.py b/python_pkg/screen_locker/_phone_verification.py new file mode 100644 index 0000000..5d158ef --- /dev/null +++ b/python_pkg/screen_locker/_phone_verification.py @@ -0,0 +1,203 @@ +"""Phone workout verification mixin using ADB and StrongLifts.""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, as_completed +import contextlib +import logging +from pathlib import Path +import shutil +import socket +import sqlite3 +import subprocess +import tempfile + +from python_pkg.screen_locker._constants import ADB_TIMEOUT, STRONGLIFTS_DB_REMOTE + +_logger = logging.getLogger(__name__) + + +class PhoneVerificationMixin: + """Mixin providing phone-based workout verification via ADB.""" + + def _run_adb(self, args: list[str]) -> tuple[bool, str]: + """Run an ADB command and return success flag and stdout.""" + adb = shutil.which("adb") or "adb" + # When multiple devices are connected (e.g. USB + wireless), pin to + # the wireless device's serial to avoid "more than one device" errors. + _discovery_cmds = {"devices", "connect", "disconnect", "kill-server"} + serial = ( + self._get_wireless_serial() + if args and args[0] not in _discovery_cmds + else None + ) + serial_args = ["-s", serial] if serial else [] + try: + result = subprocess.run( + [adb, *serial_args, *args], + capture_output=True, + text=True, + timeout=ADB_TIMEOUT, + check=False, + ) + except (FileNotFoundError, OSError) as exc: + _logger.warning("ADB not available: %s", exc) + return False, "" + except subprocess.TimeoutExpired: + _logger.warning("ADB command timed out: %s", args) + return False, "" + return result.returncode == 0, result.stdout + + def _adb_shell( + self, + command: str, + *, + root: bool = False, + ) -> tuple[bool, str]: + """Run a shell command on the connected Android device.""" + if root: + return self._run_adb(["shell", "su", "-c", command]) + return self._run_adb(["shell", command]) + + def _get_wireless_serial(self) -> str | None: + """Return the serial (ip:port) of the first connected wireless ADB device. + + Used to pin ADB commands to the wireless device when multiple devices + (e.g. USB cable + wireless debugging) are simultaneously connected. + """ + success, output = self._run_adb(["devices"]) + if not success: + return None + for line in output.strip().split("\n")[1:]: + parts = line.split() + if parts and ":" in parts[0] and "device" in line and "offline" not in line: + return parts[0] + return None + + def _has_adb_device(self) -> bool: + """Return True if adb devices shows at least one connected device.""" + success, output = self._run_adb(["devices"]) + if not success: + return False + lines = output.strip().split("\n")[1:] + return any("device" in line and "offline" not in line for line in lines) + + def _try_adb_connect(self, address: str) -> bool: + """Run adb connect to address. Returns True on success.""" + _, output = self._run_adb(["connect", address]) + lower = output.lower() + return "connected" in lower and "unable" not in lower and "failed" not in lower + + def _get_local_subnet_prefix(self) -> str | None: + """Detect the local /24 network prefix (e.g. '192.168.1').""" + with ( + contextlib.suppress(OSError), + socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock, + ): + sock.connect(("8.8.8.8", 80)) + return ".".join(sock.getsockname()[0].split(".")[:3]) + return None + + def _try_wireless_reconnect(self) -> bool: + """Scan local /24 subnet on port 5555 and attempt ADB connect to phone.""" + prefix = self._get_local_subnet_prefix() + if prefix is None: + _logger.info("Could not determine local subnet for wireless scan") + return False + + def probe(i: int) -> bool: + ip = f"{prefix}.{i}" + with ( + contextlib.suppress(OSError), + socket.create_connection((ip, 5555), timeout=0.5), + ): + if self._try_adb_connect(f"{ip}:5555"): + return self._has_adb_device() + return False + + _logger.info("Scanning %s.1-254:5555 for phone...", prefix) + with ThreadPoolExecutor(max_workers=64) as executor: + for future in as_completed( + executor.submit(probe, i) for i in range(1, 255) + ): + if future.result(): + return True + return False + + def _is_phone_connected(self) -> bool: + """Check if an Android device is connected via ADB. + + If no device is visible, attempts wireless reconnection using the + stored phone IP/port config. USB-connected devices are detected + automatically by adb devices without any extra steps. + """ + if self._has_adb_device(): + return True + _logger.info("No ADB device detected — attempting wireless reconnect...") + return self._try_wireless_reconnect() + + def _pull_stronglifts_db(self) -> Path | None: + """Pull StrongLifts database from phone to a local temp file. + + Returns: + Path to the local copy, or None on failure. + """ + tmp = Path(tempfile.gettempdir()) / "stronglifts_check.db" + success, _ = self._adb_shell( + f"cat '{STRONGLIFTS_DB_REMOTE}' > /sdcard/_sl_tmp.db", + root=True, + ) + if not success: + return None + ok, _ = self._run_adb(["pull", "/sdcard/_sl_tmp.db", str(tmp)]) + if not ok: + return None + return tmp + + def _count_today_workouts(self, db_path: Path) -> int: + """Count today's workouts in a local copy of StrongLifts DB. + + Args: + db_path: Path to the locally-pulled StrongLifts database. + + Returns: + Number of workouts started today (local time). + """ + try: + conn = sqlite3.connect(str(db_path)) + try: + cursor = conn.execute( + "SELECT COUNT(*) FROM workouts " + "WHERE date(start / 1000, 'unixepoch', 'localtime') " + "= date('now', 'localtime')", + ) + row = cursor.fetchone() + return int(row[0]) if row else 0 + finally: + conn.close() + except (sqlite3.Error, ValueError, TypeError): + _logger.warning("Failed to query StrongLifts database") + return 0 + + def _verify_phone_workout(self) -> tuple[str, str]: + """Verify workout was recorded in StrongLifts on the phone. + + Returns: + Tuple of (status, message) where status is one of: + - "verified": Workout confirmed on phone. + - "not_verified": Phone connected but no workout found. + - "no_phone": No phone connected via ADB. + - "error": Could not access StrongLifts database. + """ + if not self._is_phone_connected(): + return "no_phone", "No phone connected via ADB" + local_db = self._pull_stronglifts_db() + if local_db is None: + return "error", "StrongLifts database not found on phone" + count = self._count_today_workouts(local_db) + if count > 0: + return ( + "verified", + f"Workout verified! ({count} session(s) found on phone)", + ) + return "not_verified", "No workout found on phone today" diff --git a/python_pkg/screen_locker/_shutdown.py b/python_pkg/screen_locker/_shutdown.py new file mode 100644 index 0000000..bb90a56 --- /dev/null +++ b/python_pkg/screen_locker/_shutdown.py @@ -0,0 +1,262 @@ +"""Shutdown schedule adjustment mixin for the screen locker.""" + +from __future__ import annotations + +from datetime import datetime, timezone +import json +import logging +import subprocess + +from python_pkg.screen_locker._constants import ( + ADJUST_SHUTDOWN_SCRIPT, + SHUTDOWN_CONFIG_FILE, + SICK_DAY_STATE_FILE, +) + +_logger = logging.getLogger(__name__) + + +class ShutdownMixin: + """Mixin providing shutdown schedule adjustment functionality.""" + + def _apply_earlier_shutdown(self, today: str) -> bool: + """Read config, save state, and write earlier shutdown hours.""" + config_values = self._read_shutdown_config() + if config_values is None: + return False + mon_wed_hour, thu_sun_hour, morning_end_hour = config_values + if not self._save_sick_day_state(today, mon_wed_hour, thu_sun_hour): + _logger.error("Failed to save state - aborting adjustment") + return False + new_mon_wed = max(18, mon_wed_hour - 1) + new_thu_sun = max(18, thu_sun_hour - 1) + return self._write_shutdown_config( + new_mon_wed, + new_thu_sun, + morning_end_hour, + ) + + def _adjust_shutdown_time_earlier(self) -> bool: + """Adjust shutdown schedule 1.5 hours earlier (stricter). + + This can only be used once per day. Original values are saved and + automatically restored when checked the next day. + + Returns True if successful, False otherwise. + """ + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + self._restore_original_config_if_needed() + if self._sick_mode_used_today(): + _logger.warning("Sick mode already used today") + return False + try: + return self._apply_earlier_shutdown(today) + except (OSError, ValueError) as e: + _logger.warning("Failed to adjust shutdown time: %s", e) + return False + + def _adjust_shutdown_time_later(self) -> bool: + """Adjust shutdown schedule 2 hours later as workout reward. + + Returns True if successful, False otherwise. + """ + try: + config_values = self._read_shutdown_config() + if config_values is None: + return False + mon_wed_hour, thu_sun_hour, morning_end_hour = config_values + new_mon_wed = min(23, mon_wed_hour + 2) + new_thu_sun = min(23, thu_sun_hour + 2) + return self._write_shutdown_config( + new_mon_wed, + new_thu_sun, + morning_end_hour, + restore=True, + ) + except (OSError, ValueError) as e: + _logger.warning("Failed to adjust shutdown time for workout: %s", e) + return False + + def _sick_mode_used_today(self) -> bool: + """Check if sick mode was already used today.""" + if not SICK_DAY_STATE_FILE.exists(): + return False + + try: + with SICK_DAY_STATE_FILE.open() as f: + state = json.load(f) + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + return state.get("date") == today + except (OSError, json.JSONDecodeError): + return False + + def _save_sick_day_state( + self, + date: str, + orig_mon_wed: int, + orig_thu_sun: int, + ) -> bool: + """Save sick day state with original config values. + + Returns True if saved successfully, False otherwise. + """ + state = { + "date": date, + "original_mon_wed_hour": orig_mon_wed, + "original_thu_sun_hour": orig_thu_sun, + } + try: + with SICK_DAY_STATE_FILE.open("w") as f: + json.dump(state, f, indent=2) + except OSError as e: + _logger.warning("Failed to save sick day state: %s", e) + return False + + _logger.info("Saved sick day state for %s", date) + return True + + def _load_sick_day_state(self) -> tuple[str, int, int] | None: + """Load sick day state file. + + Returns (date, orig_mon_wed_hour, orig_thu_sun_hour) or None. + """ + with SICK_DAY_STATE_FILE.open() as f: + state = json.load(f) + date = state.get("date") + orig_mw = state.get("original_mon_wed_hour") + orig_ts = state.get("original_thu_sun_hour") + if date is None or orig_mw is None or orig_ts is None: + return None + return (str(date), int(orig_mw), int(orig_ts)) + + def _write_restored_config( + self, + orig_mw: int, + orig_ts: int, + state_date: str, + ) -> None: + """Write restored config values and clean up state file.""" + config_values = self._read_shutdown_config() + if config_values: + _, _, morning_end = config_values + _logger.info( + "Restoring original shutdown config from %s", + state_date, + ) + self._write_shutdown_config( + orig_mw, + orig_ts, + morning_end, + restore=True, + ) + SICK_DAY_STATE_FILE.unlink() + _logger.info("Removed stale sick day state from %s", state_date) + + def _restore_original_config_if_needed(self) -> None: + """Restore original config if sick day state is from a previous day.""" + if not SICK_DAY_STATE_FILE.exists(): + return + try: + loaded = self._load_sick_day_state() + if loaded is None: + return + state_date, orig_mw, orig_ts = loaded + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + if state_date != today: + self._write_restored_config(orig_mw, orig_ts, state_date) + except (OSError, json.JSONDecodeError) as e: + _logger.warning("Error checking sick day state: %s", e) + + def _read_shutdown_config(self) -> tuple[int, int, int] | None: + """Read shutdown config. Returns (mw_hour, ts_hour, me_hour) or None.""" + if not SHUTDOWN_CONFIG_FILE.exists(): + _logger.warning("Config not found: %s", SHUTDOWN_CONFIG_FILE) + return None + parsed: dict[str, int] = {} + keys = ("MON_WED_HOUR", "THU_SUN_HOUR", "MORNING_END_HOUR") + with SHUTDOWN_CONFIG_FILE.open() as f: + for line in f: + stripped = line.strip() + for key in keys: + if stripped.startswith(f"{key}="): + parsed[key] = int(stripped.split("=")[1]) + if len(parsed) < len(keys): + _logger.warning("Shutdown config missing required values") + return None + return ( + parsed["MON_WED_HOUR"], + parsed["THU_SUN_HOUR"], + parsed["MORNING_END_HOUR"], + ) + + def _build_shutdown_cmd( + self, + mon_wed: int, + thu_sun: int, + morning: int, + *, + restore: bool, + ) -> list[str]: + """Build the shutdown adjustment command.""" + cmd = ["/usr/bin/sudo", str(ADJUST_SHUTDOWN_SCRIPT)] + if restore: + cmd.append("--restore") + cmd.extend([str(mon_wed), str(thu_sun), str(morning)]) + return cmd + + def _write_shutdown_config( + self, + mon_wed_hour: int, + thu_sun_hour: int, + morning_end_hour: int, + *, + restore: bool = False, + ) -> bool: + """Write new shutdown config values using helper script. + + Args: + mon_wed_hour: Shutdown hour for Monday-Wednesday. + thu_sun_hour: Shutdown hour for Thursday-Sunday. + morning_end_hour: Morning end hour. + restore: If True, allows restoring to later times. + + Returns True if successful, False otherwise. + """ + if not ADJUST_SHUTDOWN_SCRIPT.exists(): + _logger.warning( + "Script not found: %s", + ADJUST_SHUTDOWN_SCRIPT, + ) + return False + cmd = self._build_shutdown_cmd( + mon_wed_hour, + thu_sun_hour, + morning_end_hour, + restore=restore, + ) + return self._run_shutdown_cmd(cmd, mon_wed_hour, thu_sun_hour) + + def _run_shutdown_cmd( + self, + cmd: list[str], + mon_wed_hour: int, + thu_sun_hour: int, + ) -> bool: + """Execute the shutdown adjustment command.""" + try: + result = subprocess.run( + cmd, + check=True, + capture_output=True, + text=True, + ) + except subprocess.SubprocessError as e: + _logger.warning("Failed to adjust shutdown config: %s", e) + return False + _logger.info( + "Adjusted shutdown: Mon-Wed=%d, Thu-Sun=%d. %s", + mon_wed_hour, + thu_sun_hour, + result.stdout.strip(), + ) + return True diff --git a/python_pkg/screen_locker/_ui_flows.py b/python_pkg/screen_locker/_ui_flows.py new file mode 100644 index 0000000..f6045d7 --- /dev/null +++ b/python_pkg/screen_locker/_ui_flows.py @@ -0,0 +1,294 @@ +"""UI flow methods mixin for the screen locker.""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +import contextlib +import tkinter as tk +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + +from python_pkg.screen_locker._constants import ( + PHONE_PENALTY_DELAY_DEMO, + PHONE_PENALTY_DELAY_PRODUCTION, + SICK_LOCKOUT_SECONDS, +) + + +class UIFlowsMixin: + """Mixin providing UI flow logic for the screen locker.""" + + def ask_workout_done(self) -> None: + """Display the initial workout question dialog.""" + self.clear_container() + self._label("Did you work out today?", pady=30) + frame = self._button_row() + self._button( + frame, + "YES", + bg="#00aa00", + command=self.ask_workout_type, + ).pack(side="left", padx=20) + self._button( + frame, + "NO", + bg="#aa0000", + command=self.ask_if_sick, + ).pack(side="left", padx=20) + + def _start_phone_check(self) -> None: + """Check phone for today's workout immediately at startup.""" + self.clear_container() + self._label("Checking phone...", font_size=36, color="#ffaa00", pady=30) + self._text("Looking for today's workout in StrongLifts...", font_size=18) + executor = ThreadPoolExecutor(max_workers=1) + self._phone_future = executor.submit(self._verify_phone_workout) + executor.shutdown(wait=False) + self._poll_phone_check() + + def _poll_phone_check(self) -> None: + """Poll background phone check and route to result handler when done.""" + if self._phone_future is not None and self._phone_future.done(): + status, message = self._phone_future.result() + self._handle_startup_phone_result(status, message) + else: + self.root.after(500, self._poll_phone_check) + + def _handle_startup_phone_result(self, status: str, message: str) -> None: + """Route to appropriate screen based on startup phone check result.""" + if status == "verified": + self.workout_data["type"] = "phone_verified" + self.workout_data["source"] = message + self.clear_container() + self._label( + "\u2713 Workout Verified!", font_size=42, color="#00cc44", pady=30 + ) + self._text(message, font_size=20, color="#aaffaa") + self._text("Unlocking...", font_size=18, color="#888888") + unlock_delay = 1500 if self.demo_mode else 2000 + self.root.after(unlock_delay, self.unlock_screen) + elif status == "not_verified": + self.clear_container() + self._label("No Workout Found", font_size=36, color="#ff4444", pady=20) + self._text( + f"\u274c {message}\n\n" + "StrongLifts shows no workout today.\n" + "Go do your workout first!", + color="#ffaa00", + ) + frame = self._button_row() + self._button( + frame, + "TRY AGAIN", + bg="#0066cc", + command=self._start_phone_check, + width=12, + ).pack(side="left", padx=10) + self._button( + frame, + "I'm sick", + bg="#cc6600", + command=self.ask_if_sick, + width=12, + ).pack(side="left", padx=10) + else: + # no_phone or error — penalty timer, then proceed to logging form + self._show_phone_penalty(message, on_done=self.ask_workout_done) + + def ask_if_sick(self) -> None: + """Display sick day question dialog.""" + self.clear_container() + self._label("Are you sick?", pady=30) + self._text( + "If yes, shutdown time will be moved 1.5 hours earlier", + color="#ffaa00", + ) + self._sick_question_buttons() + + def _sick_question_buttons(self) -> None: + """Create the sick day yes/no buttons.""" + frame = self._button_row() + self._button( + frame, + "YES (sick)", + bg="#cc6600", + command=self.handle_sick_day, + width=12, + ).pack(side="left", padx=20) + self._button( + frame, + "NO", + bg="#aa0000", + command=self.lockout, + width=12, + ).pack(side="left", padx=20) + + def _get_sick_day_status(self) -> tuple[str, str]: + """Determine sick day status text and color.""" + if self._sick_mode_used_today(): + return "Shutdown time already adjusted today", "#ffaa00" + if self._adjust_shutdown_time_earlier(): + return ( + "Shutdown time moved 1.5 hours earlier \u2713\n(Will revert tomorrow)" + ), "#00aa00" + return "Could not adjust shutdown time (check permissions)", "#ff4444" + + def handle_sick_day(self) -> None: + """Handle sick day: adjust shutdown time and start 2-minute wait.""" + self.clear_container() + status_text, status_color = self._get_sick_day_status() + self._show_sick_day_ui(status_text, status_color) + self.sick_remaining_time = SICK_LOCKOUT_SECONDS + self._update_sick_countdown() + + def _show_sick_day_ui(self, status_text: str, status_color: str) -> None: + """Display sick day UI labels and countdown.""" + self._label("Sick Day Mode", color="#cc6600", pady=20) + self._text(status_text, color=status_color) + self._text( + "Please wait 2 minutes before unlocking...", + font_size=24, + pady=20, + ) + self.sick_countdown_label = self._label( + str(SICK_LOCKOUT_SECONDS), + font_size=80, + pady=30, + ) + + def _update_sick_countdown(self) -> None: + """Update the sick day countdown timer.""" + if self.sick_remaining_time > 0: + self.sick_countdown_label.config(text=str(self.sick_remaining_time)) + self.sick_remaining_time -= 1 + self.root.after(1000, self._update_sick_countdown) + else: + # Record sick day and unlock + self.workout_data["type"] = "sick_day" + self.workout_data["note"] = "Sick day - shutdown moved earlier" + self.unlock_screen() + + # ------------------------------------------------------------------ + # Lockout flow + # ------------------------------------------------------------------ + + def lockout(self) -> None: + """Display lockout screen with countdown timer.""" + self.clear_container() + self.lockout_label = self._label( + f"Go work out!\nLocked for {self.lockout_time} seconds", + font_size=48, + color="#ff4444", + pady=30, + ) + self.countdown_label = self._label( + str(self.lockout_time), + font_size=120, + pady=30, + ) + self.remaining_time = self.lockout_time + self.update_lockout_countdown() + + def update_lockout_countdown(self) -> None: + """Update the lockout countdown timer display.""" + if self.remaining_time > 0: + self.countdown_label.config(text=str(self.remaining_time)) + self.remaining_time -= 1 + self.root.after(1000, self.update_lockout_countdown) + else: + self.ask_workout_done() + + # ------------------------------------------------------------------ + # Phone penalty + # ------------------------------------------------------------------ + + def _attempt_unlock(self) -> None: + """Unlock screen after workout form submission.""" + self.unlock_screen() + + def _show_phone_penalty( + self, message: str, *, on_done: Callable[[], None] | None = None + ) -> None: + """Show penalty countdown when phone verification is unavailable.""" + self.clear_container() + self._phone_penalty_done_fn: Callable[[], None] = ( + on_done if on_done is not None else self.unlock_screen + ) + delay = ( + PHONE_PENALTY_DELAY_DEMO + if self.demo_mode + else PHONE_PENALTY_DELAY_PRODUCTION + ) + self._label( + "Cannot Verify Workout", + font_size=36, + color="#ff8800", + pady=20, + ) + self._text(message, color="#ffaa00") + self._text( + "Connect phone via ADB to skip this wait,\n" + "or wait for the penalty timer.\n\n" + "Note: Phone must be rooted and StrongLifts installed.", + font_size=18, + ) + self.phone_penalty_remaining = delay + self.phone_penalty_label = self._label( + str(delay), + font_size=80, + pady=20, + ) + self._update_phone_penalty() + + def _update_phone_penalty(self) -> None: + """Update phone penalty countdown.""" + if self.phone_penalty_remaining > 0: + self.phone_penalty_label.config( + text=str(self.phone_penalty_remaining), + ) + self.phone_penalty_remaining -= 1 + self.root.after(1000, self._update_phone_penalty) + else: + self._phone_penalty_done_fn() + + # ------------------------------------------------------------------ + # Submit timer and entry checking + # ------------------------------------------------------------------ + + def _tick_submit_timer(self) -> None: + """Decrement submit timer and schedule next tick.""" + self.timer_label.config( + text=f"Submit available in {self.submit_unlock_time} seconds...", + ) + self.submit_unlock_time -= 1 + self.root.after(1000, self.update_submit_timer) + + def _try_enable_submit(self) -> None: + """Enable submit button if all entries are filled.""" + all_filled = all(entry.get().strip() for entry in self.entries_to_check) + if all_filled: + self.submit_btn.config( + text="SUBMIT", + state="normal", + bg="#00aa00", + command=self.submit_command, + ) + self.timer_label.config(text="You can now submit!") + else: + self.timer_label.config(text="Fill all fields to enable submit") + self.root.after(1000, self.check_entries_filled) + + def update_submit_timer(self) -> None: + """Update countdown timer and check if submit can be enabled.""" + with contextlib.suppress(tk.TclError): + if self.submit_unlock_time > 0: + self._tick_submit_timer() + else: + self._try_enable_submit() + + def check_entries_filled(self) -> None: + """Continuously check if entries are filled after timer expires.""" + with contextlib.suppress(tk.TclError): + self._try_enable_submit() diff --git a/python_pkg/screen_locker/_workout_forms.py b/python_pkg/screen_locker/_workout_forms.py new file mode 100644 index 0000000..e3a43c3 --- /dev/null +++ b/python_pkg/screen_locker/_workout_forms.py @@ -0,0 +1,269 @@ +"""Workout form methods mixin for the screen locker.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from python_pkg.screen_locker._constants import ( + MAX_DISTANCE_KM, + MAX_PACE_MIN_PER_KM, + MAX_REPS, + MAX_SETS, + MAX_TIME_MINUTES, + MAX_WEIGHT_KG, + MIN_EXERCISE_NAME_LEN, + STRENGTH_FIELDS, +) + +if TYPE_CHECKING: + import tkinter as tk + + +class WorkoutFormsMixin: + """Mixin providing workout form creation and validation.""" + + # ------------------------------------------------------------------ + # Workout type selection + # ------------------------------------------------------------------ + + def ask_workout_type(self) -> None: + """Display workout type selection dialog.""" + self.clear_container() + self._label("What type of workout?", pady=30) + frame = self._button_row() + self._button( + frame, + "STRENGTH", + bg="#cc6600", + command=self.ask_strength_details, + width=12, + ).pack(side="left", padx=20) + + # ------------------------------------------------------------------ + # Running workout + # ------------------------------------------------------------------ + + def _create_running_entries(self) -> list[tk.Entry]: + """Create running workout entry fields.""" + self.distance_entry = self._entry_row("Distance (km):") + self.time_entry = self._entry_row("Time (minutes):") + self.pace_entry = self._entry_row("Pace (min/km):") + return [self.distance_entry, self.time_entry, self.pace_entry] + + def ask_running_details(self) -> None: + """Display running workout input form.""" + self.clear_container() + self.workout_data["type"] = "running" + self._label("Running Details", pady=20) + entries = self._create_running_entries() + self._setup_form_controls( + entries, + self.verify_running_data, + self.ask_workout_type, + ) + + def _check_running_ranges( + self, + distance: float, + time_mins: float, + pace: float, + ) -> str | None: + """Check if running values are in valid ranges.""" + if distance <= 0 or distance > MAX_DISTANCE_KM: + return f"Distance seems unrealistic (0-{MAX_DISTANCE_KM} km)" + if time_mins <= 0 or time_mins > MAX_TIME_MINUTES: + return f"Time seems unrealistic (0-{MAX_TIME_MINUTES} minutes)" + if pace <= 0 or pace > MAX_PACE_MIN_PER_KM: + return f"Pace seems unrealistic (0-{MAX_PACE_MIN_PER_KM} min/km)" + expected_pace = time_mins / distance + tolerance = expected_pace * 0.15 # 15% tolerance + if abs(pace - expected_pace) > tolerance: + return ( + f"Pace doesn't match! " + f"Expected ~{expected_pace:.2f} min/km, got {pace:.2f}" + ) + return None + + def _validate_running_input(self) -> tuple[float, float, float] | None: + """Parse and validate running input fields.""" + try: + distance = float(self.distance_entry.get()) + time_mins = float(self.time_entry.get()) + pace = float(self.pace_entry.get()) + except ValueError: + self.show_error("Please enter valid numbers") + return None + error = self._check_running_ranges(distance, time_mins, pace) + if error: + self.show_error(error) + return None + return distance, time_mins, pace + + def verify_running_data(self) -> None: + """Validate running workout data and unlock if valid.""" + result = self._validate_running_input() + if result is None: + return + distance, time_mins, pace = result + self.workout_data["distance_km"] = str(distance) + self.workout_data["time_minutes"] = str(time_mins) + self.workout_data["pace_min_per_km"] = str(pace) + self._attempt_unlock() + + # ------------------------------------------------------------------ + # Strength workout + # ------------------------------------------------------------------ + + def _create_strength_entries(self) -> list[tk.Entry]: + """Create strength training entry fields.""" + entries = [ + self._entry_row(lbl, width=w, font_size=18) for lbl, w in STRENGTH_FIELDS + ] + ( + self.exercises_entry, + self.sets_entry, + self.reps_entry, + self.weights_entry, + self.total_weight_entry, + ) = entries + return entries + + def ask_strength_details(self) -> None: + """Display strength training input form.""" + self.clear_container() + self.workout_data["type"] = "strength" + self._label("Strength Training Details", pady=20) + entries = self._create_strength_entries() + self._setup_form_controls( + entries, + self.verify_strength_data, + self.ask_workout_type, + ) + + def _parse_reps(self, reps_raw: list[str]) -> list[list[int]]: + """Parse reps input - single number or variable reps like '12+11+12'.""" + reps: list[list[int]] = [] + for r in reps_raw: + if "+" in r: + reps.append([int(x.strip()) for x in r.split("+")]) + else: + reps.append([int(r)]) + return reps + + def _validate_strength_inputs( + self, + exercises: list[str], + sets: list[int], + reps: list[list[int]], + weights: list[float], + ) -> str | None: + """Validate strength workout inputs. Returns error message or None.""" + if not (len(exercises) == len(sets) == len(reps) == len(weights)): + return "Number of exercises, sets, reps, and weights must match" + if any(len(ex) < MIN_EXERCISE_NAME_LEN for ex in exercises): + return "Exercise names too short - be specific" + if any(s < 1 or s > MAX_SETS for s in sets): + return f"Sets should be between 1-{MAX_SETS}" + if any(w < 0 or w > MAX_WEIGHT_KG for w in weights): + return f"Weights should be between 0-{MAX_WEIGHT_KG} kg" + return self._validate_reps(exercises, sets, reps) + + def _validate_reps( + self, + exercises: list[str], + sets: list[int], + reps: list[list[int]], + ) -> str | None: + """Validate reps data. Returns error message or None if valid.""" + for i, rep_list in enumerate(reps): + if any(r < 1 or r > MAX_REPS for r in rep_list): + return f"Reps should be between 1-{MAX_REPS}" + if len(rep_list) > 1 and len(rep_list) != sets[i]: + return ( + f"For {exercises[i]!r}: variable reps count " + f"({len(rep_list)}) doesn't match sets ({sets[i]})" + ) + return None + + def _calculate_expected_total( + self, + sets: list[int], + reps: list[list[int]], + weights: list[float], + ) -> float: + """Calculate expected total weight lifted.""" + expected_total = 0.0 + for i, rep_list in enumerate(reps): + if len(rep_list) == 1: + expected_total += sets[i] * rep_list[0] * weights[i] + else: + expected_total += sum(rep_list) * weights[i] + return expected_total + + def _parse_strength_entries( + self, + ) -> tuple[list[str], list[int], list[list[int]], list[float], float]: + """Parse raw strength training input from entry widgets.""" + exercises = [e.strip() for e in self.exercises_entry.get().split(",")] + sets = [int(s.strip()) for s in self.sets_entry.get().split(",")] + reps_raw = [r.strip() for r in self.reps_entry.get().split(",")] + reps = self._parse_reps(reps_raw) + weights = [float(w.strip()) for w in self.weights_entry.get().split(",")] + total_weight = float(self.total_weight_entry.get()) + return exercises, sets, reps, weights, total_weight + + def _check_total_weight( + self, + sets: list[int], + reps: list[list[int]], + weights: list[float], + total_weight: float, + ) -> str | None: + """Verify total weight matches individual exercise calculations.""" + expected = self._calculate_expected_total(sets, reps, weights) + tolerance = expected * 0.15 # 15% tolerance + if abs(total_weight - expected) > tolerance: + return ( + f"Total weight doesn't match! " + f"Expected ~{expected:.1f} kg, got {total_weight:.1f}" + ) + return None + + def _store_strength_data( + self, + exercises: list[str], + sets: list[int], + reps: list[list[int]], + weights: list[float], + total_weight: float, + ) -> None: + """Store validated strength workout data.""" + self.workout_data["exercises"] = exercises + self.workout_data["sets"] = [str(s) for s in sets] + self.workout_data["reps"] = [ + "+".join(str(r) for r in rep_list) for rep_list in reps + ] + self.workout_data["weights_kg"] = [str(w) for w in weights] + self.workout_data["total_weight_kg"] = str(total_weight) + + def verify_strength_data(self) -> None: + """Validate strength workout data and unlock if valid.""" + try: + self._verify_strength_data_inner() + except ValueError: + self.show_error("Please enter valid data in correct format") + + def _verify_strength_data_inner(self) -> None: + """Parse, validate, and store strength data.""" + data = self._parse_strength_entries() + exercises, sets, reps, weights, total_weight = data + error = self._validate_strength_inputs(exercises, sets, reps, weights) + if error: + self.show_error(error) + return + total_err = self._check_total_weight(sets, reps, weights, total_weight) + if total_err: + self.show_error(total_err) + return + self._store_strength_data(exercises, sets, reps, weights, total_weight) + self._attempt_unlock() diff --git a/python_pkg/screen_locker/screen_lock.py b/python_pkg/screen_locker/screen_lock.py index f217786..3f376a3 100755 --- a/python_pkg/screen_locker/screen_lock.py +++ b/python_pkg/screen_locker/screen_lock.py @@ -6,59 +6,48 @@ Requires user to log their workout to unlock the screen. from __future__ import annotations -from concurrent.futures import Future, ThreadPoolExecutor, as_completed import contextlib from datetime import datetime, timezone import json import logging from pathlib import Path -import shutil -import socket -import sqlite3 -import subprocess import sys -import tempfile import tkinter as tk from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Callable + from concurrent.futures import Future + +from python_pkg.screen_locker._constants import ( # noqa: F401 + MAX_DISTANCE_KM, + MAX_PACE_MIN_PER_KM, + MAX_REPS, + MAX_SETS, + MAX_TIME_MINUTES, + MAX_WEIGHT_KG, + MIN_EXERCISE_NAME_LEN, + PHONE_PENALTY_DELAY_DEMO, + PHONE_PENALTY_DELAY_PRODUCTION, + SICK_LOCKOUT_SECONDS, + STRONGLIFTS_DB_REMOTE, + SUBMIT_DELAY_DEMO, + SUBMIT_DELAY_PRODUCTION, +) +from python_pkg.screen_locker._phone_verification import PhoneVerificationMixin +from python_pkg.screen_locker._shutdown import ShutdownMixin +from python_pkg.screen_locker._ui_flows import UIFlowsMixin +from python_pkg.screen_locker._workout_forms import WorkoutFormsMixin _logger = logging.getLogger(__name__) -# Validation limits for workout data -MAX_DISTANCE_KM = 100 -MAX_TIME_MINUTES = 600 -MAX_PACE_MIN_PER_KM = 20 -MIN_EXERCISE_NAME_LEN = 3 -MAX_SETS = 20 -MAX_REPS = 100 -MAX_WEIGHT_KG = 500 -SICK_LOCKOUT_SECONDS = 120 # 2 minutes wait when sick -SUBMIT_DELAY_DEMO = 30 -SUBMIT_DELAY_PRODUCTION = 180 -PHONE_PENALTY_DELAY_DEMO = 10 -PHONE_PENALTY_DELAY_PRODUCTION = 600 -ADB_TIMEOUT = 15 -STRONGLIFTS_DB_REMOTE = ( - "/data/data/com.stronglifts.app/databases/StrongLifts-Database-3" -) -SHUTDOWN_CONFIG_FILE = Path("/etc/shutdown-schedule.conf") -# Helper script path (relative to this file) -ADJUST_SHUTDOWN_SCRIPT = Path(__file__).resolve().parent / "adjust_shutdown_schedule.sh" -# State file to track sick day usage and original config values -SICK_DAY_STATE_FILE = Path(__file__).resolve().parent / "sick_day_state.json" -_STRENGTH_FIELDS: list[tuple[str, int]] = [ - ("Exercises (comma-separated):", 50), - ("Sets per exercise (comma-separated):", 20), - ("Reps (comma-sep, + for variable: 12+11+12):", 30), - ("Weight per exercise kg (comma-separated):", 20), - ("Total weight lifted (kg):", 15), -] - - -class ScreenLocker: +class ScreenLocker( + ShutdownMixin, + PhoneVerificationMixin, + WorkoutFormsMixin, + UIFlowsMixin, +): """Screen locker that requires workout logging to unlock.""" def __init__(self, *, demo_mode: bool = True) -> None: @@ -262,957 +251,6 @@ class ScreenLocker: self.submit_command = verify_command self.update_submit_timer() - # ------------------------------------------------------------------ - # Main screen flows - # ------------------------------------------------------------------ - - def ask_workout_done(self) -> None: - """Display the initial workout question dialog.""" - self.clear_container() - self._label("Did you work out today?", pady=30) - frame = self._button_row() - self._button( - frame, - "YES", - bg="#00aa00", - command=self.ask_workout_type, - ).pack(side="left", padx=20) - self._button( - frame, - "NO", - bg="#aa0000", - command=self.ask_if_sick, - ).pack(side="left", padx=20) - - def _start_phone_check(self) -> None: - """Check phone for today's workout immediately at startup.""" - self.clear_container() - self._label("Checking phone...", font_size=36, color="#ffaa00", pady=30) - self._text("Looking for today's workout in StrongLifts...", font_size=18) - executor = ThreadPoolExecutor(max_workers=1) - self._phone_future = executor.submit(self._verify_phone_workout) - executor.shutdown(wait=False) - self._poll_phone_check() - - def _poll_phone_check(self) -> None: - """Poll background phone check and route to result handler when done.""" - if self._phone_future is not None and self._phone_future.done(): - status, message = self._phone_future.result() - self._handle_startup_phone_result(status, message) - else: - self.root.after(500, self._poll_phone_check) - - def _handle_startup_phone_result(self, status: str, message: str) -> None: - """Route to appropriate screen based on startup phone check result.""" - if status == "verified": - self.workout_data["type"] = "phone_verified" - self.workout_data["source"] = message - self.clear_container() - self._label( - "\u2713 Workout Verified!", font_size=42, color="#00cc44", pady=30 - ) - self._text(message, font_size=20, color="#aaffaa") - self._text("Unlocking...", font_size=18, color="#888888") - unlock_delay = 1500 if self.demo_mode else 2000 - self.root.after(unlock_delay, self.unlock_screen) - elif status == "not_verified": - self.clear_container() - self._label("No Workout Found", font_size=36, color="#ff4444", pady=20) - self._text( - f"\u274c {message}\n\n" - "StrongLifts shows no workout today.\n" - "Go do your workout first!", - color="#ffaa00", - ) - frame = self._button_row() - self._button( - frame, - "TRY AGAIN", - bg="#0066cc", - command=self._start_phone_check, - width=12, - ).pack(side="left", padx=10) - self._button( - frame, - "I'm sick", - bg="#cc6600", - command=self.ask_if_sick, - width=12, - ).pack(side="left", padx=10) - else: - # no_phone or error — penalty timer, then proceed to logging form - self._show_phone_penalty(message, on_done=self.ask_workout_done) - - def ask_if_sick(self) -> None: - """Display sick day question dialog.""" - self.clear_container() - self._label("Are you sick?", pady=30) - self._text( - "If yes, shutdown time will be moved 1.5 hours earlier", - color="#ffaa00", - ) - self._sick_question_buttons() - - def _sick_question_buttons(self) -> None: - """Create the sick day yes/no buttons.""" - frame = self._button_row() - self._button( - frame, - "YES (sick)", - bg="#cc6600", - command=self.handle_sick_day, - width=12, - ).pack(side="left", padx=20) - self._button( - frame, - "NO", - bg="#aa0000", - command=self.lockout, - width=12, - ).pack(side="left", padx=20) - - def _get_sick_day_status(self) -> tuple[str, str]: - """Determine sick day status text and color.""" - if self._sick_mode_used_today(): - return "Shutdown time already adjusted today", "#ffaa00" - if self._adjust_shutdown_time_earlier(): - return ( - "Shutdown time moved 1.5 hours earlier ✓\n(Will revert tomorrow)" - ), "#00aa00" - return "Could not adjust shutdown time (check permissions)", "#ff4444" - - def handle_sick_day(self) -> None: - """Handle sick day: adjust shutdown time and start 2-minute wait.""" - self.clear_container() - status_text, status_color = self._get_sick_day_status() - self._show_sick_day_ui(status_text, status_color) - self.sick_remaining_time = SICK_LOCKOUT_SECONDS - self._update_sick_countdown() - - def _show_sick_day_ui(self, status_text: str, status_color: str) -> None: - """Display sick day UI labels and countdown.""" - self._label("Sick Day Mode", color="#cc6600", pady=20) - self._text(status_text, color=status_color) - self._text( - "Please wait 2 minutes before unlocking...", - font_size=24, - pady=20, - ) - self.sick_countdown_label = self._label( - str(SICK_LOCKOUT_SECONDS), - font_size=80, - pady=30, - ) - - def _update_sick_countdown(self) -> None: - """Update the sick day countdown timer.""" - if self.sick_remaining_time > 0: - self.sick_countdown_label.config(text=str(self.sick_remaining_time)) - self.sick_remaining_time -= 1 - self.root.after(1000, self._update_sick_countdown) - else: - # Record sick day and unlock - self.workout_data["type"] = "sick_day" - self.workout_data["note"] = "Sick day - shutdown moved earlier" - self.unlock_screen() - - # ------------------------------------------------------------------ - # Shutdown schedule adjustment - # ------------------------------------------------------------------ - - def _apply_earlier_shutdown(self, today: str) -> bool: - """Read config, save state, and write earlier shutdown hours.""" - config_values = self._read_shutdown_config() - if config_values is None: - return False - mon_wed_hour, thu_sun_hour, morning_end_hour = config_values - if not self._save_sick_day_state(today, mon_wed_hour, thu_sun_hour): - _logger.error("Failed to save state - aborting adjustment") - return False - new_mon_wed = max(18, mon_wed_hour - 1) - new_thu_sun = max(18, thu_sun_hour - 1) - return self._write_shutdown_config( - new_mon_wed, - new_thu_sun, - morning_end_hour, - ) - - def _adjust_shutdown_time_earlier(self) -> bool: - """Adjust shutdown schedule 1.5 hours earlier (stricter). - - This can only be used once per day. Original values are saved and - automatically restored when checked the next day. - - Returns True if successful, False otherwise. - """ - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - self._restore_original_config_if_needed() - if self._sick_mode_used_today(): - _logger.warning("Sick mode already used today") - return False - try: - return self._apply_earlier_shutdown(today) - except (OSError, ValueError) as e: - _logger.warning("Failed to adjust shutdown time: %s", e) - return False - - def _adjust_shutdown_time_later(self) -> bool: - """Adjust shutdown schedule 2 hours later as workout reward. - - Returns True if successful, False otherwise. - """ - try: - config_values = self._read_shutdown_config() - if config_values is None: - return False - mon_wed_hour, thu_sun_hour, morning_end_hour = config_values - new_mon_wed = min(23, mon_wed_hour + 2) - new_thu_sun = min(23, thu_sun_hour + 2) - return self._write_shutdown_config( - new_mon_wed, - new_thu_sun, - morning_end_hour, - restore=True, - ) - except (OSError, ValueError) as e: - _logger.warning("Failed to adjust shutdown time for workout: %s", e) - return False - - def _sick_mode_used_today(self) -> bool: - """Check if sick mode was already used today.""" - if not SICK_DAY_STATE_FILE.exists(): - return False - - try: - with SICK_DAY_STATE_FILE.open() as f: - state = json.load(f) - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - return state.get("date") == today - except (OSError, json.JSONDecodeError): - return False - - def _save_sick_day_state( - self, - date: str, - orig_mon_wed: int, - orig_thu_sun: int, - ) -> bool: - """Save sick day state with original config values. - - Returns True if saved successfully, False otherwise. - """ - state = { - "date": date, - "original_mon_wed_hour": orig_mon_wed, - "original_thu_sun_hour": orig_thu_sun, - } - try: - with SICK_DAY_STATE_FILE.open("w") as f: - json.dump(state, f, indent=2) - except OSError as e: - _logger.warning("Failed to save sick day state: %s", e) - return False - - _logger.info("Saved sick day state for %s", date) - return True - - def _load_sick_day_state(self) -> tuple[str, int, int] | None: - """Load sick day state file. - - Returns (date, orig_mon_wed_hour, orig_thu_sun_hour) or None. - """ - with SICK_DAY_STATE_FILE.open() as f: - state = json.load(f) - date = state.get("date") - orig_mw = state.get("original_mon_wed_hour") - orig_ts = state.get("original_thu_sun_hour") - if date is None or orig_mw is None or orig_ts is None: - return None - return (str(date), int(orig_mw), int(orig_ts)) - - def _write_restored_config( - self, - orig_mw: int, - orig_ts: int, - state_date: str, - ) -> None: - """Write restored config values and clean up state file.""" - config_values = self._read_shutdown_config() - if config_values: - _, _, morning_end = config_values - _logger.info( - "Restoring original shutdown config from %s", - state_date, - ) - self._write_shutdown_config( - orig_mw, - orig_ts, - morning_end, - restore=True, - ) - SICK_DAY_STATE_FILE.unlink() - _logger.info("Removed stale sick day state from %s", state_date) - - def _restore_original_config_if_needed(self) -> None: - """Restore original config if sick day state is from a previous day.""" - if not SICK_DAY_STATE_FILE.exists(): - return - try: - loaded = self._load_sick_day_state() - if loaded is None: - return - state_date, orig_mw, orig_ts = loaded - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - if state_date != today: - self._write_restored_config(orig_mw, orig_ts, state_date) - except (OSError, json.JSONDecodeError) as e: - _logger.warning("Error checking sick day state: %s", e) - - def _read_shutdown_config(self) -> tuple[int, int, int] | None: - """Read shutdown config. Returns (mw_hour, ts_hour, me_hour) or None.""" - if not SHUTDOWN_CONFIG_FILE.exists(): - _logger.warning("Config not found: %s", SHUTDOWN_CONFIG_FILE) - return None - parsed: dict[str, int] = {} - keys = ("MON_WED_HOUR", "THU_SUN_HOUR", "MORNING_END_HOUR") - with SHUTDOWN_CONFIG_FILE.open() as f: - for line in f: - stripped = line.strip() - for key in keys: - if stripped.startswith(f"{key}="): - parsed[key] = int(stripped.split("=")[1]) - if len(parsed) < len(keys): - _logger.warning("Shutdown config missing required values") - return None - return ( - parsed["MON_WED_HOUR"], - parsed["THU_SUN_HOUR"], - parsed["MORNING_END_HOUR"], - ) - - def _build_shutdown_cmd( - self, - mon_wed: int, - thu_sun: int, - morning: int, - *, - restore: bool, - ) -> list[str]: - """Build the shutdown adjustment command.""" - cmd = ["/usr/bin/sudo", str(ADJUST_SHUTDOWN_SCRIPT)] - if restore: - cmd.append("--restore") - cmd.extend([str(mon_wed), str(thu_sun), str(morning)]) - return cmd - - def _write_shutdown_config( - self, - mon_wed_hour: int, - thu_sun_hour: int, - morning_end_hour: int, - *, - restore: bool = False, - ) -> bool: - """Write new shutdown config values using helper script. - - Args: - mon_wed_hour: Shutdown hour for Monday-Wednesday. - thu_sun_hour: Shutdown hour for Thursday-Sunday. - morning_end_hour: Morning end hour. - restore: If True, allows restoring to later times. - - Returns True if successful, False otherwise. - """ - if not ADJUST_SHUTDOWN_SCRIPT.exists(): - _logger.warning( - "Script not found: %s", - ADJUST_SHUTDOWN_SCRIPT, - ) - return False - cmd = self._build_shutdown_cmd( - mon_wed_hour, - thu_sun_hour, - morning_end_hour, - restore=restore, - ) - return self._run_shutdown_cmd(cmd, mon_wed_hour, thu_sun_hour) - - def _run_shutdown_cmd( - self, - cmd: list[str], - mon_wed_hour: int, - thu_sun_hour: int, - ) -> bool: - """Execute the shutdown adjustment command.""" - try: - result = subprocess.run( - cmd, - check=True, - capture_output=True, - text=True, - ) - except subprocess.SubprocessError as e: - _logger.warning("Failed to adjust shutdown config: %s", e) - return False - _logger.info( - "Adjusted shutdown: Mon-Wed=%d, Thu-Sun=%d. %s", - mon_wed_hour, - thu_sun_hour, - result.stdout.strip(), - ) - return True - - # ------------------------------------------------------------------ - # Lockout flow - # ------------------------------------------------------------------ - - def lockout(self) -> None: - """Display lockout screen with countdown timer.""" - self.clear_container() - self.lockout_label = self._label( - f"Go work out!\nLocked for {self.lockout_time} seconds", - font_size=48, - color="#ff4444", - pady=30, - ) - self.countdown_label = self._label( - str(self.lockout_time), - font_size=120, - pady=30, - ) - self.remaining_time = self.lockout_time - self.update_lockout_countdown() - - def update_lockout_countdown(self) -> None: - """Update the lockout countdown timer display.""" - if self.remaining_time > 0: - self.countdown_label.config(text=str(self.remaining_time)) - self.remaining_time -= 1 - self.root.after(1000, self.update_lockout_countdown) - else: - self.ask_workout_done() - - # ------------------------------------------------------------------ - # Workout type selection - # ------------------------------------------------------------------ - - def ask_workout_type(self) -> None: - """Display workout type selection dialog.""" - self.clear_container() - self._label("What type of workout?", pady=30) - frame = self._button_row() - self._button( - frame, - "STRENGTH", - bg="#cc6600", - command=self.ask_strength_details, - width=12, - ).pack(side="left", padx=20) - - # ------------------------------------------------------------------ - # Running workout - # ------------------------------------------------------------------ - - def _create_running_entries(self) -> list[tk.Entry]: - """Create running workout entry fields.""" - self.distance_entry = self._entry_row("Distance (km):") - self.time_entry = self._entry_row("Time (minutes):") - self.pace_entry = self._entry_row("Pace (min/km):") - return [self.distance_entry, self.time_entry, self.pace_entry] - - def ask_running_details(self) -> None: - """Display running workout input form.""" - self.clear_container() - self.workout_data["type"] = "running" - self._label("Running Details", pady=20) - entries = self._create_running_entries() - self._setup_form_controls( - entries, - self.verify_running_data, - self.ask_workout_type, - ) - - def _check_running_ranges( - self, - distance: float, - time_mins: float, - pace: float, - ) -> str | None: - """Check if running values are in valid ranges.""" - if distance <= 0 or distance > MAX_DISTANCE_KM: - return f"Distance seems unrealistic (0-{MAX_DISTANCE_KM} km)" - if time_mins <= 0 or time_mins > MAX_TIME_MINUTES: - return f"Time seems unrealistic (0-{MAX_TIME_MINUTES} minutes)" - if pace <= 0 or pace > MAX_PACE_MIN_PER_KM: - return f"Pace seems unrealistic (0-{MAX_PACE_MIN_PER_KM} min/km)" - expected_pace = time_mins / distance - tolerance = expected_pace * 0.15 # 15% tolerance - if abs(pace - expected_pace) > tolerance: - return ( - f"Pace doesn't match! " - f"Expected ~{expected_pace:.2f} min/km, got {pace:.2f}" - ) - return None - - def _validate_running_input(self) -> tuple[float, float, float] | None: - """Parse and validate running input fields.""" - try: - distance = float(self.distance_entry.get()) - time_mins = float(self.time_entry.get()) - pace = float(self.pace_entry.get()) - except ValueError: - self.show_error("Please enter valid numbers") - return None - error = self._check_running_ranges(distance, time_mins, pace) - if error: - self.show_error(error) - return None - return distance, time_mins, pace - - def verify_running_data(self) -> None: - """Validate running workout data and unlock if valid.""" - result = self._validate_running_input() - if result is None: - return - distance, time_mins, pace = result - self.workout_data["distance_km"] = str(distance) - self.workout_data["time_minutes"] = str(time_mins) - self.workout_data["pace_min_per_km"] = str(pace) - self._attempt_unlock() - - # ------------------------------------------------------------------ - # Strength workout - # ------------------------------------------------------------------ - - def _create_strength_entries(self) -> list[tk.Entry]: - """Create strength training entry fields.""" - entries = [ - self._entry_row(lbl, width=w, font_size=18) for lbl, w in _STRENGTH_FIELDS - ] - ( - self.exercises_entry, - self.sets_entry, - self.reps_entry, - self.weights_entry, - self.total_weight_entry, - ) = entries - return entries - - def ask_strength_details(self) -> None: - """Display strength training input form.""" - self.clear_container() - self.workout_data["type"] = "strength" - self._label("Strength Training Details", pady=20) - entries = self._create_strength_entries() - self._setup_form_controls( - entries, - self.verify_strength_data, - self.ask_workout_type, - ) - - def _parse_reps(self, reps_raw: list[str]) -> list[list[int]]: - """Parse reps input - can be single number or variable reps like '12+11+12'.""" - reps: list[list[int]] = [] - for r in reps_raw: - if "+" in r: - reps.append([int(x.strip()) for x in r.split("+")]) - else: - reps.append([int(r)]) - return reps - - def _validate_strength_inputs( - self, - exercises: list[str], - sets: list[int], - reps: list[list[int]], - weights: list[float], - ) -> str | None: - """Validate strength workout inputs. Returns error message or None if valid.""" - if not (len(exercises) == len(sets) == len(reps) == len(weights)): - return "Number of exercises, sets, reps, and weights must match" - if any(len(ex) < MIN_EXERCISE_NAME_LEN for ex in exercises): - return "Exercise names too short - be specific" - if any(s < 1 or s > MAX_SETS for s in sets): - return f"Sets should be between 1-{MAX_SETS}" - if any(w < 0 or w > MAX_WEIGHT_KG for w in weights): - return f"Weights should be between 0-{MAX_WEIGHT_KG} kg" - return self._validate_reps(exercises, sets, reps) - - def _validate_reps( - self, - exercises: list[str], - sets: list[int], - reps: list[list[int]], - ) -> str | None: - """Validate reps data. Returns error message or None if valid.""" - for i, rep_list in enumerate(reps): - if any(r < 1 or r > MAX_REPS for r in rep_list): - return f"Reps should be between 1-{MAX_REPS}" - if len(rep_list) > 1 and len(rep_list) != sets[i]: - return ( - f"For {exercises[i]!r}: variable reps count " - f"({len(rep_list)}) doesn't match sets ({sets[i]})" - ) - return None - - def _calculate_expected_total( - self, - sets: list[int], - reps: list[list[int]], - weights: list[float], - ) -> float: - """Calculate expected total weight lifted.""" - expected_total = 0.0 - for i, rep_list in enumerate(reps): - if len(rep_list) == 1: - expected_total += sets[i] * rep_list[0] * weights[i] - else: - expected_total += sum(rep_list) * weights[i] - return expected_total - - def _parse_strength_entries( - self, - ) -> tuple[list[str], list[int], list[list[int]], list[float], float]: - """Parse raw strength training input from entry widgets.""" - exercises = [e.strip() for e in self.exercises_entry.get().split(",")] - sets = [int(s.strip()) for s in self.sets_entry.get().split(",")] - reps_raw = [r.strip() for r in self.reps_entry.get().split(",")] - reps = self._parse_reps(reps_raw) - weights = [float(w.strip()) for w in self.weights_entry.get().split(",")] - total_weight = float(self.total_weight_entry.get()) - return exercises, sets, reps, weights, total_weight - - def _check_total_weight( - self, - sets: list[int], - reps: list[list[int]], - weights: list[float], - total_weight: float, - ) -> str | None: - """Verify total weight matches individual exercise calculations.""" - expected = self._calculate_expected_total(sets, reps, weights) - tolerance = expected * 0.15 # 15% tolerance - if abs(total_weight - expected) > tolerance: - return ( - f"Total weight doesn't match! " - f"Expected ~{expected:.1f} kg, got {total_weight:.1f}" - ) - return None - - def _store_strength_data( - self, - exercises: list[str], - sets: list[int], - reps: list[list[int]], - weights: list[float], - total_weight: float, - ) -> None: - """Store validated strength workout data.""" - self.workout_data["exercises"] = exercises - self.workout_data["sets"] = [str(s) for s in sets] - self.workout_data["reps"] = [ - "+".join(str(r) for r in rep_list) for rep_list in reps - ] - self.workout_data["weights_kg"] = [str(w) for w in weights] - self.workout_data["total_weight_kg"] = str(total_weight) - - def verify_strength_data(self) -> None: - """Validate strength workout data and unlock if valid.""" - try: - self._verify_strength_data_inner() - except ValueError: - self.show_error("Please enter valid data in correct format") - - def _verify_strength_data_inner(self) -> None: - """Parse, validate, and store strength data.""" - data = self._parse_strength_entries() - exercises, sets, reps, weights, total_weight = data - error = self._validate_strength_inputs(exercises, sets, reps, weights) - if error: - self.show_error(error) - return - total_err = self._check_total_weight(sets, reps, weights, total_weight) - if total_err: - self.show_error(total_err) - return - self._store_strength_data(exercises, sets, reps, weights, total_weight) - self._attempt_unlock() - - # ------------------------------------------------------------------ - # Phone workout verification via ADB + StrongLifts DB - # ------------------------------------------------------------------ - - def _run_adb(self, args: list[str]) -> tuple[bool, str]: - """Run an ADB command and return success flag and stdout.""" - adb = shutil.which("adb") or "adb" - # When multiple devices are connected (e.g. USB + wireless), pin to - # the wireless device's serial to avoid "more than one device" errors. - _discovery_cmds = {"devices", "connect", "disconnect", "kill-server"} - serial = ( - self._get_wireless_serial() - if args and args[0] not in _discovery_cmds - else None - ) - serial_args = ["-s", serial] if serial else [] - try: - result = subprocess.run( - [adb, *serial_args, *args], - capture_output=True, - text=True, - timeout=ADB_TIMEOUT, - check=False, - ) - except (FileNotFoundError, OSError) as exc: - _logger.warning("ADB not available: %s", exc) - return False, "" - except subprocess.TimeoutExpired: - _logger.warning("ADB command timed out: %s", args) - return False, "" - return result.returncode == 0, result.stdout - - def _adb_shell( - self, - command: str, - *, - root: bool = False, - ) -> tuple[bool, str]: - """Run a shell command on the connected Android device.""" - if root: - return self._run_adb(["shell", "su", "-c", command]) - return self._run_adb(["shell", command]) - - def _get_wireless_serial(self) -> str | None: - """Return the serial (ip:port) of the first connected wireless ADB device. - - Used to pin ADB commands to the wireless device when multiple devices - (e.g. USB cable + wireless debugging) are simultaneously connected. - """ - success, output = self._run_adb(["devices"]) - if not success: - return None - for line in output.strip().split("\n")[1:]: - parts = line.split() - if parts and ":" in parts[0] and "device" in line and "offline" not in line: - return parts[0] - return None - - def _has_adb_device(self) -> bool: - """Return True if adb devices shows at least one connected device.""" - success, output = self._run_adb(["devices"]) - if not success: - return False - lines = output.strip().split("\n")[1:] - return any("device" in line and "offline" not in line for line in lines) - - def _try_adb_connect(self, address: str) -> bool: - """Run adb connect to address. Returns True on success.""" - _, output = self._run_adb(["connect", address]) - lower = output.lower() - return "connected" in lower and "unable" not in lower and "failed" not in lower - - def _get_local_subnet_prefix(self) -> str | None: - """Detect the local /24 network prefix (e.g. '192.168.1').""" - with ( - contextlib.suppress(OSError), - socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock, - ): - sock.connect(("8.8.8.8", 80)) - return ".".join(sock.getsockname()[0].split(".")[:3]) - return None - - def _try_wireless_reconnect(self) -> bool: - """Scan local /24 subnet on port 5555 and attempt ADB connect to phone.""" - prefix = self._get_local_subnet_prefix() - if prefix is None: - _logger.info("Could not determine local subnet for wireless scan") - return False - - def probe(i: int) -> bool: - ip = f"{prefix}.{i}" - with ( - contextlib.suppress(OSError), - socket.create_connection((ip, 5555), timeout=0.5), - ): - if self._try_adb_connect(f"{ip}:5555"): - return self._has_adb_device() - return False - - _logger.info("Scanning %s.1-254:5555 for phone...", prefix) - with ThreadPoolExecutor(max_workers=64) as executor: - for future in as_completed( - executor.submit(probe, i) for i in range(1, 255) - ): - if future.result(): - return True - return False - - def _is_phone_connected(self) -> bool: - """Check if an Android device is connected via ADB. - - If no device is visible, attempts wireless reconnection using the - stored phone IP/port config. USB-connected devices are detected - automatically by adb devices without any extra steps. - """ - if self._has_adb_device(): - return True - _logger.info("No ADB device detected — attempting wireless reconnect...") - return self._try_wireless_reconnect() - - def _pull_stronglifts_db(self) -> Path | None: - """Pull StrongLifts database from phone to a local temp file. - - Returns: - Path to the local copy, or None on failure. - """ - tmp = Path(tempfile.gettempdir()) / "stronglifts_check.db" - success, _ = self._adb_shell( - f"cat '{STRONGLIFTS_DB_REMOTE}' > /sdcard/_sl_tmp.db", - root=True, - ) - if not success: - return None - ok, _ = self._run_adb(["pull", "/sdcard/_sl_tmp.db", str(tmp)]) - if not ok: - return None - return tmp - - def _count_today_workouts(self, db_path: Path) -> int: - """Count today's workouts in a local copy of StrongLifts DB. - - Args: - db_path: Path to the locally-pulled StrongLifts database. - - Returns: - Number of workouts started today (local time). - """ - try: - conn = sqlite3.connect(str(db_path)) - try: - cursor = conn.execute( - "SELECT COUNT(*) FROM workouts " - "WHERE date(start / 1000, 'unixepoch', 'localtime') " - "= date('now', 'localtime')", - ) - row = cursor.fetchone() - return int(row[0]) if row else 0 - finally: - conn.close() - except (sqlite3.Error, ValueError, TypeError): - _logger.warning("Failed to query StrongLifts database") - return 0 - - def _verify_phone_workout(self) -> tuple[str, str]: - """Verify workout was recorded in StrongLifts on the phone. - - Returns: - Tuple of (status, message) where status is one of: - - "verified": Workout confirmed on phone. - - "not_verified": Phone connected but no workout found. - - "no_phone": No phone connected via ADB. - - "error": Could not access StrongLifts database. - """ - if not self._is_phone_connected(): - return "no_phone", "No phone connected via ADB" - local_db = self._pull_stronglifts_db() - if local_db is None: - return "error", "StrongLifts database not found on phone" - count = self._count_today_workouts(local_db) - if count > 0: - return ( - "verified", - f"Workout verified! ({count} session(s) found on phone)", - ) - return "not_verified", "No workout found on phone today" - - def _attempt_unlock(self) -> None: - """Unlock screen after workout form submission.""" - self.unlock_screen() - - def _show_phone_penalty( - self, message: str, *, on_done: Callable[[], None] | None = None - ) -> None: - """Show penalty countdown when phone verification is unavailable.""" - self.clear_container() - self._phone_penalty_done_fn: Callable[[], None] = ( - on_done if on_done is not None else self.unlock_screen - ) - delay = ( - PHONE_PENALTY_DELAY_DEMO - if self.demo_mode - else PHONE_PENALTY_DELAY_PRODUCTION - ) - self._label( - "Cannot Verify Workout", - font_size=36, - color="#ff8800", - pady=20, - ) - self._text(message, color="#ffaa00") - self._text( - "Connect phone via ADB to skip this wait,\n" - "or wait for the penalty timer.\n\n" - "Note: Phone must be rooted and StrongLifts installed.", - font_size=18, - ) - self.phone_penalty_remaining = delay - self.phone_penalty_label = self._label( - str(delay), - font_size=80, - pady=20, - ) - self._update_phone_penalty() - - def _update_phone_penalty(self) -> None: - """Update phone penalty countdown.""" - if self.phone_penalty_remaining > 0: - self.phone_penalty_label.config( - text=str(self.phone_penalty_remaining), - ) - self.phone_penalty_remaining -= 1 - self.root.after(1000, self._update_phone_penalty) - else: - self._phone_penalty_done_fn() - - # ------------------------------------------------------------------ - # Submit timer and entry checking - # ------------------------------------------------------------------ - - def _tick_submit_timer(self) -> None: - """Decrement submit timer and schedule next tick.""" - self.timer_label.config( - text=f"Submit available in {self.submit_unlock_time} seconds...", - ) - self.submit_unlock_time -= 1 - self.root.after(1000, self.update_submit_timer) - - def _try_enable_submit(self) -> None: - """Enable submit button if all entries are filled.""" - all_filled = all(entry.get().strip() for entry in self.entries_to_check) - if all_filled: - self.submit_btn.config( - text="SUBMIT", - state="normal", - bg="#00aa00", - command=self.submit_command, - ) - self.timer_label.config(text="You can now submit!") - else: - self.timer_label.config(text="Fill all fields to enable submit") - self.root.after(1000, self.check_entries_filled) - - def update_submit_timer(self) -> None: - """Update countdown timer and check if submit can be enabled.""" - with contextlib.suppress(tk.TclError): - if self.submit_unlock_time > 0: - self._tick_submit_timer() - else: - self._try_enable_submit() - - def check_entries_filled(self) -> None: - """Continuously check if entries are filled after timer expires.""" - with contextlib.suppress(tk.TclError): - self._try_enable_submit() - # ------------------------------------------------------------------ # Error, unlock, and logging # ------------------------------------------------------------------ diff --git a/python_pkg/screen_locker/tests/conftest.py b/python_pkg/screen_locker/tests/conftest.py new file mode 100644 index 0000000..9c8cc5f --- /dev/null +++ b/python_pkg/screen_locker/tests/conftest.py @@ -0,0 +1,113 @@ +"""Shared fixtures and helpers for screen_locker tests.""" + +from __future__ import annotations + +from pathlib import Path +import tkinter as tk +from typing import TYPE_CHECKING, NamedTuple +from unittest.mock import MagicMock, patch + +import pytest + +from python_pkg.screen_locker.screen_lock import ScreenLocker + +if TYPE_CHECKING: + from collections.abc import Generator + + +class RunningData(NamedTuple): + """Running workout data for tests.""" + + distance: str + time_mins: str + pace: str + + +class StrengthData(NamedTuple): + """Strength workout data for tests.""" + + exercises: str + sets: str + reps: str + weights: str + total_weight: str + + +@pytest.fixture +def mock_tk() -> Generator[MagicMock]: + """Mock tkinter module for testing without display.""" + with patch("python_pkg.screen_locker.screen_lock.tk") as mock: + # Set up Tk root mock + mock_root = MagicMock() + mock_root.winfo_screenwidth.return_value = 1920 + mock_root.winfo_screenheight.return_value = 1080 + mock.Tk.return_value = mock_root + + # Set up Frame mock + mock_frame = MagicMock() + mock_frame.winfo_children.return_value = [] + mock.Frame.return_value = mock_frame + + # Set up TclError as actual exception class + mock.TclError = tk.TclError + + yield mock + + +@pytest.fixture +def mock_sys_exit() -> Generator[MagicMock]: + """Mock sys.exit to prevent test termination.""" + with patch("python_pkg.screen_locker.screen_lock.sys.exit") as mock: + yield mock + + +@pytest.fixture +def _mock_sys_exit(mock_sys_exit: MagicMock) -> MagicMock: + """Alias for mock_sys_exit when the return value is unused.""" + return mock_sys_exit + + +@pytest.fixture +def temp_log_file(tmp_path: Path) -> Path: + """Create a temporary log file path.""" + return tmp_path / "workout_log.json" + + +def create_locker( + _mock_tk: MagicMock, + tmp_path: Path, + *, + demo_mode: bool = True, + has_logged: bool = False, +) -> ScreenLocker: + """Create a ScreenLocker instance for testing.""" + with ( + patch.object(Path, "resolve", return_value=tmp_path), + patch.object(ScreenLocker, "has_logged_today", return_value=has_logged), + patch.object(ScreenLocker, "_start_phone_check"), + ): + return ScreenLocker(demo_mode=demo_mode) + + +def setup_running_entries(locker: ScreenLocker, data: RunningData) -> None: + """Set up mock running entry widgets.""" + locker.distance_entry = MagicMock() + locker.distance_entry.get.return_value = data.distance + locker.time_entry = MagicMock() + locker.time_entry.get.return_value = data.time_mins + locker.pace_entry = MagicMock() + locker.pace_entry.get.return_value = data.pace + + +def setup_strength_entries(locker: ScreenLocker, data: StrengthData) -> None: + """Set up mock strength entry widgets.""" + locker.exercises_entry = MagicMock() + locker.exercises_entry.get.return_value = data.exercises + locker.sets_entry = MagicMock() + locker.sets_entry.get.return_value = data.sets + locker.reps_entry = MagicMock() + locker.reps_entry.get.return_value = data.reps + locker.weights_entry = MagicMock() + locker.weights_entry.get.return_value = data.weights + locker.total_weight_entry = MagicMock() + locker.total_weight_entry.get.return_value = data.total_weight diff --git a/python_pkg/screen_locker/tests/test_adb_and_phone.py b/python_pkg/screen_locker/tests/test_adb_and_phone.py new file mode 100644 index 0000000..9c2b5e1 --- /dev/null +++ b/python_pkg/screen_locker/tests/test_adb_and_phone.py @@ -0,0 +1,411 @@ +"""Tests for ADB commands, phone connection, and database operations.""" + +from __future__ import annotations + +import sqlite3 +import subprocess +import time +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +from python_pkg.screen_locker.screen_lock import STRONGLIFTS_DB_REMOTE +from python_pkg.screen_locker.tests.conftest import create_locker + +if TYPE_CHECKING: + from pathlib import Path + + +class TestRunAdb: + """Tests for _run_adb ADB command execution.""" + + def test_run_adb_success( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test successful ADB command.""" + locker = create_locker(mock_tk, tmp_path) + mock_result = MagicMock(returncode=0, stdout="ok\n") + with patch( + "python_pkg.screen_locker._phone_verification.subprocess.run", + return_value=mock_result, + ) as mock_run: + success, output = locker._run_adb(["devices"]) + + assert success is True + assert output == "ok\n" + mock_run.assert_called_once() + + def test_run_adb_failure( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test failed ADB command.""" + locker = create_locker(mock_tk, tmp_path) + mock_result = MagicMock(returncode=1, stdout="") + with patch( + "python_pkg.screen_locker._phone_verification.subprocess.run", + return_value=mock_result, + ): + success, _output = locker._run_adb(["devices"]) + + assert success is False + + def test_run_adb_not_found( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ADB binary not found.""" + locker = create_locker(mock_tk, tmp_path) + with patch( + "python_pkg.screen_locker._phone_verification.subprocess.run", + side_effect=FileNotFoundError("adb not found"), + ): + success, output = locker._run_adb(["devices"]) + + assert success is False + assert output == "" + + def test_run_adb_oserror( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ADB OSError.""" + locker = create_locker(mock_tk, tmp_path) + with patch( + "python_pkg.screen_locker._phone_verification.subprocess.run", + side_effect=OSError("permission denied"), + ): + success, output = locker._run_adb(["devices"]) + + assert success is False + assert output == "" + + def test_run_adb_timeout( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ADB command timeout.""" + locker = create_locker(mock_tk, tmp_path) + with patch( + "python_pkg.screen_locker._phone_verification.subprocess.run", + side_effect=subprocess.TimeoutExpired("adb", 15), + ): + success, output = locker._run_adb(["devices"]) + + assert success is False + assert output == "" + + +class TestAdbShell: + """Tests for _adb_shell method.""" + + def test_adb_shell_no_root( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ADB shell without root.""" + locker = create_locker(mock_tk, tmp_path) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(True, "output"), + ) + + success, output = locker._adb_shell("ls /sdcard") + + locker._run_adb.assert_called_once_with(["shell", "ls /sdcard"]) + assert success is True + assert output == "output" + + def test_adb_shell_with_root( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ADB shell with root.""" + locker = create_locker(mock_tk, tmp_path) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(True, "output"), + ) + + success, _output = locker._adb_shell("ls /data", root=True) + + locker._run_adb.assert_called_once_with( + ["shell", "su", "-c", "ls /data"], + ) + assert success is True + + +class TestIsPhoneConnected: + """Tests for _is_phone_connected method.""" + + def test_phone_connected( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test phone detected as connected.""" + locker = create_locker(mock_tk, tmp_path) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=( + True, + "List of devices attached\nABC123\tdevice\n\n", + ), + ) + + assert locker._is_phone_connected() is True + + def test_phone_not_connected( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test no phone connected.""" + locker = create_locker(mock_tk, tmp_path) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(True, "List of devices attached\n\n"), + ) + locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign] + return_value=False, + ) + + assert locker._is_phone_connected() is False + + def test_phone_offline( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test phone connected but offline.""" + locker = create_locker(mock_tk, tmp_path) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=( + True, + "List of devices attached\nABC123\toffline\n\n", + ), + ) + locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign] + return_value=False, + ) + + assert locker._is_phone_connected() is False + + def test_adb_command_fails( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ADB command failure.""" + locker = create_locker(mock_tk, tmp_path) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(False, ""), + ) + locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign] + return_value=False, + ) + + assert locker._is_phone_connected() is False + + +class TestFindHealthConnectDb: + """Tests for _pull_stronglifts_db method.""" + + def test_db_pulled_successfully( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test StrongLifts DB pulled from device.""" + locker = create_locker(mock_tk, tmp_path) + locker._adb_shell = MagicMock( # type: ignore[method-assign] + return_value=(True, ""), + ) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(True, ""), + ) + + result = locker._pull_stronglifts_db() + + assert result is not None + locker._adb_shell.assert_called_once() + locker._run_adb.assert_called_once() + call_args = locker._run_adb.call_args[0][0] + assert call_args[0] == "pull" + + def test_db_cat_fails( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test returns None when cat command fails.""" + locker = create_locker(mock_tk, tmp_path) + locker._adb_shell = MagicMock( # type: ignore[method-assign] + return_value=(False, ""), + ) + + assert locker._pull_stronglifts_db() is None + + def test_db_pull_fails( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test returns None when adb pull fails.""" + locker = create_locker(mock_tk, tmp_path) + locker._adb_shell = MagicMock( # type: ignore[method-assign] + return_value=(True, ""), + ) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(False, ""), + ) + + assert locker._pull_stronglifts_db() is None + + def test_db_uses_correct_remote_path( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test uses the correct StrongLifts DB remote path.""" + locker = create_locker(mock_tk, tmp_path) + locker._adb_shell = MagicMock( # type: ignore[method-assign] + return_value=(True, ""), + ) + locker._run_adb = MagicMock( # type: ignore[method-assign] + return_value=(True, ""), + ) + + locker._pull_stronglifts_db() + + shell_cmd = locker._adb_shell.call_args[0][0] + assert STRONGLIFTS_DB_REMOTE in shell_cmd + + +class TestCountTodayWorkouts: + """Tests for _count_today_workouts method.""" + + def test_workouts_found_today( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test workouts found today.""" + locker = create_locker(mock_tk, tmp_path) + db_file = tmp_path / "sl_test.db" + conn = sqlite3.connect(str(db_file)) + conn.execute( + "CREATE TABLE workouts " + "(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)", + ) + # Insert a workout with today's timestamp (ms) + now_ms = int(time.time() * 1000) + conn.execute( + "INSERT INTO workouts VALUES (?, ?, ?)", + ("w1", now_ms, now_ms + 3600000), + ) + conn.commit() + conn.close() + + assert locker._count_today_workouts(db_file) == 1 + + def test_no_workouts_today( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test no workouts today.""" + locker = create_locker(mock_tk, tmp_path) + db_file = tmp_path / "sl_test.db" + conn = sqlite3.connect(str(db_file)) + conn.execute( + "CREATE TABLE workouts " + "(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)", + ) + # Insert a workout from yesterday (24h+ ago) + yesterday_ms = int((time.time() - 200000) * 1000) + conn.execute( + "INSERT INTO workouts VALUES (?, ?, ?)", + ("w1", yesterday_ms, yesterday_ms + 3600000), + ) + conn.commit() + conn.close() + + assert locker._count_today_workouts(db_file) == 0 + + def test_invalid_db_returns_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test returns 0 for invalid database file.""" + locker = create_locker(mock_tk, tmp_path) + bad_file = tmp_path / "not_a_db.db" + bad_file.write_text("not a database") + + assert locker._count_today_workouts(bad_file) == 0 + + def test_missing_table_returns_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test returns 0 when workouts table doesn't exist.""" + locker = create_locker(mock_tk, tmp_path) + db_file = tmp_path / "empty.db" + conn = sqlite3.connect(str(db_file)) + conn.execute("CREATE TABLE other (id TEXT)") + conn.commit() + conn.close() + + assert locker._count_today_workouts(db_file) == 0 + + def test_multiple_workouts_today( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test counts multiple workouts today correctly.""" + locker = create_locker(mock_tk, tmp_path) + db_file = tmp_path / "sl_test.db" + conn = sqlite3.connect(str(db_file)) + conn.execute( + "CREATE TABLE workouts " + "(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)", + ) + now_ms = int(time.time() * 1000) + conn.execute( + "INSERT INTO workouts VALUES (?, ?, ?)", + ("w1", now_ms, now_ms + 3600000), + ) + conn.execute( + "INSERT INTO workouts VALUES (?, ?, ?)", + ("w2", now_ms + 100000, now_ms + 3700000), + ) + conn.commit() + conn.close() + + assert locker._count_today_workouts(db_file) == 2 diff --git a/python_pkg/screen_locker/tests/test_init_and_log.py b/python_pkg/screen_locker/tests/test_init_and_log.py new file mode 100644 index 0000000..feaa5c8 --- /dev/null +++ b/python_pkg/screen_locker/tests/test_init_and_log.py @@ -0,0 +1,390 @@ +"""Tests for screen_locker initialization, logging, and basic operations.""" + +from __future__ import annotations + +from datetime import datetime, timezone +import json +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock + +import pytest + +from python_pkg.screen_locker.screen_lock import ( + MAX_DISTANCE_KM, + MAX_PACE_MIN_PER_KM, + MAX_REPS, + MAX_SETS, + MAX_TIME_MINUTES, + MAX_WEIGHT_KG, + MIN_EXERCISE_NAME_LEN, +) +from python_pkg.screen_locker.tests.conftest import create_locker + +if TYPE_CHECKING: + from pathlib import Path + + +class TestConstants: + """Tests for module constants.""" + + def test_max_distance_km(self) -> None: + """Test MAX_DISTANCE_KM is reasonable.""" + assert MAX_DISTANCE_KM == 100 + assert MAX_DISTANCE_KM > 0 + + def test_max_time_minutes(self) -> None: + """Test MAX_TIME_MINUTES is reasonable.""" + assert MAX_TIME_MINUTES == 600 + assert MAX_TIME_MINUTES > 0 + + def test_max_pace_min_per_km(self) -> None: + """Test MAX_PACE_MIN_PER_KM is reasonable.""" + assert MAX_PACE_MIN_PER_KM == 20 + assert MAX_PACE_MIN_PER_KM > 0 + + def test_min_exercise_name_len(self) -> None: + """Test MIN_EXERCISE_NAME_LEN is reasonable.""" + assert MIN_EXERCISE_NAME_LEN == 3 + assert MIN_EXERCISE_NAME_LEN > 0 + + def test_max_sets(self) -> None: + """Test MAX_SETS is reasonable.""" + assert MAX_SETS == 20 + assert MAX_SETS > 0 + + def test_max_reps(self) -> None: + """Test MAX_REPS is reasonable.""" + assert MAX_REPS == 100 + assert MAX_REPS > 0 + + def test_max_weight_kg(self) -> None: + """Test MAX_WEIGHT_KG is reasonable.""" + assert MAX_WEIGHT_KG == 500 + assert MAX_WEIGHT_KG > 0 + + +class TestScreenLockerInit: + """Tests for ScreenLocker initialization.""" + + def test_init_demo_mode( + self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path + ) -> None: + """Test initialization in demo mode.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=True) + + assert locker.demo_mode is True + assert locker.lockout_time == 10 + mock_sys_exit.assert_not_called() + + def test_init_production_mode( + self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path + ) -> None: + """Test initialization in production mode.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=False) + + assert locker.demo_mode is False + assert locker.lockout_time == 1800 + mock_sys_exit.assert_not_called() + + def test_init_exits_if_logged_today( + self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path + ) -> None: + """Test that init exits early if workout logged today.""" + mock_sys_exit.side_effect = SystemExit(0) + + with pytest.raises(SystemExit): + create_locker(mock_tk, tmp_path, has_logged=True) + + mock_sys_exit.assert_called_once_with(0) + + +class TestHasLoggedToday: + """Tests for has_logged_today method.""" + + def test_no_log_file( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test when log file doesn't exist.""" + log_file = tmp_path / "workout_log.json" + locker = create_locker(mock_tk, tmp_path) + + locker.log_file = log_file + assert locker.has_logged_today() is False + + def test_empty_log_file( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test when log file is empty/invalid JSON.""" + log_file = tmp_path / "workout_log.json" + log_file.write_text("") + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + assert locker.has_logged_today() is False + + def test_invalid_json( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test when log file contains invalid JSON.""" + log_file = tmp_path / "workout_log.json" + log_file.write_text("{invalid json}") + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + assert locker.has_logged_today() is False + + def test_today_logged( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test when today's workout is logged.""" + log_file = tmp_path / "workout_log.json" + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + log_file.write_text(json.dumps({today: {"workout": "data"}})) + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + assert locker.has_logged_today() is True + + def test_other_day_logged( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test when only other days are logged.""" + log_file = tmp_path / "workout_log.json" + log_file.write_text(json.dumps({"2020-01-01": {"workout": "data"}})) + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + assert locker.has_logged_today() is False + + +class TestSaveWorkoutLog: + """Tests for save_workout_log method.""" + + def test_save_to_new_file( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test saving to a new log file.""" + log_file = tmp_path / "workout_log.json" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + locker.workout_data = {"type": "running"} + locker.save_workout_log() + + assert log_file.exists() + with log_file.open() as f: + data: dict[str, Any] = json.load(f) + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + assert today in data + assert data[today]["workout_data"]["type"] == "running" + + def test_save_to_existing_file( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test saving appends to existing log file.""" + log_file = tmp_path / "workout_log.json" + log_file.write_text(json.dumps({"2020-01-01": {"old": "data"}})) + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + locker.workout_data = {"type": "strength"} + locker.save_workout_log() + + with log_file.open() as f: + data: dict[str, Any] = json.load(f) + assert "2020-01-01" in data + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + assert today in data + + def test_save_with_corrupted_existing_file( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test saving when existing file is corrupted.""" + log_file = tmp_path / "workout_log.json" + log_file.write_text("not valid json") + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + locker.workout_data = {"type": "running"} + locker.save_workout_log() + + with log_file.open() as f: + data: dict[str, Any] = json.load(f) + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + assert today in data + + def test_save_with_write_error( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test saving handles write errors gracefully.""" + log_file = tmp_path / "nonexistent_dir" / "workout_log.json" + + locker = create_locker(mock_tk, tmp_path) + locker.log_file = log_file + locker.workout_data = {"type": "running"} + # Should not raise, just log warning + locker.save_workout_log() + + +class TestShowError: + """Tests for show_error method.""" + + def test_show_error_displays_message( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test show_error clears container and displays error.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + + locker.show_error("Test error message") + + locker.clear_container.assert_called_once() + + +class TestRun: + """Tests for run method.""" + + def test_run_starts_mainloop( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test run starts the tkinter mainloop.""" + locker = create_locker(mock_tk, tmp_path) + + locker.run() + + locker.root.mainloop.assert_called_once() # type: ignore[attr-defined] + + +class TestMainEntry: + """Tests for main entry point.""" + + def test_main_demo_mode_default( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test main defaults to demo mode.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=True) + + assert locker.demo_mode is True + + def test_main_production_mode_flag( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test main with --production flag.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=False) + + assert locker.demo_mode is False + + +class TestAdjustShutdownTimeLater: + """Tests for _adjust_shutdown_time_later method.""" + + def test_adjust_shutdown_time_later_success( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _adjust_shutdown_time_later adds hours successfully.""" + locker = create_locker(mock_tk, tmp_path) + locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] + return_value=(21, 22, 8) + ) + locker._write_shutdown_config = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + result = locker._adjust_shutdown_time_later() + + assert result is True + locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True) + + def test_adjust_shutdown_time_later_caps_at_23( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _adjust_shutdown_time_later caps hours at 23.""" + locker = create_locker(mock_tk, tmp_path) + locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] + return_value=(22, 23, 8) + ) + locker._write_shutdown_config = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + result = locker._adjust_shutdown_time_later() + + assert result is True + # 22+2=24 capped to 23, 23+2=25 capped to 23 + locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True) + + def test_adjust_shutdown_time_later_no_config( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _adjust_shutdown_time_later returns False if config missing.""" + locker = create_locker(mock_tk, tmp_path) + locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] + return_value=None + ) + + result = locker._adjust_shutdown_time_later() + + assert result is False + + def test_adjust_shutdown_time_later_oserror( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _adjust_shutdown_time_later handles OSError.""" + locker = create_locker(mock_tk, tmp_path) + locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] + side_effect=OSError("permission denied") + ) + + result = locker._adjust_shutdown_time_later() + + assert result is False diff --git a/python_pkg/screen_locker/tests/test_phone_check_unlock.py b/python_pkg/screen_locker/tests/test_phone_check_unlock.py new file mode 100644 index 0000000..599324d --- /dev/null +++ b/python_pkg/screen_locker/tests/test_phone_check_unlock.py @@ -0,0 +1,430 @@ +"""Tests for phone workout verification, phone check, and unlock operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +from python_pkg.screen_locker.screen_lock import ( + PHONE_PENALTY_DELAY_DEMO, + PHONE_PENALTY_DELAY_PRODUCTION, +) +from python_pkg.screen_locker.tests.conftest import create_locker + +if TYPE_CHECKING: + from pathlib import Path + + +class TestVerifyPhoneWorkout: + """Tests for _verify_phone_workout method.""" + + def test_verified( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test workout verified on phone.""" + locker = create_locker(mock_tk, tmp_path) + locker._is_phone_connected = MagicMock( # type: ignore[method-assign] + return_value=True, + ) + locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign] + return_value=tmp_path / "sl.db", + ) + locker._count_today_workouts = MagicMock( # type: ignore[method-assign] + return_value=2, + ) + + status, message = locker._verify_phone_workout() + + assert status == "verified" + assert "2 session" in message + + def test_not_verified( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test no workout found on phone.""" + locker = create_locker(mock_tk, tmp_path) + locker._is_phone_connected = MagicMock( # type: ignore[method-assign] + return_value=True, + ) + locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign] + return_value=tmp_path / "sl.db", + ) + locker._count_today_workouts = MagicMock( # type: ignore[method-assign] + return_value=0, + ) + + status, message = locker._verify_phone_workout() + + assert status == "not_verified" + assert "No workout" in message + + def test_no_phone( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test no phone connected.""" + locker = create_locker(mock_tk, tmp_path) + locker._is_phone_connected = MagicMock( # type: ignore[method-assign] + return_value=False, + ) + + status, _ = locker._verify_phone_workout() + + assert status == "no_phone" + + def test_error_no_db( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test error when StrongLifts DB cannot be pulled.""" + locker = create_locker(mock_tk, tmp_path) + locker._is_phone_connected = MagicMock( # type: ignore[method-assign] + return_value=True, + ) + locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign] + return_value=None, + ) + + status, message = locker._verify_phone_workout() + + assert status == "error" + assert "database" in message.lower() + + +class TestStartPhoneCheck: + """Tests for _start_phone_check and _handle_startup_phone_result.""" + + def test_start_phone_check_shows_checking_screen( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _start_phone_check shows checking message and starts check.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker._verify_phone_workout = MagicMock( # type: ignore[method-assign] + return_value=("no_phone", "No phone"), + ) + locker._poll_phone_check = MagicMock() # type: ignore[method-assign] + + locker._start_phone_check() + + locker.clear_container.assert_called() + locker._poll_phone_check.assert_called_once() + assert locker._phone_future is not None + + def test_handle_startup_verified_unlocks_directly( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test verified result shows success screen then unlocks via after().""" + locker = create_locker(mock_tk, tmp_path) + locker.unlock_screen = MagicMock() # type: ignore[method-assign] + locker.root.after = MagicMock() # type: ignore[method-assign] + + locker._handle_startup_phone_result("verified", "Workout verified! (1 session)") + + # unlock_screen is deferred via root.after, not called directly + locker.unlock_screen.assert_not_called() + assert locker.workout_data["type"] == "phone_verified" + locker.root.after.assert_called_once_with(1500, locker.unlock_screen) + + def test_handle_startup_not_verified_shows_block( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test not_verified result shows blocking screen with buttons.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker._handle_startup_phone_result( + "not_verified", "No workout found on phone today" + ) + + locker.clear_container.assert_called() + + def test_handle_startup_no_phone_shows_penalty( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test no_phone result triggers penalty with ask_workout_done as callback.""" + locker = create_locker(mock_tk, tmp_path) + locker._show_phone_penalty = MagicMock() # type: ignore[method-assign] + + locker._handle_startup_phone_result("no_phone", "No phone") + + locker._show_phone_penalty.assert_called_once() + _, kwargs = locker._show_phone_penalty.call_args + assert kwargs["on_done"] == locker.ask_workout_done + + def test_handle_startup_error_shows_penalty( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test error result triggers penalty with ask_workout_done as callback.""" + locker = create_locker(mock_tk, tmp_path) + locker._show_phone_penalty = MagicMock() # type: ignore[method-assign] + + locker._handle_startup_phone_result("error", "DB not found") + + locker._show_phone_penalty.assert_called_once() + _, kwargs = locker._show_phone_penalty.call_args + assert kwargs["on_done"] == locker.ask_workout_done + + def test_poll_phone_check_schedules_retry_when_pending( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _poll_phone_check reschedules itself when future is not done.""" + locker = create_locker(mock_tk, tmp_path) + mock_future: MagicMock = MagicMock() + mock_future.done.return_value = False + locker._phone_future = mock_future # type: ignore[assignment] + locker.root.after = MagicMock() # type: ignore[method-assign] + + locker._poll_phone_check() + + locker.root.after.assert_called_once_with(500, locker._poll_phone_check) + + def test_poll_phone_check_routes_when_done( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _poll_phone_check calls result handler when future is done.""" + locker = create_locker(mock_tk, tmp_path) + mock_future: MagicMock = MagicMock() + mock_future.done.return_value = True + mock_future.result.return_value = ("no_phone", "No phone") + locker._phone_future = mock_future # type: ignore[assignment] + locker._handle_startup_phone_result = MagicMock() # type: ignore[method-assign] + + locker._poll_phone_check() + + locker._handle_startup_phone_result.assert_called_once_with( + "no_phone", "No phone" + ) + + +class TestAttemptUnlock: + """Tests for _attempt_unlock method.""" + + def test_attempt_unlock_calls_unlock_screen( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test _attempt_unlock calls unlock_screen directly.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "strength"} + locker.unlock_screen = MagicMock() # type: ignore[method-assign] + + locker._attempt_unlock() + + locker.unlock_screen.assert_called_once() + + +class TestShowPhonePenalty: + """Tests for _show_phone_penalty and _update_phone_penalty methods.""" + + def test_show_phone_penalty_demo_delay( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test demo mode uses short penalty delay.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=True) + locker.clear_container = MagicMock() # type: ignore[method-assign] + + locker._show_phone_penalty("test message") + + # _update_phone_penalty is called once, decrementing by 1 + assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_DEMO - 1 + + def test_show_phone_penalty_production_delay( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test production mode uses long penalty delay.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=False) + locker.clear_container = MagicMock() # type: ignore[method-assign] + + locker._show_phone_penalty("test message") + + assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_PRODUCTION - 1 + + def test_update_phone_penalty_countdown( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test phone penalty countdown decrements.""" + locker = create_locker(mock_tk, tmp_path) + locker.phone_penalty_remaining = 5 + locker.phone_penalty_label = MagicMock() + + locker._update_phone_penalty() + + assert locker.phone_penalty_remaining == 4 + locker.phone_penalty_label.config.assert_called_once_with(text="5") + locker.root.after.assert_called() # type: ignore[attr-defined] + + def test_update_phone_penalty_at_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test phone penalty unlocks when timer reaches zero.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "strength"} + locker.phone_penalty_remaining = 0 + locker.phone_penalty_label = MagicMock() + locker.unlock_screen = MagicMock() # type: ignore[method-assign] + locker._phone_penalty_done_fn = locker.unlock_screen # type: ignore[attr-defined] + + locker._update_phone_penalty() + + locker.unlock_screen.assert_called_once() + + +class TestUnlockScreenShutdownAdjustment: + """Tests for unlock_screen shutdown time adjustment.""" + + def test_unlock_screen_adjusts_for_running( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen adjusts shutdown for running workout.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "running"} + locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + locker.unlock_screen() + + locker._adjust_shutdown_time_later.assert_called_once() + + def test_unlock_screen_adjusts_for_strength( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen adjusts shutdown for strength workout.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "strength"} + locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + locker.unlock_screen() + + locker._adjust_shutdown_time_later.assert_called_once() + + def test_unlock_screen_adjusts_for_phone_verified( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen adjusts shutdown for phone-verified workout.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "phone_verified"} + locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + locker.unlock_screen() + + locker._adjust_shutdown_time_later.assert_called_once() + + def test_unlock_screen_skips_adjustment_for_sick_day( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen does not adjust for sick day.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "sick_day"} + locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + locker.unlock_screen() + + locker._adjust_shutdown_time_later.assert_not_called() + + def test_unlock_screen_skips_adjustment_no_type( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen does not adjust when no workout type.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {} + locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] + return_value=True + ) + + locker.unlock_screen() + + locker._adjust_shutdown_time_later.assert_not_called() + + def test_unlock_screen_handles_adjustment_failure( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen continues when adjustment fails.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "running"} + locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] + return_value=False + ) + + # Should not raise, should continue with unlock + locker.unlock_screen() + + locker._adjust_shutdown_time_later.assert_called_once() + locker.root.after.assert_called() # type: ignore[attr-defined] diff --git a/python_pkg/screen_locker/tests/test_screen_lock.py b/python_pkg/screen_locker/tests/test_screen_lock.py deleted file mode 100644 index 19bd424..0000000 --- a/python_pkg/screen_locker/tests/test_screen_lock.py +++ /dev/null @@ -1,2079 +0,0 @@ -"""Comprehensive tests for screen_locker module. - -Tests cover: -- ScreenLocker initialization and configuration -- Workout data validation (running and strength) -- Log file operations (reading/writing) -- UI state transitions -- Timer logic -""" - -from __future__ import annotations - -from datetime import datetime, timezone -import json -from pathlib import Path -import sqlite3 -import subprocess -import tkinter as tk -from typing import TYPE_CHECKING, Any, NamedTuple -from unittest.mock import MagicMock, patch - -import pytest - -from python_pkg.screen_locker.screen_lock import ( - MAX_DISTANCE_KM, - MAX_PACE_MIN_PER_KM, - MAX_REPS, - MAX_SETS, - MAX_TIME_MINUTES, - MAX_WEIGHT_KG, - MIN_EXERCISE_NAME_LEN, - PHONE_PENALTY_DELAY_DEMO, - PHONE_PENALTY_DELAY_PRODUCTION, - STRONGLIFTS_DB_REMOTE, - SUBMIT_DELAY_DEMO, - SUBMIT_DELAY_PRODUCTION, - ScreenLocker, -) - -if TYPE_CHECKING: - from collections.abc import Generator - -# Reference tk to avoid import-but-unused error -_TK_TCLERROR = tk.TclError - - -class RunningData(NamedTuple): - """Running workout data for tests.""" - - distance: str - time_mins: str - pace: str - - -class StrengthData(NamedTuple): - """Strength workout data for tests.""" - - exercises: str - sets: str - reps: str - weights: str - total_weight: str - - -@pytest.fixture -def mock_tk() -> Generator[MagicMock]: - """Mock tkinter module for testing without display.""" - with patch("python_pkg.screen_locker.screen_lock.tk") as mock: - # Set up Tk root mock - mock_root = MagicMock() - mock_root.winfo_screenwidth.return_value = 1920 - mock_root.winfo_screenheight.return_value = 1080 - mock.Tk.return_value = mock_root - - # Set up Frame mock - mock_frame = MagicMock() - mock_frame.winfo_children.return_value = [] - mock.Frame.return_value = mock_frame - - # Set up TclError as actual exception class - mock.TclError = _TK_TCLERROR - - yield mock - - -@pytest.fixture -def mock_sys_exit() -> Generator[MagicMock]: - """Mock sys.exit to prevent test termination.""" - with patch("python_pkg.screen_locker.screen_lock.sys.exit") as mock: - yield mock - - -@pytest.fixture -def temp_log_file(tmp_path: Path) -> Path: - """Create a temporary log file path.""" - return tmp_path / "workout_log.json" - - -def create_locker( - _mock_tk: MagicMock, - tmp_path: Path, - *, - demo_mode: bool = True, - has_logged: bool = False, -) -> ScreenLocker: - """Create a ScreenLocker instance for testing.""" - with ( - patch.object(Path, "resolve", return_value=tmp_path), - patch.object(ScreenLocker, "has_logged_today", return_value=has_logged), - patch.object(ScreenLocker, "_start_phone_check"), - ): - return ScreenLocker(demo_mode=demo_mode) - - -def setup_running_entries(locker: ScreenLocker, data: RunningData) -> None: - """Set up mock running entry widgets.""" - locker.distance_entry = MagicMock() - locker.distance_entry.get.return_value = data.distance - locker.time_entry = MagicMock() - locker.time_entry.get.return_value = data.time_mins - locker.pace_entry = MagicMock() - locker.pace_entry.get.return_value = data.pace - - -def setup_strength_entries(locker: ScreenLocker, data: StrengthData) -> None: - """Set up mock strength entry widgets.""" - locker.exercises_entry = MagicMock() - locker.exercises_entry.get.return_value = data.exercises - locker.sets_entry = MagicMock() - locker.sets_entry.get.return_value = data.sets - locker.reps_entry = MagicMock() - locker.reps_entry.get.return_value = data.reps - locker.weights_entry = MagicMock() - locker.weights_entry.get.return_value = data.weights - locker.total_weight_entry = MagicMock() - locker.total_weight_entry.get.return_value = data.total_weight - - -class TestConstants: - """Tests for module constants.""" - - def test_max_distance_km(self) -> None: - """Test MAX_DISTANCE_KM is reasonable.""" - assert MAX_DISTANCE_KM == 100 - assert MAX_DISTANCE_KM > 0 - - def test_max_time_minutes(self) -> None: - """Test MAX_TIME_MINUTES is reasonable.""" - assert MAX_TIME_MINUTES == 600 - assert MAX_TIME_MINUTES > 0 - - def test_max_pace_min_per_km(self) -> None: - """Test MAX_PACE_MIN_PER_KM is reasonable.""" - assert MAX_PACE_MIN_PER_KM == 20 - assert MAX_PACE_MIN_PER_KM > 0 - - def test_min_exercise_name_len(self) -> None: - """Test MIN_EXERCISE_NAME_LEN is reasonable.""" - assert MIN_EXERCISE_NAME_LEN == 3 - assert MIN_EXERCISE_NAME_LEN > 0 - - def test_max_sets(self) -> None: - """Test MAX_SETS is reasonable.""" - assert MAX_SETS == 20 - assert MAX_SETS > 0 - - def test_max_reps(self) -> None: - """Test MAX_REPS is reasonable.""" - assert MAX_REPS == 100 - assert MAX_REPS > 0 - - def test_max_weight_kg(self) -> None: - """Test MAX_WEIGHT_KG is reasonable.""" - assert MAX_WEIGHT_KG == 500 - assert MAX_WEIGHT_KG > 0 - - -class TestScreenLockerInit: - """Tests for ScreenLocker initialization.""" - - def test_init_demo_mode( - self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path - ) -> None: - """Test initialization in demo mode.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=True) - - assert locker.demo_mode is True - assert locker.lockout_time == 10 - mock_sys_exit.assert_not_called() - - def test_init_production_mode( - self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path - ) -> None: - """Test initialization in production mode.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=False) - - assert locker.demo_mode is False - assert locker.lockout_time == 1800 - mock_sys_exit.assert_not_called() - - def test_init_exits_if_logged_today( - self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path - ) -> None: - """Test that init exits early if workout logged today.""" - mock_sys_exit.side_effect = SystemExit(0) - - with pytest.raises(SystemExit): - create_locker(mock_tk, tmp_path, has_logged=True) - - mock_sys_exit.assert_called_once_with(0) - - -class TestHasLoggedToday: - """Tests for has_logged_today method.""" - - def test_no_log_file( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test when log file doesn't exist.""" - log_file = tmp_path / "workout_log.json" - locker = create_locker(mock_tk, tmp_path) - - locker.log_file = log_file - assert locker.has_logged_today() is False - - def test_empty_log_file( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test when log file is empty/invalid JSON.""" - log_file = tmp_path / "workout_log.json" - log_file.write_text("") - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - assert locker.has_logged_today() is False - - def test_invalid_json( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test when log file contains invalid JSON.""" - log_file = tmp_path / "workout_log.json" - log_file.write_text("{invalid json}") - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - assert locker.has_logged_today() is False - - def test_today_logged( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test when today's workout is logged.""" - log_file = tmp_path / "workout_log.json" - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - log_file.write_text(json.dumps({today: {"workout": "data"}})) - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - assert locker.has_logged_today() is True - - def test_other_day_logged( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test when only other days are logged.""" - log_file = tmp_path / "workout_log.json" - log_file.write_text(json.dumps({"2020-01-01": {"workout": "data"}})) - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - assert locker.has_logged_today() is False - - -class TestSaveWorkoutLog: - """Tests for save_workout_log method.""" - - def test_save_to_new_file( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test saving to a new log file.""" - log_file = tmp_path / "workout_log.json" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - locker.workout_data = {"type": "running"} - locker.save_workout_log() - - assert log_file.exists() - with log_file.open() as f: - data: dict[str, Any] = json.load(f) - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - assert today in data - assert data[today]["workout_data"]["type"] == "running" - - def test_save_to_existing_file( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test saving appends to existing log file.""" - log_file = tmp_path / "workout_log.json" - log_file.write_text(json.dumps({"2020-01-01": {"old": "data"}})) - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - locker.workout_data = {"type": "strength"} - locker.save_workout_log() - - with log_file.open() as f: - data: dict[str, Any] = json.load(f) - assert "2020-01-01" in data - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - assert today in data - - def test_save_with_corrupted_existing_file( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test saving when existing file is corrupted.""" - log_file = tmp_path / "workout_log.json" - log_file.write_text("not valid json") - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - locker.workout_data = {"type": "running"} - locker.save_workout_log() - - with log_file.open() as f: - data: dict[str, Any] = json.load(f) - today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - assert today in data - - def test_save_with_write_error( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test saving handles write errors gracefully.""" - log_file = tmp_path / "nonexistent_dir" / "workout_log.json" - - locker = create_locker(mock_tk, tmp_path) - locker.log_file = log_file - locker.workout_data = {"type": "running"} - # Should not raise, just log warning - locker.save_workout_log() - - -class TestVerifyRunningData: - """Tests for verify_running_data method.""" - - def test_valid_running_data( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test valid running data triggers unlock attempt.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("5", "25", "5")) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "running"} - locker._attempt_unlock = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker._attempt_unlock.assert_called_once() - - def test_invalid_distance_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test zero distance is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("0", "25", "5")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Distance" in locker.show_error.call_args[0][0] - - def test_invalid_distance_too_high( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test distance over max is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("150", "600", "4")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Distance" in locker.show_error.call_args[0][0] - - def test_invalid_time_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test zero time is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("5", "0", "5")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Time" in locker.show_error.call_args[0][0] - - def test_invalid_time_too_high( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test time over max is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("5", "700", "5")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Time" in locker.show_error.call_args[0][0] - - def test_invalid_pace_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test zero pace is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("5", "25", "0")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Pace" in locker.show_error.call_args[0][0] - - def test_invalid_pace_too_high( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test pace over max is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("5", "25", "25")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Pace" in locker.show_error.call_args[0][0] - - def test_pace_mismatch( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test pace mismatch is rejected.""" - # 5km in 25 min should be 5 min/km, but we say 10 min/km - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("5", "25", "10")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "Pace doesn't match" in locker.show_error.call_args[0][0] - - def test_invalid_number_format( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test non-numeric input is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_running_entries(locker, RunningData("abc", "25", "5")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_running_data() - - locker.show_error.assert_called_once() - assert "valid numbers" in locker.show_error.call_args[0][0] - - -class TestVerifyStrengthData: - """Tests for verify_strength_data method.""" - - def test_valid_strength_data( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test valid strength data triggers unlock attempt.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "3", "10", "50", "1500")) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "strength"} - locker._attempt_unlock = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker._attempt_unlock.assert_called_once() - - def test_valid_multiple_exercises( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test valid data with multiple exercises.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries( - locker, - StrengthData("Squat, Bench Press", "3, 3", "10, 8", "50, 40", "2460"), - ) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "strength"} - locker._attempt_unlock = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker._attempt_unlock.assert_called_once() - - def test_mismatched_list_lengths( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test mismatched list lengths are rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries( - locker, - StrengthData("Squat, Bench", "3", "10, 8", "50, 40", "2000"), - ) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "must match" in locker.show_error.call_args[0][0] - - def test_short_exercise_name( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test short exercise names are rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Sq", "3", "10", "50", "1500")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "too short" in locker.show_error.call_args[0][0] - - def test_invalid_sets_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test zero sets is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "0", "10", "50", "0")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Sets" in locker.show_error.call_args[0][0] - - def test_invalid_sets_too_high( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test sets over max is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "25", "10", "50", "12500")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Sets" in locker.show_error.call_args[0][0] - - def test_invalid_reps_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test zero reps is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "3", "0", "50", "0")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Reps" in locker.show_error.call_args[0][0] - - def test_invalid_reps_too_high( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test reps over max is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "3", "150", "50", "22500")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Reps" in locker.show_error.call_args[0][0] - - def test_invalid_weight_negative( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test negative weight is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "3", "10", "-10", "-300")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Weights" in locker.show_error.call_args[0][0] - - def test_invalid_weight_too_high( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test weight over max is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "3", "10", "600", "18000")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Weights" in locker.show_error.call_args[0][0] - - def test_total_weight_mismatch( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test total weight mismatch is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "3", "10", "50", "3000")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "Total weight doesn't match" in locker.show_error.call_args[0][0] - - def test_invalid_format( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test invalid format is rejected.""" - locker = create_locker(mock_tk, tmp_path) - setup_strength_entries(locker, StrengthData("Squat", "abc", "10", "50", "1500")) - locker.show_error = MagicMock() # type: ignore[method-assign] - - locker.verify_strength_data() - - locker.show_error.assert_called_once() - assert "valid data" in locker.show_error.call_args[0][0] - - -class TestUITransitions: - """Tests for UI state transitions.""" - - def test_clear_container( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test clear_container destroys all child widgets.""" - locker = create_locker(mock_tk, tmp_path) - - # Set up mock children - mock_child1 = MagicMock() - mock_child2 = MagicMock() - locker.container.winfo_children.return_value = [ # type: ignore[attr-defined] - mock_child1, - mock_child2, - ] - - locker.clear_container() - - mock_child1.destroy.assert_called_once() - mock_child2.destroy.assert_called_once() - - def test_unlock_screen_saves_and_schedules_close( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen saves log and schedules close.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "running"} - - locker.unlock_screen() - - # Check that after() was called to schedule close - locker.root.after.assert_called() # type: ignore[attr-defined] - - def test_lockout_starts_countdown( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test lockout initializes countdown timer.""" - locker = create_locker(mock_tk, tmp_path) - - locker.lockout() - - # lockout() sets remaining_time to lockout_time (10 in demo mode) - # then calls update_lockout_countdown() which decrements it by 1 - assert locker.remaining_time == 9 # 10 - 1 after first update - - def test_close_destroys_root_and_exits( - self, - mock_tk: MagicMock, - mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test close destroys root window and exits.""" - locker = create_locker(mock_tk, tmp_path) - - locker.close() - - locker.root.destroy.assert_called_once() # type: ignore[attr-defined] - mock_sys_exit.assert_called_with(0) - - -class TestTimerLogic: - """Tests for timer countdown logic.""" - - def test_update_lockout_countdown_decrements( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test countdown decrements remaining time.""" - locker = create_locker(mock_tk, tmp_path) - locker.remaining_time = 5 - locker.countdown_label = MagicMock() - - locker.update_lockout_countdown() - - assert locker.remaining_time == 4 - locker.root.after.assert_called_with( # type: ignore[attr-defined] - 1000, locker.update_lockout_countdown - ) - - def test_update_lockout_countdown_at_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test countdown at zero returns to workout question.""" - locker = create_locker(mock_tk, tmp_path) - locker.remaining_time = 0 - locker.countdown_label = MagicMock() - locker.ask_workout_done = MagicMock() # type: ignore[method-assign] - - locker.update_lockout_countdown() - - locker.ask_workout_done.assert_called_once() - - def test_update_submit_timer_countdown( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test submit timer counts down.""" - locker = create_locker(mock_tk, tmp_path) - locker.submit_unlock_time = 5 - locker.timer_label = MagicMock() - locker.submit_btn = MagicMock() - locker.entries_to_check = [] - - locker.update_submit_timer() - - assert locker.submit_unlock_time == 4 - locker.root.after.assert_called() # type: ignore[attr-defined] - - def test_update_submit_timer_enables_when_filled( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test submit enabled when timer done and entries filled.""" - locker = create_locker(mock_tk, tmp_path) - locker.submit_unlock_time = 0 - locker.timer_label = MagicMock() - locker.submit_btn = MagicMock() - mock_entry = MagicMock() - mock_entry.get.return_value = "some value" - locker.entries_to_check = [mock_entry] - locker.submit_command = MagicMock() - - locker.update_submit_timer() - - locker.submit_btn.config.assert_called() - - def test_update_submit_timer_waits_for_entries( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test submit waits when entries not filled.""" - locker = create_locker(mock_tk, tmp_path) - locker.submit_unlock_time = 0 - locker.timer_label = MagicMock() - locker.submit_btn = MagicMock() - mock_entry = MagicMock() - mock_entry.get.return_value = "" # Empty entry - locker.entries_to_check = [mock_entry] - - locker.update_submit_timer() - - locker.root.after.assert_called_with( # type: ignore[attr-defined] - 1000, locker.check_entries_filled - ) - - def test_update_submit_timer_handles_tcl_error( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test timer handles TclError when widgets destroyed.""" - locker = create_locker(mock_tk, tmp_path) - locker.submit_unlock_time = 5 - locker.timer_label = MagicMock() - locker.timer_label.config.side_effect = _TK_TCLERROR("widget destroyed") - - # Should not raise - locker.update_submit_timer() - - def test_check_entries_filled_enables_submit( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test check_entries_filled enables submit when all filled.""" - locker = create_locker(mock_tk, tmp_path) - locker.timer_label = MagicMock() - locker.submit_btn = MagicMock() - mock_entry = MagicMock() - mock_entry.get.return_value = "value" - locker.entries_to_check = [mock_entry] - locker.submit_command = MagicMock() - - locker.check_entries_filled() - - locker.submit_btn.config.assert_called() - - def test_check_entries_filled_continues_waiting( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test check_entries_filled continues waiting when not filled.""" - locker = create_locker(mock_tk, tmp_path) - locker.timer_label = MagicMock() - locker.submit_btn = MagicMock() - mock_entry = MagicMock() - mock_entry.get.return_value = "" - locker.entries_to_check = [mock_entry] - - locker.check_entries_filled() - - locker.root.after.assert_called_with( # type: ignore[attr-defined] - 1000, locker.check_entries_filled - ) - - def test_check_entries_filled_handles_tcl_error( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test check_entries_filled handles TclError.""" - locker = create_locker(mock_tk, tmp_path) - locker.timer_label = MagicMock() - mock_entry = MagicMock() - mock_entry.get.side_effect = _TK_TCLERROR("widget destroyed") - locker.entries_to_check = [mock_entry] - - # Should not raise - locker.check_entries_filled() - - -class TestShowError: - """Tests for show_error method.""" - - def test_show_error_displays_message( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test show_error clears container and displays error.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - - locker.show_error("Test error message") - - locker.clear_container.assert_called_once() - - -class TestRun: - """Tests for run method.""" - - def test_run_starts_mainloop( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test run starts the tkinter mainloop.""" - locker = create_locker(mock_tk, tmp_path) - - locker.run() - - locker.root.mainloop.assert_called_once() # type: ignore[attr-defined] - - -class TestMainEntry: - """Tests for main entry point.""" - - def test_main_demo_mode_default( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test main defaults to demo mode.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=True) - - assert locker.demo_mode is True - - def test_main_production_mode_flag( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test main with --production flag.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=False) - - assert locker.demo_mode is False - - -class TestAskWorkoutType: - """Tests for ask_workout_type method.""" - - def test_ask_workout_type_creates_buttons( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_workout_type creates running and strength buttons.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - - locker.ask_workout_type() - - locker.clear_container.assert_called_once() - # Verify Label and Button were called - mock_tk.Label.assert_called() - mock_tk.Button.assert_called() - - -class TestAskRunningDetails: - """Tests for ask_running_details method.""" - - def test_ask_running_details_sets_workout_type( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_running_details sets workout type to running.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_running_details() - - assert locker.workout_data["type"] == "running" - locker.clear_container.assert_called_once() - - def test_ask_running_details_creates_entry_fields( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_running_details creates entry fields.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_running_details() - - # Verify Entry fields were created - mock_tk.Entry.assert_called() - assert hasattr(locker, "distance_entry") - assert hasattr(locker, "time_entry") - assert hasattr(locker, "pace_entry") - - def test_ask_running_details_sets_timer( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_running_details initializes submit timer.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_running_details() - - assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO - locker.update_submit_timer.assert_called_once() - - -class TestAskStrengthDetails: - """Tests for ask_strength_details method.""" - - def test_ask_strength_details_sets_workout_type( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_strength_details sets workout type to strength.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_strength_details() - - assert locker.workout_data["type"] == "strength" - locker.clear_container.assert_called_once() - - def test_ask_strength_details_creates_entry_fields( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_strength_details creates entry fields.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_strength_details() - - # Verify Entry fields were created - mock_tk.Entry.assert_called() - assert hasattr(locker, "exercises_entry") - assert hasattr(locker, "sets_entry") - assert hasattr(locker, "reps_entry") - assert hasattr(locker, "weights_entry") - assert hasattr(locker, "total_weight_entry") - - def test_ask_strength_details_sets_timer( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_strength_details initializes submit timer.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_strength_details() - - assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO - locker.update_submit_timer.assert_called_once() - - def test_ask_strength_details_production_timer( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test production mode uses longer submit delay.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=False) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker.update_submit_timer = MagicMock() # type: ignore[method-assign] - - locker.ask_strength_details() - - assert locker.submit_unlock_time == SUBMIT_DELAY_PRODUCTION - - -class TestAskWorkoutDone: - """Tests for ask_workout_done method.""" - - def test_ask_workout_done_creates_buttons( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ask_workout_done creates yes/no buttons.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - - locker.ask_workout_done() - - locker.clear_container.assert_called_once() - mock_tk.Label.assert_called() - mock_tk.Button.assert_called() - - -class TestAdjustShutdownTimeLater: - """Tests for _adjust_shutdown_time_later method.""" - - def test_adjust_shutdown_time_later_success( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _adjust_shutdown_time_later adds hours successfully.""" - locker = create_locker(mock_tk, tmp_path) - locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] - return_value=(21, 22, 8) - ) - locker._write_shutdown_config = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - result = locker._adjust_shutdown_time_later() - - assert result is True - locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True) - - def test_adjust_shutdown_time_later_caps_at_23( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _adjust_shutdown_time_later caps hours at 23.""" - locker = create_locker(mock_tk, tmp_path) - locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] - return_value=(22, 23, 8) - ) - locker._write_shutdown_config = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - result = locker._adjust_shutdown_time_later() - - assert result is True - # 22+2=24 capped to 23, 23+2=25 capped to 23 - locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True) - - def test_adjust_shutdown_time_later_no_config( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _adjust_shutdown_time_later returns False if config missing.""" - locker = create_locker(mock_tk, tmp_path) - locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] - return_value=None - ) - - result = locker._adjust_shutdown_time_later() - - assert result is False - - def test_adjust_shutdown_time_later_oserror( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _adjust_shutdown_time_later handles OSError.""" - locker = create_locker(mock_tk, tmp_path) - locker._read_shutdown_config = MagicMock( # type: ignore[method-assign] - side_effect=OSError("permission denied") - ) - - result = locker._adjust_shutdown_time_later() - - assert result is False - - -class TestRunAdb: - """Tests for _run_adb ADB command execution.""" - - def test_run_adb_success( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test successful ADB command.""" - locker = create_locker(mock_tk, tmp_path) - mock_result = MagicMock(returncode=0, stdout="ok\n") - with patch( - "python_pkg.screen_locker.screen_lock.subprocess.run", - return_value=mock_result, - ) as mock_run: - success, output = locker._run_adb(["devices"]) - - assert success is True - assert output == "ok\n" - mock_run.assert_called_once() - - def test_run_adb_failure( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test failed ADB command.""" - locker = create_locker(mock_tk, tmp_path) - mock_result = MagicMock(returncode=1, stdout="") - with patch( - "python_pkg.screen_locker.screen_lock.subprocess.run", - return_value=mock_result, - ): - success, _output = locker._run_adb(["devices"]) - - assert success is False - - def test_run_adb_not_found( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ADB binary not found.""" - locker = create_locker(mock_tk, tmp_path) - with patch( - "python_pkg.screen_locker.screen_lock.subprocess.run", - side_effect=FileNotFoundError("adb not found"), - ): - success, output = locker._run_adb(["devices"]) - - assert success is False - assert output == "" - - def test_run_adb_oserror( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ADB OSError.""" - locker = create_locker(mock_tk, tmp_path) - with patch( - "python_pkg.screen_locker.screen_lock.subprocess.run", - side_effect=OSError("permission denied"), - ): - success, output = locker._run_adb(["devices"]) - - assert success is False - assert output == "" - - def test_run_adb_timeout( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ADB command timeout.""" - locker = create_locker(mock_tk, tmp_path) - with patch( - "python_pkg.screen_locker.screen_lock.subprocess.run", - side_effect=subprocess.TimeoutExpired("adb", 15), - ): - success, output = locker._run_adb(["devices"]) - - assert success is False - assert output == "" - - -class TestAdbShell: - """Tests for _adb_shell method.""" - - def test_adb_shell_no_root( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ADB shell without root.""" - locker = create_locker(mock_tk, tmp_path) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(True, "output"), - ) - - success, output = locker._adb_shell("ls /sdcard") - - locker._run_adb.assert_called_once_with(["shell", "ls /sdcard"]) - assert success is True - assert output == "output" - - def test_adb_shell_with_root( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ADB shell with root.""" - locker = create_locker(mock_tk, tmp_path) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(True, "output"), - ) - - success, _output = locker._adb_shell("ls /data", root=True) - - locker._run_adb.assert_called_once_with( - ["shell", "su", "-c", "ls /data"], - ) - assert success is True - - -class TestIsPhoneConnected: - """Tests for _is_phone_connected method.""" - - def test_phone_connected( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test phone detected as connected.""" - locker = create_locker(mock_tk, tmp_path) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=( - True, - "List of devices attached\nABC123\tdevice\n\n", - ), - ) - - assert locker._is_phone_connected() is True - - def test_phone_not_connected( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test no phone connected.""" - locker = create_locker(mock_tk, tmp_path) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(True, "List of devices attached\n\n"), - ) - locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign] - return_value=False, - ) - - assert locker._is_phone_connected() is False - - def test_phone_offline( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test phone connected but offline.""" - locker = create_locker(mock_tk, tmp_path) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=( - True, - "List of devices attached\nABC123\toffline\n\n", - ), - ) - locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign] - return_value=False, - ) - - assert locker._is_phone_connected() is False - - def test_adb_command_fails( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test ADB command failure.""" - locker = create_locker(mock_tk, tmp_path) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(False, ""), - ) - locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign] - return_value=False, - ) - - assert locker._is_phone_connected() is False - - -class TestFindHealthConnectDb: - """Tests for _pull_stronglifts_db method.""" - - def test_db_pulled_successfully( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test StrongLifts DB pulled from device.""" - locker = create_locker(mock_tk, tmp_path) - locker._adb_shell = MagicMock( # type: ignore[method-assign] - return_value=(True, ""), - ) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(True, ""), - ) - - result = locker._pull_stronglifts_db() - - assert result is not None - locker._adb_shell.assert_called_once() - locker._run_adb.assert_called_once() - call_args = locker._run_adb.call_args[0][0] - assert call_args[0] == "pull" - - def test_db_cat_fails( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test returns None when cat command fails.""" - locker = create_locker(mock_tk, tmp_path) - locker._adb_shell = MagicMock( # type: ignore[method-assign] - return_value=(False, ""), - ) - - assert locker._pull_stronglifts_db() is None - - def test_db_pull_fails( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test returns None when adb pull fails.""" - locker = create_locker(mock_tk, tmp_path) - locker._adb_shell = MagicMock( # type: ignore[method-assign] - return_value=(True, ""), - ) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(False, ""), - ) - - assert locker._pull_stronglifts_db() is None - - def test_db_uses_correct_remote_path( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test uses the correct StrongLifts DB remote path.""" - locker = create_locker(mock_tk, tmp_path) - locker._adb_shell = MagicMock( # type: ignore[method-assign] - return_value=(True, ""), - ) - locker._run_adb = MagicMock( # type: ignore[method-assign] - return_value=(True, ""), - ) - - locker._pull_stronglifts_db() - - shell_cmd = locker._adb_shell.call_args[0][0] - assert STRONGLIFTS_DB_REMOTE in shell_cmd - - -class TestCountTodayWorkouts: - """Tests for _count_today_workouts method.""" - - def test_workouts_found_today( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test workouts found today.""" - locker = create_locker(mock_tk, tmp_path) - db_file = tmp_path / "sl_test.db" - conn = sqlite3.connect(str(db_file)) - conn.execute( - "CREATE TABLE workouts " - "(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)", - ) - # Insert a workout with today's timestamp (ms) - import time - - now_ms = int(time.time() * 1000) - conn.execute( - "INSERT INTO workouts VALUES (?, ?, ?)", - ("w1", now_ms, now_ms + 3600000), - ) - conn.commit() - conn.close() - - assert locker._count_today_workouts(db_file) == 1 - - def test_no_workouts_today( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test no workouts today.""" - locker = create_locker(mock_tk, tmp_path) - db_file = tmp_path / "sl_test.db" - conn = sqlite3.connect(str(db_file)) - conn.execute( - "CREATE TABLE workouts " - "(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)", - ) - # Insert a workout from yesterday (24h+ ago) - import time - - yesterday_ms = int((time.time() - 200000) * 1000) - conn.execute( - "INSERT INTO workouts VALUES (?, ?, ?)", - ("w1", yesterday_ms, yesterday_ms + 3600000), - ) - conn.commit() - conn.close() - - assert locker._count_today_workouts(db_file) == 0 - - def test_invalid_db_returns_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test returns 0 for invalid database file.""" - locker = create_locker(mock_tk, tmp_path) - bad_file = tmp_path / "not_a_db.db" - bad_file.write_text("not a database") - - assert locker._count_today_workouts(bad_file) == 0 - - def test_missing_table_returns_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test returns 0 when workouts table doesn't exist.""" - locker = create_locker(mock_tk, tmp_path) - db_file = tmp_path / "empty.db" - conn = sqlite3.connect(str(db_file)) - conn.execute("CREATE TABLE other (id TEXT)") - conn.commit() - conn.close() - - assert locker._count_today_workouts(db_file) == 0 - - def test_multiple_workouts_today( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test counts multiple workouts today correctly.""" - locker = create_locker(mock_tk, tmp_path) - db_file = tmp_path / "sl_test.db" - conn = sqlite3.connect(str(db_file)) - conn.execute( - "CREATE TABLE workouts " - "(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)", - ) - import time - - now_ms = int(time.time() * 1000) - conn.execute( - "INSERT INTO workouts VALUES (?, ?, ?)", - ("w1", now_ms, now_ms + 3600000), - ) - conn.execute( - "INSERT INTO workouts VALUES (?, ?, ?)", - ("w2", now_ms + 100000, now_ms + 3700000), - ) - conn.commit() - conn.close() - - assert locker._count_today_workouts(db_file) == 2 - - -class TestVerifyPhoneWorkout: - """Tests for _verify_phone_workout method.""" - - def test_verified( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test workout verified on phone.""" - locker = create_locker(mock_tk, tmp_path) - locker._is_phone_connected = MagicMock( # type: ignore[method-assign] - return_value=True, - ) - locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign] - return_value=tmp_path / "sl.db", - ) - locker._count_today_workouts = MagicMock( # type: ignore[method-assign] - return_value=2, - ) - - status, message = locker._verify_phone_workout() - - assert status == "verified" - assert "2 session" in message - - def test_not_verified( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test no workout found on phone.""" - locker = create_locker(mock_tk, tmp_path) - locker._is_phone_connected = MagicMock( # type: ignore[method-assign] - return_value=True, - ) - locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign] - return_value=tmp_path / "sl.db", - ) - locker._count_today_workouts = MagicMock( # type: ignore[method-assign] - return_value=0, - ) - - status, message = locker._verify_phone_workout() - - assert status == "not_verified" - assert "No workout" in message - - def test_no_phone( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test no phone connected.""" - locker = create_locker(mock_tk, tmp_path) - locker._is_phone_connected = MagicMock( # type: ignore[method-assign] - return_value=False, - ) - - status, _ = locker._verify_phone_workout() - - assert status == "no_phone" - - def test_error_no_db( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test error when StrongLifts DB cannot be pulled.""" - locker = create_locker(mock_tk, tmp_path) - locker._is_phone_connected = MagicMock( # type: ignore[method-assign] - return_value=True, - ) - locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign] - return_value=None, - ) - - status, message = locker._verify_phone_workout() - - assert status == "error" - assert "database" in message.lower() - - -class TestStartPhoneCheck: - """Tests for _start_phone_check and _handle_startup_phone_result.""" - - def test_start_phone_check_shows_checking_screen( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _start_phone_check shows checking message and starts check.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker._verify_phone_workout = MagicMock( # type: ignore[method-assign] - return_value=("no_phone", "No phone"), - ) - locker._poll_phone_check = MagicMock() # type: ignore[method-assign] - - locker._start_phone_check() - - locker.clear_container.assert_called() - locker._poll_phone_check.assert_called_once() - assert locker._phone_future is not None - - def test_handle_startup_verified_unlocks_directly( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test verified result shows success screen then unlocks via after().""" - locker = create_locker(mock_tk, tmp_path) - locker.unlock_screen = MagicMock() # type: ignore[method-assign] - locker.root.after = MagicMock() # type: ignore[method-assign] - - locker._handle_startup_phone_result("verified", "Workout verified! (1 session)") - - # unlock_screen is deferred via root.after, not called directly - locker.unlock_screen.assert_not_called() - assert locker.workout_data["type"] == "phone_verified" - locker.root.after.assert_called_once_with(1500, locker.unlock_screen) - - def test_handle_startup_not_verified_shows_block( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test not_verified result shows blocking screen with buttons.""" - locker = create_locker(mock_tk, tmp_path) - locker.clear_container = MagicMock() # type: ignore[method-assign] - locker._handle_startup_phone_result( - "not_verified", "No workout found on phone today" - ) - - locker.clear_container.assert_called() - - def test_handle_startup_no_phone_shows_penalty( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test no_phone result triggers penalty with ask_workout_done as callback.""" - locker = create_locker(mock_tk, tmp_path) - locker._show_phone_penalty = MagicMock() # type: ignore[method-assign] - - locker._handle_startup_phone_result("no_phone", "No phone") - - locker._show_phone_penalty.assert_called_once() - _, kwargs = locker._show_phone_penalty.call_args - assert kwargs["on_done"] == locker.ask_workout_done - - def test_handle_startup_error_shows_penalty( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test error result triggers penalty with ask_workout_done as callback.""" - locker = create_locker(mock_tk, tmp_path) - locker._show_phone_penalty = MagicMock() # type: ignore[method-assign] - - locker._handle_startup_phone_result("error", "DB not found") - - locker._show_phone_penalty.assert_called_once() - _, kwargs = locker._show_phone_penalty.call_args - assert kwargs["on_done"] == locker.ask_workout_done - - def test_poll_phone_check_schedules_retry_when_pending( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _poll_phone_check reschedules itself when future is not done.""" - locker = create_locker(mock_tk, tmp_path) - mock_future: MagicMock = MagicMock() - mock_future.done.return_value = False - locker._phone_future = mock_future # type: ignore[assignment] - locker.root.after = MagicMock() # type: ignore[method-assign] - - locker._poll_phone_check() - - locker.root.after.assert_called_once_with(500, locker._poll_phone_check) - - def test_poll_phone_check_routes_when_done( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _poll_phone_check calls result handler when future is done.""" - locker = create_locker(mock_tk, tmp_path) - mock_future: MagicMock = MagicMock() - mock_future.done.return_value = True - mock_future.result.return_value = ("no_phone", "No phone") - locker._phone_future = mock_future # type: ignore[assignment] - locker._handle_startup_phone_result = MagicMock() # type: ignore[method-assign] - - locker._poll_phone_check() - - locker._handle_startup_phone_result.assert_called_once_with( - "no_phone", "No phone" - ) - - -class TestAttemptUnlock: - """Tests for _attempt_unlock method.""" - - def test_attempt_unlock_calls_unlock_screen( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test _attempt_unlock calls unlock_screen directly.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "strength"} - locker.unlock_screen = MagicMock() # type: ignore[method-assign] - - locker._attempt_unlock() - - locker.unlock_screen.assert_called_once() - - -class TestShowPhonePenalty: - """Tests for _show_phone_penalty and _update_phone_penalty methods.""" - - def test_show_phone_penalty_demo_delay( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test demo mode uses short penalty delay.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=True) - locker.clear_container = MagicMock() # type: ignore[method-assign] - - locker._show_phone_penalty("test message") - - # _update_phone_penalty is called once, decrementing by 1 - assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_DEMO - 1 - - def test_show_phone_penalty_production_delay( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test production mode uses long penalty delay.""" - locker = create_locker(mock_tk, tmp_path, demo_mode=False) - locker.clear_container = MagicMock() # type: ignore[method-assign] - - locker._show_phone_penalty("test message") - - assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_PRODUCTION - 1 - - def test_update_phone_penalty_countdown( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test phone penalty countdown decrements.""" - locker = create_locker(mock_tk, tmp_path) - locker.phone_penalty_remaining = 5 - locker.phone_penalty_label = MagicMock() - - locker._update_phone_penalty() - - assert locker.phone_penalty_remaining == 4 - locker.phone_penalty_label.config.assert_called_once_with(text="5") - locker.root.after.assert_called() # type: ignore[attr-defined] - - def test_update_phone_penalty_at_zero( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test phone penalty unlocks when timer reaches zero.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "strength"} - locker.phone_penalty_remaining = 0 - locker.phone_penalty_label = MagicMock() - locker.unlock_screen = MagicMock() # type: ignore[method-assign] - locker._phone_penalty_done_fn = locker.unlock_screen # type: ignore[attr-defined] - - locker._update_phone_penalty() - - locker.unlock_screen.assert_called_once() - - -class TestUnlockScreenShutdownAdjustment: - """Tests for unlock_screen shutdown time adjustment.""" - - def test_unlock_screen_adjusts_for_running( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen adjusts shutdown for running workout.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "running"} - locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - locker.unlock_screen() - - locker._adjust_shutdown_time_later.assert_called_once() - - def test_unlock_screen_adjusts_for_strength( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen adjusts shutdown for strength workout.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "strength"} - locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - locker.unlock_screen() - - locker._adjust_shutdown_time_later.assert_called_once() - - def test_unlock_screen_adjusts_for_phone_verified( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen adjusts shutdown for phone-verified workout.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "phone_verified"} - locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - locker.unlock_screen() - - locker._adjust_shutdown_time_later.assert_called_once() - - def test_unlock_screen_skips_adjustment_for_sick_day( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen does not adjust for sick day.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "sick_day"} - locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - locker.unlock_screen() - - locker._adjust_shutdown_time_later.assert_not_called() - - def test_unlock_screen_skips_adjustment_no_type( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen does not adjust when no workout type.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {} - locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] - return_value=True - ) - - locker.unlock_screen() - - locker._adjust_shutdown_time_later.assert_not_called() - - def test_unlock_screen_handles_adjustment_failure( - self, - mock_tk: MagicMock, - _mock_sys_exit: MagicMock, - tmp_path: Path, - ) -> None: - """Test unlock_screen continues when adjustment fails.""" - locker = create_locker(mock_tk, tmp_path) - locker.log_file = tmp_path / "workout_log.json" - locker.workout_data = {"type": "running"} - locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign] - return_value=False - ) - - # Should not raise, should continue with unlock - locker.unlock_screen() - - locker._adjust_shutdown_time_later.assert_called_once() - locker.root.after.assert_called() # type: ignore[attr-defined] diff --git a/python_pkg/screen_locker/tests/test_ui_and_timers.py b/python_pkg/screen_locker/tests/test_ui_and_timers.py new file mode 100644 index 0000000..209b104 --- /dev/null +++ b/python_pkg/screen_locker/tests/test_ui_and_timers.py @@ -0,0 +1,424 @@ +"""Tests for UI transitions, timer logic, and workout detail screens.""" + +from __future__ import annotations + +import tkinter as tk +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +from python_pkg.screen_locker.screen_lock import ( + SUBMIT_DELAY_DEMO, + SUBMIT_DELAY_PRODUCTION, +) +from python_pkg.screen_locker.tests.conftest import create_locker + +if TYPE_CHECKING: + from pathlib import Path + +_TK_TCLERROR = tk.TclError + + +class TestUITransitions: + """Tests for UI state transitions.""" + + def test_clear_container( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test clear_container destroys all child widgets.""" + locker = create_locker(mock_tk, tmp_path) + + # Set up mock children + mock_child1 = MagicMock() + mock_child2 = MagicMock() + locker.container.winfo_children.return_value = [ # type: ignore[attr-defined] + mock_child1, + mock_child2, + ] + + locker.clear_container() + + mock_child1.destroy.assert_called_once() + mock_child2.destroy.assert_called_once() + + def test_unlock_screen_saves_and_schedules_close( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test unlock_screen saves log and schedules close.""" + locker = create_locker(mock_tk, tmp_path) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "running"} + + locker.unlock_screen() + + # Check that after() was called to schedule close + locker.root.after.assert_called() # type: ignore[attr-defined] + + def test_lockout_starts_countdown( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test lockout initializes countdown timer.""" + locker = create_locker(mock_tk, tmp_path) + + locker.lockout() + + # lockout() sets remaining_time to lockout_time (10 in demo mode) + # then calls update_lockout_countdown() which decrements it by 1 + assert locker.remaining_time == 9 # 10 - 1 after first update + + def test_close_destroys_root_and_exits( + self, + mock_tk: MagicMock, + mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test close destroys root window and exits.""" + locker = create_locker(mock_tk, tmp_path) + + locker.close() + + locker.root.destroy.assert_called_once() # type: ignore[attr-defined] + mock_sys_exit.assert_called_with(0) + + +class TestTimerLogic: + """Tests for timer countdown logic.""" + + def test_update_lockout_countdown_decrements( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test countdown decrements remaining time.""" + locker = create_locker(mock_tk, tmp_path) + locker.remaining_time = 5 + locker.countdown_label = MagicMock() + + locker.update_lockout_countdown() + + assert locker.remaining_time == 4 + locker.root.after.assert_called_with( # type: ignore[attr-defined] + 1000, locker.update_lockout_countdown + ) + + def test_update_lockout_countdown_at_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test countdown at zero returns to workout question.""" + locker = create_locker(mock_tk, tmp_path) + locker.remaining_time = 0 + locker.countdown_label = MagicMock() + locker.ask_workout_done = MagicMock() # type: ignore[method-assign] + + locker.update_lockout_countdown() + + locker.ask_workout_done.assert_called_once() + + def test_update_submit_timer_countdown( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test submit timer counts down.""" + locker = create_locker(mock_tk, tmp_path) + locker.submit_unlock_time = 5 + locker.timer_label = MagicMock() + locker.submit_btn = MagicMock() + locker.entries_to_check = [] + + locker.update_submit_timer() + + assert locker.submit_unlock_time == 4 + locker.root.after.assert_called() # type: ignore[attr-defined] + + def test_update_submit_timer_enables_when_filled( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test submit enabled when timer done and entries filled.""" + locker = create_locker(mock_tk, tmp_path) + locker.submit_unlock_time = 0 + locker.timer_label = MagicMock() + locker.submit_btn = MagicMock() + mock_entry = MagicMock() + mock_entry.get.return_value = "some value" + locker.entries_to_check = [mock_entry] + locker.submit_command = MagicMock() + + locker.update_submit_timer() + + locker.submit_btn.config.assert_called() + + def test_update_submit_timer_waits_for_entries( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test submit waits when entries not filled.""" + locker = create_locker(mock_tk, tmp_path) + locker.submit_unlock_time = 0 + locker.timer_label = MagicMock() + locker.submit_btn = MagicMock() + mock_entry = MagicMock() + mock_entry.get.return_value = "" # Empty entry + locker.entries_to_check = [mock_entry] + + locker.update_submit_timer() + + locker.root.after.assert_called_with( # type: ignore[attr-defined] + 1000, locker.check_entries_filled + ) + + def test_update_submit_timer_handles_tcl_error( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test timer handles TclError when widgets destroyed.""" + locker = create_locker(mock_tk, tmp_path) + locker.submit_unlock_time = 5 + locker.timer_label = MagicMock() + locker.timer_label.config.side_effect = _TK_TCLERROR("widget destroyed") + + # Should not raise + locker.update_submit_timer() + + def test_check_entries_filled_enables_submit( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test check_entries_filled enables submit when all filled.""" + locker = create_locker(mock_tk, tmp_path) + locker.timer_label = MagicMock() + locker.submit_btn = MagicMock() + mock_entry = MagicMock() + mock_entry.get.return_value = "value" + locker.entries_to_check = [mock_entry] + locker.submit_command = MagicMock() + + locker.check_entries_filled() + + locker.submit_btn.config.assert_called() + + def test_check_entries_filled_continues_waiting( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test check_entries_filled continues waiting when not filled.""" + locker = create_locker(mock_tk, tmp_path) + locker.timer_label = MagicMock() + locker.submit_btn = MagicMock() + mock_entry = MagicMock() + mock_entry.get.return_value = "" + locker.entries_to_check = [mock_entry] + + locker.check_entries_filled() + + locker.root.after.assert_called_with( # type: ignore[attr-defined] + 1000, locker.check_entries_filled + ) + + def test_check_entries_filled_handles_tcl_error( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test check_entries_filled handles TclError.""" + locker = create_locker(mock_tk, tmp_path) + locker.timer_label = MagicMock() + mock_entry = MagicMock() + mock_entry.get.side_effect = _TK_TCLERROR("widget destroyed") + locker.entries_to_check = [mock_entry] + + # Should not raise + locker.check_entries_filled() + + +class TestAskWorkoutType: + """Tests for ask_workout_type method.""" + + def test_ask_workout_type_creates_buttons( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_workout_type creates running and strength buttons.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + + locker.ask_workout_type() + + locker.clear_container.assert_called_once() + # Verify Label and Button were called + mock_tk.Label.assert_called() + mock_tk.Button.assert_called() + + +class TestAskRunningDetails: + """Tests for ask_running_details method.""" + + def test_ask_running_details_sets_workout_type( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_running_details sets workout type to running.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_running_details() + + assert locker.workout_data["type"] == "running" + locker.clear_container.assert_called_once() + + def test_ask_running_details_creates_entry_fields( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_running_details creates entry fields.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_running_details() + + # Verify Entry fields were created + mock_tk.Entry.assert_called() + assert hasattr(locker, "distance_entry") + assert hasattr(locker, "time_entry") + assert hasattr(locker, "pace_entry") + + def test_ask_running_details_sets_timer( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_running_details initializes submit timer.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_running_details() + + assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO + locker.update_submit_timer.assert_called_once() + + +class TestAskStrengthDetails: + """Tests for ask_strength_details method.""" + + def test_ask_strength_details_sets_workout_type( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_strength_details sets workout type to strength.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_strength_details() + + assert locker.workout_data["type"] == "strength" + locker.clear_container.assert_called_once() + + def test_ask_strength_details_creates_entry_fields( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_strength_details creates entry fields.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_strength_details() + + # Verify Entry fields were created + mock_tk.Entry.assert_called() + assert hasattr(locker, "exercises_entry") + assert hasattr(locker, "sets_entry") + assert hasattr(locker, "reps_entry") + assert hasattr(locker, "weights_entry") + assert hasattr(locker, "total_weight_entry") + + def test_ask_strength_details_sets_timer( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_strength_details initializes submit timer.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_strength_details() + + assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO + locker.update_submit_timer.assert_called_once() + + def test_ask_strength_details_production_timer( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test production mode uses longer submit delay.""" + locker = create_locker(mock_tk, tmp_path, demo_mode=False) + locker.clear_container = MagicMock() # type: ignore[method-assign] + locker.update_submit_timer = MagicMock() # type: ignore[method-assign] + + locker.ask_strength_details() + + assert locker.submit_unlock_time == SUBMIT_DELAY_PRODUCTION + + +class TestAskWorkoutDone: + """Tests for ask_workout_done method.""" + + def test_ask_workout_done_creates_buttons( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test ask_workout_done creates yes/no buttons.""" + locker = create_locker(mock_tk, tmp_path) + locker.clear_container = MagicMock() # type: ignore[method-assign] + + locker.ask_workout_done() + + locker.clear_container.assert_called_once() + mock_tk.Label.assert_called() + mock_tk.Button.assert_called() diff --git a/python_pkg/screen_locker/tests/test_verify_data.py b/python_pkg/screen_locker/tests/test_verify_data.py new file mode 100644 index 0000000..05a064b --- /dev/null +++ b/python_pkg/screen_locker/tests/test_verify_data.py @@ -0,0 +1,371 @@ +"""Tests for running and strength data verification.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +from python_pkg.screen_locker.tests.conftest import ( + RunningData, + StrengthData, + create_locker, + setup_running_entries, + setup_strength_entries, +) + +if TYPE_CHECKING: + from pathlib import Path + + +class TestVerifyRunningData: + """Tests for verify_running_data method.""" + + def test_valid_running_data( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test valid running data triggers unlock attempt.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("5", "25", "5")) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "running"} + locker._attempt_unlock = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker._attempt_unlock.assert_called_once() + + def test_invalid_distance_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test zero distance is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("0", "25", "5")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Distance" in locker.show_error.call_args[0][0] + + def test_invalid_distance_too_high( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test distance over max is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("150", "600", "4")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Distance" in locker.show_error.call_args[0][0] + + def test_invalid_time_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test zero time is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("5", "0", "5")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Time" in locker.show_error.call_args[0][0] + + def test_invalid_time_too_high( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test time over max is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("5", "700", "5")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Time" in locker.show_error.call_args[0][0] + + def test_invalid_pace_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test zero pace is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("5", "25", "0")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Pace" in locker.show_error.call_args[0][0] + + def test_invalid_pace_too_high( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test pace over max is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("5", "25", "25")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Pace" in locker.show_error.call_args[0][0] + + def test_pace_mismatch( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test pace mismatch is rejected.""" + # 5km in 25 min should be 5 min/km, but we say 10 min/km + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("5", "25", "10")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "Pace doesn't match" in locker.show_error.call_args[0][0] + + def test_invalid_number_format( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test non-numeric input is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_running_entries(locker, RunningData("abc", "25", "5")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_running_data() + + locker.show_error.assert_called_once() + assert "valid numbers" in locker.show_error.call_args[0][0] + + +class TestVerifyStrengthData: + """Tests for verify_strength_data method.""" + + def test_valid_strength_data( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test valid strength data triggers unlock attempt.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "3", "10", "50", "1500")) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "strength"} + locker._attempt_unlock = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker._attempt_unlock.assert_called_once() + + def test_valid_multiple_exercises( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test valid data with multiple exercises.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries( + locker, + StrengthData("Squat, Bench Press", "3, 3", "10, 8", "50, 40", "2460"), + ) + locker.log_file = tmp_path / "workout_log.json" + locker.workout_data = {"type": "strength"} + locker._attempt_unlock = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker._attempt_unlock.assert_called_once() + + def test_mismatched_list_lengths( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test mismatched list lengths are rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries( + locker, + StrengthData("Squat, Bench", "3", "10, 8", "50, 40", "2000"), + ) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "must match" in locker.show_error.call_args[0][0] + + def test_short_exercise_name( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test short exercise names are rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Sq", "3", "10", "50", "1500")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "too short" in locker.show_error.call_args[0][0] + + def test_invalid_sets_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test zero sets is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "0", "10", "50", "0")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Sets" in locker.show_error.call_args[0][0] + + def test_invalid_sets_too_high( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test sets over max is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "25", "10", "50", "12500")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Sets" in locker.show_error.call_args[0][0] + + def test_invalid_reps_zero( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test zero reps is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "3", "0", "50", "0")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Reps" in locker.show_error.call_args[0][0] + + def test_invalid_reps_too_high( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test reps over max is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "3", "150", "50", "22500")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Reps" in locker.show_error.call_args[0][0] + + def test_invalid_weight_negative( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test negative weight is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "3", "10", "-10", "-300")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Weights" in locker.show_error.call_args[0][0] + + def test_invalid_weight_too_high( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test weight over max is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "3", "10", "600", "18000")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Weights" in locker.show_error.call_args[0][0] + + def test_total_weight_mismatch( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test total weight mismatch is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "3", "10", "50", "3000")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "Total weight doesn't match" in locker.show_error.call_args[0][0] + + def test_invalid_format( + self, + mock_tk: MagicMock, + _mock_sys_exit: MagicMock, + tmp_path: Path, + ) -> None: + """Test invalid format is rejected.""" + locker = create_locker(mock_tk, tmp_path) + setup_strength_entries(locker, StrengthData("Squat", "abc", "10", "50", "1500")) + locker.show_error = MagicMock() # type: ignore[method-assign] + + locker.verify_strength_data() + + locker.show_error.assert_called_once() + assert "valid data" in locker.show_error.call_args[0][0] diff --git a/python_pkg/steam_backlog_enforcer/game_install.py b/python_pkg/steam_backlog_enforcer/game_install.py new file mode 100644 index 0000000..733ceea --- /dev/null +++ b/python_pkg/steam_backlog_enforcer/game_install.py @@ -0,0 +1,349 @@ +"""Game installation and uninstallation management.""" + +from __future__ import annotations + +import contextlib +import logging +import os +from pathlib import Path +import pwd +import re +import shutil +import subprocess +import sys +import time + +logger = logging.getLogger(__name__) + + +def _echo(msg: str = "", *, end: str = "\n", flush: bool = False) -> None: + """Write user-facing CLI output to stdout. + + Args: + msg: Text to output. + end: String appended after the message. + flush: Whether to flush stdout immediately. + """ + sys.stdout.write(msg + end) + if flush: + sys.stdout.flush() + + +# Steam infrastructure app IDs that should NEVER be uninstalled. +PROTECTED_APP_IDS = { + 228980, # Steamworks Common Redistributables + 1070560, # Steam Linux Runtime 1.0 (scout) + 1391110, # Steam Linux Runtime 2.0 (soldier) + 1628350, # Steam Linux Runtime 3.0 (sniper) + 961940, # Steam Linux Runtime (legacy) + # Proton versions (never uninstall these) + 858280, # Proton 3.7 (Beta) + 930400, # Proton 3.16 (Beta) + 1054830, # Proton 4.2 + 1113280, # Proton 4.11 + 1245040, # Proton 5.0 + 1420170, # Proton 5.13 + 1580130, # Proton 6.3 + 1887720, # Proton 7.0 + 2230260, # Proton 7.0 (alt) + 2348590, # Proton 8.0 + 2805730, # Proton 9.0 + 3201940, # Proton 9.0 (alt) + 3658110, # Proton 10.0 + 2180100, # Proton Hotfix + 1493710, # Proton Experimental + 1161040, # Proton BattlEye Runtime + 1007020, # Proton EasyAntiCheat Runtime + # Games allowed to be installed anytime + 3949040, # RV There Yet? +} + +STEAMAPPS_PATH = Path("~/.local/share/Steam/steamapps").expanduser() + + +# ────────────────────────────────────────────────────────────── +# Game install management +# ────────────────────────────────────────────────────────────── + + +def _get_real_user() -> str | None: + """Get the real (non-root) user when running under sudo.""" + return os.environ.get("SUDO_USER") or os.environ.get("USER") + + +def _get_uid_gid_for_user(username: str) -> tuple[int, int]: + """Get (uid, gid) for a username.""" + try: + pw = pwd.getpwnam(username) + except KeyError: + return 1000, 1000 + else: + return pw.pw_uid, pw.pw_gid + + +def is_game_installed(app_id: int) -> bool: + """Check if a game is installed by looking for its appmanifest. + + A manifest with StateFlags != 4 (FullyInstalled) means the game is + still downloading or queued, which still counts as "install triggered". + """ + manifest = STEAMAPPS_PATH / f"appmanifest_{app_id}.acf" + return manifest.exists() + + +def _ensure_steam_running() -> None: + """Start the Steam client if it is not already running.""" + # Check if any steam process is running (main client, not just helpers). + try: + result = subprocess.run( + ["/usr/bin/pgrep", "-f", "steam.sh"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + logger.debug("Steam client already running") + return + except FileNotFoundError: + pass + + real_user = _get_real_user() + logger.info("Starting Steam client...") + + try: + if os.geteuid() == 0 and real_user and real_user != "root": + uid, _ = _get_uid_gid_for_user(real_user) + dbus_default = f"unix:path=/run/user/{uid}/bus" + dbus_addr = os.environ.get("DBUS_SESSION_BUS_ADDRESS", dbus_default) + xauth_default = f"/home/{real_user}/.Xauthority" + xauth = os.environ.get("XAUTHORITY", xauth_default) + cmd = [ + "sudo", + "-u", + real_user, + "env", + f"DISPLAY={os.environ.get('DISPLAY', ':0')}", + f"XAUTHORITY={xauth}", + f"DBUS_SESSION_BUS_ADDRESS={dbus_addr}", + "steam", + "-silent", + ] + else: + cmd = ["steam", "-silent"] + + subprocess.Popen( + cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + # Give Steam time to initialize and start scanning manifests. + time.sleep(15) + except FileNotFoundError: + logger.exception("Steam executable not found") + + +def install_game(app_id: int, game_name: str, steam_id: str) -> bool: + """Install a game by writing an appmanifest that triggers Steam's download. + + Creates a minimal appmanifest with StateFlags=1026 (UpdateRequired | + UpdateStarted) in the steamapps directory. The running Steam client + detects the new manifest and automatically queues the download — no + dialog or user interaction required. + + If Steam is not running it will be started in silent mode first. + + Args: + app_id: Steam application ID. + game_name: Human-readable game name. + steam_id: Steam64 ID of the account that owns the game. + + Returns True if the manifest was written successfully. + """ + label = game_name or f"AppID={app_id}" + + if is_game_installed(app_id): + logger.info("Game already installed: %s", label) + return True + + # Build a minimal appmanifest. StateFlags 1026 = UpdateRequired (2) + + # UpdateStarted (1024), which tells Steam "this app needs downloading". + manifest_content = ( + '"AppState"\n' + "{\n" + f'\t"appid"\t\t"{app_id}"\n' + '\t"universe"\t\t"1"\n' + f'\t"name"\t\t"{game_name}"\n' + '\t"StateFlags"\t\t"1026"\n' + f'\t"installdir"\t\t"{game_name}"\n' + '\t"LastUpdated"\t\t"0"\n' + '\t"LastPlayed"\t\t"0"\n' + '\t"SizeOnDisk"\t\t"0"\n' + '\t"StagingSize"\t\t"0"\n' + '\t"buildid"\t\t"0"\n' + f'\t"LastOwner"\t\t"{steam_id}"\n' + '\t"UpdateResult"\t\t"0"\n' + '\t"BytesToDownload"\t\t"0"\n' + '\t"BytesDownloaded"\t\t"0"\n' + '\t"BytesToStage"\t\t"0"\n' + '\t"BytesStaged"\t\t"0"\n' + '\t"TargetBuildID"\t\t"0"\n' + '\t"AutoUpdateBehavior"\t\t"0"\n' + '\t"AllowOtherDownloadsWhileRunning"\t\t"0"\n' + '\t"ScheduledAutoUpdate"\t\t"0"\n' + '\t"InstalledDepots"\n' + "\t{\n" + "\t}\n" + '\t"UserConfig"\n' + "\t{\n" + "\t}\n" + '\t"MountedConfig"\n' + "\t{\n" + "\t}\n" + "}\n" + ) + + manifest_path = STEAMAPPS_PATH / f"appmanifest_{app_id}.acf" + + try: + with manifest_path.open("w", encoding="utf-8") as fh: + fh.write(manifest_content) + + # Fix ownership so the Steam client (running as the real user) can + # read and update the manifest. + real_user = _get_real_user() + if os.geteuid() == 0 and real_user and real_user != "root": + uid, gid = _get_uid_gid_for_user(real_user) + os.chown(manifest_path, uid, gid) + + logger.info("Created appmanifest for %s — Steam will auto-download", label) + except OSError: + logger.exception("Failed to create appmanifest for %s", label) + return False + + # Make sure Steam is running so it picks up the manifest. + _ensure_steam_running() + + return True + + +# ────────────────────────────────────────────────────────────── +# Game uninstall management +# ────────────────────────────────────────────────────────────── + + +def get_installed_games() -> list[tuple[int, str]]: + """Parse appmanifest files to find installed games. + + Returns: list of (app_id, game_name) tuples. + """ + installed: list[tuple[int, str]] = [] + + for manifest_file in STEAMAPPS_PATH.glob("appmanifest_*.acf"): + with contextlib.suppress(OSError): + content = manifest_file.read_text(encoding="utf-8") + app_id_match = re.search(r'"appid"\s+"(\d+)"', content) + name_match = re.search(r'"name"\s+"([^"]+)"', content) + if app_id_match: + app_id = int(app_id_match.group(1)) + name = name_match.group(1) if name_match else f"Unknown ({app_id})" + installed.append((app_id, name)) + + installed.sort(key=lambda x: x[1].lower()) + return installed + + +def _read_install_dir(manifest: Path) -> Path | None: + """Read installdir from a game's appmanifest file.""" + if not manifest.exists(): + return None + try: + content = manifest.read_text(encoding="utf-8") + match = re.search(r'"installdir"\s+"([^"]+)"', content) + if match: + return STEAMAPPS_PATH / "common" / match.group(1) + except OSError: + pass + return None + + +def _remove_manifest(manifest: Path, game_name: str, app_id: int) -> bool: + """Remove a game manifest file. + + Args: + manifest: Path to the appmanifest file. + game_name: Human-readable game name for logging. + app_id: Steam application ID. + """ + try: + if manifest.exists(): + manifest.unlink() + logger.info( + "Removed manifest for %s (AppID=%d)", game_name or app_id, app_id + ) + except OSError: + logger.exception("Failed to remove manifest for AppID=%d", app_id) + return False + return True + + +def _remove_game_dirs(install_dir: Path | None, app_id: int) -> bool: + """Remove game installation directory and cache directories. + + Args: + install_dir: Path to the game's install directory, or None. + app_id: Steam application ID. + """ + success = True + if install_dir and install_dir.is_dir(): + try: + shutil.rmtree(install_dir) + logger.info("Removed game files: %s", install_dir) + except OSError: + logger.exception("Failed to remove game dir %s", install_dir) + success = False + + for subdir in ("shadercache", "compatdata"): + cache_path = STEAMAPPS_PATH / subdir / str(app_id) + if cache_path.is_dir(): + with contextlib.suppress(OSError): + shutil.rmtree(cache_path) + logger.debug("Removed %s/%d", subdir, app_id) + + return success + + +def uninstall_game(app_id: int, game_name: str = "") -> bool: + """Uninstall a single game by removing its manifest and game files. + + Uses direct file removal instead of ``steam://uninstall`` URI to avoid + GUI popups and to work when Steam is not running. + """ + manifest = STEAMAPPS_PATH / f"appmanifest_{app_id}.acf" + install_dir = _read_install_dir(manifest) + success = _remove_manifest(manifest, game_name, app_id) + if not _remove_game_dirs(install_dir, app_id): + success = False + return success + + +def uninstall_other_games(allowed_app_id: int | None) -> int: + """Uninstall all installed games except the assigned one and protected IDs. + + Returns: number of games uninstalled. + """ + installed = get_installed_games() + count = 0 + + for app_id, name in installed: + if app_id == allowed_app_id: + logger.info("KEEPING assigned game: %s (AppID=%d)", name, app_id) + continue + if app_id in PROTECTED_APP_IDS: + logger.debug("Skipping protected: %s (AppID=%d)", name, app_id) + continue + + logger.info("UNINSTALLING: %s (AppID=%d)", name, app_id) + if uninstall_game(app_id, name): + count += 1 + + return count diff --git a/python_pkg/steam_backlog_enforcer/main.py b/python_pkg/steam_backlog_enforcer/main.py index 5093a4b..3b8a7dd 100644 --- a/python_pkg/steam_backlog_enforcer/main.py +++ b/python_pkg/steam_backlog_enforcer/main.py @@ -2,29 +2,27 @@ from __future__ import annotations -import contextlib import logging -import os -from pathlib import Path -import pwd -import re -import shutil -import subprocess import sys -import time -from typing import Any from python_pkg.steam_backlog_enforcer.config import ( Config, State, interactive_setup, load_snapshot, - save_snapshot, ) from python_pkg.steam_backlog_enforcer.enforcer import ( enforce_allowed_game, send_notification, ) +from python_pkg.steam_backlog_enforcer.game_install import ( + PROTECTED_APP_IDS, + _echo, + get_installed_games, + install_game, + is_game_installed, + uninstall_other_games, +) from python_pkg.steam_backlog_enforcer.hltb import ( fetch_hltb_times_cached, load_hltb_cache, @@ -34,9 +32,13 @@ from python_pkg.steam_backlog_enforcer.library_hider import ( restart_steam, unhide_all_games, ) -from python_pkg.steam_backlog_enforcer.protondb import ( - ProtonDBRating, - fetch_protondb_ratings, +from python_pkg.steam_backlog_enforcer.scanning import ( + _pick_playable_candidate, + do_check, + do_enforce, + do_scan, + get_all_owned_app_ids, + pick_next_game, ) from python_pkg.steam_backlog_enforcer.steam_api import GameInfo, SteamAPIClient from python_pkg.steam_backlog_enforcer.store_blocker import ( @@ -52,783 +54,8 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) - -def _echo(msg: str = "", *, end: str = "\n", flush: bool = False) -> None: - """Write user-facing CLI output to stdout. - - Args: - msg: Text to output. - end: String appended after the message. - flush: Whether to flush stdout immediately. - """ - sys.stdout.write(msg + end) - if flush: - sys.stdout.flush() - - -# Steam infrastructure app IDs that should NEVER be uninstalled. -PROTECTED_APP_IDS = { - 228980, # Steamworks Common Redistributables - 1070560, # Steam Linux Runtime 1.0 (scout) - 1391110, # Steam Linux Runtime 2.0 (soldier) - 1628350, # Steam Linux Runtime 3.0 (sniper) - 961940, # Steam Linux Runtime (legacy) - # Proton versions (never uninstall these) - 858280, # Proton 3.7 (Beta) - 930400, # Proton 3.16 (Beta) - 1054830, # Proton 4.2 - 1113280, # Proton 4.11 - 1245040, # Proton 5.0 - 1420170, # Proton 5.13 - 1580130, # Proton 6.3 - 1887720, # Proton 7.0 - 2230260, # Proton 7.0 (alt) - 2348590, # Proton 8.0 - 2805730, # Proton 9.0 - 3201940, # Proton 9.0 (alt) - 3658110, # Proton 10.0 - 2180100, # Proton Hotfix - 1493710, # Proton Experimental - 1161040, # Proton BattlEye Runtime - 1007020, # Proton EasyAntiCheat Runtime - # Games allowed to be installed anytime - 3949040, # RV There Yet? -} - -STEAMAPPS_PATH = Path("~/.local/share/Steam/steamapps").expanduser() - _LIST_DISPLAY_LIMIT = 50 _MIN_CLI_ARGS = 2 -_TAMPER_CHECK_LIMIT = 3 - - -# ────────────────────────────────────────────────────────────── -# Game install management -# ────────────────────────────────────────────────────────────── - - -def _get_real_user() -> str | None: - """Get the real (non-root) user when running under sudo.""" - return os.environ.get("SUDO_USER") or os.environ.get("USER") - - -def _get_uid_gid_for_user(username: str) -> tuple[int, int]: - """Get (uid, gid) for a username.""" - try: - pw = pwd.getpwnam(username) - except KeyError: - return 1000, 1000 - else: - return pw.pw_uid, pw.pw_gid - - -def is_game_installed(app_id: int) -> bool: - """Check if a game is installed by looking for its appmanifest. - - A manifest with StateFlags != 4 (FullyInstalled) means the game is - still downloading or queued, which still counts as "install triggered". - """ - manifest = STEAMAPPS_PATH / f"appmanifest_{app_id}.acf" - return manifest.exists() - - -def _ensure_steam_running() -> None: - """Start the Steam client if it is not already running.""" - # Check if any steam process is running (main client, not just helpers). - try: - result = subprocess.run( - ["/usr/bin/pgrep", "-f", "steam.sh"], - capture_output=True, - text=True, - check=False, - ) - if result.returncode == 0: - logger.debug("Steam client already running") - return - except FileNotFoundError: - pass - - real_user = _get_real_user() - logger.info("Starting Steam client...") - - try: - if os.geteuid() == 0 and real_user and real_user != "root": - uid, _ = _get_uid_gid_for_user(real_user) - dbus_default = f"unix:path=/run/user/{uid}/bus" - dbus_addr = os.environ.get("DBUS_SESSION_BUS_ADDRESS", dbus_default) - xauth_default = f"/home/{real_user}/.Xauthority" - xauth = os.environ.get("XAUTHORITY", xauth_default) - cmd = [ - "sudo", - "-u", - real_user, - "env", - f"DISPLAY={os.environ.get('DISPLAY', ':0')}", - f"XAUTHORITY={xauth}", - f"DBUS_SESSION_BUS_ADDRESS={dbus_addr}", - "steam", - "-silent", - ] - else: - cmd = ["steam", "-silent"] - - subprocess.Popen( - cmd, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - # Give Steam time to initialize and start scanning manifests. - time.sleep(15) - except FileNotFoundError: - logger.exception("Steam executable not found") - - -def install_game(app_id: int, game_name: str, steam_id: str) -> bool: - """Install a game by writing an appmanifest that triggers Steam's download. - - Creates a minimal appmanifest with StateFlags=1026 (UpdateRequired | - UpdateStarted) in the steamapps directory. The running Steam client - detects the new manifest and automatically queues the download — no - dialog or user interaction required. - - If Steam is not running it will be started in silent mode first. - - Args: - app_id: Steam application ID. - game_name: Human-readable game name. - steam_id: Steam64 ID of the account that owns the game. - - Returns True if the manifest was written successfully. - """ - label = game_name or f"AppID={app_id}" - - if is_game_installed(app_id): - logger.info("Game already installed: %s", label) - return True - - # Build a minimal appmanifest. StateFlags 1026 = UpdateRequired (2) + - # UpdateStarted (1024), which tells Steam "this app needs downloading". - manifest_content = ( - '"AppState"\n' - "{\n" - f'\t"appid"\t\t"{app_id}"\n' - '\t"universe"\t\t"1"\n' - f'\t"name"\t\t"{game_name}"\n' - '\t"StateFlags"\t\t"1026"\n' - f'\t"installdir"\t\t"{game_name}"\n' - '\t"LastUpdated"\t\t"0"\n' - '\t"LastPlayed"\t\t"0"\n' - '\t"SizeOnDisk"\t\t"0"\n' - '\t"StagingSize"\t\t"0"\n' - '\t"buildid"\t\t"0"\n' - f'\t"LastOwner"\t\t"{steam_id}"\n' - '\t"UpdateResult"\t\t"0"\n' - '\t"BytesToDownload"\t\t"0"\n' - '\t"BytesDownloaded"\t\t"0"\n' - '\t"BytesToStage"\t\t"0"\n' - '\t"BytesStaged"\t\t"0"\n' - '\t"TargetBuildID"\t\t"0"\n' - '\t"AutoUpdateBehavior"\t\t"0"\n' - '\t"AllowOtherDownloadsWhileRunning"\t\t"0"\n' - '\t"ScheduledAutoUpdate"\t\t"0"\n' - '\t"InstalledDepots"\n' - "\t{\n" - "\t}\n" - '\t"UserConfig"\n' - "\t{\n" - "\t}\n" - '\t"MountedConfig"\n' - "\t{\n" - "\t}\n" - "}\n" - ) - - manifest_path = STEAMAPPS_PATH / f"appmanifest_{app_id}.acf" - - try: - with manifest_path.open("w", encoding="utf-8") as fh: - fh.write(manifest_content) - - # Fix ownership so the Steam client (running as the real user) can - # read and update the manifest. - real_user = _get_real_user() - if os.geteuid() == 0 and real_user and real_user != "root": - uid, gid = _get_uid_gid_for_user(real_user) - os.chown(manifest_path, uid, gid) - - logger.info("Created appmanifest for %s — Steam will auto-download", label) - except OSError: - logger.exception("Failed to create appmanifest for %s", label) - return False - - # Make sure Steam is running so it picks up the manifest. - _ensure_steam_running() - - return True - - -# ────────────────────────────────────────────────────────────── -# Game uninstall management -# ────────────────────────────────────────────────────────────── - - -def get_installed_games() -> list[tuple[int, str]]: - """Parse appmanifest files to find installed games. - - Returns: list of (app_id, game_name) tuples. - """ - installed: list[tuple[int, str]] = [] - - for manifest_file in STEAMAPPS_PATH.glob("appmanifest_*.acf"): - with contextlib.suppress(OSError): - content = manifest_file.read_text(encoding="utf-8") - app_id_match = re.search(r'"appid"\s+"(\d+)"', content) - name_match = re.search(r'"name"\s+"([^"]+)"', content) - if app_id_match: - app_id = int(app_id_match.group(1)) - name = name_match.group(1) if name_match else f"Unknown ({app_id})" - installed.append((app_id, name)) - - installed.sort(key=lambda x: x[1].lower()) - return installed - - -def _read_install_dir(manifest: Path) -> Path | None: - """Read installdir from a game's appmanifest file.""" - if not manifest.exists(): - return None - try: - content = manifest.read_text(encoding="utf-8") - match = re.search(r'"installdir"\s+"([^"]+)"', content) - if match: - return STEAMAPPS_PATH / "common" / match.group(1) - except OSError: - pass - return None - - -def _remove_manifest(manifest: Path, game_name: str, app_id: int) -> bool: - """Remove a game manifest file. - - Args: - manifest: Path to the appmanifest file. - game_name: Human-readable game name for logging. - app_id: Steam application ID. - """ - try: - if manifest.exists(): - manifest.unlink() - logger.info( - "Removed manifest for %s (AppID=%d)", game_name or app_id, app_id - ) - except OSError: - logger.exception("Failed to remove manifest for AppID=%d", app_id) - return False - return True - - -def _remove_game_dirs(install_dir: Path | None, app_id: int) -> bool: - """Remove game installation directory and cache directories. - - Args: - install_dir: Path to the game's install directory, or None. - app_id: Steam application ID. - """ - success = True - if install_dir and install_dir.is_dir(): - try: - shutil.rmtree(install_dir) - logger.info("Removed game files: %s", install_dir) - except OSError: - logger.exception("Failed to remove game dir %s", install_dir) - success = False - - for subdir in ("shadercache", "compatdata"): - cache_path = STEAMAPPS_PATH / subdir / str(app_id) - if cache_path.is_dir(): - with contextlib.suppress(OSError): - shutil.rmtree(cache_path) - logger.debug("Removed %s/%d", subdir, app_id) - - return success - - -def uninstall_game(app_id: int, game_name: str = "") -> bool: - """Uninstall a single game by removing its manifest and game files. - - Uses direct file removal instead of `steam://uninstall` URI to avoid - GUI popups and to work when Steam is not running. - """ - manifest = STEAMAPPS_PATH / f"appmanifest_{app_id}.acf" - install_dir = _read_install_dir(manifest) - success = _remove_manifest(manifest, game_name, app_id) - if not _remove_game_dirs(install_dir, app_id): - success = False - return success - - -def uninstall_other_games(allowed_app_id: int | None) -> int: - """Uninstall all installed games except the assigned one and protected IDs. - - Returns: number of games uninstalled. - """ - installed = get_installed_games() - count = 0 - - for app_id, name in installed: - if app_id == allowed_app_id: - logger.info("KEEPING assigned game: %s (AppID=%d)", name, app_id) - continue - if app_id in PROTECTED_APP_IDS: - logger.debug("Skipping protected: %s (AppID=%d)", name, app_id) - continue - - logger.info("UNINSTALLING: %s (AppID=%d)", name, app_id) - if uninstall_game(app_id, name): - count += 1 - - return count - - -# ────────────────────────────────────────────────────────────── -# Scanning & game selection -# ────────────────────────────────────────────────────────────── - - -def do_scan(config: Config, state: State) -> list[GameInfo]: - """Full library scan: Steam API + HLTB times.""" - client = SteamAPIClient(config.steam_api_key, config.steam_id) - - start = time.time() - done_count = 0 - - def progress(current: int, total: int) -> None: - nonlocal done_count - done_count = current - if current % 50 == 0 or current == total: - _echo(f"\r Scanning achievements: {current}/{total}", end="", flush=True) - - _echo("Scanning Steam library...") - games = client.build_game_list( - skip_app_ids=config.skip_app_ids, - progress_callback=progress, - ) - elapsed = time.time() - start - _echo(f"\n Scanned {len(games)} games with achievements in {elapsed:.1f}s") - - # Fetch HLTB times (cached). - incomplete = [(g.app_id, g.name) for g in games if not g.is_complete] - if incomplete: - _echo(f"Fetching HLTB completion times for {len(incomplete)} games...") - - def hltb_progress(done: int, total: int, found: int, name: str) -> None: - pct = done * 100 // total - bar_w = 30 - filled = bar_w * done // total - bar = "█" * filled + "░" * (bar_w - filled) - _echo( - f"\r HLTB [{bar}] {done}/{total} ({pct}%) " - f"| {found} found | {name[:30]:<30s}", - end="", - flush=True, - ) - - hltb_cache = fetch_hltb_times_cached(incomplete, progress_cb=hltb_progress) - _echo("") # newline after progress bar - for g in games: - hours = hltb_cache.get(g.app_id, -1) - g.completionist_hours = hours - found = sum(1 for h in hltb_cache.values() if h > 0) - _echo(f" HLTB data: {found} games have completion estimates") - - # Save snapshot. - save_snapshot([g.to_snapshot() for g in games]) - - complete = [g for g in games if g.is_complete] - incomplete_games = [g for g in games if not g.is_complete] - _echo(f"\nResults: {len(complete)} complete, {len(incomplete_games)} incomplete") - - # Auto-pick a game if none assigned. - if state.current_app_id is None: - pick_next_game(games, state, config) - - return games - - -# How many candidates to check per ProtonDB batch. -_PROTONDB_BATCH_SIZE = 20 - - -def _pick_playable_candidate( - candidates: list[GameInfo], -) -> GameInfo | None: - """Return the first candidate with an acceptable ProtonDB rating. - - Checks candidates in batches (sorted by HLTB hours, shortest first). - Games rated silver-or-worse, or gold-trending-down, are skipped. - """ - offset = 0 - while offset < len(candidates): - batch = candidates[offset : offset + _PROTONDB_BATCH_SIZE] - app_ids = [g.app_id for g in batch] - ratings = fetch_protondb_ratings(app_ids) - - for game in batch: - rating = ratings.get(game.app_id, ProtonDBRating(app_id=game.app_id)) - if rating.is_playable: - if offset > 0 or game is not batch[0]: - _echo( - f" Skipped {offset + batch.index(game)} game(s) " - f"with poor Linux compatibility" - ) - return game - logger.info( - "Skipping %s (AppID=%d): ProtonDB %s (trending %s)", - game.name, - game.app_id, - rating.tier, - rating.trending_tier, - ) - - offset += _PROTONDB_BATCH_SIZE - - return None - - -def pick_next_game(games: list[GameInfo], state: State, config: Config) -> None: - """Select the next game: shortest completionist time first. - - Games with silver-or-worse ProtonDB ratings (or gold trending - downward) are automatically skipped as unplayable on Linux. - """ - skip = set(config.skip_app_ids) | set(state.finished_app_ids) - candidates = [g for g in games if not g.is_complete and g.app_id not in skip] - - if not candidates: - _echo("\nCongratulations! All games are complete!") - state.current_app_id = None - state.current_game_name = "" - state.save() - return - - # Sort: games with known HLTB time first (shortest), then unknown. - def sort_key(g: GameInfo) -> tuple[int, float]: - if g.completionist_hours > 0: - return (0, g.completionist_hours) - return (1, g.name.lower().encode().hex().__hash__()) - - candidates.sort(key=sort_key) - - # Filter out Linux-incompatible games via ProtonDB. - chosen = _pick_playable_candidate(candidates) - - if chosen is None: - _echo("\nNo playable games left (all have poor ProtonDB ratings)!") - state.current_app_id = None - state.current_game_name = "" - state.save() - return - - state.current_app_id = chosen.app_id - state.current_game_name = chosen.name - state.save() - - hours_str = "" - if chosen.completionist_hours > 0: - hours_str = f" (~{chosen.completionist_hours:.1f}h to 100%)" - _echo(f"\n>>> ASSIGNED: {chosen.name} (AppID={chosen.app_id}){hours_str}") - _echo( - f" Progress: {chosen.unlocked_achievements}/{chosen.total_achievements}" - f" ({chosen.completion_pct:.1f}%)" - ) - - # Uninstall all other games first, then auto-install the assigned one. - if config.uninstall_other_games: - count = uninstall_other_games(chosen.app_id) - if count: - _echo(f"\n Uninstalled {count} non-assigned games") - - if not is_game_installed(chosen.app_id): - _echo(f"\n Auto-installing {chosen.name}...") - install_game(chosen.app_id, chosen.name, config.steam_id) - - -# ────────────────────────────────────────────────────────────── -# Checking & tampering detection -# ────────────────────────────────────────────────────────────── - - -def do_check(config: Config, state: State) -> None: - """Check assigned game completion status; detect tampering.""" - if state.current_app_id is None: - _echo("No game currently assigned. Run 'scan' first.") - return - - client = SteamAPIClient(config.steam_api_key, config.steam_id) - _echo(f"Checking {state.current_game_name} (AppID={state.current_app_id})...") - - game = client.refresh_single_game(state.current_app_id, state.current_game_name) - if game is None: - _echo(" Could not fetch achievement data.") - return - - _echo( - f" Progress: {game.unlocked_achievements}/{game.total_achievements}" - f" ({game.completion_pct:.1f}%)" - ) - - if game.is_complete: - _echo(f"\n COMPLETED: {state.current_game_name}!") - state.finished_app_ids.append(state.current_app_id) - send_notification( - "Game Complete!", - f"You finished {state.current_game_name}! Picking next game...", - ) - - # Load snapshot and pick next. - snapshot_data = load_snapshot() - if snapshot_data: - games = [GameInfo.from_snapshot(d) for d in snapshot_data] - pick_next_game(games, state, config) - else: - state.current_app_id = None - state.current_game_name = "" - state.save() - _echo(" Run 'scan' to pick the next game.") - else: - remaining = game.total_achievements - game.unlocked_achievements - _echo(f" {remaining} achievements remaining. Keep going!") - - # Tampering detection on snapshot. - detect_tampering(config, state) - - -def _check_game_tampering( - client: SteamAPIClient, - entry: dict[str, Any], - state: State, -) -> tuple[str, int, int] | None: - """Check if a single game has unexpected achievement progress. - - Args: - client: Steam API client. - entry: Snapshot entry for the game. - state: Current enforcer state. - - Returns: - Tuple of (name, app_id, diff) if tampering detected, else None. - """ - app_id = entry["app_id"] - if app_id == state.current_app_id: - return None - if entry["unlocked_achievements"] >= entry["total_achievements"]: - return None - if entry.get("playtime_minutes", 0) <= 0: - return None - game = client.refresh_single_game( - app_id, entry["name"], entry.get("playtime_minutes", 0) - ) - if game and game.unlocked_achievements > entry["unlocked_achievements"]: - diff = game.unlocked_achievements - entry["unlocked_achievements"] - return (entry["name"], app_id, diff) - return None - - -def detect_tampering(config: Config, state: State) -> None: - """Check if achievements were unlocked on non-assigned games.""" - old_snapshot = load_snapshot() - if old_snapshot is None: - return - - client = SteamAPIClient(config.steam_api_key, config.steam_id) - - # Quick check: only re-fetch a few random non-assigned games. - suspicious: list[tuple[str, int, int]] = [] - for entry in old_snapshot: - result = _check_game_tampering(client, entry, state) - if result: - suspicious.append(result) - if len(suspicious) >= _TAMPER_CHECK_LIMIT: - break - - if suspicious: - _echo("\n TAMPERING DETECTED:") - for name, app_id, diff in suspicious: - _echo(f" {name} (AppID={app_id}): +{diff} new achievements!") - send_notification( - "Tampering Detected!", - f"Achievements unlocked on {len(suspicious)} non-assigned games!", - ) - - -# ────────────────────────────────────────────────────────────── -# Enforce mode (daemon loop) -# ────────────────────────────────────────────────────────────── - -# How often the enforce loop runs (seconds). -ENFORCE_INTERVAL = 3 - - -def _guard_installed_games(allowed_app_id: int | None) -> int: - """Remove any unauthorized game manifests + files. Runs every loop. - - Returns number of games removed this pass. - """ - installed = get_installed_games() - count = 0 - for app_id, name in installed: - if app_id == allowed_app_id: - continue - if app_id in PROTECTED_APP_IDS: - continue - - logger.warning( - "Unauthorized game detected — removing: %s (AppID=%d)", name, app_id - ) - if uninstall_game(app_id, name): - count += 1 - send_notification( - "Game Removed!", - f"Uninstalled {name} (AppID={app_id}). " - f"Only the assigned game is allowed.", - ) - return count - - -def _enforce_setup(config: Config, state: State) -> None: - """Perform initial setup for enforcement mode. - - Args: - config: Enforcer configuration. - state: Current enforcer state. - """ - # Initial store block. - if config.block_store: - if block_store(): - _echo(" Steam store: BLOCKED") - else: - _echo(" Steam store: FAILED (need sudo?)") - - # Initial cleanup. - if config.uninstall_other_games: - _echo(" Uninstalling non-assigned games...") - count = uninstall_other_games(state.current_app_id) - _echo(f" Uninstalled {count} games") - - # Auto-install the assigned game. - _enforce_auto_install(config, state) - - # Hide all other games in the Steam library. - _enforce_hide_games(config, state) - - -def _enforce_auto_install(config: Config, state: State) -> None: - """Auto-install the assigned game if not already installed. - - Args: - config: Enforcer configuration. - state: Current enforcer state. - """ - app_id = state.current_app_id - if app_id is None: - return - if not is_game_installed(app_id): - _echo(f" Auto-installing {state.current_game_name}...") - if install_game(app_id, state.current_game_name, config.steam_id): - send_notification( - "Game Installing", - f"{state.current_game_name} is being downloaded.", - ) - else: - _echo(" Could not auto-install. Install manually from Steam.") - else: - _echo(f" Assigned game already installed: {state.current_game_name}") - - -def _enforce_hide_games(config: Config, state: State) -> None: - """Hide non-assigned games in the Steam library. - - Args: - config: Enforcer configuration. - state: Current enforcer state. - """ - owned_ids = _get_all_owned_app_ids(config) - if owned_ids: - hidden = hide_other_games(owned_ids, state.current_app_id) - if hidden > 0: - _echo(f" Library: hid {hidden} games (only assigned game visible)") - else: - _echo(" Library: games already hidden") - else: - _echo(" Library hiding: skipped (no owned game list — run 'scan' first)") - - -def _enforce_loop_iteration(config: Config, state: State) -> None: - """Perform one iteration of the enforcement loop. - - Args: - config: Enforcer configuration. - state: Current enforcer state. - """ - # A) Kill unauthorized game processes. - if config.kill_unauthorized_games: - violations = enforce_allowed_game( - state.current_app_id, - kill_unauthorized=True, - ) - for pid, app_id in violations: - _echo(f" Killed unauthorized game: AppID={app_id} (PID={pid})") - send_notification( - "Game Blocked!", - f"Killed unauthorized game (AppID={app_id}). " - f"Focus on {state.current_game_name}!", - ) - - # B) Remove any newly-installed unauthorized games. - if config.uninstall_other_games: - removed = _guard_installed_games(state.current_app_id) - if removed > 0: - _echo(f" Guard removed {removed} unauthorized game(s)") - - # C) Re-install assigned game if it was somehow removed. - app_id = state.current_app_id - if app_id is not None and not is_game_installed(app_id): - logger.info( - "Assigned game disappeared — re-installing %s", - state.current_game_name, - ) - install_game( - app_id, - state.current_game_name, - config.steam_id, - ) - - -def do_enforce(config: Config, state: State) -> None: - """Run the enforcer: block store, uninstall other games, kill processes. - - This is a persistent loop that continuously: - 1. Keeps the Steam store blocked. - 2. Removes any newly-installed unauthorized games. - 3. Auto-installs the assigned game if missing. - 4. Kills any running unauthorized game processes. - """ - if state.current_app_id is None: - _echo("No game assigned. Run 'scan' first.") - return - - _echo(f"Enforcing: {state.current_game_name} (AppID={state.current_app_id})") - _enforce_setup(config, state) - - _echo(f" Enforce loop: ACTIVE (every {ENFORCE_INTERVAL}s)") - _echo(" Guarding: processes + installs + store") - _echo(" Press Ctrl+C to stop.\n") - try: - while True: - _enforce_loop_iteration(config, state) - time.sleep(ENFORCE_INTERVAL) - except KeyboardInterrupt: - _echo("\nEnforcer stopped.") # ────────────────────────────────────────────────────────────── @@ -934,7 +161,7 @@ def cmd_reset(config: Config, state: State) -> None: # Unhide all games in the library. try: - owned = _get_all_owned_app_ids(config) + owned = get_all_owned_app_ids(config) if owned: count = unhide_all_games(owned) if count: @@ -1015,29 +242,13 @@ def cmd_install(config: Config, state: State) -> None: _echo("Failed to create install manifest.") -def _get_all_owned_app_ids(config: Config) -> list[int]: - """Get all owned game app IDs from the snapshot or Steam API.""" - snapshot = load_snapshot() - if snapshot: - return [d["app_id"] for d in snapshot] - - # Fall back to a quick API call. - try: - client = SteamAPIClient(config.steam_api_key, config.steam_id) - owned = client.get_owned_games() - return [g["appid"] for g in owned] - except (OSError, RuntimeError, ValueError): - logger.warning("Could not fetch owned game list for hiding.") - return [] - - def cmd_hide(config: Config, state: State) -> None: """Hide all non-assigned games in the Steam library.""" if state.current_app_id is None: _echo("No game assigned. Run 'scan' first.") return - owned_ids = _get_all_owned_app_ids(config) + owned_ids = get_all_owned_app_ids(config) if not owned_ids: _echo("No owned game list available. Run 'scan' first.") return @@ -1052,7 +263,7 @@ def cmd_hide(config: Config, state: State) -> None: def cmd_unhide(config: Config, _state: State) -> None: """Unhide all games in the Steam library.""" - owned_ids = _get_all_owned_app_ids(config) + owned_ids = get_all_owned_app_ids(config) if not owned_ids: _echo("No owned game list available. Run 'scan' first.") return @@ -1130,7 +341,7 @@ def _finalize_completion( _echo(" No more games to assign!") return - owned_ids = _get_all_owned_app_ids(config) + owned_ids = get_all_owned_app_ids(config) if owned_ids: hidden = hide_other_games(owned_ids, state.current_app_id) if hidden > 0: @@ -1143,6 +354,37 @@ def _finalize_completion( _echo(f"\nAll done! Go play {state.current_game_name}!") +def _enforce_on_done(config: Config, state: State) -> None: + """Run a single enforcement pass during the 'done' command. + + Kills unauthorized game processes, uninstalls unauthorized games, + and ensures the assigned game is installed. + """ + if state.current_app_id is None: + return + + if config.kill_unauthorized_games: + violations = enforce_allowed_game( + state.current_app_id, + kill_unauthorized=True, + ) + for pid, app_id in violations: + _echo(f" Killed unauthorized game: AppID={app_id} (PID={pid})") + + if config.uninstall_other_games: + count = uninstall_other_games(state.current_app_id) + if count: + _echo(f" Uninstalled {count} unauthorized game(s)") + + if not is_game_installed(state.current_app_id): + _echo(f" Re-installing {state.current_game_name}...") + install_game( + state.current_app_id, + state.current_game_name, + config.steam_id, + ) + + def cmd_done(config: Config, state: State) -> None: """Check completion, pick next game, uninstall & hide. @@ -1186,6 +428,7 @@ def cmd_done(config: Config, state: State) -> None: if not game.is_complete: remaining = game.total_achievements - game.unlocked_achievements _echo(f"\n NOT COMPLETE: {remaining} achievements remaining. Keep going!") + _enforce_on_done(config, state) return _finalize_completion(config, state, game_name, app_id) diff --git a/python_pkg/steam_backlog_enforcer/scanning.py b/python_pkg/steam_backlog_enforcer/scanning.py new file mode 100644 index 0000000..8a25b5c --- /dev/null +++ b/python_pkg/steam_backlog_enforcer/scanning.py @@ -0,0 +1,501 @@ +"""Game scanning, selection, checking, and enforcement daemon.""" + +from __future__ import annotations + +import logging +import time +from typing import Any + +from python_pkg.steam_backlog_enforcer.config import ( + Config, + State, + load_snapshot, + save_snapshot, +) +from python_pkg.steam_backlog_enforcer.enforcer import ( + enforce_allowed_game, + send_notification, +) +from python_pkg.steam_backlog_enforcer.game_install import ( + PROTECTED_APP_IDS, + _echo, + get_installed_games, + install_game, + is_game_installed, + uninstall_game, + uninstall_other_games, +) +from python_pkg.steam_backlog_enforcer.hltb import fetch_hltb_times_cached +from python_pkg.steam_backlog_enforcer.library_hider import hide_other_games +from python_pkg.steam_backlog_enforcer.protondb import ( + ProtonDBRating, + fetch_protondb_ratings, +) +from python_pkg.steam_backlog_enforcer.steam_api import GameInfo, SteamAPIClient +from python_pkg.steam_backlog_enforcer.store_blocker import block_store + +logger = logging.getLogger(__name__) + +_TAMPER_CHECK_LIMIT = 3 + + +# ────────────────────────────────────────────────────────────── +# Scanning & game selection +# ────────────────────────────────────────────────────────────── + + +def do_scan(config: Config, state: State) -> list[GameInfo]: + """Full library scan: Steam API + HLTB times.""" + client = SteamAPIClient(config.steam_api_key, config.steam_id) + + start = time.time() + done_count = 0 + + def progress(current: int, total: int) -> None: + nonlocal done_count + done_count = current + if current % 50 == 0 or current == total: + _echo(f"\r Scanning achievements: {current}/{total}", end="", flush=True) + + _echo("Scanning Steam library...") + games = client.build_game_list( + skip_app_ids=config.skip_app_ids, + progress_callback=progress, + ) + elapsed = time.time() - start + _echo(f"\n Scanned {len(games)} games with achievements in {elapsed:.1f}s") + + # Fetch HLTB times (cached). + incomplete = [(g.app_id, g.name) for g in games if not g.is_complete] + if incomplete: + _echo(f"Fetching HLTB completion times for {len(incomplete)} games...") + + def hltb_progress(done: int, total: int, found: int, name: str) -> None: + pct = done * 100 // total + bar_w = 30 + filled = bar_w * done // total + bar = "█" * filled + "░" * (bar_w - filled) + _echo( + f"\r HLTB [{bar}] {done}/{total} ({pct}%) " + f"| {found} found | {name[:30]:<30s}", + end="", + flush=True, + ) + + hltb_cache = fetch_hltb_times_cached(incomplete, progress_cb=hltb_progress) + _echo("") # newline after progress bar + for g in games: + hours = hltb_cache.get(g.app_id, -1) + g.completionist_hours = hours + found = sum(1 for h in hltb_cache.values() if h > 0) + _echo(f" HLTB data: {found} games have completion estimates") + + # Save snapshot. + save_snapshot([g.to_snapshot() for g in games]) + + complete = [g for g in games if g.is_complete] + incomplete_games = [g for g in games if not g.is_complete] + _echo(f"\nResults: {len(complete)} complete, {len(incomplete_games)} incomplete") + + # Auto-pick a game if none assigned. + if state.current_app_id is None: + pick_next_game(games, state, config) + + return games + + +# How many candidates to check per ProtonDB batch. +_PROTONDB_BATCH_SIZE = 20 + + +def _pick_playable_candidate( + candidates: list[GameInfo], +) -> GameInfo | None: + """Return the first candidate with an acceptable ProtonDB rating. + + Checks candidates in batches (sorted by HLTB hours, shortest first). + Games rated silver-or-worse, or gold-trending-down, are skipped. + """ + offset = 0 + while offset < len(candidates): + batch = candidates[offset : offset + _PROTONDB_BATCH_SIZE] + app_ids = [g.app_id for g in batch] + ratings = fetch_protondb_ratings(app_ids) + + for game in batch: + rating = ratings.get(game.app_id, ProtonDBRating(app_id=game.app_id)) + if rating.is_playable: + if offset > 0 or game is not batch[0]: + _echo( + f" Skipped {offset + batch.index(game)} game(s) " + f"with poor Linux compatibility" + ) + return game + logger.info( + "Skipping %s (AppID=%d): ProtonDB %s (trending %s)", + game.name, + game.app_id, + rating.tier, + rating.trending_tier, + ) + + offset += _PROTONDB_BATCH_SIZE + + return None + + +def pick_next_game(games: list[GameInfo], state: State, config: Config) -> None: + """Select the next game: shortest completionist time first. + + Games with silver-or-worse ProtonDB ratings (or gold trending + downward) are automatically skipped as unplayable on Linux. + """ + skip = set(config.skip_app_ids) | set(state.finished_app_ids) + candidates = [g for g in games if not g.is_complete and g.app_id not in skip] + + if not candidates: + _echo("\nCongratulations! All games are complete!") + state.current_app_id = None + state.current_game_name = "" + state.save() + return + + # Sort: games with known HLTB time first (shortest), then unknown. + def sort_key(g: GameInfo) -> tuple[int, float]: + if g.completionist_hours > 0: + return (0, g.completionist_hours) + return (1, g.name.lower().encode().hex().__hash__()) + + candidates.sort(key=sort_key) + + # Filter out Linux-incompatible games via ProtonDB. + chosen = _pick_playable_candidate(candidates) + + if chosen is None: + _echo("\nNo playable games left (all have poor ProtonDB ratings)!") + state.current_app_id = None + state.current_game_name = "" + state.save() + return + + state.current_app_id = chosen.app_id + state.current_game_name = chosen.name + state.save() + + hours_str = "" + if chosen.completionist_hours > 0: + hours_str = f" (~{chosen.completionist_hours:.1f}h to 100%)" + _echo(f"\n>>> ASSIGNED: {chosen.name} (AppID={chosen.app_id}){hours_str}") + _echo( + f" Progress: {chosen.unlocked_achievements}/{chosen.total_achievements}" + f" ({chosen.completion_pct:.1f}%)" + ) + + # Uninstall all other games first, then auto-install the assigned one. + if config.uninstall_other_games: + count = uninstall_other_games(chosen.app_id) + if count: + _echo(f"\n Uninstalled {count} non-assigned games") + + if not is_game_installed(chosen.app_id): + _echo(f"\n Auto-installing {chosen.name}...") + install_game(chosen.app_id, chosen.name, config.steam_id) + + +# ────────────────────────────────────────────────────────────── +# Checking & tampering detection +# ────────────────────────────────────────────────────────────── + + +def do_check(config: Config, state: State) -> None: + """Check assigned game completion status; detect tampering.""" + if state.current_app_id is None: + _echo("No game currently assigned. Run 'scan' first.") + return + + client = SteamAPIClient(config.steam_api_key, config.steam_id) + _echo(f"Checking {state.current_game_name} (AppID={state.current_app_id})...") + + game = client.refresh_single_game(state.current_app_id, state.current_game_name) + if game is None: + _echo(" Could not fetch achievement data.") + return + + _echo( + f" Progress: {game.unlocked_achievements}/{game.total_achievements}" + f" ({game.completion_pct:.1f}%)" + ) + + if game.is_complete: + _echo(f"\n COMPLETED: {state.current_game_name}!") + state.finished_app_ids.append(state.current_app_id) + send_notification( + "Game Complete!", + f"You finished {state.current_game_name}! Picking next game...", + ) + + # Load snapshot and pick next. + snapshot_data = load_snapshot() + if snapshot_data: + games = [GameInfo.from_snapshot(d) for d in snapshot_data] + pick_next_game(games, state, config) + else: + state.current_app_id = None + state.current_game_name = "" + state.save() + _echo(" Run 'scan' to pick the next game.") + else: + remaining = game.total_achievements - game.unlocked_achievements + _echo(f" {remaining} achievements remaining. Keep going!") + + # Tampering detection on snapshot. + detect_tampering(config, state) + + +def _check_game_tampering( + client: SteamAPIClient, + entry: dict[str, Any], + state: State, +) -> tuple[str, int, int] | None: + """Check if a single game has unexpected achievement progress. + + Args: + client: Steam API client. + entry: Snapshot entry for the game. + state: Current enforcer state. + + Returns: + Tuple of (name, app_id, diff) if tampering detected, else None. + """ + app_id = entry["app_id"] + if app_id == state.current_app_id: + return None + if entry["unlocked_achievements"] >= entry["total_achievements"]: + return None + if entry.get("playtime_minutes", 0) <= 0: + return None + game = client.refresh_single_game( + app_id, entry["name"], entry.get("playtime_minutes", 0) + ) + if game and game.unlocked_achievements > entry["unlocked_achievements"]: + diff = game.unlocked_achievements - entry["unlocked_achievements"] + return (entry["name"], app_id, diff) + return None + + +def detect_tampering(config: Config, state: State) -> None: + """Check if achievements were unlocked on non-assigned games.""" + old_snapshot = load_snapshot() + if old_snapshot is None: + return + + client = SteamAPIClient(config.steam_api_key, config.steam_id) + + # Quick check: only re-fetch a few random non-assigned games. + suspicious: list[tuple[str, int, int]] = [] + for entry in old_snapshot: + result = _check_game_tampering(client, entry, state) + if result: + suspicious.append(result) + if len(suspicious) >= _TAMPER_CHECK_LIMIT: + break + + if suspicious: + _echo("\n TAMPERING DETECTED:") + for name, app_id, diff in suspicious: + _echo(f" {name} (AppID={app_id}): +{diff} new achievements!") + send_notification( + "Tampering Detected!", + f"Achievements unlocked on {len(suspicious)} non-assigned games!", + ) + + +# ────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────── + + +def get_all_owned_app_ids(config: Config) -> list[int]: + """Get all owned game app IDs from the snapshot or Steam API.""" + snapshot = load_snapshot() + if snapshot: + return [d["app_id"] for d in snapshot] + + # Fall back to a quick API call. + try: + client = SteamAPIClient(config.steam_api_key, config.steam_id) + owned = client.get_owned_games() + return [g["appid"] for g in owned] + except (OSError, RuntimeError, ValueError): + logger.warning("Could not fetch owned game list for hiding.") + return [] + + +# ────────────────────────────────────────────────────────────── +# Enforce mode (daemon loop) +# ────────────────────────────────────────────────────────────── + +# How often the enforce loop runs (seconds). +ENFORCE_INTERVAL = 3 + + +def _guard_installed_games(allowed_app_id: int | None) -> int: + """Remove any unauthorized game manifests + files. Runs every loop. + + Returns number of games removed this pass. + """ + installed = get_installed_games() + count = 0 + for app_id, name in installed: + if app_id == allowed_app_id: + continue + if app_id in PROTECTED_APP_IDS: + continue + + logger.warning( + "Unauthorized game detected — removing: %s (AppID=%d)", name, app_id + ) + if uninstall_game(app_id, name): + count += 1 + send_notification( + "Game Removed!", + f"Uninstalled {name} (AppID={app_id}). " + f"Only the assigned game is allowed.", + ) + return count + + +def _enforce_setup(config: Config, state: State) -> None: + """Perform initial setup for enforcement mode. + + Args: + config: Enforcer configuration. + state: Current enforcer state. + """ + # Initial store block. + if config.block_store: + if block_store(): + _echo(" Steam store: BLOCKED") + else: + _echo(" Steam store: FAILED (need sudo?)") + + # Initial cleanup. + if config.uninstall_other_games: + _echo(" Uninstalling non-assigned games...") + count = uninstall_other_games(state.current_app_id) + _echo(f" Uninstalled {count} games") + + # Auto-install the assigned game. + _enforce_auto_install(config, state) + + # Hide all other games in the Steam library. + _enforce_hide_games(config, state) + + +def _enforce_auto_install(config: Config, state: State) -> None: + """Auto-install the assigned game if not already installed. + + Args: + config: Enforcer configuration. + state: Current enforcer state. + """ + app_id = state.current_app_id + if app_id is None: + return + if not is_game_installed(app_id): + _echo(f" Auto-installing {state.current_game_name}...") + if install_game(app_id, state.current_game_name, config.steam_id): + send_notification( + "Game Installing", + f"{state.current_game_name} is being downloaded.", + ) + else: + _echo(" Could not auto-install. Install manually from Steam.") + else: + _echo(f" Assigned game already installed: {state.current_game_name}") + + +def _enforce_hide_games(config: Config, state: State) -> None: + """Hide non-assigned games in the Steam library. + + Args: + config: Enforcer configuration. + state: Current enforcer state. + """ + owned_ids = get_all_owned_app_ids(config) + if owned_ids: + hidden = hide_other_games(owned_ids, state.current_app_id) + if hidden > 0: + _echo(f" Library: hid {hidden} games (only assigned game visible)") + else: + _echo(" Library: games already hidden") + else: + _echo(" Library hiding: skipped (no owned game list — run 'scan' first)") + + +def _enforce_loop_iteration(config: Config, state: State) -> None: + """Perform one iteration of the enforcement loop. + + Args: + config: Enforcer configuration. + state: Current enforcer state. + """ + # A) Kill unauthorized game processes. + if config.kill_unauthorized_games: + violations = enforce_allowed_game( + state.current_app_id, + kill_unauthorized=True, + ) + for pid, app_id in violations: + _echo(f" Killed unauthorized game: AppID={app_id} (PID={pid})") + send_notification( + "Game Blocked!", + f"Killed unauthorized game (AppID={app_id}). " + f"Focus on {state.current_game_name}!", + ) + + # B) Remove any newly-installed unauthorized games. + if config.uninstall_other_games: + removed = _guard_installed_games(state.current_app_id) + if removed > 0: + _echo(f" Guard removed {removed} unauthorized game(s)") + + # C) Re-install assigned game if it was somehow removed. + app_id = state.current_app_id + if app_id is not None and not is_game_installed(app_id): + logger.info( + "Assigned game disappeared — re-installing %s", + state.current_game_name, + ) + install_game( + app_id, + state.current_game_name, + config.steam_id, + ) + + +def do_enforce(config: Config, state: State) -> None: + """Run the enforcer: block store, uninstall other games, kill processes. + + This is a persistent loop that continuously: + 1. Keeps the Steam store blocked. + 2. Removes any newly-installed unauthorized games. + 3. Auto-installs the assigned game if missing. + 4. Kills any running unauthorized game processes. + """ + if state.current_app_id is None: + _echo("No game assigned. Run 'scan' first.") + return + + _echo(f"Enforcing: {state.current_game_name} (AppID={state.current_app_id})") + _enforce_setup(config, state) + + _echo(f" Enforce loop: ACTIVE (every {ENFORCE_INTERVAL}s)") + _echo(" Guarding: processes + installs + store") + _echo(" Press Ctrl+C to stop.\n") + try: + while True: + _enforce_loop_iteration(config, state) + time.sleep(ENFORCE_INTERVAL) + except KeyboardInterrupt: + _echo("\nEnforcer stopped.") diff --git a/python_pkg/steam_backlog_enforcer/steam-backlog-enforcer.service b/python_pkg/steam_backlog_enforcer/steam-backlog-enforcer.service index 9483de3..babf1d7 100644 --- a/python_pkg/steam_backlog_enforcer/steam-backlog-enforcer.service +++ b/python_pkg/steam_backlog_enforcer/steam-backlog-enforcer.service @@ -1,6 +1,6 @@ [Unit] Description=Steam Backlog Enforcer -After=network-online.target graphical.target +After=network-online.target Wants=network-online.target [Service] diff --git a/tests/test_file_length.py b/tests/test_file_length.py new file mode 100644 index 0000000..716d53f --- /dev/null +++ b/tests/test_file_length.py @@ -0,0 +1,47 @@ +"""Test that all Python source files are at most 500 lines long.""" + +from __future__ import annotations + +from pathlib import Path + +MAX_LINES = 500 + +# Directories to skip (vendored / generated / virtual-envs) +_SKIP_DIRS = frozenset( + { + ".venv", + "__pycache__", + "build", + "dist", + ".eggs", + "node_modules", + "sonic_pi", + ".git", + } +) + +_ROOT = Path(__file__).resolve().parents[1] + + +def _python_files() -> list[Path]: + """Collect every *.py file under the repo root, skipping vendored dirs.""" + files: list[Path] = [] + for path in _ROOT.rglob("*.py"): + if any(part in _SKIP_DIRS for part in path.parts): + continue + files.append(path) + return sorted(files) + + +def test_all_python_files_are_at_most_500_lines() -> None: + """Every Python source file must be at most 500 lines.""" + violations: list[str] = [] + for path in _python_files(): + line_count = len(path.read_text(encoding="utf-8").splitlines()) + if line_count > MAX_LINES: + rel = path.relative_to(_ROOT) + violations.append(f" {rel}: {line_count} lines") + + assert not violations, ( + f"The following files exceed {MAX_LINES} lines:\n" + "\n".join(violations) + )