WIP: Enforce 500-line limit - split batch 1

Split 16+ files. 27 files still need splitting. See session notes.
This commit is contained in:
Krzysztof kuhy Rudnicki 2026-03-16 22:46:48 +01:00
parent e51c12dd8e
commit c985160d17
106 changed files with 28081 additions and 25858 deletions

View File

@ -64,7 +64,7 @@ unfixable = []
"SLF001", # Allow private member access in tests
]
# Files using urlopen with validated URL schemes
"python_pkg/geo_data.py" = ["S310"]
"python_pkg/geo_data/_common.py" = ["S310"]
"python_pkg/steam_backlog_enforcer/library_hider.py" = ["S310"]
"poker_modifier_app/poker_modifier_app.py" = [
"FBT003", # Boolean positional values in tkinter API calls
@ -76,9 +76,12 @@ unfixable = []
"FBT003", # Boolean positional values in tkinter API calls
]
# Brother printer - optional usb.core/usb.util imports
"python_pkg/brother_printer/check_brother_printer.py" = [
"python_pkg/brother_printer/cups_service.py" = [
"PLC0415", # Late imports for optional pyusb dependency
]
"python_pkg/brother_printer/usb_query.py" = [
"PLC0415", # Late import of cups_service fallback
]
# Music generator - CLI script with intentional patterns
"python_pkg/music_gen/music_generator.py" = [
"T201", # print() is intentional for CLI feedback
@ -182,6 +185,8 @@ disable = []
[tool.pylint.design]
# A class with just run() as public API is valid for games/apps
min-public-methods = 1
# Enforce maximum file length of 500 lines
max-module-lines = 500
[tool.pylint.spelling]
# No spelling dictionary to avoid false positives

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,173 @@
"""Constants, status codes, and lookup tables for Brother printer checking."""
from __future__ import annotations
import sys
# ── Colors ───────────────────────────────────────────────────────────
RED = "\033[0;31m"
YELLOW = "\033[1;33m"
GREEN = "\033[0;32m"
CYAN = "\033[0;36m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
# ── SNMP supply level sentinel values ────────────────────────────────────────
SNMP_LEVEL_OK = -3
SNMP_LEVEL_LOW = -2
SUPPLY_LOW_PCT = 10
SUPPLY_WARN_PCT = 25
PROGRESS_BAR_WIDTH = 25
# Brother HL-1110 consumable page ratings
TONER_RATED_PAGES = 1000
DRUM_RATED_PAGES = 10000
CUPS_PAGE_LOG_PATH = "/var/log/cups/page_log"
CONSUMABLE_STATE_DIR = ".config/brother_printer"
MIN_LPSTAT_JOB_PARTS = 4
BROTHER_USB_VENDOR_ID = 0x04F9
def _out(text: str = "") -> None:
"""Write a line to stdout."""
sys.stdout.write(text + "\n")
def _prompt(text: str) -> str:
"""Read user input with a prompt."""
sys.stdout.write(text)
sys.stdout.flush()
return sys.stdin.readline().strip()
# ── Brother PJL status codes ────────────────────────────────────────
# Documented in Brother PJL Technical Reference.
# Format: code -> (severity, short_text, action)
# Severities: ok, info, warn, critical
BROTHER_STATUS_CODES: dict[int, tuple[str, str, str]] = {
10001: ("ok", "Ready", ""),
10002: ("ok", "Sleep", ""),
10003: ("info", "Self-test / Calibrating", ""),
10004: ("ok", "Warming up", ""),
10005: ("ok", "Cooling down", ""),
10006: ("info", "Processing", ""),
10007: ("info", "Printing", ""),
10014: ("ok", "Cancelling", ""),
10023: ("info", "Waiting", ""),
# Toner
30010: (
"warn",
"Toner Low",
"Order replacement toner cartridge (TN-1050/TN-1030 compatible).",
),
30038: (
"warn",
"Toner Low",
"Order replacement toner cartridge (TN-1050/TN-1030 compatible).",
),
40038: (
"warn",
"Toner Low",
"Order replacement toner cartridge (TN-1050/TN-1030 compatible).",
),
40309: (
"critical",
"Replace Toner",
"The toner cartridge needs immediate replacement"
" (TN-1050/TN-1030 compatible).",
),
40310: (
"critical",
"Toner End",
"The toner cartridge is empty." " Replace now (TN-1050/TN-1030 compatible).",
),
# Drum
30201: (
"warn",
"Drum End Soon",
"The drum unit is nearing end of life."
" Order replacement (DR-1050 compatible).",
),
40201: (
"warn",
"Drum End Soon",
"The drum unit is nearing end of life."
" Order replacement (DR-1050 compatible).",
),
40019: (
"critical",
"Replace Drum",
"The drum unit must be replaced (DR-1050 compatible).",
),
40020: (
"critical",
"Drum Stop",
"The drum unit must be replaced immediately (DR-1050 compatible).",
),
# Paper / feed
40000: ("critical", "Paper Jam", "Clear the paper jam and close all covers."),
40300: (
"critical",
"No Paper / Tray Open",
"Load paper or close the paper tray.",
),
40302: ("critical", "No Paper", "Load paper into the paper tray."),
40016: ("warn", "Paper Feed Error", "Check paper tray and re-seat paper."),
# Cover
41000: ("critical", "Cover Open", "Close the top cover of the printer."),
41001: ("critical", "Cover Open", "Close the front cover of the printer."),
# Others
35078: ("info", "Manual Feed", "Load paper in the manual feed slot."),
42000: (
"critical",
"Machine Error",
"Power-cycle the printer. If error persists, contact service.",
),
}
# ── CUPS status code mappings ────────────────────────────────────────
_CUPS_REASONS_TO_STATUS: dict[str, int] = {
"paused": 10023,
"moving-to-paused": 10023,
"toner-low": 30010,
"toner-empty": 40310,
"marker-supply-low": 30010,
"marker-supply-empty": 40310,
"media-empty": 40302,
"media-needed": 40302,
"media-jam": 40000,
"cover-open": 41000,
"door-open": 41000,
"input-tray-missing": 40300,
}
_CUPS_STATE_TO_STATUS: dict[str, int] = {
"idle": 10001,
"processing": 10007,
"stopped": 10023,
}
_ERROR_REASON_MAP: tuple[tuple[tuple[str, ...], str, str], ...] = (
(("media-jam",), "40000", "Paper Jam"),
(("cover-open", "door-open"), "41000", "Cover Open"),
(("toner-empty",), "40310", "Toner End"),
(("toner-low",), "30010", "Toner Low"),
)
def get_status_info(code: str) -> tuple[str, str, str]:
"""Look up a PJL status code. Returns (severity, text, action)."""
try:
return BROTHER_STATUS_CODES[int(code)]
except (KeyError, ValueError):
return (
"info",
f"Unknown status (code {code})",
"Check printer display for details.",
)

View File

@ -0,0 +1,459 @@
"""CUPS queue inspection, display, and interactive fix functions."""
from __future__ import annotations
from pathlib import Path
import re
import shutil
import subprocess
import sys
import time
from typing import TYPE_CHECKING
from python_pkg.brother_printer.constants import (
BOLD,
CYAN,
DIM,
GREEN,
MIN_LPSTAT_JOB_PARTS,
RED,
RESET,
YELLOW,
_out,
_prompt,
)
from python_pkg.brother_printer.cups_service import find_cups_printer_name
from python_pkg.brother_printer.data_classes import CUPSJob, CUPSQueueStatus
if TYPE_CHECKING:
from collections.abc import Callable
# ── Queue inspection ─────────────────────────────────────────────────
def _parse_lpstat_printer_line(line: str) -> tuple[bool, str]:
"""Parse an lpstat -p line. Returns (enabled, reason)."""
enabled = "disabled" not in line.lower()
reason = ""
match = re.search(r"\d{4}\s+-\s*(.+)", line)
if match:
reason = match.group(1).strip()
return enabled, reason
def _parse_lpstat_jobs(output: str, printer_name: str) -> list[CUPSJob]:
"""Parse lpstat -o output into CUPSJob list."""
jobs: list[CUPSJob] = []
for line in output.splitlines():
if not line.startswith(printer_name):
continue
parts = line.split()
if len(parts) >= MIN_LPSTAT_JOB_PARTS:
job_id = parts[0]
user = parts[1]
size = parts[2]
date = " ".join(parts[3:])
jobs.append(CUPSJob(job_id=job_id, user=user, size=size, date=date))
return jobs
def get_cups_queue_status() -> CUPSQueueStatus:
"""Check if the CUPS queue is disabled and list pending jobs."""
printer_name = find_cups_printer_name()
if not printer_name:
return CUPSQueueStatus()
result = CUPSQueueStatus(printer_name=printer_name)
lpstat_path = shutil.which("lpstat")
if not lpstat_path:
return result
try:
r = subprocess.run(
[lpstat_path, "-p", printer_name],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if "printer" in line.lower() and printer_name in line:
result.enabled, result.reason = _parse_lpstat_printer_line(line)
break
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
pass
try:
r = subprocess.run(
[lpstat_path, "-o", printer_name],
capture_output=True,
text=True,
timeout=5,
check=False,
)
result.jobs = _parse_lpstat_jobs(r.stdout, printer_name)
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
pass
has_errors, last_error = _check_cups_backend_errors(printer_name)
result.has_backend_errors = has_errors
result.last_backend_error = last_error
return result
# ── CUPS fix actions ─────────────────────────────────────────────────
def _cups_enable_printer(printer_name: str) -> bool:
"""Re-enable a disabled CUPS printer. Returns True on success."""
cupsenable_path = shutil.which("cupsenable")
if not cupsenable_path:
_out(f" {RED}cupsenable not found.{RESET}")
return False
try:
subprocess.run(
[cupsenable_path, printer_name],
timeout=5,
check=True,
)
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e:
_out(f" {RED}Failed to enable printer: {e}{RESET}")
return False
else:
return True
def _cups_cancel_all_jobs(printer_name: str) -> bool:
"""Cancel all pending jobs. Returns True on success."""
cancel_path = shutil.which("cancel")
if not cancel_path:
_out(f" {RED}cancel command not found.{RESET}")
return False
try:
subprocess.run(
[cancel_path, "-a", printer_name],
timeout=5,
check=True,
)
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e:
_out(f" {RED}Failed to cancel jobs: {e}{RESET}")
return False
else:
return True
def _cups_cancel_job(job_id: str) -> bool:
"""Cancel a specific job. Returns True on success."""
cancel_path = shutil.which("cancel")
if not cancel_path:
return False
try:
subprocess.run(
[cancel_path, job_id],
timeout=5,
check=True,
)
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
return False
else:
return True
def _cups_restart_service() -> bool:
"""Restart the CUPS service. Returns True on success."""
systemctl_path = shutil.which("systemctl")
if not systemctl_path:
_out(f" {RED}systemctl not found.{RESET}")
return False
sys.stdout.write(f" {DIM}Restarting CUPS...{RESET}")
sys.stdout.flush()
try:
proc = subprocess.Popen(
[systemctl_path, "restart", "cups"],
)
deadline = time.time() + 30
while proc.poll() is None:
if time.time() > deadline:
proc.kill()
proc.wait()
sys.stdout.write("\n")
_out(
f" {RED}CUPS restart timed out"
f" (stuck backend process?).{RESET}"
)
_out(
f" {DIM}Try: sudo kill -9 $(pgrep -f 'cups/backend/usb')"
f" && sudo systemctl restart cups{RESET}"
)
return False
sys.stdout.write(".")
sys.stdout.flush()
time.sleep(1)
sys.stdout.write("\n")
if proc.returncode != 0:
_out(
f" {RED}CUPS restart failed" f" (exit code {proc.returncode}).{RESET}"
)
return False
except OSError as e:
sys.stdout.write("\n")
_out(f" {RED}Failed to restart CUPS: {e}{RESET}")
return False
time.sleep(2)
return True
# ── Backend error detection ──────────────────────────────────────────
def _is_cups_printer_healthy(printer_name: str) -> bool:
"""Check live CUPS state via lpstat. Returns True if enabled with no issues."""
lpstat_path = shutil.which("lpstat")
if not lpstat_path:
return False
try:
r = subprocess.run(
[lpstat_path, "-p", printer_name],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if (
printer_name in line
and "idle" in line.lower()
and "enabled" in line.lower()
):
return True
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
pass
return False
def _find_backend_error_in_log(
lines: list[str],
) -> tuple[str, str, str]:
"""Scan CUPS log lines (reversed) for backend errors.
Returns:
(backend_error, error_timestamp, last_success_timestamp)
"""
backend_error = ""
error_timestamp = ""
last_success_timestamp = ""
for line in reversed(lines):
if (
"backend errors" in line or "stopped with status" in line
) and not backend_error:
backend_error = line.strip()
ts_match = re.search(r"\[([^\]]+)\]", line)
if ts_match:
error_timestamp = ts_match.group(1)
if ("Completed" in line or "total" in line) and error_timestamp:
ts_match = re.search(r"\[([^\]]+)\]", line)
if ts_match:
last_success_timestamp = ts_match.group(1)
break
return backend_error, error_timestamp, last_success_timestamp
def _check_cups_backend_errors(
printer_name: str,
) -> tuple[bool, str]:
"""Check CUPS error log for backend errors. Returns (has_errors, last_error)."""
if _is_cups_printer_healthy(printer_name):
return False, ""
log_path = Path("/var/log/cups/error_log")
if not log_path.exists():
return False, ""
try:
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
except OSError:
return False, ""
backend_error, error_timestamp, last_success_timestamp = _find_backend_error_in_log(
lines
)
if not backend_error:
return False, ""
if last_success_timestamp and last_success_timestamp > error_timestamp:
return False, ""
return True, backend_error
# ── Queue status display ────────────────────────────────────────────
def display_cups_queue_status(queue: CUPSQueueStatus) -> None:
"""Display CUPS queue status and offer interactive fixes."""
if not queue.printer_name:
return
if queue.enabled and not queue.jobs and not queue.has_backend_errors:
return
_out()
_out(f"{BOLD}── Print Queue ──{RESET}")
_out()
if queue.has_backend_errors and queue.enabled and not queue.jobs:
_out(f" {YELLOW}{BOLD}⚡ CUPS backend has stale errors{RESET}")
_out(
f" {DIM}New print jobs may silently fail."
f" A CUPS restart usually fixes this.{RESET}"
)
_out()
if not queue.enabled:
_out(f" {RED}{BOLD}⚠ Printer queue is DISABLED{RESET}")
if queue.reason:
_out(f" {DIM}Reason: {queue.reason}{RESET}")
_out()
if queue.jobs:
_out(f" {BOLD}Pending jobs ({len(queue.jobs)}):{RESET}")
for job in queue.jobs:
_out(f" {job.job_id} {DIM}{job.user} {job.size}B {job.date}{RESET}")
_out()
_offer_queue_fix(queue)
# ── Interactive queue fix ────────────────────────────────────────────
def _offer_queue_fix(queue: CUPSQueueStatus) -> None:
"""Prompt the user to fix a disabled queue / pending jobs."""
_out(f" {BOLD}Available actions:{RESET}")
options: list[str] = []
if not queue.enabled and queue.jobs:
_out(f" {CYAN}1){RESET} Re-enable printer and retry all jobs")
_out(f" {CYAN}2){RESET} Re-enable printer and cancel all jobs")
_out(f" {CYAN}3){RESET} Cancel all jobs (keep printer disabled)")
_out(f" {CYAN}4){RESET} Restart CUPS service (fixes stale backend)")
_out(f" {CYAN}5){RESET} Restart CUPS + re-enable + retry all jobs")
_out(f" {CYAN}6){RESET} Do nothing")
options = ["1", "2", "3", "4", "5", "6"]
elif not queue.enabled:
_out(f" {CYAN}1){RESET} Re-enable printer")
_out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)")
_out(f" {CYAN}3){RESET} Do nothing")
options = ["1", "2", "3"]
elif queue.jobs:
_out(f" {CYAN}1){RESET} Cancel all pending jobs")
_out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)")
_out(f" {CYAN}3){RESET} Do nothing")
options = ["1", "2", "3"]
else:
_out(f" {CYAN}1){RESET} Restart CUPS service (fixes stale backend)")
_out(f" {CYAN}2){RESET} Do nothing")
options = ["1", "2"]
_out()
choice = _prompt(f" Choose [{'/'.join(options)}]: ")
_out()
if not queue.enabled and queue.jobs:
_handle_disabled_with_jobs(queue, choice)
elif not queue.enabled:
_handle_disabled_no_jobs(queue, choice)
elif queue.jobs:
_handle_enabled_with_jobs(queue, choice)
else:
_handle_backend_errors_only(choice)
def _dwj_enable_only(printer_name: str) -> None:
"""Choice 1: re-enable printer so queued jobs are retried."""
if _cups_enable_printer(printer_name):
_out(f" {GREEN}✓ Printer re-enabled. Jobs will be retried.{RESET}")
def _dwj_cancel_and_enable(printer_name: str) -> None:
"""Choice 2: cancel all jobs then re-enable."""
_cups_cancel_all_jobs(printer_name)
if _cups_enable_printer(printer_name):
_out(f" {GREEN}✓ All jobs cancelled and printer re-enabled.{RESET}")
def _dwj_cancel_only(printer_name: str) -> None:
"""Choice 3: cancel all jobs."""
if _cups_cancel_all_jobs(printer_name):
_out(f" {GREEN}✓ All jobs cancelled.{RESET}")
def _dwj_restart_only(_printer_name: str) -> None:
"""Choice 4: restart CUPS."""
if _cups_restart_service():
_out(f" {GREEN}✓ CUPS restarted.{RESET}")
def _dwj_restart_and_enable(printer_name: str) -> None:
"""Choice 5: restart CUPS and re-enable printer."""
if _cups_restart_service():
_cups_enable_printer(printer_name)
_out(
f" {GREEN}✓ CUPS restarted, printer re-enabled."
f" Jobs will be retried.{RESET}"
)
_DWJ_ACTIONS: dict[str, Callable[[str], None]] = {
"1": _dwj_enable_only,
"2": _dwj_cancel_and_enable,
"3": _dwj_cancel_only,
"4": _dwj_restart_only,
"5": _dwj_restart_and_enable,
}
def _handle_disabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None:
"""Handle fix for disabled printer with pending jobs."""
action = _DWJ_ACTIONS.get(choice)
if action is not None:
action(queue.printer_name)
else:
_out(f" {DIM}No changes made.{RESET}")
def _handle_disabled_no_jobs(queue: CUPSQueueStatus, choice: str) -> None:
"""Handle fix for disabled printer with no pending jobs."""
if choice == "1":
if _cups_enable_printer(queue.printer_name):
_out(f" {GREEN}✓ Printer re-enabled.{RESET}")
elif choice == "2":
if _cups_restart_service():
_cups_enable_printer(queue.printer_name)
_out(f" {GREEN}✓ CUPS restarted and printer re-enabled.{RESET}")
else:
_out(f" {DIM}No changes made.{RESET}")
def _handle_enabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None:
"""Handle fix for enabled printer with stuck jobs."""
if choice == "1":
if _cups_cancel_all_jobs(queue.printer_name):
_out(f" {GREEN}✓ All jobs cancelled.{RESET}")
elif choice == "2":
if _cups_restart_service():
_out(f" {GREEN}✓ CUPS restarted.{RESET}")
else:
_out(f" {DIM}No changes made.{RESET}")
def _handle_backend_errors_only(choice: str) -> None:
"""Handle fix when only stale backend errors are detected."""
if choice == "1":
if _cups_restart_service():
_out(f" {GREEN}✓ CUPS restarted. Stale backend errors cleared.{RESET}")
else:
_out(f" {DIM}No changes made.{RESET}")

View File

@ -0,0 +1,479 @@
"""CUPS service management, USB fallback, and consumable state tracking."""
from __future__ import annotations
import json
import logging
from pathlib import Path
import re
import shutil
import subprocess
import time
import urllib.parse
from python_pkg.brother_printer.constants import (
_CUPS_REASONS_TO_STATUS,
_CUPS_STATE_TO_STATUS,
_ERROR_REASON_MAP,
BROTHER_USB_VENDOR_ID,
CONSUMABLE_STATE_DIR,
CUPS_PAGE_LOG_PATH,
DRUM_RATED_PAGES,
GREEN,
RESET,
TONER_RATED_PAGES,
_out,
)
from python_pkg.brother_printer.data_classes import (
PageCountEstimate,
USBPortStatus,
USBResult,
)
logger = logging.getLogger(__name__)
CUPS_PAGE_LOG = Path(CUPS_PAGE_LOG_PATH)
CONSUMABLE_STATE_FILE = Path.home() / CONSUMABLE_STATE_DIR / "state.json"
# ── pyusb device info ────────────────────────────────────────────────
def _get_pyusb_device_info() -> dict[str, str]:
"""Get Brother USB printer info via pyusb (no interface claim needed)."""
try:
import usb.core
dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID)
if dev is None:
return {}
except (ImportError, OSError, ValueError):
return {}
else:
return {
"product": dev.product or "",
"serial": dev.serial_number or "",
}
# ── CUPS service control ────────────────────────────────────────────
def _stop_cups() -> bool:
"""Stop CUPS service and sockets. Returns True on success."""
systemctl = shutil.which("systemctl")
if not systemctl:
return False
try:
subprocess.run(
[systemctl, "stop", "cups.service", "cups.socket", "cups.path"],
timeout=15,
check=True,
)
time.sleep(2)
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
return False
return True
def is_cups_scheduler_running() -> bool:
"""Check if the CUPS scheduler is currently running."""
lpstat = shutil.which("lpstat")
if not lpstat:
return False
try:
r = subprocess.run(
[lpstat, "-r"],
capture_output=True,
text=True,
timeout=3,
check=False,
)
return (
"is running" in r.stdout.lower() and "not running" not in r.stdout.lower()
)
except (subprocess.TimeoutExpired, OSError):
return False
def start_cups() -> bool:
"""Start CUPS service, socket, and path units. Returns True on success."""
systemctl = shutil.which("systemctl")
if not systemctl:
return False
try:
subprocess.run(
[systemctl, "start", "cups.service", "cups.socket", "cups.path"],
timeout=15,
check=True,
)
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
return False
for _ in range(10):
if is_cups_scheduler_running():
return True
time.sleep(1)
return False
def _ensure_cups_running() -> bool:
"""Make sure CUPS is running, starting it if necessary."""
if is_cups_scheduler_running():
return True
return start_cups()
# ── USB port status via pyusb ────────────────────────────────────────
def _query_usb_port_status_raw() -> USBPortStatus | None:
"""Query USB printer port status via pyusb control transfer.
Requires root and temporarily stops CUPS to access the USB device.
Returns None if the query fails.
"""
try:
import usb.core
import usb.util
except ImportError:
return None
dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID)
if dev is None:
return None
if not _stop_cups():
return None
try:
dev.reset()
time.sleep(2)
dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID)
if dev is None:
return None
try:
if dev.is_kernel_driver_active(0):
dev.detach_kernel_driver(0)
except (usb.core.USBError, NotImplementedError):
pass
usb.util.claim_interface(dev, 0)
try:
# USB Printer Class GET_PORT_STATUS (bRequest=0x01)
raw = dev.ctrl_transfer(0xA1, 0x01, 0, 0, 1, timeout=5000)
port_byte = raw[0]
return USBPortStatus(
paper_empty=bool(port_byte & 0x20),
online=bool(port_byte & 0x10),
error=not bool(port_byte & 0x08),
raw_byte=port_byte,
)
finally:
usb.util.release_interface(dev, 0)
usb.util.dispose_resources(dev)
except (OSError, ValueError):
logger.debug("USB port status query failed", exc_info=True)
return None
finally:
start_cups()
# ── Consumable state management ──────────────────────────────────────
def _get_cups_total_pages() -> int:
"""Parse CUPS page_log to get total pages printed (deduplicated by job)."""
if not CUPS_PAGE_LOG.exists():
return 0
try:
text = CUPS_PAGE_LOG.read_text(encoding="utf-8", errors="replace")
except OSError:
return 0
jobs: dict[str, int] = {}
for line in text.splitlines():
match = re.search(r"\s(\d+)\s+\[.*?\]\s+total\s+(\d+)", line)
if match:
job_id = match.group(1)
pages = int(match.group(2))
jobs[job_id] = max(jobs.get(job_id, 0), pages)
return sum(jobs.values())
def _load_consumable_state() -> dict[str, int]:
"""Load consumable replacement state from disk."""
defaults: dict[str, int] = {"toner_replaced_at": 0, "drum_replaced_at": 0}
if not CONSUMABLE_STATE_FILE.exists():
return defaults
try:
data = json.loads(
CONSUMABLE_STATE_FILE.read_text(encoding="utf-8"),
)
return {
"toner_replaced_at": int(data.get("toner_replaced_at", 0)),
"drum_replaced_at": int(data.get("drum_replaced_at", 0)),
}
except (OSError, json.JSONDecodeError, ValueError, TypeError):
return defaults
def _save_consumable_state(state: dict[str, int]) -> None:
"""Persist consumable replacement state to disk."""
CONSUMABLE_STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
CONSUMABLE_STATE_FILE.write_text(
json.dumps(state, indent=2) + "\n",
encoding="utf-8",
)
def reset_consumable(name: str) -> None:
"""Record current page count as replacement point for a consumable."""
total = _get_cups_total_pages()
state = _load_consumable_state()
key = f"{name}_replaced_at"
state[key] = total
_save_consumable_state(state)
_out(
f"{GREEN}{name.capitalize()} counter reset at page count" f" {total}.{RESET}"
)
_out(f" State saved to {CONSUMABLE_STATE_FILE}")
def estimate_consumable_life() -> PageCountEstimate:
"""Estimate toner/drum life from CUPS page count since last replacement."""
total = _get_cups_total_pages()
if total <= 0:
return PageCountEstimate()
state = _load_consumable_state()
toner_pages = max(0, total - state["toner_replaced_at"])
drum_pages = max(0, total - state["drum_replaced_at"])
toner_pct = max(0, 100 - (toner_pages * 100 // TONER_RATED_PAGES))
drum_pct = max(0, 100 - (drum_pages * 100 // DRUM_RATED_PAGES))
return PageCountEstimate(
total_pages=total,
toner_pages=toner_pages,
drum_pages=drum_pages,
toner_pct_remaining=toner_pct,
drum_pct_remaining=drum_pct,
toner_exhausted=toner_pages >= TONER_RATED_PAGES,
toner_low=toner_pages >= TONER_RATED_PAGES * 80 // 100,
drum_near_end=drum_pages >= DRUM_RATED_PAGES * 90 // 100,
)
# ── IPP / CUPS attribute queries ────────────────────────────────────
def _parse_ipp_attributes(output: str) -> dict[str, str]:
"""Parse ipptool verbose output into an attribute dict."""
attrs: dict[str, str] = {}
for line in output.splitlines():
match = re.match(r"\s+(\S+)\s+\([^)]+\)\s+=\s+(.*)", line)
if match:
attrs[match.group(1)] = match.group(2).strip()
return attrs
def _get_cups_ipp_status(printer_name: str) -> dict[str, str]:
"""Query printer attributes via CUPS IPP using ipptool."""
ipptool_path = shutil.which("ipptool")
if not ipptool_path:
return {}
uri = f"ipp://localhost/printers/{printer_name}"
try:
r = subprocess.run(
[ipptool_path, "-tv", uri, "get-printer-attributes.test"],
capture_output=True,
text=True,
timeout=10,
check=False,
)
return _parse_ipp_attributes(r.stdout)
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
return {}
def _get_cups_economode(printer_name: str) -> str:
"""Query toner save mode setting via lpoptions."""
lpoptions_path = shutil.which("lpoptions")
if not lpoptions_path:
return ""
try:
r = subprocess.run(
[lpoptions_path, "-p", printer_name, "-l"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if "conomode" in line.lower():
match = re.search(r"\*(\w+)", line)
if match:
return "ON" if match.group(1).lower() == "true" else "OFF"
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
pass
return ""
# ── Status code mapping ──────────────────────────────────────────────
def _map_cups_to_status_code(state: str, reasons: str) -> str:
"""Map CUPS state + reasons to a Brother PJL status code string."""
for keyword, code in _CUPS_REASONS_TO_STATUS.items():
if keyword in reasons.lower():
return str(code)
clean_state = re.sub(r"\(.*\)", "", state).strip().lower()
return str(_CUPS_STATE_TO_STATUS.get(clean_state, 10001))
def _cups_reasons_to_error(cups_reasons: str) -> tuple[str, str]:
"""Map CUPS reason keywords to a (status_code, display) pair."""
reasons_lower = cups_reasons.lower()
for keywords, code, display in _ERROR_REASON_MAP:
if any(kw in reasons_lower for kw in keywords):
return code, display
return "42000", "Printer Error"
def _port_status_to_status_code(
ps: USBPortStatus,
cups_reasons: str,
) -> tuple[str, str]:
"""Map USB port status + CUPS reasons to (status_code, display)."""
if ps.error and ps.paper_empty:
return "40302", "No Paper"
if ps.error and not ps.online:
return "41000", "Cover Open"
if ps.error:
return _cups_reasons_to_error(cups_reasons)
if ps.paper_empty:
return "40302", "No Paper"
if not ps.online:
return "10002", "Offline / Sleep"
return "", ""
# ── CUPS printer name discovery ──────────────────────────────────────
def find_cups_printer_name() -> str:
"""Find the CUPS queue name for a Brother printer."""
lpstat_path = shutil.which("lpstat")
if not lpstat_path:
return ""
try:
r = subprocess.run(
[lpstat_path, "-v"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if "brother" in line.lower():
match = re.match(r"device for (\S+):", line)
if match:
return match.group(1)
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
pass
return ""
# ── CUPS-based USB fallback query ────────────────────────────────────
def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None:
"""Extract product and serial from a CUPS usb:// URI."""
parsed = urllib.parse.urlparse(uri)
info["product"] = urllib.parse.unquote(parsed.path.lstrip("/"))
qs = urllib.parse.parse_qs(parsed.query)
if "serial" in qs:
info["serial"] = qs["serial"][0]
def _get_printer_info_from_cups() -> dict[str, str]:
"""Get printer model/serial from lpstat."""
info: dict[str, str] = {"product": "", "serial": ""}
try:
r = subprocess.run(
["/usr/bin/lpstat", "-v"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if "Brother" in line:
for part in line.split():
if part.startswith("usb://"):
_parse_cups_usb_uri(part, info)
break
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
logger.debug("Failed to query CUPS for printer info", exc_info=True)
return info
def query_usb_via_cups() -> USBResult:
"""Query USB printer status through CUPS when /dev/usb/lp* is unavailable."""
_ensure_cups_running()
printer_name = find_cups_printer_name()
if not printer_name:
return USBResult(
error="No USB printer device at /dev/usb/lp*"
" (usblp module not available)"
" and no Brother printer found in CUPS.",
)
pyusb_info = _get_pyusb_device_info()
cups_info = _get_printer_info_from_cups()
result = USBResult(
device="cups",
product=(
pyusb_info.get("product")
or cups_info.get("product")
or "Brother Laser Printer"
),
serial=pyusb_info.get("serial") or cups_info.get("serial", ""),
)
ipp = _get_cups_ipp_status(printer_name)
state = ipp.get("printer-state", "")
reasons = ipp.get("printer-state-reasons", "none")
result.economode = _get_cups_economode(printer_name)
port_status = _query_usb_port_status_raw()
if port_status is not None:
result.port_status = port_status
hw_code, hw_display = _port_status_to_status_code(
port_status,
reasons,
)
if hw_code:
result.status_code = hw_code
result.display = hw_display
result.online = "TRUE" if port_status.online else "FALSE"
return result
estimate = estimate_consumable_life()
if estimate.toner_exhausted:
result.status_code = "40310"
result.display = "Toner End (estimated from page count)"
result.online = "TRUE"
return result
if estimate.toner_low:
result.status_code = "30010"
result.display = "Toner Low (estimated from page count)"
result.online = "TRUE"
return result
result.status_code = _map_cups_to_status_code(state, reasons)
result.display = ipp.get("printer-state-message", "")
result.online = "TRUE"
return result
result.status_code = _map_cups_to_status_code(state, reasons)
result.display = ipp.get("printer-state-message", "")
result.online = "TRUE" if state.lower() in {"idle", "processing"} else "FALSE"
return result

View File

@ -0,0 +1,96 @@
"""Data classes for Brother printer status information."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class CUPSJob:
"""A single CUPS print job."""
job_id: str
user: str
size: str
date: str
@dataclass
class CUPSQueueStatus:
"""Status of the CUPS print queue for a printer."""
printer_name: str = ""
enabled: bool = True
reason: str = ""
jobs: list[CUPSJob] = field(default_factory=list)
has_backend_errors: bool = False
last_backend_error: str = ""
@dataclass
class PageCountEstimate:
"""Estimated consumable life based on CUPS page count."""
total_pages: int = 0
toner_pages: int = 0
drum_pages: int = 0
toner_pct_remaining: int = 100
drum_pct_remaining: int = 100
toner_exhausted: bool = False
toner_low: bool = False
drum_near_end: bool = False
@dataclass
class USBPortStatus:
"""IEEE 1284 USB printer port status bits."""
paper_empty: bool = False
online: bool = True
error: bool = False
raw_byte: int = 0
@dataclass
class USBResult:
"""Result from a USB PJL query."""
connection: str = "usb"
device: str = ""
product: str = "Brother Laser Printer"
serial: str = ""
status_code: str = ""
display: str = ""
online: str = ""
economode: str = ""
error: str = ""
port_status: USBPortStatus | None = None
@dataclass
class NetworkResult:
"""Result from an SNMP network query."""
connection: str = "network"
ip: str = ""
product: str = "Unknown"
serial: str = ""
printer_status: str = ""
device_status: str = ""
display: str = ""
page_count: str = ""
supply_descriptions: list[str] = field(default_factory=list)
supply_max: list[str] = field(default_factory=list)
supply_levels: list[str] = field(default_factory=list)
error: str = ""
@dataclass
class SupplyStatus:
"""Processed supply level info for display."""
color: str
bar: str
status_text: str
warning: str
needs_replacement: bool

View File

@ -0,0 +1,385 @@
"""Display and formatting functions for Brother printer status reports."""
from __future__ import annotations
import sys
from python_pkg.brother_printer.constants import (
BOLD,
CYAN,
DIM,
DRUM_RATED_PAGES,
GREEN,
PROGRESS_BAR_WIDTH,
RED,
RESET,
SNMP_LEVEL_LOW,
SNMP_LEVEL_OK,
SUPPLY_LOW_PCT,
SUPPLY_WARN_PCT,
TONER_RATED_PAGES,
YELLOW,
_out,
get_status_info,
)
from python_pkg.brother_printer.cups_queue import (
display_cups_queue_status,
get_cups_queue_status,
)
from python_pkg.brother_printer.cups_service import estimate_consumable_life
from python_pkg.brother_printer.data_classes import (
NetworkResult,
SupplyStatus,
USBResult,
)
# ── Shared display helpers ───────────────────────────────────────────
def _display_report_header() -> None:
"""Print the report banner box."""
_out()
_out(f"{BOLD}╔══════════════════════════════════════════════════╗{RESET}")
_out(f"{BOLD}║ Brother Laser Printer Status Report ║{RESET}")
_out(f"{BOLD}╚══════════════════════════════════════════════════╝{RESET}")
_out()
def _display_page_count_estimate() -> None:
"""Show estimated consumable life based on CUPS page count."""
estimate = estimate_consumable_life()
if estimate.total_pages <= 0:
return
_out(f"{BOLD}── Page Count Estimate ──{RESET}")
_out()
_out(
f" {BOLD}Total pages printed:{RESET} {estimate.total_pages}"
f" (toner: {estimate.toner_pages} since replacement,"
f" drum: {estimate.drum_pages} since replacement)"
)
_out()
# Toner bar
toner_pct = estimate.toner_pct_remaining
toner_filled = toner_pct * PROGRESS_BAR_WIDTH // 100
toner_empty = PROGRESS_BAR_WIDTH - toner_filled
toner_bar = f"[{'' * toner_filled}{'' * toner_empty}]"
if estimate.toner_exhausted:
toner_color = RED
toner_note = " ← REPLACE NOW"
elif estimate.toner_low:
toner_color = YELLOW
toner_note = " ← order soon"
else:
toner_color = GREEN
toner_note = ""
_out(
f" {BOLD}Toner:{RESET} {toner_color}{toner_bar} ~{toner_pct}%"
f"{toner_note}{RESET}"
)
# Drum bar
drum_pct = estimate.drum_pct_remaining
drum_filled = drum_pct * PROGRESS_BAR_WIDTH // 100
drum_empty = PROGRESS_BAR_WIDTH - drum_filled
drum_bar = f"[{'' * drum_filled}{'' * drum_empty}]"
if estimate.drum_near_end:
drum_color = YELLOW
drum_note = " ← nearing end"
else:
drum_color = GREEN
drum_note = ""
_out(
f" {BOLD}Drum:{RESET} {drum_color}{drum_bar} ~{drum_pct}%"
f"{drum_note}{RESET}"
)
_out(
f" {DIM}Based on pages since last replacement"
f" vs rated capacity (toner ~{TONER_RATED_PAGES},"
f" drum ~{DRUM_RATED_PAGES}).{RESET}"
)
_out(f" {DIM}Reset after replacing: --reset-toner or --reset-drum{RESET}")
if estimate.toner_exhausted:
_out()
_out(
f" {RED}{BOLD}⚠ Toner is likely exhausted."
f" This is probably why the orange light is flashing.{RESET}"
)
_out()
def _display_consumables_reference() -> None:
"""Print compatible consumables reference."""
_out(f"{BOLD}── Compatible Consumables ──{RESET}")
_out()
_out(f" {BOLD}Toner:{RESET} TN-1050 / TN-1030 (or compatible third-party)")
_out(f" {BOLD}Drum:{RESET} DR-1050 / DR-1030 (or compatible third-party)")
_out(f" {DIM} Toner rated ~1000 pages; Drum rated ~10000 pages.{RESET}")
_out()
# ── USB display helpers ──────────────────────────────────────────────
def _display_usb_device_info(result: USBResult) -> None:
"""Print device info block for USB results."""
_out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}")
_out(f"{BOLD}Connection:{RESET} USB")
if result.serial:
_out(f"{BOLD}Serial:{RESET} {result.serial}")
if result.online == "TRUE":
_out(f"{BOLD}Online:{RESET} {GREEN}Yes{RESET}")
elif result.online == "FALSE":
_out(f"{BOLD}Online:{RESET} {YELLOW}No (needs attention){RESET}")
_out()
if result.economode:
if result.economode == "ON":
_out(
f"{BOLD}Toner Save:{RESET} {GREEN}ON{RESET}"
" (extends toner life, lighter prints)"
)
else:
_out(f"{BOLD}Toner Save:{RESET} OFF")
_SEVERITY_ICONS: dict[str, str] = {
"ok": "",
"info": "i",
"warn": "",
"critical": "",
}
_SEVERITY_COLORS: dict[str, str] = {
"ok": GREEN,
"info": CYAN,
"warn": YELLOW,
"critical": RED,
}
_SEVERITY_SUMMARIES: dict[str, str] = {
"ok": f"{GREEN}{BOLD}✓ Printer is healthy. No replacements needed.{RESET}",
"info": (
f"{CYAN}{BOLD}i Printer is busy/processing." f" No replacements needed.{RESET}"
),
"warn": (
f"{YELLOW}{BOLD}⚡ WARNING: Maintenance will be needed"
f" soon.{RESET}\n{YELLOW} Order replacement parts"
f" now to avoid interruption.{RESET}"
),
"critical": (
f"{RED}{BOLD}⚠ ACTION REQUIRED:" f" Replacement or fix needed now!{RESET}"
),
}
def _format_status_detail(
severity: str, short_text: str, action: str, result: USBResult
) -> None:
"""Print severity icon, display text, and action."""
color = _SEVERITY_COLORS.get(severity, GREEN)
icon = _SEVERITY_ICONS.get(severity, "")
_out(f" {color}{BOLD}{icon} {short_text}{RESET}")
if result.display and result.display != short_text:
_out(f" {DIM}Display: {result.display}{RESET}")
_out(f" {DIM}Status code: {result.status_code}{RESET}")
if action:
_out()
_out(f" {color}{BOLD}Action:{RESET} {color}{action}{RESET}")
_out()
_out(_SEVERITY_SUMMARIES.get(severity, ""))
def _display_pjl_status(result: USBResult) -> None:
"""Display PJL status code interpretation."""
_out()
_out(f"{BOLD}── Printer Status ──{RESET}")
_out()
if not result.status_code:
_out(f" {YELLOW}Could not read status from printer.{RESET}")
if result.display:
_out(f" Display message: {BOLD}{result.display}{RESET}")
return
severity, short_text, action = get_status_info(result.status_code)
_format_status_detail(severity, short_text, action, result)
def _display_cups_fallback_note(result: USBResult) -> None:
"""Show a note when running in CUPS fallback mode."""
_out()
if result.port_status is not None:
_out(
f" {DIM}Note: Hardware status obtained via USB port query."
f" Toner/drum percentages not available.{RESET}"
)
else:
_out(
f" {DIM}Note: pyusb not available; status obtained via"
f" CUPS only. Detailed toner/drum levels are not"
f" available in this mode.{RESET}"
)
# ── USB results display ─────────────────────────────────────────────
def display_usb_results(result: USBResult) -> None:
"""Print a formatted report for USB PJL query results."""
if result.error:
_out(f"{RED}Error: {result.error}{RESET}")
sys.exit(1)
_display_report_header()
_display_usb_device_info(result)
_display_pjl_status(result)
if result.device == "cups":
_display_cups_fallback_note(result)
_out()
_display_page_count_estimate()
_display_consumables_reference()
queue = get_cups_queue_status()
display_cups_queue_status(queue)
# ── Network supply level helpers ─────────────────────────────────────
def _classify_percentage_level(desc: str, pct: int) -> tuple[int, str, str, str, bool]:
"""Classify a supply by its calculated percentage."""
if pct <= SUPPLY_LOW_PCT:
return pct, f"{pct}%", RED, f"{desc} at {pct}%.", True
if pct <= SUPPLY_WARN_PCT:
return pct, f"{pct}%", YELLOW, f"{desc} at {pct}% -- order soon.", False
return pct, f"{pct}%", GREEN, "", False
def _classify_supply_level(
desc: str, max_val: int, level: int
) -> tuple[int, str, str, str, bool]:
"""Classify a supply level. Returns (pct, status, color, warning, replace)."""
if level == SNMP_LEVEL_OK:
return -1, "OK", GREEN, "", False
if level == SNMP_LEVEL_LOW:
return -1, "LOW", RED, f"{desc} is LOW.", True
if level == 0:
return 0, "EMPTY", RED, f"{desc} is EMPTY -- replace now!", True
if max_val > 0:
pct = min(level * 100 // max_val, 100)
return _classify_percentage_level(desc, pct)
return -1, "", GREEN, "", False
def _format_supply_bar(pct: int) -> str:
"""Build a progress bar string for a supply percentage."""
if pct < 0:
return ""
filled = pct * PROGRESS_BAR_WIDTH // 100
empty = PROGRESS_BAR_WIDTH - filled
return f"[{'' * filled}{'' * empty}]"
def _process_supply_item(desc: str, max_val: int, level: int) -> SupplyStatus:
"""Process a single supply item into display info."""
pct, status_text, color, warning, needs_replacement = _classify_supply_level(
desc, max_val, level
)
bar = _format_supply_bar(pct)
return SupplyStatus(color, bar, status_text, warning, needs_replacement)
def _display_supply_warnings(*, needs_replacement: bool, warnings: list[str]) -> None:
"""Display supply level warnings summary."""
_out()
if needs_replacement:
_out(f"{RED}{BOLD}⚠ ACTION NEEDED:{RESET}")
for w in warnings:
_out(f" {RED}{w}{RESET}")
elif warnings:
_out(f"{YELLOW}{BOLD}⚡ HEADS UP:{RESET}")
for w in warnings:
_out(f" {YELLOW}{w}{RESET}")
else:
_out(f"{GREEN}{BOLD}✓ All consumables are at healthy levels.{RESET}")
def _parse_supply_value(values: list[str], index: int) -> int:
"""Safely parse an integer from a supply value list."""
try:
return int(values[index])
except (IndexError, ValueError):
return 0
def _collect_supply_items(
result: NetworkResult,
) -> tuple[list[SupplyStatus], list[str]]:
"""Parse and collect supply items with their descriptions."""
items: list[SupplyStatus] = []
descs: list[str] = []
for i, desc in enumerate(result.supply_descriptions):
max_val = _parse_supply_value(result.supply_max, i)
level = _parse_supply_value(result.supply_levels, i)
items.append(_process_supply_item(desc, max_val, level))
descs.append(desc)
return items, descs
def _display_supply_levels(result: NetworkResult) -> None:
"""Display consumable supply levels section."""
_out()
_out(f"{BOLD}── Consumable Levels ──{RESET}")
_out()
needs_replacement = False
warnings: list[str] = []
items, descs = _collect_supply_items(result)
for desc, item in zip(descs, items, strict=True):
_out(
f" {BOLD}{desc:<25}{RESET}"
f" {item.color}{item.bar} {item.status_text}{RESET}"
)
if item.needs_replacement:
needs_replacement = True
if item.warning:
warnings.append(item.warning)
_display_supply_warnings(needs_replacement=needs_replacement, warnings=warnings)
def _display_network_device_info(result: NetworkResult) -> None:
"""Display device info section for network results."""
_out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}")
_out(f"{BOLD}Connection:{RESET} Network ({result.ip})")
if result.serial:
_out(f"{BOLD}Serial:{RESET} {result.serial}")
if result.display:
_out(f"{BOLD}Display:{RESET} {result.display}")
if result.page_count and result.page_count.isdigit():
_out(f"{BOLD}Pages:{RESET} {result.page_count} total")
# ── Network results display ──────────────────────────────────────────
def display_network_results(result: NetworkResult) -> None:
"""Print a formatted report for SNMP network query results."""
if result.error:
_out(f"{RED}Error: {result.error}{RESET}")
sys.exit(1)
_display_report_header()
_display_network_device_info(result)
_display_supply_levels(result)
_out()
_out(
f"{CYAN}Tip: Visit http://{result.ip} for the full web management"
f" interface.{RESET}"
)
_out()

View File

@ -0,0 +1,97 @@
"""SNMP network query functions for Brother printers."""
from __future__ import annotations
import shutil
import subprocess
from python_pkg.brother_printer.data_classes import NetworkResult
def _snmpwalk_cmd(
path: str, community: str, timeout: int, ip: str, oid: str
) -> list[str]:
"""Build the snmpwalk command arguments."""
return [path, "-v", "2c", "-c", community, "-t", str(timeout), "-OQvs", ip, oid]
def snmp_walk(ip: str, oid: str, community: str, timeout: int) -> list[str]:
"""Run snmpwalk and return cleaned values."""
snmpwalk_path = shutil.which("snmpwalk")
if not snmpwalk_path:
return []
try:
r = subprocess.run(
_snmpwalk_cmd(snmpwalk_path, community, timeout, ip, oid),
capture_output=True,
text=True,
timeout=15,
check=False,
)
return [
line.strip().strip('"')
for line in r.stdout.strip().splitlines()
if line.strip()
]
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
return []
def _snmpget_cmd(
path: str, community: str, timeout: int, ip: str, oid: str
) -> list[str]:
"""Build the snmpget command arguments."""
return [path, "-v", "2c", "-c", community, "-t", str(timeout), ip, oid]
def _check_snmp_connectivity(ip: str, community: str, timeout: int) -> str | None:
"""Verify SNMP connectivity. Returns error message or None on success."""
snmpget_path = shutil.which("snmpget")
if not snmpget_path:
return "snmpget not found. Install: sudo pacman -S net-snmp"
try:
subprocess.run(
_snmpget_cmd(
snmpget_path,
community,
timeout,
ip,
"1.3.6.1.2.1.43.11.1.1.6.1.1",
),
capture_output=True,
timeout=10,
check=True,
)
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
return f"Cannot reach printer at {ip} via SNMP."
return None
def _build_network_result(ip: str, community: str, timeout: int) -> NetworkResult:
"""Collect all SNMP data into a NetworkResult."""
def walk(oid: str) -> list[str]:
return snmp_walk(ip, oid, community, timeout)
return NetworkResult(
ip=ip,
product=" ".join(walk("1.3.6.1.2.1.25.3.2.1.3")[:1]) or "Unknown",
serial=" ".join(walk("1.3.6.1.2.1.43.5.1.1.17")[:1]) or "",
printer_status=" ".join(walk("1.3.6.1.2.1.25.3.5.1.1")[:1]) or "",
device_status=" ".join(walk("1.3.6.1.2.1.25.3.2.1.5")[:1]) or "",
display=" ".join(walk("1.3.6.1.2.1.43.16.5.1.2")[:3]) or "",
page_count=" ".join(walk("1.3.6.1.2.1.43.10.2.1.4")[:1]) or "",
supply_descriptions=walk("1.3.6.1.2.1.43.11.1.1.6"),
supply_max=walk("1.3.6.1.2.1.43.11.1.1.8"),
supply_levels=walk("1.3.6.1.2.1.43.11.1.1.9"),
)
def query_network_snmp(ip: str) -> NetworkResult:
"""Query a Brother printer via SNMP over the network."""
community = "public"
timeout = 5
error = _check_snmp_connectivity(ip, community, timeout)
if error:
return NetworkResult(ip=ip, error=error)
return _build_network_result(ip, community, timeout)

View File

@ -0,0 +1,233 @@
"""USB printer discovery and PJL query functions."""
from __future__ import annotations
import contextlib
import fcntl
import os
from pathlib import Path
import select
import shutil
import subprocess
import time
from typing import TYPE_CHECKING
import urllib.parse
from python_pkg.brother_printer.data_classes import USBResult
if TYPE_CHECKING:
from collections.abc import Callable
import logging
logger = logging.getLogger(__name__)
# ── USB printer discovery ────────────────────────────────────────────
def find_brother_usb() -> str:
"""Look for any Brother printer on USB via lsusb. Returns the info line."""
if not shutil.which("lsusb"):
return ""
try:
r = subprocess.run(
["/usr/bin/lsusb"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if "04f9:" in line.lower():
return line.split(": ", 1)[1] if ": " in line else line
except (subprocess.TimeoutExpired, OSError):
pass
return ""
def find_usb_printer_dev() -> str | None:
"""Find /dev/usb/lp* device for the Brother printer."""
devices = sorted(Path("/dev/usb").glob("lp*"))
return str(devices[0]) if devices else None
def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None:
"""Extract product and serial from a CUPS usb:// URI."""
parsed = urllib.parse.urlparse(uri)
info["product"] = urllib.parse.unquote(parsed.path.lstrip("/"))
qs = urllib.parse.parse_qs(parsed.query)
if "serial" in qs:
info["serial"] = qs["serial"][0]
def get_printer_info_from_cups() -> dict[str, str]:
"""Get printer model/serial from lpstat."""
info: dict[str, str] = {"product": "", "serial": ""}
try:
r = subprocess.run(
["/usr/bin/lpstat", "-v"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
for line in r.stdout.splitlines():
if "Brother" in line:
for part in line.split():
if part.startswith("usb://"):
_parse_cups_usb_uri(part, info)
break
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
logger.debug("Failed to query CUPS for printer info", exc_info=True)
return info
# ── PJL over USB ─────────────────────────────────────────────────────
def _drain_buffer(fd: int) -> None:
"""Read and discard any stale data from the USB buffer."""
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
with contextlib.suppress(OSError):
while os.read(fd, 4096):
pass
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
def _read_nonblocking(fd: int, flags: int) -> bytes:
"""Read all available data from fd in non-blocking mode."""
data = b""
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
with contextlib.suppress(OSError):
while True:
chunk = os.read(fd, 4096)
if not chunk:
break
data += chunk
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
return data
def _wait_for_pjl_response(fd: int, flags: int, deadline: float) -> bytes:
"""Poll fd until PJL data arrives or deadline expires."""
response = b""
while time.time() < deadline:
remaining = deadline - time.time()
if remaining <= 0:
break
readable, _, _ = select.select([fd], [], [], min(remaining, 1.0))
if readable:
response += _read_nonblocking(fd, flags)
if response and (b"=" in response or b"@PJL" in response):
break
return response
def pjl_query(fd: int, cmd: str, timeout_sec: float = 5.0) -> str:
"""Send a PJL command via raw fd and read the response."""
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
pjl_cmd = f"\x1b%-12345X@PJL\r\n{cmd}\r\n\x1b%-12345X"
os.write(fd, pjl_cmd.encode())
deadline = time.time() + timeout_sec
response = _wait_for_pjl_response(fd, flags, deadline)
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
return response.decode("ascii", errors="replace")
def _parse_status(resp: str, result: USBResult) -> bool:
"""Parse STATUS response into result. Returns True if code was found."""
found = False
for raw_line in resp.splitlines():
stripped = raw_line.strip()
if stripped.startswith("CODE="):
result.status_code = stripped.split("=", 1)[1]
found = True
elif stripped.startswith("DISPLAY="):
result.display = stripped.split("=", 1)[1].strip().strip('"').strip()
elif stripped.startswith("ONLINE="):
result.online = stripped.split("=", 1)[1]
return found
def _parse_variables(resp: str, result: USBResult) -> bool:
"""Parse VARIABLES response into result. Returns True if data found."""
found = False
for raw_line in resp.splitlines():
stripped = raw_line.strip()
if stripped.startswith("ECONOMODE="):
result.economode = stripped.split("=", 1)[1].split()[0]
found = True
return found
def _retry_pjl_query(
fd: int,
cmd: str,
parser: Callable[[str, USBResult], bool],
result: USBResult,
max_retries: int,
) -> None:
"""Send a PJL query with retries, draining between attempts."""
for attempt in range(max_retries + 1):
resp = pjl_query(fd, cmd)
if parser(resp, result):
break
if attempt < max_retries:
_drain_buffer(fd)
time.sleep(0.5)
def _run_pjl_queries(fd: int, result: USBResult, max_retries: int) -> None:
"""Execute PJL query sequence on an open file descriptor."""
_drain_buffer(fd)
os.write(fd, b"\x1b%-12345X@PJL\r\n\x1b%-12345X")
time.sleep(0.5)
_drain_buffer(fd)
_retry_pjl_query(fd, "@PJL INFO STATUS", _parse_status, result, max_retries)
_drain_buffer(fd)
time.sleep(0.5)
_retry_pjl_query(fd, "@PJL INFO VARIABLES", _parse_variables, result, max_retries)
def _init_usb_result(dev_path: str) -> USBResult:
"""Create a USBResult with device info from CUPS."""
cups_info = get_printer_info_from_cups()
return USBResult(
device=dev_path,
product=cups_info.get("product") or "Brother Laser Printer",
serial=cups_info.get("serial", ""),
)
def query_usb_pjl(max_retries: int = 2) -> USBResult:
"""Query a Brother printer via PJL over /dev/usb/lp*."""
dev_path = find_usb_printer_dev()
if not dev_path:
from python_pkg.brother_printer.cups_service import query_usb_via_cups
return query_usb_via_cups()
result = _init_usb_result(dev_path)
if not os.access(dev_path, os.R_OK | os.W_OK):
result.error = f"Permission denied: {dev_path}. Run with sudo."
return result
fd: int | None = None
try:
fd = os.open(dev_path, os.O_RDWR)
fcntl.fcntl(fd, fcntl.F_GETFL)
_run_pjl_queries(fd, result, max_retries)
except OSError as e:
result.error = str(e)
finally:
if fd is not None:
os.close(fd)
return result

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,206 @@
"""Shared geographic data module for Warsaw and Poland Anki generators.
This module handles downloading and caching geographic data from various sources:
- OpenStreetMap via Overpass API
- Geofabrik OSM extracts
- GitHub repositories with pre-processed GeoJSON
All data is cached locally to avoid repeated downloads.
"""
from __future__ import annotations
import shutil
import sys
from python_pkg.geo_data._common import (
CACHE_DIR,
MAX_RETRIES,
MIN_LAKE_AREA_KM2,
MIN_LINE_COORDS,
MIN_PEAK_ELEVATION,
MIN_RING_COORDS,
MIN_RIVER_LENGTH_KM,
OVERPASS_ENDPOINTS,
POLSKA_GEOJSON_BASE,
REQUEST_TIMEOUT,
RETRY_DELAY,
WIKIDATA_SPARQL,
)
from python_pkg.geo_data._poland_admin import (
get_poland_boundary,
get_polish_gminy,
get_polish_powiaty,
get_polish_wojewodztwa,
)
from python_pkg.geo_data._poland_nature import (
get_polish_forests,
get_polish_landscape_parks,
get_polish_mountain_peaks,
get_polish_mountain_ranges,
get_polish_national_parks,
get_polish_nature_reserves,
)
from python_pkg.geo_data._poland_water import (
get_polish_coastal_features,
get_polish_islands,
get_polish_lakes,
get_polish_rivers,
get_polish_unesco_sites,
)
from python_pkg.geo_data._warsaw import (
get_vistula_river,
get_warsaw_boundary,
get_warsaw_bridges,
get_warsaw_districts,
get_warsaw_metro_stations,
get_warsaw_osiedla,
)
from python_pkg.geo_data._warsaw_places import get_warsaw_landmarks, get_warsaw_streets
__all__ = [
"CACHE_DIR",
"MAX_RETRIES",
"MIN_LAKE_AREA_KM2",
"MIN_LINE_COORDS",
"MIN_PEAK_ELEVATION",
"MIN_RING_COORDS",
"MIN_RIVER_LENGTH_KM",
"OVERPASS_ENDPOINTS",
"POLSKA_GEOJSON_BASE",
"REQUEST_TIMEOUT",
"RETRY_DELAY",
"WIKIDATA_SPARQL",
"clear_cache",
"download_all_poland_data",
"download_all_warsaw_data",
"get_poland_boundary",
"get_polish_coastal_features",
"get_polish_forests",
"get_polish_gminy",
"get_polish_islands",
"get_polish_lakes",
"get_polish_landscape_parks",
"get_polish_mountain_peaks",
"get_polish_mountain_ranges",
"get_polish_national_parks",
"get_polish_nature_reserves",
"get_polish_powiaty",
"get_polish_rivers",
"get_polish_unesco_sites",
"get_polish_wojewodztwa",
"get_vistula_river",
"get_warsaw_boundary",
"get_warsaw_bridges",
"get_warsaw_districts",
"get_warsaw_landmarks",
"get_warsaw_metro_stations",
"get_warsaw_osiedla",
"get_warsaw_streets",
]
def download_all_warsaw_data() -> None:
"""Download and cache all Warsaw geographic data.
Call this once to pre-populate the cache.
"""
sys.stdout.write("Downloading all Warsaw geographic data...\n")
sys.stdout.write("=" * 60 + "\n")
sys.stdout.write("\n1. Warsaw boundary...\n")
get_warsaw_boundary()
sys.stdout.write("\n2. Vistula river...\n")
get_vistula_river()
sys.stdout.write("\n3. Warsaw bridges...\n")
get_warsaw_bridges()
sys.stdout.write("\n4. Metro stations...\n")
get_warsaw_metro_stations()
sys.stdout.write("\n5. Major streets...\n")
get_warsaw_streets()
sys.stdout.write("\n6. Landmarks...\n")
get_warsaw_landmarks()
sys.stdout.write("\n7. Osiedla...\n")
get_warsaw_osiedla()
sys.stdout.write("\n" + "=" * 60 + "\n")
sys.stdout.write("All Warsaw data cached successfully!\n")
def download_all_poland_data() -> None:
"""Download and cache all Poland geographic data.
Call this once to pre-populate the cache.
"""
sys.stdout.write("Downloading all Poland geographic data...\n")
sys.stdout.write("=" * 60 + "\n")
sys.stdout.write("\n1. Województwa...\n")
get_polish_wojewodztwa()
sys.stdout.write("\n2. Powiaty...\n")
get_polish_powiaty()
sys.stdout.write("\n3. Gminy (this may take a while)...\n")
get_polish_gminy()
sys.stdout.write("\n4. Poland boundary...\n")
get_poland_boundary()
sys.stdout.write("\n" + "=" * 60 + "\n")
sys.stdout.write("All Poland data cached successfully!\n")
def clear_cache() -> None:
"""Clear all cached data."""
if CACHE_DIR.exists():
shutil.rmtree(CACHE_DIR)
sys.stdout.write("Cache cleared.\n")
else:
sys.stdout.write("Cache directory does not exist.\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Manage geographic data cache")
parser.add_argument(
"--download-warsaw",
action="store_true",
help="Download all Warsaw data",
)
parser.add_argument(
"--download-poland",
action="store_true",
help="Download all Poland data",
)
parser.add_argument(
"--download-all",
action="store_true",
help="Download all data",
)
parser.add_argument(
"--clear-cache",
action="store_true",
help="Clear cached data",
)
args = parser.parse_args()
if args.clear_cache:
clear_cache()
elif args.download_warsaw:
download_all_warsaw_data()
elif args.download_poland:
download_all_poland_data()
elif args.download_all:
download_all_warsaw_data()
download_all_poland_data()
else:
parser.print_help()

View File

@ -0,0 +1,318 @@
"""Common utilities for geographic data operations.
Shared constants, API helpers, and geometry extraction functions used
across the geo_data package.
"""
from __future__ import annotations
import json
from pathlib import Path
import sys
import time
from typing import TYPE_CHECKING
from urllib.request import urlopen
import geopandas as gpd
import requests
from shapely.geometry import (
GeometryCollection,
MultiPolygon,
Polygon,
)
if TYPE_CHECKING:
from typing import Any
# Parent directory of the geo_data package (i.e. python_pkg/)
_PKG_DIR = Path(__file__).resolve().parent.parent
# Shared cache directory for all geo data
CACHE_DIR = _PKG_DIR / "geo_cache"
# Overpass API endpoints (multiple for redundancy)
# Note: kumi.systems is more reliable, so it's first
OVERPASS_ENDPOINTS = [
"https://overpass.kumi.systems/api/interpreter",
"https://overpass-api.de/api/interpreter",
"https://maps.mail.ru/osm/tools/overpass/api/interpreter",
]
# GitHub URLs for pre-processed data
POLSKA_GEOJSON_BASE = "https://raw.githubusercontent.com/ppatrzyk/polska-geojson/master"
# Wikidata SPARQL endpoint
WIKIDATA_SPARQL = "https://query.wikidata.org/sparql"
# Request timeout and retry settings
REQUEST_TIMEOUT = 180
MAX_RETRIES = 3
RETRY_DELAY = 5
# Data thresholds for filtering
MIN_PEAK_ELEVATION = 300 # meters
MIN_LAKE_AREA_KM2 = 0.5 # km²
MIN_RIVER_LENGTH_KM = 10 # km
MIN_LINE_COORDS = 2 # minimum coordinates for a line
MIN_RING_COORDS = 4 # minimum coordinates for a polygon ring
def _ensure_cache_dir() -> None:
"""Create cache directory if it doesn't exist."""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
def _extract_polygonal_geometry(
geom: Polygon | MultiPolygon | GeometryCollection,
) -> Polygon | MultiPolygon | None:
"""Extract only polygonal geometry from a geometry that may be mixed.
Some OSM data comes as GeometryCollections containing polygons mixed with
lines. This function extracts only the polygon/multipolygon parts.
Args:
geom: Input geometry (Polygon, MultiPolygon, or GeometryCollection).
Returns:
Polygon or MultiPolygon with only the polygonal parts, or None if empty.
"""
if isinstance(geom, Polygon | MultiPolygon):
return geom
if isinstance(geom, GeometryCollection):
polygons = [g for g in geom.geoms if isinstance(g, Polygon | MultiPolygon)]
if not polygons:
return None
if len(polygons) == 1:
return polygons[0]
# Flatten MultiPolygons and combine all polygons
all_polys = []
for p in polygons:
if isinstance(p, Polygon):
all_polys.append(p)
elif isinstance(p, MultiPolygon):
all_polys.extend(p.geoms)
return MultiPolygon(all_polys)
return None
def _try_single_request(
endpoint: str, query: str
) -> tuple[dict[str, Any] | None, Exception | None]:
"""Try a single request to an endpoint.
Args:
endpoint: Overpass API endpoint URL.
query: Overpass QL query string.
Returns:
Tuple of (result, error). One will be None.
"""
try:
sys.stdout.write(f" Querying {endpoint}...\n")
response = requests.post(
endpoint,
data={"data": query},
timeout=REQUEST_TIMEOUT,
)
response.raise_for_status()
result = response.json()
except (requests.RequestException, requests.Timeout, ValueError) as e:
return None, e
else:
# Check for valid response with elements
if not isinstance(result, dict) or "elements" not in result:
return None, ValueError("Invalid response format")
return result, None
def _overpass_query(query: str) -> dict[str, Any]:
"""Execute an Overpass API query with retry logic.
Args:
query: Overpass QL query string.
Returns:
JSON response from the API.
Raises:
RuntimeError: If all endpoints fail.
"""
last_error: Exception | None = None
for endpoint in OVERPASS_ENDPOINTS:
for attempt in range(MAX_RETRIES):
result, error = _try_single_request(endpoint, query)
if result is not None:
return result
last_error = error
sys.stdout.write(f" Attempt {attempt + 1} failed: {error}\n")
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
msg = f"All Overpass API endpoints failed. Last error: {last_error}"
raise RuntimeError(msg)
def _download_github_geojson(url: str, cache_path: Path) -> gpd.GeoDataFrame:
"""Download GeoJSON from GitHub and cache it.
Args:
url: URL to download from.
cache_path: Path to cache the data.
Returns:
GeoDataFrame with the data.
"""
if cache_path.exists():
return gpd.read_file(cache_path)
sys.stdout.write(f"Downloading from {url}...\n")
if not url.startswith(("http://", "https://")):
msg = f"Unsupported URL scheme: {url}"
raise ValueError(msg)
with urlopen(url, timeout=REQUEST_TIMEOUT) as response:
data = json.loads(response.read().decode())
_ensure_cache_dir()
cache_path.write_text(json.dumps(data))
return gpd.GeoDataFrame.from_features(data["features"], crs="EPSG:4326")
def _extract_osiedla_rings(
element: dict[str, Any], min_coords: int
) -> tuple[list[list[tuple[float, float]]], list[list[tuple[float, float]]]]:
"""Extract outer and inner rings from an OSM relation.
Args:
element: OSM relation element.
min_coords: Minimum number of coordinates for a valid ring.
Returns:
Tuple of (outer_rings, inner_rings).
"""
outer_rings: list[list[tuple[float, float]]] = []
inner_rings: list[list[tuple[float, float]]] = []
for member in element.get("members", []):
if "geometry" not in member:
continue
ring = [(p["lon"], p["lat"]) for p in member["geometry"]]
if len(ring) < min_coords:
continue
# Close the ring if not closed
if ring[0] != ring[-1]:
ring.append(ring[0])
if member.get("role") == "outer":
outer_rings.append(ring)
elif member.get("role") == "inner":
inner_rings.append(ring)
return outer_rings, inner_rings
def _build_osiedla_geometry(
outer_rings: list[list[tuple[float, float]]],
inner_rings: list[list[tuple[float, float]]],
) -> dict[str, Any]:
"""Build GeoJSON geometry from outer and inner rings.
Args:
outer_rings: List of outer ring coordinates.
inner_rings: List of inner ring coordinates.
Returns:
GeoJSON geometry dict.
"""
if len(outer_rings) == 1:
return {
"type": "Polygon",
"coordinates": [outer_rings[0], *inner_rings],
}
# Multiple outer rings - create MultiPolygon
# Each polygon in a MultiPolygon is [exterior, hole1, hole2, ...]
return {
"type": "MultiPolygon",
"coordinates": [[ring] for ring in outer_rings],
}
def _extract_polygon_from_element(
element: dict[str, Any],
) -> dict[str, Any] | None:
"""Extract polygon geometry from an OSM relation or way element.
Args:
element: OSM element (relation or way).
Returns:
GeoJSON geometry dict, or None if extraction fails.
"""
if element.get("type") == "relation":
outer_rings, inner_rings = _extract_osiedla_rings(element, MIN_RING_COORDS)
if not outer_rings:
return None
return _build_osiedla_geometry(outer_rings, inner_rings)
if element.get("type") == "way" and "geometry" in element:
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) < MIN_RING_COORDS:
return None
if coords[0] != coords[-1]:
coords.append(coords[0])
return {"type": "Polygon", "coordinates": [coords]}
return None
def _extract_line_from_way(element: dict[str, Any]) -> dict[str, Any] | None:
"""Extract line geometry from an OSM way element.
Args:
element: OSM way element.
Returns:
GeoJSON LineString geometry dict, or None if extraction fails.
"""
if element.get("type") != "way" or "geometry" not in element:
return None
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) < MIN_LINE_COORDS:
return None
return {"type": "LineString", "coordinates": coords}
def _add_area_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""Add area_km2 column to a GeoDataFrame.
Args:
gdf: GeoDataFrame with polygon geometries.
Returns:
GeoDataFrame with area_km2 column added.
"""
if len(gdf) == 0:
return gdf
gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
return gdf
def _add_length_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""Add length_km column to a GeoDataFrame.
Args:
gdf: GeoDataFrame with line geometries.
Returns:
GeoDataFrame with length_km column added.
"""
if len(gdf) == 0:
return gdf
gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system
gdf["length_km"] = gdf_proj.geometry.length / 1000
return gdf

View File

@ -0,0 +1,226 @@
"""Polish administrative boundary data.
Functions for downloading and caching Polish administrative divisions:
województwa, powiaty, gminy, and the national boundary.
Includes Wikidata integration for population data.
"""
from __future__ import annotations
import contextlib
import json
import sys
from typing import TYPE_CHECKING
import geopandas as gpd
import requests
from python_pkg.geo_data._common import (
CACHE_DIR,
POLSKA_GEOJSON_BASE,
WIKIDATA_SPARQL,
_build_osiedla_geometry,
_download_github_geojson,
_ensure_cache_dir,
_extract_osiedla_rings,
_overpass_query,
)
if TYPE_CHECKING:
from typing import Any
def _query_wikidata(query: str) -> list[dict[str, Any]]:
"""Query Wikidata SPARQL endpoint.
Args:
query: SPARQL query string.
Returns:
List of result bindings.
"""
response = requests.get(
WIKIDATA_SPARQL,
params={"query": query, "format": "json"},
timeout=60,
)
response.raise_for_status()
return response.json()["results"]["bindings"]
def _get_powiaty_population() -> dict[str, int]:
"""Get population data for all Polish powiaty from Wikidata.
Returns:
Dictionary mapping powiat name to population.
"""
cache_path = CACHE_DIR / "powiaty_population.json"
if cache_path.exists():
return json.loads(cache_path.read_text())
# Query Wikidata for all powiaty (Q247073) in Poland (Q36) with population
# Filter to only current Polish powiaty using country=Poland filter
query = """
SELECT ?powiat ?powiatLabel ?population WHERE {
?powiat wdt:P31 wd:Q247073.
?powiat wdt:P17 wd:Q36.
?powiat wdt:P1082 ?population.
SERVICE wikibase:label { bd:serviceParam wikibase:language "pl,en". }
}
ORDER BY DESC(?population)
"""
sys.stdout.write("Fetching powiaty population data from Wikidata...\n")
results = _query_wikidata(query)
population_map: dict[str, int] = {}
for item in results:
label = item.get("powiatLabel", {}).get("value", "")
pop = item.get("population", {}).get("value", "0")
if label and pop:
# Remove "powiat" prefix if present for matching
clean_label = label.replace("powiat ", "").strip()
with contextlib.suppress(ValueError):
population_map[clean_label] = int(pop)
_ensure_cache_dir()
cache_path.write_text(json.dumps(population_map, ensure_ascii=False, indent=2))
sys.stdout.write(f"Cached population data for {len(population_map)} powiaty.\n")
return population_map
def get_polish_wojewodztwa() -> gpd.GeoDataFrame:
"""Get Polish województwa (voivodeships).
Returns:
GeoDataFrame with województwa boundaries.
"""
url = f"{POLSKA_GEOJSON_BASE}/wojewodztwa/wojewodztwa-min.geojson"
cache_path = CACHE_DIR / "polish_wojewodztwa.geojson"
return _download_github_geojson(url, cache_path)
def get_polish_powiaty() -> gpd.GeoDataFrame:
"""Get Polish powiaty (counties), sorted by population descending.
Returns:
GeoDataFrame with powiat boundaries and population.
"""
url = f"{POLSKA_GEOJSON_BASE}/powiaty/powiaty-min.geojson"
cache_path = CACHE_DIR / "polish_powiaty.geojson"
gdf = _download_github_geojson(url, cache_path)
# Get population data from Wikidata
population_map = _get_powiaty_population()
# Add population column
def get_population(nazwa: str) -> int:
"""Match powiat name to population data."""
if not nazwa:
return 0
# Remove "powiat " prefix for matching
clean_name = nazwa.replace("powiat ", "").strip()
# Try direct match
if clean_name in population_map:
return population_map[clean_name]
# Try lowercase
name_lower = clean_name.lower()
for pop_name, pop in population_map.items():
if pop_name.lower() == name_lower:
return pop
return 0
gdf["population"] = gdf["nazwa"].apply(get_population)
# Sort by population descending
return gdf.sort_values("population", ascending=False).reset_index(drop=True)
def get_polish_gminy() -> gpd.GeoDataFrame:
"""Get Polish gminy (municipalities) from OSM, sorted by area descending.
Returns:
GeoDataFrame with gminy boundaries.
"""
cache_path = CACHE_DIR / "polish_gminy.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching gminy data from OSM (this may take a while)...\n")
# Polish gminy are admin_level=7 in OSM
query = """
[out:json][timeout:300];
area["ISO3166-1"="PL"]->.pl;
relation["boundary"="administrative"]["admin_level"="7"]["name"](area.pl);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
min_ring_coords = 4
for element in data.get("elements", []):
if element.get("type") != "relation":
continue
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
if not outer_rings:
continue
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name},
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
sys.stdout.write(f"Cached {len(features)} gminy.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
# Add area column
from python_pkg.geo_data._common import _add_area_column
gdf = _add_area_column(gdf)
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
def get_poland_boundary() -> gpd.GeoDataFrame:
"""Get Poland country boundary.
Returns:
GeoDataFrame with Poland boundary.
"""
cache_path = CACHE_DIR / "poland_boundary.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
# Dissolve from województwa
woj = get_polish_wojewodztwa()
# Fix invalid geometries with buffer(0)
woj["geometry"] = woj["geometry"].buffer(0)
poland = gpd.GeoDataFrame(geometry=[woj.union_all()], crs=woj.crs)
_ensure_cache_dir()
poland.to_file(cache_path, driver="GeoJSON")
return poland

View File

@ -0,0 +1,446 @@
"""Polish natural land features.
Functions for downloading and caching data about Polish mountains,
national parks, forests, nature reserves, and landscape parks.
"""
from __future__ import annotations
import contextlib
import json
import sys
from typing import TYPE_CHECKING
import geopandas as gpd
from python_pkg.geo_data._common import (
CACHE_DIR,
MIN_PEAK_ELEVATION,
_add_area_column,
_build_osiedla_geometry,
_ensure_cache_dir,
_extract_osiedla_rings,
_extract_polygon_from_element,
_extract_polygonal_geometry,
_overpass_query,
)
if TYPE_CHECKING:
from typing import Any
def get_polish_mountain_peaks() -> gpd.GeoDataFrame:
"""Get Polish mountain peaks, sorted by elevation descending.
Returns:
GeoDataFrame with mountain peak points and elevation.
"""
cache_path = CACHE_DIR / "polish_mountain_peaks.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
return gdf.sort_values("elevation", ascending=False).reset_index(drop=True)
sys.stdout.write("Fetching mountain peaks data from OSM...\n")
query = """
[out:json][timeout:120];
area["ISO3166-1"="PL"]->.pl;
(
node["natural"="peak"]["name"]["ele"](area.pl);
);
out;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
if element.get("type") != "node":
continue
name = element.get("tags", {}).get("name", "")
ele_str = element.get("tags", {}).get("ele", "")
if not name or not ele_str or name in seen_names:
continue
with contextlib.suppress(ValueError):
elevation = float(ele_str.replace(",", ".").split()[0])
if elevation < MIN_PEAK_ELEVATION:
continue
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name, "elevation": elevation},
"geometry": {
"type": "Point",
"coordinates": [element["lon"], element["lat"]],
},
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} mountain peaks.\n")
if not features:
msg = "No mountain peaks found in OSM data"
raise ValueError(msg)
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
return gdf.sort_values("elevation", ascending=False).reset_index(drop=True)
def get_polish_mountain_ranges() -> gpd.GeoDataFrame:
"""Get Polish mountain ranges, sorted by area descending.
Returns:
GeoDataFrame with mountain range polygons.
"""
cache_path = CACHE_DIR / "polish_mountain_ranges.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
# Fix invalid geometries from OSM data and extract only polygons
gdf["geometry"] = gdf.geometry.make_valid()
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching mountain ranges data from OSM...\n")
query = """
[out:json][timeout:180];
area["ISO3166-1"="PL"]->.pl;
(
relation["natural"="mountain_range"]["name"](area.pl);
way["natural"="mountain_range"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features: list[dict[str, Any]] = []
seen_names: set[str] = set()
min_ring_coords = 4
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
if element.get("type") == "relation":
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
if not outer_rings:
continue
geometry = _build_osiedla_geometry(outer_rings, inner_rings)
elif element.get("type") == "way" and "geometry" in element:
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) < min_ring_coords:
continue
if coords[0] != coords[-1]:
coords.append(coords[0])
geometry = {"type": "Polygon", "coordinates": [coords]}
else:
continue
seen_names.add(name)
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} mountain ranges.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
# Fix invalid geometries from OSM data and extract only polygons
gdf["geometry"] = gdf.geometry.make_valid()
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
# Calculate area in km²
gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
def get_polish_national_parks() -> gpd.GeoDataFrame:
"""Get all 23 Polish national parks, sorted by area descending.
Returns:
GeoDataFrame with national park polygons.
"""
cache_path = CACHE_DIR / "polish_national_parks.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching national parks data from OSM...\n")
query = """
[out:json][timeout:180];
area["ISO3166-1"="PL"]->.pl;
(
relation["boundary"="national_park"]["name"](area.pl);
relation["leisure"="nature_reserve"]["name"]["protect_class"="2"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
min_ring_coords = 4
for element in data.get("elements", []):
if element.get("type") != "relation":
continue
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
# Filter to only include "Park Narodowy" in name
if "Narodowy" not in name and "narodowy" not in name.lower():
continue
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
if not outer_rings:
continue
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name},
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} national parks.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
# Calculate area in km²
gdf_proj = gdf.to_crs("EPSG:2180")
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
def get_polish_forests() -> gpd.GeoDataFrame:
"""Get major Polish forests (puszcze), sorted by area descending.
Returns:
GeoDataFrame with forest polygons.
"""
cache_path = CACHE_DIR / "polish_forests.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching forests data from OSM...\n")
# Query for named forests, especially "Puszcza" type
query = """
[out:json][timeout:300];
area["ISO3166-1"="PL"]->.pl;
(
relation["natural"="wood"]["name"](area.pl);
relation["landuse"="forest"]["name"~"Puszcza|Bory|Las"](area.pl);
way["natural"="wood"]["name"~"Puszcza|Bory"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
forest_keywords = ("Puszcza", "Bory", "Las ", "Lasy ")
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
if not any(keyword in name for keyword in forest_keywords):
continue
geometry = _extract_polygon_from_element(element)
if geometry is None:
continue
seen_names.add(name)
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} forests.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
gdf = _add_area_column(gdf)
if len(gdf) > 0:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
def get_polish_nature_reserves() -> gpd.GeoDataFrame:
"""Get Polish nature reserves, sorted by area descending.
Returns:
GeoDataFrame with nature reserve polygons.
"""
cache_path = CACHE_DIR / "polish_nature_reserves.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write(
"Fetching nature reserves data from OSM (this may take a while)...\n"
)
query = """
[out:json][timeout:600];
area["ISO3166-1"="PL"]->.pl;
(
relation["leisure"="nature_reserve"]["name"](area.pl);
way["leisure"="nature_reserve"]["name"](area.pl);
relation["boundary"="protected_area"]["protect_class"="4"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
geometry = _extract_polygon_from_element(element)
if geometry is None:
continue
seen_names.add(name)
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} nature reserves.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
gdf = _add_area_column(gdf)
if len(gdf) > 0:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
def get_polish_landscape_parks() -> gpd.GeoDataFrame:
"""Get Polish landscape parks, sorted by area descending.
Returns:
GeoDataFrame with landscape park polygons.
"""
cache_path = CACHE_DIR / "polish_landscape_parks.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
# Fix invalid geometries from OSM data and extract only polygons
gdf["geometry"] = gdf.geometry.make_valid()
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
# Remove any rows where geometry extraction failed
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching landscape parks data from OSM...\n")
query = """
[out:json][timeout:300];
area["ISO3166-1"="PL"]->.pl;
(
relation["boundary"="protected_area"]["protect_class"="5"]["name"](area.pl);
relation["leisure"="nature_reserve"]["name"~"Park Krajobrazowy"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
min_ring_coords = 4
for element in data.get("elements", []):
if element.get("type") != "relation":
continue
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
if not outer_rings:
continue
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name},
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} landscape parks.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
# Fix invalid geometries from OSM data and extract only polygons
gdf["geometry"] = gdf.geometry.make_valid()
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
# Remove any rows where geometry extraction failed
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
if len(gdf) > 0:
gdf_proj = gdf.to_crs("EPSG:2180")
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf

View File

@ -0,0 +1,437 @@
"""Polish water features and cultural sites.
Functions for downloading and caching data about Polish lakes, rivers,
islands, coastal features, and UNESCO World Heritage sites.
"""
from __future__ import annotations
import json
import sys
from typing import TYPE_CHECKING
import geopandas as gpd
from python_pkg.geo_data._common import (
CACHE_DIR,
MIN_LAKE_AREA_KM2,
MIN_LINE_COORDS,
MIN_RING_COORDS,
MIN_RIVER_LENGTH_KM,
_add_area_column,
_add_length_column,
_build_osiedla_geometry,
_ensure_cache_dir,
_extract_osiedla_rings,
_extract_polygon_from_element,
_overpass_query,
)
if TYPE_CHECKING:
from typing import Any
def _extract_coastal_geometry(
element: dict[str, Any],
natural_type: str,
line_types: tuple[str, ...],
) -> dict[str, Any] | None:
"""Extract geometry from a coastal feature element.
For cliffs and beaches, returns LineString. For others, returns Polygon.
Args:
element: OSM element.
natural_type: The natural= tag value.
line_types: Tuple of natural types that should be lines.
Returns:
GeoJSON geometry dict, or None if extraction fails.
"""
if element.get("type") == "relation":
return _extract_polygon_from_element(element)
if element.get("type") != "way" or "geometry" not in element:
return None
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) < MIN_LINE_COORDS:
return None
# For cliffs and beaches, keep as linestring
if natural_type in line_types:
return {"type": "LineString", "coordinates": coords}
# Otherwise try to make a polygon
if len(coords) >= MIN_RING_COORDS:
if coords[0] != coords[-1]:
coords.append(coords[0])
return {"type": "Polygon", "coordinates": [coords]}
return None
def _extract_river_coords_from_element(
element: dict[str, Any],
) -> list[list[tuple[float, float]]]:
"""Extract coordinate lists from a river element.
Args:
element: OSM element (way or relation).
Returns:
List of coordinate lists (line segments).
"""
coord_lists: list[list[tuple[float, float]]] = []
if element.get("type") == "way" and "geometry" in element:
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) >= MIN_LINE_COORDS:
coord_lists.append(coords)
elif element.get("type") == "relation":
for member in element.get("members", []):
if member.get("type") == "way" and "geometry" in member:
coords = [(p["lon"], p["lat"]) for p in member["geometry"]]
if len(coords) >= MIN_LINE_COORDS:
coord_lists.append(coords)
return coord_lists
def get_polish_lakes() -> gpd.GeoDataFrame:
"""Get Polish lakes, sorted by area descending.
Returns:
GeoDataFrame with lake polygons.
"""
cache_path = CACHE_DIR / "polish_lakes.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching lakes data from OSM...\n")
query = """
[out:json][timeout:300];
area["ISO3166-1"="PL"]->.pl;
(
relation["natural"="water"]["water"="lake"]["name"](area.pl);
way["natural"="water"]["water"="lake"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
geometry = _extract_polygon_from_element(element)
if geometry is None:
continue
seen_names.add(name)
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} lakes.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
gdf = _add_area_column(gdf)
if len(gdf) > 0:
# Filter to lakes > MIN_LAKE_AREA_KM2 to exclude tiny ponds
gdf = gdf[gdf["area_km2"] > MIN_LAKE_AREA_KM2]
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
def get_polish_rivers() -> gpd.GeoDataFrame:
"""Get Polish rivers, sorted by length descending.
Rivers with the same name but in different locations are kept separate
by using unique IDs from OSM when available.
Returns:
GeoDataFrame with river linestrings.
"""
cache_path = CACHE_DIR / "polish_rivers.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "length_km" in gdf.columns:
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching rivers data from OSM...\n")
query = """
[out:json][timeout:300];
area["ISO3166-1"="PL"]->.pl;
(
relation["waterway"="river"]["name"](area.pl);
way["waterway"="river"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
# Group ways by river name AND wikidata ID (or OSM ID for uniqueness)
# This prevents merging different rivers with the same name
rivers_by_key: dict[str, list[list[tuple[float, float]]]] = {}
river_names: dict[str, str] = {} # key -> display name
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name:
continue
# Use wikidata ID if available, otherwise use element type+id
wikidata = element.get("tags", {}).get("wikidata", "")
if wikidata:
key = f"{name}_{wikidata}"
else:
# Fall back to element ID for grouping related ways
key = f"{name}_{element.get('type')}_{element.get('id')}"
coord_lists = _extract_river_coords_from_element(element)
if coord_lists:
rivers_by_key.setdefault(key, []).extend(coord_lists)
river_names[key] = name
features = []
for key, coord_lists in rivers_by_key.items():
name = river_names[key]
geometry: dict[str, Any]
if len(coord_lists) == 1:
geometry = {"type": "LineString", "coordinates": coord_lists[0]}
else:
geometry = {"type": "MultiLineString", "coordinates": coord_lists}
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} rivers.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
gdf = _add_length_column(gdf)
if len(gdf) > 0:
gdf = gdf[gdf["length_km"] > MIN_RIVER_LENGTH_KM]
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
return gdf
def get_polish_islands() -> gpd.GeoDataFrame:
"""Get Polish islands, sorted by area descending.
Returns:
GeoDataFrame with island polygons.
"""
cache_path = CACHE_DIR / "polish_islands.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "area_km2" in gdf.columns:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching islands data from OSM...\n")
query = """
[out:json][timeout:180];
area["ISO3166-1"="PL"]->.pl;
(
relation["place"="island"]["name"](area.pl);
way["place"="island"]["name"](area.pl);
relation["place"="islet"]["name"](area.pl);
way["place"="islet"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
geometry = _extract_polygon_from_element(element)
if geometry is None:
continue
seen_names.add(name)
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} islands.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
gdf = _add_area_column(gdf)
if len(gdf) > 0:
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
return gdf
def get_polish_coastal_features() -> gpd.GeoDataFrame:
"""Get Polish coastal features (peninsulas, spits, cliffs), sorted by length.
Returns:
GeoDataFrame with coastal feature geometries.
"""
cache_path = CACHE_DIR / "polish_coastal_features.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
if "length_km" in gdf.columns:
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
return gdf
sys.stdout.write("Fetching coastal features data from OSM...\n")
query = """
[out:json][timeout:180];
area["ISO3166-1"="PL"]->.pl;
(
relation["natural"="peninsula"]["name"](area.pl);
way["natural"="peninsula"]["name"](area.pl);
relation["natural"="spit"]["name"](area.pl);
way["natural"="spit"]["name"](area.pl);
relation["natural"="cliff"]["name"](area.pl);
way["natural"="cliff"]["name"](area.pl);
relation["natural"="coastline"]["name"](area.pl);
way["natural"="beach"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
line_types = ("cliff", "beach", "coastline")
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
natural_type = element.get("tags", {}).get("natural", "")
if not name or name in seen_names:
continue
geometry = _extract_coastal_geometry(element, natural_type, line_types)
if geometry is None:
continue
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name, "type": natural_type},
"geometry": geometry,
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} coastal features.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
gdf = _add_length_column(gdf)
if len(gdf) > 0:
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
return gdf
def get_polish_unesco_sites() -> gpd.GeoDataFrame:
"""Get Polish UNESCO World Heritage Sites, sorted by inscription year.
Returns:
GeoDataFrame with UNESCO site geometries.
"""
cache_path = CACHE_DIR / "polish_unesco_sites.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
sys.stdout.write("Fetching UNESCO sites data from OSM...\n")
query = """
[out:json][timeout:180];
area["ISO3166-1"="PL"]->.pl;
(
relation["heritage"="world_heritage_site"]["name"](area.pl);
way["heritage"="world_heritage_site"]["name"](area.pl);
node["heritage"="world_heritage_site"]["name"](area.pl);
relation["heritage:operator"="whc"]["name"](area.pl);
way["heritage:operator"="whc"]["name"](area.pl);
node["heritage:operator"="whc"]["name"](area.pl);
);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
min_ring_coords = 4
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
if element.get("type") == "node":
geometry: dict[str, Any] = {
"type": "Point",
"coordinates": [element["lon"], element["lat"]],
}
elif element.get("type") == "relation":
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
if not outer_rings:
continue
geometry = _build_osiedla_geometry(outer_rings, inner_rings)
elif element.get("type") == "way" and "geometry" in element:
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) < min_ring_coords:
continue
if coords[0] != coords[-1]:
coords.append(coords[0])
geometry = {"type": "Polygon", "coordinates": [coords]}
else:
continue
seen_names.add(name)
features.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
sys.stdout.write(f"Cached {len(features)} UNESCO sites.\n")
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")

View File

@ -0,0 +1,407 @@
"""Warsaw geographic data functions.
Functions for downloading and caching Warsaw-specific geographic data:
boundaries, districts, Vistula river, bridges, metro stations, and osiedla.
"""
from __future__ import annotations
import json
import sys
import geopandas as gpd
from shapely.geometry import LineString
from python_pkg.geo_data._common import (
_PKG_DIR,
CACHE_DIR,
_build_osiedla_geometry,
_ensure_cache_dir,
_extract_osiedla_rings,
_overpass_query,
)
def get_warsaw_boundary() -> gpd.GeoDataFrame:
"""Get Warsaw city boundary.
Returns:
GeoDataFrame with Warsaw boundary polygon.
"""
cache_path = CACHE_DIR / "warsaw_boundary.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
# Try to use districts file first
districts_path = (
_PKG_DIR / "anki_decks" / "warsaw_districts" / "warszawa-dzielnice.geojson"
)
if districts_path.exists():
warsaw_gdf = gpd.read_file(districts_path)
warsaw_boundary = warsaw_gdf[warsaw_gdf["name"] == "Warszawa"]
if len(warsaw_boundary) == 0:
warsaw_boundary = gpd.GeoDataFrame(
geometry=[warsaw_gdf.union_all()], crs=warsaw_gdf.crs
)
_ensure_cache_dir()
warsaw_boundary.to_file(cache_path, driver="GeoJSON")
return warsaw_boundary
# Fallback to Overpass query
sys.stdout.write("Fetching Warsaw boundary from OpenStreetMap...\n")
query = """
[out:json][timeout:60];
relation["name"="Warszawa"]["admin_level"="6"];
out geom;
"""
data = _overpass_query(query)
features = []
for element in data.get("elements", []):
if element.get("type") == "relation":
coords = []
for member in element.get("members", []):
if member.get("role") == "outer" and "geometry" in member:
coords.extend([(p["lon"], p["lat"]) for p in member["geometry"]])
if coords:
features.append(
{
"type": "Feature",
"properties": {"name": "Warszawa"},
"geometry": {"type": "Polygon", "coordinates": [coords]},
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
def get_warsaw_districts() -> gpd.GeoDataFrame:
"""Get Warsaw districts (dzielnice).
Returns:
GeoDataFrame with district boundaries.
"""
districts_path = (
_PKG_DIR / "anki_decks" / "warsaw_districts" / "warszawa-dzielnice.geojson"
)
if districts_path.exists():
gdf = gpd.read_file(districts_path)
return gdf[gdf["name"] != "Warszawa"].copy()
msg = "Warsaw districts GeoJSON not found"
raise FileNotFoundError(msg)
def get_vistula_river() -> gpd.GeoDataFrame:
"""Get Vistula river in Warsaw.
Returns:
GeoDataFrame with river geometry.
"""
cache_path = CACHE_DIR / "warsaw_vistula.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
sys.stdout.write("Fetching Vistula river data...\n")
query = """
[out:json][timeout:60];
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
(
way["waterway"="river"]["name"="Wisła"](area.warsaw);
);
out geom;
"""
data = _overpass_query(query)
features = []
min_coords = 2
for element in data.get("elements", []):
if element.get("type") == "way" and "geometry" in element:
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) >= min_coords:
features.append(
{
"type": "Feature",
"properties": {"name": "Wisła"},
"geometry": {"type": "LineString", "coordinates": coords},
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
def get_warsaw_bridges() -> gpd.GeoDataFrame:
"""Get Warsaw bridges over the Vistula.
Returns:
GeoDataFrame with bridge geometries.
"""
cache_path = CACHE_DIR / "warsaw_bridges.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
sys.stdout.write("Fetching Warsaw bridges data...\n")
# First get the Vistula to filter bridges
vistula = get_vistula_river()
vistula_union = vistula.union_all()
vistula_buffer = vistula_union.buffer(0.002) # ~200m buffer
# Query for bridges with "Most" in name - smaller query
query = """
[out:json][timeout:90];
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
way["bridge"="yes"]["name"~"^Most"](area.warsaw);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
min_coords = 2
for element in data.get("elements", []):
if element.get("type") != "way" or "geometry" not in element:
continue
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) < min_coords:
continue
line = LineString(coords)
# Check if bridge crosses/is near Vistula
if line.intersects(vistula_buffer):
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name, "osm_id": element.get("id")},
"geometry": {"type": "LineString", "coordinates": coords},
}
)
# Merge segments of the same bridge
merged_features = _merge_bridge_segments(features)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": merged_features}
cache_path.write_text(json.dumps(geojson))
sys.stdout.write(f"Cached {len(merged_features)} bridges.\n")
return gpd.GeoDataFrame.from_features(merged_features, crs="EPSG:4326")
def _merge_bridge_segments(features: list[dict]) -> list[dict]:
"""Merge bridge segments with the same name.
Args:
features: List of GeoJSON features.
Returns:
List of merged features.
"""
by_name: dict[str, list[list[tuple[float, float]]]] = {}
for feature in features:
name = feature["properties"]["name"]
coords = feature["geometry"]["coordinates"]
if name not in by_name:
by_name[name] = []
by_name[name].append(coords)
merged = []
for name, coord_lists in by_name.items():
if len(coord_lists) == 1:
geom = {"type": "LineString", "coordinates": coord_lists[0]}
else:
geom = {"type": "MultiLineString", "coordinates": coord_lists}
merged.append(
{"type": "Feature", "properties": {"name": name}, "geometry": geom}
)
return merged
def get_warsaw_metro_stations() -> gpd.GeoDataFrame:
"""Get Warsaw metro stations with line information.
Returns:
GeoDataFrame with station points and line info (M1, M2, or M1/M2).
"""
cache_path = CACHE_DIR / "warsaw_metro.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
# Known stations for each line (as of 2024)
m1_stations = {
"Kabaty",
"Natolin",
"Imielin",
"Stokłosy",
"Ursynów",
"Służew",
"Wilanowska",
"Wierzbno",
"Racławicka",
"Pole Mokotowskie",
"Politechnika",
"Centrum",
"Świętokrzyska", # Also M2
"Ratusz-Arsenał",
"Dworzec Gdański",
"Plac Wilsona",
"Marymont",
"Słodowiec",
"Stare Bielany",
"Wawrzyszew",
"Młociny",
}
m2_stations = {
"Bródno",
"Kondratowicza",
"Zacisze",
"Targówek Mieszkaniowy",
"Trocka",
"Szwedzka",
"Dworzec Wileński",
"Świętokrzyska", # Also M1
"Nowy Świat-Uniwersytet",
"Centrum Nauki Kopernik",
"Stadion Narodowy",
"Rondo ONZ",
"Rondo Daszyńskiego",
"Płocka",
"Młynów",
"Księcia Janusza",
"Ulrychów",
"Bemowo",
}
sys.stdout.write("Fetching metro station data...\n")
query = """
[out:json][timeout:60];
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
(
node["railway"="station"]["station"="subway"](area.warsaw);
node["railway"="station"]["network"~"Metro"](area.warsaw);
);
out body;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
if element.get("type") == "node":
name = element.get("tags", {}).get("name", "")
if name and name not in seen_names:
seen_names.add(name)
# Determine line from known station lists
in_m1 = name in m1_stations
in_m2 = name in m2_stations
if in_m1 and in_m2:
line = "M1/M2"
elif in_m1:
line = "M1"
elif in_m2:
line = "M2"
else:
line = "?" # Unknown station
features.append(
{
"type": "Feature",
"properties": {
"name": name,
"line": line,
},
"geometry": {
"type": "Point",
"coordinates": [element["lon"], element["lat"]],
},
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
sys.stdout.write(f"Cached {len(features)} metro stations.\n")
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
def get_warsaw_osiedla() -> gpd.GeoDataFrame:
"""Get Warsaw osiedla (neighborhoods).
Returns:
GeoDataFrame with osiedla boundaries.
"""
cache_path = CACHE_DIR / "warsaw_osiedla.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
sys.stdout.write("Fetching osiedla data...\n")
query = """
[out:json][timeout:180];
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
relation["boundary"="administrative"]["admin_level"="11"]["name"](area.warsaw);
out geom;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
min_ring_coords = 4
for element in data.get("elements", []):
if element.get("type") != "relation":
continue
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
if not outer_rings:
continue
seen_names.add(name)
features.append(
{
"type": "Feature",
"properties": {"name": name},
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
sys.stdout.write(f"Cached {len(features)} osiedla.\n")
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")

View File

@ -0,0 +1,186 @@
"""Warsaw streets, landmarks, and place data.
Functions for downloading and caching Warsaw streets, landmarks,
and other place-related geographic data.
"""
from __future__ import annotations
import json
import sys
import geopandas as gpd
from shapely.geometry import MultiLineString
from python_pkg.geo_data._common import CACHE_DIR, _ensure_cache_dir, _overpass_query
def get_warsaw_streets(min_length: int = 500) -> gpd.GeoDataFrame:
"""Get major Warsaw streets.
Args:
min_length: Minimum street length in meters.
Returns:
GeoDataFrame with street geometries.
"""
cache_path = CACHE_DIR / "warsaw_streets.geojson"
if cache_path.exists():
gdf = gpd.read_file(cache_path)
# Filter by length if needed
return _filter_streets_by_length(gdf, min_length)
sys.stdout.write("Fetching street data from OpenStreetMap...\n")
query = """
[out:json][timeout:120];
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
(
way["highway"="primary"]["name"](area.warsaw);
way["highway"="secondary"]["name"](area.warsaw);
way["highway"="tertiary"]["name"](area.warsaw);
);
out geom;
"""
data = _overpass_query(query)
features = []
min_coords = 2
for element in data.get("elements", []):
if element.get("type") == "way" and "geometry" in element:
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
if len(coords) >= min_coords:
features.append(
{
"type": "Feature",
"properties": {
"name": element.get("tags", {}).get("name", "Unknown"),
"highway": element.get("tags", {}).get("highway", ""),
},
"geometry": {"type": "LineString", "coordinates": coords},
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
sys.stdout.write(f"Cached {len(features)} street segments.\n")
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
return _filter_streets_by_length(gdf, min_length)
def _filter_streets_by_length(
gdf: gpd.GeoDataFrame, min_length: int
) -> gpd.GeoDataFrame:
"""Filter and merge streets by name, keeping only those above min_length.
Args:
gdf: GeoDataFrame with street segments.
min_length: Minimum length in meters.
Returns:
GeoDataFrame with merged streets, sorted by length (longest first).
"""
# Group by street name
streets: dict[str, list] = {}
for _, row in gdf.iterrows():
name = row.get("name", "Unknown")
if name and name != "Unknown":
if name not in streets:
streets[name] = []
streets[name].append(row.geometry)
# Merge and filter
result_rows = []
for name, geometries in streets.items():
merged = geometries[0] if len(geometries) == 1 else MultiLineString(geometries)
# Create temp GeoDataFrame for length calculation
temp_gdf = gpd.GeoDataFrame(geometry=[merged], crs="EPSG:4326")
temp_proj = temp_gdf.to_crs("EPSG:2180") # Polish coordinate system
length = temp_proj.geometry.length.iloc[0]
if length >= min_length:
result_rows.append({"name": name, "geometry": merged, "length_m": length})
# Sort by length (longest first)
result_rows.sort(key=lambda x: x["length_m"], reverse=True)
return gpd.GeoDataFrame(result_rows, crs="EPSG:4326")
def get_warsaw_landmarks() -> gpd.GeoDataFrame:
"""Get Warsaw landmarks (museums, monuments, parks, etc.).
Returns:
GeoDataFrame with landmark points.
"""
cache_path = CACHE_DIR / "warsaw_landmarks.geojson"
if cache_path.exists():
return gpd.read_file(cache_path)
sys.stdout.write("Fetching landmark data...\n")
# Simplified query - just museums and major attractions
query = """
[out:json][timeout:60];
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
(
node["tourism"="museum"]["name"](area.warsaw);
node["tourism"="attraction"]["name"](area.warsaw);
node["historic"="monument"]["name"](area.warsaw);
way["tourism"="museum"]["name"](area.warsaw);
way["tourism"="attraction"]["name"](area.warsaw);
);
out center;
"""
data = _overpass_query(query)
features = []
seen_names: set[str] = set()
for element in data.get("elements", []):
name = element.get("tags", {}).get("name", "")
if not name or name in seen_names:
continue
# Get coordinates
if element.get("type") == "node":
lon, lat = element["lon"], element["lat"]
elif "center" in element:
lon, lat = element["center"]["lon"], element["center"]["lat"]
else:
continue
seen_names.add(name)
landmark_type = (
element.get("tags", {}).get("tourism")
or element.get("tags", {}).get("historic")
or element.get("tags", {}).get("leisure")
or "landmark"
)
features.append(
{
"type": "Feature",
"properties": {"name": name, "type": landmark_type},
"geometry": {"type": "Point", "coordinates": [lon, lat]},
}
)
_ensure_cache_dir()
geojson = {"type": "FeatureCollection", "features": features}
cache_path.write_text(json.dumps(geojson))
sys.stdout.write(f"Cached {len(features)} landmarks.\n")
if not features:
return gpd.GeoDataFrame(
{"name": [], "type": [], "geometry": []}, crs="EPSG:4326"
)
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")

View File

@ -0,0 +1,14 @@
"""Shared fixtures for keyboard_coop tests."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture(autouse=True)
def mock_pygame() -> MagicMock:
"""Mock pygame to prevent display initialization."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
yield

View File

@ -0,0 +1,148 @@
"""Tests for keyboard_coop constants and dataclasses."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
class TestConstants:
"""Tests for module constants."""
def test_screen_dimensions(self) -> None:
"""Test screen dimension constants."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import SCREEN_HEIGHT, SCREEN_WIDTH
expected_width = 1366
expected_height = 768
assert expected_width == SCREEN_WIDTH
assert expected_height == SCREEN_HEIGHT
def test_min_word_length(self) -> None:
"""Test minimum word length constant."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import MIN_WORD_LENGTH
expected_min = 3
assert expected_min == MIN_WORD_LENGTH
def test_keyboard_layout_structure(self) -> None:
"""Test KEYBOARD_LAYOUT has correct structure."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import KEYBOARD_LAYOUT
expected_rows = 3
assert len(KEYBOARD_LAYOUT) == expected_rows
expected_first_row_len = 10
expected_second_row_len = 9
expected_third_row_len = 7
assert len(KEYBOARD_LAYOUT[0]) == expected_first_row_len
assert len(KEYBOARD_LAYOUT[1]) == expected_second_row_len
assert len(KEYBOARD_LAYOUT[2]) == expected_third_row_len
class TestKeyAdjacency:
"""Tests for KEY_ADJACENCY mapping."""
def test_q_adjacents(self) -> None:
"""Test Q key has correct adjacent keys."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import KEY_ADJACENCY
assert set(KEY_ADJACENCY["q"]) == {"w", "a", "s"}
def test_all_letters_have_adjacents(self) -> None:
"""Test all 26 letters have adjacency entries."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import KEY_ADJACENCY
alphabet = "qwertyuiopasdfghjklzxcvbnm"
for letter in alphabet:
assert letter in KEY_ADJACENCY
assert len(KEY_ADJACENCY[letter]) > 0
class TestGameState:
"""Tests for GameState dataclass."""
def test_default_values(self) -> None:
"""Test GameState default values."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import GameState
state = GameState()
assert state.current_player == 0
assert state.current_word == ""
assert state.selected_letters == []
assert state.score == 0
assert state.game_over is False
assert "Player 1" in state.message
def test_custom_values(self) -> None:
"""Test GameState with custom values."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import GameState
state = GameState(
current_player=1,
current_word="test",
selected_letters=["t", "e", "s", "t"],
score=100,
game_over=True,
message="Game Over!",
)
assert state.current_player == 1
assert state.current_word == "test"
expected_score = 100
assert state.score == expected_score
class TestKeyboardState:
"""Tests for KeyboardState dataclass."""
def test_default_values(self) -> None:
"""Test KeyboardState default values."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import KeyboardState
kb_state = KeyboardState()
assert kb_state.layout == []
assert kb_state.available_letters == set()
assert kb_state.adjacency == {}
assert kb_state.positions == {}
class TestFontSet:
"""Tests for FontSet dataclass."""
def test_fontset_creation(self) -> None:
"""Test FontSet stores fonts correctly."""
mock_font = MagicMock()
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import FontSet
fonts = FontSet(normal=mock_font, large=mock_font, small=mock_font)
assert fonts.normal == mock_font
assert fonts.large == mock_font
assert fonts.small == mock_font
class TestColors:
"""Tests for color constants."""
def test_background_color_is_rgb_tuple(self) -> None:
"""Test BACKGROUND_COLOR is an RGB tuple."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import BACKGROUND_COLOR
expected_len = 3
assert len(BACKGROUND_COLOR) == expected_len
assert all(isinstance(c, int) for c in BACKGROUND_COLOR)
def test_player_colors_list(self) -> None:
"""Test PLAYER_COLORS has colors for 2 players."""
with patch.dict("sys.modules", {"pygame": MagicMock()}):
from python_pkg.keyboard_coop.main import PLAYER_COLORS
expected_players = 2
assert len(PLAYER_COLORS) == expected_players

View File

@ -0,0 +1,371 @@
"""Tests for keyboard_coop game logic methods."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
if TYPE_CHECKING:
from python_pkg.keyboard_coop.main import KeyboardCoopGame
class TestKeyboardCoopGame:
"""Tests for KeyboardCoopGame class methods."""
@pytest.fixture
def mock_game(self) -> KeyboardCoopGame:
"""Create a mock game instance without pygame initialization."""
mock_pg = MagicMock()
mock_pg.font.Font.return_value = MagicMock()
mock_pg.Rect = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
GameState,
KeyboardCoopGame,
KeyboardState,
)
# Create game without calling __init__ directly
game = object.__new__(KeyboardCoopGame)
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.layout = [["a", "b", "c"], ["d", "e", "f"]]
game.keyboard.adjacency = {
"a": ["b", "d"],
"b": ["a", "c", "e"],
"c": ["b", "f"],
"d": ["a", "e"],
"e": ["b", "d", "f"],
"f": ["c", "e"],
}
game.keyboard.available_letters = {"a", "b", "c", "d", "e", "f"}
game.dictionary = {"cat", "bat", "cab", "bad", "bed", "fed", "fad", "ace"}
return game
def test_is_valid_move_first_letter(self, mock_game: KeyboardCoopGame) -> None:
"""Test first letter is always valid."""
mock_game.state.selected_letters = []
assert mock_game._is_valid_move("a") is True
assert mock_game._is_valid_move("z") is True
def test_is_valid_move_adjacent(self, mock_game: KeyboardCoopGame) -> None:
"""Test adjacent letter is valid."""
mock_game.state.selected_letters = ["a"]
# "b" and "d" are adjacent to "a"
assert mock_game._is_valid_move("b") is True
assert mock_game._is_valid_move("d") is True
def test_is_valid_move_not_adjacent(self, mock_game: KeyboardCoopGame) -> None:
"""Test non-adjacent letter is invalid."""
mock_game.state.selected_letters = ["a"]
# "f" is not adjacent to "a"
assert mock_game._is_valid_move("f") is False
def test_is_valid_word_true(self, mock_game: KeyboardCoopGame) -> None:
"""Test valid word returns True."""
assert mock_game._is_valid_word("cat") is True
assert mock_game._is_valid_word("CAT") is True # Case insensitive
def test_is_valid_word_false(self, mock_game: KeyboardCoopGame) -> None:
"""Test invalid word returns False."""
assert mock_game._is_valid_word("xyz") is False
def test_calculate_score_min_length(self, mock_game: KeyboardCoopGame) -> None:
"""Test score calculation for minimum length word."""
# 3-letter word: 2^(3-2) = 2
assert mock_game._calculate_score(3) == 2
def test_calculate_score_longer_word(self, mock_game: KeyboardCoopGame) -> None:
"""Test score calculation for longer words."""
# 4-letter: 2^(4-2) = 4
assert mock_game._calculate_score(4) == 4
# 5-letter: 2^(5-2) = 8
assert mock_game._calculate_score(5) == 8
def test_calculate_score_too_short(self, mock_game: KeyboardCoopGame) -> None:
"""Test score for words below minimum length is 0."""
assert mock_game._calculate_score(2) == 0
assert mock_game._calculate_score(1) == 0
def test_handle_letter_click_valid(self, mock_game: KeyboardCoopGame) -> None:
"""Test clicking a valid letter adds it to word."""
mock_game.state.selected_letters = []
mock_game.state.current_word = ""
mock_game.state.current_player = 0
mock_game._handle_letter_click("a")
assert mock_game.state.selected_letters == ["a"]
assert mock_game.state.current_word == "a"
assert mock_game.state.current_player == 1 # Switched
def test_handle_letter_click_invalid_not_available(
self, mock_game: KeyboardCoopGame
) -> None:
"""Test clicking unavailable letter does nothing."""
mock_game.keyboard.available_letters = {"b", "c"}
mock_game.state.selected_letters = []
mock_game.state.current_word = ""
mock_game._handle_letter_click("a")
assert mock_game.state.selected_letters == []
assert mock_game.state.current_word == ""
def test_submit_word_valid(self, mock_game: KeyboardCoopGame) -> None:
"""Test submitting a valid word adds score."""
mock_game._generate_random_keyboard = MagicMock()
mock_game.state.current_word = "cat"
mock_game.state.selected_letters = ["c", "a", "t"]
mock_game.state.score = 0
mock_game._submit_word()
assert mock_game.state.score == 2 # 2^(3-2) = 2
assert mock_game.state.current_word == ""
assert mock_game.state.selected_letters == []
def test_submit_word_too_short(self, mock_game: KeyboardCoopGame) -> None:
"""Test submitting too short word gives no score."""
mock_game.state.current_word = "ca"
mock_game.state.selected_letters = ["c", "a"]
mock_game.state.score = 0
mock_game._submit_word()
assert mock_game.state.score == 0
assert "too short" in mock_game.state.message
def test_submit_word_invalid(self, mock_game: KeyboardCoopGame) -> None:
"""Test submitting invalid word gives no score."""
mock_game.state.current_word = "xyz"
mock_game.state.selected_letters = ["x", "y", "z"]
mock_game.state.score = 0
mock_game._submit_word()
assert mock_game.state.score == 0
assert "not a valid word" in mock_game.state.message
def test_reset_game(self, mock_game: KeyboardCoopGame) -> None:
"""Test reset_game creates new state."""
mock_game._generate_random_keyboard = MagicMock()
mock_game.state.score = 100
mock_game.state.current_word = "test"
mock_game._reset_game()
# After reset, state should be fresh
assert mock_game.state.score == 0
assert mock_game.state.current_word == ""
assert mock_game._generate_random_keyboard.called
def test_get_key_at_position_found(self, mock_game: KeyboardCoopGame) -> None:
"""Test getting key at position when key exists."""
mock_rect = MagicMock()
mock_rect.collidepoint.return_value = True
mock_game.keyboard.positions = {"a": mock_rect}
result = mock_game._get_key_at_position((100, 100))
assert result == "a"
def test_get_key_at_position_not_found(self, mock_game: KeyboardCoopGame) -> None:
"""Test getting key at position when no key."""
mock_rect = MagicMock()
mock_rect.collidepoint.return_value = False
mock_game.keyboard.positions = {"a": mock_rect}
result = mock_game._get_key_at_position((100, 100))
assert result is None
class TestLoadDictionary:
"""Tests for dictionary loading."""
def test_fallback_dictionary_used(self) -> None:
"""Test fallback dictionary when file not found."""
mock_pg = MagicMock()
mock_pg.font.Font.return_value = MagicMock()
mock_pg.display.set_mode.return_value = MagicMock()
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("pathlib.Path.open", side_effect=FileNotFoundError),
):
from python_pkg.keyboard_coop.main import KeyboardCoopGame
game = object.__new__(KeyboardCoopGame)
dictionary = game._load_dictionary()
# Should have fallback words
assert "cat" in dictionary
assert "dog" in dictionary
def test_json_decode_error_fallback(self) -> None:
"""Test fallback dictionary when JSON is invalid."""
import json
mock_pg = MagicMock()
mock_pg.font.Font.return_value = MagicMock()
mock_pg.display.set_mode.return_value = MagicMock()
mock_file = MagicMock()
mock_file.__enter__ = MagicMock(return_value=mock_file)
mock_file.__exit__ = MagicMock(return_value=False)
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("pathlib.Path.open", return_value=mock_file),
patch("json.load", side_effect=json.JSONDecodeError("err", "doc", 0)),
):
from python_pkg.keyboard_coop.main import KeyboardCoopGame
game = object.__new__(KeyboardCoopGame)
dictionary = game._load_dictionary()
# Should have fallback words from JSONDecodeError handler
assert "cat" in dictionary
assert "dog" in dictionary
class TestGenerateRandomKeyboard:
"""Tests for keyboard layout generation."""
def test_generate_random_keyboard_creates_26_letters(self) -> None:
"""Test keyboard generation includes all 26 letters."""
mock_pg = MagicMock()
mock_pg.Rect = MagicMock(return_value=MagicMock())
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.keyboard = KeyboardState()
game._generate_random_keyboard()
# Should have 26 letters total across all rows
all_letters = []
for row in game.keyboard.layout:
all_letters.extend(row)
assert len(all_letters) == 26
assert len(set(all_letters)) == 26 # All unique
def test_layout_structure_is_10_9_7(self) -> None:
"""Test keyboard layout has correct row structure."""
mock_pg = MagicMock()
mock_pg.Rect = MagicMock(return_value=MagicMock())
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.keyboard = KeyboardState()
game._generate_random_keyboard()
assert len(game.keyboard.layout) == 3
assert len(game.keyboard.layout[0]) == 10
assert len(game.keyboard.layout[1]) == 9
assert len(game.keyboard.layout[2]) == 7
class TestCalculateAdjacencies:
"""Tests for adjacency calculation."""
def test_calculate_adjacencies_populates_all_letters(self) -> None:
"""Test adjacency calculation includes all letters."""
mock_pg = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.keyboard = KeyboardState()
game.keyboard.layout = [
["a", "b", "c"],
["d", "e", "f"],
["g", "h"],
]
game._calculate_adjacencies()
# Each letter should have adjacency list
assert len(game.keyboard.adjacency) == 8
# Corner letter should have fewer adjacents
assert "b" in game.keyboard.adjacency["a"]
assert "d" in game.keyboard.adjacency["a"]
assert "e" in game.keyboard.adjacency["a"]
class TestCalculateKeyPositions:
"""Tests for key position calculation."""
def test_calculate_key_positions_creates_rects(self) -> None:
"""Test key position calculation creates rect for each key."""
mock_pg = MagicMock()
mock_pg.Rect = MagicMock(return_value=MagicMock())
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.keyboard = KeyboardState()
game.keyboard.layout = [["a", "b"], ["c", "d"]]
positions = game._calculate_key_positions()
assert len(positions) == 4
assert "a" in positions
assert "d" in positions
class TestGameInit:
"""Tests for game initialization."""
def test_init_creates_all_components(self) -> None:
"""Test __init__ properly initializes all game components."""
mock_pg = MagicMock()
mock_pg.font.Font.return_value = MagicMock()
mock_pg.display.set_mode.return_value = MagicMock()
mock_pg.Rect = MagicMock(return_value=MagicMock())
mock_file = MagicMock()
mock_file.__enter__ = MagicMock(return_value=mock_file)
mock_file.__exit__ = MagicMock(return_value=False)
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("pathlib.Path.open", return_value=mock_file),
patch("json.load", return_value={"cat": 1, "dog": 1}),
):
from python_pkg.keyboard_coop.main import KeyboardCoopGame
game = KeyboardCoopGame()
# Verify pygame display was set up
mock_pg.display.set_mode.assert_called()
mock_pg.display.set_caption.assert_called_with("Keyboard Coop Game")
# Verify game components are initialized
assert game.screen is not None
assert game.clock is not None
assert game.fonts is not None
assert game.dictionary is not None
assert game.state is not None
assert game.keyboard is not None

View File

@ -0,0 +1,426 @@
"""Tests for keyboard_coop game loop and forced submission."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
class TestForceSubmitWhenNoMoves:
"""Tests for forced word submission when no moves available."""
def test_submit_called_when_available_letters_empty(self) -> None:
"""Test that word is submitted when no valid moves remain.
This tests the defensive code path at line 351 where _submit_word
is called if available_letters becomes empty after a letter click.
"""
mock_pg = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.layout = [["a", "b"]]
game.keyboard.adjacency = {}
game.keyboard.available_letters = {"a"}
game.keyboard.positions = {}
game.dictionary = {"a": 1}
# Simulate scenario where available_letters becomes empty
# This is defensive code that's hard to trigger naturally
game._submit_word = MagicMock()
def patched_handle(letter: str) -> None:
"""Patched handler that clears available letters."""
if letter in game.keyboard.available_letters:
game.state.selected_letters.append(letter)
game.state.current_word += letter
# Force empty to trigger the check
game.keyboard.available_letters = set()
if not game.keyboard.available_letters:
game._submit_word()
patched_handle("a")
# Should have triggered submit_word
game._submit_word.assert_called()
class TestGameLoop:
"""Tests for the main game loop."""
def test_run_quit_event(self) -> None:
"""Test game loop exits on QUIT event."""
mock_pg = MagicMock()
# Create quit event
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.event.get.return_value = [quit_event]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit") as mock_exit,
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game.run()
mock_pg.quit.assert_called()
mock_exit.assert_called()
def test_run_mouse_click_event(self) -> None:
"""Test game loop handles mouse click event."""
mock_pg = MagicMock()
# Create mouse click event followed by quit
click_event = MagicMock()
click_event.type = "MOUSEDOWN"
click_event.button = 1
click_event.pos = (100, 100)
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
# Return click event first, then quit event
mock_pg.event.get.side_effect = [[click_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._handle_click = MagicMock()
game.run()
game._handle_click.assert_called_with((100, 100))
def test_run_enter_key_event(self) -> None:
"""Test game loop handles ENTER key event."""
mock_pg = MagicMock()
key_event = MagicMock()
key_event.type = "KEYDOWN"
key_event.key = "K_RETURN"
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.K_RETURN = "K_RETURN"
mock_pg.K_r = "K_r"
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._submit_word = MagicMock()
game.run()
game._submit_word.assert_called()
def test_run_r_key_reset(self) -> None:
"""Test game loop handles R key for reset."""
mock_pg = MagicMock()
key_event = MagicMock()
key_event.type = "KEYDOWN"
key_event.key = "K_r"
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.K_RETURN = "K_RETURN"
mock_pg.K_r = "K_r"
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._reset_game = MagicMock()
game.run()
game._reset_game.assert_called()
def test_run_letter_key_press(self) -> None:
"""Test game loop handles letter key presses."""
mock_pg = MagicMock()
key_event = MagicMock()
key_event.type = "KEYDOWN"
key_event.key = "some_key"
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.K_RETURN = "K_RETURN"
mock_pg.K_r = "K_r"
mock_pg.key.name.return_value = "a" # Single letter key
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._handle_letter_click = MagicMock()
game.run()
game._handle_letter_click.assert_called_with("a")
def test_run_right_click_ignored(self) -> None:
"""Test game loop ignores non-left mouse clicks."""
mock_pg = MagicMock()
click_event = MagicMock()
click_event.type = "MOUSEDOWN"
click_event.button = 3 # Right click
click_event.pos = (100, 100)
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.event.get.side_effect = [[click_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._handle_click = MagicMock()
game.run()
# handle_click should NOT be called for right click
game._handle_click.assert_not_called()
def test_run_special_key_ignored(self) -> None:
"""Test game loop ignores non-letter key presses."""
mock_pg = MagicMock()
key_event = MagicMock()
key_event.type = "KEYDOWN"
key_event.key = "some_key"
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.K_RETURN = "K_RETURN"
mock_pg.K_r = "K_r"
mock_pg.key.name.return_value = "escape" # Multi-char, not a letter
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._handle_letter_click = MagicMock()
game.run()
# handle_letter_click should NOT be called for special keys
game._handle_letter_click.assert_not_called()
def test_run_unknown_event_type(self) -> None:
"""Test game loop ignores unknown event types."""
mock_pg = MagicMock()
unknown_event = MagicMock()
unknown_event.type = "UNKNOWN"
quit_event = MagicMock()
quit_event.type = "QUIT"
mock_pg.QUIT = "QUIT"
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
mock_pg.KEYDOWN = "KEYDOWN"
mock_pg.event.get.side_effect = [[unknown_event], [quit_event]]
with (
patch.dict("sys.modules", {"pygame": mock_pg}),
patch("sys.exit"),
):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.clock = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
game._draw_keyboard = MagicMock()
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
game._handle_click = MagicMock()
game._submit_word = MagicMock()
game._reset_game = MagicMock()
game._handle_letter_click = MagicMock()
game.run()
# None of the handlers should be called for unknown event
game._handle_click.assert_not_called()
game._submit_word.assert_not_called()
game._reset_game.assert_not_called()
game._handle_letter_click.assert_not_called()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,311 @@
"""Tests for keyboard_coop UI drawing and click handling."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
class TestHandleClick:
"""Tests for click handling."""
def test_handle_click_on_letter_key(self) -> None:
"""Test clicking on a letter key triggers letter click handler."""
mock_pg = MagicMock()
mock_pg.Rect = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.available_letters = {"a"}
game.keyboard.adjacency = {"a": []}
# Mock _get_key_at_position to return "a"
game._get_key_at_position = MagicMock(return_value="a")
game._handle_letter_click = MagicMock()
game._submit_word = MagicMock()
game._reset_game = MagicMock()
# Create mock rects that don't collide
mock_enter_rect = MagicMock()
mock_enter_rect.collidepoint.return_value = False
mock_reset_rect = MagicMock()
mock_reset_rect.collidepoint.return_value = False
# Patch pygame.Rect to return our mocks
mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect]
game._handle_click((100, 100))
game._handle_letter_click.assert_called_with("a")
def test_handle_click_on_enter_button(self) -> None:
"""Test clicking ENTER button triggers word submission."""
mock_pg = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
# Mock methods
game._get_key_at_position = MagicMock(return_value=None)
game._submit_word = MagicMock()
game._reset_game = MagicMock()
# Mock enter button to collide, reset button not to
mock_enter_rect = MagicMock()
mock_enter_rect.collidepoint.return_value = True
mock_reset_rect = MagicMock()
mock_reset_rect.collidepoint.return_value = False
mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect]
game._handle_click((750, 200))
game._submit_word.assert_called()
game._reset_game.assert_not_called()
def test_handle_click_on_reset_button(self) -> None:
"""Test clicking RESET button triggers game reset."""
mock_pg = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.state = GameState()
game.keyboard = KeyboardState()
game.keyboard.positions = {}
# Mock methods
game._get_key_at_position = MagicMock(return_value=None)
game._submit_word = MagicMock()
game._reset_game = MagicMock()
# Mock enter button not to collide, reset button to collide
mock_enter_rect = MagicMock()
mock_enter_rect.collidepoint.return_value = False
mock_reset_rect = MagicMock()
mock_reset_rect.collidepoint.return_value = True
mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect]
game._handle_click((900, 200))
game._reset_game.assert_called()
game._submit_word.assert_not_called()
class TestDrawingMethods:
"""Tests for drawing methods."""
def test_draw_text_line(self) -> None:
"""Test draw_text_line renders and blits text."""
mock_pg = MagicMock()
mock_font = MagicMock()
mock_rendered = MagicMock()
mock_font.render.return_value = mock_rendered
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KeyboardCoopGame,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game._draw_text_line("Test text", (10, 20), mock_font)
mock_font.render.assert_called()
game.screen.blit.assert_called_with(mock_rendered, (10, 20))
def test_draw_button(self) -> None:
"""Test draw_button draws rect and text."""
mock_pg = MagicMock()
mock_pg.draw = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
FontSet,
KeyboardCoopGame,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
mock_rect = MagicMock()
mock_rect.center = (50, 50)
game._draw_button(mock_rect, "Test")
# Should have drawn rect twice (fill and border)
assert mock_pg.draw.rect.call_count == 2
class TestDrawKeyboard:
"""Tests for keyboard drawing."""
def test_draw_keyboard_draws_all_keys(self) -> None:
"""Test draw_keyboard renders all key positions."""
mock_pg = MagicMock()
mock_pg.draw = MagicMock()
mock_pg.mouse.get_pos.return_value = (0, 0)
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.state = GameState()
game.keyboard = KeyboardState()
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
# Set up some positions
mock_rect_a = MagicMock()
mock_rect_a.collidepoint.return_value = False
mock_rect_a.center = (100, 100)
mock_rect_b = MagicMock()
mock_rect_b.collidepoint.return_value = False
mock_rect_b.center = (150, 100)
game.keyboard.positions = {"a": mock_rect_a, "b": mock_rect_b}
game.keyboard.available_letters = {"a", "b"}
game._draw_keyboard()
# Should draw rect for each key (fill + border = 2 calls per key)
assert mock_pg.draw.rect.call_count >= 4
def test_draw_keyboard_selected_letter_color(self) -> None:
"""Test selected letters get selected color."""
mock_pg = MagicMock()
mock_pg.draw = MagicMock()
mock_pg.mouse.get_pos.return_value = (0, 0)
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KEY_SELECTED_COLOR,
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.state = GameState()
game.state.selected_letters = ["a"] # 'a' is selected
game.keyboard = KeyboardState()
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
mock_rect_a = MagicMock()
mock_rect_a.collidepoint.return_value = False
mock_rect_a.center = (100, 100)
game.keyboard.positions = {"a": mock_rect_a}
game.keyboard.available_letters = {"a"}
game._draw_keyboard()
# Check that KEY_SELECTED_COLOR was used
calls = mock_pg.draw.rect.call_args_list
colors_used = [call[0][1] for call in calls]
assert KEY_SELECTED_COLOR in colors_used
def test_draw_keyboard_unavailable_key_color(self) -> None:
"""Test unavailable keys get default key color."""
mock_pg = MagicMock()
mock_pg.draw = MagicMock()
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
KEY_COLOR,
FontSet,
GameState,
KeyboardCoopGame,
KeyboardState,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.state = GameState()
game.state.selected_letters = [] # Not selected
game.keyboard = KeyboardState()
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
mock_rect_a = MagicMock()
mock_rect_a.center = (100, 100)
game.keyboard.positions = {"a": mock_rect_a}
# Key is NOT available - should get KEY_COLOR
game.keyboard.available_letters = set()
game._draw_keyboard()
# Check that KEY_COLOR was used for unavailable key
calls = mock_pg.draw.rect.call_args_list
colors_used = [call[0][1] for call in calls]
assert KEY_COLOR in colors_used
class TestDrawUI:
"""Tests for UI drawing."""
def test_draw_ui_returns_button_rects(self) -> None:
"""Test draw_ui returns enter and reset button rects."""
mock_pg = MagicMock()
mock_pg.draw = MagicMock()
mock_rect_instance = MagicMock()
mock_pg.Rect.return_value = mock_rect_instance
with patch.dict("sys.modules", {"pygame": mock_pg}):
from python_pkg.keyboard_coop.main import (
FontSet,
GameState,
KeyboardCoopGame,
)
game = object.__new__(KeyboardCoopGame)
game.screen = MagicMock()
game.state = GameState()
game.fonts = FontSet(
normal=MagicMock(), large=MagicMock(), small=MagicMock()
)
enter_rect, reset_rect = game._draw_ui()
# Should return pygame.Rect instances
assert enter_rect is not None
assert reset_rect is not None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,409 @@
"""Tests for lichess_bot main module: game events and analysis."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, PropertyMock, patch
import chess
import pytest
from python_pkg.lichess_bot.main import (
BotContext,
GameMeta,
GameState,
_collect_analysis_lines,
_finalize_game,
_insert_analysis_into_log,
_log_analysis_progress,
_process_analysis_output,
_process_game_event,
_run_analysis_subprocess,
_write_pgn_to_log,
)
if TYPE_CHECKING:
from pathlib import Path
# Type alias to make mypy happy with test event dicts
Event = dict[str, Any]
class TestProcessGameEvent:
"""Tests for _process_game_event."""
def test_process_game_event_unhandled_type(self) -> None:
"""Test processing unhandled event type."""
ctx = MagicMock()
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {"type": "chatLine", "text": "hello"}
result = _process_game_event(event, ctx, state, meta)
assert result is True
def test_process_game_event_game_full(self) -> None:
"""Test processing gameFull event."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("e2e4"),
"opening",
)
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"type": "gameFull",
"state": {"moves": "", "status": "started"},
"white": {"id": "mybot"},
"black": {"id": "opp"},
}
result = _process_game_event(event, ctx, state, meta)
assert result is True
def test_process_game_event_game_end(self) -> None:
"""Test processing game end event."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (None, "no moves")
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(color="white", last_handled_len=-1)
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"type": "gameState",
"moves": "e2e4 e7e5",
"status": "mate",
}
result = _process_game_event(event, ctx, state, meta)
assert result is False
def test_process_game_event_game_end_after_move(self) -> None:
"""Test game ends with status after handling move.
This covers the case where _handle_move_if_needed returns True
but status indicates game end.
"""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("d2d4"),
"response",
)
ctx = BotContext(api=api, engine=engine, bot_version=1)
# Black's turn - it's opponent's move, so we don't need to move
state = GameState(color="black", last_handled_len=-1)
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"type": "gameState",
"moves": "e2e4", # One move - now it's black's turn
"status": "resign", # Game ended with resign
}
result = _process_game_event(event, ctx, state, meta)
assert result is False # Game should end
def test_process_game_event_unchanged_position(self) -> None:
"""Test processing event with unchanged position."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
state = GameState(last_handled_len=2, color="white")
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {"type": "gameState", "moves": "e2e4 e7e5"}
result = _process_game_event(event, ctx, state, meta)
assert result is True
def test_process_game_event_color_unknown(self) -> None:
"""Test processing event with unknown color."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
state = GameState(last_handled_len=-1)
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {"type": "gameState", "moves": "e2e4"}
result = _process_game_event(event, ctx, state, meta)
assert result is True
assert state.last_handled_len == 1
def test_process_game_event_color_unknown_on_gamefull(self) -> None:
"""Test processing gameFull event with still unknown color.
This covers the branch where event_type is gameFull but color
is not determined (e.g., spectator watching game).
"""
api = MagicMock()
# Return a user id that doesn't match either player
api.get_my_user_id.return_value = "spectator"
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
state = GameState(last_handled_len=-1)
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"type": "gameFull",
"state": {"moves": "e2e4", "status": "started"},
"white": {"id": "player1"},
"black": {"id": "player2"},
}
result = _process_game_event(event, ctx, state, meta)
assert result is True
# last_handled_len should NOT be updated for gameFull with unknown color
assert state.last_handled_len == -1
class TestWritePgnToLog:
"""Tests for _write_pgn_to_log."""
def test_write_pgn_to_log(self, tmp_path: Path) -> None:
"""Test writing PGN to log file."""
log_path = tmp_path / "game.log"
log_path.write_text("header\n")
board = chess.Board()
board.push_uci("e2e4")
meta = GameMeta(
game_id="game1",
bot_version=1,
site_url="https://lichess.org/game1",
date_iso="2021.01.01",
white_name="White",
black_name="Black",
)
_write_pgn_to_log(log_path, board, meta)
content = log_path.read_text()
assert "PGN:" in content
assert "e4" in content
class TestRunAnalysisSubprocess:
"""Tests for _run_analysis_subprocess."""
def test_run_analysis_subprocess_script_not_found(self, tmp_path: Path) -> None:
"""Test analysis when script not found."""
log_path = tmp_path / "game.log"
with patch("python_pkg.lichess_bot.main.Path") as mock_path:
mock_script = MagicMock()
mock_script.is_file.return_value = False
resolve = mock_path.return_value.resolve.return_value
resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = (
mock_script
)
result = _run_analysis_subprocess("game1", log_path, 10)
assert result is None
def test_run_analysis_subprocess_success(self, tmp_path: Path) -> None:
"""Test successful analysis subprocess."""
log_path = tmp_path / "game.log"
log_path.write_text("test")
mock_proc = MagicMock()
mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"])
mock_proc.stderr.read.return_value = ""
mock_proc.wait.return_value = 0
mock_proc.__enter__ = MagicMock(return_value=mock_proc)
mock_proc.__exit__ = MagicMock(return_value=False)
with (
patch("python_pkg.lichess_bot.main.Path") as mock_path,
patch("subprocess.Popen", return_value=mock_proc),
):
mock_script = MagicMock()
mock_script.is_file.return_value = True
resolve = mock_path.return_value.resolve.return_value
resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = (
mock_script
)
result = _run_analysis_subprocess("game1", log_path, 2)
assert result is not None
class TestProcessAnalysisOutput:
"""Tests for _process_analysis_output."""
def test_process_analysis_output_success(self) -> None:
"""Test processing analysis output successfully."""
mock_proc = MagicMock()
mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"])
mock_proc.stderr.read.return_value = ""
mock_proc.wait.return_value = 0
result = _process_analysis_output(mock_proc, "game1", 2)
assert result is not None
assert "e4" in result
def test_process_analysis_output_error_exit(self) -> None:
"""Test processing analysis output with error exit."""
mock_proc = MagicMock()
mock_proc.stdout = iter(["output\n"])
mock_proc.stderr.read.return_value = "error message"
mock_proc.wait.return_value = 1
result = _process_analysis_output(mock_proc, "game1", 1)
assert result is not None
assert "stderr" in result
def test_process_analysis_output_error_exit_no_stderr(self) -> None:
"""Test processing analysis output with error exit but no stderr."""
mock_proc = MagicMock()
mock_proc.stdout = iter(["output\n"])
mock_proc.stderr.read.return_value = ""
mock_proc.wait.return_value = 1
result = _process_analysis_output(mock_proc, "game1", 1)
assert result is not None
assert "stderr" not in result
def test_process_analysis_output_none_pipes(self) -> None:
"""Test processing analysis output with None pipes."""
mock_proc = MagicMock()
mock_proc.stdout = None
mock_proc.stderr = None
with pytest.raises(RuntimeError, match="pipes unexpectedly None"):
_process_analysis_output(mock_proc, "game1", 1)
class TestCollectAnalysisLines:
"""Tests for _collect_analysis_lines helper."""
def test_collect_analysis_lines_empty_iterator(self) -> None:
"""Test collecting lines from empty iterator."""
empty_iter: list[str] = []
analyzed, lines = _collect_analysis_lines(iter(empty_iter), "game1", 10)
assert analyzed == 0
assert lines == []
def test_collect_analysis_lines_with_content(self) -> None:
"""Test collecting lines from iterator with content."""
content = [" 1 e4\n", " 2 e5\n", "not a ply line\n"]
analyzed, lines = _collect_analysis_lines(iter(content), "game1", 3)
assert analyzed == 2
assert lines == content
def test_collect_analysis_lines_full_iteration(self) -> None:
"""Test that all lines are collected."""
content = ["line1\n", " 3 Nf3\n", "line3\n"]
analyzed, lines = _collect_analysis_lines(iter(content), "game1", 1)
assert analyzed == 1
assert len(lines) == 3
class TestLogAnalysisProgress:
"""Tests for _log_analysis_progress."""
def test_log_analysis_progress_with_total(self) -> None:
"""Test logging progress with known total."""
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
_log_analysis_progress("game1", 5, 10)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args[0]
assert "50%" in call_args[0] % call_args[1:]
def test_log_analysis_progress_zero_total(self) -> None:
"""Test logging progress with zero total."""
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
_log_analysis_progress("game1", 5, 0)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args[0]
assert "unknown" in call_args[0]
class TestInsertAnalysisIntoLog:
"""Tests for _insert_analysis_into_log."""
def test_insert_analysis_before_pgn(self, tmp_path: Path) -> None:
"""Test inserting analysis before PGN section."""
log_path = tmp_path / "game.log"
log_path.write_text("header\n\nPGN:\n1. e4\n")
meta = GameMeta(
game_id="game1",
bot_version=1,
date_iso="2021.01.01",
white_name="White",
black_name="Black",
)
_insert_analysis_into_log(log_path, "Analysis here", meta)
content = log_path.read_text()
assert "ANALYSIS:" in content
assert content.index("ANALYSIS:") < content.index("PGN:")
def test_insert_analysis_at_start(self, tmp_path: Path) -> None:
"""Test inserting analysis when PGN at start."""
log_path = tmp_path / "game.log"
log_path.write_text("PGN:\n1. e4\n")
meta = GameMeta(game_id="game1", bot_version=1)
_insert_analysis_into_log(log_path, "Analysis here", meta)
content = log_path.read_text()
assert "ANALYSIS:" in content
def test_insert_analysis_no_pgn(self, tmp_path: Path) -> None:
"""Test inserting analysis when no PGN section."""
log_path = tmp_path / "game.log"
log_path.write_text("header\n")
meta = GameMeta(game_id="game1", bot_version=1)
_insert_analysis_into_log(log_path, "Analysis here", meta)
content = log_path.read_text()
assert "ANALYSIS:" in content
def test_insert_analysis_oserror(self, tmp_path: Path) -> None:
"""Test inserting analysis with OSError."""
log_path = tmp_path / "nonexistent" / "game.log"
meta = GameMeta(game_id="game1", bot_version=1)
# Should not raise, just log debug
_insert_analysis_into_log(log_path, "Analysis", meta)
class TestFinalizeGame:
"""Tests for _finalize_game."""
def test_finalize_game_no_log_path(self) -> None:
"""Test finalize game with no log path."""
state = GameState(log_path=None)
meta = GameMeta(game_id="game1", bot_version=1)
_finalize_game(state, meta) # Should not raise
def test_finalize_game_write_error(self, tmp_path: Path) -> None:
"""Test finalize game with write error."""
log_path = tmp_path / "game.log"
log_path.write_text("header")
state = GameState(log_path=log_path)
meta = GameMeta(game_id="game1", bot_version=1)
with patch(
"python_pkg.lichess_bot.main._write_pgn_to_log",
side_effect=OSError("error"),
):
_finalize_game(state, meta) # Should not raise
def test_finalize_game_type_error_on_move_stack(self, tmp_path: Path) -> None:
"""Test finalize game with TypeError on move_stack."""
log_path = tmp_path / "game.log"
log_path.write_text("header\n")
state = GameState(log_path=log_path)
meta = GameMeta(game_id="game1", bot_version=1)
mock_board = MagicMock()
# Use PropertyMock to raise TypeError when move_stack is accessed
type(mock_board).move_stack = PropertyMock(side_effect=TypeError())
state.board = mock_board
with patch("python_pkg.lichess_bot.main._write_pgn_to_log"):
_finalize_game(state, meta) # Should not raise
def test_finalize_game_analysis_error(self, tmp_path: Path) -> None:
"""Test finalize game with analysis error."""
log_path = tmp_path / "game.log"
log_path.write_text("header\n")
state = GameState(log_path=log_path)
meta = GameMeta(game_id="game1", bot_version=1)
with (
patch("python_pkg.lichess_bot.main._write_pgn_to_log"),
patch(
"python_pkg.lichess_bot.main._run_analysis_subprocess",
side_effect=OSError("error"),
),
):
_finalize_game(state, meta) # Should not raise

View File

@ -0,0 +1,474 @@
"""Tests for lichess_bot main module: bot event loop."""
from __future__ import annotations
import os
import threading
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
import chess
import pytest
import requests
from python_pkg.lichess_bot.main import (
BotContext,
GameMeta,
GameState,
_handle_challenge,
_handle_game,
_process_bot_event,
_process_game_events_loop,
_run_event_loop,
_run_event_loop_iteration,
_safe_event_loop_iteration,
_stream_bot_events,
main,
run_bot,
)
if TYPE_CHECKING:
from pathlib import Path
# Type aliases to make mypy happy with test event dicts
Event = dict[str, Any]
GameThreads = dict[str, threading.Thread]
class TestHandleGame:
"""Tests for _handle_game."""
def test_handle_game_success(self, tmp_path: Path) -> None:
"""Test handling a game successfully."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
api.stream_game_events.return_value = iter(
[
{
"type": "gameFull",
"state": {"moves": "", "status": "started"},
"white": {"id": "mybot"},
"black": {"id": "opp"},
},
{"type": "gameState", "moves": "e2e4", "status": "mate"},
]
)
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (None, "no moves")
ctx = BotContext(api=api, engine=engine, bot_version=1)
with (
patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path),
patch(
"python_pkg.lichess_bot.main._run_analysis_subprocess",
return_value=None,
),
):
_handle_game("game1", ctx, None)
def test_handle_game_request_error(self, tmp_path: Path) -> None:
"""Test handling a game with request error."""
api = MagicMock()
api.stream_game_events.side_effect = requests.RequestException("error")
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
with (
patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path),
patch(
"python_pkg.lichess_bot.main._run_analysis_subprocess",
return_value=None,
),
):
_handle_game("game1", ctx, None) # Should not raise
def test_handle_game_skips_chat_events(self, tmp_path: Path) -> None:
"""Test handling a game skips chat events."""
api = MagicMock()
api.stream_game_events.return_value = iter(
[
{"type": "chatLine", "text": "hello"},
{"type": "opponentGone", "gone": True},
]
)
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
with (
patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path),
patch(
"python_pkg.lichess_bot.main._run_analysis_subprocess",
return_value=None,
),
):
_handle_game("game1", ctx, None)
class TestProcessGameEventsLoop:
"""Tests for _process_game_events_loop."""
def test_empty_events_iterator(self) -> None:
"""Test processing empty events iterator."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
state = GameState(color="white")
meta = GameMeta(game_id="game1", bot_version=1)
empty_iter: list[Event] = []
# Should complete without error when iterator is empty
_process_game_events_loop(iter(empty_iter), ctx, state, meta)
def test_processes_all_events(self) -> None:
"""Test that all events are processed until break condition."""
api = MagicMock()
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (None, "no moves")
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(color="white")
meta = GameMeta(game_id="game1", bot_version=1)
events: list[Event] = [
{"type": "chatLine", "text": "hello"}, # skipped
{"type": "gameState", "moves": "e2e4", "status": "resign"}, # game end
]
_process_game_events_loop(iter(events), ctx, state, meta)
def test_processes_multiple_game_events(self) -> None:
"""Test processing multiple game events that continue the game."""
api = MagicMock()
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("e2e4"),
"e4",
)
api.make_move.return_value = None
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(color="white")
state.board = chess.Board()
meta = GameMeta(game_id="game1", bot_version=1)
events: list[Event] = [
# First event - game state, game continues
{"type": "gameState", "moves": "", "status": "started"},
# Second event - opponent moves, game continues
{"type": "gameState", "moves": "e2e4 e7e5", "status": "started"},
# Third event - game ends
{"type": "gameState", "moves": "e2e4 e7e5", "status": "mate"},
]
_process_game_events_loop(iter(events), ctx, state, meta)
class TestRunEventLoop:
"""Tests for _run_event_loop."""
def test_run_event_loop_zero_iterations(self) -> None:
"""Test running event loop with zero iterations."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
# Should complete immediately with 0 iterations
_run_event_loop(ctx, game_threads, 0, 0)
def test_run_event_loop_limited_iterations(self) -> None:
"""Test running event loop with limited iterations."""
api = MagicMock()
api.stream_bot_events.return_value = iter([])
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
with patch(
"python_pkg.lichess_bot.main._safe_event_loop_iteration", return_value=0
) as mock_iter:
_run_event_loop(ctx, game_threads, 0, 3)
assert mock_iter.call_count == 3
def test_run_event_loop_none_iterations_needs_interrupt(self) -> None:
"""Test that None iterations runs until interrupted."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
call_count = 0
def stop_after_calls(*_args: object, **_kwargs: object) -> int:
nonlocal call_count
call_count += 1
if call_count >= 5:
raise KeyboardInterrupt
return 0
with (
patch(
"python_pkg.lichess_bot.main._safe_event_loop_iteration",
side_effect=stop_after_calls,
),
pytest.raises(KeyboardInterrupt),
):
_run_event_loop(ctx, game_threads, 0, None)
assert call_count == 5
class TestHandleChallenge:
"""Tests for _handle_challenge."""
def test_accept_standard_blitz(self) -> None:
"""Test accepting standard blitz challenge."""
api = MagicMock()
challenge: Event = {
"id": "ch1",
"variant": {"key": "standard"},
"speed": "blitz",
}
_handle_challenge(challenge, api, decline_correspondence=False)
api.accept_challenge.assert_called_once_with("ch1")
def test_decline_variant(self) -> None:
"""Test declining non-standard variant."""
api = MagicMock()
challenge: Event = {
"id": "ch1",
"variant": {"key": "chess960"},
"speed": "blitz",
}
_handle_challenge(challenge, api, decline_correspondence=False)
api.decline_challenge.assert_called_once()
def test_decline_correspondence(self) -> None:
"""Test declining correspondence when flag set."""
api = MagicMock()
challenge: Event = {
"id": "ch1",
"variant": {"key": "standard"},
"speed": "correspondence",
}
_handle_challenge(challenge, api, decline_correspondence=True)
api.decline_challenge.assert_called_once()
def test_accept_correspondence_when_allowed(self) -> None:
"""Test accepting correspondence when flag not set."""
api = MagicMock()
challenge: Event = {
"id": "ch1",
"variant": {"key": "standard"},
"speed": "correspondence",
}
_handle_challenge(challenge, api, decline_correspondence=False)
api.decline_challenge.assert_called_once() # Still declined due to perf_ok
def test_invalid_variant_data(self) -> None:
"""Test handling invalid variant data."""
api = MagicMock()
challenge: Event = {
"id": "ch1",
"variant": "invalid",
"speed": "blitz",
}
_handle_challenge(challenge, api, decline_correspondence=False)
api.accept_challenge.assert_called_once()
class TestProcessBotEvent:
"""Tests for _process_bot_event."""
def test_process_challenge_event(self) -> None:
"""Test processing challenge event."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {
"type": "challenge",
"challenge": {
"id": "ch1",
"variant": {"key": "standard"},
"speed": "blitz",
},
}
_process_bot_event(event, ctx, game_threads)
api.accept_challenge.assert_called_once()
def test_process_game_start_event(self) -> None:
"""Test processing gameStart event."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {"type": "gameStart", "game": {"id": "game1"}}
with patch("python_pkg.lichess_bot.main.threading.Thread") as mock_thread_class:
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
_process_bot_event(event, ctx, game_threads)
assert "game1" in game_threads
mock_thread.start.assert_called_once()
def test_process_game_start_existing_thread(self) -> None:
"""Test processing gameStart with existing alive thread."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
mock_thread = MagicMock(spec=threading.Thread)
mock_thread.is_alive.return_value = True
game_threads: GameThreads = {"game1": mock_thread}
event: Event = {"type": "gameStart", "game": {"id": "game1"}}
_process_bot_event(event, ctx, game_threads)
# Should not create new thread
assert game_threads["game1"] is mock_thread
def test_process_game_finish_event(self) -> None:
"""Test processing gameFinish event."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {"type": "gameFinish", "game": {"id": "game1"}}
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
_process_bot_event(event, ctx, game_threads)
mock_logger.info.assert_called()
def test_process_game_finish_invalid_data(self) -> None:
"""Test processing gameFinish event with non-dict game data."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {"type": "gameFinish", "game": "not_a_dict"}
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
_process_bot_event(event, ctx, game_threads)
# Should not log info since game data is invalid
mock_logger.info.assert_not_called()
def test_process_unknown_event(self) -> None:
"""Test processing unknown event."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {"type": "unknown", "data": "test"}
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
_process_bot_event(event, ctx, game_threads)
mock_logger.debug.assert_called()
def test_process_challenge_invalid_data(self) -> None:
"""Test processing challenge with invalid data."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {"type": "challenge", "challenge": "invalid"}
_process_bot_event(event, ctx, game_threads)
api.accept_challenge.assert_not_called()
def test_process_game_start_invalid_data(self) -> None:
"""Test processing gameStart with invalid data."""
api = MagicMock()
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
event: Event = {"type": "gameStart", "game": "invalid"}
_process_bot_event(event, ctx, game_threads)
assert len(game_threads) == 0
class TestStreamBotEvents:
"""Tests for _stream_bot_events."""
def test_stream_bot_events(self) -> None:
"""Test streaming bot events."""
api = MagicMock()
api.stream_events.return_value = iter([{"type": "test"}])
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
events = list(_stream_bot_events(ctx))
assert len(events) == 1
class TestRunEventLoopIteration:
"""Tests for _run_event_loop_iteration."""
def test_run_event_loop_iteration(self) -> None:
"""Test running event loop iteration."""
api = MagicMock()
api.stream_events.return_value = iter([{"type": "unknown"}])
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
result = _run_event_loop_iteration(ctx, game_threads)
assert result == 0
class TestSafeEventLoopIteration:
"""Tests for _safe_event_loop_iteration."""
def test_safe_event_loop_iteration_success(self) -> None:
"""Test safe event loop iteration success."""
api = MagicMock()
api.stream_events.return_value = iter([])
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
result = _safe_event_loop_iteration(ctx, game_threads, 0)
assert result == 0
def test_safe_event_loop_iteration_error(self) -> None:
"""Test safe event loop iteration with error."""
api = MagicMock()
api.stream_events.side_effect = requests.RequestException("error")
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
game_threads: GameThreads = {}
with patch("python_pkg.lichess_bot.main.backoff_sleep", return_value=5):
result = _safe_event_loop_iteration(ctx, game_threads, 2)
assert result == 5
class TestRunBot:
"""Tests for run_bot."""
def test_run_bot_no_token(self) -> None:
"""Test run_bot without token raises error."""
with (
patch.dict(os.environ, {}, clear=True),
pytest.raises(RuntimeError, match="LICHESS_TOKEN"),
):
run_bot()
def test_run_bot_with_token(self) -> None:
"""Test run_bot with token starts event loop."""
class _StopLoopError(Exception):
"""Custom exception to stop the loop."""
def stop_loop(*_args: object, **_kwargs: object) -> None:
raise _StopLoopError
with (
patch.dict(os.environ, {"LICHESS_TOKEN": "test_token"}),
patch(
"python_pkg.lichess_bot.main.get_and_increment_version",
return_value=1,
),
patch("python_pkg.lichess_bot.main.LichessAPI"),
patch("python_pkg.lichess_bot.main.RandomEngine"),
patch(
"python_pkg.lichess_bot.main._safe_event_loop_iteration",
side_effect=stop_loop,
),
pytest.raises(_StopLoopError),
):
run_bot("DEBUG", decline_correspondence=True)
class TestMain:
"""Tests for main function."""
def test_main_parses_args(self) -> None:
"""Test main parses command line arguments."""
class _StopExecutionError(Exception):
"""Custom exception to stop execution."""
with (
patch(
"sys.argv",
["main.py", "--log-level", "DEBUG", "--decline-correspondence"],
),
patch(
"python_pkg.lichess_bot.main.run_bot", side_effect=_StopExecutionError
) as mock_run_bot,
pytest.raises(_StopExecutionError),
):
main()
mock_run_bot.assert_called_once_with("DEBUG", decline_correspondence=True)

View File

@ -0,0 +1,403 @@
"""Tests for lichess_bot main module: game state helpers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
import chess
import requests
from python_pkg.lichess_bot.main import (
BotContext,
GameMeta,
GameState,
_apply_move_to_board,
_attempt_move,
_calculate_time_budget,
_extract_game_full_data,
_extract_game_state_data,
_extract_player_info,
_handle_move_if_needed,
_init_game_log,
_is_my_turn,
_log_move_to_file,
_rebuild_board_from_moves,
_update_clocks_from_state,
)
if TYPE_CHECKING:
from pathlib import Path
# Type alias to make mypy happy with test event dicts
Event = dict[str, Any]
class TestApplyMoveToBoard:
"""Tests for _apply_move_to_board."""
def test_apply_valid_move(self) -> None:
"""Test applying a valid move."""
board = chess.Board()
_apply_move_to_board(board, "e2e4", "game1")
assert board.fen() != chess.STARTING_FEN
def test_apply_invalid_move(self) -> None:
"""Test applying an invalid move logs debug."""
board = chess.Board()
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
_apply_move_to_board(board, "invalid", "game1")
mock_logger.debug.assert_called_once()
class TestInitGameLog:
"""Tests for _init_game_log."""
def test_init_game_log_success(self, tmp_path: Path) -> None:
"""Test successful log initialization."""
with patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path):
result = _init_game_log("game123", 42)
assert result is not None
assert result.exists()
content = result.read_text()
assert "game game123 started" in content
assert "bot_version v42" in content
def test_init_game_log_oserror(self) -> None:
"""Test log initialization with OSError."""
with patch("python_pkg.lichess_bot.main.Path.cwd") as mock_cwd:
mock_path = MagicMock()
mock_path.__truediv__ = MagicMock(return_value=mock_path)
mock_path.open.side_effect = OSError("Permission denied")
mock_cwd.return_value = mock_path
result = _init_game_log("game123", 42)
assert result is None
class TestUpdateClocksFromState:
"""Tests for _update_clocks_from_state."""
def test_update_clocks_white(self) -> None:
"""Test clock update when playing as white."""
state = GameState(color="white")
state_data: Event = {"wtime": 60000, "btime": 55000, "winc": 1000}
_update_clocks_from_state(state_data, state)
assert state.my_ms == 60000
assert state.opp_ms == 55000
assert state.inc_ms == 1000
def test_update_clocks_black(self) -> None:
"""Test clock update when playing as black."""
state = GameState(color="black")
state_data: Event = {"wtime": 60000, "btime": 55000, "binc": 2000}
_update_clocks_from_state(state_data, state)
assert state.my_ms == 55000
assert state.opp_ms == 60000
assert state.inc_ms == 2000
def test_update_clocks_float_values(self) -> None:
"""Test clock update with float values."""
state = GameState(color="white")
state_data: Event = {"wtime": 60000.5, "btime": 55000.5}
_update_clocks_from_state(state_data, state)
assert state.my_ms == 60000
assert state.opp_ms == 55000
def test_update_clocks_none_values(self) -> None:
"""Test clock update with None values."""
state = GameState(color="white")
state_data: Event = {"wtime": None, "btime": None}
_update_clocks_from_state(state_data, state)
assert state.my_ms is None
assert state.opp_ms is None
class TestExtractPlayerInfo:
"""Tests for _extract_player_info."""
def test_extract_player_info_white(self) -> None:
"""Test extracting player info when bot is white."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"white": {"id": "mybot", "name": "MyBot"},
"black": {"id": "opp", "name": "Opponent"},
}
_extract_player_info(event, state, meta, api)
assert state.color == "white"
assert meta.white_name == "MyBot"
assert meta.black_name == "Opponent"
def test_extract_player_info_black(self) -> None:
"""Test extracting player info when bot is black."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"white": {"id": "opp", "name": "Opponent"},
"black": {"id": "mybot", "name": "MyBot"},
}
_extract_player_info(event, state, meta, api)
assert state.color == "black"
def test_extract_player_info_invalid_data(self) -> None:
"""Test extracting player info with invalid data."""
api = MagicMock()
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {"white": "invalid", "black": "invalid"}
_extract_player_info(event, state, meta, api)
assert state.color is None
def test_extract_player_info_missing_name(self) -> None:
"""Test extracting player info with missing name uses id."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"white": {"id": "mybot"},
"black": {"id": "opponent"},
}
_extract_player_info(event, state, meta, api)
assert meta.white_name == "mybot"
assert meta.black_name == "opponent"
class TestExtractGameFullData:
"""Tests for _extract_game_full_data."""
def test_extract_game_full_data(self) -> None:
"""Test extracting gameFull data."""
api = MagicMock()
api.get_my_user_id.return_value = "mybot"
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {
"state": {"moves": "e2e4 e7e5", "status": "started", "wtime": 60000},
"white": {"id": "mybot"},
"black": {"id": "opp"},
"createdAt": 1609459200000, # 2021-01-01
}
moves, status = _extract_game_full_data(event, state, meta, api)
assert moves == "e2e4 e7e5"
assert status == "started"
assert meta.site_url == "https://lichess.org/game1"
assert meta.date_iso == "2021.01.01"
def test_extract_game_full_data_invalid_state(self) -> None:
"""Test extracting gameFull data with invalid state."""
api = MagicMock()
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
event: Event = {"state": "invalid"}
moves, status = _extract_game_full_data(event, state, meta, api)
assert moves == ""
assert status is None
class TestExtractGameStateData:
"""Tests for _extract_game_state_data."""
def test_extract_game_state_as_white(self) -> None:
"""Test extracting gameState data as white."""
state = GameState(color="white", my_ms=60000)
event: Event = {
"moves": "e2e4",
"status": "started",
"wtime": 59000,
"btime": 60000,
}
moves, status = _extract_game_state_data(event, state)
assert moves == "e2e4"
assert status == "started"
assert state.my_ms == 59000
assert state.opp_ms == 60000
def test_extract_game_state_as_black(self) -> None:
"""Test extracting gameState data as black."""
state = GameState(color="black")
event: Event = {
"moves": "e2e4 e7e5",
"wtime": 60000,
"btime": 59000,
"binc": 1000,
}
moves, __status = _extract_game_state_data(event, state)
assert moves == "e2e4 e7e5"
assert state.my_ms == 59000
assert state.opp_ms == 60000
assert state.inc_ms == 1000
class TestCalculateTimeBudget:
"""Tests for _calculate_time_budget."""
def test_calculate_time_budget_normal(self) -> None:
"""Test time budget calculation."""
state = GameState(my_ms=60000, inc_ms=1000)
board = chess.Board()
budget = _calculate_time_budget(state, board, 10.0)
assert 0.05 <= budget <= 10.0
def test_calculate_time_budget_low_time(self) -> None:
"""Test time budget with low time."""
state = GameState(my_ms=1000, inc_ms=0)
board = chess.Board()
budget = _calculate_time_budget(state, board, 10.0)
assert budget >= 0.05
class TestLogMoveToFile:
"""Tests for _log_move_to_file."""
def test_log_move_to_file(self, tmp_path: Path) -> None:
"""Test logging a move to file."""
log_path = tmp_path / "game.log"
log_path.write_text("header\n")
move = chess.Move.from_uci("e2e4")
_log_move_to_file(log_path, 1, move, "best move")
content = log_path.read_text()
assert "ply 1: e2e4" in content
assert "best move" in content
def test_log_move_to_file_none_path(self) -> None:
"""Test logging with None path does nothing."""
move = chess.Move.from_uci("e2e4")
_log_move_to_file(None, 1, move, "reason") # Should not raise
class TestAttemptMove:
"""Tests for _attempt_move."""
def test_attempt_move_success(self) -> None:
"""Test successful move attempt."""
api = MagicMock()
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("e2e4"),
"opening",
)
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(my_ms=60000)
meta = GameMeta(game_id="game1", bot_version=1)
board = chess.Board()
result = _attempt_move(ctx, state, meta, board)
assert result is True
api.make_move.assert_called_once()
def test_attempt_move_no_moves(self) -> None:
"""Test move attempt with no legal moves."""
api = MagicMock()
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (None, "no moves")
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState()
meta = GameMeta(game_id="game1", bot_version=1)
board = chess.Board()
result = _attempt_move(ctx, state, meta, board)
assert result is False
def test_attempt_move_illegal(self) -> None:
"""Test move attempt with illegal move."""
api = MagicMock()
engine = MagicMock()
engine.max_time_sec = 5.0
# Return a move that's not legal (e.g., random square move)
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("a1a8"),
"bad",
)
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(my_ms=60000)
meta = GameMeta(game_id="game1", bot_version=1)
board = chess.Board()
result = _attempt_move(ctx, state, meta, board)
assert result is True
api.make_move.assert_not_called()
def test_attempt_move_request_error(self) -> None:
"""Test move attempt with request error."""
api = MagicMock()
api.make_move.side_effect = requests.RequestException("Network error")
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("e2e4"),
"opening",
)
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(my_ms=60000)
meta = GameMeta(game_id="game1", bot_version=1)
board = chess.Board()
result = _attempt_move(ctx, state, meta, board)
assert result is True # Still returns True
class TestIsMyTurn:
"""Tests for _is_my_turn."""
def test_is_my_turn_white_to_move(self) -> None:
"""Test checking turn when white to move."""
board = chess.Board() # White to move
assert _is_my_turn(board, "white") is True
assert _is_my_turn(board, "black") is False
def test_is_my_turn_black_to_move(self) -> None:
"""Test checking turn when black to move."""
board = chess.Board()
board.push_uci("e2e4") # Black to move
assert _is_my_turn(board, "white") is False
assert _is_my_turn(board, "black") is True
class TestRebuildBoardFromMoves:
"""Tests for _rebuild_board_from_moves."""
def test_rebuild_board_from_moves(self) -> None:
"""Test rebuilding board from moves list."""
moves_list = ["e2e4", "e7e5", "g1f3"]
board = _rebuild_board_from_moves(moves_list, "game1")
assert len(board.move_stack) == 3
class TestHandleMoveIfNeeded:
"""Tests for _handle_move_if_needed."""
def test_handle_move_game_state_my_turn(self) -> None:
"""Test handling move on gameState when my turn."""
api = MagicMock()
engine = MagicMock()
engine.max_time_sec = 5.0
engine.choose_move_with_explanation.return_value = (
chess.Move.from_uci("e2e4"),
"opening",
)
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(color="white", my_ms=60000, board=chess.Board())
meta = GameMeta(game_id="game1", bot_version=1)
result = _handle_move_if_needed(ctx, state, meta, "gameState", 0)
assert result is True
def test_handle_move_game_full_with_moves(self) -> None:
"""Test handling move on gameFull with existing moves (opponent's turn)."""
api = MagicMock()
engine = MagicMock()
ctx = BotContext(api=api, engine=engine, bot_version=1)
state = GameState(color="white", my_ms=60000, board=chess.Board())
meta = GameMeta(game_id="game1", bot_version=1)
# gameFull with moves - don't move
result = _handle_move_if_needed(ctx, state, meta, "gameFull", 1)
assert result is True
engine.choose_move_with_explanation.assert_not_called()

View File

@ -0,0 +1,445 @@
"""Classical segmentation methods: concept, thresholding, region growing, watershed."""
from __future__ import annotations
from moviepy import (
CompositeVideoClip,
VideoClip,
)
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
from python_pkg.praca_magisterska_video._q23_helpers import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_tc,
)
# ── Segmentation concept ─────────────────────────────────────────
def _segmentation_concept() -> list[CompositeVideoClip]:
"""Show what segmentation is: pixel-level labeling."""
slides = []
# Synthetic image: grid of colored pixels
def make_image_frame(_t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
# Draw a small "image" grid
grid_x, grid_y = 100, 150
cell = 40
# Sky (top rows)
colors_map = [
[(135, 206, 235)] * 8, # sky
[(135, 206, 235)] * 5 + [(34, 139, 34)] * 3, # sky + tree
[(34, 139, 34)] * 3
+ [(128, 128, 128)] * 3
+ [(34, 139, 34)] * 2, # tree+road+tree
[(128, 128, 128)] * 3
+ [(200, 50, 50)] * 2
+ [(128, 128, 128)] * 3, # road+car+road
]
labels_map = [
["niebo"] * 8,
["niebo"] * 5 + ["drzewo"] * 3,
["drzewo"] * 3 + ["droga"] * 3 + ["drzewo"] * 2,
["droga"] * 3 + ["samochód"] * 2 + ["droga"] * 3,
]
label_colors = {
"niebo": (100, 180, 255),
"drzewo": (50, 200, 50),
"droga": (180, 180, 180),
"samochód": (255, 80, 80),
}
for r, row in enumerate(colors_map):
for c, col in enumerate(row):
y = grid_y + r * cell
x = grid_x + c * cell
frame[y : y + cell - 2, x : x + cell - 2] = col
# Draw segmentation map on the right
seg_x = 600
for r, row in enumerate(labels_map):
for c, lab in enumerate(row):
y = grid_y + r * cell
x = seg_x + c * cell
frame[y : y + cell - 2, x : x + cell - 2] = label_colors[lab]
return frame
image_clip = VideoClip(make_image_frame, duration=STEP_DUR).with_fps(FPS)
labels_text = [
("Obraz wejściowy", 22, "white", FONT_B, (170, 100)),
("Mapa segmentacji", 22, "white", FONT_B, (660, 100)),
("", 50, "#FFE082", FONT_B, (450, 250)),
("Każdy piksel → etykieta klasy", 20, "#B0BEC5", FONT_R, (100, 420)),
("niebo | drzewo | droga | samochód", 18, "#90CAF9", FONT_R, (600, 420)),
("Segmentacja = klasyfikacja per-piksel", 24, "#FFE082", FONT_B, (100, 500)),
(
"Semantic: klasy bez instancji | Instance: "
"rozróżnia obiekty | Panoptic: oba",
16,
"#78909C",
FONT_R,
(100, 560),
),
]
clips: list[VideoClip] = [image_clip]
for text, fs, color, font, pos in labels_text:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
clips.append(tc)
slides.append(
CompositeVideoClip(clips, size=(W, H)).with_effects([FadeIn(0.3), FadeOut(0.3)])
)
return slides
# ── Thresholding / Otsu ───────────────────────────────────────────
def _thresholding_demo() -> list[CompositeVideoClip]:
"""Animate thresholding and Otsu concept."""
slides = []
# Show histogram & threshold
def make_threshold_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
# Draw bimodal histogram bars
bar_start_x = 80
bar_y = 500
bar_w = 4
for i in range(256):
# Bimodal: peaks at 60 and 190
h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2))
h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2))
bar_h = int(h1 + h2)
x = bar_start_x + i * bar_w
if x + bar_w < W:
frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (150, 150, 170)
# Animated threshold line
threshold = int(60 + (190 - 60) * min(t / (STEP_DUR * 0.7), 1.0))
tx = bar_start_x + threshold * bar_w
if tx < W:
frame[bar_y - 250 : bar_y + 10, tx : tx + 3] = (255, 80, 80)
# Color the two sides
for i in range(threshold):
x = bar_start_x + i * bar_w
h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2))
h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2))
bar_h = int(h1 + h2)
if x + bar_w < W and bar_h > 0:
frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (70, 130, 200)
for i in range(threshold, 256):
x = bar_start_x + i * bar_w
h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2))
h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2))
bar_h = int(h1 + h2)
if x + bar_w < W and bar_h > 0:
frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (200, 100, 80)
return frame
hist_clip = VideoClip(make_threshold_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [hist_clip]
labels = [
("Progowanie (Thresholding) z metodą Otsu", 28, "#FFE082", FONT_B, (80, 30)),
(
"Histogram jasności pikseli — dwumodalny (bimodal)",
20,
"#B0BEC5",
FONT_R,
(80, 80),
),
("Garb 1: piksele obiektu (ciemne ~60)", 16, "#64B5F6", FONT_R, (80, 120)),
("Garb 2: piksele tła (jasne ~190)", 16, "#EF9A9A", FONT_R, (80, 150)),
(
"Próg T (czerwona linia) dzieli piksele na 2 klasy",
18,
"white",
FONT_R,
(80, 540),
),
(
"Otsu: automatycznie testuje T=0..255, minimalizuje σ² wewnątrzklasową",
16,
"#A5D6A7",
FONT_R,
(80, 580),
),
(
"Piksel ≤ T → klasa 0 (tło) | Piksel > T → klasa 1 (obiekt)",
16,
"#78909C",
FONT_R,
(80, 620),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── Region Growing ────────────────────────────────────────────────
def _region_growing_demo() -> list[CompositeVideoClip]:
"""Animate region growing BFS from a seed pixel."""
slides = []
grid_size = 10
cell_size = 40
rng = np.random.default_rng(42)
# Create a simple grid: dark region (30-80) and bright region (160-220)
grid = np.zeros((grid_size, grid_size), dtype=np.uint8)
grid[:] = 60 # dark background
grid[2:7, 3:8] = 180 # bright rectangle
# Add some noise
noise = rng.integers(-15, 15, (grid_size, grid_size))
grid = np.clip(grid.astype(int) + noise, 0, 255).astype(np.uint8)
# BFS steps from seed (4, 5)
seed = (4, 5)
threshold_val = 50
visited_order: list[tuple[int, int]] = []
queue = [seed]
visited_set = {seed}
while queue:
r, c = queue.pop(0)
visited_order.append((r, c))
for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
nr, nc = r + dr, c + dc
if (
0 <= nr < grid_size
and 0 <= nc < grid_size
and (nr, nc) not in visited_set
) and abs(int(grid[nr, nc]) - int(grid[seed])) < threshold_val:
visited_set.add((nr, nc))
queue.append((nr, nc))
def make_region_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
ox, oy = 100, 180
# How many cells to show as visited
progress = min(t / (STEP_DUR * 0.8), 1.0)
n_visited = int(progress * len(visited_order))
for r in range(grid_size):
for c in range(grid_size):
x = ox + c * cell_size
y = oy + r * cell_size
val = grid[r, c]
color = (val, val, val)
# Highlight visited
if (r, c) in visited_order[:n_visited]:
color = (80, 200, 120) # green for region
elif (r, c) == seed:
color = (255, 200, 50) # yellow seed
frame[y : y + cell_size - 2, x : x + cell_size - 2] = color
# Mark the seed with a bright border
sx = ox + seed[1] * cell_size
sy = ox + seed[0] * cell_size + 80
frame[sy : sy + cell_size, sx : sx + 2] = (255, 200, 50)
frame[sy : sy + cell_size, sx + cell_size - 2 : sx + cell_size] = (255, 200, 50)
frame[sy : sy + 2, sx : sx + cell_size] = (255, 200, 50)
frame[sy + cell_size - 2 : sy + cell_size, sx : sx + cell_size] = (255, 200, 50)
return frame
region_clip = VideoClip(make_region_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [region_clip]
labels = [
("Region Growing — rozrastanie regionu", 28, "#FFE082", FONT_B, (100, 30)),
("Seed (ziarno) → BFS do podobnych sąsiadów", 20, "#B0BEC5", FONT_R, (100, 80)),
(
"Żółty = seed | Zielony = region | Szary = nieodwiedzone",
16,
"#78909C",
FONT_R,
(100, 120),
),
(
"Sąsiad PODOBNY (|jasność - jasność_regionu| < próg) → dodaj do regionu",
16,
"#A5D6A7",
FONT_R,
(100, 600),
),
(
"Algorytm zatrzymuje się gdy brak podobnych sąsiadów",
16,
"#90CAF9",
FONT_R,
(100, 640),
),
(
"Mnemonik: PLAMA atramentu — rozlewa się na podobne piksele",
18,
"#EF9A9A",
FONT_R,
(100, 670),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── Watershed ─────────────────────────────────────────────────────
def _watershed_demo() -> list[CompositeVideoClip]:
"""Animate watershed flooding concept."""
slides = []
def make_watershed_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
# Draw terrain profile (1D cross-section)
ox, oy = 100, 450
terrain_w = 900
terrain_points = 100
xs = np.linspace(0, 1, terrain_points)
# Two valleys with a ridge
terrain = (
120 * np.exp(-((xs - 0.25) ** 2) / 0.005)
+ 80 * np.exp(-((xs - 0.75) ** 2) / 0.008)
+ 30
)
terrain = 250 - terrain # invert for visual (valleys at bottom)
# Water level rises over time
water_level = int(160 + 80 * min(t / (STEP_DUR * 0.7), 1.0))
for i in range(terrain_points - 1):
x1 = ox + int(xs[i] * terrain_w)
x2 = ox + int(xs[i + 1] * terrain_w)
y1 = oy - int(terrain[i])
y2 = oy - int(terrain[i + 1])
# Fill terrain
for x in range(x1, min(x2 + 1, W)):
top = min(y1, y2) - 5
frame[top:oy, x : x + 1] = (100, 80, 60)
# Fill water
water_y = oy - water_level
for x in range(x1, min(x2 + 1, W)):
t_y = oy - int(terrain[i])
if water_y < t_y:
# Water fills below terrain surface
fill_top = max(water_y, 0)
fill_bot = min(t_y, oy)
if fill_top < fill_bot:
frame[fill_top:fill_bot, x : x + 1] = (70, 130, 220)
# Dam marker at ridge
ridge_x = ox + int(0.5 * terrain_w)
dam_visible_threshold = 160
if water_level > dam_visible_threshold:
frame[oy - water_level : oy - 140, ridge_x - 2 : ridge_x + 2] = (
255,
80,
80,
)
return frame
ws_clip = VideoClip(make_watershed_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [ws_clip]
labels = [
("Watershed — metoda zlewiska", 28, "#FFE082", FONT_B, (100, 20)),
(
"Obraz = mapa topograficzna (jasność = wysokość)",
20,
"#B0BEC5",
FONT_R,
(100, 65),
),
(
"Brązowy = teren (ciemne=doliny, jasne=szczyty)",
16,
"#8D6E63",
FONT_R,
(100, 100),
),
("Niebieski = woda zalewająca od minimów", 16, "#64B5F6", FONT_R, (100, 130)),
(
"Czerwony = TAMA (granica segmentu) — gdy woda z 2 dolin się spotka",
16,
"#EF9A9A",
FONT_R,
(100, 160),
),
(
"Problem: over-segmentation "
"(za dużo regionów). "
"Rozwiązanie: marker-controlled.",
16,
"#A5D6A7",
FONT_R,
(100, 560),
),
(
"Mnemonik: ZALEWANIE terenu — granie gór = granice segmentów",
18,
"#FFE082",
FONT_R,
(100, 600),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides

View File

@ -0,0 +1,248 @@
"""DeepLab architecture animations for Q23 segmentation video."""
from __future__ import annotations
from moviepy import (
CompositeVideoClip,
VideoClip,
)
import numpy as np
from python_pkg.praca_magisterska_video._q23_helpers import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_compose_slide,
)
# ── DeepLab Architecture ─────────────────────────────────────────
def _make_dilated_frame(t: float) -> np.ndarray:
"""Render a dilated convolution comparison frame."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
cell = 36
grids = [
(
"rate=1",
60,
[
(0, 0),
(0, 1),
(0, 2),
(1, 0),
(1, 1),
(1, 2),
(2, 0),
(2, 1),
(2, 2),
],
),
(
"rate=2",
420,
[
(0, 0),
(0, 2),
(0, 4),
(2, 0),
(2, 2),
(2, 4),
(4, 0),
(4, 2),
(4, 4),
],
),
(
"rate=3",
820,
[
(0, 0),
(0, 3),
(0, 6),
(3, 0),
(3, 3),
(3, 6),
(6, 0),
(6, 3),
(6, 6),
],
),
]
for gi, (_label, gx, positions) in enumerate(grids):
if progress < gi * 0.3:
break
gy = 180
grid_size = 7
for r in range(grid_size):
for c in range(grid_size):
x = gx + c * cell
y = gy + r * cell
frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55)
for r, c in positions:
x = gx + c * cell
y = gy + r * cell
frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200)
frame[y : y + 2, x : x + cell - 2] = (120, 180, 255)
frame[y + cell - 4 : y + cell - 2, x : x + cell - 2] = (120, 180, 255)
return frame
def _make_aspp_frame(t: float) -> np.ndarray:
"""Render a single ASPP module animation frame."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
frame[250:330, 50:130] = (70, 130, 200)
frame[250:252, 50:130] = (120, 180, 255)
frame[328:330, 50:130] = (120, 180, 255)
branches = [
("1x1 conv", 250, (200, 170), (100, 40), (80, 200, 120)),
("rate=6", 310, (200, 250), (100, 40), (200, 160, 80)),
("rate=12", 370, (200, 330), (100, 40), (200, 120, 60)),
("rate=18", 430, (200, 410), (100, 40), (180, 100, 80)),
("GAP", 490, (200, 490), (100, 40), (160, 80, 160)),
]
n_branches = min(int(progress * 5) + 1, 5)
for i, (_lbl, _h, (bx, by), (bw, bh), color) in enumerate(branches):
if i < n_branches:
frame[by : by + bh, bx : bx + bw] = color
frame[by : by + 2, bx : bx + bw] = tuple(min(c + 50, 255) for c in color)
ay = by + bh // 2
frame[ay - 1 : ay + 2, 133:197] = (150, 150, 170)
concat_phase = 0.6
if progress > concat_phase:
frame[250:530, 380:420] = (50, 60, 80)
frame[250:252, 380:420] = (200, 200, 100)
frame[528:530, 380:420] = (200, 200, 100)
for i, (_lbl, _h, (bx, by), (bw, bh), _c) in enumerate(branches):
if i < n_branches:
ay = by + bh // 2
frame[ay - 1 : ay + 2, bx + bw + 3 : 378] = (150, 150, 170)
final_conv_phase = 0.8
if progress > final_conv_phase:
frame[350:420, 450:550] = (100, 200, 100)
frame[350:352, 450:550] = (150, 230, 150)
frame[418:420, 450:550] = (150, 230, 150)
frame[388:391, 423:448] = (150, 150, 170)
return frame
def _deeplab_demo() -> list[CompositeVideoClip]:
"""Animate DeepLab: dilated convolution + ASPP step by step."""
dur = STEP_DUR + 1
# Slide 1: Regular vs Dilated convolution
dil_clip = VideoClip(_make_dilated_frame, duration=dur).with_fps(FPS)
labels = [
("DeepLab: Atrous (Dilated) Convolution", 26, "#FFE082", FONT_B, (80, 20)),
(
"KROK 1: Zrozum dilated convolution — filtr z DZIURAMI",
18,
"#A5D6A7",
FONT_R,
(80, 60),
),
("rate=1 (zwykła)", 14, "#64B5F6", FONT_B, (60, 160)),
("RF = 3x3", 14, "#64B5F6", FONT_R, (60, 440)),
("9 wag, kontekst 3px", 12, "#78909C", FONT_R, (60, 470)),
("rate=2 (dilated)", 14, "#FFE082", FONT_B, (420, 160)),
("RF = 5x5", 14, "#FFE082", FONT_R, (420, 440)),
("9 wag, kontekst 5px!", 12, "#78909C", FONT_R, (420, 470)),
("rate=3 (dilated)", 14, "#A5D6A7", FONT_B, (820, 160)),
("RF = 7x7", 14, "#A5D6A7", FONT_R, (820, 440)),
("9 wag, kontekst 7px!", 12, "#78909C", FONT_R, (820, 470)),
(
"Niebieski = pozycja wag filtra 3x3 | Szary = pominięte (dziury)",
15,
"#B0BEC5",
FONT_R,
(80, 510),
),
(
"TE SAME 9 wag → WIĘKSZE pole widzenia "
"→ lepszy kontekst BEZ dodatkowych parametrów!",
16,
"white",
FONT_R,
(80, 550),
),
(
"Mnemonik: DZIURY w filtrze — à trous = z dziurami (fr.)",
16,
"#FFE082",
FONT_R,
(80, 600),
),
]
slides = [_compose_slide(dil_clip, labels, dur)]
# Slide 2: ASPP module step by step
aspp_clip = VideoClip(_make_aspp_frame, duration=dur).with_fps(FPS)
labels2 = [
(
"DeepLab: ASPP (Atrous Spatial Pyramid Pooling)",
24,
"#FFE082",
FONT_B,
(80, 20),
),
(
"KROK 2: Multi-scale — analizuj obraz na WIELU skalach naraz",
17,
"#A5D6A7",
FONT_R,
(80, 60),
),
("Wejście", 13, "#64B5F6", FONT_B, (55, 235)),
("Conv 1x1", 12, "white", FONT_R, (210, 178)),
("Dilated r=6", 12, "white", FONT_R, (205, 258)),
("Dilated r=12", 12, "white", FONT_R, (203, 338)),
("Dilated r=18", 12, "white", FONT_R, (203, 418)),
("GAP (global)", 12, "white", FONT_R, (205, 498)),
("Concat", 13, "#FFE082", FONT_B, (381, 537)),
("Conv", 13, "#A5D6A7", FONT_B, (470, 425)),
(
"5 gałęzi RÓWNOLEGŁYCH → różne skale kontekstu:",
16,
"#B0BEC5",
FONT_R,
(550, 170),
),
(" 1x1: kontekst punktowy (piksel)", 14, "#A5D6A7", FONT_R, (560, 210)),
(" r=6: kontekst lokalny (~13px)", 14, "#FFE082", FONT_R, (560, 245)),
(" r=12: kontekst średni (~25px)", 14, "#FFE082", FONT_R, (560, 280)),
(" r=18: kontekst szeroki (~37px)", 14, "#FFE082", FONT_R, (560, 315)),
(" GAP: kontekst GLOBALNY (cały obraz)", 14, "#CE93D8", FONT_R, (560, 350)),
("Concat → 1x1 conv → mapa segmentacji", 16, "#A5D6A7", FONT_R, (550, 400)),
(
"Efekt: sieć widzi OD piksela DO całego obrazu naraz!",
17,
"white",
FONT_R,
(80, 600),
),
(
"Mnemonik: ASPP = Piramida z DZIURAMI, patrzy na 5 skal jednocześnie",
15,
"#FFE082",
FONT_R,
(80, 645),
),
]
slides.append(_compose_slide(aspp_clip, labels2, dur))
return slides

View File

@ -0,0 +1,116 @@
"""Shared constants and helper functions for Q23 segmentation video."""
from __future__ import annotations
import logging
import os
from pathlib import Path
import numpy as np
os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg"
from moviepy import (
ColorClip,
CompositeVideoClip,
TextClip,
VideoClip,
)
from moviepy.video.fx import FadeIn, FadeOut
# ── Constants ─────────────────────────────────────────────────────
W, H = 1280, 720
FPS = 24
STEP_DUR = 7.0
HEADER_DUR = 4.0
FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf"
FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf"
OUTPUT_DIR = Path(__file__).resolve().parent / "videos"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT = str(OUTPUT_DIR / "q23_segmentation.mp4")
logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger(__name__)
BG_COLOR = (15, 20, 35)
rng = np.random.default_rng(42)
def _tc(**kwargs: object) -> TextClip:
"""TextClip wrapper that adds enough bottom margin to prevent clipping."""
fs = kwargs.get("font_size", 24)
m = int(fs) // 3 + 2
kwargs["margin"] = (0, m)
return TextClip(**kwargs)
def _make_header(
title: str, subtitle: str, duration: float = HEADER_DUR
) -> CompositeVideoClip:
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
t = (
_tc(
text=title,
font_size=48,
color="white",
font=FONT_B,
)
.with_duration(duration)
.with_position(("center", 260))
)
s = (
_tc(
text=subtitle,
font_size=24,
color="#90CAF9",
font=FONT_R,
)
.with_duration(duration)
.with_position(("center", 340))
)
return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects(
[FadeIn(0.5), FadeOut(0.5)]
)
def _text_slide(
lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]],
duration: float = STEP_DUR,
) -> CompositeVideoClip:
"""Create a slide with multiple text elements."""
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
clips: list[VideoClip] = [bg]
for text, font_size, color, font, pos in lines:
tc = (
_tc(
text=text,
font_size=font_size,
color=color,
font=font,
)
.with_duration(duration)
.with_position(pos)
)
clips.append(tc)
return CompositeVideoClip(clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
def _compose_slide(
base_clip: VideoClip,
labels: list[tuple[str, int, str, str, tuple[int, int]]],
duration: float,
) -> CompositeVideoClip:
"""Overlay text labels on an animated base clip."""
text_clips: list[VideoClip] = [base_clip]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(duration)
.with_position(pos)
)
text_clips.append(tc)
return CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)

View File

@ -0,0 +1,430 @@
"""Transformer segmentation and methods comparison for Q23 video."""
from __future__ import annotations
from moviepy import (
ColorClip,
CompositeVideoClip,
VideoClip,
)
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
from python_pkg.praca_magisterska_video._q23_helpers import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_compose_slide,
_tc,
_text_slide,
)
# ── Transformer Segmentation ────────────────────────────────────
def _draw_base_grid(
frame: np.ndarray,
gx: int,
gy: int,
grid_n: int,
cell: int,
) -> None:
"""Draw an empty grid of cells."""
for r in range(grid_n):
for c in range(grid_n):
x = gx + c * cell
y = gy + r * cell
frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55)
def _draw_cnn_kernel(
frame: np.ndarray,
lx: int,
ly: int,
cell: int,
progress: float,
) -> None:
"""Highlight a 3x3 CNN kernel on the grid."""
cnn_phase = 0.2
if progress <= cnn_phase:
return
cx, cy = 2, 2
for dr in range(-1, 2):
for dc in range(-1, 2):
r, c = cy + dr, cx + dc
x = lx + c * cell
y = ly + r * cell
frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200)
x = lx + cx * cell
y = ly + cy * cell
frame[y : y + cell - 2, x : x + cell - 2] = (120, 180, 255)
def _draw_conn_line(
frame: np.ndarray,
x0: int,
y0: int,
x1: int,
y1: int,
) -> None:
"""Draw a dashed connection line between two points."""
steps = max(abs(x1 - x0), abs(y1 - y0))
if steps <= 0:
return
for s in range(0, steps, 3):
px = x0 + int((x1 - x0) * s / steps)
py = y0 + int((y1 - y0) * s / steps)
if 0 <= px < W - 1 and 0 <= py < H - 1:
frame[py : py + 1, px : px + 1] = (200, 180, 50)
def _draw_attention_connections(
frame: np.ndarray,
origin: tuple[int, int],
grid_n: int,
cell: int,
progress: float,
) -> None:
"""Draw transformer self-attention connections on the grid."""
rx, ry = origin
transformer_phase = 0.4
if progress <= transformer_phase:
return
cx_t, cy_t = 2, 2
x0 = rx + cx_t * cell + cell // 2
y0 = ry + cy_t * cell + cell // 2
n_connections = int(progress * 36)
conn_idx = 0
for r in range(grid_n):
for c in range(grid_n):
conn_idx += 1
if conn_idx > n_connections:
break
x = rx + c * cell
y = ry + r * cell
dist = abs(r - cy_t) + abs(c - cx_t)
strength = max(30, 200 - dist * 30)
frame[y : y + cell - 2, x : x + cell - 2] = (
strength // 3,
strength // 2,
strength,
)
_draw_conn_line(frame, x0, y0, x + cell // 2, y + cell // 2)
else:
continue
break
x = rx + cx_t * cell
y = ry + cy_t * cell
frame[y : y + cell - 2, x : x + cell - 2] = (255, 200, 50)
def _make_attention_frame(t: float) -> np.ndarray:
"""Render a CNN-vs-Transformer attention comparison frame."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
cell = 40
grid_n = 6
lx, ly = 60, 200
_draw_base_grid(frame, lx, ly, grid_n, cell)
_draw_cnn_kernel(frame, lx, ly, cell, progress)
rx, ry = 680, 200
_draw_base_grid(frame, rx, ry, grid_n, cell)
_draw_attention_connections(frame, (rx, ry), grid_n, cell, progress)
return frame
def _transformer_seg_demo() -> list[CompositeVideoClip]:
"""Animate transformer-based segmentation: self-attention concept."""
dur = STEP_DUR + 1
# Slide 1: CNN local vs Transformer global
att_clip = VideoClip(_make_attention_frame, duration=dur).with_fps(FPS)
labels = [
("Transformer: Self-Attention w segmentacji", 26, "#FFE082", FONT_B, (80, 20)),
("CNN = LOKALNY kontekst", 18, "#64B5F6", FONT_B, (60, 160)),
("Transformer = GLOBALNY kontekst", 18, "#FFE082", FONT_B, (680, 160)),
("Filtr 3x3 widzi", 14, "#64B5F6", FONT_R, (60, 460)),
("TYLKO 9 sąsiadów", 14, "#64B5F6", FONT_R, (60, 485)),
("Self-attention: każdy", 14, "#FFE082", FONT_R, (680, 460)),
("piksel widzi WSZYSTKIE!", 14, "#FFE082", FONT_R, (680, 485)),
("vs", 28, "#B0BEC5", FONT_B, (450, 300)),
]
slides = [_compose_slide(att_clip, labels, dur)]
# Slide 2: Self-attention Q/K/V step by step
qkv_lines = [
("Self-Attention: Q / K / V krok po kroku", 26, "#FFE082", FONT_B, (80, 30)),
("Każdy piksel (token) tworzy 3 wektory:", 18, "#B0BEC5", FONT_R, (100, 100)),
(
" Q (Query) = 'czego szukam?' - pytanie piksela",
17,
"#64B5F6",
FONT_R,
(120, 145),
),
(
" K (Key) = 'co oferuj\u0119?' - odpowied\u017a piksela",
17,
"#A5D6A7",
FONT_R,
(120, 185),
),
(
" V (Value) = 'moja warto\u015b\u0107' - informacja do przekazania",
17,
"#FFE082",
FONT_R,
(120, 225),
),
("Algorytm attention:", 18, "#B0BEC5", FONT_R, (100, 285)),
(
" 1. Mnożenie Q x K\u1d40 → macierz NxN (kto ważny dla kogo)",
16,
"white",
FONT_R,
(120, 320),
),
(
" 2. Skalowanie: / \u221ad (stabilno\u015b\u0107 gradient\u00f3w)",
16,
"white",
FONT_R,
(120, 355),
),
(
" 3. Softmax \u2192 wagi attention (sumuj\u0105 si\u0119 do 1)",
16,
"white",
FONT_R,
(120, 390),
),
(
" 4. Mno\u017cenie wag x V \u2192 wa\u017cona suma warto\u015bci",
16,
"white",
FONT_R,
(120, 425),
),
(
"Attention(Q,K,V) = softmax(Q \u00b7 K\u1d40 / \u221ad) \u00b7 V",
20,
"#FFE082",
FONT_B,
(100, 480),
),
(
"Z\u0142o\u017cono\u015b\u0107: O(n\u00b2) pami\u0119ci \u2014 n = liczba pikseli/token\u00f3w",
16,
"#EF9A9A",
FONT_R,
(100, 535),
),
(
"Dlatego SegFormer u\u017cywa efficient attention (liniowa z\u0142o\u017cono\u015b\u0107)",
15,
"#78909C",
FONT_R,
(100, 570),
),
(
"SegFormer (2021): lightweight + hierarchiczny encoder",
16,
"#A5D6A7",
FONT_R,
(100, 610),
),
(
"Mask2Former (2022): masked attention + "
"unified (semantic+instance+panoptic)",
16,
"#CE93D8",
FONT_R,
(100, 645),
),
]
slides.append(_text_slide(qkv_lines, duration=STEP_DUR + 1))
# Slide 3: Encoder-Decoder in DL summary
summary_lines = [
(
"Podsumowanie: Encoder-Decoder w segmentacji DL",
24,
"#FFE082",
FONT_B,
(80, 30),
),
(
"Wsp\u00f3lna idea WSZYSTKICH sieci segmentacji:",
18,
"#B0BEC5",
FONT_R,
(80, 90),
),
(
"Encoder: obraz \u2192 cechy (zmniejsza rozdzielczo\u015b\u0107, wyci\u0105ga CO)",
16,
"#64B5F6",
FONT_R,
(100, 140),
),
(
"Decoder: cechy \u2192 mapa (zwi\u0119ksza rozdzielczo\u015b\u0107, odtwarza GDZIE)",
16,
"#A5D6A7",
FONT_R,
(100, 175),
),
(
"Skip: przenosi detale z encodera do decodera",
16,
"#FFE082",
FONT_R,
(100, 210),
),
("", 10, "white", FONT_R, (100, 240)),
(
"FCN (2015): Conv1x1 + skip \u2192 pierwsza end-to-end",
16,
"#64B5F6",
FONT_R,
(100, 275),
),
(
"U-Net (2015): U-shape + skip concat \u2192 segmentacja medyczna",
16,
"#A5D6A7",
FONT_R,
(100, 310),
),
(
"DeepLab (2018): dilated conv + ASPP \u2192 multi-scale kontekst",
16,
"#FFE082",
FONT_R,
(100, 345),
),
(
"SegFormer: transformer encoder (globalny kontekst)",
16,
"#CE93D8",
FONT_R,
(100, 380),
),
(
"Mask2Former: masked attention (unified, SOTA)",
16,
"#CE93D8",
FONT_R,
(100, 415),
),
("", 10, "white", FONT_R, (100, 440)),
(
"Ewolucja: wi\u0119cej kontekstu + lepsze skip connections:",
17,
"white",
FONT_R,
(80, 465),
),
(
" CNN lokal. \u2192 dilated (szersze RF) \u2192 transformer (global) \u2192 masked att.",
16,
"#B0BEC5",
FONT_R,
(80, 505),
),
(
" addition skip \u2192 concat skip \u2192 cross-attention skip",
16,
"#B0BEC5",
FONT_R,
(80, 540),
),
(
"Metryki: mIoU (standard), Dice (medycyna), Focal Loss (imbalance)",
16,
"#90CAF9",
FONT_R,
(80, 590),
),
(
"Loss: Cross-Entropy per piksel + opcjonalnie Dice/Focal",
15,
"#78909C",
FONT_R,
(80, 625),
),
]
slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1))
return slides
# ── Methods comparison ────────────────────────────────────────────
def _methods_comparison() -> CompositeVideoClip:
"""Create a comparison table of all segmentation methods."""
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0)
title = (
_tc(
text="Por\u00f3wnanie metod segmentacji",
font_size=36,
color="white",
font=FONT_B,
)
.with_duration(10.0)
.with_position(("center", 20))
)
rows = [
("Metoda", "Typ", "Idea", "Mnemonik"),
(
"Thresholding",
"Klasyczna",
"piksel > T \u2192 klasa 1",
"PR\u00d3G na bramce",
),
("Otsu", "Klasyczna", "auto-pr\u00f3g, min \u03c3\u00b2", "AUTO-bramkarz"),
("Region Growing", "Klasyczna", "BFS od seeda", "PLAMA atramentu"),
("Watershed", "Klasyczna", "zalewanie minim\u00f3w", "ZALEWANIE terenu"),
(
"Mean Shift",
"Klasyczna",
"j\u0105dro \u2192 max g\u0119sto\u015bci",
"KULKI do do\u0142k\u00f3w",
),
("U-Net", "Deep Learning", "encoder-decoder + skip", "Litera U + mosty"),
("DeepLab", "Deep Learning", "dilated conv + ASPP", "DZIURY w filtrze"),
]
clips: list[VideoClip] = [bg, title]
mnemonic_col = 3
for i, row in enumerate(rows):
y_pos = 75 + i * 72
col_x = [40, 210, 340, 660]
for j, cell in enumerate(row):
fs = 16 if i > 0 else 18
color = (
"#64B5F6" if i == 0 else ("#E0E0E0" if j < mnemonic_col else "#FFE082")
)
tc = (
_tc(
text=cell,
font_size=fs,
color=color,
font=FONT_B if i == 0 else FONT_R,
)
.with_duration(10.0)
.with_position((col_x[j], y_pos))
)
clips.append(tc)
return CompositeVideoClip(clips, size=(W, H)).with_effects(
[FadeIn(0.5), FadeOut(0.5)]
)

View File

@ -0,0 +1,399 @@
"""U-Net and FCN architecture animations for Q23 segmentation video."""
from __future__ import annotations
from moviepy import (
CompositeVideoClip,
VideoClip,
)
import numpy as np
from python_pkg.praca_magisterska_video._q23_helpers import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_compose_slide,
_text_slide,
)
# ── U-Net Architecture ───────────────────────────────────────────
def _draw_unet_skips(
frame: np.ndarray,
enc_positions: list[tuple[int, int, int, int]],
n_blocks: int,
dec_x: int,
skip_threshold: int,
) -> None:
"""Draw horizontal dashed skip-connection lines."""
if n_blocks <= skip_threshold:
return
for i in range(min(n_blocks - 5, 4)):
ey = enc_positions[i][1] + enc_positions[i][3] // 2
ex_end = enc_positions[i][0] + enc_positions[i][2]
for dash_x in range(ex_end + 10, dec_x - 10, 15):
frame[ey : ey + 2, dash_x : dash_x + 8] = (255, 200, 50)
def _make_unet_frame(t: float) -> np.ndarray:
"""Render a single U-Net animation frame."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
enc_sizes = [(80, 120), (60, 100), (45, 80), (30, 60)]
dec_sizes = list(reversed(enc_sizes))
enc_x = 150
dec_x = 850
progress = min(t / (STEP_DUR * 0.6), 1.0)
n_blocks = int(progress * 8) + 1
enc_positions: list[tuple[int, int, int, int]] = []
y_offset = 120
for i, (bw, bh) in enumerate(enc_sizes):
x = enc_x
y = y_offset + i * 130
enc_positions.append((x, y, bw, bh))
if i < n_blocks:
frame[y : y + bh, x : x + bw] = (70, 130, 200)
frame[y : y + 2, x : x + bw] = (100, 180, 255)
frame[y + bh - 2 : y + bh, x : x + bw] = (100, 180, 255)
frame[y : y + bh, x : x + 2] = (100, 180, 255)
frame[y : y + bh, x + bw - 2 : x + bw] = (100, 180, 255)
if i < len(enc_sizes) - 1:
ax = x + bw // 2
ay = y + bh + 10
frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170)
bx, by = 500, y_offset + 3 * 130 + 30
encoder_count = 4
if n_blocks > encoder_count:
frame[by : by + 50, bx : bx + 25] = (200, 100, 80)
frame[by : by + 2, bx : bx + 25] = (255, 140, 100)
frame[by + 48 : by + 50, bx : bx + 25] = (255, 140, 100)
for i, (bw, bh) in enumerate(dec_sizes):
x = dec_x
y = y_offset + (3 - i) * 130
if n_blocks > 4 + i + 1:
frame[y : y + bh, x : x + bw] = (80, 200, 120)
frame[y : y + 2, x : x + bw] = (120, 230, 150)
frame[y + bh - 2 : y + bh, x : x + bw] = (120, 230, 150)
frame[y : y + bh, x : x + 2] = (120, 230, 150)
frame[y : y + bh, x + bw - 2 : x + bw] = (120, 230, 150)
if i < len(dec_sizes) - 1:
ax = x + bw // 2
ay = y - 30
frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170)
skip_threshold = 5
_draw_unet_skips(frame, enc_positions, n_blocks, dec_x, skip_threshold)
return frame
def _unet_demo() -> list[CompositeVideoClip]:
"""Animate U-Net encoder-decoder architecture."""
dur = STEP_DUR + 1
unet_clip = VideoClip(_make_unet_frame, duration=dur).with_fps(FPS)
labels = [
("U-Net: Encoder-Decoder + Skip Connections", 28, "#FFE082", FONT_B, (80, 20)),
(
"Niebieski = Encoder (↓ zmniejsza rozdzielczość, wyciąga cechy)",
16,
"#64B5F6",
FONT_R,
(80, 65),
),
(
"Zielony = Decoder (↑ zwiększa rozdzielczość, odtwarza mapę)",
16,
"#A5D6A7",
FONT_R,
(80, 90),
),
(
"Żółte przerywane = Skip connections (przenoszą detale z encodera)",
16,
"#FFE082",
FONT_R,
(80, 115),
),
(
"Czerwony = Bottleneck (najgłębsza warstwa, max abstrakcja)",
16,
"#EF9A9A",
FONT_R,
(450, 570),
),
(
"Kształt U: encoder ↓ decoder ↑, mosty pośrodku",
18,
"white",
FONT_R,
(80, 640),
),
(
"Concatenation: skip łączy kanały (więcej informacji niż dodawanie)",
16,
"#78909C",
FONT_R,
(80, 670),
),
]
return [_compose_slide(unet_clip, labels, dur)]
# ── FCN Architecture ─────────────────────────────────────────────
def _draw_pipeline_blocks(
frame: np.ndarray,
blocks: list[tuple[tuple[int, int], tuple[int, int], tuple[int, int, int]]],
n_visible: int,
arrow_limit: int,
) -> None:
"""Draw coloured blocks with connecting arrows."""
for i, ((bx, by), (bw, bh), color) in enumerate(blocks):
if i < n_visible:
frame[by : by + bh, bx : bx + bw] = color
frame[by : by + 2, bx : bx + bw] = tuple(min(c + 50, 255) for c in color)
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
if i < arrow_limit:
ax = bx + bw + 3
ay = by + bh // 2
frame[ay - 1 : ay + 2, ax : ax + 12] = (150, 150, 170)
def _draw_red_cross(
frame: np.ndarray,
x_start: int,
width: int,
top_y: int,
height: int,
) -> None:
"""Draw a red X across the given rectangle."""
for d in range(-2, 3):
for step in range(height):
x1 = x_start + int(step * width / height)
y1 = top_y + step + d
if 0 <= y1 < H and 0 <= x1 < W:
frame[y1, x1] = (255, 80, 80)
y2 = top_y + height - step + d
if 0 <= y2 < H and 0 <= x1 < W:
frame[y2, x1] = (255, 80, 80)
def _make_fcn_frame(t: float) -> np.ndarray:
"""Render a single FCN comparison frame."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.8), 1.0)
top_y = 140
blocks_classic = [
((80, top_y), (70, 50), (70, 130, 200)),
((170, top_y), (50, 40), (50, 100, 160)),
((240, top_y), (60, 50), (70, 130, 200)),
((320, top_y), (40, 35), (50, 100, 160)),
((385, top_y), (55, 50), (160, 80, 60)),
((465, top_y), (55, 50), (180, 60, 60)),
((545, top_y), (80, 50), (200, 80, 80)),
]
n_top = min(int(progress * 7) + 1, 7)
arrow_limit = 6
_draw_pipeline_blocks(frame, blocks_classic, n_top, arrow_limit)
cross_phase = 0.6
if progress > cross_phase:
_draw_red_cross(frame, 385, 135, top_y, 50)
bot_y = 380
blocks_fcn = [
((80, bot_y), (70, 50), (70, 130, 200)),
((170, bot_y), (50, 40), (50, 100, 160)),
((240, bot_y), (60, 50), (70, 130, 200)),
((320, bot_y), (40, 35), (50, 100, 160)),
((385, bot_y), (70, 50), (80, 200, 120)),
((480, bot_y), (75, 50), (200, 160, 80)),
((580, bot_y), (80, 50), (100, 200, 100)),
]
fcn_phase = 0.4
if progress > fcn_phase:
n_bot = min(int((progress - fcn_phase) / 0.6 * 7) + 1, 7)
_draw_pipeline_blocks(frame, blocks_fcn, n_bot, arrow_limit)
return frame
def _fcn_demo() -> list[CompositeVideoClip]:
"""Animate FCN step-by-step: FC → Conv 1x1 transformation."""
dur = STEP_DUR + 1
fcn_clip = VideoClip(_make_fcn_frame, duration=dur).with_fps(FPS)
labels = [
("FCN: Fully Convolutional Network (2015)", 26, "#FFE082", FONT_B, (80, 20)),
("KROK 1: Zamień FC → Conv 1x1", 18, "#A5D6A7", FONT_R, (80, 60)),
("Klasyczny CNN:", 16, "#EF9A9A", FONT_B, (80, 105)),
("Conv", 11, "white", FONT_R, (92, 148)),
("Pool", 11, "white", FONT_R, (178, 148)),
("Conv", 11, "white", FONT_R, (250, 148)),
("Pool", 11, "white", FONT_R, (325, 148)),
("Flatten", 11, "#EF9A9A", FONT_R, (390, 148)),
("FC", 11, "#EF9A9A", FONT_R, (480, 148)),
("1 label", 11, "#EF9A9A", FONT_R, (555, 148)),
("FCN:", 16, "#A5D6A7", FONT_B, (80, 350)),
("Conv", 11, "white", FONT_R, (92, 388)),
("Pool", 11, "white", FONT_R, (178, 388)),
("Conv", 11, "white", FONT_R, (250, 388)),
("Pool", 11, "white", FONT_R, (325, 388)),
("Conv1x1", 11, "#A5D6A7", FONT_R, (390, 388)),
("Upsample", 11, "#FFE082", FONT_R, (486, 388)),
("Mapa", 11, "#A5D6A7", FONT_R, (595, 388)),
(
"FC: spłaszcza 3D→1D, wymusza stały rozmiar → 1 etykieta",
16,
"#EF9A9A",
FONT_R,
(80, 250),
),
(
"Conv1x1: działa per piksel x kanały → DOWOLNY rozmiar → mapa klasy",
16,
"#A5D6A7",
FONT_R,
(80, 460),
),
(
"KROK 2: Skip connections — łączą wczesne detale z późną abstrakcją",
17,
"#64B5F6",
FONT_R,
(80, 510),
),
(
"Wczesne warstwy = krawędzie, tekstury | Późne = koncepty obiektów",
15,
"#78909C",
FONT_R,
(80, 545),
),
(
"FCN = PIERWSZA sieć end-to-end do segmentacji per-piksel!",
18,
"white",
FONT_R,
(80, 590),
),
(
"Mnemonik: FC → Conv 1x1 = otwieramy bramkę dla DOWOLNEGO rozmiaru",
16,
"#FFE082",
FONT_R,
(80, 640),
),
]
slides = [_compose_slide(fcn_clip, labels, dur)]
# Slide 2: FCN skip connections step by step
skip_lines = [
("FCN: Skip Connections — krok po kroku", 26, "#FFE082", FONT_B, (80, 30)),
(
"1. Encoder zmniejsza: 224→112→56→28→14 (pooling)",
18,
"#64B5F6",
FONT_R,
(100, 100),
),
(
" Każdy pooling traci detale przestrzenne (dokładne krawędzie)",
15,
"#78909C",
FONT_R,
(100, 135),
),
(
"2. Decoder powiększa: 14→28→56→112→224 (upsample/deconv)",
18,
"#A5D6A7",
FONT_R,
(100, 190),
),
(
" Upsample ODGADUJE piksele — rozmyty wynik!",
15,
"#78909C",
FONT_R,
(100, 225),
),
(
"3. Skip connections: dodaj cechy z encodera do decodera",
18,
"#FFE082",
FONT_R,
(100, 280),
),
(
" Wczesne cechy = GDZIE (precyzyjne krawędzie)",
15,
"#64B5F6",
FONT_R,
(100, 315),
),
(
" Późne cechy = CO (abstrakcyjne koncepty)",
15,
"#A5D6A7",
FONT_R,
(100, 345),
),
(
" Skip = daje decoderowi OBA → ostry wynik!",
15,
"#FFE082",
FONT_R,
(100, 375),
),
(
"Warianty: FCN-32s (brak skip, rozmyty) → FCN-16s → FCN-8s (najlepszy)",
16,
"#B0BEC5",
FONT_R,
(80, 440),
),
(
"FCN-32s: upsample 32x naraz → ROZMYTE granice",
15,
"#EF9A9A",
FONT_R,
(100, 485),
),
(
"FCN-16s: skip z pool4 + upsample 16x → lepiej",
15,
"#FFE082",
FONT_R,
(100, 520),
),
(
"FCN-8s: skip z pool3+pool4 + upsample 8x → OSTRE granice!",
15,
"#A5D6A7",
FONT_R,
(100, 555),
),
(
"Im więcej skip connections → tym więcej "
"detali z encodera → ostrzejszy wynik",
17,
"white",
FONT_R,
(80, 620),
),
]
slides.append(_text_slide(skip_lines, duration=STEP_DUR + 1))
return slides

View File

@ -0,0 +1,332 @@
"""Classical detection methods: detection concept, HOG+SVM, Viola-Jones."""
from __future__ import annotations
from _q24_common import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_tc,
)
from moviepy import CompositeVideoClip, VideoClip
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
# ── Detection concept ────────────────────────────────────────────
def _detection_concept() -> list[CompositeVideoClip]:
"""Show what detection is: bounding box + class + confidence."""
slides = []
def make_det_frame(_t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
# Draw a "scene" with colored rectangles representing objects
# Sky background area
frame[140:500, 100:700] = (40, 50, 70)
# "Car" object
frame[350:430, 150:320] = (180, 60, 60)
# "Person" object
frame[280:440, 450:520] = (60, 120, 180)
# "Tree" object
frame[200:400, 580:650] = (40, 130, 50)
# Bounding boxes (with labels drawn as colored borders)
# Car bbox
for thickness in range(3):
t = thickness
frame[348 - t : 432 + t, 148 - t : 148 - t + 2] = (255, 80, 80)
frame[348 - t : 432 + t, 322 + t - 2 : 322 + t] = (255, 80, 80)
frame[348 - t : 348 - t + 2, 148 - t : 322 + t] = (255, 80, 80)
frame[432 + t - 2 : 432 + t, 148 - t : 322 + t] = (255, 80, 80)
# Person bbox
for thickness in range(3):
t = thickness
frame[278 - t : 442 + t, 448 - t : 448 - t + 2] = (80, 180, 255)
frame[278 - t : 442 + t, 522 + t - 2 : 522 + t] = (80, 180, 255)
frame[278 - t : 278 - t + 2, 448 - t : 522 + t] = (80, 180, 255)
frame[442 + t - 2 : 442 + t, 448 - t : 522 + t] = (80, 180, 255)
# Tree bbox
for thickness in range(3):
t = thickness
frame[198 - t : 402 + t, 578 - t : 578 - t + 2] = (80, 220, 100)
frame[198 - t : 402 + t, 652 + t - 2 : 652 + t] = (80, 220, 100)
frame[198 - t : 198 - t + 2, 578 - t : 652 + t] = (80, 220, 100)
frame[402 + t - 2 : 402 + t, 578 - t : 652 + t] = (80, 220, 100)
# Comparison boxes on right side
# Classification
frame[180:260, 800:1150] = (35, 45, 65)
# Detection
frame[290:370, 800:1150] = (35, 45, 65)
# Segmentation
frame[400:480, 800:1150] = (35, 45, 65)
return frame
det_clip = VideoClip(make_det_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [det_clip]
labels = [
("Detekcja obiektów — co to jest?", 28, "#FFE082", FONT_B, (100, 20)),
("Wynik: (klasa, bounding box, pewność)", 20, "#B0BEC5", FONT_R, (100, 65)),
("samochód 95%", 14, "#EF9A9A", FONT_B, (150, 340)),
("osoba 88%", 14, "#64B5F6", FONT_B, (450, 268)),
("drzewo 72%", 14, "#A5D6A7", FONT_B, (580, 188)),
("Klasyfikacja: cały obraz → 1 etykieta", 15, "#78909C", FONT_R, (810, 210)),
("Detekcja: bbox + klasa + pewność", 15, "#FFE082", FONT_R, (810, 320)),
("Segmentacja: maska per piksel", 15, "#78909C", FONT_R, (810, 430)),
("← granulacja rośnie →", 14, "#90CAF9", FONT_R, (810, 520)),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── HOG + SVM pipeline ───────────────────────────────────────────
def _hog_svm_demo() -> list[CompositeVideoClip]:
"""Animate HOG feature computation and SVM classification."""
slides = []
def make_hog_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.8), 1.0)
# Pipeline stages as boxes with arrows
stages = [
("Gradient", (80, 250), (130, 80), (100, 160, 220)),
("Orientacja", (260, 250), (130, 80), (80, 180, 140)),
("Komórki 8x8", (440, 250), (130, 80), (200, 160, 80)),
("Bloki 2x2", (620, 250), (130, 80), (200, 120, 60)),
("Normalizacja", (800, 250), (130, 80), (180, 100, 80)),
("SVM", (980, 250), (130, 80), (220, 80, 80)),
]
n_active = int(progress * len(stages)) + 1
for i, (_label, (sx, sy), (sw, sh), color) in enumerate(stages):
if i < n_active:
frame[sy : sy + sh, sx : sx + sw] = color
# Border
frame[sy : sy + 2, sx : sx + sw] = tuple(
min(c + 60, 255) for c in color
)
frame[sy + sh - 2 : sy + sh, sx : sx + sw] = tuple(
min(c + 60, 255) for c in color
)
# Arrow to next
if i < len(stages) - 1:
ax = sx + sw + 5
ay = sy + sh // 2
frame[ay - 1 : ay + 2, ax : ax + 20] = (150, 150, 170)
# Show gradient computation example at bottom
gradient_phase = 0.2
if progress > gradient_phase:
# Mini pixel grid showing gradient computation
gx, gy = 100, 430
pixels = [50, 50, 200]
for idx, val in enumerate(pixels):
x = gx + idx * 50
frame[gy : gy + 40, x : x + 40] = (val, val, val)
return frame
hog_clip = VideoClip(make_hog_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [hog_clip]
labels = [
("HOG + SVM — pipeline detekcji pieszych", 28, "#FFE082", FONT_B, (80, 20)),
(
"Mnemonik: GOKBN = Gradienty→Orientacja→Komórki→Bloki→Normalizacja",
16,
"#A5D6A7",
FONT_R,
(80, 65),
),
("Gradient: siła i kierunek zmiany jasności", 14, "#64B5F6", FONT_R, (80, 95)),
(
"Histogram: 9 binów (0°-180°, co 20°) per komórka 8x8",
14,
"#78909C",
FONT_R,
(80, 120),
),
(
"[50][50][200] → Gx = 200-50 = 150 = silna krawędź!",
16,
"#EF9A9A",
FONT_R,
(80, 490),
),
(
"Wektor HOG (3780 cech) → SVM: pieszy (+1) / tło (-1)",
16,
"white",
FONT_R,
(80, 540),
),
(
"Sliding window 64x128 przesuwa się po obrazie → NMS → wynik",
16,
"#90CAF9",
FONT_R,
(80, 580),
),
(
"SVM = LINIA MAKSYMALNEGO ODDECHU (max margines, support vectors)",
16,
"#FFE082",
FONT_R,
(80, 620),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── Viola-Jones ───────────────────────────────────────────────────
def _viola_jones_demo() -> list[CompositeVideoClip]:
"""Animate Viola-Jones cascade concept."""
slides = []
def make_cascade_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.8), 1.0)
# Draw cascade "funnel" — stages filtering out non-faces
stages = 5
start_width = 1000
start_count = 10000
x_center = W // 2
for i in range(stages):
stage_progress = min(progress * stages - i, 1.0)
if stage_progress <= 0:
break
width = int(start_width * (1 - i * 0.18))
int(start_count * (0.3**i))
y = 150 + i * 100
h_box = 60
# Stage box
x1 = x_center - width // 2
frame[y : y + h_box, x1 : x1 + width] = (
50 + i * 10,
60 + i * 10,
80 + i * 10,
)
# Border
frame[y : y + 2, x1 : x1 + width] = (100 + i * 20, 130 + i * 15, 200)
frame[y + h_box - 2 : y + h_box, x1 : x1 + width] = (
100 + i * 20,
130 + i * 15,
200,
)
# Arrow down to next
if i < stages - 1:
frame[y + h_box + 5 : y + h_box + 25, x_center - 1 : x_center + 2] = (
150,
150,
170,
)
# Red "rejected" arrows on sides
if i > 0:
# Left reject arrow
rx = x1 - 30
ry = y + h_box // 2
frame[ry - 1 : ry + 2, rx : rx + 25] = (200, 80, 80)
return frame
cascade_clip = VideoClip(make_cascade_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [cascade_clip]
labels = [
(
"Viola-Jones — kaskada klasyfikatorów (2001)",
28,
"#FFE082",
FONT_B,
(80, 20),
),
(
"3 innowacje: HIC = Haar + Integral Image + Cascade",
20,
"#B0BEC5",
FONT_R,
(80, 65),
),
("Etap 1: 2 cechy Haar", 14, "#64B5F6", FONT_R, (170, 170)),
("Etap 2: 10 cech", 14, "#64B5F6", FONT_R, (210, 270)),
("Etap 3: 25 cech", 14, "#64B5F6", FONT_R, (240, 370)),
("Etap 4: 50 cech", 14, "#64B5F6", FONT_R, (260, 470)),
("→ TWARZ!", 16, "#A5D6A7", FONT_B, (590, 560)),
(
"SITO: 99% okien odpada w pierwszych 3 etapach → REAL-TIME!",
16,
"#EF9A9A",
FONT_R,
(80, 620),
),
(
"Haar: kontrast jasna/ciemna | Integral Image: "
"suma prostokąta O(1) = 4 odczyty",
14,
"#78909C",
FONT_R,
(80, 655),
),
("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 275)),
("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 375)),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides

View File

@ -0,0 +1,115 @@
"""Shared constants and helpers for Q24 object detection visualization."""
from __future__ import annotations
import logging
import os
from pathlib import Path
import numpy as np
os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg"
from moviepy import (
ColorClip,
CompositeVideoClip,
TextClip,
VideoClip,
)
from moviepy.video.fx import FadeIn, FadeOut
# ── Constants ─────────────────────────────────────────────────────
W, H = 1280, 720
FPS = 24
STEP_DUR = 7.0
HEADER_DUR = 4.0
FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf"
FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf"
OUTPUT_DIR = Path(__file__).resolve().parent / "videos"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT = str(OUTPUT_DIR / "q24_object_detection.mp4")
BG_COLOR = (15, 20, 35)
_logger = logging.getLogger(__name__)
# Re-export numpy for sub-modules that need it alongside constants.
__all__ = [
"BG_COLOR",
"FONT_B",
"FONT_R",
"FPS",
"HEADER_DUR",
"OUTPUT",
"OUTPUT_DIR",
"STEP_DUR",
"H",
"W",
"_logger",
"_make_header",
"_tc",
"_text_slide",
"np",
]
def _tc(**kwargs: object) -> TextClip:
"""TextClip wrapper that adds enough bottom margin to prevent clipping."""
fs = kwargs.get("font_size", 24)
m = int(fs) // 3 + 2
kwargs["margin"] = (0, m)
return TextClip(**kwargs)
def _make_header(
title: str, subtitle: str, duration: float = HEADER_DUR
) -> CompositeVideoClip:
"""Create a title/subtitle header slide."""
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
t = (
_tc(
text=title,
font_size=48,
color="white",
font=FONT_B,
)
.with_duration(duration)
.with_position(("center", 260))
)
s = (
_tc(
text=subtitle,
font_size=24,
color="#90CAF9",
font=FONT_R,
)
.with_duration(duration)
.with_position(("center", 340))
)
return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects(
[FadeIn(0.5), FadeOut(0.5)]
)
def _text_slide(
lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]],
duration: float = STEP_DUR,
) -> CompositeVideoClip:
"""Create a text-only slide from a list of (text, size, color, font, pos)."""
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
clips: list[VideoClip] = [bg]
for text, font_size, color, font, pos in lines:
tc = (
_tc(
text=text,
font_size=font_size,
color=color,
font=font,
)
.with_duration(duration)
.with_position(pos)
)
clips.append(tc)
return CompositeVideoClip(clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)

View File

@ -0,0 +1,239 @@
"""NMS/IoU, detector-from-classifier, and methods comparison."""
from __future__ import annotations
from _q24_common import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_tc,
_text_slide,
)
from moviepy import ColorClip, CompositeVideoClip, VideoClip
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
# ── NMS + IoU ─────────────────────────────────────────────────────
def _nms_iou_demo() -> list[CompositeVideoClip]:
"""Animate NMS and IoU concepts."""
slides = []
def make_nms_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
# Draw overlapping bounding boxes
ox, oy = 100, 200
obj_w, obj_h = 150, 120
# Multiple overlapping detections for same object
boxes = [
(ox, oy, obj_w, obj_h, 0.95, (255, 80, 80)), # best
(ox + 15, oy - 10, obj_w + 10, obj_h + 5, 0.90, (200, 60, 60)),
(ox - 10, oy + 5, obj_w - 5, obj_h + 10, 0.85, (160, 50, 50)),
]
# Different object far away
boxes.append((ox + 350, oy + 50, 100, 100, 0.40, (80, 180, 255)))
for i, (bx, by, bw, bh, _conf, color) in enumerate(boxes):
dc = color
nms_phase = 0.4
nms_limit = 3
if progress > nms_phase and i > 0 and i < nms_limit:
# After NMS, these get removed (shown as faded/crossed)
dc = (60, 40, 40)
for tt in range(2):
frame[by - tt : by + bh + tt, bx - tt : bx - tt + 2] = dc
frame[by - tt : by + bh + tt, bx + bw + tt - 2 : bx + bw + tt] = dc
frame[by - tt : by - tt + 2, bx - tt : bx + bw + tt] = dc
frame[by + bh + tt - 2 : by + bh + tt, bx - tt : bx + bw + tt] = dc
# IoU visualization on right side
iou_x, iou_y = 700, 200
# Box A
frame[iou_y : iou_y + 100, iou_x : iou_x + 100] = (80, 80, 200)
# Box B (overlapping)
frame[iou_y + 40 : iou_y + 140, iou_x + 40 : iou_x + 140] = (200, 80, 80)
# Intersection highlighted
frame[iou_y + 40 : iou_y + 100, iou_x + 40 : iou_x + 100] = (200, 150, 200)
return frame
nms_clip = VideoClip(make_nms_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [nms_clip]
labels = [
("NMS (Non-Maximum Suppression) + IoU", 28, "#FFE082", FONT_B, (80, 20)),
(
"NMS = Najlepszy Ma Się dobrze — zachowaj najlepszą, usuń duplikaty",
18,
"#B0BEC5",
FONT_R,
(80, 65),
),
("conf=0.95 ✓", 14, "#A5D6A7", FONT_B, (100, 340)),
("0.90 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 365)),
("0.85 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 390)),
("0.40 ✓ INNY obiekt", 13, "#64B5F6", FONT_R, (100, 420)),
("IoU = Intersection over Union", 18, "#FFE082", FONT_B, (700, 160)),
("IoU = pole(∩) / pole(AUB)", 16, "white", FONT_R, (700, 380)),
("Fioletowy = intersection", 14, "#CE93D8", FONT_R, (700, 410)),
("IoU > 0.5 → TEN SAM obiekt → usuń", 14, "#EF9A9A", FONT_R, (700, 440)),
("IoU < 0.5 → INNY obiekt → zachowaj", 14, "#A5D6A7", FONT_R, (700, 470)),
(
"DETR: jedyny detektor BEZ NMS (Hungarian matching zamiast tego)",
14,
"#78909C",
FONT_R,
(80, 620),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── Detector from Classifier ─────────────────────────────────────
def _detector_from_classifier() -> list[CompositeVideoClip]:
"""Show 3 approaches to building a detector from a classifier."""
slides = []
approaches = [
(
"Podejście 1: Sliding Window (NAJWOLNIEJSZE)",
[
("Okno przesuwa się po obrazie w wielu skalach", "#B0BEC5"),
("Każde okno → klasyfikator (np. ResNet) → klasa + pewność", "#B0BEC5"),
("~18 000 okien x 10ms = ~3 minuty na obraz!", "#EF9A9A"),
("Mnemonik: WYCINAJ i PYTAJ — jak wycinanie ciasteczek", "#FFE082"),
],
"SRF",
),
(
"Podejście 2: Region Proposals (= R-CNN)",
[
("Selective Search → ~2000 inteligentnych regionów", "#B0BEC5"),
("Każdy region → CNN → wektor cech → SVM klasyfikuje", "#B0BEC5"),
("~2000 x 10ms = ~20 sec — 9x szybciej!", "#64B5F6"),
(
"Mnemonik: INTELIGENTNE CIĘCIE — wytnij tylko tam gdzie wiśnie",
"#FFE082",
),
],
"SRF",
),
(
"Podejście 3: Fine-tune backbone (NAJLEPSZE)",
[
(
"Pretrained backbone (ResNet) → odetnij FC → dodaj detection head",
"#B0BEC5",
),
(
"Detection head = głowica klasyfikacji + głowica regresji bbox",
"#B0BEC5",
),
("~0.2 sec/obraz, najlepsza jakość (mAP ~42%)", "#A5D6A7"),
("Mnemonik: PRZESZCZEP GŁOWY — ten sam silnik, nowa głowa", "#FFE082"),
],
"SRF",
),
]
for title, points, _mnem in approaches:
lines = [
(title, 24, "#FFE082", FONT_B, (80, 140)),
]
for i, (text, color) in enumerate(points):
lines.append((f"{text}", 18, color, FONT_R, (100, 220 + i * 50)))
lines.append(
(
"Detektor z klasyfikatora: SRF = Sliding → Region → Fine-tune",
16,
"#78909C",
FONT_R,
(80, 520),
)
)
lines.append(
(
"= Szukaj Ręcznie, Finalnie optymalizuj!",
16,
"#90CAF9",
FONT_R,
(80, 550),
)
)
slides.append(_text_slide(lines, duration=STEP_DUR))
return slides
# ── Methods comparison ────────────────────────────────────────────
def _methods_comparison() -> CompositeVideoClip:
"""Create a comparison table of all detection methods."""
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0)
title = (
_tc(
text="Porównanie detektorów",
font_size=36,
color="white",
font=FONT_B,
)
.with_duration(10.0)
.with_position(("center", 20))
)
rows = [
("Model", "Rok", "Typ", "Szybkość", "Kluczowe"),
("HOG+SVM", "2005", "Klasyczny", "~1 fps", "Gradient histogramy"),
("Viola-Jones", "2001", "Klasyczny", "30+ fps", "Haar+Cascade"),
("R-CNN", "2014", "Two-stage", "50 sec!", "CNN per region"),
("Fast R-CNN", "2015", "Two-stage", "2 sec", "ROI Pooling"),
("Faster R-CNN", "2015", "Two-stage", "5 fps", "RPN w sieci"),
("YOLO", "2016", "One-stage", "45+ fps", "Siatka SxS"),
("DETR", "2020", "Transformer", "~40 fps", "Bez NMS!"),
]
clips: list[VideoClip] = [bg, title]
for i, row in enumerate(rows):
y_pos = 75 + i * 72
col_x = [40, 200, 280, 400, 530]
for j, cell in enumerate(row):
fs = 16 if i > 0 else 18
color = "#64B5F6" if i == 0 else "#E0E0E0"
tc = (
_tc(
text=cell,
font_size=fs,
color=color,
font=FONT_B if i == 0 else FONT_R,
)
.with_duration(10.0)
.with_position((col_x[j], y_pos))
)
clips.append(tc)
return CompositeVideoClip(clips, size=(W, H)).with_effects(
[FadeIn(0.5), FadeOut(0.5)]
)

View File

@ -0,0 +1,405 @@
"""R-CNN family: evolution, detailed pipeline, ROI pooling."""
from __future__ import annotations
from _q24_common import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_tc,
)
from moviepy import CompositeVideoClip, VideoClip
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
# ── R-CNN Evolution ───────────────────────────────────────────────
def _rcnn_evolution() -> list[CompositeVideoClip]:
"""Animate R-CNN → Fast R-CNN → Faster R-CNN evolution."""
slides = []
def make_evolution_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.8), 1.0)
# Three rows: R-CNN, Fast R-CNN, Faster R-CNN
models = [
(
"R-CNN (2014)",
50,
[
("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)),
("2000x\nCNN", (350, 150), (80, 50), (180, 60, 60)),
("2000x\nSVM", (480, 150), (80, 50), (180, 60, 60)),
("NMS", (610, 150), (60, 50), (100, 140, 100)),
],
"50 sec/obraz!",
),
(
"Fast R-CNN (2015)",
300,
[
("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)),
("1x CNN\n(cały obraz)", (350, 150), (100, 50), (80, 140, 200)),
("ROI Pool\n(2000)", (500, 150), (90, 50), (200, 160, 80)),
("FC", (640, 150), (50, 50), (100, 140, 100)),
],
"2 sec/obraz",
),
(
"Faster R-CNN (2015)",
300,
[
("CNN\nbackbone", (200, 150), (90, 50), (80, 140, 200)),
("RPN\n(~300)", (340, 150), (80, 50), (200, 120, 60)),
("ROI Pool", (470, 150), (80, 50), (200, 160, 80)),
("FC", (600, 150), (50, 50), (100, 140, 100)),
],
"0.2 sec → 5 fps!",
),
]
n_models = int(progress * 3) + 1
for mi, (_name, base_y, stages, _speed) in enumerate(models):
if mi >= n_models:
break
for _label, (bx, by_off), (bw, bh), color in stages:
by = base_y + by_off - 150
frame[by : by + bh, bx : bx + bw] = color
frame[by : by + 2, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
# Arrows between stages
for si in range(len(stages) - 1):
sx = stages[si][1][0] + stages[si][2][0]
ex = stages[si + 1][1][0]
ay = base_y + 25
frame[ay - 1 : ay + 2, sx + 3 : ex - 3] = (150, 150, 170)
return frame
evo_clip = VideoClip(make_evolution_frame, duration=STEP_DUR + 1).with_fps(FPS)
text_clips: list[VideoClip] = [evo_clip]
labels = [
("Ewolucja R-CNN — CORAZ MNIEJ MARNOWANIA", 28, "#FFE082", FONT_B, (80, 20)),
("R-CNN (2014)", 20, "#EF9A9A", FONT_B, (50, 80)),
("50 sec/obraz (2000x forward pass!)", 14, "#EF9A9A", FONT_R, (720, 100)),
("Fast R-CNN (2015)", 20, "#64B5F6", FONT_B, (50, 330)),
("2 sec/obraz (CNN raz + ROI Pool)", 14, "#64B5F6", FONT_R, (720, 350)),
("Faster R-CNN (2015)", 20, "#A5D6A7", FONT_B, (50, 580)),
("0.2 sec → 5 fps (RPN w sieci!)", 14, "#A5D6A7", FONT_R, (720, 600)),
(
"Kluczowe innowacje: ROI Pooling → stały rozmiar "
"| RPN → propozycje w sieci",
14,
"#78909C",
FONT_R,
(80, 660),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR + 1)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── R-CNN Detailed Pipeline ──────────────────────────────────────
def _rcnn_detailed() -> list[CompositeVideoClip]:
"""Animate R-CNN step-by-step pipeline in detail."""
slides = []
# Slide 1: R-CNN pipeline step by step
def make_rcnn_pipeline(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.8), 1.0)
# Step boxes arranged vertically with arrows
steps = [
((80, 130), (200, 55), (120, 100, 60), "1. Selective Search"),
((80, 230), (200, 55), (180, 60, 60), "2. Wytnij 2000 regionów"),
((80, 330), (200, 55), (70, 130, 200), "3. CNN per region"),
((80, 430), (200, 55), (200, 100, 80), "4. SVM klasyfikuje"),
((80, 530), (200, 55), (100, 180, 100), "5. Bbox regresja + NMS"),
]
n_steps = min(int(progress * 5) + 1, 5)
for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(steps):
if i < n_steps:
frame[by : by + bh, bx : bx + bw] = color
frame[by : by + 2, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
# Arrow down
arrow_limit = 4
if i < arrow_limit:
ax = bx + bw // 2
ay = by + bh + 5
frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170)
# Illustration: many overlapping regions from Selective Search
overlay_phase = 0.2
if progress > overlay_phase:
rng_local = np.random.default_rng(42)
n_boxes = min(int((progress - 0.2) * 15), 8)
for i in range(n_boxes):
rx = 500 + rng_local.integers(-30, 100)
ry = 200 + rng_local.integers(-20, 120)
rw = 60 + rng_local.integers(0, 80)
rh = 50 + rng_local.integers(0, 70)
c = (80 + i * 15, 100 + i * 10, 60 + i * 20)
for tt in range(2):
frame[ry - tt : ry + rh + tt, rx - tt : rx - tt + 2] = c
frame[ry - tt : ry + rh + tt, rx + rw + tt - 2 : rx + rw + tt] = c
frame[ry - tt : ry - tt + 2, rx - tt : rx + rw + tt] = c
frame[ry + rh + tt - 2 : ry + rh + tt, rx - tt : rx + rw + tt] = c
return frame
rcnn_clip = VideoClip(make_rcnn_pipeline, duration=STEP_DUR + 1).with_fps(FPS)
dur = STEP_DUR + 1
labels = [
("R-CNN: krok po kroku (2014, Girshick)", 26, "#FFE082", FONT_B, (80, 20)),
("Pipeline detekcji two-stage", 16, "#B0BEC5", FONT_R, (80, 60)),
("Selective Search", 11, "white", FONT_R, (105, 145)),
("2000 regionów", 11, "white", FONT_R, (105, 245)),
("CNN per region", 11, "white", FONT_R, (105, 345)),
("SVM klasyfikuje", 11, "white", FONT_R, (105, 445)),
("Regresja + NMS", 11, "white", FONT_R, (105, 545)),
("~2000 propozycji regionów", 14, "#78909C", FONT_R, (500, 155)),
("(inteligentne łączenie", 13, "#78909C", FONT_R, (500, 180)),
("podobnych fragmentów)", 13, "#78909C", FONT_R, (500, 200)),
("Problem: 2000 x CNN forward pass", 16, "#EF9A9A", FONT_R, (400, 400)),
("= 50 SEKUND na obraz!", 18, "#EF9A9A", FONT_B, (400, 430)),
("CNN liczy cechy per region OSOBNO", 14, "#EF9A9A", FONT_R, (400, 470)),
(
"→ regiony się nakładają → obliczenia się powtarzają!",
14,
"#EF9A9A",
FONT_R,
(400, 495),
),
(
"Rozwiązanie: CNN raz na cały obraz → Fast R-CNN →",
16,
"#A5D6A7",
FONT_R,
(80, 620),
),
]
text_clips: list[VideoClip] = [rcnn_clip]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(dur)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── ROI Pooling ──────────────────────────────────────────────────
def _draw_roi_pool_grid(frame: np.ndarray) -> None:
"""Draw the 3x3 ROI pool grid with max-pooled feature values."""
out_x, out_y = 400, 220
out_cell = 50
out_n = 3
roi_r1, roi_c1 = 2, 1
roi_r2, roi_c2 = 6, 5
roi_h = roi_r2 - roi_r1
roi_w = roi_c2 - roi_c1
for r in range(out_n):
for c in range(out_n):
x = out_x + c * out_cell
y = out_y + r * out_cell
# Compute the max from corresponding region
src_r1 = roi_r1 + r * roi_h // out_n
src_r2 = roi_r1 + (r + 1) * roi_h // out_n
src_c1 = roi_c1 + c * roi_w // out_n
src_c2 = roi_c1 + (c + 1) * roi_w // out_n
max_val = 0
for sr in range(src_r1, src_r2):
for sc in range(src_c1, src_c2):
v = 30 + ((sr * 7 + sc * 13 + 42) % 40)
max_val = max(max_val, v)
frame[y : y + out_cell - 2, x : x + out_cell - 2] = (
max_val,
max_val + 20,
max_val + 40,
)
frame[y : y + 2, x : x + out_cell - 2] = (80, 200, 120)
frame[y + out_cell - 4 : y + out_cell - 2, x : x + out_cell - 2] = (
80,
200,
120,
)
def _make_roi_frame(t: float) -> np.ndarray:
"""Render a single frame for the ROI pooling animation."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
# Left: feature map with ROI highlighted
fm_x, fm_y = 60, 180
fm_cell = 30
fm_grid = 8
for r in range(fm_grid):
for c in range(fm_grid):
x = fm_x + c * fm_cell
y = fm_y + r * fm_cell
# Random-looking feature values
val = 30 + ((r * 7 + c * 13 + 42) % 40)
frame[y : y + fm_cell - 1, x : x + fm_cell - 1] = (
val,
val + 10,
val + 20,
)
# ROI region highlighted
roi_r1, roi_c1 = 2, 1
roi_r2, roi_c2 = 6, 5
for tt in range(3):
ry1 = fm_y + roi_r1 * fm_cell - tt
ry2 = fm_y + roi_r2 * fm_cell + tt
rx1 = fm_x + roi_c1 * fm_cell - tt
rx2 = fm_x + roi_c2 * fm_cell + tt
frame[ry1:ry2, rx1 : rx1 + 2] = (255, 200, 50)
frame[ry1:ry2, rx2 - 2 : rx2] = (255, 200, 50)
frame[ry1 : ry1 + 2, rx1:rx2] = (255, 200, 50)
frame[ry2 - 2 : ry2, rx1:rx2] = (255, 200, 50)
# Arrow
arrow_phase = 0.3
if progress > arrow_phase:
frame[300:303, 310:380] = (150, 150, 170)
# Middle: ROI divided into 3x3 grid (output_size)
grid_phase = 0.3
if progress > grid_phase:
_draw_roi_pool_grid(frame)
# Arrow to FC
fc_phase = 0.6
if progress > fc_phase:
frame[300:303, 560:630] = (150, 150, 170)
# FC box
frame[270:340, 650:730] = (200, 100, 80)
frame[270:272, 650:730] = (240, 140, 120)
frame[338:340, 650:730] = (240, 140, 120)
return frame
def _roi_pooling_demo() -> list[CompositeVideoClip]:
"""Animate ROI Pooling: key Fast R-CNN innovation."""
slides = []
roi_clip = VideoClip(_make_roi_frame, duration=STEP_DUR + 1).with_fps(FPS)
dur = STEP_DUR + 1
labels = [
("ROI Pooling: kluczowa innowacja Fast R-CNN", 26, "#FFE082", FONT_B, (80, 20)),
(
"KROK 1: CNN raz na CAŁY obraz → feature mapa",
17,
"#64B5F6",
FONT_R,
(80, 60),
),
(
"KROK 2: Wytnij ROI z feature mapy (nie z obrazu!)",
17,
"#FFE082",
FONT_R,
(80, 90),
),
(
"KROK 3: Siatkuj ROI na 3x3 → max pool per komórka → stały rozmiar",
17,
"#A5D6A7",
FONT_R,
(80, 120),
),
("Feature mapa", 14, "#64B5F6", FONT_B, (60, 160)),
("ROI (żółta ramka)", 13, "#FFE082", FONT_R, (60, 440)),
("ROI Pool 3x3", 14, "#A5D6A7", FONT_B, (400, 195)),
("(max z komórki)", 13, "#78909C", FONT_R, (400, 380)),
("FC", 14, "white", FONT_B, (670, 280)),
(
"Problem: ROI mają RÓŻNE rozmiary, FC wymaga STAŁEGO",
15,
"#B0BEC5",
FONT_R,
(80, 500),
),
(
"ROI Pooling: dzieli ROI na siatkę, max pool → STAŁY rozmiar!",
16,
"white",
FONT_R,
(80, 535),
),
(
"Fast R-CNN: CNN raz → 1 feature mapa → "
"ROI Pool 2000 regionów → 25x szybciej!",
16,
"#A5D6A7",
FONT_R,
(80, 580),
),
(
"(R-CNN: 2000x CNN = 50s | Fast R-CNN: 1xCNN + ROI Pool = 2s)",
15,
"#EF9A9A",
FONT_R,
(80, 620),
),
]
text_clips: list[VideoClip] = [roi_clip]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(dur)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides

View File

@ -0,0 +1,383 @@
"""RPN anchor boxes and YOLO grid detection."""
from __future__ import annotations
from _q24_common import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_tc,
_text_slide,
)
from moviepy import CompositeVideoClip, VideoClip
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
# ── RPN + Anchor Boxes ───────────────────────────────────────────
def _rpn_anchors_demo() -> list[CompositeVideoClip]:
"""Animate RPN and anchor boxes: Faster R-CNN innovation."""
slides = []
# Slide 1: Anchor boxes concept
def make_anchors_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
# Draw feature map grid point with multiple anchors
cx, cy = 350, 360 # center point on feature map
# Draw a "feature map" grid background
cell = 60
for r in range(-3, 4):
for c in range(-3, 4):
x = cx + c * cell - cell // 2
y = cy + r * cell - cell // 2
frame[y : y + cell - 1, x : x + cell - 1] = (30, 35, 48)
# Center point highlighted
frame[cy - 5 : cy + 5, cx - 5 : cx + 5] = (255, 200, 50)
# Draw anchors around center: 3 sizes x 3 ratios = 9
anchor_specs = [
(30, 30, (200, 80, 80)), # small 1:1
(20, 40, (200, 60, 60)), # small 1:2
(40, 20, (180, 60, 60)), # small 2:1
(60, 60, (80, 200, 80)), # medium 1:1
(40, 80, (60, 180, 60)), # medium 1:2
(80, 40, (60, 160, 60)), # medium 2:1
(90, 90, (80, 80, 200)), # large 1:1
(60, 120, (60, 60, 180)), # large 1:2
(120, 60, (60, 60, 160)), # large 2:1
]
n_anchors = min(int(progress * 9) + 1, 9)
for i in range(n_anchors):
hw, hh, color = anchor_specs[i]
x1 = max(0, cx - hw)
y1 = max(0, cy - hh)
x2 = min(W - 1, cx + hw)
y2 = min(H - 1, cy + hh)
for tt in range(2):
frame[y1 - tt : y2 + tt, x1 - tt : x1 - tt + 2] = color
frame[y1 - tt : y2 + tt, x2 + tt - 2 : x2 + tt] = color
frame[y1 - tt : y1 - tt + 2, x1 - tt : x2 + tt] = color
frame[y2 + tt - 2 : y2 + tt, x1 - tt : x2 + tt] = color
return frame
anch_clip = VideoClip(make_anchors_frame, duration=STEP_DUR + 1).with_fps(FPS)
dur = STEP_DUR + 1
labels = [
("Anchor Boxes + RPN (Faster R-CNN)", 26, "#FFE082", FONT_B, (80, 20)),
(
"KROK 1: Anchory = predefiniowane kształty w każdej pozycji",
17,
"#A5D6A7",
FONT_R,
(80, 60),
),
(
"3 rozmiary x 3 proporcje = 9 anchorów per punkt",
16,
"#B0BEC5",
FONT_R,
(80, 90),
),
("Małe (1:1, 1:2, 2:1)", 14, "#EF9A9A", FONT_R, (750, 170)),
("Średnie (1:1, 1:2, 2:1)", 14, "#A5D6A7", FONT_R, (750, 210)),
("Duże (1:1, 1:2, 2:1)", 14, "#64B5F6", FONT_R, (750, 250)),
("Żółty punkt = pozycja", 14, "#FFE082", FONT_R, (750, 310)),
("na feature mapie", 14, "#FFE082", FONT_R, (750, 335)),
("Sieć NIE predykuje bbox od zera!", 16, "white", FONT_R, (80, 530)),
(
"Predykuje OFFSET od najbliższego anchora: (Δx, Δy, Δw, Δh)",
16,
"#FFE082",
FONT_R,
(80, 565),
),
(
"+ P(obiekt) = 'czy w tym anchorze jest coś?'",
16,
"#A5D6A7",
FONT_R,
(80, 600),
),
(
"Mnemonik: Anchor = KOTWICA — sieć dopasowuje bbox do kotwicy",
15,
"#78909C",
FONT_R,
(80, 645),
),
]
text_clips: list[VideoClip] = [anch_clip]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(dur)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
# Slide 2: RPN step by step
rpn_lines = [
(
"RPN: Region Proposal Network — krok po kroku",
24,
"#FFE082",
FONT_B,
(80, 30),
),
(
"Zastępuje Selective Search SIECIĄ NEURONOWĄ (end-to-end!)",
17,
"#B0BEC5",
FONT_R,
(80, 85),
),
("", 10, "white", FONT_R, (80, 110)),
(
"1. Backbone (ResNet) przetwarza obraz → feature mapa [40x60x256]",
16,
"#64B5F6",
FONT_R,
(100, 140),
),
(
"2. Filtr 3x3 przesuwa się po feature mapie",
16,
"#A5D6A7",
FONT_R,
(100, 180),
),
(
"3. W KAŻDEJ pozycji (x,y) rozważ k=9 anchorów:",
16,
"#FFE082",
FONT_R,
(100, 220),
),
(" → P(obiekt) — 'czy tu jest coś?'", 15, "white", FONT_R, (120, 255)),
(" → (Δx, Δy, Δw, Δh) — poprawka pozycji", 15, "white", FONT_R, (120, 285)),
(
"4. 40x60 pozycji x 9 anchorów = 21 600 kandydatów!",
16,
"#EF9A9A",
FONT_R,
(100, 325),
),
(
"5. Weź ~300 z najwyższym P(obiekt) → ROI Pool → FC",
16,
"#A5D6A7",
FONT_R,
(100, 365),
),
("", 10, "white", FONT_R, (100, 395)),
("Porównanie generowania propozycji:", 17, "white", FONT_B, (80, 420)),
(
" Selective Search: ~2000 regionów, osobny algorytm, ~2 sec",
15,
"#EF9A9A",
FONT_R,
(100, 460),
),
(
" RPN: ~300 regionów, W SIECI, ~10 ms → 200x szybciej!",
15,
"#A5D6A7",
FONT_R,
(100, 495),
),
("", 10, "white", FONT_R, (100, 520)),
(
"Faster R-CNN = Backbone + RPN + ROI Pool + FC — WSZYSTKO end-to-end",
17,
"#FFE082",
FONT_R,
(80, 545),
),
(
"→ 5 fps (0.2 sec/obraz) vs R-CNN 50 sec = 250x szybciej!",
17,
"#A5D6A7",
FONT_R,
(80, 585),
),
(
"Wciąż two-stage: (1) RPN generuje propozycje, (2) FC klasyfikuje",
15,
"#78909C",
FONT_R,
(80, 630),
),
]
slides.append(_text_slide(rpn_lines, duration=STEP_DUR + 1))
return slides
# ── YOLO ──────────────────────────────────────────────────────────
def _yolo_demo() -> list[CompositeVideoClip]:
"""Animate YOLO grid detection concept."""
slides = []
def make_yolo_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
# Draw image with grid overlay
img_x, img_y = 100, 140
img_size = 420
grid_n = 7
# Background "image"
frame[img_y : img_y + img_size, img_x : img_x + img_size] = (50, 55, 70)
# Objects in the image
frame[img_y + 80 : img_y + 200, img_x + 50 : img_x + 180] = (
180,
60,
60,
) # "car"
frame[img_y + 150 : img_y + 350, img_x + 250 : img_x + 330] = (
60,
120,
180,
) # "person"
# Grid lines
cell = img_size // grid_n
for i in range(grid_n + 1):
# Vertical
x = img_x + i * cell
frame[img_y : img_y + img_size, x : x + 1] = (100, 100, 120)
# Horizontal
y = img_y + i * cell
frame[y : y + 1, img_x : img_x + img_size] = (100, 100, 120)
# Highlight cells containing object centers
car_phase = 0.3
if progress > car_phase:
# Car center ~ cell (1, 1)
cx, cy = 1, 2
hx = img_x + cx * cell
hy = img_y + cy * cell
frame[hy : hy + cell, hx : hx + cell] = np.clip(
frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255
).astype(np.uint8)
person_phase = 0.5
if progress > person_phase:
# Person center ~ cell (4, 4)
cx, cy = 4, 4
hx = img_x + cx * cell
hy = img_y + cy * cell
frame[hy : hy + cell, hx : hx + cell] = np.clip(
frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255
).astype(np.uint8)
# Bounding boxes predictions from cells
bbox_phase = 0.6
if progress > bbox_phase:
# Car bbox
for tt in range(2):
frame[
img_y + 78 - tt : img_y + 202 + tt,
img_x + 48 - tt : img_x + 48 - tt + 2,
] = (255, 80, 80)
frame[
img_y + 78 - tt : img_y + 202 + tt,
img_x + 182 + tt - 2 : img_x + 182 + tt,
] = (255, 80, 80)
frame[
img_y + 78 - tt : img_y + 78 - tt + 2,
img_x + 48 - tt : img_x + 182 + tt,
] = (255, 80, 80)
frame[
img_y + 202 + tt - 2 : img_y + 202 + tt,
img_x + 48 - tt : img_x + 182 + tt,
] = (255, 80, 80)
# Person bbox
for tt in range(2):
frame[
img_y + 148 - tt : img_y + 352 + tt,
img_x + 248 - tt : img_x + 248 - tt + 2,
] = (80, 180, 255)
frame[
img_y + 148 - tt : img_y + 352 + tt,
img_x + 332 + tt - 2 : img_x + 332 + tt,
] = (80, 180, 255)
frame[
img_y + 148 - tt : img_y + 148 - tt + 2,
img_x + 248 - tt : img_x + 332 + tt,
] = (80, 180, 255)
frame[
img_y + 352 + tt - 2 : img_y + 352 + tt,
img_x + 248 - tt : img_x + 332 + tt,
] = (80, 180, 255)
return frame
yolo_clip = VideoClip(make_yolo_frame, duration=STEP_DUR).with_fps(FPS)
text_clips: list[VideoClip] = [yolo_clip]
labels = [
("YOLO — You Only Look Once", 28, "#FFE082", FONT_B, (80, 20)),
(
"Jednoetapowy detektor: siatka SxS → wszystkie detekcje naraz!",
18,
"#B0BEC5",
FONT_R,
(80, 65),
),
("Siatka 7x7 = 49 komórek", 16, "#64B5F6", FONT_R, (600, 180)),
("Każda komórka predykuje:", 16, "white", FONT_R, (600, 220)),
(" • B bbox (x, y, w, h, conf)", 14, "#B0BEC5", FONT_R, (600, 255)),
(" • C klas (prawdopodobieństwa)", 14, "#B0BEC5", FONT_R, (600, 285)),
("Komórka odpowiada za obiekt", 14, "#A5D6A7", FONT_R, (600, 325)),
("którego ŚRODEK w niej wpada", 14, "#A5D6A7", FONT_R, (600, 350)),
("45-155 fps! (vs 5 fps Faster R-CNN)", 18, "#EF9A9A", FONT_B, (600, 400)),
(
"Jedno przejście przez sieć → WSZYSTKIE detekcje naraz → NMS → wynik",
14,
"#78909C",
FONT_R,
(80, 620),
),
(
"Two-stage (R-CNN): propozycje+klasyfikacja "
"| One-stage (YOLO): bez propozycji!",
14,
"#90CAF9",
FONT_R,
(80, 655),
),
]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(STEP_DUR)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides

View File

@ -0,0 +1,459 @@
"""YOLO architecture detail and DETR transformer detection."""
from __future__ import annotations
from _q24_common import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_tc,
_text_slide,
)
from moviepy import CompositeVideoClip, VideoClip
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
# ── YOLO Architecture Detail ──────────────────────────────────────
def _yolo_architecture() -> list[CompositeVideoClip]:
"""Show YOLO architecture: backbone → head, output tensor."""
slides = []
# Slide 1: YOLO architecture breakdown
def make_yolo_arch(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
# Pipeline: Image → Backbone → Neck → Head → SxSx(B*5+C) tensor
blocks = [
((60, 280), (100, 80), (50, 70, 90), "Obraz"),
((200, 280), (100, 80), (70, 130, 200), "Backbone"),
((340, 280), (100, 80), (200, 160, 80), "Neck"),
((480, 280), (100, 80), (200, 100, 60), "Head"),
((620, 280), (160, 80), (80, 200, 120), "SxSx(B*5+C)"),
]
n_blocks = min(int(progress * 5) + 1, 5)
for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(blocks):
if i < n_blocks:
frame[by : by + bh, bx : bx + bw] = color
frame[by : by + 2, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
arrow_limit = 4
if i < arrow_limit:
ax = bx + bw + 5
ay = by + bh // 2
frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170)
# Output tensor breakdown (right side)
tensor_phase = 0.6
if progress > tensor_phase:
# Show SxS grid
gx, gy = 850, 180
gs = 120
gn = 4 # simplified from 7
gc = gs // gn
for r in range(gn):
for c in range(gn):
x = gx + c * gc
y = gy + r * gc
frame[y : y + gc - 1, x : x + gc - 1] = (40, 50, 65)
# Highlight one cell
frame[gy + gc : gy + 2 * gc - 1, gx + gc : gx + 2 * gc - 1] = (
80,
200,
120,
)
return frame
arch_clip = VideoClip(make_yolo_arch, duration=STEP_DUR + 1).with_fps(FPS)
dur = STEP_DUR + 1
labels = [
("YOLO: Architektura — krok po kroku", 26, "#FFE082", FONT_B, (80, 20)),
(
"One-stage: JEDEN forward pass → WSZYSTKIE detekcje naraz",
17,
"#B0BEC5",
FONT_R,
(80, 60),
),
("Obraz", 13, "white", FONT_R, (85, 295)),
("Backbone", 13, "white", FONT_R, (215, 295)),
("(ResNet/", 11, "#78909C", FONT_R, (210, 370)),
("Darknet)", 11, "#78909C", FONT_R, (210, 390)),
("Neck", 13, "white", FONT_R, (365, 295)),
("(FPN/", 11, "#78909C", FONT_R, (360, 370)),
("PANet)", 11, "#78909C", FONT_R, (360, 390)),
("Head", 13, "white", FONT_R, (505, 295)),
("(conv)", 11, "#78909C", FONT_R, (500, 370)),
("Tensor wyjścia", 13, "#A5D6A7", FONT_R, (640, 295)),
("Każda komórka SxS predykuje:", 15, "#FFE082", FONT_R, (830, 320)),
(" B bbox x (x,y,w,h,conf)", 13, "#B0BEC5", FONT_R, (830, 350)),
(" + C klas (prob.)", 13, "#B0BEC5", FONT_R, (830, 375)),
("= SxSx(Bx5+C) tensor", 13, "#A5D6A7", FONT_R, (830, 400)),
("Np. 7x7x(2x5+20) = 7x7x30", 13, "#78909C", FONT_R, (830, 430)),
(
"Two-stage (R-CNN): (1) propozycje → (2) klasyfikacja = 2 przejścia",
15,
"#EF9A9A",
FONT_R,
(80, 470),
),
(
"One-stage (YOLO): siatka → predykcja all-in-one = 1 przejście!",
15,
"#A5D6A7",
FONT_R,
(80, 505),
),
(
"Ewolucja YOLO: v1(2016)→v3→v5→v8(2023, anchor-free, SOTA)",
16,
"#FFE082",
FONT_R,
(80, 555),
),
(
"SSD (2016): multi-scale feature maps → lepsza detekcja małych obiektów",
15,
"#64B5F6",
FONT_R,
(80, 595),
),
(
"FPN: łączy wczesne warstwy (małe obiekty) + późne (duże obiekty)",
15,
"#78909C",
FONT_R,
(80, 630),
),
]
text_clips: list[VideoClip] = [arch_clip]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(dur)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
return slides
# ── DETR ──────────────────────────────────────────────────────────
def _detr_demo() -> list[CompositeVideoClip]:
"""Animate DETR: transformer detection, object queries, no NMS."""
slides = []
# Slide 1: DETR pipeline
def make_detr_frame(t: float) -> np.ndarray:
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
# DETR pipeline: Image → Backbone → Encoder → Decoder → N predictions
blocks = [
((50, 260), (80, 60), (50, 70, 90)),
((170, 260), (90, 60), (70, 130, 200)),
((300, 260), (110, 60), (200, 120, 60)),
((450, 260), (110, 60), (200, 80, 160)),
((600, 260), (120, 60), (80, 200, 120)),
]
n_blocks = min(int(progress * 5) + 1, 5)
for i, ((bx, by), (bw, bh), color) in enumerate(blocks):
if i < n_blocks:
frame[by : by + bh, bx : bx + bw] = color
frame[by : by + 2, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
min(c + 50, 255) for c in color
)
arrow_limit = 4
if i < arrow_limit:
ax = bx + bw + 5
ay = by + bh // 2
frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170)
# Object queries illustration (right side)
query_phase = 0.5
if progress > query_phase:
qx, qy = 800, 140
for i in range(6):
y = qy + i * 50
w = 130
active_limit = 3
active = i < active_limit
color = (80, 180, 120) if active else (60, 50, 50)
frame[y : y + 35, qx : qx + w] = color
frame[y : y + 1, qx : qx + w] = tuple(min(c + 40, 255) for c in color)
# Arrow from decoder to queries
frame[285:288, 723:798] = (150, 150, 170)
return frame
detr_clip = VideoClip(make_detr_frame, duration=STEP_DUR + 1).with_fps(FPS)
dur = STEP_DUR + 1
labels = [
("DETR: DEtection TRansformer (2020)", 26, "#FFE082", FONT_B, (80, 20)),
(
"Radykalnie prostszy pipeline: BEZ anchorów, BEZ NMS!",
17,
"#B0BEC5",
FONT_R,
(80, 60),
),
("Obraz", 12, "white", FONT_R, (65, 275)),
("Backbone", 12, "white", FONT_R, (185, 275)),
("Transformer", 12, "white", FONT_R, (310, 275)),
("Encoder", 12, "white", FONT_R, (325, 295)),
("Transformer", 12, "white", FONT_R, (460, 275)),
("Decoder", 12, "white", FONT_R, (478, 295)),
("N predykcji", 12, "white", FONT_R, (615, 275)),
("Object Queries:", 14, "#FFE082", FONT_B, (800, 115)),
("samochód 95%", 11, "white", FONT_R, (810, 148)),
("pies 88%", 11, "white", FONT_R, (810, 198)),
("rower 72%", 11, "white", FONT_R, (810, 248)),
("brak", 11, "#78909C", FONT_R, (810, 298)),
("brak", 11, "#78909C", FONT_R, (810, 348)),
("brak", 11, "#78909C", FONT_R, (810, 398)),
("100 wyuczonych queries", 13, "#FFE082", FONT_R, (800, 440)),
("→ każdy 'szuka' obiektu", 13, "#FFE082", FONT_R, (800, 465)),
]
text_clips: list[VideoClip] = [detr_clip]
for text, fs, color, font, pos in labels:
tc = (
_tc(text=text, font_size=fs, color=color, font=font)
.with_duration(dur)
.with_position(pos)
)
text_clips.append(tc)
slides.append(
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
[FadeIn(0.3), FadeOut(0.3)]
)
)
# Slide 2: Why no NMS + Hungarian matching
detr_details = [
("DETR: Dlaczego bez NMS? — krok po kroku", 24, "#FFE082", FONT_B, (80, 30)),
(
"Problem NMS: duplikaty detekcji → ręcznie usuwaj post-hoc",
16,
"#EF9A9A",
FONT_R,
(80, 90),
),
(
"DETR rozwiązanie: Hungarian matching (dopasowanie węgierskie)",
17,
"#A5D6A7",
FONT_R,
(80, 130),
),
("", 10, "white", FONT_R, (80, 155)),
("Jak to działa podczas TRENINGU:", 17, "white", FONT_B, (80, 180)),
(
" 1. Sieć daje N=100 predykcji (queries)",
15,
"#64B5F6",
FONT_R,
(100, 220),
),
(
" 2. Na obrazie jest np. 5 obiektów (ground truth)",
15,
"#64B5F6",
FONT_R,
(100, 255),
),
(
" 3. Hungarian matching: optymalne dopasowanie 1:1",
15,
"#FFE082",
FONT_R,
(100, 290),
),
(
" → query_1 ↔ gt_samochód (najlepsze dopasowanie)",
14,
"#A5D6A7",
FONT_R,
(120, 325),
),
(" → query_7 ↔ gt_pies", 14, "#A5D6A7", FONT_R, (120, 355)),
(" → query_3 ↔ gt_rower", 14, "#A5D6A7", FONT_R, (120, 385)),
(
" → pozostałe 97 queries ↔ klasa 'brak obiektu'",
14,
"#78909C",
FONT_R,
(120, 415),
),
(
" 4. Każdy obiekt ma DOKŁADNIE 1 predykcję → BRAK duplikatów!",
15,
"#A5D6A7",
FONT_R,
(100, 455),
),
("", 10, "white", FONT_R, (100, 475)),
(
"Self-attention w encoderze: cechy obrazu 'rozmawiają' ze sobą",
15,
"#64B5F6",
FONT_R,
(80, 500),
),
(
"Cross-attention w decoderze: queries 'pytają' cechy obrazu",
15,
"#CE93D8",
FONT_R,
(80, 535),
),
(
"→ query 'rozumie' który fragment obrazu to 'jego' obiekt",
15,
"#FFE082",
FONT_R,
(80, 570),
),
(
"DETR = Detekcja Eliminująca Trikowe Redundancje (NMS, anchory)",
16,
"#FFE082",
FONT_R,
(80, 620),
),
(
"Wada: wolniejszy trening (O(n²) attention) | Zaleta: prostszy pipeline!",
15,
"#78909C",
FONT_R,
(80, 660),
),
]
slides.append(_text_slide(detr_details, duration=STEP_DUR + 1))
# Slide 3: Two-stage vs One-stage vs Transformer summary
summary_lines = [
(
"Podsumowanie: Two-stage vs One-stage vs Transformer",
22,
"#FFE082",
FONT_B,
(80, 30),
),
("", 10, "white", FONT_R, (80, 55)),
("TWO-STAGE (R-CNN family):", 18, "#EF9A9A", FONT_B, (80, 90)),
(
" (1) Generuj propozycje → (2) Klasyfikuj per region",
15,
"white",
FONT_R,
(100, 125),
),
(
" + Wysoka precyzja | - Wolniejsze (2 przejścia)",
15,
"#78909C",
FONT_R,
(100, 155),
),
(
" R-CNN → Fast R-CNN → Faster R-CNN (0.2s)",
15,
"#B0BEC5",
FONT_R,
(100, 185),
),
("", 10, "white", FONT_R, (80, 210)),
("ONE-STAGE (YOLO, SSD):", 18, "#A5D6A7", FONT_B, (80, 240)),
(
" Siatka → predykcja all-in-one (1 przejście)",
15,
"white",
FONT_R,
(100, 275),
),
(
" + Bardzo szybkie (45-155 fps) | - Historycznie mniej precyzyjne",
15,
"#78909C",
FONT_R,
(100, 305),
),
(
" YOLOv8 (2023): anchor-free, dorównuje two-stage!",
15,
"#B0BEC5",
FONT_R,
(100, 335),
),
("", 10, "white", FONT_R, (80, 360)),
("TRANSFORMER (DETR):", 18, "#CE93D8", FONT_B, (80, 390)),
(
" Object queries + self-attention (globalny kontekst)",
15,
"white",
FONT_R,
(100, 425),
),
(
" + Brak NMS/anchorów | - Wolniejszy trening (O(n²))",
15,
"#78909C",
FONT_R,
(100, 455),
),
(
" Hungarian matching → 1:1 obiekt↔predykcja → brak duplikatów",
15,
"#B0BEC5",
FONT_R,
(100, 485),
),
("", 10, "white", FONT_R, (80, 510)),
(
"Trend: coraz prostsze pipeline, mniej ręcznych komponentów",
17,
"white",
FONT_R,
(80, 540),
),
(
" R-CNN (SS+CNN+SVM+NMS) → YOLO "
"(backbone+head+NMS) → DETR (backbone+transformer)",
14,
"#90CAF9",
FONT_R,
(80, 580),
),
(
"Metryki: mAP@0.5 (standard), mAP@0.5:0.95 (surowsza), "
"IoU do dopasowania",
15,
"#78909C",
FONT_R,
(80, 630),
),
]
slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1))
return slides

View File

@ -0,0 +1,235 @@
"""Common drawing primitives and constants for Pub/Sub diagrams."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from matplotlib.axes import Axes
DPI = 300
BG = "white"
LN = "black"
FS = 9
FS_TITLE = 13
FIG_W = 8.27 # A4 width in inches
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
GRAY1 = "#E8E8E8"
GRAY2 = "#D0D0D0"
GRAY3 = "#B8B8B8"
GRAY4 = "#F5F5F5"
GRAY5 = "#C0C0C0"
@dataclass(frozen=True)
class BoxStyle:
"""Optional styling for boxes."""
fill: str = "white"
lw: float = 1.2
fontsize: float = FS
fontweight: str = "normal"
ha: str = "center"
va: str = "center"
rounded: bool = True
@dataclass(frozen=True)
class ArrowCfg:
"""Config for arrows."""
lw: float = 1.2
style: str = "->"
color: str = LN
label: str = ""
label_offset: float = 0.15
label_fs: float = 8
@dataclass(frozen=True)
class DashedCfg:
"""Config for dashed arrows."""
lw: float = 1.0
color: str = LN
label: str = ""
label_offset: float = 0.15
label_fs: float = 8
def draw_box(
ax: Axes,
pos: tuple[float, float],
size: tuple[float, float],
text: str,
style: BoxStyle | None = None,
) -> None:
"""Draw box."""
s = style or BoxStyle()
x, y = pos
w, h = size
if s.rounded:
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.05",
lw=s.lw,
edgecolor=LN,
facecolor=s.fill,
)
else:
rect = mpatches.Rectangle(
(x, y),
w,
h,
lw=s.lw,
edgecolor=LN,
facecolor=s.fill,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h / 2,
text,
ha=s.ha,
va=s.va,
fontsize=s.fontsize,
fontweight=s.fontweight,
wrap=True,
)
def draw_arrow(
ax: Axes,
start: tuple[float, float],
end: tuple[float, float],
cfg: ArrowCfg | None = None,
) -> None:
"""Draw arrow."""
c = cfg or ArrowCfg()
ax.annotate(
"",
xy=end,
xytext=start,
arrowprops={
"arrowstyle": c.style,
"color": c.color,
"lw": c.lw,
},
)
if c.label:
mx = (start[0] + end[0]) / 2
my = (start[1] + end[1]) / 2 + c.label_offset
ax.text(
mx,
my,
c.label,
ha="center",
va="bottom",
fontsize=c.label_fs,
color=c.color,
)
def draw_dashed_arrow(
ax: Axes,
start: tuple[float, float],
end: tuple[float, float],
cfg: DashedCfg | None = None,
) -> None:
"""Draw dashed arrow."""
c = cfg or DashedCfg()
ax.annotate(
"",
xy=end,
xytext=start,
arrowprops={
"arrowstyle": "->",
"color": c.color,
"lw": c.lw,
"linestyle": "dashed",
},
)
if c.label:
mx = (start[0] + end[0]) / 2
my = (start[1] + end[1]) / 2 + c.label_offset
ax.text(
mx,
my,
c.label,
ha="center",
va="bottom",
fontsize=c.label_fs,
color=c.color,
)
def draw_cross(
ax: Axes,
pos: tuple[float, float],
size: float = 0.15,
lw: float = 2.5,
color: str = "black",
) -> None:
"""Draw cross."""
x, y = pos
ax.plot(
[x - size, x + size],
[y - size, y + size],
color=color,
lw=lw,
)
ax.plot(
[x - size, x + size],
[y + size, y - size],
color=color,
lw=lw,
)
def draw_check(
ax: Axes,
pos: tuple[float, float],
size: float = 0.15,
lw: float = 2.5,
color: str = "black",
) -> None:
"""Draw check."""
x, y = pos
ax.plot(
[x - size, x - size * 0.2],
[y, y - size * 0.7],
color=color,
lw=lw,
)
ax.plot(
[x - size * 0.2, x + size],
[y - size * 0.7, y + size * 0.5],
color=color,
lw=lw,
)
def save(fig: plt.Figure, name: str) -> None:
"""Save."""
plt.tight_layout()
fig.savefig(
str(Path(OUTPUT_DIR) / name),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close(fig)

View File

@ -0,0 +1,430 @@
"""QoS delivery guarantee diagrams: at-most-once, at-least-once, exactly-once."""
from __future__ import annotations
from _pubsub_common import (
FIG_W,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
LN,
ArrowCfg,
BoxStyle,
draw_arrow,
draw_box,
draw_check,
draw_cross,
draw_dashed_arrow,
save,
)
import matplotlib.pyplot as plt
# ============================================================
# 5. At-most-once (QoS 0)
# ============================================================
def draw_qos_at_most_once() -> None:
"""Draw qos at most once."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 6)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"QoS: At-most-once"
" \u2014 \u201ewy\u015blij i zapomnij\u201d"
" (0 lub 1 dostarczenie)",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
px, bx, sx = 1.0, 4.8, 8.5
pw, bw, sw = 2.0, 2.2, 2.0
bh = 0.8
bold10_g1 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
bold10_g2 = BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold")
draw_box(ax, (px, 5.0), (pw, bh), "Publisher", bold10_g1)
draw_box(ax, (bx, 5.0), (bw, bh), "Broker", bold10_g2)
draw_box(
ax,
(sx, 5.0),
(sw, bh),
"Subscriber",
bold10_g1,
)
for xc in [px + pw / 2, bx + bw / 2, sx + sw / 2]:
ax.plot(
[xc, xc],
[5.0, 1.2],
color=GRAY3,
lw=1,
linestyle=":",
)
# Scenario A: success
y = 4.3
ax.text(
0.2,
y + 0.15,
"Scenariusz A:",
fontsize=8.5,
fontweight="bold",
)
msg9 = ArrowCfg(label="MSG", label_fs=9)
draw_arrow(ax, (px + pw / 2, y), (bx + bw / 2, y), msg9)
draw_arrow(
ax,
(bx + bw / 2, y - 0.6),
(sx + sw / 2, y - 0.6),
msg9,
)
draw_check(ax, (sx + sw / 2 + 0.4, y - 0.6), size=0.18)
ax.text(
sx + sw / 2 + 0.7,
y - 0.6,
"OK",
fontsize=9,
fontweight="bold",
)
# Scenario B: lost
y = 2.6
ax.text(
0.2,
y + 0.15,
"Scenariusz B:",
fontsize=8.5,
fontweight="bold",
)
draw_arrow(ax, (px + pw / 2, y), (bx + bw / 2, y), msg9)
draw_dashed_arrow(ax, (bx + bw / 2, y - 0.6), (7.5, y - 0.6))
draw_cross(ax, (7.8, y - 0.6), size=0.2)
ax.text(
8.2,
y - 0.55,
"UTRACONA",
fontsize=9,
fontweight="bold",
)
ax.text(
8.2,
y - 1.0,
"(brak retransmisji)",
fontsize=8,
style="italic",
)
ax.text(
6.0,
0.5,
"Brak ACK, brak retransmisji."
" Najszybszy. Use case:"
" logi, metryki, telemetria.",
ha="center",
va="center",
fontsize=9,
bbox={
"boxstyle": "round,pad=0.4",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
save(fig, "pubsub_qos_at_most_once.png")
# ============================================================
# 6. At-least-once (QoS 1)
# ============================================================
def draw_qos_at_least_once() -> None:
"""Draw qos at least once."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0))
ax.set_xlim(0, 12)
ax.set_ylim(0, 6.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"QoS: At-least-once"
" \u2014 \u201epowtarzaj a\u017c potwierdz\u0105\u201d"
" (\u22651 dostarczenie)",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
bx, bw = 3.5, 2.2
sx, sw = 8.0, 2.2
bh = 0.8
draw_box(
ax,
(bx, 5.5),
(bw, bh),
"Broker",
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
)
draw_box(
ax,
(sx, 5.5),
(sw, bh),
"Subscriber",
BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold"),
)
for xc in [bx + bw / 2, sx + sw / 2]:
ax.plot(
[xc, xc],
[5.5, 0.8],
color=GRAY3,
lw=1,
linestyle=":",
)
# Step 1: send MSG
y1 = 4.8
draw_arrow(
ax,
(bx + bw / 2, y1),
(sx + sw / 2, y1),
ArrowCfg(label="MSG #1", label_fs=9),
)
draw_check(ax, (sx + sw + 0.2, y1), size=0.15)
ax.text(sx + sw + 0.5, y1, "odebrano", fontsize=8)
# Step 2: ACK lost
y2 = 3.9
draw_dashed_arrow(
ax,
(sx + sw / 2, y2),
(bx + bw + 1.2, y2),
)
ax.text(
(bx + bw / 2 + sx + sw / 2) / 2,
y2 + 0.18,
"ACK",
fontsize=9,
)
draw_cross(ax, (bx + bw + 0.8, y2), size=0.18)
ax.text(
bx + 0.3,
y2 - 0.35,
"ACK utracony!",
fontsize=8.5,
style="italic",
)
# Step 3: timeout -> retry
y3 = 2.9
ax.text(
bx + bw / 2,
y3 + 0.45,
"timeout...",
fontsize=8.5,
style="italic",
ha="center",
)
draw_arrow(
ax,
(bx + bw / 2, y3),
(sx + sw / 2, y3),
ArrowCfg(label="MSG #1 (retry)", label_fs=9),
)
draw_check(ax, (sx + sw + 0.2, y3), size=0.15)
ax.text(
sx + sw + 0.5,
y3,
"odebrano\n(ponownie!)",
fontsize=8,
)
# Step 4: ACK ok
y4 = 2.0
draw_arrow(
ax,
(sx + sw / 2, y4),
(bx + bw / 2, y4),
ArrowCfg(label="ACK", label_fs=9),
)
draw_check(ax, (bx + bw / 2 - 0.5, y4), size=0.18)
# Duplicate bracket
ax.annotate(
"",
xy=(sx + sw + 1.3, y1),
xytext=(sx + sw + 1.3, y3),
arrowprops={
"arrowstyle": "<->",
"color": "black",
"lw": 1.2,
},
)
ax.text(
sx + sw + 1.6,
(y1 + y3) / 2,
"DUPLIKAT!\nSubscriber\notrzyma\u0142 2x",
fontsize=9,
ha="left",
va="center",
fontweight="bold",
bbox={
"boxstyle": "round,pad=0.25",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
ax.text(
6.0,
0.5,
"Broker czeka na ACK, retransmituje"
" po timeout. Mog\u0105 by\u0107 duplikaty!\n"
"Use case: zam\u00f3wienia, p\u0142atno\u015bci"
" (subscriber musi by\u0107 idempotentny).",
ha="center",
va="center",
fontsize=9,
bbox={
"boxstyle": "round,pad=0.4",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
save(fig, "pubsub_qos_at_least_once.png")
# ============================================================
# 7. Exactly-once (QoS 2)
# ============================================================
def draw_qos_exactly_once() -> None:
"""Draw qos exactly once."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"QoS: Exactly-once \u2014 4-krokowy"
" handshake (dok\u0142adnie 1 dostarczenie)",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
bx, bw = 2.5, 2.2
sx, sw = 7.5, 2.2
bh = 0.8
draw_box(
ax,
(bx, 6.0),
(bw, bh),
"Broker",
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
)
draw_box(
ax,
(sx, 6.0),
(sw, bh),
"Subscriber",
BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold"),
)
for xc in [bx + bw / 2, sx + sw / 2]:
ax.plot(
[xc, xc],
[6.0, 1.0],
color=GRAY3,
lw=1,
linestyle=":",
)
steps = [
(
5.2,
"right",
"PUBLISH (msg_id=42)",
"Broker wysy\u0142a wiadomo\u015b\u0107",
),
(
4.2,
"left",
"PUBREC (otrzyma\u0142em id=42)",
"Sub potwierdza odbi\u00f3r," " zapisuje id",
),
(
3.2,
"right",
"PUBREL (mo\u017cesz przetworzy\u0107)",
"Broker zwalnia wiadomo\u015b\u0107",
),
(
2.2,
"left",
"PUBCOMP (zako\u0144czone)",
"Sub potwierdza przetworzenie",
),
]
for i, (y, direction, label, desc) in enumerate(steps):
ax.text(
bx + bw / 2 - 0.7,
y,
f"{i + 1}",
fontsize=9,
fontweight="bold",
ha="center",
va="center",
bbox={
"boxstyle": "circle,pad=0.18",
"facecolor": GRAY3,
"edgecolor": LN,
},
)
if direction == "right":
draw_arrow(
ax,
(bx + bw / 2, y),
(sx + sw / 2, y),
ArrowCfg(label=label, label_fs=9),
)
else:
draw_arrow(
ax,
(sx + sw / 2, y),
(bx + bw / 2, y),
ArrowCfg(label=label, label_fs=9),
)
ax.text(
sx + sw + 0.3,
y,
desc,
fontsize=8,
ha="left",
va="center",
style="italic",
)
ax.text(
6.0,
0.6,
"Deduplikacja po msg_id."
" Sub nie przetwarza przed PUBREL.\n"
"Najkosztowniejszy (4 pakiety)."
" Use case: transakcje finansowe,"
" krytyczne zdarzenia.",
ha="center",
va="center",
fontsize=9,
bbox={
"boxstyle": "round,pad=0.4",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
save(fig, "pubsub_qos_exactly_once.png")

View File

@ -0,0 +1,239 @@
"""Subscription-type diagrams: topic-based and content-based."""
from __future__ import annotations
from _pubsub_common import (
FIG_W,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
ArrowCfg,
BoxStyle,
DashedCfg,
draw_arrow,
draw_box,
draw_dashed_arrow,
save,
)
import matplotlib.pyplot as plt
# ============================================================
# 1. Topic-based subscription
# ============================================================
def draw_sub_topic() -> None:
"""Draw sub topic."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.0))
ax.set_xlim(0, 12)
ax.set_ylim(0, 5.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Subskrypcja topic-based" " \u2014 routing po nazwie tematu",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
fs85 = BoxStyle(fill=GRAY1, fontsize=8.5)
draw_box(ax, (0.2, 3.2), (2.4, 1.1), "Publisher", bold10)
draw_box(
ax,
(0.3, 1.8),
(2.2, 0.8),
'topic: "orders"',
BoxStyle(fill=GRAY4, fontsize=8),
)
draw_box(
ax,
(0.3, 0.7),
(2.2, 0.8),
'topic: "payments"',
BoxStyle(fill=GRAY4, fontsize=8),
)
draw_box(
ax,
(4.2, 1.5),
(2.8, 2.2),
"BROKER\n\ntopic routing",
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
)
draw_box(
ax,
(8.5, 3.8),
(3.0, 1.0),
'Subscriber A\nsubskrybuje: "orders"',
fs85,
)
draw_box(
ax,
(8.5, 2.2),
(3.0, 1.0),
'Subscriber B\nsubskrybuje: "payments"',
fs85,
)
draw_box(
ax,
(8.5, 0.6),
(3.0, 1.0),
'Subscriber C\nsubskrybuje: "orders"',
fs85,
)
fs8 = ArrowCfg(label_fs=8)
draw_arrow(ax, (2.6, 2.2), (4.2, 2.8), fs8)
draw_arrow(ax, (2.6, 1.1), (4.2, 2.2), fs8)
draw_arrow(
ax,
(7.0, 3.4),
(8.5, 4.2),
ArrowCfg(label='"orders"', label_fs=8),
)
draw_arrow(
ax,
(7.0, 2.6),
(8.5, 2.7),
ArrowCfg(label='"payments"', label_fs=8),
)
draw_arrow(
ax,
(7.0, 2.2),
(8.5, 1.2),
ArrowCfg(label='"orders"', label_fs=8),
)
ax.text(
6.0,
0.1,
"Subscriber deklaruje nazw\u0119 tematu."
" Broker kieruje wiadomo\u015bci\n"
"do WSZYSTKICH subscriber\u00f3w"
" danego tematu. Najprostszy model.",
ha="center",
va="bottom",
fontsize=8.5,
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
save(fig, "pubsub_sub_topic.png")
# ============================================================
# 2. Content-based subscription
# ============================================================
def draw_sub_content() -> None:
"""Draw sub content."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 6)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Subskrypcja content-based"
" \u2014 filtrowanie po tre\u015bci wiadomo\u015bci",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
draw_box(ax, (0.2, 3.5), (2.4, 1.1), "Publisher", bold10)
draw_box(
ax,
(0.2, 1.8),
(2.4, 1.2),
'price: 150\ntype: "book"\ncategory: "IT"',
BoxStyle(fill=GRAY4, fontsize=8.5),
)
draw_box(
ax,
(4.0, 2.0),
(3.0, 2.5),
"BROKER\n\newaluuje filtry\n" "ka\u017cdego subscribera",
BoxStyle(fill=GRAY2, fontsize=9, fontweight="bold"),
)
fs9 = BoxStyle(fill=GRAY1, fontsize=9)
draw_box(
ax,
(8.5, 4.2),
(3.2, 1.0),
"Sub A\nfiltr: price > 100",
fs9,
)
draw_box(
ax,
(8.5, 2.6),
(3.2, 1.0),
'Sub B\nfiltr: type = "food"',
fs9,
)
draw_box(
ax,
(8.5, 1.0),
(3.2, 1.0),
"Sub C\nfiltr: price < 50",
fs9,
)
draw_arrow(ax, (2.6, 2.4), (4.0, 3.0))
draw_arrow(
ax,
(7.0, 4.0),
(8.5, 4.6),
ArrowCfg(
label="150 > 100 \u2713 dostarczono",
label_fs=8,
),
)
draw_dashed_arrow(
ax,
(7.0, 3.2),
(8.5, 3.1),
DashedCfg(
label='"book" \u2260 "food"' " \u2717 odrzucono",
label_fs=8,
),
)
draw_dashed_arrow(
ax,
(7.0, 2.5),
(8.5, 1.6),
DashedCfg(
label="150 < 50 \u2717 odrzucono",
label_fs=8,
),
)
ax.text(
6.0,
0.2,
"Broker analizuje TRE\u015a\u0106 wiadomo\u015bci"
" i ewaluuje predykaty.\n"
"Bardziej elastyczny ni\u017c topic-based,"
" ale wolniejszy (koszt ewaluacji).",
ha="center",
va="bottom",
fontsize=8.5,
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
save(fig, "pubsub_sub_content.png")

View File

@ -0,0 +1,279 @@
"""Subscription-type diagrams: type-based and hierarchical."""
from __future__ import annotations
from _pubsub_common import (
FIG_W,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
ArrowCfg,
BoxStyle,
draw_arrow,
draw_box,
save,
)
import matplotlib.pyplot as plt
# ============================================================
# 3. Type-based subscription
# ============================================================
def draw_sub_type() -> None:
"""Draw sub type."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0))
ax.set_xlim(0, 12)
ax.set_ylim(0, 6.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Subskrypcja type-based" " \u2014 routing po typie (klasie) obiektu",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
draw_box(ax, (0.2, 4.2), (2.4, 1.1), "Publisher", bold10)
fs9_g4 = BoxStyle(fill=GRAY4, fontsize=9)
draw_box(
ax,
(0.1, 2.8),
(2.6, 0.9),
"new OrderEvent()",
fs9_g4,
)
draw_box(
ax,
(0.1, 1.5),
(2.6, 0.9),
"new PaymentEvent()",
fs9_g4,
)
draw_box(
ax,
(4.0, 2.3),
(3.0, 2.4),
"BROKER\n\nrouting po\ntypie klasy",
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
)
fs9 = BoxStyle(fill=GRAY1, fontsize=9)
draw_box(
ax,
(8.5, 4.8),
(3.2, 1.0),
"Sub A\n\u2192 OrderEvent",
fs9,
)
draw_box(
ax,
(8.5, 3.2),
(3.2, 1.0),
"Sub B\n\u2192 PaymentEvent",
fs9,
)
draw_box(
ax,
(8.5, 1.6),
(3.2, 1.0),
"Sub C\n\u2192 Event (base)",
fs9,
)
draw_arrow(ax, (2.7, 3.2), (4.0, 3.8))
draw_arrow(ax, (2.7, 2.0), (4.0, 3.0))
draw_arrow(
ax,
(7.0, 4.3),
(8.5, 5.2),
ArrowCfg(label="OrderEvent", label_fs=8),
)
draw_arrow(
ax,
(7.0, 3.5),
(8.5, 3.7),
ArrowCfg(label="PaymentEvent", label_fs=8),
)
draw_arrow(
ax,
(7.0, 3.0),
(8.5, 2.2),
ArrowCfg(label="oba (dziedziczenie!)", label_fs=8),
)
hx, hy = 0.5, 0.0
draw_box(
ax,
(hx + 2.0, hy + 0.2),
(1.8, 0.6),
"Event",
BoxStyle(fill=GRAY3, fontsize=8, fontweight="bold"),
)
draw_box(
ax,
(hx, hy + 0.2),
(1.8, 0.6),
"OrderEvent",
BoxStyle(fill=GRAY4, fontsize=7.5),
)
draw_box(
ax,
(hx + 4.0, hy + 0.2),
(2.0, 0.6),
"PaymentEvent",
BoxStyle(fill=GRAY4, fontsize=7.5),
)
draw_arrow(
ax,
(hx + 2.9, hy + 0.2),
(hx + 0.9, hy + 0.2),
ArrowCfg(
lw=1.0,
label="extends",
label_offset=-0.3,
label_fs=7,
),
)
draw_arrow(
ax,
(hx + 2.9, hy + 0.2),
(hx + 5.0, hy + 0.2),
ArrowCfg(
lw=1.0,
label="extends",
label_offset=-0.3,
label_fs=7,
),
)
ax.text(
9.5,
0.5,
"Sub C subskrybuje bazowy Event\n" "\u2192 otrzymuje WSZYSTKIE podtypy",
ha="center",
va="center",
fontsize=8.5,
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
},
)
save(fig, "pubsub_sub_type.png")
# ============================================================
# 4. Hierarchical / Wildcards subscription
# ============================================================
def draw_sub_hierarchical() -> None:
"""Draw sub hierarchical."""
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Subskrypcja hierarchiczna (wildcards)" " \u2014 wzorce temat\u00f3w",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
bold10 = BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold")
draw_box(ax, (4.5, 5.8), (2.4, 0.8), "sensors/", bold10)
fs9_g3 = BoxStyle(fill=GRAY3, fontsize=9)
draw_box(
ax,
(1.5, 4.2),
(2.4, 0.8),
"temperature/",
fs9_g3,
)
draw_box(
ax,
(7.5, 4.2),
(2.4, 0.8),
"humidity/",
fs9_g3,
)
fs85_g4 = BoxStyle(fill=GRAY4, fontsize=8.5)
draw_box(ax, (0.2, 2.8), (1.8, 0.7), "room1", fs85_g4)
draw_box(ax, (2.4, 2.8), (1.8, 0.7), "room2", fs85_g4)
draw_box(ax, (6.8, 2.8), (1.8, 0.7), "room1", fs85_g4)
draw_box(ax, (9.0, 2.8), (1.8, 0.7), "room2", fs85_g4)
thin = ArrowCfg(lw=1.0)
draw_arrow(ax, (5.7, 5.8), (2.7, 5.0), thin)
draw_arrow(ax, (5.7, 5.8), (8.7, 5.0), thin)
draw_arrow(ax, (2.2, 4.2), (1.1, 3.5), thin)
draw_arrow(ax, (3.2, 4.2), (3.3, 3.5), thin)
draw_arrow(ax, (8.2, 4.2), (7.7, 3.5), thin)
draw_arrow(ax, (9.2, 4.2), (9.9, 3.5), thin)
ax.text(
1.1,
2.4,
"sensors/temperature/room1",
fontsize=7,
ha="center",
fontfamily="monospace",
style="italic",
)
ax.text(
3.3,
2.4,
"sensors/temperature/room2",
fontsize=7,
ha="center",
fontfamily="monospace",
style="italic",
)
ax.text(
0.3,
1.5,
"Wzorce subskrypcji (MQTT-style):",
fontsize=10,
fontweight="bold",
)
patterns = [
(
'"sensors/temperature/room1"',
"\u2192 TYLKO room1",
"(dok\u0142adne dopasowanie)",
),
(
'"sensors/temperature/*"',
"\u2192 room1, room2",
"( * = jeden poziom)",
),
(
'"sensors/#"',
"\u2192 WSZYSTKO",
"( # = dowolna g\u0142\u0119boko\u015b\u0107)",
),
]
for i, (pat, result, note) in enumerate(patterns):
yy = 0.9 - i * 0.55
ax.text(
0.5,
yy,
pat,
fontsize=9,
fontweight="bold",
fontfamily="monospace",
)
ax.text(7.0, yy, result, fontsize=9, fontweight="bold")
ax.text(9.5, yy, note, fontsize=8, style="italic")
save(fig, "pubsub_sub_hierarchical.png")

View File

@ -0,0 +1,421 @@
"""Spark Streaming, Lambda/Kappa architecture, and exactly-once diagrams for Q20."""
from __future__ import annotations
from _q20_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
draw_table,
plt,
save_fig,
)
# ============================================================
# 12. Spark Streaming architecture
# ============================================================
def gen_spark_streaming_arch() -> None:
"""Gen spark streaming arch."""
fig, ax = plt.subplots(figsize=(9, 5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Spark Streaming — architektura (micro-batch)",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Cluster border
draw_box(ax, 0.3, 0.5, 11.4, 5.8, "", fill=GRAY4, rounded=True, lw=2.5)
ax.text(
6.0, 6.0, "SPARK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold"
)
# Driver
draw_box(
ax,
1.0,
4.5,
3.0,
1.2,
"Driver\n(planuje mini-batche)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 2.5, 4.5, 6.0, 4.0, lw=1.5)
# Batches
batches = ["batch 1\n(e1,e2,e3)", "batch 2\n(e4,e5,e6)", "batch 3\n(e7,e8,e9)"]
for i, b in enumerate(batches):
y = 2.8 - i * 1.0
draw_box(
ax, 4.5, y, 2.5, 0.8, b, fill=GRAY1, fontsize=FS_SMALL, fontweight="bold"
)
# map → reduce
draw_arrow(ax, 7.0, y + 0.4, 7.5, y + 0.4, lw=1)
draw_box(ax, 7.5, y, 1.3, 0.8, "map→\nreduce", fill=GRAY3, fontsize=5.5)
draw_arrow(ax, 8.8, y + 0.4, 9.3, y + 0.4, lw=1)
draw_box(
ax, 9.3, y, 1.5, 0.8, f"result {i + 1}", fill="white", fontsize=FS_SMALL
)
# Spark ecosystem
draw_box(
ax,
1.0,
1.0,
3.0,
1.0,
"Spark SQL / MLlib\n(ten sam ekosystem!)",
fill=GRAY5,
fontsize=FS,
fontweight="bold",
)
ax.text(
6.0,
0.3,
"ZALETA: batch API | WADA: latencja ≥ batch interval (~100ms)",
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN},
)
save_fig(fig, "q20_spark_streaming_arch.png")
# ============================================================
# 13. Lambda vs Kappa architecture
# ============================================================
def gen_lambda_vs_kappa() -> None:
"""Gen lambda vs kappa."""
fig, axes = plt.subplots(2, 1, figsize=(10, 7))
fig.suptitle("Architektura Lambda vs Kappa", fontsize=FS_TITLE, fontweight="bold")
# --- Lambda ---
ax = axes[0]
ax.set_xlim(0, 12)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"LAMBDA — 2 ścieżki (batch + speed)", fontsize=FS_LABEL, fontweight="bold"
)
# Source
draw_box(
ax,
0.3,
1.8,
2.0,
1.5,
"Źródło\ndanych",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Batch layer (top)
draw_box(
ax,
3.5,
3.3,
3.0,
1.2,
"Batch Layer\n(Spark)\nprzelicza co godzinę",
fill=GRAY1,
fontsize=FS_SMALL,
fontweight="bold",
)
draw_arrow(ax, 2.3, 3.0, 3.5, 3.9, lw=1.5)
# Speed layer (bottom)
draw_box(
ax,
3.5,
0.8,
3.0,
1.2,
"Speed Layer\n(Flink)\nreal-time",
fill=GRAY3,
fontsize=FS_SMALL,
fontweight="bold",
)
draw_arrow(ax, 2.3, 2.2, 3.5, 1.4, lw=1.5)
# Results
draw_box(
ax,
7.5,
3.3,
2.0,
1.2,
"Dokładne\nwyniki\n(wolne)",
fill=GRAY4,
fontsize=FS_SMALL,
)
draw_arrow(ax, 6.5, 3.9, 7.5, 3.9, lw=1.5)
draw_box(
ax,
7.5,
0.8,
2.0,
1.2,
"Przybliżone\nwyniki\n(szybkie)",
fill=GRAY4,
fontsize=FS_SMALL,
)
draw_arrow(ax, 6.5, 1.4, 7.5, 1.4, lw=1.5)
# Merge
draw_box(
ax,
10.0,
2.0,
1.5,
1.5,
"MERGE\n→ UI",
fill=GRAY5,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 9.5, 3.5, 10.0, 3.0, lw=1.5)
draw_arrow(ax, 9.5, 1.8, 10.0, 2.5, lw=1.5)
ax.text(
6.0,
0.1,
"2 systemy, 2 kody — złożone ale pewne",
fontsize=FS,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
# --- Kappa ---
ax = axes[1]
ax.set_xlim(0, 12)
ax.set_ylim(0, 4)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"KAPPA — 1 ścieżka (streaming only)", fontsize=FS_LABEL, fontweight="bold"
)
# Source
draw_box(
ax,
0.3,
1.3,
2.0,
1.5,
"Źródło\ndanych",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Single streaming layer
draw_box(
ax,
3.5,
1.3,
3.5,
1.5,
"Streaming Layer\n(Flink)\n+ replay z Kafka log",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 2.3, 2.05, 3.5, 2.05, lw=2)
# Output
draw_box(
ax,
8.0,
1.3,
2.5,
1.5,
"Wyniki\n→ UI",
fill=GRAY4,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 7.0, 2.05, 8.0, 2.05, lw=2)
# Replay arrow
ax.annotate(
"",
xy=(3.5, 1.0),
xytext=(7.0, 1.0),
arrowprops={
"arrowstyle": "<-",
"lw": 1.5,
"color": LN,
"connectionstyle": "arc3,rad=0.3",
"linestyle": "--",
},
)
ax.text(
5.25,
0.3,
"Replay z Kafka\n(przetwórz historię od nowa)",
fontsize=FS_SMALL,
ha="center",
style="italic",
)
ax.text(
6.0,
3.3,
"1 system, 1 kod — prostsze, ale replay = dużo I/O",
fontsize=FS,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
fig.tight_layout(rect=[0, 0, 1, 0.92])
save_fig(fig, "q20_lambda_vs_kappa.png")
# ============================================================
# 14. Lambda vs Kappa comparison table
# ============================================================
def gen_lambda_kappa_table() -> None:
"""Gen lambda kappa table."""
fig, ax = plt.subplots(figsize=(8, 3.5))
ax.set_xlim(0, 10)
ax.set_ylim(-4.5, 1)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Lambda vs Kappa — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10
)
headers = ["Cecha", "Lambda", "Kappa"]
col_w = [2.5, 3.5, 3.5]
rows = [
["Ścieżki", "2 (batch + speed)", "1 (streaming)"],
["Kod", "2 implementacje", "1 implementacja"],
["Złożoność", "wysoka", "niska"],
["Replay", "batch przelicza", "Kafka replay"],
["Spójność", "merge wymagany", "natywna"],
["Przykład", "Netflix, LinkedIn", "Uber, Confluent"],
]
draw_table(
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.55, fontsize=7.5
)
save_fig(fig, "q20_lambda_kappa_table.png")
# ============================================================
# 15. Exactly-once comparison
# ============================================================
def gen_exactly_once() -> None:
"""Gen exactly once."""
fig, ax = plt.subplots(figsize=(9, 5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Exactly-Once — mechanizmy na 3 platformach",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Flink
draw_box(ax, 0.3, 4.3, 11.0, 2.0, "", fill=GRAY4, rounded=True, lw=1.5)
ax.text(
1.0,
5.9,
"Flink — Distributed Snapshots (Chandy-Lamport)",
fontsize=FS,
fontweight="bold",
)
flink_steps = ["source", "|B|", "map()", "|B|", "sink()"]
bx = 1.0
for s in flink_steps:
if s == "|B|":
ax.text(
bx + 0.25,
4.85,
s,
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN},
)
draw_arrow(ax, bx - 0.1, 4.85, bx + 0.05, 4.85, lw=1)
bx += 0.7
else:
draw_box(
ax,
bx,
4.6,
1.5,
0.55,
s,
fill=GRAY1,
fontsize=FS_SMALL,
fontweight="bold",
)
bx += 1.8
ax.text(
8.5,
5.0,
"barrier → save state\n→ checkpoint (HDFS/S3)",
fontsize=FS_SMALL,
style="italic",
)
# Kafka Streams
draw_box(ax, 0.3, 2.3, 11.0, 1.5, "", fill=GRAY1, rounded=True, lw=1.5)
ax.text(
1.0, 3.5, "Kafka Streams — Transakcje Kafka", fontsize=FS, fontweight="bold"
)
ax.text(
1.5,
2.85,
"idempotent producer + begin TX → produce → commit TX → consumer offsets w TX",
fontsize=FS_SMALL,
)
# Spark
draw_box(ax, 0.3, 0.5, 11.0, 1.5, "", fill=GRAY3, rounded=True, lw=1.5)
ax.text(
1.0,
1.7,
"Spark Streaming — Write-Ahead Log (WAL)",
fontsize=FS,
fontweight="bold",
)
ax.text(
1.5,
1.05,
"WAL + checkpointing micro-batchów + idempotent sinks (np. upsert do DB)",
fontsize=FS_SMALL,
)
save_fig(fig, "q20_exactly_once.png")

View File

@ -0,0 +1,449 @@
"""Batch vs streaming concept and window type diagrams for Q20."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q20_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
plt,
save_fig,
)
import matplotlib.patches as mpatches
if TYPE_CHECKING:
from matplotlib.axes import Axes
# ============================================================
# 1. Batch vs Streaming concept
# ============================================================
def gen_batch_vs_streaming() -> None:
"""Gen batch vs streaming."""
fig, axes = plt.subplots(2, 1, figsize=(9, 5))
fig.suptitle(
"Batch vs Streaming — dwa modele przetwarzania",
fontsize=FS_TITLE,
fontweight="bold",
)
# Batch
ax = axes[0]
ax.set_xlim(0, 12)
ax.set_ylim(0, 3)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("BATCH (wsadowe)", fontsize=FS_LABEL, fontweight="bold")
# Data collected
draw_box(
ax,
0.5,
0.8,
3.0,
1.4,
"Zbierz WSZYSTKIE\ndane\n(godziny / dni)",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 3.5, 1.5, 4.5, 1.5, lw=2)
draw_box(
ax,
4.5,
0.8,
2.5,
1.4,
"Analiza\n(batch job)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 7.0, 1.5, 8.0, 1.5, lw=2)
draw_box(
ax,
8.0,
0.8,
2.5,
1.4,
"Wynik\n(jednorazowy)",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
ax.text(11.0, 1.5, "min-h", fontsize=FS, va="center", fontweight="bold")
# Streaming
ax = axes[1]
ax.set_xlim(0, 12)
ax.set_ylim(0, 3)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("STREAMING (strumieniowe)", fontsize=FS_LABEL, fontweight="bold")
# Events flowing
events_x = [0.5, 1.5, 2.5, 3.5]
for i, ex in enumerate(events_x):
draw_box(
ax,
ex,
1.0,
0.8,
0.8,
f"e{i + 1}",
fill=GRAY4,
fontsize=FS,
fontweight="bold",
rounded=False,
)
if i < len(events_x) - 1:
draw_arrow(ax, ex + 0.8, 1.4, ex + 1.0, 1.4, lw=1)
ax.text(4.8, 1.4, "...", fontsize=FS_LABEL, va="center")
draw_arrow(ax, 5.2, 1.4, 5.8, 1.4, lw=2)
draw_box(
ax,
5.8,
0.8,
2.8,
1.4,
"Analiza\nCIĄGŁA\n(event-by-event)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 8.6, 1.5, 9.3, 1.5, lw=2)
draw_box(
ax,
9.3,
0.8,
2.0,
1.4,
"Wyniki\nciągłe",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
ax.text(11.5, 0.5, "ms-s", fontsize=FS, va="center", fontweight="bold")
# Arrow marking infinity
ax.annotate(
"",
xy=(0.2, 1.4),
xytext=(-0.3, 1.4),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(0.0, 2.3, "∞ zdarzeń", fontsize=FS_SMALL, ha="center", style="italic")
fig.tight_layout(rect=[0, 0, 1, 0.92])
save_fig(fig, "q20_batch_vs_streaming.png")
# ============================================================
# 2. All 4 window types (TSSG)
# ============================================================
def _draw_tumbling_window(ax: Axes, events: list[int]) -> None:
"""Draw tumbling window section."""
ax.set_xlim(0, 14)
ax.set_ylim(0, 4)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Tumbling Window (okno przerzutne) — rozłączne, stały rozmiar",
fontsize=FS_LABEL,
fontweight="bold",
)
# Time axis
ax.annotate(
"",
xy=(13.5, 1.0),
xytext=(0.3, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
# Events
for i, e in enumerate(events):
x = 1.0 + i * 1.0
ax.plot(x, 1.0, "ko", markersize=5)
ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center")
# Windows
colors_w = [GRAY1, GRAY3, GRAY1, GRAY3]
for w in range(4):
x_start = 1.0 + w * 3.0 - 0.3
rect = mpatches.FancyBboxPatch(
(x_start, 1.5),
3.0,
1.2,
boxstyle="round,pad=0.1",
facecolor=colors_w[w],
edgecolor=LN,
lw=1.5,
)
ax.add_patch(rect)
ax.text(
x_start + 1.5,
2.1,
f"Okno {w + 1}",
fontsize=FS,
ha="center",
fontweight="bold",
)
# Braces down to events
for j in range(3):
ex = 1.0 + w * 3.0 + j * 1.0
ax.plot([ex, ex], [1.0, 1.5], color=LN, lw=0.8, linestyle="--")
ax.text(
7.0,
3.2,
"Każde zdarzenie → DOKŁADNIE 1 okno. Zero nakładania.",
fontsize=FS,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
def _draw_sliding_window(ax: Axes, events: list[int]) -> None:
"""Draw sliding window section."""
ax.set_xlim(0, 14)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Sliding Window (okno przesuwne) — nakładające, stały rozmiar + krok",
fontsize=FS_LABEL,
fontweight="bold",
)
ax.annotate(
"",
xy=(13.5, 1.0),
xytext=(0.3, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
for i, e in enumerate(events[:8]):
x = 1.0 + i * 1.0
ax.plot(x, 1.0, "ko", markersize=5)
ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center")
# Sliding windows: size=4, slide=2
slide_colors = [GRAY1, GRAY2, GRAY3]
for w in range(3):
x_start = 0.7 + w * 2.0
y_base = 1.5 + w * 0.9
rect = mpatches.FancyBboxPatch(
(x_start, y_base),
4.0,
0.7,
boxstyle="round,pad=0.08",
facecolor=slide_colors[w],
edgecolor=LN,
lw=1.5,
alpha=0.7,
)
ax.add_patch(rect)
ax.text(
x_start + 2.0,
y_base + 0.35,
f"Okno {w + 1} (size=4)",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
)
ax.text(
10.5,
3.5,
"krok=2\nNakładanie!\ne3,e4 → w oknie 1 i 2",
fontsize=FS,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
def _draw_session_window(ax: Axes) -> None:
"""Draw session window section."""
ax.set_xlim(0, 14)
ax.set_ylim(0, 4)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Session Window (okno sesji) — dynamiczny rozmiar, gap = przerwa",
fontsize=FS_LABEL,
fontweight="bold",
)
ax.annotate(
"",
xy=(13.5, 1.0),
xytext=(0.3, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
# Cluster 1: events close together
cluster1 = [1.0, 1.8, 2.3, 3.0]
for x in cluster1:
ax.plot(x, 1.0, "ko", markersize=5)
# Gap
ax.annotate(
"",
xy=(7.0, 0.7),
xytext=(4.0, 0.7),
arrowprops={"arrowstyle": "<->", "lw": 1, "color": LN},
)
ax.text(
5.5,
0.3,
"GAP > timeout",
fontsize=FS,
ha="center",
fontweight="bold",
style="italic",
)
# Cluster 2
cluster2 = [8.0, 8.8, 9.5]
for x in cluster2:
ax.plot(x, 1.0, "ko", markersize=5)
# Session boxes
rect1 = mpatches.FancyBboxPatch(
(0.7, 1.4),
2.6,
1.0,
boxstyle="round,pad=0.1",
facecolor=GRAY1,
edgecolor=LN,
lw=1.5,
)
ax.add_patch(rect1)
ax.text(
2.0, 1.9, "Sesja 1\n(4 zdarzenia)", fontsize=FS, ha="center", fontweight="bold"
)
rect2 = mpatches.FancyBboxPatch(
(7.7, 1.4),
2.1,
1.0,
boxstyle="round,pad=0.1",
facecolor=GRAY3,
edgecolor=LN,
lw=1.5,
)
ax.add_patch(rect2)
ax.text(
8.75, 1.9, "Sesja 2\n(3 zdarzenia)", fontsize=FS, ha="center", fontweight="bold"
)
ax.text(
5.5,
3.0,
"Nowa sesja po przerwie > gap",
fontsize=FS,
ha="center",
style="italic",
)
def _draw_global_window(ax: Axes) -> None:
"""Draw global window section."""
ax.set_xlim(0, 14)
ax.set_ylim(0, 4)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Global Window — jedno okno na cały strumień + trigger",
fontsize=FS_LABEL,
fontweight="bold",
)
ax.annotate(
"",
xy=(13.5, 1.0),
xytext=(0.3, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
for i in range(12):
x = 1.0 + i * 1.0
ax.plot(x, 1.0, "ko", markersize=5)
# One big window
rect = mpatches.FancyBboxPatch(
(0.5, 1.4),
12.5,
1.0,
boxstyle="round,pad=0.1",
facecolor=GRAY1,
edgecolor=LN,
lw=2,
)
ax.add_patch(rect)
ax.text(
6.75,
1.9,
"GLOBAL WINDOW (cały strumień)",
fontsize=FS,
ha="center",
fontweight="bold",
)
# Trigger markers
for tx in [4.0, 8.0, 12.0]:
ax.plot([tx, tx], [1.4, 2.4], color=LN, lw=2, linestyle="--")
ax.text(
tx,
2.7,
"EMIT",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY3, "edgecolor": LN},
)
ax.text(
6.75,
3.3,
"Trigger decyduje kiedy emitować (np. co N zdarzeń)",
fontsize=FS,
ha="center",
style="italic",
)
def gen_window_types() -> None:
"""Gen window types."""
fig, axes = plt.subplots(4, 1, figsize=(9, 10))
fig.suptitle("4 typy okien — TSSG", fontsize=FS_TITLE, fontweight="bold")
events = list(range(1, 13))
_draw_tumbling_window(axes[0], events)
_draw_sliding_window(axes[1], events)
_draw_session_window(axes[2])
_draw_global_window(axes[3])
fig.tight_layout(rect=[0, 0, 1, 0.94])
save_fig(fig, "q20_window_types.png")

View File

@ -0,0 +1,180 @@
"""Common utilities and constants for Q20 diagram generation.
Monochrome, A4-printable PNGs (300 DPI).
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
import numpy as np
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
_logger = logging.getLogger(__name__)
rng = np.random.default_rng(42)
DPI = 300
BG = "white"
LN = "black"
FS = 8
FS_TITLE = 11
FS_SMALL = 6.5
FS_LABEL = 9
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
GRAY1 = "#E8E8E8"
GRAY2 = "#D0D0D0"
GRAY3 = "#B8B8B8"
GRAY4 = "#F5F5F5"
GRAY5 = "#C0C0C0"
def draw_box(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
text: str,
*,
fill: str = "white",
lw: float = 1.2,
fontsize: float = FS,
fontweight: str = "normal",
ha: str = "center",
va: str = "center",
rounded: bool = True,
edgecolor: str = LN,
linestyle: str = "-",
) -> None:
"""Draw box."""
if rounded:
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.05",
lw=lw,
edgecolor=edgecolor,
facecolor=fill,
linestyle=linestyle,
)
else:
rect = mpatches.Rectangle(
(x, y),
w,
h,
lw=lw,
edgecolor=edgecolor,
facecolor=fill,
linestyle=linestyle,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h / 2,
text,
ha=ha,
va=va,
fontsize=fontsize,
fontweight=fontweight,
wrap=True,
)
def draw_arrow(
ax: Axes,
x1: float,
y1: float,
x2: float,
y2: float,
lw: float = 1.2,
style: str = "->",
color: str = LN,
) -> None:
"""Draw arrow."""
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
)
def save_fig(fig: Figure, name: str) -> None:
"""Save fig."""
path = str(Path(OUTPUT_DIR) / name)
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15)
plt.close(fig)
_logger.info(" Saved: %s", path)
def draw_table(
ax: Axes,
headers: list[str],
rows: list[list[str]],
x0: float,
y0: float,
col_widths: list[float],
row_h: float = 0.4,
header_fill: str = GRAY2,
row_fills: list[str] | None = None,
fontsize: float = FS,
header_fontsize: float | None = None,
) -> None:
"""Draw table."""
if header_fontsize is None:
header_fontsize = fontsize
len(headers)
# Header
cx = x0
for j, hdr in enumerate(headers):
draw_box(
ax,
cx,
y0,
col_widths[j],
row_h,
hdr,
fill=header_fill,
fontsize=header_fontsize,
fontweight="bold",
rounded=False,
)
cx += col_widths[j]
# Rows
for i, row in enumerate(rows):
cy = y0 - (i + 1) * row_h
cx = x0
fill = GRAY4 if (i % 2 == 0) else "white"
if row_fills and i < len(row_fills):
fill = row_fills[i]
for j, cell in enumerate(row):
fw = "bold" if j == 0 else "normal"
draw_box(
ax,
cx,
cy,
col_widths[j],
row_h,
cell,
fill=fill,
fontsize=fontsize,
fontweight=fw,
rounded=False,
)
cx += col_widths[j]

View File

@ -0,0 +1,240 @@
"""Late data strategies and decision tree diagrams for Q20."""
from __future__ import annotations
from _q20_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
draw_arrow,
draw_box,
plt,
save_fig,
)
# ============================================================
# 16. Late data strategies (DRAS)
# ============================================================
def gen_late_data_strategies() -> None:
"""Gen late data strategies."""
fig, ax = plt.subplots(figsize=(9, 5.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Late Data — 4 strategie (mnemonik DRAS)",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Setup: window closed, late event arrives
draw_box(
ax,
0.5,
5.5,
4.5,
1.0,
"Okno [14:00-14:05]\nZAMKNIĘTE o 14:05",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
draw_box(
ax,
6.0,
5.5,
4.5,
1.0,
"Spóźnione zdarzenie\nevent_time=14:00:03\narrives=14:05:30",
fill="#F8D7DA",
fontsize=FS_SMALL,
fontweight="bold",
)
draw_arrow(ax, 10.5, 6.0, 5.0, 6.0, lw=2, color="#C62828", style="->")
ax.text(
7.5,
5.2,
"LATE!",
fontsize=FS_LABEL,
ha="center",
fontweight="bold",
color="#C62828",
)
# 4 strategies
strategies = [
("D — Drop", "Odrzuć spóźnione", "/dev/null", GRAY4),
("R — Recompute", "Przelicz okno ponownie", "poprawne ale kosztowne", GRAY1),
(
"A — Allowed lateness",
"Czekaj dodatkowy czas\n(np. +2 min)",
"kompromis pamięci",
GRAY2,
),
(
"S — Side output",
"Przekieruj do osobnej\nkolejki",
"elastyczne, ręczna analiza",
GRAY3,
),
]
for i, (name, desc, tradeoff, color) in enumerate(strategies):
y = 3.8 - i * 1.1
draw_box(ax, 0.5, y, 2.5, 0.9, name, fill=color, fontsize=FS, fontweight="bold")
ax.text(3.3, y + 0.45, desc, fontsize=FS_SMALL, va="center")
ax.text(
8.5,
y + 0.45,
tradeoff,
fontsize=FS_SMALL,
va="center",
style="italic",
color="#555",
)
save_fig(fig, "q20_late_data_strategies.png")
# ============================================================
# 17. Decision tree — which platform
# ============================================================
def gen_decision_tree() -> None:
"""Gen decision tree."""
fig, ax = plt.subplots(figsize=(10, 5.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Drzewo decyzyjne — wybór platformy",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Root question
draw_box(
ax,
3.5,
5.5,
4.5,
1.0,
"Latencja < 10ms\nwymagana?",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# TAK branch
draw_arrow(ax, 3.5, 5.7, 2.0, 5.0, lw=1.5)
ax.text(2.3, 5.3, "TAK", fontsize=FS, fontweight="bold")
draw_box(
ax,
0.3,
3.5,
3.5,
1.0,
"Dane już w Kafce?\nProste transformacje?",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
# TAK → Kafka Streams
draw_arrow(ax, 0.3, 3.7, -0.1, 3.0, lw=1.5)
ax.text(0.0, 3.3, "TAK", fontsize=FS_SMALL, fontweight="bold")
draw_box(
ax,
-0.3,
1.8,
2.5,
1.0,
"Kafka\nStreams",
fill=GRAY5,
fontsize=FS_LABEL,
fontweight="bold",
)
# NIE → Flink
draw_arrow(ax, 3.8, 3.7, 4.5, 3.0, lw=1.5)
ax.text(4.0, 3.3, "NIE\n(złożona logika)", fontsize=FS_SMALL)
draw_box(
ax,
3.0,
1.8,
2.5,
1.0,
"Apache\nFlink",
fill=GRAY5,
fontsize=FS_LABEL,
fontweight="bold",
)
# NIE branch
draw_arrow(ax, 8.0, 5.7, 9.5, 5.0, lw=1.5)
ax.text(8.7, 5.3, "NIE", fontsize=FS, fontweight="bold")
draw_box(
ax,
7.5,
3.5,
4.2,
1.0,
"~100ms-1s OK?\nPotrzeba ML / SQL?",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
# TAK + ML → Spark
draw_arrow(ax, 9.5, 3.5, 9.5, 3.0, lw=1.5)
ax.text(10.0, 3.3, "TAK + ML/SQL", fontsize=FS_SMALL)
draw_box(
ax,
8.0,
1.8,
2.5,
1.0,
"Spark\nStreaming",
fill=GRAY5,
fontsize=FS_LABEL,
fontweight="bold",
)
# TAK + proste → Kafka Streams too
draw_arrow(ax, 7.5, 3.7, 6.5, 3.0, lw=1.5)
ax.text(6.3, 3.3, "proste + TAK", fontsize=FS_SMALL)
draw_box(
ax,
5.8,
1.8,
2.0,
1.0,
"Kafka\nStreams",
fill=GRAY5,
fontsize=FS,
fontweight="bold",
)
# Legend
ax.text(
6.0,
0.7,
"Reguła: Kafka Streams = najprostsze (library) | "
"Flink = najpotężniejszy (true streaming) | Spark = ekosystem ML",
fontsize=FS,
ha="center",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
save_fig(fig, "q20_decision_tree.png")

View File

@ -0,0 +1,471 @@
"""Streaming ecosystem, micro-batch, platform comparison, and engine diagrams for Q20."""
from __future__ import annotations
from _q20_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
draw_table,
plt,
save_fig,
)
# ============================================================
# 7. Streaming ecosystem overview
# ============================================================
def gen_streaming_ecosystem() -> None:
"""Gen streaming ecosystem."""
fig, ax = plt.subplots(figsize=(10, 5.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Ekosystem przetwarzania strumieniowego",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Source
draw_box(
ax,
0.3,
2.5,
2.0,
3.0,
"Kafka\nTopics\n(źródło)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Engines
engines = [
("Kafka Streams\n(library w JVM)", GRAY1, 4.7),
("Apache Flink\n(klaster)", GRAY3, 3.2),
("Spark Streaming\n(klaster)", GRAY5, 1.7),
]
for label, color, y in engines:
draw_box(
ax, 4.0, y, 3.0, 1.2, label, fill=color, fontsize=FS, fontweight="bold"
)
draw_arrow(ax, 2.3, 4.0, 4.0, y + 0.6, lw=1.5)
# Sinks
sinks = [
("Kafka topic\n/ baza danych", GRAY4, 4.7),
("DB / Kafka\n/ S3", GRAY4, 3.2),
("HDFS / DB\n/ dashboard", GRAY4, 1.7),
]
for label, color, y in sinks:
draw_box(ax, 8.5, y, 2.5, 1.2, label, fill=color, fontsize=FS)
draw_arrow(ax, 7.0, y + 0.6, 8.5, y + 0.6, lw=1.5)
# Labels
ax.text(1.3, 6.0, "ŹRÓDŁO", fontsize=FS_LABEL, ha="center", fontweight="bold")
ax.text(5.5, 6.2, "SILNIK", fontsize=FS_LABEL, ha="center", fontweight="bold")
ax.text(9.75, 6.2, "WYNIK", fontsize=FS_LABEL, ha="center", fontweight="bold")
# Latency annotations
ax.text(5.5, 5.95, "~1-10 ms", fontsize=FS_SMALL, ha="center", style="italic")
ax.text(5.5, 4.5, "<10 ms", fontsize=FS_SMALL, ha="center", style="italic")
ax.text(5.5, 3.0, "~100 ms", fontsize=FS_SMALL, ha="center", style="italic")
save_fig(fig, "q20_streaming_ecosystem.png")
# ============================================================
# 8. True streaming vs Micro-batch
# ============================================================
def gen_true_vs_microbatch() -> None:
"""Gen true vs microbatch."""
fig, axes = plt.subplots(2, 1, figsize=(10, 5.5))
fig.suptitle("True Streaming vs Micro-Batch", fontsize=FS_TITLE, fontweight="bold")
# True streaming
ax = axes[0]
ax.set_xlim(0, 12)
ax.set_ylim(0, 3.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"TRUE STREAMING (Flink, Kafka Streams) — event-by-event",
fontsize=FS_LABEL,
fontweight="bold",
)
for i in range(6):
x = 1.0 + i * 1.8
# Event
draw_box(
ax,
x,
2.0,
0.8,
0.7,
f"e{i + 1}",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
rounded=False,
)
# Arrow down
draw_arrow(ax, x + 0.4, 2.0, x + 0.4, 1.4, lw=1)
# Result
draw_box(
ax,
x,
0.5,
0.8,
0.7,
f"r{i + 1}",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
rounded=False,
)
# Latency label
ax.text(x + 0.4, 1.6, "~ms", fontsize=5, ha="center", color="#555")
ax.text(
11.5,
1.3,
"Latencja:\n< 10 ms",
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
# Micro-batch
ax = axes[1]
ax.set_xlim(0, 12)
ax.set_ylim(0, 3.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"MICRO-BATCH (Spark Streaming) — grupami co ~100ms",
fontsize=FS_LABEL,
fontweight="bold",
)
batch_colors = [GRAY1, GRAY2, GRAY3]
for b in range(3):
bx = 0.8 + b * 3.5
# Batch boundary
draw_box(ax, bx, 1.8, 3.0, 1.0, "", fill=batch_colors[b], rounded=True, lw=1.5)
ax.text(
bx + 1.5, 2.6, f"Batch {b + 1}", fontsize=FS, ha="center", fontweight="bold"
)
for j in range(3):
ex = bx + 0.3 + j * 0.9
draw_box(
ax,
ex,
2.0,
0.7,
0.5,
f"e{b * 3 + j + 1}",
fill="white",
fontsize=FS_SMALL,
rounded=False,
)
# Arrow down
draw_arrow(ax, bx + 1.5, 1.8, bx + 1.5, 1.2, lw=1.5)
# Result
draw_box(
ax,
bx + 0.5,
0.4,
2.0,
0.7,
f"result {b + 1}",
fill=GRAY4,
fontsize=FS,
fontweight="bold",
)
ax.text(
11.5,
1.3,
"Latencja:\n~100ms-s",
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
fig.tight_layout(rect=[0, 0, 1, 0.92])
save_fig(fig, "q20_true_vs_microbatch.png")
# ============================================================
# 9. Platform comparison table
# ============================================================
def gen_platform_comparison() -> None:
"""Gen platform comparison."""
fig, ax = plt.subplots(figsize=(9, 5))
ax.set_xlim(0, 11.5)
ax.set_ylim(-6, 1)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Porównanie platform strumieniowych",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
headers = ["Cecha", "Kafka Streams", "Apache Flink", "Spark Streaming"]
col_w = [2.5, 2.8, 2.8, 2.8]
rows = [
["Model", "event-by-event", "event-by-event", "micro-batch (~100ms)"],
["Deployment", "library (w JVM)", "klaster", "klaster"],
["Latencja", "~1-10 ms", "< 10 ms", "100 ms - sekundy"],
["Exactly-once", "Kafka TXN", "checkpointing", "WAL"],
["State", "RocksDB local", "RocksDB + ckpt", "in-memory / ext"],
["Okna", "T, S, Session", "wszystkie + custom", "T, S"],
["Use case", "Kafka → Kafka", "złożona analityka", "ETL + ML / SQL"],
]
draw_table(
ax,
headers,
rows,
x0=0.25,
y0=0.5,
col_widths=col_w,
row_h=0.6,
fontsize=7,
header_fontsize=8,
)
save_fig(fig, "q20_platform_comparison.png")
# ============================================================
# 10. Kafka Streams architecture
# ============================================================
def gen_kafka_streams_arch() -> None:
"""Gen kafka streams arch."""
fig, ax = plt.subplots(figsize=(9, 5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Kafka Streams — architektura (library w JVM)",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Outer box: Your Java application
draw_box(ax, 0.5, 0.5, 11.0, 5.5, "", fill=GRAY4, rounded=True, lw=2.5)
ax.text(
6.0,
5.7,
"Twoja aplikacja Java (JVM)",
fontsize=FS_LABEL,
ha="center",
fontweight="bold",
)
# Kafka Consumer
draw_box(
ax,
1.0,
3.0,
2.5,
1.5,
"Kafka\nConsumer\n(input topic)",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
# Processing
draw_box(
ax,
4.5,
3.0,
2.5,
1.5,
"Kafka Streams\n(logika\nbiznesowa)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Kafka Producer
draw_box(
ax,
8.0,
3.0,
2.5,
1.5,
"Kafka\nProducer\n(output topic)",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
# Arrows
draw_arrow(ax, 3.5, 3.75, 4.5, 3.75, lw=2)
draw_arrow(ax, 7.0, 3.75, 8.0, 3.75, lw=2)
# RocksDB state store
draw_box(
ax,
4.5,
1.0,
2.5,
1.3,
"RocksDB\n(stan lokalny)",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
ax.plot([5.75, 5.75], [3.0, 2.3], color=LN, lw=1.5)
ax.text(
7.3,
1.6,
"okna, joiny,\nagregacje",
fontsize=FS_SMALL,
style="italic",
va="center",
)
# Key message
ax.text(
6.0,
0.2,
"NIE potrzebujesz osobnego klastra! Skalujesz = więcej instancji JVM.",
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN},
)
save_fig(fig, "q20_kafka_streams_arch.png")
# ============================================================
# 11. Flink architecture + checkpointing
# ============================================================
def gen_flink_arch() -> None:
"""Gen flink arch."""
fig, ax = plt.subplots(figsize=(9, 6))
ax.set_xlim(0, 12)
ax.set_ylim(0, 8)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Apache Flink — architektura klastra + checkpointing",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Cluster border
draw_box(ax, 0.3, 1.0, 11.4, 6.2, "", fill=GRAY4, rounded=True, lw=2.5)
ax.text(
6.0, 6.95, "FLINK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold"
)
# Job Manager
draw_box(
ax,
1.0,
5.5,
3.0,
1.2,
"Job Manager\n(koordynacja,\ncheckpointy)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Task Managers
draw_box(ax, 1.0, 3.0, 10.0, 2.0, "", fill="white", rounded=True, lw=1.5)
ax.text(
6.0, 4.7, "Task Managers (workery)", fontsize=FS, ha="center", fontweight="bold"
)
slots = ["source\n& map()", "map()", "window()\n& reduce", "sink()"]
for i, s in enumerate(slots):
x = 1.5 + i * 2.4
draw_box(
ax,
x,
3.3,
2.0,
1.2,
f"Slot {i + 1}\n{s}",
fill=GRAY1,
fontsize=FS_SMALL,
fontweight="bold",
)
draw_arrow(ax, 2.5, 5.5, 6.0, 5.0, lw=1.5, style="->")
ax.text(5.0, 5.5, "przydziela\npodzadania", fontsize=FS_SMALL, style="italic")
# Checkpoint storage
draw_box(
ax,
5.5,
1.2,
3.5,
1.2,
"Checkpoint Storage\n(HDFS / S3)",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
ax.plot([7.25, 7.25], [2.4, 3.3], color=LN, lw=1.5, linestyle="--")
ax.text(8.0, 2.7, "snapshoty\nstanu", fontsize=FS_SMALL, style="italic")
# Barrier concept at bottom
ax.text(3.0, 1.6, "Barrier:", fontsize=FS, fontweight="bold")
barrier_boxes = ["source", "|B|", "map", "|B|", "sink"]
bx = 0.8
for _i, b in enumerate(barrier_boxes):
if b == "|B|":
ax.text(
bx + 0.3,
1.5,
b,
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN},
)
draw_arrow(ax, bx, 1.5, bx + 0.1, 1.5, lw=1)
bx += 0.7
else:
draw_box(
ax,
bx,
1.3,
1.0,
0.45,
b,
fill=GRAY1,
fontsize=FS_SMALL,
fontweight="bold",
)
bx += 1.2
save_fig(fig, "q20_flink_arch.png")

View File

@ -0,0 +1,464 @@
"""Event time, fraud detection, SLA monitoring, and session diagrams for Q20."""
from __future__ import annotations
from _q20_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_box,
np,
plt,
rng,
save_fig,
)
import matplotlib.patches as mpatches
# ============================================================
# 3. Event Time vs Processing Time scatter + watermark
# ============================================================
def gen_event_vs_processing_time() -> None:
"""Gen event vs processing time."""
fig, axes = plt.subplots(1, 2, figsize=(11, 5))
fig.suptitle(
"Event Time vs Processing Time + Watermark",
fontsize=FS_TITLE,
fontweight="bold",
)
# --- Panel 1: Ideal vs Real ---
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect("equal")
ax.set_xlabel("Event Time", fontsize=FS_LABEL)
ax.set_ylabel("Processing Time", fontsize=FS_LABEL)
ax.set_title("Idealny vs Realny świat", fontsize=FS_LABEL, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Ideal line
ax.plot([0, 9], [0, 9], "k--", lw=1.5, label="ideał (brak opóźnień)")
# Real scattered points (processing >= event, some out of order)
event_times = np.sort(rng.uniform(1, 8, 15))
proc_times = event_times + rng.exponential(0.5, 15)
# Make some out of order
idx = [3, 7, 11]
for i in idx:
proc_times[i] += 1.5
ax.scatter(
event_times, proc_times, c="black", s=30, zorder=5, label="zdarzenia (realne)"
)
# Highlight out-of-order
for i in idx:
ax.annotate(
"out-of-order",
xy=(event_times[i], proc_times[i]),
xytext=(event_times[i] + 0.8, proc_times[i] + 0.5),
fontsize=FS_SMALL,
ha="left",
arrowprops={"arrowstyle": "->", "lw": 0.8, "color": "#555"},
)
ax.legend(fontsize=FS_SMALL, loc="upper left")
ax.text(
7,
2,
"Opóźnienie\nsieciowe ↑",
fontsize=FS,
ha="center",
style="italic",
color="#555",
)
# --- Panel 2: Watermark concept ---
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect("equal")
ax.set_xlabel("Event Time", fontsize=FS_LABEL)
ax.set_ylabel("Processing Time", fontsize=FS_LABEL)
ax.set_title("Watermark — granica postępu", fontsize=FS_LABEL, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Events
ax.scatter(event_times, proc_times, c="black", s=30, zorder=5)
# Watermark line (below most points, tracks progress)
wm_x = np.linspace(0, 9, 50)
wm_y = wm_x + 0.3 # watermark slightly above ideal
ax.plot(wm_x, wm_y, "k-", lw=2.5, label="Watermark")
ax.fill_between(wm_x, 0, wm_y, alpha=0.15, color="gray")
ax.text(
2.0,
1.0,
'PONIŻEJ watermark:\n„na pewno dotarło"',
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
# Late event
late_x, late_y = event_times[7], proc_times[7]
ax.scatter(
[late_x], [late_y], c="white", s=80, zorder=6, edgecolors="black", linewidths=2
)
ax.annotate(
"LATE DATA!\n(po watermarku)",
xy=(late_x, late_y),
xytext=(late_x + 1.2, late_y + 0.8),
fontsize=FS_SMALL,
ha="left",
fontweight="bold",
arrowprops={"arrowstyle": "->", "lw": 1, "color": LN},
)
ax.legend(fontsize=FS_SMALL, loc="upper left")
fig.tight_layout(rect=[0, 0, 1, 0.92])
save_fig(fig, "q20_event_vs_processing_time.png")
# ============================================================
# 4. Tumbling window example — fraud detection
# ============================================================
def gen_tumbling_fraud() -> None:
"""Gen tumbling fraud."""
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_xlim(0, 12)
ax.set_ylim(0, 5.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Tumbling Window — fraud detection (okno = 1 min)",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Time axis
ax.annotate(
"",
xy=(11.5, 1.0),
xytext=(0.5, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(6.0, 0.4, "czas", fontsize=FS, ha="center")
# Window 1: normal
draw_box(ax, 1.0, 1.5, 4.5, 3.0, "", fill=GRAY4, rounded=True, lw=2)
ax.text(
3.25, 4.2, "[14:00 — 14:01]", fontsize=FS_LABEL, ha="center", fontweight="bold"
)
# Transactions
txns1 = ["Sklep A: 50 zł", "Sklep B: 30 zł", "Stacja: 80 zł"]
for i, t in enumerate(txns1):
draw_box(
ax,
1.3,
3.3 - i * 0.55,
4.0,
0.45,
t,
fill=GRAY1,
fontsize=FS_SMALL,
rounded=False,
)
ax.text(
3.25,
1.7,
"count = 3 → OK",
fontsize=FS,
ha="center",
fontweight="bold",
color="#2E7D32",
bbox={
"boxstyle": "round,pad=0.15",
"facecolor": "#E8F5E9",
"edgecolor": "#2E7D32",
},
)
# Window 2: fraud!
draw_box(ax, 6.0, 1.5, 4.5, 3.0, "", fill=GRAY1, rounded=True, lw=2)
ax.text(
8.25, 4.2, "[14:01 — 14:02]", fontsize=FS_LABEL, ha="center", fontweight="bold"
)
txns2 = ["ATM Warszawa: 500 zł", "ATM Kraków: 500 zł", "... +45 transakcji"]
for i, t in enumerate(txns2):
draw_box(
ax,
6.3,
3.3 - i * 0.55,
4.0,
0.45,
t,
fill=GRAY3,
fontsize=FS_SMALL,
rounded=False,
)
ax.text(
8.25,
1.7,
"count = 47 → ALERT!",
fontsize=FS,
ha="center",
fontweight="bold",
color="#C62828",
bbox={
"boxstyle": "round,pad=0.15",
"facecolor": "#F8D7DA",
"edgecolor": "#C62828",
},
)
save_fig(fig, "q20_tumbling_fraud.png")
# ============================================================
# 5. Sliding window — SLA monitoring
# ============================================================
def gen_sliding_sla() -> None:
"""Gen sliding sla."""
fig, ax = plt.subplots(figsize=(9, 4.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 6)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Sliding Window — monitoring SLA (okno=5min, krok=1min)",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Time axis
ax.annotate(
"",
xy=(11.5, 0.5),
xytext=(0.5, 0.5),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
times = ["14:05", "14:06", "14:07", "14:08", "14:09"]
latencies = [120, 180, 340, 290, 150]
sla = 200
for i, (t, lat) in enumerate(zip(times, latencies, strict=False)):
x = 1.5 + i * 2.0
ax.text(x, 0.1, t, fontsize=FS, ha="center")
# Bar proportional to latency
bar_h = lat / 100.0
is_breach = lat > sla
fill = "#F8D7DA" if is_breach else GRAY1
edge = "#C62828" if is_breach else LN
draw_box(
ax,
x - 0.5,
1.0,
1.0,
bar_h,
"",
fill=fill,
rounded=False,
edgecolor=edge,
lw=1.5,
)
ax.text(
x,
1.0 + bar_h + 0.15,
f"{lat}ms",
fontsize=FS,
ha="center",
fontweight="bold",
color="#C62828" if is_breach else LN,
)
# Status
status = "ALERT!" if is_breach else "OK"
ax.text(
x,
1.0 + bar_h + 0.55,
status,
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
color="#C62828" if is_breach else "#2E7D32",
)
# SLA line
sla_y = 1.0 + sla / 100.0
ax.plot([0.8, 11.2], [sla_y, sla_y], "k--", lw=1.5)
ax.text(11.3, sla_y, f"SLA={sla}ms", fontsize=FS, va="center", fontweight="bold")
# Sliding window bracket
ax.annotate(
"",
xy=(1.0, 5.3),
xytext=(5.0, 5.3),
arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN},
)
ax.text(3.0, 5.6, "okno = 5 min", fontsize=FS, ha="center", fontweight="bold")
ax.annotate(
"",
xy=(3.0, 4.8),
xytext=(5.0, 4.8),
arrowprops={"arrowstyle": "<->", "lw": 1, "color": "#555"},
)
ax.text(
4.0,
4.4,
"krok = 1 min\n(nakładanie!)",
fontsize=FS_SMALL,
ha="center",
style="italic",
)
save_fig(fig, "q20_sliding_sla.png")
# ============================================================
# 6. Session window — user sessions
# ============================================================
def gen_session_users() -> None:
"""Gen session users."""
fig, axes = plt.subplots(2, 1, figsize=(10, 5))
fig.suptitle(
"Session Window — sesje użytkowników (gap = 30 min)",
fontsize=FS_TITLE,
fontweight="bold",
)
# Anna: 2 sessions
ax = axes[0]
ax.set_xlim(0, 14)
ax.set_ylim(0, 3.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Użytkownik Anna", fontsize=FS_LABEL, fontweight="bold")
ax.annotate(
"",
xy=(13.5, 1.0),
xytext=(0.3, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
# Clicks cluster 1
for x in [1.0, 1.8, 2.5, 3.2]:
ax.plot(x, 1.0, "ko", markersize=6)
# Clicks cluster 2
for x in [9.0, 9.8, 10.5]:
ax.plot(x, 1.0, "ko", markersize=6)
# Sessions
rect1 = mpatches.FancyBboxPatch(
(0.7, 1.5),
2.8,
1.2,
boxstyle="round,pad=0.1",
facecolor=GRAY1,
edgecolor=LN,
lw=1.5,
)
ax.add_patch(rect1)
ax.text(
2.1,
2.1,
"Sesja 1\n4 kliknięcia, 12 min",
fontsize=FS,
ha="center",
fontweight="bold",
)
rect2 = mpatches.FancyBboxPatch(
(8.7, 1.5),
2.1,
1.2,
boxstyle="round,pad=0.1",
facecolor=GRAY3,
edgecolor=LN,
lw=1.5,
)
ax.add_patch(rect2)
ax.text(
9.75,
2.1,
"Sesja 2\n3 kliknięcia, 8 min",
fontsize=FS,
ha="center",
fontweight="bold",
)
# Gap
ax.annotate(
"",
xy=(8.5, 0.5),
xytext=(3.8, 0.5),
arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN},
)
ax.text(
6.15,
0.1,
"cisza 45 min > gap(30)",
fontsize=FS,
ha="center",
fontweight="bold",
style="italic",
)
# Bob: 1 session
ax = axes[1]
ax.set_xlim(0, 14)
ax.set_ylim(0, 3.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Użytkownik Bob", fontsize=FS_LABEL, fontweight="bold")
ax.annotate(
"",
xy=(13.5, 1.0),
xytext=(0.3, 1.0),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
# Clicks spread evenly
bobs = [1.0, 2.5, 4.0, 5.5, 7.0, 8.5, 10.0]
for x in bobs:
ax.plot(x, 1.0, "ko", markersize=6)
rect = mpatches.FancyBboxPatch(
(0.7, 1.5),
9.6,
1.2,
boxstyle="round,pad=0.1",
facecolor=GRAY1,
edgecolor=LN,
lw=2,
)
ax.add_patch(rect)
ax.text(
5.5,
2.1,
"Sesja 1 (ciągła) — 7 kliknięć, każde < 30 min od poprzedniego",
fontsize=FS,
ha="center",
fontweight="bold",
)
fig.tight_layout(rect=[0, 0, 1, 0.92])
save_fig(fig, "q20_session_users.png")

View File

@ -0,0 +1,467 @@
"""FCN and U-Net architecture diagram generators."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q23_common import (
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TINY,
FS_TITLE,
GRAY1,
GRAY2,
GRAY5,
GREEN_ACCENT,
RED_ACCENT,
_save_figure,
plt,
)
from matplotlib.patches import FancyBboxPatch
if TYPE_CHECKING:
from matplotlib.axes import Axes
def generate_fcn() -> None:
"""Generate fcn."""
_fig, axes = plt.subplots(2, 1, figsize=(10, 7))
# --- Panel 1: FC vs Conv 1x1 ---
ax = axes[0]
ax.set_xlim(0, 20)
ax.set_ylim(0, 6)
ax.axis("off")
ax.set_title(
"FC (Fully Connected) vs Conv 1x1", fontsize=FS_TITLE, fontweight="bold"
)
# Classic CNN with FC
layer_info_fc = [
(1.5, "Obraz\n224x224x3", 2.2, GRAY2),
(4.5, "Conv+Pool\n112x112x64", 1.8, GRAY2),
(7.5, "Conv+Pool\n7x7x512", 1.0, GRAY2),
(10, "Flatten\n25088", 0.5, ACCENT_LIGHT),
(12, "FC\n4096", 0.5, ACCENT_LIGHT),
(14, "FC\n1000", 0.3, ACCENT_LIGHT),
(16, '"Kot"', 0.3, "#FFCDD2"),
]
y_fc = 4.5
for i, (x, label, w, color) in enumerate(layer_info_fc):
rect = FancyBboxPatch(
(x - w / 2, y_fc - 0.6),
w,
1.2,
boxstyle="round,pad=0.05",
facecolor=color,
edgecolor=BLACK,
linewidth=0.8,
)
ax.add_patch(rect)
ax.text(x, y_fc, label, ha="center", va="center", fontsize=FS_TINY)
if i < len(layer_info_fc) - 1:
next_x = layer_info_fc[i + 1][0]
ax.annotate(
"",
xy=(next_x - layer_info_fc[i + 1][2] / 2, y_fc),
xytext=(x + w / 2, y_fc),
arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1},
)
ax.text(
0.3, y_fc, "CNN:", fontsize=FS, fontweight="bold", color=RED_ACCENT, va="center"
)
ax.text(
12,
y_fc + 1,
"PROBLEM: FC wymaga\nSTAŁEGO rozmiaru\n(np. 224x224)",
ha="center",
fontsize=FS_SMALL,
color=RED_ACCENT,
fontweight="bold",
bbox={
"boxstyle": "round",
"facecolor": "#FFCDD2",
"edgecolor": RED_ACCENT,
"alpha": 0.3,
},
)
# FCN with Conv 1x1
layer_info_fcn = [
(1.5, "Obraz\nHxWx3", 2.2, GRAY2),
(4.5, "Conv+Pool\nH/2 x W/2\nx64", 1.8, GRAY2),
(7.5, "Conv+Pool\nH/32 x W/32\nx512", 1.0, GRAY2),
(10.5, "Conv 1x1\nH/32 x W/32\nxC", 0.8, "#C8E6C9"),
(13.5, "Upsample\nHxWxC", 1.8, "#C8E6C9"),
(16.5, "Mapa\nsegmentacji", 1.5, "#C8E6C9"),
]
y_fcn = 1.5
for i, (x, label, w, color) in enumerate(layer_info_fcn):
rect = FancyBboxPatch(
(x - w / 2, y_fcn - 0.7),
w,
1.4,
boxstyle="round,pad=0.05",
facecolor=color,
edgecolor=BLACK,
linewidth=0.8,
)
ax.add_patch(rect)
ax.text(x, y_fcn, label, ha="center", va="center", fontsize=FS_TINY)
if i < len(layer_info_fcn) - 1:
next_x = layer_info_fcn[i + 1][0]
ax.annotate(
"",
xy=(next_x - layer_info_fcn[i + 1][2] / 2, y_fcn),
xytext=(x + w / 2, y_fcn),
arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1},
)
ax.text(
0.3,
y_fcn,
"FCN:",
fontsize=FS,
fontweight="bold",
color=GREEN_ACCENT,
va="center",
)
ax.text(
10.5,
y_fcn + 1.2,
"Conv 1x1:\nkażdy piksel\nosobno x wagi\n(jak FC ale\nzachowuje HxW)",
ha="center",
fontsize=FS_TINY,
color=GREEN_ACCENT,
bbox={
"boxstyle": "round",
"facecolor": "#C8E6C9",
"edgecolor": GREEN_ACCENT,
"alpha": 0.3,
},
)
# --- Panel 2: What FC and Conv do ---
ax = axes[1]
ax.set_xlim(0, 20)
ax.set_ylim(0, 6)
ax.axis("off")
ax.set_title(
"Co robi warstwa FC? Co robi konwolucja?", fontsize=FS_TITLE, fontweight="bold"
)
# FC explanation
rect = FancyBboxPatch(
(0.3, 3.2),
9,
2.5,
boxstyle="round,pad=0.15",
facecolor=ACCENT_LIGHT,
edgecolor=ACCENT,
linewidth=1,
)
ax.add_patch(rect)
ax.text(
4.8, 5.2, "Fully Connected (FC)", fontsize=FS, fontweight="bold", ha="center"
)
ax.text(
4.8,
4.5,
"KAŻDY neuron połączony z KAŻDYM wejściem\n"
"25 088 wejść x 4 096 neuronów = ~103 MLN wag!\n"
"Traci informację GDZIE (przestrzenną)\n"
"Wymaga STAŁEGO rozmiaru wejścia",
fontsize=FS_TINY,
ha="center",
va="top",
)
# Conv explanation
rect = FancyBboxPatch(
(10.3, 3.2),
9,
2.5,
boxstyle="round,pad=0.15",
facecolor="#C8E6C9",
edgecolor=GREEN_ACCENT,
linewidth=1,
)
ax.add_patch(rect)
ax.text(14.8, 5.2, "Konwolucja (Conv)", fontsize=FS, fontweight="bold", ha="center")
ax.text(
14.8,
4.5,
'Filtr (np. 3x3) „jedzie" po obrazie\n'
"Te same wagi dla KAŻDEJ pozycji\n"
"Zachowuje informację GDZIE\n"
"Akceptuje DOWOLNY rozmiar wejścia",
fontsize=FS_TINY,
ha="center",
va="top",
)
# Conv 1x1 explanation
rect = FancyBboxPatch(
(3, 0.3),
14,
2.2,
boxstyle="round,pad=0.15",
facecolor=GRAY1,
edgecolor=BLACK,
linewidth=1,
)
ax.add_patch(rect)
ax.text(
10,
2.1,
'Conv 1x1 = „FC per piksel"',
fontsize=FS,
fontweight="bold",
ha="center",
)
ax.text(
10,
1.5,
"Filtr 1x1: patrzy na JEDEN piksel, ale WSZYSTKIE kanały (512→C klas)\n"
"Działa jak FC ale zachowuje mapę HxW → każdy piksel osobno klasyfikowany\n"
"FCN: zamień FC na Conv1x1 → koniec z wymogiem stałego rozmiaru!",
fontsize=FS_TINY,
ha="center",
va="top",
)
_save_figure("q23_fc_vs_conv1x1.png")
def generate_unet() -> None:
"""Generate unet."""
_fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.set_xlim(-1, 21)
ax.set_ylim(-1, 12)
ax.axis("off")
ax.set_title(
"U-Net: architektura w kształcie litery U",
fontsize=FS_TITLE + 1,
fontweight="bold",
)
# Encoder layers (going DOWN-LEFT)
encoder_layers = [
(2, 10, 2.5, 1.5, "572x572x1\n(wejście)", 64),
(2, 7.5, 2.2, 1.3, "284x284\nx64", 64),
(2, 5, 1.8, 1.1, "140x140\nx128", 128),
(2, 2.5, 1.5, 1.0, "68x68\nx256", 256),
]
# Bottleneck
bottleneck = (8, 0.5, 2.5, 1.2, "32x32x512\n(bottleneck)", 512)
# Decoder layers (going UP-RIGHT)
decoder_layers = [
(14, 2.5, 1.5, 1.0, "68x68\nx256", 256),
(14, 5, 1.8, 1.1, "140x140\nx128", 128),
(14, 7.5, 2.2, 1.3, "284x284\nx64", 64),
(14, 10, 2.5, 1.5, "572x572xC\n(mapa seg.)", "C"),
]
def draw_block(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
label: str,
color: str,
) -> None:
"""Draw block."""
rect = FancyBboxPatch(
(x - w / 2, y - h / 2),
w,
h,
boxstyle="round,pad=0.05",
facecolor=color,
edgecolor=BLACK,
linewidth=1.2,
)
ax.add_patch(rect)
ax.text(x, y, label, ha="center", va="center", fontsize=FS_TINY)
# Draw encoder
for x, y, w, h, label, _channels in encoder_layers:
draw_block(ax, x, y, w, h, label, ACCENT_LIGHT)
# Draw arrows down (encoder)
for i in range(len(encoder_layers) - 1):
x1, y1 = encoder_layers[i][0], encoder_layers[i][1] - encoder_layers[i][3] / 2
x2, y2 = (
encoder_layers[i + 1][0],
encoder_layers[i + 1][1] + encoder_layers[i + 1][3] / 2,
)
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2},
)
ax.text(
x1 - 1.7,
(y1 + y2) / 2,
"MaxPool\n2x2\n↓ zmniejsz",
fontsize=FS_TINY,
ha="center",
color=ACCENT,
fontweight="bold",
)
# Encoder to bottleneck
x1, y1 = encoder_layers[-1][0], encoder_layers[-1][1] - encoder_layers[-1][3] / 2
draw_block(
ax,
bottleneck[0],
bottleneck[1],
bottleneck[2],
bottleneck[3],
bottleneck[4],
GRAY2,
)
ax.annotate(
"",
xy=(bottleneck[0] - bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2),
xytext=(x1, y1),
arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2},
)
# Bottleneck to decoder
ax.annotate(
"",
xy=(
decoder_layers[0][0] - decoder_layers[0][2] / 2,
decoder_layers[0][1] - decoder_layers[0][3] / 2,
),
xytext=(bottleneck[0] + bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2),
arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 2},
)
# Draw decoder
for x, y, w, h, label, channels in decoder_layers:
color = "#C8E6C9" if channels != "C" else "#A5D6A7"
draw_block(ax, x, y, w, h, label, color)
# Draw arrows up (decoder)
for i in range(len(decoder_layers) - 1):
x1, y1 = decoder_layers[i][0], decoder_layers[i][1] + decoder_layers[i][3] / 2
x2, y2 = (
decoder_layers[i + 1][0],
decoder_layers[i + 1][1] - decoder_layers[i + 1][3] / 2,
)
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT, "lw": 2},
)
ax.text(
x1 + 2,
(y1 + y2) / 2,
"UpConv\n2x2\n↑ zwiększ",
fontsize=FS_TINY,
ha="center",
color=GREEN_ACCENT,
fontweight="bold",
)
# Skip connections (horizontal arrows)
for i in range(len(encoder_layers)):
enc = encoder_layers[i]
dec = decoder_layers[len(decoder_layers) - 1 - i]
ax.annotate(
"",
xy=(dec[0] - dec[2] / 2, dec[1]),
xytext=(enc[0] + enc[2] / 2, enc[1]),
arrowprops={
"arrowstyle": "->",
"color": GRAY5,
"lw": 1.5,
"linestyle": "dashed",
},
)
mid_x = (enc[0] + enc[2] / 2 + dec[0] - dec[2] / 2) / 2
ax.text(
mid_x,
enc[1] + 0.6,
"skip\n(concat)",
fontsize=FS_TINY,
ha="center",
color=GRAY5,
fontweight="bold",
)
# Labels
ax.text(
0,
11.5,
"ENCODER\n(↓ zmniejsza)",
fontsize=FS,
fontweight="bold",
color=ACCENT,
ha="center",
)
ax.text(
17,
11.5,
"DECODER\n(↑ zwiększa)",
fontsize=FS,
fontweight="bold",
color=GREEN_ACCENT,
ha="center",
)
ax.text(
8,
-0.8,
'Kształt litery „U": encoder schodzi ↓ → bottleneck na dnie → decoder wraca ↑',
fontsize=FS_SMALL,
ha="center",
color=GRAY5,
fontweight="bold",
)
# Concatenation explanation
rect = FancyBboxPatch(
(17.5, 3),
3,
5,
boxstyle="round,pad=0.15",
facecolor=GRAY1,
edgecolor=GRAY5,
linewidth=1,
linestyle="--",
)
ax.add_patch(rect)
ax.text(
19, 7.5, "Concatenation:", fontsize=FS_SMALL, ha="center", fontweight="bold"
)
ax.text(
19,
6.5,
"Encoder: 64 kanały\nDecoder: 64 kanały\n→ concat → 128 kanałów\n\n"
"Jak sklejenie\ndwóch stosów\nkart:",
fontsize=FS_TINY,
ha="center",
)
ax.text(
19,
3.7,
"[enc₁|enc₂|...|dec₁|dec₂|...]",
fontsize=FS_TINY - 1,
ha="center",
fontweight="bold",
color=ACCENT,
)
_save_figure("q23_unet_arch.png")

View File

@ -0,0 +1,96 @@
"""Common utilities and constants for Q23 diagram generation.
A4-compatible, monochrome-friendly (grays + one accent), 300 DPI.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
if TYPE_CHECKING:
from matplotlib.axes import Axes
_logger = logging.getLogger(__name__)
rng = np.random.default_rng(42)
DPI = 300
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
# Color palette — monochrome-friendly
BLACK = "#000000"
WHITE = "#FFFFFF"
GRAY1 = "#F5F5F5"
GRAY2 = "#E0E0E0"
GRAY3 = "#BDBDBD"
GRAY4 = "#9E9E9E"
GRAY5 = "#757575"
GRAY6 = "#424242"
ACCENT = "#4A90D9" # single blue accent for highlights
ACCENT_LIGHT = "#B3D4FC"
RED_ACCENT = "#D32F2F"
GREEN_ACCENT = "#388E3C"
FS = 9
FS_TITLE = 11
FS_SMALL = 7
FS_TINY = 6
_RIDGE_X = 5
_VALLEY2_END = 9
_DARK_PIXEL_THRESHOLD = 100
_GRID_LAST_IDX = 3
_HIGHLIGHT_START = 3
_HIGHLIGHT_END = 5
_BRIGHT_THRESHOLD = 170
_OTSU_THRESHOLD = 128
def _save_figure(name: str) -> None:
"""Save current figure and log."""
plt.tight_layout()
plt.savefig(
str(Path(OUTPUT_DIR) / name),
dpi=DPI,
bbox_inches="tight",
facecolor="white",
)
plt.close()
_logger.info("%s", name)
def _render_text_lines(
ax: Axes,
lines: list[tuple[str, int, str, str]],
*,
x_pos: float = 0.5,
start_y: float,
y_step: float = 0.5,
y_empty_step: float = 0.2,
) -> None:
"""Render a list of styled text lines on an axis."""
y = start_y
for txt, size, color, weight in lines:
if txt == "":
y -= y_empty_step
continue
ax.text(
x_pos,
y,
txt,
fontsize=size,
color=color,
fontweight=weight,
va="top",
)
y -= y_step

View File

@ -0,0 +1,251 @@
"""DIY U-Net step-by-step diagram generator."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q23_common import (
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TINY,
GRAY1,
GRAY3,
GRAY5,
GREEN_ACCENT,
WHITE,
_save_figure,
np,
plt,
rng,
)
from matplotlib.patches import FancyBboxPatch
if TYPE_CHECKING:
from matplotlib.axes import Axes
def _draw_unet_layer_stack(
ax: Axes,
layer_sizes: list[tuple[int, int]],
*,
face_color: str,
edge_color: str,
arrow_color: str,
arrow_label: str,
add_skip: bool = False,
) -> None:
"""Draw encoder or decoder layer stack for DIY U-Net."""
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
y_pos = 8.5
for i, (s, c) in enumerate(layer_sizes):
w = s / 64 * 4
h = 0.8
rect = FancyBboxPatch(
(5 - w / 2, y_pos),
w,
h,
boxstyle="round,pad=0.05",
facecolor=face_color,
edgecolor=edge_color,
linewidth=1,
)
ax.add_patch(rect)
label = f"{s}x{s}x{c}"
if add_skip and i < len(layer_sizes) - 1:
label += " + skip!"
ax.text(
5,
y_pos + h / 2,
label,
ha="center",
va="center",
fontsize=FS_SMALL,
fontweight="bold",
)
if i < len(layer_sizes) - 1:
ax.annotate(
"",
xy=(5, y_pos - 0.3),
xytext=(5, y_pos),
arrowprops={
"arrowstyle": "->",
"color": arrow_color,
"lw": 1.5,
},
)
ax.text(
7,
y_pos - 0.15,
arrow_label,
fontsize=FS_TINY,
color=arrow_color,
)
y_pos -= 2.2
def _draw_unet_pseudocode(ax: Axes) -> None:
"""Draw panel 6: U-Net pseudocode."""
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Pseudokod U-Net", fontsize=FS, fontweight="bold")
code_lines = [
"# ENCODER",
"e1 = conv_block(input, 64) # 64x64",
"e2 = conv_block(pool(e1), 128) # 32x32",
"e3 = conv_block(pool(e2), 256) # 16x16",
"",
"# BOTTLENECK",
"b = conv_block(pool(e3), 512) # 8x8",
"",
"# DECODER + SKIP",
"d3 = conv_block(concat(",
" upconv(b), e3), 256) # 16x16",
"d2 = conv_block(concat(",
" upconv(d3), e2), 128) # 32x32",
"d1 = conv_block(concat(",
" upconv(d2), e1), 64) # 64x64",
"",
"output = conv_1x1(d1, n_classes)",
]
for i, line in enumerate(code_lines):
txt_color = (
ACCENT
if "concat" in line
else (GREEN_ACCENT if "output" in line else BLACK)
)
ax.text(
0.3,
9.5 - i * 0.55,
line,
fontsize=FS_TINY,
fontfamily="monospace",
color=txt_color,
)
def generate_diy_unet() -> None:
"""Generate diy unet."""
fig, axes = plt.subplots(2, 3, figsize=(11, 7))
size = 64
# Create synthetic image with two regions
img = np.ones((size, size, 3), dtype=np.uint8) * 200 # bright bg
# Dark region (object 1)
yy, xx = np.mgrid[:size, :size]
mask1 = ((xx - 20) ** 2 + (yy - 30) ** 2) < 12**2
img[mask1] = [60, 60, 60]
# Medium region (object 2)
mask2 = ((xx - 45) ** 2 + (yy - 25) ** 2) < 8**2
img[mask2] = [120, 120, 120]
gt = np.zeros((size, size), dtype=np.uint8)
gt[mask1] = 1 # class 1
gt[mask2] = 2 # class 2
# --- Panel 1: Input image ---
ax = axes[0, 0]
ax.imshow(img)
ax.set_title("Krok 1: obraz RGB\n64x64x3", fontsize=FS, fontweight="bold")
ax.axis("off")
# --- Panel 2: Encoder shrinks ---
ax = axes[0, 1]
ax.set_title("Krok 2: Encoder ZMNIEJSZA", fontsize=FS, fontweight="bold")
_draw_unet_layer_stack(
ax,
[(64, 3), (32, 64), (16, 128), (8, 256)],
face_color=ACCENT_LIGHT,
edge_color=ACCENT,
arrow_color=ACCENT,
arrow_label="Conv+Pool",
)
ax.text(
5,
0.3,
"Wyciąga cechy:\nkrawędzie → tekstury → obiekty",
ha="center",
fontsize=FS_TINY,
color=GRAY5,
)
# --- Panel 3: Bottleneck ---
ax = axes[0, 2]
# Show feature maps at bottleneck (abstract)
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title(
"Krok 3: Bottleneck\n(najbardziej abstrakcyjne cechy)",
fontsize=FS,
fontweight="bold",
)
# Show small abstract feature maps
for k in range(4):
small = rng.random((4, 4))
ax_inset = fig.add_axes(
[0.68 + (k % 2) * 0.08, 0.72 - (k // 2) * 0.1, 0.06, 0.06]
)
ax_inset.imshow(small, cmap="gray")
ax_inset.axis("off")
ax.text(
5,
5,
'8x8x256\n\nMałe mapy, ale DUŻO kanałów\nKażdy kanał = jedna „cecha"\n'
'(np. kanał 42 = „wykrył koło"\n kanał 78 = „wykrył krawędź")\n\n'
"Wie CO jest na obrazie\nale nie wie GDZIE dokładnie",
ha="center",
va="center",
fontsize=FS_SMALL,
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
)
# --- Panel 4: Decoder enlarges ---
ax = axes[1, 0]
ax.set_title(
"Krok 4: Decoder ZWIĘKSZA\n(+ skip connections!)",
fontsize=FS,
fontweight="bold",
)
_draw_unet_layer_stack(
ax,
[(8, 256), (16, 128), (32, 64), (64, 3)],
face_color="#C8E6C9",
edge_color=GREEN_ACCENT,
arrow_color=GREEN_ACCENT,
arrow_label="UpConv+Concat",
add_skip=True,
)
ax.text(
5,
0.3,
"Odtwarza rozdzielczość:\nskip → przywraca krawędzie",
ha="center",
fontsize=FS_TINY,
color=GRAY5,
)
# --- Panel 5: Output segmentation map ---
ax = axes[1, 1]
cmap = plt.cm.colors.ListedColormap([WHITE, ACCENT_LIGHT, "#FFCDD2"])
ax.imshow(gt, cmap=cmap, interpolation="nearest")
ax.set_title(
"Krok 5: mapa segmentacji\n64x64 (3 klasy)", fontsize=FS, fontweight="bold"
)
ax.axis("off")
ax.text(20, -3, "Tło=0, obiekt A=1, obiekt B=2", fontsize=FS_TINY, ha="center")
# --- Panel 6: Summary pseudocode ---
_draw_unet_pseudocode(axes[1, 2])
_save_figure("q23_diy_unet.png")

View File

@ -0,0 +1,380 @@
"""Mean shift and normalized cuts diagram generators."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q23_common import (
_DARK_PIXEL_THRESHOLD,
_GRID_LAST_IDX,
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TINY,
FS_TITLE,
GRAY1,
GRAY3,
GRAY4,
GRAY5,
GRAY6,
GREEN_ACCENT,
RED_ACCENT,
_render_text_lines,
_save_figure,
np,
plt,
rng,
)
from matplotlib import patches
if TYPE_CHECKING:
from matplotlib.axes import Axes
def generate_mean_shift() -> None:
"""Generate mean shift."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
# --- Panel 1: Feature space concept ---
ax = axes[0]
# Three clusters in 2D feature space (brightness, x-position)
c1x = rng.normal(2, 0.5, 40)
c1y = rng.normal(2, 0.5, 40)
c2x = rng.normal(6, 0.6, 35)
c2y = rng.normal(7, 0.5, 35)
c3x = rng.normal(8, 0.4, 25)
c3y = rng.normal(3, 0.6, 25)
ax.scatter(c1x, c1y, c=GRAY4, s=15, alpha=0.7, zorder=3)
ax.scatter(c2x, c2y, c=GRAY4, s=15, alpha=0.7, zorder=3)
ax.scatter(c3x, c3y, c=GRAY4, s=15, alpha=0.7, zorder=3)
# Label peaks
ax.scatter([2], [2], c=RED_ACCENT, s=80, marker="*", zorder=5, label="Max gęstości")
ax.scatter([6], [7], c=RED_ACCENT, s=80, marker="*", zorder=5)
ax.scatter([8], [3], c=RED_ACCENT, s=80, marker="*", zorder=5)
ax.set_xlabel("Cecha 1: jasność", fontsize=FS)
ax.set_ylabel("Cecha 2: pozycja x", fontsize=FS)
ax.set_title("Przestrzeń cech", fontsize=FS_TITLE, fontweight="bold")
for lx, ly, ltxt in [
(2, 0.3, "Klaster 1\n(ciemne, lewo)"),
(6, 5.3, "Klaster 2\n(jasne, prawo)"),
(8, 1.3, "Klaster 3\n(jasne, dół)"),
]:
ax.text(lx, ly, ltxt, ha="center", fontsize=FS_TINY, color=GRAY6)
ax.legend(fontsize=FS_SMALL, loc="upper left")
# --- Panel 2: Kernel/window moving ---
ax = axes[1]
ax.scatter(c1x, c1y, c=ACCENT_LIGHT, s=15, alpha=0.7, zorder=3)
ax.scatter(c2x, c2y, c=GRAY3, s=15, alpha=0.7, zorder=3)
ax.scatter(c3x, c3y, c=GRAY3, s=15, alpha=0.7, zorder=3)
# Show kernel movement
path_x = [4.5, 3.8, 3.0, 2.3, 2.05]
path_y = [4.0, 3.3, 2.7, 2.2, 2.03]
for i, (px, py) in enumerate(zip(path_x, path_y, strict=False)):
alpha = 0.3 + 0.15 * i
circle = plt.Circle(
(px, py),
1.2,
fill=False,
edgecolor=ACCENT,
linewidth=1.5,
linestyle="--" if i < len(path_x) - 1 else "-",
alpha=alpha,
)
ax.add_patch(circle)
if i < len(path_x) - 1:
ax.annotate(
"",
xy=(path_x[i + 1], path_y[i + 1]),
xytext=(px, py),
arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 1.5},
)
ax.scatter([path_x[0]], [path_y[0]], c=ACCENT, s=50, marker="o", zorder=5)
ax.scatter([path_x[-1]], [path_y[-1]], c=RED_ACCENT, s=80, marker="*", zorder=5)
ax.text(
4.5, 5.2, "Start: losowy\npiksel", fontsize=FS_SMALL, ha="center", color=ACCENT
)
ax.text(
2.05,
0.5,
"Koniec: max\ngęstości",
fontsize=FS_SMALL,
ha="center",
color=RED_ACCENT,
fontweight="bold",
)
ax.text(
7,
8,
"Okno (jądro)\nprzesuwa się\ndo skupiska",
fontsize=FS_SMALL,
ha="center",
color=GRAY6,
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
)
ax.set_xlabel("Cecha 1", fontsize=FS)
ax.set_ylabel("Cecha 2", fontsize=FS)
ax.set_title("Jądro → max gęstości", fontsize=FS_TITLE, fontweight="bold")
ax.set_xlim(0, 10)
ax.set_ylim(0, 9)
# --- Panel 3: Why no K parameter ---
ax = axes[2]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Dlaczego bez K?", fontsize=FS_TITLE, fontweight="bold")
lines = [
("K-means wymaga:", FS, RED_ACCENT, "bold"),
(' „Podaj K=3 klastry"', FS_SMALL, "black", "normal"),
(" Problem: skąd wiesz ile klastrów?", FS_SMALL, GRAY5, "normal"),
("", 0, "", ""),
("Mean Shift NIE wymaga K:", FS, GREEN_ACCENT, "bold"),
(" Każdy piksel startuje → toczy się", FS_SMALL, "black", "normal"),
(" → trafia do najbliższego szczytu", FS_SMALL, "black", "normal"),
(" → ile szczytów = tyle segmentów", FS_SMALL, "black", "normal"),
(" → automatycznie!", FS_SMALL, GREEN_ACCENT, "bold"),
("", 0, "", ""),
("Parametr: bandwidth (szerokość okna)", FS, "black", "bold"),
(" Duże okno → mało segmentów", FS_SMALL, "black", "normal"),
(" Małe okno → dużo segmentów", FS_SMALL, "black", "normal"),
("", 0, "", ""),
("Okno = jądro (kernel):", FS, "black", "bold"),
(" Koło o promieniu h wokół punktu.", FS_SMALL, "black", "normal"),
(" Oblicz średnią pikseli W oknie.", FS_SMALL, "black", "normal"),
(" Przesuń okno na tę średnią.", FS_SMALL, "black", "normal"),
(" Powtórz aż się zatrzyma.", FS_SMALL, "black", "normal"),
]
_render_text_lines(ax, lines, start_y=9.0)
_save_figure("q23_mean_shift.png")
def _draw_ncuts_pixel_grid(
ax: Axes,
pixel_vals: np.ndarray,
) -> None:
"""Draw 4x4 pixel grid with value labels and edge weights."""
for i in range(4):
for j in range(4):
v = pixel_vals[i, j]
gray_val = v / 255.0
str(gray_val)
rect = patches.Rectangle(
(j - 0.4, 3 - i - 0.4),
0.8,
0.8,
facecolor=(gray_val, gray_val, gray_val),
edgecolor=BLACK,
linewidth=0.8,
)
ax.add_patch(rect)
text_color = "white" if v < _DARK_PIXEL_THRESHOLD else "black"
ax.text(
j,
3 - i,
str(v),
ha="center",
va="center",
fontsize=FS_SMALL,
color=text_color,
fontweight="bold",
)
def _draw_ncuts_edges(
ax: Axes,
pixel_vals: np.ndarray,
) -> None:
"""Draw weighted edges between adjacent pixels."""
for i in range(4):
for j in range(4):
if j < _GRID_LAST_IDX:
similarity = max(
0,
1 - abs(pixel_vals[i, j] - pixel_vals[i, j + 1]) / 255,
)
lw = similarity * 2.5 + 0.3
alpha = similarity * 0.8 + 0.2
ax.plot(
[j + 0.4, j + 0.6],
[3 - i, 3 - i],
color=GRAY5,
linewidth=lw,
alpha=alpha,
)
if i < _GRID_LAST_IDX:
similarity = max(
0,
1 - abs(pixel_vals[i, j] - pixel_vals[i + 1, j]) / 255,
)
lw = similarity * 2.5 + 0.3
alpha = similarity * 0.8 + 0.2
ax.plot(
[j, j],
[3 - i - 0.4, 3 - i - 0.6],
color=GRAY5,
linewidth=lw,
alpha=alpha,
)
def generate_normalized_cuts() -> None:
"""Generate normalized cuts."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
# --- Panel 1: Image as graph ---
ax = axes[0]
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.5, 4.5)
ax.set_aspect("equal")
ax.set_title("Obraz → graf", fontsize=FS_TITLE, fontweight="bold")
pixel_vals = np.array(
[
[30, 35, 180, 190],
[40, 30, 185, 200],
[170, 180, 40, 35],
[190, 175, 30, 45],
]
)
_draw_ncuts_pixel_grid(ax, pixel_vals)
_draw_ncuts_edges(ax, pixel_vals)
ax.text(
2,
-0.8,
"Grube linie = duże podobieństwo\n(silna krawędź grafu)",
ha="center",
fontsize=FS_TINY,
color=GRAY5,
)
ax.axis("off")
# --- Panel 2: Cut concept ---
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Cięcie grafu (graph cut)", fontsize=FS_TITLE, fontweight="bold")
# Draw two groups of nodes
# Group A (dark pixels)
positions_a = [(2, 7), (3, 8), (2, 5), (3, 6)]
positions_b = [(7, 7), (8, 8), (7, 5), (8, 6)]
# Intra-group edges (thick = similar)
for i, (x1, y1) in enumerate(positions_a):
for x2, y2 in positions_a[i + 1 :]:
ax.plot([x1, x2], [y1, y2], color=ACCENT, linewidth=2, alpha=0.5)
for i, (x1, y1) in enumerate(positions_b):
for x2, y2 in positions_b[i + 1 :]:
ax.plot([x1, x2], [y1, y2], color=RED_ACCENT, linewidth=2, alpha=0.5)
# Inter-group edges (thin = dissimilar) — these get cut
cut_edges = [((3, 8), (7, 7)), ((3, 6), (7, 5)), ((2, 5), (7, 5))]
for (x1, y1), (x2, y2) in cut_edges:
ax.plot([x1, x2], [y1, y2], color=GRAY4, linewidth=0.8, linestyle="--")
# Draw nodes
for x, y in positions_a:
ax.scatter(x, y, c=ACCENT, s=120, zorder=5, edgecolors=BLACK, linewidth=0.8)
for x, y in positions_b:
ax.scatter(x, y, c="#FFCDD2", s=120, zorder=5, edgecolors=BLACK, linewidth=0.8)
# Cut line
ax.plot(
[5, 5], [3.5, 9.5], color=RED_ACCENT, linewidth=2.5, linestyle="-", zorder=4
)
ax.text(
5, 9.8, "CIĘCIE", ha="center", fontsize=FS, fontweight="bold", color=RED_ACCENT
)
ax.text(
2.5,
3.8,
"Segment A\n(ciemne piksele)",
ha="center",
fontsize=FS_SMALL,
color=ACCENT,
)
ax.text(
7.5,
3.8,
"Segment B\n(jasne piksele)",
ha="center",
fontsize=FS_SMALL,
color=RED_ACCENT,
)
# Formula
ax.text(
5,
1.8,
"Ncut(A,B) = cut(A,B)/assoc(A,V)\n + cut(A,B)/assoc(B,V)",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
)
ax.text(
5,
0.5,
"Minimalizuj Ncut → tnij SŁABE krawędzie\nzachowuj SILNE (wewnątrz grupy)",
ha="center",
fontsize=FS_TINY,
color=GRAY5,
)
# --- Panel 3: Algorithm summary ---
ax = axes[2]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Algorytm Normalized Cuts", fontsize=FS_TITLE, fontweight="bold")
steps = [
(
"1. Zbuduj graf",
"Piksele = węzły\nKrawędzie = podobieństwo"
" sąsiadów\n(kolor, jasność, odległość)",
),
(
"2. Macierz podobieństwa W",
"W[i,j] = exp(-|kolori - kolorj|² / σ²)"
"\n→ im podobniejsze, tym wyższa waga",
),
("3. Macierz stopni D", "D[i,i] = Σ W[i,j]\n(suma wszystkich wag z węzła i)"),
("4. Rozwiąż problem własny", "(D-W)·y = λ·D·y\n→ drugi najm. wektor własny y"),
("5. Podziel wg y", "y[i] > 0 → segment A\ny[i] ≤ 0 → segment B"),
]
y = 9.5
for title, desc in steps:
ax.text(0.5, y, title, fontsize=FS, fontweight="bold", va="top")
y -= 0.4
ax.text(0.8, y, desc, fontsize=FS_TINY, va="top", color=GRAY6)
y -= 1.2
ax.text(
5,
0.3,
"Złożoność: O(n³) — wymaga eigen decomposition!",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
color=RED_ACCENT,
)
_save_figure("q23_normalized_cuts.png")

View File

@ -0,0 +1,327 @@
"""Mnemonic summary diagram generator."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q23_common import (
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TINY,
FS_TITLE,
GRAY1,
GRAY5,
GRAY6,
GREEN_ACCENT,
RED_ACCENT,
_save_figure,
plt,
)
from matplotlib.patches import FancyBboxPatch
if TYPE_CHECKING:
from matplotlib.axes import Axes
def generate_mnemonics() -> None:
"""Generate mnemonics."""
_fig, ax = plt.subplots(1, 1, figsize=(10, 8))
ax.set_xlim(0, 20)
ax.set_ylim(0, 16)
ax.axis("off")
ax.set_title(
"Mnemoniki — segmentacja obrazu", fontsize=FS_TITLE + 2, fontweight="bold"
)
def draw_card(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
title: str,
mnemonic: str,
color: str,
detail: str = "",
) -> None:
"""Draw card."""
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.15",
facecolor=color,
edgecolor=BLACK,
linewidth=1,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h - 0.3,
title,
ha="center",
va="top",
fontsize=FS,
fontweight="bold",
)
ax.text(
x + w / 2,
y + h / 2 - 0.1,
mnemonic,
ha="center",
va="center",
fontsize=FS_SMALL,
fontstyle="italic",
color=GRAY6,
)
if detail:
ax.text(
x + w / 2,
y + 0.4,
detail,
ha="center",
va="bottom",
fontsize=FS_TINY,
color=GRAY5,
)
# Title: STRATEGIE KLASYCZNE
ax.text(
5,
15.5,
"STRATEGIE KLASYCZNE",
fontsize=FS_TITLE,
fontweight="bold",
color=ACCENT,
ha="center",
)
cards_classic = [
(
0.2,
12.5,
4.5,
2.5,
"Thresholding",
'„PRÓG na bramce"\nPrzepuszcza > T,\nblokuje ≤ T',
ACCENT_LIGHT,
"jasne=1, ciemne=0",
),
(
5,
12.5,
4.5,
2.5,
"Otsu",
'„AUTO-bramkarz"\nSam dobiera próg\nmin σ² wewnątrz',
ACCENT_LIGHT,
"histogram bimodalny",
),
(
0.2,
9.5,
4.5,
2.5,
"Region Growing",
'„PLAMA rozlana"\nSeed → BFS po\npodobnych sąsiadach',
ACCENT_LIGHT,
"jak atrament na papierze",
),
(
5,
9.5,
4.5,
2.5,
"Watershed",
'„ZALEWANIE terenu"\nDoliny=obiekty\nGranie=granice',
ACCENT_LIGHT,
"woda + geography",
),
(
0.2,
6.5,
4.5,
2.5,
"Mean Shift",
'„KULKI toczą się"\nKażda → max gęstości\nBez K!',
ACCENT_LIGHT,
"bandwidth = okno",
),
(
5,
6.5,
4.5,
2.5,
"Normalized Cuts",
'„CIĘCIE sznurków"\nGraf: tnij słabe\nkrawędzie (O(n³)!)',
ACCENT_LIGHT,
"eigenvector problem",
),
]
for args in cards_classic:
draw_card(ax, *args)
# Title: SIECI NEURONOWE
ax.text(
15,
15.5,
"SIECI NEURONOWE",
fontsize=FS_TITLE,
fontweight="bold",
color=GREEN_ACCENT,
ha="center",
)
cards_nn = [
(
10.5,
12.5,
4.5,
2.5,
"FCN (2015)",
'„FC → Conv 1x1"\nPierwsza end-to-end\nDowolny rozmiar',
"#C8E6C9",
"skip connections",
),
(
15.3,
12.5,
4.5,
2.5,
"U-Net (2015)",
'„Litera U"\nEncoder↓ Decoder↑\nSkip = concat',
"#C8E6C9",
"medycyna, małe dane",
),
(
10.5,
9.5,
4.5,
2.5,
"DeepLab v3+",
'„DZIURY w filtrze"\nAtrous conv (rate)\nASPP multi-scale',
"#C8E6C9",
"à trous = z dziurami",
),
(
15.3,
9.5,
4.5,
2.5,
"Transformer",
'„WSZYSCY ze\nWSZYSTKIMI"\nSelf-attention O(n²)',
"#C8E6C9",
"SegFormer, Mask2Former",
),
]
for args in cards_nn:
draw_card(ax, *args)
# Metryki
ax.text(
10,
8.3,
"METRYKI I LOSS",
fontsize=FS_TITLE,
fontweight="bold",
color=RED_ACCENT,
ha="center",
)
cards_metrics = [
(
10.5,
6.5,
4.5,
1.6,
"mIoU",
'„Nakładka / Suma"\nIoU = A∩B / A\u222aB',
"#FFCDD2",
"",
),
(
15.3,
6.5,
4.5,
1.6,
"Dice / Focal",
'„Dice=2·nakładka"\nFocal=trudne px',
"#FFCDD2",
"",
),
]
for args in cards_metrics:
draw_card(ax, *args)
# Master mnemonic at bottom
rect = FancyBboxPatch(
(1, 0.3),
18,
5.5,
boxstyle="round,pad=0.2",
facecolor=GRAY1,
edgecolor=BLACK,
linewidth=1.5,
)
ax.add_patch(rect)
ax.text(
10,
5.3,
"SUPER-MNEMONIK: kolejność algorytmów segmentacji",
ha="center",
fontsize=FS,
fontweight="bold",
)
ax.text(
10,
4.5,
'„TORW-MN FUD-T"',
ha="center",
fontsize=FS_TITLE + 2,
fontweight="bold",
color=RED_ACCENT,
)
ax.text(
10,
3.5,
"Klasyczne: Thresholding → Otsu → Region"
" growing → Watershed → Mean shift → Norm. cuts",
ha="center",
fontsize=FS_SMALL,
)
ax.text(
10,
2.8,
"Neuronowe: FCN → U-Net → DeepLab → Transformer",
ha="center",
fontsize=FS_SMALL,
)
ax.text(
10,
1.8,
"„Turyści Oglądają Rzekę, Wodospad,"
" Morze, Nurt — Fotografują Uroczy"
' Dwór Tajemnic"',
ha="center",
fontsize=FS_SMALL,
fontstyle="italic",
color=ACCENT,
)
ax.text(
10,
1.0,
"Klasyczne: proste→auto→BFS→flood→"
"gęstość→graf | Neuronowe:"
" FC→U-skip→dilated→attention",
ha="center",
fontsize=FS_TINY,
color=GRAY5,
)
_save_figure("q23_mnemonics.png")

View File

@ -0,0 +1,293 @@
"""ReLU and dot product diagram generators."""
from __future__ import annotations
from _q23_common import (
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TINY,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY5,
GREEN_ACCENT,
RED_ACCENT,
_save_figure,
np,
plt,
)
from matplotlib import patches
def generate_relu() -> None:
"""Generate relu."""
_fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
# --- Panel 1: ReLU plot ---
ax = axes[0]
x = np.linspace(-5, 5, 200)
relu = np.maximum(0, x)
ax.plot(x, relu, color=ACCENT, linewidth=2.5, label="ReLU(x) = max(0, x)")
ax.axhline(y=0, color=GRAY3, linewidth=0.5)
ax.axvline(x=0, color=GRAY3, linewidth=0.5)
ax.fill_between(x[x < 0], 0, 0, color=RED_ACCENT, alpha=0.1)
ax.fill_between(x[x >= 0], 0, relu[x >= 0], color=ACCENT, alpha=0.1)
# Annotations
ax.annotate(
'x < 0 → output = 0\n(neuron „wyłączony")',
xy=(-3, 0),
fontsize=FS_SMALL,
ha="center",
va="bottom",
color=RED_ACCENT,
arrowprops={"arrowstyle": "->", "color": RED_ACCENT},
xytext=(-3, 2),
)
ax.annotate(
'x ≥ 0 → output = x\n(neuron „włączony")',
xy=(3, 3),
fontsize=FS_SMALL,
ha="center",
va="bottom",
color=ACCENT,
arrowprops={"arrowstyle": "->", "color": ACCENT},
xytext=(3, 4.5),
)
ax.scatter([0], [0], c=BLACK, s=40, zorder=5)
ax.text(0.3, -0.5, "(0,0)", fontsize=FS_SMALL, color=GRAY5)
ax.set_xlabel("x (wejście neuronu)", fontsize=FS)
ax.set_ylabel("ReLU(x)", fontsize=FS)
ax.set_title("ReLU — Rectified Linear Unit", fontsize=FS_TITLE, fontweight="bold")
ax.legend(fontsize=FS_SMALL, loc="upper left")
ax.set_ylim(-1, 6)
ax.grid(visible=True, alpha=0.2)
# --- Panel 2: Why ReLU ---
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Dlaczego ReLU?", fontsize=FS_TITLE, fontweight="bold")
y = 9.0
lines = [
("Neuron oblicza:", FS, BLACK, "bold"),
(" z = w₁·x₁ + w₂·x₂ + ... + bias", FS_SMALL, BLACK, "normal"),
(" output = ReLU(z) = max(0, z)", FS_SMALL, ACCENT, "bold"),
("", 0, "", ""),
("Przykład:", FS, BLACK, "bold"),
(" wagi: w₁=0.5, w₂=-0.3, bias=0.1", FS_SMALL, BLACK, "normal"),
(" wejścia: x₁=2.0, x₂=4.0", FS_SMALL, BLACK, "normal"),
(" z = 0.5·2 + (-0.3)·4 + 0.1 = -0.1", FS_SMALL, BLACK, "normal"),
(" ReLU(-0.1) = max(0, -0.1) = 0", FS_SMALL, RED_ACCENT, "bold"),
(" → neuron milczy (wejście nieistotne)", FS_SMALL, GRAY5, "normal"),
("", 0, "", ""),
("Gdyby z = 2.3:", FS, BLACK, "bold"),
(" ReLU(2.3) = max(0, 2.3) = 2.3", FS_SMALL, GREEN_ACCENT, "bold"),
(" → neuron aktywny! Przekazuje sygnał", FS_SMALL, GRAY5, "normal"),
("", 0, "", ""),
("Szybsza niż sigmoid/tanh", FS_SMALL, GRAY5, "normal"),
("(brak exp() → szybkie obliczenia)", FS_SMALL, GRAY5, "normal"),
]
for txt, size, color, weight in lines:
if txt == "":
y -= 0.2
continue
ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
y -= 0.5
_save_figure("q23_relu.png")
def generate_dot_product() -> None:
"""Generate dot product."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
# --- Panel 1: Concept ---
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title(
"Iloczyn skalarny\n(dot product)", fontsize=FS_TITLE, fontweight="bold"
)
y = 8.5
lines = [
("Dwa wektory (listy liczb) → JEDNA liczba", FS, BLACK, "bold"),
("", 0, "", ""),
("a = [a₁, a₂, a₃] b = [b₁, b₂, b₃]", FS, ACCENT, "normal"),
("", 0, "", ""),
("a · b = a₁·b₁ + a₂·b₂ + a₃·b₃", FS, BLACK, "bold"),
("", 0, "", ""),
("Przykład:", FS, BLACK, "bold"),
("a = [1, 3, -2] b = [4, -1, 5]", FS_SMALL, BLACK, "normal"),
("a·b = 1·4 + 3·(-1) + (-2)·5", FS_SMALL, BLACK, "normal"),
(" = 4 + (-3) + (-10) = -9", FS_SMALL, RED_ACCENT, "bold"),
("", 0, "", ""),
(
'Duży wynik → wektory „podobne" (w tym samym kierunku)',
FS_SMALL,
GREEN_ACCENT,
"normal",
),
('Mały/ujemny → wektory „różne"', FS_SMALL, RED_ACCENT, "normal"),
]
for txt, size, color, weight in lines:
if txt == "":
y -= 0.25
continue
ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
y -= 0.55
# --- Panel 2: Convolution as dot product ---
ax = axes[1]
ax.set_xlim(-0.5, 5.5)
ax.set_ylim(-0.5, 5.5)
ax.set_aspect("equal")
ax.set_title(
"Konwolucja = iloczyn skalarny\nfiltra x fragment obrazu",
fontsize=FS_TITLE,
fontweight="bold",
)
# Filter 3x3
filter_vals = [[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]
for i in range(3):
for j in range(3):
rect = patches.Rectangle(
(j - 0.4, 4 - i - 0.4),
0.8,
0.8,
facecolor=ACCENT_LIGHT,
edgecolor=BLACK,
linewidth=0.8,
)
ax.add_patch(rect)
ax.text(
j,
4 - i,
str(filter_vals[i][j]),
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
ax.text(1, 1.5, "Filtr", ha="center", fontsize=FS, fontweight="bold", color=ACCENT)
# Image patch
img_vals = [[50, 50, 200], [50, 50, 200], [50, 50, 200]]
for i in range(3):
for j in range(3):
rect = patches.Rectangle(
(j + 2.6, 4 - i - 0.4),
0.8,
0.8,
facecolor=GRAY2,
edgecolor=BLACK,
linewidth=0.8,
)
ax.add_patch(rect)
ax.text(
j + 3,
4 - i,
str(img_vals[i][j]),
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
ax.text(
4,
1.5,
"Fragment\nobrazu",
ha="center",
fontsize=FS,
fontweight="bold",
color=GRAY5,
)
ax.text(
2.5,
0.5,
"(-1)·50 + 0·50 + 1·200 +\n"
"(-1)·50 + 0·50 + 1·200 +\n"
"(-1)·50 + 0·50 + 1·200\n= 450 (krawędź!)",
ha="center",
fontsize=FS_TINY,
fontweight="bold",
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GREEN_ACCENT},
)
ax.axis("off")
# --- Panel 3: Vector visualization ---
ax = axes[2]
# Draw two vectors
ax.quiver(
0,
0,
3,
4,
angles="xy",
scale_units="xy",
scale=1,
color=ACCENT,
width=0.025,
label="a = [3, 4]",
)
ax.quiver(
0,
0,
4,
1,
angles="xy",
scale_units="xy",
scale=1,
color=RED_ACCENT,
width=0.025,
label="b = [4, 1]",
)
# Show angle
theta = np.linspace(np.arctan2(1, 4), np.arctan2(4, 3), 30)
r = 1.5
ax.plot(r * np.cos(theta), r * np.sin(theta), color=GREEN_ACCENT, linewidth=1.5)
ax.text(1.8, 1.3, "θ", fontsize=FS, color=GREEN_ACCENT, fontweight="bold")
ax.text(3.2, 4.2, "a", fontsize=FS, color=ACCENT, fontweight="bold")
ax.text(4.2, 1.2, "b", fontsize=FS, color=RED_ACCENT, fontweight="bold")
ax.text(
2.5,
-1.0,
"a · b = |a|·|b|·cos(θ)\n= 3·4 + 4·1 = 16",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
)
ax.text(
2.5,
-2.0,
'Mały kąt θ → duży dot product\n= wektory „zgadają się"',
ha="center",
fontsize=FS_TINY,
color=GRAY5,
)
ax.set_xlim(-0.5, 5.5)
ax.set_ylim(-2.5, 5.5)
ax.set_aspect("equal")
ax.grid(visible=True, alpha=0.2)
ax.legend(fontsize=FS_SMALL, loc="upper left")
ax.set_title("Geometrycznie: kąt", fontsize=FS_TITLE, fontweight="bold")
_save_figure("q23_dot_product.png")

View File

@ -0,0 +1,408 @@
"""Otsu thresholding and watershed diagram generators."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q23_common import (
_RIDGE_X,
_VALLEY2_END,
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
GREEN_ACCENT,
RED_ACCENT,
_render_text_lines,
_save_figure,
np,
plt,
rng,
)
from matplotlib.patches import FancyBboxPatch
if TYPE_CHECKING:
from matplotlib.axes import Axes
def _draw_otsu_variance_panel(ax: Axes) -> None:
"""Draw panel 2: within-class variance explanation."""
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Wariancja wewnątrzklasowa", fontsize=FS_TITLE, fontweight="bold")
texts = [
(
"Wariancja = jak bardzo wartości\nróżnią się od średniej",
FS,
"black",
"normal",
),
("", 0, "black", "normal"),
("Klasa 0 (piksele ≤ T):", FS, ACCENT, "bold"),
(" wartości: 30, 50, 45, 60, 55", FS_SMALL, "black", "normal"),
(" średnia μ₀ = 48", FS_SMALL, "black", "normal"),
(" σ₀² = ((30-48)²+(50-48)²+...)/5 = 108", FS_SMALL, "black", "normal"),
("", 0, "black", "normal"),
("Klasa 1 (piksele > T):", FS, RED_ACCENT, "bold"),
(" wartości: 180, 200, 190, 210, 195", FS_SMALL, "black", "normal"),
(" średnia μ₁ = 195", FS_SMALL, "black", "normal"),
(" σ₁² = ((180-195)²+...)/5 = 100", FS_SMALL, "black", "normal"),
("", 0, "black", "normal"),
("σ²_wewnątrz = w₀·σ₀² + w₁·σ₁²", FS, BLACK, "bold"),
("= 0.6·108 + 0.4·100 = 104.8", FS_SMALL, "black", "normal"),
("", 0, "black", "normal"),
("Otsu próbuje KAŻDE T: 0,1,...,255", FS_SMALL, GREEN_ACCENT, "bold"),
("Wybiera T dające MINIMUM σ²_wewnątrz", FS_SMALL, GREEN_ACCENT, "bold"),
]
_render_text_lines(
ax,
texts,
x_pos=0.3,
start_y=9.2,
y_step=0.55,
y_empty_step=0.25,
)
def generate_otsu_bimodal() -> None:
"""Generate otsu bimodal."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
# --- Panel 1: Bimodal histogram ---
ax = axes[0]
dark = rng.normal(60, 20, 3000).clip(0, 255)
bright = rng.normal(190, 25, 2000).clip(0, 255)
all_pixels = np.concatenate([dark, bright])
counts, _bins, _bars = ax.hist(
all_pixels, bins=64, color=GRAY3, edgecolor=GRAY5, linewidth=0.5
)
ax.axvline(
x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="Próg Otsu T=128"
)
ax.fill_betweenx([0, max(counts) * 1.1], 0, 128, alpha=0.12, color=ACCENT)
ax.fill_betweenx([0, max(counts) * 1.1], 128, 255, alpha=0.12, color=RED_ACCENT)
ax.text(
45,
max(counts) * 0.85,
"Klasa 0\n(tło)",
ha="center",
fontsize=FS,
fontweight="bold",
color=ACCENT,
)
ax.text(
195,
max(counts) * 0.85,
"Klasa 1\n(obiekt)",
ha="center",
fontsize=FS,
fontweight="bold",
color=RED_ACCENT,
)
ax.annotate(
"Garb 1",
xy=(60, max(counts) * 0.6),
fontsize=FS_SMALL,
ha="center",
arrowprops={"arrowstyle": "->", "color": GRAY5},
xytext=(30, max(counts) * 0.45),
)
ax.annotate(
"Garb 2",
xy=(190, max(counts) * 0.5),
fontsize=FS_SMALL,
ha="center",
arrowprops={"arrowstyle": "->", "color": GRAY5},
xytext=(220, max(counts) * 0.35),
)
ax.set_xlabel("Jasność piksela (0-255)", fontsize=FS)
ax.set_ylabel("Liczba pikseli", fontsize=FS)
ax.set_title("Histogram bimodalny", fontsize=FS_TITLE, fontweight="bold")
ax.legend(fontsize=FS_SMALL, loc="upper right")
ax.set_xlim(0, 255)
# --- Panel 2: Within-class variance explanation ---
_draw_otsu_variance_panel(axes[1])
# --- Panel 3: Jednorodność explanation ---
ax = axes[2]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title('"Jednorodne" = małe σ²', fontsize=FS_TITLE, fontweight="bold")
# Draw two clusters
# Good separation
c0 = rng.normal(2, 0.4, 15)
c1 = rng.normal(7, 0.4, 15)
y_pos_0 = rng.uniform(6, 8, 15)
y_pos_1 = rng.uniform(6, 8, 15)
ax.scatter(c0, y_pos_0, c=ACCENT, s=30, zorder=5, label="Klasa 0")
ax.scatter(c1, y_pos_1, c=RED_ACCENT, s=30, zorder=5, label="Klasa 1")
ax.axvline(x=4.5, color=GREEN_ACCENT, linewidth=2, linestyle="--")
ax.text(
4.5,
8.8,
"T optymalny",
ha="center",
fontsize=FS_SMALL,
color=GREEN_ACCENT,
fontweight="bold",
)
ax.text(
2, 5.3, "σ₀² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=ACCENT
)
ax.text(
7, 5.3, "σ₁² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=RED_ACCENT
)
ax.text(
5,
4,
"σ²_wewnątrz MINIMALNA\n→ klasy JEDNORODNE\n→ dobra segmentacja!",
ha="center",
fontsize=FS,
fontweight="bold",
color=GREEN_ACCENT,
)
# Bad separation
c0b = rng.normal(3.5, 1.5, 15)
c1b = rng.normal(6, 1.5, 15)
y_pos_0b = rng.uniform(1, 3, 15)
y_pos_1b = rng.uniform(1, 3, 15)
ax.scatter(c0b, y_pos_0b, c=ACCENT, s=30, marker="x", zorder=5)
ax.scatter(c1b, y_pos_1b, c=RED_ACCENT, s=30, marker="x", zorder=5)
ax.axvline(x=4.5, color=GRAY4, linewidth=1, linestyle=":", ymin=0, ymax=0.35)
ax.text(
5,
0.3,
"σ²_wewnątrz DUŻA → klasy mieszają się → zły próg",
ha="center",
fontsize=FS_SMALL,
color=GRAY5,
)
ax.legend(fontsize=FS_SMALL, loc="upper left")
_save_figure("q23_otsu_bimodal.png")
def _draw_watershed_result_panel(ax: Axes) -> None:
"""Draw panel 3: watershed result with over-segmentation problem."""
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Krok 3: wynik", fontsize=FS_TITLE, fontweight="bold")
rect1 = FancyBboxPatch(
(0.5, 6),
3.5,
3.2,
boxstyle="round,pad=0.1",
facecolor=ACCENT_LIGHT,
edgecolor=BLACK,
linewidth=1,
)
ax.add_patch(rect1)
ax.text(2.25, 8.8, "Ideał: 2 segmenty", fontsize=FS, ha="center", fontweight="bold")
ax.text(2.25, 7.5, "Segment A Segment B", fontsize=FS_SMALL, ha="center")
ax.text(
2.25,
6.7,
"(po marker-controlled)",
fontsize=FS_SMALL,
ha="center",
color=GREEN_ACCENT,
)
rect2 = FancyBboxPatch(
(5.5, 6),
4,
3.2,
boxstyle="round,pad=0.1",
facecolor="#FFCDD2",
edgecolor=BLACK,
linewidth=1,
)
ax.add_patch(rect2)
ax.text(
7.5,
8.8,
"Problem: over-segmentation",
fontsize=FS,
ha="center",
fontweight="bold",
color=RED_ACCENT,
)
ax.text(
7.5,
7.8,
"47 regionów zamiast 2!",
fontsize=FS_SMALL,
ha="center",
color=RED_ACCENT,
)
ax.text(7.5, 7.1, "Każde mini-minimum", fontsize=FS_SMALL, ha="center")
ax.text(7.5, 6.5, '→ osobna „dolina"', fontsize=FS_SMALL, ha="center")
# Apply marker-controlled solution
rect3 = FancyBboxPatch(
(1, 0.5),
8,
4.5,
boxstyle="round,pad=0.15",
facecolor=GRAY1,
edgecolor=GREEN_ACCENT,
linewidth=1.5,
)
ax.add_patch(rect3)
ax.text(
5,
4.3,
"Rozwiązanie: Marker-controlled watershed",
fontsize=FS,
ha="center",
fontweight="bold",
color=GREEN_ACCENT,
)
ax.text(
5,
3.4,
'1. Zaznacz ręcznie „seeds" (markery) w każdym obiekcie',
fontsize=FS_SMALL,
ha="center",
)
ax.text(
5,
2.7,
"2. Zalewaj TYLKO od tych markerów (nie od wszystkich minimów)",
fontsize=FS_SMALL,
ha="center",
)
ax.text(
5,
2.0,
"3. Eliminuje fałszywe doliny z szumu",
fontsize=FS_SMALL,
ha="center",
)
ax.text(
5,
1.2,
"Wynik: tyle segmentów, ile podano markerów",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
)
def generate_watershed() -> None:
"""Generate watershed."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 3.8))
# --- Panel 1: Image as topographic surface ---
ax = axes[0]
x = np.linspace(0, 10, 200)
# Create a surface with two valleys and a ridge
surface = (
3 * np.exp(-((x - 3) ** 2) / 1.5)
+ 4 * np.exp(-((x - 7) ** 2) / 1.2)
+ 0.5 * np.sin(x * 2)
+ 1
)
# Invert: valleys at objects (dark), peaks at boundaries (bright)
surface_inv = 6 - surface + 1
ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7)
ax.plot(x, surface_inv, color=BLACK, linewidth=1.5)
# Mark valleys
ax.annotate(
"Dolina 1\n(obiekt A)",
xy=(3, surface_inv[60]),
fontsize=FS_SMALL,
ha="center",
va="bottom",
arrowprops={"arrowstyle": "->", "color": ACCENT},
xytext=(1.5, 5.5),
)
ax.annotate(
"Dolina 2\n(obiekt B)",
xy=(7, surface_inv[140]),
fontsize=FS_SMALL,
ha="center",
va="bottom",
arrowprops={"arrowstyle": "->", "color": RED_ACCENT},
xytext=(8.5, 5.5),
)
# Mark ridge
ax.annotate(
"Grań\n(granica)",
xy=(5, surface_inv[100]),
fontsize=FS_SMALL,
ha="center",
va="bottom",
arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT},
xytext=(5, 6.5),
)
ax.set_xlabel("Pozycja piksela", fontsize=FS)
ax.set_ylabel("Jasność (= wysokość)", fontsize=FS)
ax.set_title("Krok 1: obraz → teren", fontsize=FS_TITLE, fontweight="bold")
ax.set_ylim(0, 7)
# --- Panel 2: Flooding ---
ax = axes[1]
ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7)
ax.plot(x, surface_inv, color=BLACK, linewidth=1.5)
# Water level
water_level = 3.2
# Fill water in valley 1
x_v1 = x[(x > 1) & (x < _RIDGE_X)]
s_v1 = surface_inv[(x > 1) & (x < _RIDGE_X)]
ax.fill_between(
x_v1, s_v1, water_level, where=s_v1 < water_level, color=ACCENT_LIGHT, alpha=0.6
)
# Fill water in valley 2
x_v2 = x[(x > _RIDGE_X) & (x < _VALLEY2_END)]
s_v2 = surface_inv[(x > _RIDGE_X) & (x < _VALLEY2_END)]
ax.fill_between(
x_v2, s_v2, water_level, where=s_v2 < water_level, color="#FFCDD2", alpha=0.6
)
ax.axhline(y=water_level, color=ACCENT, linewidth=1, linestyle="--", alpha=0.5)
ax.text(3, 2.5, "Woda A", fontsize=FS, ha="center", color=ACCENT, fontweight="bold")
ax.text(
7, 2.2, "Woda B", fontsize=FS, ha="center", color=RED_ACCENT, fontweight="bold"
)
ax.annotate(
"Tu się spotkają!\n→ GRANICA",
xy=(5, surface_inv[100]),
fontsize=FS_SMALL,
ha="center",
color=GREEN_ACCENT,
fontweight="bold",
arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT},
xytext=(5, 6.2),
)
ax.set_xlabel("Pozycja piksela", fontsize=FS)
ax.set_title("Krok 2: zalewanie", fontsize=FS_TITLE, fontweight="bold")
ax.set_ylim(0, 7)
# --- Panel 3: Result with problem ---
_draw_watershed_result_panel(axes[2])
_save_figure("q23_watershed.png")

View File

@ -0,0 +1,286 @@
"""Receptive field and transformer diagram generators."""
from __future__ import annotations
from _q23_common import (
_HIGHLIGHT_END,
_HIGHLIGHT_START,
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TITLE,
GRAY3,
GRAY5,
GREEN_ACCENT,
RED_ACCENT,
WHITE,
_save_figure,
plt,
)
from matplotlib import patches
def generate_receptive_field() -> None:
"""Generate receptive field."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
def draw_grid(
ax: patches.Axes,
size: int,
highlight_cells: list[tuple[int, int]],
highlight_color: str,
title: str,
grid_offset: tuple[int, int] = (0, 0),
) -> None:
"""Draw grid."""
ox, oy = grid_offset
for i in range(size):
for j in range(size):
color = WHITE
if (i, j) in highlight_cells:
color = highlight_color
rect = patches.Rectangle(
(ox + j, oy + size - 1 - i),
1,
1,
facecolor=color,
edgecolor=GRAY3, # Use GRAY3 instead of GRAY4 since unused
linewidth=0.5,
)
ax.add_patch(rect)
ax.set_title(title, fontsize=FS_TITLE, fontweight="bold")
# --- Panel 1: Standard 3x3 conv receptive field ---
ax = axes[0]
ax.set_xlim(-0.5, 7.5)
ax.set_ylim(-1, 8)
ax.set_aspect("equal")
ax.axis("off")
# 7x7 input grid
highlight_3x3 = [
(2, 2),
(2, 3),
(2, 4),
(3, 2),
(3, 3),
(3, 4),
(4, 2),
(4, 3),
(4, 4),
]
draw_grid(ax, 7, highlight_3x3, ACCENT_LIGHT, "Zwykła conv 3x3")
ax.text(
3.5,
-0.5,
"RF = 3x3 pikseli",
fontsize=FS,
ha="center",
fontweight="bold",
color=ACCENT,
)
# --- Panel 2: Dilated conv (rate=2) ---
ax = axes[1]
ax.set_xlim(-0.5, 7.5)
ax.set_ylim(-1, 8)
ax.set_aspect("equal")
ax.axis("off")
# 7x7 input grid with dilated highlights
highlight_dilated = [
(1, 1),
(1, 3),
(1, 5),
(3, 1),
(3, 3),
(3, 5),
(5, 1),
(5, 3),
(5, 5),
]
draw_grid(ax, 7, highlight_dilated, "#FFCDD2", "Dilated conv 3x3\n(rate=2)")
ax.text(
3.5,
-0.5,
"RF = 5x5, ale 9 parametrów!",
fontsize=FS,
ha="center",
fontweight="bold",
color=RED_ACCENT,
)
# Connect dots to show pattern
dots_x = [1.5, 3.5, 5.5, 1.5, 3.5, 5.5, 1.5, 3.5, 5.5]
dots_y = [5.5, 5.5, 5.5, 3.5, 3.5, 3.5, 1.5, 1.5, 1.5]
ax.scatter(dots_x, dots_y, c=RED_ACCENT, s=30, zorder=5)
# --- Panel 3: Comparison ---
ax = axes[2]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title(
"Receptive Field\n(pole widzenia neuronu)", fontsize=FS_TITLE, fontweight="bold"
)
y = 8.5
lines = [
("RF = ile pikseli WEJŚCIOWYCH", FS, BLACK, "bold"),
("wpływa na JEDEN piksel wyjścia", FS, BLACK, "bold"),
("", 0, "", ""),
("Rate (współczynnik dylatacji):", FS, BLACK, "bold"),
(' rate=1: filtr „dotyka" sąsiadów', FS_SMALL, BLACK, "normal"),
(" rate=2: co drugi piksel → RF = 5x5", FS_SMALL, BLACK, "normal"),
(" rate=3: co trzeci → RF = 7x7", FS_SMALL, BLACK, "normal"),
(" WIĘCEJ kontekstu, TE SAME wagi!", FS_SMALL, GREEN_ACCENT, "bold"),
("", 0, "", ""),
("Dlaczego ważne w segmentacji?", FS, BLACK, "bold"),
(" Piksel sam nie wie czym jest.", FS_SMALL, BLACK, "normal"),
(" Potrzebuje KONTEKSTU (otoczenia).", FS_SMALL, BLACK, "normal"),
(" Większe RF → widzi obok budynki", FS_SMALL, BLACK, "normal"),
(' → wie, że TEN piksel to „droga"', FS_SMALL, GREEN_ACCENT, "bold"),
("", 0, "", ""),
("Global Average Pooling:", FS, BLACK, "bold"),
(" Mapa HxWxC → 1x1xC", FS_SMALL, BLACK, "normal"),
(" Średnia z CAŁEGO feature map", FS_SMALL, BLACK, "normal"),
(" RF = nieskończone (cały obraz)", FS_SMALL, GREEN_ACCENT, "bold"),
]
for txt, size, color, weight in lines:
if txt == "":
y -= 0.2
continue
ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
y -= 0.45
_save_figure("q23_receptive_field.png")
def generate_transformer() -> None:
"""Generate transformer."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
# --- Panel 1: CNN local vs Transformer global ---
ax = axes[0]
ax.set_xlim(-0.5, 8.5)
ax.set_ylim(-1.5, 8.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("CNN: widzi LOKALNIE", fontsize=FS_TITLE, fontweight="bold")
# Draw 8x8 grid
for i in range(8):
for j in range(8):
color = WHITE
if (
_HIGHLIGHT_START <= i <= _HIGHLIGHT_END
and _HIGHLIGHT_START <= j <= _HIGHLIGHT_END
):
color = ACCENT_LIGHT
rect = patches.Rectangle(
(j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3
)
ax.add_patch(rect)
# Highlight center
rect = patches.Rectangle(
(4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.7
)
ax.add_patch(rect)
ax.text(
4.5,
4.5,
"?",
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
color=WHITE,
)
ax.text(
4.5,
-0.8,
"Filtr 3x3 widzi tylko\n9 sąsiednich pikseli",
fontsize=FS_SMALL,
ha="center",
color=ACCENT,
)
# --- Panel 2: Transformer global ---
ax = axes[1]
ax.set_xlim(-0.5, 8.5)
ax.set_ylim(-1.5, 8.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Transformer: widzi GLOBALNIE", fontsize=FS_TITLE, fontweight="bold")
# Draw 8x8 grid all highlighted
for i in range(8):
for j in range(8):
color = "#FFCDD2"
rect = patches.Rectangle(
(j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3
)
ax.add_patch(rect)
rect = patches.Rectangle(
(4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.9
)
ax.add_patch(rect)
ax.text(
4.5,
4.5,
"?",
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
color=WHITE,
)
ax.text(
4.5,
-0.8,
'Self-attention „pyta"\nALL 64 piksele naraz',
fontsize=FS_SMALL,
ha="center",
color=RED_ACCENT,
)
# --- Panel 3: SOTA + Transformer explanation ---
ax = axes[2]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Transformer & SOTA", fontsize=FS_TITLE, fontweight="bold")
y = 9.2
lines = [
("Transformer:", FS, BLACK, "bold"),
(" Architektura z 2017 (Vaswani et al.)", FS_SMALL, BLACK, "normal"),
(" Oryginalnie do NLP (tłumaczenie)", FS_SMALL, BLACK, "normal"),
(" Kluczowy mechanizm: SELF-ATTENTION", FS_SMALL, ACCENT, "bold"),
("", 0, "", ""),
("Self-attention w skrócie:", FS, BLACK, "bold"),
(" Każdy piksel tworzy trzy wektory:", FS_SMALL, BLACK, "normal"),
(' Q (Query — „czego szukam?")', FS_SMALL, ACCENT, "normal"),
(' K (Key — „co oferuję innych")', FS_SMALL, RED_ACCENT, "normal"),
(' V (Value — „moja wartość")', FS_SMALL, GREEN_ACCENT, "normal"),
(" Attention = softmax(Q·Kᵀ/√d)·V", FS_SMALL, BLACK, "bold"),
(" Koszt: O(n²) — n=liczba pikseli", FS_SMALL, RED_ACCENT, "normal"),
("", 0, "", ""),
("SOTA = State Of The Art:", FS, BLACK, "bold"),
(" Najlepszy znany wynik na benchmarku", FS_SMALL, BLACK, "normal"),
(' Np. „mIoU 85.1% na ADE20K = SOTA"', FS_SMALL, BLACK, "normal"),
(" Ciągle się zmienia (nowy paper", FS_SMALL, GRAY5, "normal"),
(" → nowy SOTA)", FS_SMALL, GRAY5, "normal"),
]
for txt, size, color, weight in lines:
if txt == "":
y -= 0.15
continue
ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
y -= 0.45
_save_figure("q23_transformer_attention.png")

View File

@ -0,0 +1,408 @@
"""Region growing and DIY thresholding diagram generators."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q23_common import (
_BRIGHT_THRESHOLD,
_OTSU_THRESHOLD,
ACCENT,
ACCENT_LIGHT,
BLACK,
FS,
FS_SMALL,
FS_TINY,
FS_TITLE,
GRAY3,
GRAY4,
GRAY5,
GREEN_ACCENT,
RED_ACCENT,
WHITE,
_save_figure,
np,
plt,
rng,
)
from matplotlib import patches
if TYPE_CHECKING:
from matplotlib.axes import Axes
def _draw_region_growing_grid(ax: Axes) -> None:
"""Draw panel 2: region growing step-by-step grid."""
ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-1.5, 7.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Region Growing: krok po kroku",
fontsize=FS_TITLE,
fontweight="bold",
)
pixel_grid = np.array(
[
[150, 153, 148, 200, 210, 205],
[147, 155, 152, 195, 208, 200],
[145, 148, 160, 190, 195, 210],
[200, 195, 190, 155, 148, 150],
[210, 205, 200, 150, 152, 145],
[215, 208, 195, 148, 147, 155],
]
)
region_mask = np.array(
[
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
]
)
for i in range(6):
for j in range(6):
v = pixel_grid[i, j]
if region_mask[i, j] == 1 and v < _BRIGHT_THRESHOLD:
cell_color = ACCENT_LIGHT
elif region_mask[i, j] == 1:
cell_color = "#E0E0E0"
else:
cell_color = WHITE
if i == 1 and j == 1:
cell_color = "#FFD54F"
rect = patches.Rectangle(
(j, 5 - i),
1,
1,
facecolor=cell_color,
edgecolor=GRAY4,
linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
j + 0.5,
5 - i + 0.5,
str(v),
ha="center",
va="center",
fontsize=FS_TINY,
fontweight="bold",
)
ax.annotate(
"SEED\n(155)",
xy=(1.5, 4.5),
fontsize=FS_SMALL,
ha="center",
color=RED_ACCENT,
fontweight="bold",
arrowprops={"arrowstyle": "->", "color": RED_ACCENT},
xytext=(-0.5, 7),
)
ax.text(
3,
-0.8,
"Próg = 20\nNiebieski = region (|val - seed| < 20)",
fontsize=FS_TINY,
ha="center",
color=ACCENT,
)
def _draw_bfs_expansion(ax: Axes) -> None:
"""Draw panel 3: BFS expansion visualization."""
ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-1.5, 7.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Rosnący region (BFS)",
fontsize=FS_TITLE,
fontweight="bold",
)
wave_colors = ["#FFD54F", "#FFF176", "#FFF9C4", ACCENT_LIGHT, "#B3D4FC"]
wave_labels = ["Seed", "Fala 1", "Fala 2", "Fala 3", "Fala 4"]
waves = [
[(1, 1)],
[(0, 1), (1, 0), (1, 2), (2, 1)],
[(0, 0), (0, 2), (2, 0), (2, 2)],
]
for i in range(6):
for j in range(6):
cell_color = WHITE
for w_idx, wave in enumerate(waves):
if (i, j) in wave:
cell_color = wave_colors[w_idx]
rect = patches.Rectangle(
(j, 5 - i),
1,
1,
facecolor=cell_color,
edgecolor=GRAY4,
linewidth=0.5,
)
ax.add_patch(rect)
seed_x, seed_y = 1.5, 4.5
for dx, dy, _label in [
(0, 1, ""),
(0, -1, ""),
(1, 0, ""),
(-1, 0, ""),
]:
ax.annotate(
"",
xy=(seed_x + dx * 0.7, seed_y + dy * 0.7),
xytext=(seed_x, seed_y),
arrowprops={
"arrowstyle": "->",
"color": RED_ACCENT,
"lw": 1.2,
},
)
ax.text(
3,
-0.5,
"BFS: sprawdzaj sąsiadów,\ndodawaj podobne do kolejki",
fontsize=FS_TINY,
ha="center",
color=GRAY5,
)
for w_idx, (wave_color, label) in enumerate(
zip(wave_colors[:3], wave_labels[:3], strict=False)
):
rect = patches.Rectangle(
(4, 6.5 - w_idx * 0.7),
0.5,
0.5,
facecolor=wave_color,
edgecolor=GRAY4,
linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
4.8,
6.75 - w_idx * 0.7,
label,
fontsize=FS_TINY,
va="center",
)
def generate_region_growing() -> None:
"""Generate region growing."""
_fig, axes = plt.subplots(1, 3, figsize=(11, 4.2))
# --- Panel 1: Manual vs automatic seed ---
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
ax.set_title("Seed: ręcznie vs automatycznie", fontsize=FS_TITLE, fontweight="bold")
y = 9.2
lines = [
("Ręczny seed:", FS, ACCENT, "bold"),
(" Użytkownik klika na obraz", FS_SMALL, BLACK, "normal"),
(' → „tu jest obiekt, od tego zacznij"', FS_SMALL, BLACK, "normal"),
(" Użycie: segmentacja interaktywna", FS_SMALL, GRAY5, "normal"),
(" (np. Photoshop — magic wand tool)", FS_SMALL, GRAY5, "normal"),
("", 0, "", ""),
("Automatyczny seed:", FS, RED_ACCENT, "bold"),
(" 1. Histogram → lokalne maxima", FS_SMALL, BLACK, "normal"),
(" (najczęstsza jasność → seed)", FS_SMALL, GRAY5, "normal"),
(" 2. Grid: siatka co N pikseli", FS_SMALL, BLACK, "normal"),
(" (np. seed co 50 px → 100 seedów)", FS_SMALL, GRAY5, "normal"),
(" 3. Losowe próbkowanie", FS_SMALL, BLACK, "normal"),
(" 4. Ekstrema lokalne gradientu", FS_SMALL, BLACK, "normal"),
("", 0, "", ""),
("Dlaczego OR?", FS, GREEN_ACCENT, "bold"),
(" Ręczny → precyzyjny, ale wolny", FS_SMALL, BLACK, "normal"),
(" Auto → szybki, ale over-segmentation", FS_SMALL, BLACK, "normal"),
]
for txt, size, color, weight in lines:
if txt == "":
y -= 0.15
continue
ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
y -= 0.45
# --- Panel 2: Region growing step by step ---
_draw_region_growing_grid(axes[1])
# --- Panel 3: BFS expansion ---
_draw_bfs_expansion(axes[2])
_save_figure("q23_region_growing.png")
def _draw_otsu_variance_and_pseudocode(
ax_var: Axes,
ax_code: Axes,
img: np.ndarray,
) -> int:
"""Draw panels 4 and 5: Otsu variance plot and pseudocode."""
thresholds = range(10, 245)
variances = []
for t in thresholds:
c0 = img[img <= t].ravel()
c1 = img[img > t].ravel()
if len(c0) == 0 or len(c1) == 0:
variances.append(np.nan)
continue
w0 = len(c0) / len(img.ravel())
w1 = len(c1) / len(img.ravel())
var = w0 * np.var(c0) + w1 * np.var(c1)
variances.append(var)
ax_var.plot(list(thresholds), variances, color=ACCENT, linewidth=1.5)
best_t = list(thresholds)[np.nanargmin(variances)]
ax_var.axvline(
x=best_t,
color=RED_ACCENT,
linewidth=1.5,
linestyle="--",
label=f"Otsu T={best_t}",
)
ax_var.scatter(
[best_t],
[np.nanmin(variances)],
c=RED_ACCENT,
s=60,
zorder=5,
)
ax_var.set_xlabel("Próg T", fontsize=FS_SMALL)
ax_var.set_ylabel("σ² wewnątrzklasowa", fontsize=FS_SMALL)
ax_var.set_title(
"Krok 4: Otsu szuka min σ²",
fontsize=FS,
fontweight="bold",
)
ax_var.legend(fontsize=FS_TINY)
ax_code.set_xlim(0, 10)
ax_code.set_ylim(0, 10)
ax_code.axis("off")
ax_code.set_title("Pseudokod Otsu", fontsize=FS, fontweight="bold")
code_lines = [
"best_T = 0",
"min_var = ∞",
"",
"for T in 0..255:",
" c0 = piksele z jasność ≤ T",
" c1 = piksele z jasność > T",
" w0 = len(c0) / len(all)",
" w1 = len(c1) / len(all)",
" var = w0·var(c0) + w1·var(c1)",
" if var < min_var:",
" min_var = var",
" best_T = T",
"",
"return best_T # optymalny próg",
]
for i, line in enumerate(code_lines):
txt_color = ACCENT if "best_T = T" in line or "return" in line else BLACK
ax_code.text(
0.5,
9.5 - i * 0.65,
line,
fontsize=FS_TINY,
fontfamily="monospace",
color=txt_color,
fontweight="bold" if txt_color == ACCENT else "normal",
)
return int(best_t)
def generate_diy_thresholding() -> None:
"""Generate diy thresholding."""
_fig, axes = plt.subplots(2, 3, figsize=(11, 7))
# Create a simple synthetic image: dark circle on bright background
size = 64
img = np.ones((size, size)) * 200 # bright background
yy, xx = np.mgrid[:size, :size]
mask = ((xx - 32) ** 2 + (yy - 32) ** 2) < 15**2
img[mask] = 60 # dark circle
# Add some noise
img += rng.normal(0, 10, img.shape)
img = np.clip(img, 0, 255)
# --- Panel 1: Original image ---
ax = axes[0, 0]
ax.imshow(img, cmap="gray", vmin=0, vmax=255)
ax.set_title("Krok 1: obraz wejściowy", fontsize=FS, fontweight="bold")
ax.axis("off")
ax.text(32, -3, "64x64 pikseli, szare", fontsize=FS_TINY, ha="center")
# --- Panel 2: Histogram ---
ax = axes[0, 1]
counts, _bins, _ = ax.hist(
img.ravel(), bins=50, color=GRAY3, edgecolor=GRAY5, linewidth=0.5
)
ax.axvline(
x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="T=128 (Otsu)"
)
ax.set_xlabel("Jasność", fontsize=FS_SMALL)
ax.set_ylabel("Piksele", fontsize=FS_SMALL)
ax.set_title("Krok 2: histogram\n(bimodalny!)", fontsize=FS, fontweight="bold")
ax.legend(fontsize=FS_TINY)
ax.annotate(
"Garb 1\n(obiekt)",
xy=(60, max(counts) * 0.5),
fontsize=FS_TINY,
ha="center",
color=ACCENT,
fontweight="bold",
)
ax.annotate(
"Garb 2\n(tło)",
xy=(200, max(counts) * 0.5),
fontsize=FS_TINY,
ha="center",
color=RED_ACCENT,
fontweight="bold",
)
# --- Panel 3: Thresholding result ---
ax = axes[0, 2]
binary = (img > _OTSU_THRESHOLD).astype(float)
ax.imshow(binary, cmap="gray", vmin=0, vmax=1)
ax.set_title("Krok 3: progowanie T=128", fontsize=FS, fontweight="bold")
ax.axis("off")
ax.text(32, -3, "Biały = tło, Czarny = obiekt", fontsize=FS_TINY, ha="center")
# --- Panels 4+5: Otsu variance plot + pseudocode ---
best_t = _draw_otsu_variance_and_pseudocode(
axes[1, 0],
axes[1, 1],
img,
)
# --- Panel 6: Final result with Otsu ---
ax = axes[1, 2]
binary_otsu = (img > best_t).astype(float)
ax.imshow(binary_otsu, cmap="gray", vmin=0, vmax=1)
ax.set_title(f"Krok 5: wynik Otsu (T={best_t})", fontsize=FS, fontweight="bold")
ax.axis("off")
ax.text(
32,
-3,
"Automatyczny próg!",
fontsize=FS_TINY,
ha="center",
color=GREEN_ACCENT,
fontweight="bold",
)
_save_figure("q23_diy_thresholding.png")

View File

@ -0,0 +1,186 @@
"""Common utilities and constants for Q24 diagram generation.
Monochrome, A4-printable PNGs (300 DPI).
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
import numpy as np
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
_logger = logging.getLogger(__name__)
rng = np.random.default_rng(42)
DPI = 300
BG = "white"
LN = "black"
FS = 8
FS_TITLE = 11
FS_SMALL = 6.5
FS_LABEL = 9
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
GRAY1 = "#E8E8E8"
GRAY2 = "#D0D0D0"
GRAY3 = "#B8B8B8"
GRAY4 = "#F5F5F5"
GRAY5 = "#C0C0C0"
_PIXEL_BRIGHT_THRESH = 127
_GRADIENT_BRIGHT_THRESH = 100
_DATA_BRIGHT_THRESH = 5
_II_BRIGHT_THRESH = 25
_DOTS_STAGE_IDX = 2
def draw_box(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
text: str,
*,
fill: str = "white",
lw: float = 1.2,
fontsize: float = FS,
fontweight: str = "normal",
ha: str = "center",
va: str = "center",
rounded: bool = True,
edgecolor: str = LN,
linestyle: str = "-",
) -> None:
"""Draw box."""
if rounded:
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.05",
lw=lw,
edgecolor=edgecolor,
facecolor=fill,
linestyle=linestyle,
)
else:
rect = mpatches.Rectangle(
(x, y),
w,
h,
lw=lw,
edgecolor=edgecolor,
facecolor=fill,
linestyle=linestyle,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h / 2,
text,
ha=ha,
va=va,
fontsize=fontsize,
fontweight=fontweight,
wrap=True,
)
def draw_arrow(
ax: Axes,
x1: float,
y1: float,
x2: float,
y2: float,
*,
lw: float = 1.2,
style: str = "->",
color: str = LN,
) -> None:
"""Draw arrow."""
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
)
def save_fig(fig: Figure, name: str) -> None:
"""Save fig."""
path = str(Path(OUTPUT_DIR) / name)
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15)
plt.close(fig)
_logger.info(" Saved: %s", path)
def draw_table(
ax: Axes,
headers: list[str],
rows: list[list[str]],
x0: float,
y0: float,
col_widths: list[float],
*,
row_h: float = 0.4,
header_fill: str = GRAY2,
row_fills: list[str] | None = None,
fontsize: float = FS,
header_fontsize: float | None = None,
) -> None:
"""Draw table."""
if header_fontsize is None:
header_fontsize = fontsize
len(headers)
cx = x0
for j, hdr in enumerate(headers):
draw_box(
ax,
cx,
y0,
col_widths[j],
row_h,
hdr,
fill=header_fill,
fontsize=header_fontsize,
fontweight="bold",
rounded=False,
)
cx += col_widths[j]
for i, row in enumerate(rows):
cy = y0 - (i + 1) * row_h
cx = x0
fill = GRAY4 if (i % 2 == 0) else "white"
if row_fills and i < len(row_fills):
fill = row_fills[i]
for j, cell in enumerate(row):
fw = "bold" if j == 0 else "normal"
draw_box(
ax,
cx,
cy,
col_widths[j],
row_h,
cell,
fill=fill,
fontsize=fontsize,
fontweight=fw,
rounded=False,
)
cx += col_widths[j]

View File

@ -0,0 +1,412 @@
"""FPN, anchor boxes, detection tasks, and CNN architecture diagrams."""
from __future__ import annotations
from _q24_common import (
FS,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
LN,
draw_arrow,
draw_box,
np,
plt,
save_fig,
)
import matplotlib.patches as mpatches
# ============================================================
# 16. FPN (Feature Pyramid Network)
# ============================================================
def draw_fpn() -> None:
"""Draw fpn."""
fig, ax = plt.subplots(figsize=(9, 5))
ax.set_xlim(-0.5, 9.5)
ax.set_ylim(-0.5, 5.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"FPN (Feature Pyramid Network) — detekcja obiektów wszystkich rozmiarów",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
levels = [
(0, 0, 2.0, 2.0, "C2\n56x56", "duże\ndetale"),
(0, 2.2, 1.5, 1.5, "C3\n28x28", ""),
(0, 3.9, 1.0, 1.0, "C4\n14x14", ""),
(0, 5.1, 0.6, 0.6, "C5\n7x7", "kontekst"),
]
for x, y, w, h, label, note in levels:
ax.add_patch(
mpatches.Rectangle((x, y - h), w, h, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
ax.text(
x + w / 2,
y - h / 2,
label,
ha="center",
va="center",
fontsize=FS_SMALL,
fontweight="bold",
)
if note:
ax.text(
x + w + 0.15,
y - h / 2,
note,
ha="left",
va="center",
fontsize=5,
style="italic",
)
ax.text(
1.0, -0.3, "Bottom-up\n(backbone)", ha="center", fontsize=FS, fontweight="bold"
)
# Top-down + lateral
td_levels = [
(4.5, 5.1, 0.6, 0.6, "P5"),
(4.5, 3.9, 1.0, 1.0, "P4"),
(4.5, 2.2, 1.5, 1.5, "P3"),
(4.5, 0, 2.0, 2.0, "P2"),
]
for x, y, w, h, label in td_levels:
ax.add_patch(
mpatches.Rectangle(
(x, y - h + h), w, h, facecolor=GRAY2, edgecolor=LN, lw=1.5
)
)
ax.text(
x + w / 2,
y - h / 2 + h,
label,
ha="center",
va="center",
fontsize=FS_SMALL,
fontweight="bold",
)
# Lateral connections
for (_, y1, w1, h1, _, _), (x2, y2, _w2, h2, _) in zip(
levels, td_levels, strict=False
):
draw_arrow(ax, w1 + 0.2, y1 - h1 / 2, x2 - 0.1, y2 + h2 / 2, lw=1, style="->")
# Top-down arrows
for i in range(len(td_levels) - 1):
x2, y2, w2, h2, _ = td_levels[i]
x3, y3, w3, h3, _ = td_levels[i + 1]
draw_arrow(
ax,
x2 + w2 / 2,
y2,
x3 + w3 / 2,
y3 + h3 + 0.1,
lw=1.2,
style="->",
color=GRAY3,
)
ax.text(
5.5,
-0.3,
"Top-down + lateral\n(FPN)",
ha="center",
fontsize=FS,
fontweight="bold",
)
# Detection outputs
det_labels = ["małe obj.", "średnie", "duże", "b. duże"]
for i, (x, y, w, h, _label) in enumerate(td_levels):
draw_arrow(ax, x + w + 0.1, y + h / 2, 7.5, y + h / 2, lw=0.8)
ax.text(
7.7,
y + h / 2,
f"detekcja:\n{det_labels[3 - i]}",
fontsize=FS_SMALL,
va="center",
)
save_fig(fig, "q24_fpn.png")
# ============================================================
# 17. Anchor boxes
# ============================================================
def draw_anchor_boxes() -> None:
"""Draw anchor boxes."""
fig, ax = plt.subplots(figsize=(7, 5))
ax.set_title(
"Anchor boxes — predefiniowane kształty",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1))
# Center point
cx, cy = 3, 2.5
ax.plot(cx, cy, "ko", markersize=8, zorder=5)
ax.text(cx + 0.15, cy + 0.15, "(x, y)", fontsize=FS, fontweight="bold")
# 9 anchors: 3 sizes x 3 ratios
anchors = [
(0.8, 0.8, "-", "1:1 small"),
(1.6, 1.6, "-", "1:1 medium"),
(2.4, 2.4, "-", "1:1 large"),
(0.6, 1.2, "--", "1:2 small"),
(1.2, 2.4, "--", "1:2 medium"),
(1.8, 3.6, "--", "1:2 large"),
(1.2, 0.6, ":", "2:1 small"),
(2.4, 1.2, ":", "2:1 medium"),
(3.6, 1.8, ":", "2:1 large"),
]
for w, h, ls, _label in anchors:
rect = mpatches.Rectangle(
(cx - w / 2, cy - h / 2),
w,
h,
facecolor="none",
edgecolor=LN,
lw=1.2,
linestyle=ls,
)
ax.add_patch(rect)
# Legend-style labels
ax.text(
3,
-0.5,
"9 anchorów = 3 rozmiary x 3 proporcje (1:1, 1:2, 2:1)\n"
"Sieć predykuje PRZESUNIĘCIE od najbliższego anchora",
ha="center",
fontsize=FS,
style="italic",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-1.2, 5.5)
ax.set_aspect("equal")
ax.axis("off")
save_fig(fig, "q24_anchor_boxes.png")
# ============================================================
# 18. Detection task comparison
# ============================================================
def draw_detection_tasks() -> None:
"""Draw detection tasks."""
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle(
"Klasyfikacja vs Detekcja vs Segmentacja",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Classification
ax = axes[0]
ax.add_patch(
mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
# Simple cat silhouette
ax.add_patch(mpatches.Ellipse((2, 2), 2, 1.5, facecolor=GRAY3, edgecolor=LN, lw=1))
ax.add_patch(mpatches.Ellipse((2, 3), 1, 0.8, facecolor=GRAY3, edgecolor=LN, lw=1))
# Ears
ax.plot([1.6, 1.5, 1.8], [3.3, 3.8, 3.4], color=LN, lw=1.5)
ax.plot([2.2, 2.5, 2.4], [3.3, 3.8, 3.4], color=LN, lw=1.5)
ax.text(
2, -0.4, '"KOT" (jedna etykieta)', ha="center", fontsize=FS, fontweight="bold"
)
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.8, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Klasyfikacja\n(co?)", fontsize=FS, fontweight="bold")
# Detection
ax = axes[1]
ax.add_patch(
mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
# Cat
ax.add_patch(
mpatches.Ellipse((1.2, 2), 1.2, 1, facecolor=GRAY3, edgecolor=LN, lw=1)
)
ax.add_patch(
mpatches.Ellipse((1.2, 2.8), 0.7, 0.5, facecolor=GRAY3, edgecolor=LN, lw=1)
)
# Dog
ax.add_patch(
mpatches.Ellipse((3, 1.5), 1.2, 1, facecolor=GRAY2, edgecolor=LN, lw=1)
)
ax.add_patch(
mpatches.Ellipse((3, 2.3), 0.7, 0.5, facecolor=GRAY2, edgecolor=LN, lw=1)
)
# Bounding boxes
ax.add_patch(
mpatches.Rectangle((0.3, 1.2), 1.8, 2.2, facecolor="none", edgecolor=LN, lw=2.5)
)
ax.text(1.2, 3.5, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.add_patch(
mpatches.Rectangle((2.1, 0.8), 1.7, 2.0, facecolor="none", edgecolor=LN, lw=2.5)
)
ax.text(3.0, 2.9, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.text(
2,
-0.4,
"→ bbox + klasa (N obiektów)",
ha="center",
fontsize=FS,
fontweight="bold",
)
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.8, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Detekcja\n(co? + gdzie?)", fontsize=FS, fontweight="bold")
# Segmentation
ax = axes[2]
ax.add_patch(
mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
# Cat mask (detailed)
theta = np.linspace(0, 2 * np.pi, 30)
cat_x = 1.2 + 0.6 * np.cos(theta) + 0.1 * np.sin(3 * theta)
cat_y = 2 + 0.5 * np.sin(theta) + 0.1 * np.cos(2 * theta)
ax.fill(cat_x, cat_y, facecolor=GRAY3, edgecolor=LN, lw=1.5)
# Dog mask
dog_x = 3.0 + 0.6 * np.cos(theta) + 0.05 * np.sin(4 * theta)
dog_y = 1.5 + 0.5 * np.sin(theta) + 0.08 * np.cos(3 * theta)
ax.fill(dog_x, dog_y, facecolor=GRAY2, edgecolor=LN, lw=1.5)
ax.text(1.2, 2, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.text(3.0, 1.5, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.text(
2,
-0.4,
"→ maska pikseli (per piksel)",
ha="center",
fontsize=FS,
fontweight="bold",
)
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.8, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Segmentacja\n(dokładna maska)", fontsize=FS, fontweight="bold")
fig.tight_layout()
save_fig(fig, "q24_detection_tasks.png")
# ============================================================
# 19. CNN Architecture overview
# ============================================================
def draw_cnn_architecture() -> None:
"""Draw cnn architecture."""
fig, ax = plt.subplots(figsize=(12, 4))
ax.set_xlim(-0.5, 12.5)
ax.set_ylim(-1, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"CNN — od obrazu do predykcji (architektura)",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
# Input image
draw_box(ax, 0, 0.5, 1.5, 3, "Obraz\n224x224x3", fill=GRAY1, fontsize=FS)
# Conv1
draw_arrow(ax, 1.6, 2.0, 2.1, 2.0, lw=1.2)
draw_box(
ax, 2.2, 0.8, 1.2, 2.4, "Conv1\n+ReLU\n55x55x96", fill=GRAY4, fontsize=FS_SMALL
)
# Pool1
draw_arrow(ax, 3.5, 2.0, 3.9, 2.0, lw=1.2)
draw_box(ax, 4.0, 1.0, 1.0, 2.0, "Pool\n27x27\nx96", fill=GRAY2, fontsize=FS_SMALL)
# Conv2
draw_arrow(ax, 5.1, 2.0, 5.5, 2.0, lw=1.2)
draw_box(
ax,
5.6,
0.8,
1.2,
2.4,
"Conv2\n+ReLU\n27x27\nx256",
fill=GRAY4,
fontsize=FS_SMALL,
)
# Pool2
draw_arrow(ax, 6.9, 2.0, 7.3, 2.0, lw=1.2)
draw_box(ax, 7.4, 1.2, 0.8, 1.6, "Pool\n13x13\nx256", fill=GRAY2, fontsize=FS_SMALL)
# More conv...
draw_arrow(ax, 8.3, 2.0, 8.7, 2.0, lw=1.2)
ax.text(9.0, 2.0, "...", fontsize=14, ha="center", va="center")
draw_arrow(ax, 9.3, 2.0, 9.7, 2.0, lw=1.2)
# FC
draw_box(ax, 9.8, 1.2, 1.0, 1.6, "FC\n4096", fill=GRAY3, fontsize=FS)
draw_arrow(ax, 10.9, 2.0, 11.3, 2.0, lw=1.2)
# Output
draw_box(
ax, 11.4, 1.5, 1.0, 1.0, "Softmax\n1000 klas", fill=GRAY1, fontsize=FS_SMALL
)
# Annotations below
ax.text(
3.0,
0.0,
"rozmiar maleje\n224→55→27→13→6",
ha="center",
fontsize=FS_SMALL,
style="italic",
)
ax.text(
6.0,
0.0,
"kanały rosną\n3→96→256→384",
ha="center",
fontsize=FS_SMALL,
style="italic",
)
ax.text(
10.0, 0.0, "decyzja\nkońcowa", ha="center", fontsize=FS_SMALL, style="italic"
)
# hierarchy
ax.text(
6.0,
4.0,
"Hierarchia: krawędzie → rogi → fragmenty → obiekty (K-R-F-O)",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
save_fig(fig, "q24_cnn_architecture.png")

View File

@ -0,0 +1,342 @@
"""Haar features, integral image, and SVM hyperplane diagrams."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q24_common import (
_DATA_BRIGHT_THRESH,
_II_BRIGHT_THRESH,
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY3,
GRAY4,
LN,
np,
plt,
rng,
save_fig,
)
import matplotlib.patches as mpatches
if TYPE_CHECKING:
from matplotlib.axes import Axes
# ============================================================
# 4. Haar Features
# ============================================================
def draw_haar_features() -> None:
"""Draw haar features."""
fig, axes = plt.subplots(1, 4, figsize=(11, 3))
fig.suptitle(
"Cechy Haar — typy i zastosowanie na twarzy",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Feature 1: Vertical edge
ax = axes[0]
ax.add_patch(
mpatches.Rectangle((0, 0), 1, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
ax.add_patch(
mpatches.Rectangle((1, 0), 1, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5)
)
ax.text(
0.5, 1, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
)
ax.text(
1.5, 1, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
)
ax.set_xlim(-0.2, 2.2)
ax.set_ylim(-0.5, 2.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Krawędź pionowa\nwartość = Σ₁ - Σ₂", fontsize=FS)
# Feature 2: Horizontal edge
ax = axes[1]
ax.add_patch(
mpatches.Rectangle((0, 1), 2, 1, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
ax.add_patch(
mpatches.Rectangle((0, 0), 2, 1, facecolor=GRAY3, edgecolor=LN, lw=1.5)
)
ax.text(
1, 1.5, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
)
ax.text(
1, 0.5, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
)
ax.set_xlim(-0.2, 2.2)
ax.set_ylim(-0.5, 2.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Krawędź pozioma\n(oczy vs czoło)", fontsize=FS)
# Feature 3: Three-rectangle (line)
ax = axes[2]
ax.add_patch(
mpatches.Rectangle((0, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5)
)
ax.add_patch(
mpatches.Rectangle((0.7, 0), 0.7, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
ax.add_patch(
mpatches.Rectangle((1.4, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5)
)
ax.text(
0.35, 1, "-Σ₁", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold"
)
ax.text(
1.05, 1, "+Σ₂", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold"
)
ax.text(
1.75, 1, "-Σ₃", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold"
)
ax.set_xlim(-0.2, 2.3)
ax.set_ylim(-0.5, 2.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Linia (3 prostokąty)\n(nos vs policzki)", fontsize=FS)
_draw_haar_face_panel(axes[3])
fig.tight_layout()
save_fig(fig, "q24_haar_features.png")
def _draw_haar_face_panel(ax: Axes) -> None:
"""Draw Haar feature application on face schematic."""
face = mpatches.Ellipse(
(1.2, 1.2),
2.0,
2.4,
facecolor=GRAY4,
edgecolor=LN,
lw=1.5,
)
ax.add_patch(face)
ax.add_patch(
mpatches.Ellipse((0.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1)
)
ax.add_patch(
mpatches.Ellipse((1.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1)
)
ax.plot([1.2, 1.1, 1.3], [1.3, 0.9, 0.9], color=LN, lw=1)
ax.plot([0.8, 1.0, 1.2, 1.4, 1.6], [0.55, 0.5, 0.55, 0.5, 0.55], color=LN, lw=1)
ax.add_patch(
mpatches.Rectangle(
(0.3, 1.4),
1.8,
0.4,
facecolor="none",
edgecolor=LN,
lw=2,
linestyle="--",
)
)
ax.annotate(
"cechy Haar\n(oczy ciemne\nvs czoło jasne)",
xy=(1.2, 1.85),
xytext=(2.2, 2.3),
fontsize=FS_SMALL,
ha="center",
arrowprops={"arrowstyle": "->", "color": LN, "lw": 1},
)
ax.set_xlim(-0.2, 3.0)
ax.set_ylim(-0.2, 2.8)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Zastosowanie na twarzy", fontsize=FS)
# ============================================================
# 5. Integral Image
# ============================================================
def draw_integral_image() -> None:
"""Draw integral image."""
fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
fig.suptitle(
"Integral Image — suma prostokąta w O(1)",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Original image
ax = axes[0]
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
ax.imshow(data, cmap="gray", vmin=0, vmax=10)
for i in range(3):
for j in range(3):
ax.text(
j,
i,
str(data[i, j]),
ha="center",
va="center",
fontsize=12,
fontweight="bold",
color="white" if data[i, j] > _DATA_BRIGHT_THRESH else "black",
)
ax.set_title("① Obraz oryginalny", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Integral image
ax = axes[1]
ii = np.array([[1, 3, 6], [5, 12, 21], [12, 27, 45]])
ax.imshow(ii, cmap="gray", vmin=0, vmax=50)
for i in range(3):
for j in range(3):
ax.text(
j,
i,
str(ii[i, j]),
ha="center",
va="center",
fontsize=12,
fontweight="bold",
color="white" if ii[i, j] > _II_BRIGHT_THRESH else "black",
)
ax.set_title("② Integral Image\n(sumy kumulatywne)", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Formula illustration
ax = axes[2]
ax.axis("off")
ax.set_xlim(0, 4)
ax.set_ylim(0, 4)
# Draw rectangle
ax.add_patch(
mpatches.Rectangle((0.5, 0.5), 3, 3, facecolor="white", edgecolor=LN, lw=1)
)
ax.add_patch(
mpatches.Rectangle((1.5, 0.5), 2, 2, facecolor=GRAY3, edgecolor=LN, lw=2)
)
# Labels
ax.text(0.3, 3.7, "A", fontsize=12, fontweight="bold")
ax.text(3.6, 3.7, "B", fontsize=12, fontweight="bold")
ax.text(0.3, 0.3, "C", fontsize=12, fontweight="bold")
ax.text(3.6, 0.3, "D", fontsize=12, fontweight="bold")
ax.text(
2.5,
1.5,
"SZUKANA\nSUMA",
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
ax.text(
2.0,
-0.3,
"Suma = D - B - C + A\n= 4 odczyty → O(1) ZAWSZE!",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.set_title(
"③ Formuła: 4 odczyty\n= O(1) niezależnie od rozmiaru",
fontsize=FS,
fontweight="bold",
)
fig.tight_layout()
save_fig(fig, "q24_integral_image.png")
# ============================================================
# 11. SVM Hyperplane
# ============================================================
def draw_svm_hyperplane() -> None:
"""Draw svm hyperplane."""
fig, ax = plt.subplots(figsize=(6, 5))
ax.set_title(
"SVM — hiperpłaszczyzna i margines",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
x_pos = rng.standard_normal(15) * 0.5 + 3
y_pos = rng.standard_normal(15) * 0.5 + 3
ax.scatter(
x_pos,
y_pos,
marker="o",
s=50,
facecolors="white",
edgecolors=LN,
linewidths=1.5,
label="klasa +1 (pieszy)",
zorder=3,
)
x_neg = rng.standard_normal(15) * 0.5 + 1
y_neg = rng.standard_normal(15) * 0.5 + 1
ax.scatter(
x_neg,
y_neg,
marker="x",
s=50,
c=LN,
linewidths=1.5,
label="klasa -1 (tło)",
zorder=3,
)
# Hyperplane (decision boundary)
x_line = np.linspace(-0.5, 5, 100)
y_line = -x_line + 4.0
ax.plot(x_line, y_line, "k-", lw=2, label="hiperpłaszczyzna")
# Margin lines
ax.plot(x_line, y_line + 0.7, "k--", lw=1, alpha=0.5)
ax.plot(x_line, y_line - 0.7, "k--", lw=1, alpha=0.5)
# Margin annotation
ax.annotate(
"",
xy=(2.5, 1.5 + 0.7),
xytext=(2.5, 1.5 - 0.7),
arrowprops={"arrowstyle": "<->", "color": LN, "lw": 1.5},
)
ax.text(2.8, 1.5, "margines\n(MAX!)", fontsize=FS, fontweight="bold")
# Support vectors (highlight closest points)
ax.scatter(
[2.5],
[2.2],
marker="o",
s=120,
facecolors="none",
edgecolors=LN,
linewidths=2.5,
zorder=4,
)
ax.scatter([1.5], [1.8], marker="x", s=120, c=LN, linewidths=2.5, zorder=4)
ax.annotate(
"support\nvectors",
xy=(1.5, 1.8),
xytext=(0.2, 3.0),
fontsize=FS,
fontweight="bold",
arrowprops={"arrowstyle": "->", "color": LN, "lw": 1},
)
ax.set_xlim(-0.5, 5)
ax.set_ylim(-0.5, 5)
ax.set_xlabel("cecha 1 (np. gradient pionowy)", fontsize=FS)
ax.set_ylabel("cecha 2 (np. gradient poziomy)", fontsize=FS)
ax.legend(fontsize=FS_SMALL, loc="lower right")
ax.set_aspect("equal")
save_fig(fig, "q24_svm_hyperplane.png")

View File

@ -0,0 +1,380 @@
"""HOG + SVM pipeline, HOG gradient steps, Viola-Jones cascade."""
from __future__ import annotations
from _q24_common import (
_DOTS_STAGE_IDX,
_GRADIENT_BRIGHT_THRESH,
_PIXEL_BRIGHT_THRESH,
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
np,
plt,
save_fig,
)
import matplotlib.patches as mpatches
# ============================================================
# 1. HOG + SVM Pipeline
# ============================================================
def draw_hog_svm_pipeline() -> None:
"""Draw hog svm pipeline."""
fig, ax = plt.subplots(figsize=(10, 4.5))
ax.set_xlim(-0.5, 10.5)
ax.set_ylim(-1, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"HOG + SVM — pipeline detekcji pieszych",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
# Step 1: Image with sliding window
ax.add_patch(
mpatches.Rectangle((0, 1.5), 2, 2, lw=1.5, edgecolor=LN, facecolor=GRAY1)
)
ax.text(1, 2.5, "Obraz\nwejściowy", ha="center", va="center", fontsize=FS)
# sliding window overlay
ax.add_patch(
mpatches.Rectangle(
(0.3, 1.8),
0.8,
1.2,
lw=1.5,
edgecolor="black",
facecolor="none",
linestyle="--",
)
)
ax.text(
0.7,
1.35,
"okno 64x128",
ha="center",
va="center",
fontsize=FS_SMALL,
style="italic",
)
draw_arrow(ax, 2.1, 2.5, 2.8, 2.5, lw=1.5)
ax.text(2.45, 2.75, "", ha="center", fontsize=FS_LABEL, fontweight="bold")
# Step 2: Gradient computation
draw_box(
ax, 2.9, 1.8, 1.6, 1.4, "Oblicz\ngradienty\nGx, Gy", fill=GRAY4, fontsize=FS
)
ax.text(
3.7, 1.55, "kierunek + siła", ha="center", fontsize=FS_SMALL, style="italic"
)
draw_arrow(ax, 4.6, 2.5, 5.2, 2.5, lw=1.5)
ax.text(4.9, 2.75, "", ha="center", fontsize=FS_LABEL, fontweight="bold")
# Step 3: HOG histogram
draw_box(
ax,
5.3,
1.8,
1.6,
1.4,
"Histogramy\nkierunkowe\n9 binów/cel",
fill=GRAY4,
fontsize=FS,
)
ax.text(6.1, 1.55, "komórki 8x8 px", ha="center", fontsize=FS_SMALL, style="italic")
draw_arrow(ax, 7.0, 2.5, 7.6, 2.5, lw=1.5)
ax.text(7.3, 2.75, "", ha="center", fontsize=FS_LABEL, fontweight="bold")
# Step 4: SVM
draw_box(
ax,
7.7,
1.8,
1.4,
1.4,
"SVM\nklasyfikator\npieszy/tło",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 9.2, 2.5, 9.7, 2.5, lw=1.5)
ax.text(9.45, 2.75, "", ha="center", fontsize=FS_LABEL, fontweight="bold")
# Step 5: NMS + output
draw_box(ax, 9.3, 2.0, 1.0, 1.0, "NMS\n→ wynik", fill=GRAY1, fontsize=FS)
# Bottom: HOG feature vector illustration
ax.text(
5.0,
0.7,
"Wektor HOG: 3780 cech = 105 bloków x 4 komórki x 9 binów",
ha="center",
fontsize=FS,
style="italic",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# Show small histogram bars
bar_x = 3.2
bar_y = 0.0
angles = [0, 20, 40, 60, 80, 100, 120, 140, 160]
values = [0.3, 0.1, 0.5, 0.8, 0.2, 0.6, 0.15, 0.4, 0.25]
for i, (_a, v) in enumerate(zip(angles, values, strict=False)):
ax.add_patch(
mpatches.Rectangle(
(bar_x + i * 0.18, bar_y),
0.15,
v * 0.6,
facecolor=GRAY3,
edgecolor=LN,
lw=0.5,
)
)
ax.text(bar_x + 0.8, -0.2, "9 binów (0°-160°)", ha="center", fontsize=FS_SMALL)
save_fig(fig, "q24_hog_svm_pipeline.png")
# ============================================================
# 2. HOG Gradient Step-by-Step
# ============================================================
def draw_hog_gradient_steps() -> None:
"""Draw hog gradient steps."""
fig, axes = plt.subplots(1, 4, figsize=(12, 3.5))
fig.suptitle(
"HOG — kroki obliczania cech", fontsize=FS_TITLE, fontweight="bold", y=1.02
)
# Step 1: Original patch
ax = axes[0]
patch = np.array([[50, 50, 200], [50, 50, 200], [50, 50, 200]])
ax.imshow(patch, cmap="gray", vmin=0, vmax=255)
for i in range(3):
for j in range(3):
ax.text(
j,
i,
str(patch[i, j]),
ha="center",
va="center",
fontsize=FS_LABEL,
fontweight="bold",
color="white" if patch[i, j] > _PIXEL_BRIGHT_THRESH else "black",
)
ax.set_title("① Fragment obrazu\n(jasność pikseli)", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Step 2: Gradient magnitude
ax = axes[1]
gx = np.array([[0, 150, 0], [0, 150, 0], [0, 150, 0]])
ax.imshow(gx, cmap="gray", vmin=0, vmax=255)
for i in range(3):
for j in range(3):
ax.text(
j,
i,
str(gx[i, j]),
ha="center",
va="center",
fontsize=FS_LABEL,
fontweight="bold",
color="white" if gx[i, j] > _GRADIENT_BRIGHT_THRESH else "black",
)
ax.set_title("② Gradient Gx\n(krawędź pionowa!)", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Step 3: Cell histogram
ax = axes[2]
angles = ["", "20°", "40°", "60°", "80°", "100°", "120°", "140°", "160°"]
values = [150, 0, 0, 0, 0, 0, 0, 0, 0]
bars = ax.bar(range(9), values, color=GRAY3, edgecolor=LN, linewidth=0.5)
bars[0].set_facecolor(GRAY5)
ax.set_xticks(range(9))
ax.set_xticklabels(angles, fontsize=5, rotation=45)
ax.set_title(
"③ Histogram komórki\n(bin 0° = krawędź pionowa)",
fontsize=FS,
fontweight="bold",
)
ax.set_ylabel("siła", fontsize=FS_SMALL)
# Step 4: Block normalization
ax = axes[3]
# 2x2 block of cells
for i in range(2):
for j in range(2):
rect = mpatches.Rectangle(
(j * 1.2, (1 - i) * 1.2),
1.0,
1.0,
lw=1.2,
edgecolor=LN,
facecolor=GRAY4,
)
ax.add_patch(rect)
ax.text(
j * 1.2 + 0.5,
(1 - i) * 1.2 + 0.5,
f"hist\n{i * 2 + j + 1}",
ha="center",
va="center",
fontsize=FS_SMALL,
)
ax.add_patch(
mpatches.Rectangle(
(-0.1, -0.1), 2.6, 2.6, lw=2, edgecolor=LN, facecolor="none", linestyle="--"
)
)
ax.text(
1.2,
-0.4,
"blok 2x2 → L2-norm",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
)
ax.set_xlim(-0.3, 2.8)
ax.set_ylim(-0.7, 2.8)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"④ Normalizacja bloków\n(odporność na oświetlenie)",
fontsize=FS,
fontweight="bold",
)
fig.tight_layout()
save_fig(fig, "q24_hog_gradient_steps.png")
# ============================================================
# 3. Viola-Jones Cascade
# ============================================================
def draw_viola_jones_cascade() -> None:
"""Draw viola jones cascade."""
fig, ax = plt.subplots(figsize=(10, 5))
ax.set_xlim(-0.5, 10.5)
ax.set_ylim(-1.5, 5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Viola-Jones — kaskada klasyfikatorów (SITO)",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
# Input
draw_box(
ax,
-0.3,
2.5,
1.5,
1.2,
"500 000\nokien",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
stages = [
("Etap 1\n2 cechy", "50%\nodrzucone", "250 000", GRAY4),
("Etap 2\n10 cech", "80%\nodrzucone", "50 000", GRAY4),
("Etap 3\n25 cech", "90%\nodrzucone", "5 000", GRAY4),
("Etap 25\n200 cech", "99%\nodrzucone", "50", GRAY3),
]
x_pos = 1.6
for i, (label, reject, remain, col) in enumerate(stages):
# Stage box
draw_box(
ax, x_pos, 2.5, 1.6, 1.2, label, fill=col, fontsize=FS, fontweight="bold"
)
# Arrow from previous
draw_arrow(ax, x_pos - 0.3, 3.1, x_pos - 0.05, 3.1, lw=1.5)
# Reject arrow down
draw_arrow(ax, x_pos + 0.8, 2.45, x_pos + 0.8, 1.6, lw=1.2)
ax.text(
x_pos + 0.8,
1.3,
reject,
ha="center",
fontsize=FS_SMALL,
color="black",
style="italic",
)
ax.text(
x_pos + 0.8,
0.8,
"✗ NIE-TWARZ",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
)
# Remaining count above
if i < len(stages) - 1:
ax.text(
x_pos + 2.0,
3.9,
f"{remain}",
ha="center",
fontsize=FS_SMALL,
style="italic",
)
# Dots between stage 3 and stage 25
if i == _DOTS_STAGE_IDX:
ax.text(
x_pos + 2.0, 3.1, "· · ·", ha="center", fontsize=12, fontweight="bold"
)
x_pos += 2.5
else:
x_pos += 2.1
# Final output
draw_arrow(ax, x_pos + 0.3, 3.1, x_pos + 0.9, 3.1, lw=1.5)
draw_box(
ax,
x_pos + 0.5,
2.5,
1.3,
1.2,
"~50\nTWARZE\n",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Timing info
ax.text(
5.0,
-0.5,
"Czas: 99% okien odrzucone w etapach 1-3 (~5 μs każde)\n"
"Tylko 0.01% dochodzi do etapu 25 → cały obraz w ~30 ms = 30+ fps",
ha="center",
fontsize=FS,
style="italic",
bbox={"boxstyle": "round,pad=0.4", "facecolor": GRAY4, "edgecolor": GRAY3},
)
save_fig(fig, "q24_viola_jones_cascade.png")

View File

@ -0,0 +1,413 @@
"""IoU diagram, NMS steps, and detector-from-classifier diagrams."""
from __future__ import annotations
from _q24_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
LN,
draw_arrow,
draw_box,
plt,
save_fig,
)
import matplotlib.patches as mpatches
# ============================================================
# 8. IoU Diagram
# ============================================================
def draw_iou_diagram() -> None:
"""Draw iou diagram."""
fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
fig.suptitle(
"IoU (Intersection over Union) — miara nakładania bboxów",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Low IoU
ax = axes[0]
ax.add_patch(
mpatches.Rectangle(
(0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5, label="A"
)
)
ax.add_patch(
mpatches.Rectangle(
(2.5, 2.5),
3,
3,
facecolor=GRAY2,
edgecolor=LN,
lw=1.5,
alpha=0.7,
label="B",
)
)
# Intersection
ax.add_patch(
mpatches.Rectangle((2.5, 2.5), 0.5, 0.5, facecolor=GRAY3, edgecolor=LN, lw=2)
)
ax.text(1.5, 1.5, "A", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(4, 4, "B", ha="center", va="center", fontsize=12, fontweight="bold")
ax.set_xlim(-0.5, 6)
ax.set_ylim(-0.5, 6)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"IoU ≈ 0.04\n(prawie się nie nakładają)", fontsize=FS, fontweight="bold"
)
# Medium IoU
ax = axes[1]
ax.add_patch(
mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
ax.add_patch(
mpatches.Rectangle(
(1.5, 1.5), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7
)
)
ax.add_patch(
mpatches.Rectangle((1.5, 1.5), 1.5, 1.5, facecolor=GRAY3, edgecolor=LN, lw=2)
)
ax.text(0.7, 0.7, "A", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(2.25, 2.25, "", ha="center", va="center", fontsize=14, fontweight="bold")
ax.set_xlim(-0.5, 5)
ax.set_ylim(-0.5, 5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("IoU ≈ 0.14\n(częściowe nakładanie)", fontsize=FS, fontweight="bold")
# High IoU
ax = axes[2]
ax.add_patch(
mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
ax.add_patch(
mpatches.Rectangle(
(0.3, 0.3), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7
)
)
ax.add_patch(
mpatches.Rectangle((0.3, 0.3), 2.7, 2.7, facecolor=GRAY3, edgecolor=LN, lw=2)
)
ax.text(-0.3, -0.3, "A", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(1.65, 1.65, "", ha="center", va="center", fontsize=14, fontweight="bold")
ax.set_xlim(-0.8, 4)
ax.set_ylim(-0.8, 4)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"IoU ≈ 0.74\n(duże nakładanie → duplikat!)", fontsize=FS, fontweight="bold"
)
fig.tight_layout()
save_fig(fig, "q24_iou_diagram.png")
# ============================================================
# 9. NMS Step-by-Step
# ============================================================
def draw_nms_steps() -> None:
"""Draw nms steps."""
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle(
"NMS (Non-Maximum Suppression) — usuwanie duplikatów",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Before NMS
ax = axes[0]
ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1))
# Multiple overlapping boxes for same object
ax.add_patch(
mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2)
)
ax.text(2.25, 4.2, "conf=0.95", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.add_patch(
mpatches.Rectangle(
(1.2, 1.3), 2.3, 2.8, facecolor="none", edgecolor=LN, lw=1.5, linestyle="--"
)
)
ax.text(2.35, 1.1, "conf=0.90", ha="center", fontsize=FS_SMALL)
ax.add_patch(
mpatches.Rectangle(
(0.8, 0.8), 2.7, 3.2, facecolor="none", edgecolor=LN, lw=1, linestyle=":"
)
)
ax.text(2.15, 0.6, "conf=0.85", ha="center", fontsize=FS_SMALL)
# Different object
ax.add_patch(
mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=1.5)
)
ax.text(4.75, 3.7, "conf=0.80", ha="center", fontsize=FS_SMALL)
ax.text(
2,
0.2,
"⚠ 4 detekcje (3 duplikaty!)",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
)
ax.set_xlim(-0.3, 6.3)
ax.set_ylim(-0.3, 5.3)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"① Przed NMS\n(wiele nakładających się)", fontsize=FS, fontweight="bold"
)
# NMS process
ax = axes[1]
ax.axis("off")
ax.set_xlim(0, 6)
ax.set_ylim(0, 5)
steps = [
("1. Sortuj: [0.95, 0.90, 0.85, 0.80]", 4.5),
("2. Weź najlepszą (0.95) → ZACHOWAJ", 3.7),
("3. IoU(0.95, 0.90)=0.82 > 0.5 → USUŃ", 2.9),
("4. IoU(0.95, 0.85)=0.75 > 0.5 → USUŃ", 2.1),
("5. IoU(0.95, 0.80)=0.10 < 0.5 → ZACHOWAJ", 1.3),
]
colors = [GRAY4, GRAY2, GRAY4, GRAY4, GRAY2]
for (text, yp), c in zip(steps, colors, strict=False):
ax.text(
3.0,
yp,
text,
ha="center",
fontsize=FS,
bbox={"boxstyle": "round,pad=0.2", "facecolor": c, "edgecolor": GRAY3},
)
ax.set_title("② Algorytm NMS\n(próg IoU = 0.5)", fontsize=FS, fontweight="bold")
# After NMS
ax = axes[2]
ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1))
# Only best box for each object
ax.add_patch(
mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2.5)
)
ax.text(2.25, 4.2, "conf=0.95 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.add_patch(
mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5)
)
ax.text(4.75, 3.7, "conf=0.80 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.text(
3,
0.2,
"✓ 2 unikalne obiekty",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
)
ax.set_xlim(-0.3, 6.3)
ax.set_ylim(-0.3, 5.3)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("③ Po NMS\n(1 bbox na obiekt)", fontsize=FS, fontweight="bold")
fig.tight_layout()
save_fig(fig, "q24_nms_steps.png")
# ============================================================
# 10. Detector from Classifier — 3 approaches
# ============================================================
def draw_detector_from_classifier() -> None:
"""Draw detector from classifier."""
fig, ax = plt.subplots(figsize=(11, 9))
ax.set_xlim(-0.5, 11)
ax.set_ylim(-1, 9.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Jak zbudować detektor z klasyfikatora? — 3 podejścia",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
# ---- Approach 1: Sliding Window ----
y = 7.0
ax.text(
0,
y + 1.5,
"① Sliding Window (NAJWOLNIEJSZE)",
fontsize=FS_LABEL,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# Image with sliding window
ax.add_patch(
mpatches.Rectangle(
(0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5
)
)
ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL)
# Sliding windows
for dx, dy in [(0.1, 0.1), (0.4, 0.1), (0.7, 0.1), (0.1, 0.5), (0.4, 0.5)]:
ax.add_patch(
mpatches.Rectangle(
(dx, y - 0.5 + dy),
0.5,
0.5,
facecolor="none",
edgecolor=LN,
lw=0.8,
linestyle="--",
)
)
draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2)
ax.text(2.35, y + 0.6, "xmiliony", fontsize=FS_SMALL, style="italic")
draw_box(
ax,
2.8,
y - 0.3,
1.8,
1.2,
'Klasyfikator\n(ResNet)\n"kot? pies? tło?"',
fill=GRAY4,
fontsize=FS,
)
draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2)
draw_box(ax, 5.4, y - 0.3, 1.2, 1.2, "NMS", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 6.7, y + 0.3, 7.3, y + 0.3, lw=1.2)
ax.text(
8.5,
y + 0.3,
"~3.3h / obraz!\n⚠ NIEPRAKTYCZNE",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# ---- Approach 2: Region Proposals ----
y = 3.8
ax.text(
0,
y + 1.5,
"② Region Proposals + Klasyfikator (= R-CNN)",
fontsize=FS_LABEL,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.add_patch(
mpatches.Rectangle(
(0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5
)
)
ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL)
# A few smart regions
ax.add_patch(
mpatches.Rectangle(
(0.1, y - 0.4), 0.7, 0.9, facecolor="none", edgecolor=LN, lw=1.5
)
)
ax.add_patch(
mpatches.Rectangle(
(0.9, y + 0.0), 0.7, 0.6, facecolor="none", edgecolor=LN, lw=1.5
)
)
draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2)
draw_box(
ax,
2.8,
y - 0.3,
1.6,
1.2,
"Selective\nSearch\n~2000 regionów",
fill=GRAY2,
fontsize=FS,
)
draw_arrow(ax, 4.5, y + 0.3, 5.1, y + 0.3, lw=1.2)
ax.text(4.8, y + 0.6, "x2000", fontsize=FS_SMALL, style="italic")
draw_box(ax, 5.2, y - 0.3, 1.5, 1.2, "Klasyfikator\n(CNN)", fill=GRAY4, fontsize=FS)
draw_arrow(ax, 6.8, y + 0.3, 7.4, y + 0.3, lw=1.2)
draw_box(ax, 7.5, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 8.6, y + 0.3, 9.0, y + 0.3, lw=1.2)
ax.text(
10.0,
y + 0.3,
"~20-50 s/obraz\n(250x szybciej)",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# ---- Approach 3: Fine-tune backbone ----
y = 0.5
ax.text(
0,
y + 1.5,
"③ Fine-tune backbone + detection head (NAJLEPSZE)",
fontsize=FS_LABEL,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY2, "edgecolor": GRAY3},
)
ax.add_patch(
mpatches.Rectangle(
(0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5
)
)
ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL)
draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2)
draw_box(
ax,
2.8,
y - 0.3,
1.8,
1.2,
"Pretrained\nbackbone\n(ResNet)",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2)
# Two heads from feature map
draw_box(ax, 5.4, y + 0.3, 1.6, 0.6, "cls head\nP(klasa)", fill=GRAY4, fontsize=FS)
draw_box(
ax, 5.4, y - 0.5, 1.6, 0.6, "bbox head\nΔx,Δy,Δw,Δh", fill=GRAY4, fontsize=FS
)
draw_arrow(ax, 7.1, y + 0.6, 7.7, y + 0.6, lw=1.0)
draw_arrow(ax, 7.1, y - 0.2, 7.7, y - 0.2, lw=1.0)
draw_box(ax, 7.8, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 8.9, y + 0.3, 9.3, y + 0.3, lw=1.2)
ax.text(
10.2,
y + 0.3,
"5-155 fps!\n✓ NAJLEPSZE",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3},
)
save_fig(fig, "q24_detector_from_classifier.png")

View File

@ -0,0 +1,365 @@
"""Two-stage vs one-stage table, ROI pooling, DETR, and sliding window."""
from __future__ import annotations
from _q24_common import (
_DATA_BRIGHT_THRESH,
FS,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
draw_table,
np,
plt,
rng,
save_fig,
)
import matplotlib.patches as mpatches
# ============================================================
# 12. Two-stage vs One-stage comparison table
# ============================================================
def draw_two_vs_one_stage() -> None:
"""Draw two vs one stage."""
fig, ax = plt.subplots(figsize=(10, 3.5))
ax.set_xlim(0, 10)
ax.set_ylim(-0.5, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Two-stage vs One-stage — porównanie",
fontsize=FS_TITLE,
fontweight="bold",
pad=8,
)
headers = ["Cecha", "Two-stage\n(Faster R-CNN)", "One-stage\n(YOLO)"]
rows = [
["Szybkość", "~5 fps", "45-155 fps"],
["Dokładność (mAP)", "wyższa (historycznie)", "dorównuje (YOLOv8)"],
["Małe obiekty", "lepszy", "gorszy (SSD/FPN pomaga)"],
["Architektura", "2 etapy + NMS", "1 etap + NMS"],
["Real-time?", "NIE", "TAK"],
]
col_widths = [2.5, 3.5, 3.5]
draw_table(
ax,
headers,
rows,
0.2,
3.8,
col_widths,
row_h=0.65,
fontsize=FS,
header_fontsize=FS,
)
save_fig(fig, "q24_two_vs_one_stage.png")
# ============================================================
# 13. ROI Pooling illustration
# ============================================================
def draw_roi_pooling() -> None:
"""Draw roi pooling."""
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle(
"ROI Pooling — dowolny rozmiar → stały rozmiar",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Feature map with ROI
ax = axes[0]
# Draw feature map grid
fm = rng.integers(0, 10, (8, 8))
ax.imshow(fm, cmap="gray", vmin=0, vmax=10, alpha=0.3)
for i in range(9):
ax.axhline(y=i - 0.5, color=LN, lw=0.3)
ax.axvline(x=i - 0.5, color=LN, lw=0.3)
# ROI rectangle
ax.add_patch(
mpatches.Rectangle(
(1.5, 1.5), 4, 4, facecolor="none", edgecolor=LN, lw=3, linestyle="-"
)
)
ax.text(3.5, 0.8, "ROI", ha="center", fontsize=FS, fontweight="bold")
ax.set_xlim(-0.5, 7.5)
ax.set_ylim(7.5, -0.5)
ax.set_title("① Feature map\nz zaznaczonym ROI", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# ROI divided into grid
ax = axes[1]
roi_data = np.array(
[
[1, 3, 2, 1],
[0, 5, 1, 6],
[0, 4, 1, 0],
[7, 2, 9, 1],
]
)
ax.imshow(roi_data, cmap="gray", vmin=0, vmax=10)
for i in range(5):
ax.axhline(y=i - 0.5, color=LN, lw=1)
ax.axvline(x=i - 0.5, color=LN, lw=1)
# Grid lines for 2x2 pooling
ax.axhline(y=0.5, color=LN, lw=3, linestyle="--")
ax.axvline(x=0.5, color=LN, lw=3, linestyle="--")
for i in range(4):
for j in range(4):
ax.text(
j,
i,
str(roi_data[i, j]),
ha="center",
va="center",
fontsize=10,
fontweight="bold",
color="white" if roi_data[i, j] > _DATA_BRIGHT_THRESH else "black",
)
ax.set_title("② ROI podzielony\nna siatkę 2x2", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
# Output after pooling
ax = axes[2]
out = np.array([[5, 6], [7, 9]])
ax.imshow(out, cmap="gray", vmin=0, vmax=10)
for i in range(3):
ax.axhline(y=i - 0.5, color=LN, lw=1.5)
ax.axvline(x=i - 0.5, color=LN, lw=1.5)
for i in range(2):
for j in range(2):
ax.text(
j,
i,
str(out[i, j]),
ha="center",
va="center",
fontsize=14,
fontweight="bold",
color="white" if out[i, j] > _DATA_BRIGHT_THRESH else "black",
)
ax.set_title(
"③ Po ROI Pool 2x2\n(max z każdej komórki)", fontsize=FS, fontweight="bold"
)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
save_fig(fig, "q24_roi_pooling.png")
# ============================================================
# 14. DETR Pipeline
# ============================================================
def draw_detr_pipeline() -> None:
"""Draw detr pipeline."""
fig, ax = plt.subplots(figsize=(11, 4.5))
ax.set_xlim(-0.5, 11.5)
ax.set_ylim(-1, 4.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"DETR — Transformer do detekcji (bez NMS, bez anchorów)",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
# Pipeline
draw_box(ax, 0, 1.5, 1.5, 1.5, "Obraz\nwejściowy", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 1.6, 2.25, 2.1, 2.25, lw=1.5)
draw_box(
ax,
2.2,
1.5,
1.5,
1.5,
"CNN\nBackbone\n(ResNet)",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 3.8, 2.25, 4.3, 2.25, lw=1.5)
draw_box(
ax,
4.4,
1.5,
1.8,
1.5,
"Transformer\nEncoder\n(self-attention)",
fill=GRAY2,
fontsize=FS,
)
draw_arrow(ax, 6.3, 2.25, 6.8, 2.25, lw=1.5)
draw_box(
ax,
6.9,
1.5,
1.8,
1.5,
"Transformer\nDecoder\n(N=100 queries)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
# Output branches
draw_arrow(ax, 8.8, 2.5, 9.5, 3.0, lw=1.2)
draw_box(ax, 9.6, 2.7, 1.5, 0.7, "klasa₁...klasa₁₀₀", fill=GRAY4, fontsize=FS_SMALL)
draw_arrow(ax, 8.8, 2.0, 9.5, 1.5, lw=1.2)
draw_box(ax, 9.6, 1.2, 1.5, 0.7, "bbox₁...bbox₁₀₀", fill=GRAY4, fontsize=FS_SMALL)
# Annotations
ax.text(
7.8,
0.5,
'100 object queries → 5 obiektów + 95x "brak"',
ha="center",
fontsize=FS,
style="italic",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.text(
5.5,
0.0,
"Hungarian matching (trening): optymalne dopasowanie predykcji do GT",
ha="center",
fontsize=FS_SMALL,
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
)
# Big benefit box
ax.text(
5.5,
4.0,
"BEZ anchorów • BEZ NMS • end-to-end • prosty pipeline",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3},
)
save_fig(fig, "q24_detr_pipeline.png")
# ============================================================
# 15. Sliding Window illustration
# ============================================================
def draw_sliding_window() -> None:
"""Draw sliding window."""
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle(
"Sliding Window — najprostsze podejście do detekcji",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Multi-position
ax = axes[0]
ax.add_patch(
mpatches.Rectangle((0, 0), 8, 6, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
# Grid of sliding windows
for i in range(4):
for j in range(3):
ax.add_patch(
mpatches.Rectangle(
(i * 1.8 + 0.2, j * 1.8 + 0.2),
1.5,
1.5,
facecolor="none",
edgecolor=LN,
lw=0.6,
linestyle="--",
)
)
# Highlight current window
ax.add_patch(
mpatches.Rectangle((2.0, 2.0), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5)
)
ax.set_xlim(-0.5, 8.5)
ax.set_ylim(-0.5, 6.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("① Wiele pozycji\n(krok co 8 px)", fontsize=FS, fontweight="bold")
# Multi-scale
ax = axes[1]
ax.add_patch(
mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1.5)
)
sizes = [(0.8, 0.8), (1.5, 1.5), (2.5, 2.5), (3.5, 3.5)]
for i, (w, h) in enumerate(sizes):
ax.add_patch(
mpatches.Rectangle(
(0.3 + i * 0.3, 0.3 + i * 0.3),
w,
h,
facecolor="none",
edgecolor=LN,
lw=1 + i * 0.3,
linestyle=[":", "--", "-.", "-"][i],
)
)
ax.text(3, 0, "4+ skal", ha="center", fontsize=FS_SMALL, fontweight="bold")
ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-0.5, 5.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"② Wiele skal\n(obiekty mają różne rozmiary)", fontsize=FS, fontweight="bold"
)
# Count
ax = axes[2]
ax.axis("off")
ax.set_xlim(0, 6)
ax.set_ylim(0, 5)
lines = [
("Obraz: 640 x 480 px", 4.5),
("Okno: 64 x 64 px, krok 8 px", 3.8),
("Pozycje: ~72 x 52 = 3 744", 3.1),
("x 5 skal = 18 720 okien", 2.4),
("x klasyfikacja = WOLNE!", 1.7),
("→ ~3h na jeden obraz", 0.8),
]
for text, yp in lines:
fw = "bold" if "~3h" in text or "WOLNE" in text else "normal"
col = GRAY2 if "WOLNE" in text or "~3h" in text else GRAY4
ax.text(
3.0,
yp,
text,
ha="center",
fontsize=FS,
fontweight=fw,
bbox={"boxstyle": "round,pad=0.2", "facecolor": col, "edgecolor": GRAY3},
)
ax.set_title(
"③ Dlaczego wolne?\n(miliony klasyfikacji)", fontsize=FS, fontweight="bold"
)
fig.tight_layout()
save_fig(fig, "q24_sliding_window.png")

View File

@ -0,0 +1,344 @@
"""R-CNN evolution and YOLO grid diagrams."""
from __future__ import annotations
from typing import TYPE_CHECKING
from _q24_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
LN,
draw_arrow,
draw_box,
plt,
save_fig,
)
import matplotlib.patches as mpatches
if TYPE_CHECKING:
from matplotlib.axes import Axes
def _draw_yolo_cell_prediction(ax: Axes) -> None:
"""Draw YOLO cell prediction vector panel."""
ax.axis("off")
ax.set_xlim(0, 6)
ax.set_ylim(-1, 5)
labels = [
"x",
"y",
"w",
"h",
"conf",
"x",
"y",
"w",
"h",
"conf",
"P(c₁)",
"...",
"P(c₂₀)",
]
colors_vec = [GRAY4] * 5 + [GRAY2] * 5 + [GRAY1] * 3
bw = 0.42
for i, (label, col) in enumerate(
zip(labels, colors_vec, strict=False),
):
x_pos = 0.3 + i * bw
ax.add_patch(
mpatches.Rectangle(
(x_pos, 2.5),
bw - 0.02,
0.6,
facecolor=col,
edgecolor=LN,
lw=0.8,
)
)
ax.text(
x_pos + bw / 2,
2.8,
label,
ha="center",
va="center",
fontsize=5,
fontweight="bold",
)
ax.annotate(
"",
xy=(0.3, 2.4),
xytext=(2.4, 2.4),
arrowprops={"arrowstyle": "-", "lw": 1},
)
ax.text(1.35, 2.15, "bbox 1 (5 wartości)", ha="center", fontsize=FS_SMALL)
ax.annotate(
"",
xy=(2.4, 2.4),
xytext=(4.5, 2.4),
arrowprops={"arrowstyle": "-", "lw": 1},
)
ax.text(3.45, 2.15, "bbox 2 (5 wartości)", ha="center", fontsize=FS_SMALL)
ax.annotate(
"",
xy=(4.5, 2.4),
xytext=(5.8, 2.4),
arrowprops={"arrowstyle": "-", "lw": 1},
)
ax.text(5.15, 2.15, "20 klas", ha="center", fontsize=FS_SMALL)
ax.text(
3.0,
3.5,
"Każda komórka → 30 wartości\n= 2x(x,y,w,h,conf) + 20 klas",
ha="center",
fontsize=FS,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.set_title(
"② Predykcja jednej komórki\n(S=7, B=2, C=20)",
fontsize=FS,
fontweight="bold",
)
# ============================================================
# 6. R-CNN Evolution
# ============================================================
def draw_rcnn_evolution() -> None:
"""Draw rcnn evolution."""
fig, ax = plt.subplots(figsize=(11, 7))
ax.set_xlim(-0.5, 11)
ax.set_ylim(-0.5, 7.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Ewolucja R-CNN: od 50s do 0.2s na obraz",
fontsize=FS_TITLE,
fontweight="bold",
pad=12,
)
y_positions = [5.5, 3.0, 0.5]
labels = [
"R-CNN (2014) — 50 s/obraz",
"Fast R-CNN (2015) — 2 s/obraz",
"Faster R-CNN (2015) — 0.2 s/obraz",
]
# R-CNN
y = y_positions[0]
ax.text(0, y + 1.3, labels[0], fontsize=FS_LABEL, fontweight="bold")
draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS)
draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45)
ax.text(2.3, y + 0.8, "~2000", ha="center", fontsize=FS_SMALL, style="italic")
draw_box(ax, 2.6, y, 1.5, 0.9, "Resize\n224x224", fill=GRAY4, fontsize=FS)
draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45)
draw_box(
ax, 4.7, y, 1.5, 0.9, "CNN\nx2000!", fill=GRAY3, fontsize=FS, fontweight="bold"
)
draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45)
draw_box(ax, 6.8, y, 1.3, 0.9, "SVM\nklasyf.", fill=GRAY4, fontsize=FS)
draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45)
draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS)
# Problem annotation
ax.text(
5.5,
y - 0.4,
"⚠ CNN uruchamiane 2000x → 50 sek!",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# Fast R-CNN
y = y_positions[1]
ax.text(0, y + 1.3, labels[1], fontsize=FS_LABEL, fontweight="bold")
draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS)
draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45)
draw_box(
ax,
2.6,
y,
1.5,
0.9,
"CNN\nx1 (RAZ!)",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45)
draw_box(
ax, 4.7, y, 1.5, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS, fontweight="bold"
)
draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45)
draw_box(ax, 6.8, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS)
draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45)
draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS)
ax.text(
3.8,
y - 0.4,
"✓ CNN RAZ na cały obraz → 25x szybciej",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# Faster R-CNN
y = y_positions[2]
ax.text(0, y + 1.3, labels[2], fontsize=FS_LABEL, fontweight="bold")
draw_box(
ax,
0.5,
y,
1.5,
0.9,
"CNN\nBackbone",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45)
draw_box(ax, 2.6, y, 1.5, 0.9, "Feature\nMap", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45)
draw_box(
ax,
4.7,
y,
1.3,
0.9,
"RPN\n(w sieci!)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 6.1, y + 0.45, 6.5, y + 0.45)
draw_box(ax, 6.6, y, 1.3, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 8.0, y + 0.45, 8.4, y + 0.45)
draw_box(ax, 8.5, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS)
ax.text(
5.0,
y - 0.4,
"✓ RPN zastępuje Selective Search → end-to-end",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
save_fig(fig, "q24_rcnn_evolution.png")
# ============================================================
# 7. YOLO Grid
# ============================================================
def draw_yolo_grid() -> None:
"""Draw yolo grid."""
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle(
"YOLO — detekcja jednoetapowa (siatka SxS)",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
# Grid on image
ax = axes[0]
grid_size = 7
ax.set_xlim(0, grid_size)
ax.set_ylim(0, grid_size)
for i in range(grid_size + 1):
ax.axhline(y=i, color=LN, lw=0.5, alpha=0.5)
ax.axvline(x=i, color=LN, lw=0.5, alpha=0.5)
ax.add_patch(
mpatches.Rectangle(
(0, 0),
grid_size,
grid_size,
facecolor=GRAY4,
edgecolor=LN,
lw=1.5,
)
)
# Highlight one cell
ax.add_patch(mpatches.Rectangle((3, 3), 1, 1, facecolor=GRAY2, edgecolor=LN, lw=2))
# Object center dot
ax.plot(3.5, 3.5, "ko", markersize=8)
# Bounding box from that cell
ax.add_patch(
mpatches.Rectangle(
(2.0, 2.2), 3.0, 2.6, facecolor="none", edgecolor=LN, lw=2, linestyle="--"
)
)
ax.text(
3.5,
1.8,
"bbox z komórki (3,3)",
ha="center",
fontsize=FS_SMALL,
fontweight="bold",
)
ax.set_aspect("equal")
ax.invert_yaxis()
ax.set_title("① Siatka 7x7\nna obrazie", fontsize=FS, fontweight="bold")
ax.set_xticks([])
ax.set_yticks([])
_draw_yolo_cell_prediction(axes[1])
# Speed comparison
ax = axes[2]
ax.axis("off")
ax.set_xlim(0, 5)
ax.set_ylim(0, 5)
methods = ["R-CNN", "Fast R-CNN", "Faster R-CNN", "YOLO", "YOLOv8"]
fps_vals = [0.02, 0.5, 5, 45, 100]
bar_colors = [GRAY3, GRAY3, GRAY3, GRAY2, GRAY1]
for i, (m, f, c) in enumerate(zip(methods, fps_vals, bar_colors, strict=False)):
bar_w = f / 100 * 4.0
y_pos = 4.0 - i * 0.8
ax.add_patch(
mpatches.Rectangle(
(0.5, y_pos), max(bar_w, 0.1), 0.5, facecolor=c, edgecolor=LN, lw=0.8
)
)
ax.text(
0.4,
y_pos + 0.25,
m,
ha="right",
va="center",
fontsize=FS,
fontweight="bold",
)
ax.text(
max(0.7, 0.5 + bar_w + 0.1),
y_pos + 0.25,
f"{f} fps",
ha="left",
va="center",
fontsize=FS,
)
ax.set_title(
"③ Porównanie szybkości\n(fps = klatki/sek)", fontsize=FS, fontweight="bold"
)
fig.tight_layout()
save_fig(fig, "q24_yolo_grid.png")

View File

@ -0,0 +1,102 @@
"""Common constants and utilities for Q31 diagrams."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
if TYPE_CHECKING:
from matplotlib.axes import Axes
_logger = logging.getLogger(__name__)
DPI = 300
BG = "white"
LN = "black"
FS = 8
FS_TITLE = 11
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
GRAY1 = "#E8E8E8"
GRAY2 = "#D0D0D0"
GRAY3 = "#B8B8B8"
GRAY4 = "#F5F5F5"
GRAY5 = "#C0C0C0"
# Number of regret table header columns before the max-regret column
_REGRET_HEADER_COLS = 4
# Number of data state columns
_DATA_STATE_COLS = 3
# Expected-value for the winning alternative
_WINNING_EV = 95
def draw_box(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
text: str,
*,
fill: str = "white",
lw: float = 1.2,
fontsize: float = FS,
fontweight: str = "normal",
ha: str = "center",
va: str = "center",
rounded: bool = True,
) -> None:
"""Draw a labeled box on the axes."""
if rounded:
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.05",
lw=lw,
edgecolor=LN,
facecolor=fill,
)
else:
rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h / 2,
text,
ha=ha,
va=va,
fontsize=fontsize,
fontweight=fontweight,
wrap=True,
)
def draw_arrow(
ax: Axes,
x1: float,
y1: float,
x2: float,
y2: float,
*,
lw: float = 1.2,
style: str = "->",
color: str = LN,
) -> None:
"""Draw an arrow between two points."""
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={
"arrowstyle": style,
"color": color,
"lw": lw,
},
)

View File

@ -0,0 +1,256 @@
"""Q31 Diagram 1: Payoff matrix + all criteria bar chart."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
_DATA_STATE_COLS,
BG,
DPI,
FS,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
OUTPUT_DIR,
_logger,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
def _draw_payoff_table(ax: Axes) -> None:
"""Draw the payoff matrix table on the left panel."""
ax.axis("off")
ax.set_xlim(0, 6)
ax.set_ylim(0, 6)
ax.set_title(
"Macierz wypłat (tys. zł)",
fontsize=FS_TITLE,
fontweight="bold",
pad=8,
)
headers_col = ["", "S₁\n(dobra)", "S₂\n(średnia)", "S₃\n(zła)"]
rows = [
["A₁ (fabryka)", "200", "50", "-100"],
["A₂ (sklep)", "80", "70", "40"],
["A₃ (obligacje)", "30", "30", "30"],
]
col_w = [1.8, 1.2, 1.2, 1.2]
row_h = 0.7
start_y = 4.5
start_x = 0.2
# Draw header row
x = start_x
for j, h in enumerate(headers_col):
fill = GRAY2 if j > 0 else GRAY3
rect = mpatches.Rectangle(
(x, start_y),
col_w[j],
row_h,
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + col_w[j] / 2,
start_y + row_h / 2,
h,
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
x += col_w[j]
# Draw data rows
for i, row in enumerate(rows):
x = start_x
y = start_y - (i + 1) * row_h
for j, val in enumerate(row):
fill = GRAY4 if j == 0 else ("white" if i % 2 == 0 else GRAY1)
if val.startswith("-"):
fill = "#D8D8D8"
rect = mpatches.Rectangle(
(x, y),
col_w[j],
row_h,
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
fw = "bold" if j == 0 else "normal"
ax.text(
x + col_w[j] / 2,
y + row_h / 2,
val,
ha="center",
va="center",
fontsize=FS,
fontweight=fw,
)
x += col_w[j]
# Probability row for EV
x = start_x
y = start_y - 4 * row_h
probs = ["p (dla E[X]):", "0.5", "0.3", "0.2"]
for j, val in enumerate(probs):
fill = GRAY5 if j > 0 else GRAY3
rect = mpatches.Rectangle(
(x, y),
col_w[j],
row_h * 0.7,
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + col_w[j] / 2,
y + row_h * 0.35,
val,
ha="center",
va="center",
fontsize=7,
fontweight="bold",
style="italic",
)
x += col_w[j]
def _draw_criteria_bars(ax2: Axes) -> None:
"""Draw the criteria comparison bar chart on the right panel."""
criteria = [
"E[X]",
"Laplace",
"Maximax",
"Maximin",
"Hurwicz\n\u03b1=0.6",
"Savage",
]
ev = [95, 69, 30]
laplace = [50, 63.3, 30]
maximax = [200, 80, 30]
maximin = [-100, 40, 30]
hurwicz = [80, 64, 30]
savage_maxregret = [140, 120, 170]
winners = [0, 1, 0, 1, 0, 1]
x_pos = np.arange(len(criteria))
width = 0.22
hatches = ["///", "...", "xxx"]
labels = ["A₁ (fabryka)", "A₂ (sklep)", "A₃ (obligacje)"]
all_vals = [
[
ev[0],
laplace[0],
maximax[0],
maximin[0],
hurwicz[0],
savage_maxregret[0],
],
[
ev[1],
laplace[1],
maximax[1],
maximin[1],
hurwicz[1],
savage_maxregret[1],
],
[
ev[2],
laplace[2],
maximax[2],
maximin[2],
hurwicz[2],
savage_maxregret[2],
],
]
for i in range(_DATA_STATE_COLS):
ax2.bar(
x_pos + (i - 1) * width,
all_vals[i],
width,
label=labels[i],
color="white",
edgecolor=LN,
hatch=hatches[i],
lw=0.8,
)
for c_idx in range(len(criteria)):
w = winners[c_idx]
val = all_vals[w][c_idx]
ax2.text(
x_pos[c_idx] + (w - 1) * width,
val + 5,
"",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax2.set_xticks(x_pos)
ax2.set_xticklabels(criteria, fontsize=7)
ax2.set_ylabel("Wartość kryterium", fontsize=8)
ax2.set_title(
"Porównanie kryteriów",
fontsize=FS_TITLE,
fontweight="bold",
pad=8,
)
ax2.legend(fontsize=7, loc="upper right")
ax2.axhline(y=0, color=LN, lw=0.5, ls="-")
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.tick_params(labelsize=7)
ax2.text(
5,
-30,
"(Savage: niżej\n= lepiej)",
fontsize=6,
ha="center",
va="top",
style="italic",
)
def draw_criteria_comparison() -> None:
"""Draw payoff matrix and criteria comparison chart."""
fig, axes = plt.subplots(
1,
2,
figsize=(8.27, 4.5),
gridspec_kw={"width_ratios": [1.2, 1]},
)
_draw_payoff_table(axes[0])
_draw_criteria_bars(axes[1])
plt.tight_layout()
outpath = str(Path(OUTPUT_DIR) / "q31_criteria_comparison.png")
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
plt.close(fig)
_logger.info(" Saved: %s", outpath)

View File

@ -0,0 +1,289 @@
"""Q31 Diagrams 5 & 6: Expected value + decision conditions spectrum."""
from __future__ import annotations
from pathlib import Path
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
_WINNING_EV,
BG,
DPI,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
LN,
OUTPUT_DIR,
_logger,
)
def draw_expected_value() -> None:
"""Draw expected value criterion with probability-weighted bars."""
fig, axes = plt.subplots(1, 3, figsize=(8.27, 3.5), sharey=True)
fig.suptitle(
"Kryterium wartości oczekiwanej E[X]" " \u2014 rozkład wyników per alternatywa",
fontsize=FS_TITLE,
fontweight="bold",
y=1.02,
)
probs = [0.5, 0.3, 0.2]
alts = [
("A₁ (fabryka)", [200, 50, -100], 95),
("A₂ (sklep)", [80, 70, 40], 69),
("A₃ (obligacje)", [30, 30, 30], 30),
]
hatches = ["///", "...", "xxx"]
for _idx, (ax, (name, vals, ev)) in enumerate(zip(axes, alts, strict=False)):
x_positions = [0, 0.6, 1.0]
widths = [p * 0.9 for p in probs]
for i, (v, p, h) in enumerate(zip(vals, probs, hatches, strict=False)):
color = "white" if v >= 0 else GRAY2
ax.bar(
x_positions[i],
v,
width=widths[i],
color=color,
edgecolor=LN,
hatch=h,
lw=0.8,
align="edge",
)
offset = 8 if v >= 0 else -12
ax.text(
x_positions[i] + widths[i] / 2,
v + offset,
f"{v}",
ha="center",
va="center",
fontsize=8,
fontweight="bold",
)
contrib = v * p
ax.text(
x_positions[i] + widths[i] / 2,
v / 2,
f"{v}x{p}\n={contrib:.0f}",
ha="center",
va="center",
fontsize=6,
style="italic",
)
# Expected value line
ax.axhline(y=ev, color=LN, lw=2, ls="--")
ax.text(
1.35,
ev,
f"E[X]={ev}",
fontsize=8,
fontweight="bold",
va="center",
ha="left",
bbox={
"boxstyle": "round,pad=0.15",
"facecolor": GRAY1,
"edgecolor": LN,
},
)
ax.set_title(name, fontsize=9, fontweight="bold")
ax.set_xticks([0.225, 0.735, 1.09])
ax.set_xticklabels(["S₁", "S₂", "S₃"], fontsize=7)
ax.axhline(y=0, color=LN, lw=0.5)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(labelsize=7)
# Star on winner
if ev == _WINNING_EV:
ax.text(
0.7,
ev + 20,
"★ MAX",
fontsize=9,
fontweight="bold",
ha="center",
va="bottom",
)
axes[0].set_ylabel("Wypłata (tys. zł)", fontsize=8)
plt.tight_layout()
outpath = str(Path(OUTPUT_DIR) / "q31_expected_value.png")
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
plt.close(fig)
_logger.info(" Saved: %s", outpath)
def draw_conditions_spectrum() -> None:
"""Draw decision conditions spectrum diagram."""
fig, ax = plt.subplots(1, 1, figsize=(8.27, 3.5))
ax.set_xlim(0, 10)
ax.set_ylim(0, 5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Warunki decyzyjne" " \u2014 spektrum wiedzy decydenta",
fontsize=FS_TITLE + 1,
fontweight="bold",
pad=10,
)
# Three zones
zones = [
(
0.3,
1.5,
2.8,
2.5,
"PEWNOŚĆ",
"white",
[
"Znamy dokładny wynik",
"Przykład: lokata 5%",
"Metoda: po prostu wybierz",
"najlepszy wynik",
],
),
(
3.5,
1.5,
2.8,
2.5,
"RYZYKO",
GRAY1,
[
"Znamy wyniki I prawdop.",
"Przykład: gra w kości",
"Metoda: wartość",
"oczekiwana E[X]",
],
),
(
6.7,
1.5,
2.8,
2.5,
"NIEPEWNOŚĆ",
GRAY3,
[
"Znamy wyniki, ale",
"NIE znamy prawdop.",
"Metody: Laplace, maximax,",
"maximin, Hurwicz, Savage",
],
),
]
for x, y, w, h, title, fill, lines in zones:
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.1",
lw=2,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h - 0.3,
title,
ha="center",
va="center",
fontsize=11,
fontweight="bold",
)
for i, line in enumerate(lines):
ax.text(
x + w / 2,
y + h - 0.7 - i * 0.4,
line,
ha="center",
va="center",
fontsize=7,
)
# Arrows between zones
ax.annotate(
"",
xy=(3.4, 2.75),
xytext=(3.15, 2.75),
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
)
ax.annotate(
"",
xy=(6.6, 2.75),
xytext=(6.35, 2.75),
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
)
# Bottom: knowledge gradient bar
gradient_y = 0.5
gradient_h = 0.5
n_steps = 50
for i in range(n_steps):
x = 0.3 + i * (9.2 / n_steps)
w = 9.2 / n_steps + 0.01
gray_val = 1 - (i / n_steps) * 0.7
rect = mpatches.Rectangle(
(x, gradient_y),
w,
gradient_h,
lw=0,
facecolor=str(gray_val),
)
ax.add_patch(rect)
rect = mpatches.Rectangle(
(0.3, gradient_y),
9.2,
gradient_h,
lw=1.5,
edgecolor=LN,
facecolor="none",
)
ax.add_patch(rect)
ax.text(
0.3,
gradient_y - 0.15,
"Dużo wiedzy",
fontsize=7,
ha="left",
va="top",
)
ax.text(
9.5,
gradient_y - 0.15,
"Mało wiedzy",
fontsize=7,
ha="right",
va="top",
)
ax.text(
4.95,
gradient_y + gradient_h / 2,
"POZIOM WIEDZY DECYDENTA",
fontsize=8,
fontweight="bold",
ha="center",
va="center",
color="white",
)
plt.tight_layout()
outpath = str(Path(OUTPUT_DIR) / "q31_conditions_spectrum.png")
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
plt.close(fig)
_logger.info(" Saved: %s", outpath)

View File

@ -0,0 +1,344 @@
"""Q31 Diagrams 3 & 4: Hurwicz interpolation + criteria mnemonic map."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
import numpy as np
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
BG,
DPI,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
LN,
OUTPUT_DIR,
_logger,
draw_box,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
def draw_hurwicz_interpolation() -> None:
"""Draw Hurwicz alpha interpolation diagram."""
fig, ax = plt.subplots(1, 1, figsize=(8.27, 4))
ax.set_title(
"Kryterium Hurwicza" " \u2014 wpływ \u03b1 na wybór alternatywy",
fontsize=FS_TITLE + 1,
fontweight="bold",
pad=10,
)
alphas = np.linspace(0, 1, 200)
v1 = alphas * 200 + (1 - alphas) * (-100)
v2 = alphas * 80 + (1 - alphas) * 40
v3 = alphas * 30 + (1 - alphas) * 30
ax.plot(
alphas,
v1,
"k-",
lw=2,
label="A₁ (fabryka): V = 300\u03b1 - 100",
)
ax.plot(
alphas,
v2,
"k--",
lw=2,
label="A₂ (sklep): V = 40\u03b1 + 40",
)
ax.plot(
alphas,
v3,
"k:",
lw=2,
label="A₃ (obligacje): V = 30",
)
# Crossover A2=A1
alpha_cross_12 = 140 / 260
v_cross_12 = 40 * alpha_cross_12 + 40
ax.plot(alpha_cross_12, v_cross_12, "ko", markersize=8, zorder=5)
ax.annotate(
f"\u03b1{alpha_cross_12:.2f}\nA₁ = A₂",
xy=(alpha_cross_12, v_cross_12),
xytext=(alpha_cross_12 + 0.12, v_cross_12 - 30),
fontsize=8,
fontweight="bold",
arrowprops={
"arrowstyle": "->",
"color": LN,
"lw": 1,
},
)
# Shade winning regions
ax.axvspan(0, alpha_cross_12, alpha=0.08, color="black")
ax.axvspan(alpha_cross_12, 1, alpha=0.15, color="black")
ax.text(
alpha_cross_12 / 2,
-60,
"A₂ wygrywa\n(pesymistycznie)",
fontsize=8,
ha="center",
va="center",
bbox={
"boxstyle": "round",
"facecolor": "white",
"edgecolor": LN,
},
)
ax.text(
(alpha_cross_12 + 1) / 2,
160,
"A₁ wygrywa\n(optymistycznie)",
fontsize=8,
ha="center",
va="center",
bbox={
"boxstyle": "round",
"facecolor": "white",
"edgecolor": LN,
},
)
# Special alpha values
ax.axvline(x=0, color=LN, lw=0.5, ls=":")
ax.axvline(x=1, color=LN, lw=0.5, ls=":")
ax.text(
0,
-115,
"\u03b1=0\nmaximin",
fontsize=7,
ha="center",
va="top",
fontweight="bold",
)
ax.text(
1,
-115,
"\u03b1=1\nmaximax",
fontsize=7,
ha="center",
va="top",
fontweight="bold",
)
ax.set_xlabel("Współczynnik optymizmu \u03b1", fontsize=9)
ax.set_ylabel("V(Aᵢ) = \u03b1·max + (1-\u03b1)·min", fontsize=9)
ax.legend(fontsize=8, loc="upper left")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlim(-0.05, 1.05)
ax.axhline(y=0, color=LN, lw=0.3, ls="-")
ax.tick_params(labelsize=8)
plt.tight_layout()
outpath = str(Path(OUTPUT_DIR) / "q31_hurwicz_alpha.png")
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
plt.close(fig)
_logger.info(" Saved: %s", outpath)
def _draw_mnemonic_criteria_boxes(ax: Axes) -> None:
"""Draw the 6 criteria boxes around the center."""
criteria = [
(
0,
6.5,
3,
1.2,
"WARTOŚĆ OCZEKIWANA",
"\u201eMam prawdopodobieństwa\u201d",
"E[Aᵢ] = Σ pⱼ·aᵢⱼ",
),
(
3.5,
6.5,
3,
1.2,
"LAPLACE",
"\u201eWszystko po równo\u201d",
"V = Σaᵢⱼ / n",
),
(
7,
6.5,
3,
1.2,
"MAXIMAX",
"\u201eOptymista: max z max\u201d",
"max maxⱼ aᵢⱼ",
),
(
0,
0.5,
3,
1.2,
"MAXIMIN (Wald)",
"\u201ePesymista: max z min\u201d",
"max minⱼ aᵢⱼ",
),
(
3.5,
0.5,
3,
1.2,
"HURWICZ",
"\u201e\u03b1 pomiędzy\u201d",
"\u03b1·max + (1-\u03b1)·min",
),
(
7,
0.5,
3,
1.2,
"SAVAGE",
"\u201eMin max żalu\u201d",
"min maxⱼ rᵢⱼ",
),
]
fills = [GRAY3, GRAY1, "white", "white", GRAY1, GRAY3]
for i, (x, y, w, h, title, mnem, formula) in enumerate(criteria):
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.08",
lw=1.5,
edgecolor=LN,
facecolor=fills[i],
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h * 0.78,
title,
ha="center",
va="center",
fontsize=8,
fontweight="bold",
)
ax.text(
x + w / 2,
y + h * 0.45,
mnem,
ha="center",
va="center",
fontsize=7,
style="italic",
)
ax.text(
x + w / 2,
y + h * 0.15,
formula,
ha="center",
va="center",
fontsize=7,
fontweight="bold",
family="monospace",
)
# Arrows from center to each box
cx, cy = 5, 4
bx = x + w / 2
by_center = y + h / 2
if by_center > cy:
ax.annotate(
"",
xy=(bx, y),
xytext=(cx, 4.5),
arrowprops={
"arrowstyle": "->",
"color": LN,
"lw": 1,
"connectionstyle": "arc3,rad=0",
},
)
else:
ax.annotate(
"",
xy=(bx, y + h),
xytext=(cx, 3.5),
arrowprops={
"arrowstyle": "->",
"color": LN,
"lw": 1,
"connectionstyle": "arc3,rad=0",
},
)
# Labels on arrows
arrow_labels = [
(1.2, 5.6, "znane p"),
(5, 5.6, "p = 1/n"),
(8.7, 5.6, "max ↑"),
(1.2, 2.5, "min ↑"),
(5, 2.5, "podaj \u03b1"),
(8.7, 2.5, "macierz\nżalu"),
]
for lx, ly, ltext in arrow_labels:
ax.text(
lx,
ly,
ltext,
fontsize=7,
ha="center",
va="center",
bbox={
"boxstyle": "round,pad=0.15",
"facecolor": "white",
"edgecolor": GRAY3,
"lw": 0.5,
},
)
def draw_criteria_mnemonic() -> None:
"""Draw decision criteria mnemonic map diagram."""
fig, ax = plt.subplots(1, 1, figsize=(8.27, 6))
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Mapa mnemoniczna \u2014 6 kryteriów decyzyjnych",
fontsize=FS_TITLE + 2,
fontweight="bold",
pad=10,
)
# Central node
draw_box(
ax,
3.5,
3.5,
3,
1,
"MACIERZ\nWYPŁAT",
fill=GRAY2,
lw=2,
fontsize=11,
fontweight="bold",
)
_draw_mnemonic_criteria_boxes(ax)
plt.tight_layout()
outpath = str(Path(OUTPUT_DIR) / "q31_criteria_mnemonic.png")
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
plt.close(fig)
_logger.info(" Saved: %s", outpath)

View File

@ -0,0 +1,322 @@
"""Q31 Diagram 2: Regret matrix construction step-by-step."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
_DATA_STATE_COLS,
_REGRET_HEADER_COLS,
BG,
DPI,
FS,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
LN,
OUTPUT_DIR,
_logger,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
def _draw_original_payoff(
ax: Axes,
start_y: float,
row_h: float,
) -> None:
"""Draw the original payoff matrix (left side of regret fig)."""
ax.text(
2.2,
6.3,
"Krok 1: Macierz wypłat",
fontsize=9,
fontweight="bold",
ha="center",
va="center",
)
col_w = 1.0
headers = ["", "S₁", "S₂", "S₃"]
data = [
["A₁", "200", "50", "-100"],
["A₂", "80", "70", "40"],
["A₃", "30", "30", "30"],
]
start_x = 0.3
for j, h in enumerate(headers):
w = 0.7 if j == 0 else col_w
x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w)
rect = mpatches.Rectangle(
(x, start_y),
w,
row_h,
lw=1,
edgecolor=LN,
facecolor=GRAY2,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
start_y + row_h / 2,
h,
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
for i, row in enumerate(data):
y = start_y - (i + 1) * row_h
for j, val in enumerate(row):
w = 0.7 if j == 0 else col_w
x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w)
fill = GRAY4 if j == 0 else "white"
rect = mpatches.Rectangle(
(x, y),
w,
row_h,
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + row_h / 2,
val,
ha="center",
va="center",
fontsize=FS,
)
# Max per column annotation
max_y = start_y - _DATA_STATE_COLS * row_h - 0.1
col_maxes = ["max=200", "max=70", "max=40"]
for idx, label in enumerate(col_maxes):
ax.text(
start_x + 0.7 + (idx + 0.5) * col_w,
max_y,
label,
fontsize=7,
ha="center",
va="top",
fontweight="bold",
color="#333",
)
# Arrow
ax.annotate(
"",
xy=(5.0, 4.8),
xytext=(4.2, 4.8),
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
)
ax.text(
4.6,
5.0,
"rᵢⱼ = max - aᵢⱼ",
fontsize=8,
ha="center",
va="bottom",
fontweight="bold",
)
def _draw_regret_table(
ax: Axes,
start_y: float,
row_h: float,
) -> None:
"""Draw the regret matrix (right side of regret fig)."""
ax.text(
7.5,
6.3,
"Krok 2: Macierz żalu",
fontsize=9,
fontweight="bold",
ha="center",
va="center",
)
regret_data = [
["A₁", "0", "20", "140"],
["A₂", "120", "0", "0"],
["A₃", "170", "40", "10"],
]
headers2 = ["", "S₁", "S₂", "S₃", "max rᵢ"]
start_x2 = 5.3
for j, h in enumerate(headers2):
w = 0.7 if j == 0 else (0.9 if j < _REGRET_HEADER_COLS else 1.0)
x = start_x2
if j == 0:
x = start_x2
elif j <= _DATA_STATE_COLS:
x = start_x2 + 0.7 + (j - 1) * 0.9
else:
x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9
rect = mpatches.Rectangle(
(x, start_y),
w,
row_h,
lw=1,
edgecolor=LN,
facecolor=(GRAY2 if j < _REGRET_HEADER_COLS else GRAY3),
)
ax.add_patch(rect)
ax.text(
x + w / 2,
start_y + row_h / 2,
h,
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
max_regrets = [140, 120, 170]
for i, row in enumerate(regret_data):
y = start_y - (i + 1) * row_h
for j, val in enumerate(row):
w = 0.7 if j == 0 else 0.9
x = start_x2 + (0 if j == 0 else 0.7 + (j - 1) * 0.9)
fill = GRAY4 if j == 0 else "white"
if j > 0 and int(val) == max_regrets[i]:
fill = GRAY2
rect = mpatches.Rectangle(
(x, y),
w,
row_h,
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
fw = "bold" if (j > 0 and int(val) == max_regrets[i]) else "normal"
ax.text(
x + w / 2,
y + row_h / 2,
val,
ha="center",
va="center",
fontsize=FS,
fontweight=fw,
)
# Max regret column
x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9
w = 1.0
is_winner = max_regrets[i] == min(max_regrets)
fill = "#C8C8C8" if is_winner else GRAY1
rect = mpatches.Rectangle(
(x, y),
w,
row_h,
lw=1.5 if is_winner else 1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
marker = "" if is_winner else ""
ax.text(
x + w / 2,
y + row_h / 2,
f"{max_regrets[i]}{marker}",
ha="center",
va="center",
fontsize=FS,
fontweight="bold",
)
def draw_regret_matrix() -> None:
"""Draw the regret matrix construction diagram."""
fig, ax = plt.subplots(1, 1, figsize=(8.27, 5))
ax.axis("off")
ax.set_xlim(0, 10)
ax.set_ylim(0, 7)
ax.set_title(
"Kryterium Savage'a \u2014 budowa macierzy żalu",
fontsize=FS_TITLE + 1,
fontweight="bold",
pad=10,
)
start_y = 5.5
row_h = 0.55
_draw_original_payoff(ax, start_y, row_h)
_draw_regret_table(ax, start_y, row_h)
# Bottom conclusion
ax.text(
5.0,
2.8,
"Krok 3: Wybierz min z max żalu" " → A₂ (max żal = 120)",
fontsize=10,
ha="center",
va="center",
fontweight="bold",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY1,
"edgecolor": LN,
"lw": 1.5,
},
)
# Interpretation examples
ax.text(
5.0,
2.0,
"Interpretacja żalu: r₁₃ = 140 oznacza:\n"
"\u201eGdyby nastąpił S₃ (zła koniunktura),"
" a wybrałbym A₁,\n"
"żałowałbym, bo najlepszą opcją byłoby"
" A₂ z wynikiem 40 \u2014 traciłbym 140\u201d",
fontsize=7.5,
ha="center",
va="center",
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
"lw": 0.8,
},
)
# Mnemonic
ax.text(
5.0,
0.8,
"Mnemonik: Savage = \u201eŻal jak nóż\u201d\n"
"Maksymalny żal to nóż "
"\u2014 wybierz opcję z NAJMNIEJSZYM nożem",
fontsize=8,
ha="center",
va="center",
fontweight="bold",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": "white",
"edgecolor": LN,
"lw": 1,
},
)
plt.tight_layout()
outpath = str(Path(OUTPUT_DIR) / "q31_regret_matrix.png")
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
plt.close(fig)
_logger.info(" Saved: %s", outpath)

View File

@ -0,0 +1,448 @@
"""Q9 diagrams 1-6: process/thread basics, memory, states, PCB, speed."""
from __future__ import annotations
from _q9_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
draw_table,
save_fig,
)
import matplotlib.pyplot as plt
# ============================================================
# 1. Process vs Thread comparison table
# ============================================================
def gen_process_vs_thread() -> None:
"""Gen process vs thread."""
fig, ax = plt.subplots(figsize=(7.5, 4.5))
ax.set_xlim(0, 10)
ax.set_ylim(-4.5, 1.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Proces vs Wątek — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10
)
headers = ["Cecha", "Proces", "Wątek"]
col_w = [2.5, 3.5, 3.5]
rows = [
["Pamięć", "Własna, izolowana", "Współdzielona (heap)"],
["Tworzenie", "~1-10 ms", "~10-100 μs (100x szybciej)"],
["Przełączanie", "~1-5 μs (TLB flush)", "~0.1-0.5 μs (10x)"],
["Komunikacja", "IPC (pipe, socket, shm)", "Bezpośrednia (wspólna pam.)"],
["Izolacja", "Pełna — awaria izolowana", "Brak — może zabić proces"],
["Zastosowanie", "Bezpieczeństwo, izolacja", "Wydajność, współdzielenie"],
]
draw_table(
ax,
headers,
rows,
x0=0.25,
y0=0.8,
col_widths=col_w,
row_h=0.55,
fontsize=7.5,
header_fontsize=FS_LABEL,
)
# Analogy at bottom
ax.text(
5.0,
-4.2,
"Analogia: Proces = mieszkanie (własny adres) "
"Wątek = pokój w mieszkaniu (wspólna kuchnia = heap)",
ha="center",
fontsize=FS,
style="italic",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
save_fig(fig, "q9_process_vs_thread.png")
# ============================================================
# 2. Memory segments layout
# ============================================================
def gen_memory_layout() -> None:
"""Gen memory layout."""
fig, ax = plt.subplots(figsize=(6, 5))
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Segmenty pamięci procesu", fontsize=FS_TITLE, fontweight="bold", pad=10
)
segments = [
("STACK ↓", "zmienne lokalne, adresy\npowrotu (każdy wątek WŁASNY)", GRAY1),
("...", "(wolna przestrzeń)", "white"),
("HEAP ↑", "malloc/new — dynamiczna\nalokacja (współdzielony)", GRAY4),
("BSS", "zmienne globalne\nniezainicjalizowane (zerowane)", GRAY2),
("DATA", "zmienne globalne\nzainicjalizowane", GRAY3),
("TEXT", "kod maszynowy\n(read-only, współdzielony)", GRAY5),
]
bx, bw = 2.0, 2.5
seg_h = 0.9
gap = 0.05
top_y = 7.0
for i, (name, desc, color) in enumerate(segments):
y = top_y - i * (seg_h + gap)
draw_box(
ax,
bx,
y,
bw,
seg_h,
name,
fill=color,
fontsize=FS_LABEL,
fontweight="bold",
rounded=False,
)
ax.text(bx + bw + 0.3, y + seg_h / 2, desc, fontsize=7.5, va="center")
# Address labels
ax.text(
bx - 0.2,
top_y + seg_h / 2,
"wysoki\nadres",
fontsize=FS_SMALL,
va="center",
ha="right",
style="italic",
)
bottom_y = top_y - 5 * (seg_h + gap)
ax.text(
bx - 0.2,
bottom_y + seg_h / 2,
"niski\nadres",
fontsize=FS_SMALL,
va="center",
ha="right",
style="italic",
)
# Arrows for growth
ax.annotate(
"",
xy=(bx - 0.5, top_y - 0.1),
xytext=(bx - 0.5, top_y + seg_h + 0.1),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(bx - 0.9, top_y + 0.4, "rośnie\nw dół", fontsize=FS_SMALL, ha="center")
heap_y = top_y - 2 * (seg_h + gap)
ax.annotate(
"",
xy=(bx - 0.5, heap_y + seg_h + 0.1),
xytext=(bx - 0.5, heap_y - 0.1),
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
)
ax.text(bx - 0.9, heap_y + 0.5, "rośnie\nw górę", fontsize=FS_SMALL, ha="center")
save_fig(fig, "q9_memory_layout.png")
# ============================================================
# 3. Process states diagram
# ============================================================
def gen_process_states() -> None:
"""Gen process states."""
fig, ax = plt.subplots(figsize=(7, 3.5))
ax.set_xlim(0, 12)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Stany procesu — diagram przejść", fontsize=FS_TITLE, fontweight="bold", pad=10
)
states = {
"NEW": (1.0, 2.5),
"READY": (3.5, 2.5),
"RUNNING": (6.5, 2.5),
"BLOCKED": (6.5, 0.5),
"TERMINATED": (10.0, 2.5),
}
fills = {
"NEW": GRAY4,
"READY": GRAY1,
"RUNNING": GRAY3,
"BLOCKED": GRAY2,
"TERMINATED": GRAY5,
}
bw, bh = 1.8, 0.9
for name, (x, y) in states.items():
draw_box(
ax,
x,
y,
bw,
bh,
name,
fill=fills[name],
fontsize=FS_LABEL,
fontweight="bold",
)
# Transitions
transitions = [
("NEW", "READY", "admit"),
("READY", "RUNNING", "dispatch\n(scheduler)"),
("RUNNING", "TERMINATED", "exit"),
("RUNNING", "BLOCKED", "I/O wait"),
]
for src, dst, label in transitions:
sx, sy = states[src]
dx, dy = states[dst]
if sy == dy: # horizontal
draw_arrow(ax, sx + bw, sy + bh / 2, dx, dy + bh / 2, lw=1.5)
mx = (sx + bw + dx) / 2
ax.text(
mx,
sy + bh / 2 + 0.25,
label,
fontsize=FS_SMALL,
ha="center",
va="bottom",
)
else: # vertical
draw_arrow(ax, sx + bw / 2, sy, dx + bw / 2, dy + bh, lw=1.5)
ax.text(
sx + bw + 0.2,
(sy + dy + bh) / 2,
label,
fontsize=FS_SMALL,
ha="left",
va="center",
)
# BLOCKED → READY
bx, by = states["BLOCKED"]
rx, ry = states["READY"]
ax.annotate(
"",
xy=(rx + bw / 2, ry),
xytext=(bx - 0.3, by + bh / 2),
arrowprops={
"arrowstyle": "->",
"lw": 1.5,
"color": LN,
"connectionstyle": "arc3,rad=0.3",
},
)
ax.text(3.5, 0.7, "I/O done", fontsize=FS_SMALL, ha="center")
# RUNNING → READY (preemption)
rux, ruy = states["RUNNING"]
draw_arrow(ax, rux, ruy + bh, rx + bw, ry + bh, lw=1.2)
ax.text(5.0, 3.7, "preempt /\ntimeout", fontsize=FS_SMALL, ha="center")
save_fig(fig, "q9_process_states.png")
# ============================================================
# 4. Thread structure within process
# ============================================================
def gen_thread_structure() -> None:
"""Gen thread structure."""
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.set_xlim(0, 10)
ax.set_ylim(0, 6)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Wątki wewnątrz procesu (PID=42)", fontsize=FS_TITLE, fontweight="bold", pad=10
)
# Shared memory region
draw_box(ax, 0.5, 3.5, 9.0, 1.8, "", fill=GRAY1, rounded=False, lw=2)
ax.text(5.0, 5.0, "WSPÓŁDZIELONE", fontsize=FS, fontweight="bold", ha="center")
labels_shared = ["TEXT", "DATA", "BSS", "HEAP", "pliki", "PID"]
for i, lab in enumerate(labels_shared):
x = 1.0 + i * 1.4
draw_box(
ax,
x,
3.8,
1.1,
0.6,
lab,
fill=GRAY3,
fontsize=FS,
fontweight="bold",
rounded=False,
)
ax.text(x + 0.55, 4.6, lab, fontsize=FS_SMALL, ha="center", color="#555555")
# Per-thread regions
draw_box(
ax, 0.5, 0.5, 9.0, 2.7, "", fill="white", rounded=False, lw=2, linestyle="--"
)
ax.text(
5.0, 2.95, "PRYWATNE (każdy wątek)", fontsize=FS, fontweight="bold", ha="center"
)
for i in range(3):
x = 1.0 + i * 3.0
tid = i + 1
draw_box(ax, x, 0.7, 2.3, 2.0, "", fill=GRAY4, rounded=False)
ax.text(
x + 1.15,
2.4,
f"Wątek {tid}",
fontsize=FS_LABEL,
fontweight="bold",
ha="center",
)
items = [f"stos_{tid}", f"rejestry_{tid}", f"PC_{tid}", f"TID={40 + tid}"]
for j, item in enumerate(items):
ax.text(
x + 1.15,
2.0 - j * 0.35,
item,
fontsize=FS_SMALL,
ha="center",
family="monospace",
)
save_fig(fig, "q9_thread_structure.png")
# ============================================================
# 5. PCB structure
# ============================================================
def gen_pcb_structure() -> None:
"""Gen pcb structure."""
fig, ax = plt.subplots(figsize=(5, 3.5))
ax.set_xlim(0, 8)
ax.set_ylim(0, 5.5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"PCB (Process Control Block)", fontsize=FS_TITLE, fontweight="bold", pad=10
)
fields = [
("PID", "42"),
("Stan", "READY / RUNNING / BLOCKED"),
("Rejestry CPU", "EAX, EBX, ESP, EIP ..."),
("Tablice stron", "mapowanie wirtualne → fizyczne"),
("Otwarte pliki", "fd[0], fd[1], fd[2] ..."),
("Priorytety", "nice value, scheduling class"),
("Statystyki", "CPU time, I/O count"),
]
top_y = 4.8
for i, (field, value) in enumerate(fields):
y = top_y - i * 0.55
draw_box(
ax,
0.5,
y,
2.2,
0.45,
field,
fill=GRAY2,
fontsize=FS,
fontweight="bold",
rounded=False,
)
draw_box(ax, 2.7, y, 4.5, 0.45, value, fill=GRAY4, fontsize=FS, rounded=False)
ax.text(
4.0,
0.3,
"Context switch = zapisz PCB starego → wczytaj PCB nowego",
fontsize=FS_SMALL,
ha="center",
style="italic",
)
save_fig(fig, "q9_pcb_structure.png")
# ============================================================
# 6. Speed comparison
# ============================================================
def gen_speed_comparison() -> None:
"""Gen speed comparison."""
fig, axes = plt.subplots(1, 2, figsize=(9, 3.5))
fig.suptitle(
"Szybkość — procesy vs wątki (benchmarki Linux)",
fontsize=FS_TITLE,
fontweight="bold",
)
# Creation time panel
ax = axes[0]
ops = ["fork()\n(nowy proces)", "pthread_create()\n(nowy wątek)"]
times = [3.0, 0.05] # ms
colors = [GRAY3, GRAY1]
bars = ax.barh(ops, times, color=colors, edgecolor=LN, height=0.5, linewidth=1.2)
ax.set_xlabel("Czas [ms]", fontsize=FS)
ax.set_title("Tworzenie", fontsize=FS_LABEL, fontweight="bold")
ax.set_xlim(0, 4.5)
for bar, t in zip(bars, times, strict=False):
ax.text(
bar.get_width() + 0.1,
bar.get_y() + bar.get_height() / 2,
f"{t} ms",
va="center",
fontsize=FS,
)
ax.text(
2.5,
-0.6,
"~100x szybciej",
fontsize=FS,
ha="center",
fontweight="bold",
transform=ax.get_xaxis_transform(),
)
ax.tick_params(labelsize=FS)
# Right: context switch
ax = axes[1]
ops2 = ["Proces→Proces\n(TLB flush)", "Wątek→Wątek\n(TLB warm)"]
times2 = [3000, 300] # ns
bars2 = ax.barh(ops2, times2, color=colors, edgecolor=LN, height=0.5, linewidth=1.2)
ax.set_xlabel("Czas [ns]", fontsize=FS)
ax.set_title("Przełączanie kontekstu", fontsize=FS_LABEL, fontweight="bold")
ax.set_xlim(0, 4500)
for bar, t in zip(bars2, times2, strict=False):
ax.text(
bar.get_width() + 50,
bar.get_y() + bar.get_height() / 2,
f"{t} ns",
va="center",
fontsize=FS,
)
ax.text(
2500,
-0.6,
"~10x szybciej",
fontsize=FS,
ha="center",
fontweight="bold",
transform=ax.get_xaxis_transform(),
)
ax.tick_params(labelsize=FS)
fig.tight_layout(rect=[0, 0.05, 1, 0.92])
save_fig(fig, "q9_speed_comparison.png")

View File

@ -0,0 +1,420 @@
"""Q9 diagrams 14-16: classic sync problems, mechanism comparison, semaphore."""
from __future__ import annotations
from _q9_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
OCCUPIED_SLOTS,
draw_arrow,
draw_box,
draw_table,
save_fig,
)
import matplotlib.pyplot as plt
import numpy as np
# ============================================================
# 14. Bounded buffer + readers-writers + philosophers
# ============================================================
def _draw_bounded_buffer_panel(ax: plt.Axes) -> None:
"""Draw the bounded-buffer (producer-consumer) panel."""
ax.set_xlim(0, 8)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Producent-Konsument\n(Bounded Buffer, N=4)", fontsize=FS, fontweight="bold"
)
draw_box(
ax,
0.2,
4.0,
2.0,
1.2,
"Producent\nP(empty)\nP(mutex)\nwstaw()\nV(mutex)\nV(full)",
fill=GRAY1,
fontsize=5.5,
)
items = ["A", "B", "", ""]
for i, item in enumerate(items):
x = 2.8 + i * 0.9
fill = GRAY3 if item else "white"
draw_box(
ax,
x,
4.3,
0.9,
0.7,
item,
fill=fill,
fontsize=FS,
fontweight="bold",
rounded=False,
)
ax.text(4.6, 5.2, "Bufor (N=4)", fontsize=FS_SMALL, ha="center", fontweight="bold")
draw_box(
ax,
6.0,
4.0,
2.0,
1.2,
"Konsument\nP(full)\nP(mutex)\npobierz()\nV(mutex)\nV(empty)",
fill=GRAY4,
fontsize=5.5,
)
draw_arrow(ax, 2.2, 4.6, 2.8, 4.65, lw=1.2)
draw_arrow(ax, 6.4, 4.65, 6.0, 4.6, lw=1.2)
sems = [("mutex = 1", GRAY2), ("empty = N", GRAY1), ("full = 0", GRAY3)]
for i, (s, c) in enumerate(sems):
draw_box(
ax,
2.0,
2.5 - i * 0.6,
4.0,
0.45,
s,
fill=c,
fontsize=FS_SMALL,
fontweight="bold",
)
ax.text(
4.0,
0.5,
"KOLEJNOŚĆ: P(empty/full)\nPRZED P(mutex)!\nOdwrotnie = DEADLOCK",
fontsize=5.5,
ha="center",
fontweight="bold",
color="#C62828",
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
)
def _draw_readers_writers_panel(ax: plt.Axes) -> None:
"""Draw the readers-writers panel."""
ax.set_xlim(0, 8)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Czytelnicy-Pisarze\n(Readers-Writers)", fontsize=FS, fontweight="bold"
)
draw_box(
ax,
2.5,
3.5,
3.0,
1.5,
"Dane\n(współdzielone)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
for i in range(3):
x = 0.3 + i * 1.0
draw_box(
ax,
x,
5.5,
0.8,
0.7,
f"R{i + 1}",
fill=GRAY4,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, x + 0.4, 5.5, 3.0 + i * 0.5, 5.0, lw=1)
ax.text(
1.5,
6.5,
"Czytelnicy (wielu naraz)",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
)
draw_box(
ax, 5.5, 5.5, 1.5, 0.7, "Pisarz", fill=GRAY5, fontsize=FS, fontweight="bold"
)
draw_arrow(ax, 6.25, 5.5, 5.0, 5.0, lw=1.5)
ax.text(
6.25,
6.5,
"WYŁĄCZNY",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
color="#C62828",
)
rules = [
"Wielu czytelników = OK",
"Jeden pisarz = wyłączny",
"Czytelnik + Pisarz = ✗",
"Problem: pisarze głodują",
]
for i, r in enumerate(rules):
ax.text(4.0, 2.5 - i * 0.45, r, fontsize=FS_SMALL, ha="center")
ax.text(
4.0,
0.5,
"Rozwiązanie:\nrw_mutex + count_mutex\n+ zmienna readers",
fontsize=5.5,
ha="center",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
def _draw_philosophers_panel(ax: plt.Axes) -> None:
"""Draw the dining-philosophers panel."""
ax.set_xlim(0, 8)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Ucztujący filozofowie\n(Dining Philosophers)", fontsize=FS, fontweight="bold"
)
cx, cy, r = 4.0, 3.8, 1.8
table = plt.Circle((cx, cy), 0.8, fill=True, facecolor=GRAY2, edgecolor=LN, lw=1.5)
ax.add_patch(table)
ax.text(cx, cy, "Stół", fontsize=FS, ha="center", fontweight="bold")
for i in range(5):
angle = np.pi / 2 + i * 2 * np.pi / 5
px = cx + r * np.cos(angle)
py = cy + r * np.sin(angle)
circle = plt.Circle(
(px, py), 0.35, fill=True, facecolor=GRAY1, edgecolor=LN, lw=1.2
)
ax.add_patch(circle)
ax.text(
px, py, f"F{i}", ha="center", va="center", fontsize=FS, fontweight="bold"
)
fork_angle = np.pi / 2 + (i + 0.5) * 2 * np.pi / 5
fx = cx + (r * 0.6) * np.cos(fork_angle)
fy = cy + (r * 0.6) * np.sin(fork_angle)
ax.plot(
[fx - 0.1, fx + 0.1],
[fy - 0.15, fy + 0.15],
color=LN,
lw=2.5,
solid_capstyle="round",
)
ax.text(fx + 0.2, fy + 0.15, f"w{i}", fontsize=5, color="#555555")
rules = [
"Jedzenie = 2 widelce",
"Naiwne → DEADLOCK",
"Fix: F4 bierze odwrotnie",
"Alt: semafor(4)",
]
for i, r in enumerate(rules):
ax.text(4.0, 1.2 - i * 0.35, r, fontsize=FS_SMALL, ha="center")
def gen_classic_problems() -> None:
"""Gen classic problems."""
fig, axes = plt.subplots(1, 3, figsize=(12, 5))
fig.suptitle(
"Klasyczne problemy synchronizacji", fontsize=FS_TITLE, fontweight="bold"
)
_draw_bounded_buffer_panel(axes[0])
_draw_readers_writers_panel(axes[1])
_draw_philosophers_panel(axes[2])
fig.tight_layout(rect=[0, 0, 1, 0.88])
save_fig(fig, "q9_classic_problems.png")
# ============================================================
# 15. Sync mechanisms comparison + mutex/sem/spinlock
# ============================================================
def gen_sync_comparison() -> None:
"""Gen sync comparison."""
fig, axes = plt.subplots(2, 1, figsize=(9, 7))
# Top: comparison table
ax = axes[0]
ax.set_xlim(0, 11.5)
ax.set_ylim(-5, 1)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Mechanizmy synchronizacji — porównanie",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
headers = ["Mechanizm", "Opis", "Kiedy używać"]
col_w = [2.5, 4.5, 4.0]
rows = [
["Mutex", "Zamek: 1 wątek w sekcji", "Sekcja krytyczna"],
["Semafor(n)", "Licznik: max n wątków", "Ograniczone zasoby (n miejsc)"],
["Monitor", "Obiekt z wbudowanym mutex", "Java synchronized"],
["Cond. Variable", "wait()/signal() na warunek", "Producent-konsument"],
["Spinlock", "Aktywne czekanie (busy-wait)", "Bardzo krótkie sekcje (<1 μs)"],
["RW Lock", "Wielu czytelników LUB 1 pisarz", "Bazy danych, cache"],
["Barrier", "Czekaj aż wszyscy dotrą", "Obliczenia równoległe"],
]
draw_table(
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7
)
# Bottom: mutex vs semafor vs spinlock
ax = axes[1]
ax.set_xlim(0, 12)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Mutex vs Semafor vs Spinlock", fontsize=FS_TITLE, fontweight="bold", pad=5
)
# Mutex
draw_box(ax, 0.3, 2.5, 3.5, 2.0, "", fill=GRAY4)
ax.text(2.05, 4.2, "MUTEX", fontsize=FS_LABEL, fontweight="bold", ha="center")
ax.text(2.05, 3.6, "= klucz do łazienki\n(1 osoba)", fontsize=FS, ha="center")
ax.text(
2.05,
2.8,
"Wątek ZASYPIA gdy czeka\nOS go obudzi (~μs)",
fontsize=FS_SMALL,
ha="center",
style="italic",
)
# Semafor
draw_box(ax, 4.3, 2.5, 3.5, 2.0, "", fill=GRAY1)
ax.text(6.05, 4.2, "SEMAFOR(n)", fontsize=FS_LABEL, fontweight="bold", ha="center")
ax.text(
6.05, 3.6, "= parking na n miejsc\n(n wątków naraz)", fontsize=FS, ha="center"
)
ax.text(
6.05,
2.8,
"Semafor(1) = mutex\nP() = zmniejsz, V() = zwiększ",
fontsize=FS_SMALL,
ha="center",
style="italic",
)
# Spinlock
draw_box(ax, 8.3, 2.5, 3.5, 2.0, "", fill=GRAY2)
ax.text(10.05, 4.2, "SPINLOCK", fontsize=FS_LABEL, fontweight="bold", ha="center")
ax.text(
10.05, 3.6, "= obrotowe drzwi\n(kręcisz się w kółko)", fontsize=FS, ha="center"
)
ax.text(
10.05,
2.8,
"Wątek KRĘCI się w pętli\nLepszy gdy sekcja < 1 μs",
fontsize=FS_SMALL,
ha="center",
style="italic",
)
# Dividing rule
ax.text(
6.0,
1.5,
"Reguła kciuka: sekcja > 1 μs → MUTEX | "
"sekcja < 1 μs → SPINLOCK | n jednocześnie → SEMAFOR(n)",
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
fig.tight_layout()
save_fig(fig, "q9_sync_comparison.png")
# ============================================================
# 16. Semaphore concept diagram
# ============================================================
def gen_semaphore_concept() -> None:
"""Gen semaphore concept."""
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_xlim(0, 10)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Semafor — koncepcja (parking na 3 miejsca)",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Parking slots
for i in range(3):
x = 2.0 + i * 2.0
occupied = i < OCCUPIED_SLOTS
fill = GRAY3 if occupied else "white"
label = f"Wątek {i + 1}" if occupied else "(wolne)"
draw_box(
ax,
x,
2.5,
1.5,
1.2,
label,
fill=fill,
fontsize=FS,
fontweight="bold" if occupied else "normal",
rounded=False,
)
ax.text(
5.0,
4.2,
"semafor(3): counter = 1 (jedno wolne miejsce)",
fontsize=FS,
ha="center",
fontweight="bold",
)
# Waiting thread
draw_box(
ax,
0.2,
0.5,
1.5,
0.8,
"Wątek 4\nP() → czeka",
fill="#F8D7DA",
fontsize=FS_SMALL,
)
draw_arrow(ax, 1.7, 0.9, 2.0, 2.5, lw=1.2, color="#C62828")
ax.text(
5.0,
0.6,
"P() = counter-- (jeśli 0 → czekaj)\nV() = counter++ (obudź czekającego)",
fontsize=FS,
ha="center",
family="monospace",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
save_fig(fig, "q9_semaphore_concept.png")

View File

@ -0,0 +1,200 @@
"""Common utilities and constants for Q9 diagram generation.
Monochrome, A4-printable PNGs (300 DPI).
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
_logger = logging.getLogger(__name__)
DPI = 300
BG = "white"
LN = "black"
FS = 8
FS_TITLE = 11
FS_SMALL = 6.5
FS_LABEL = 9
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
GRAY1 = "#E8E8E8"
GRAY2 = "#D0D0D0"
GRAY3 = "#B8B8B8"
GRAY4 = "#F5F5F5"
GRAY5 = "#C0C0C0"
OCCUPIED_SLOTS = 2
def draw_box(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
text: str,
fill: str = "white",
lw: float = 1.2,
fontsize: float = FS,
fontweight: str = "normal",
ha: str = "center",
va: str = "center",
*,
rounded: bool = True,
edgecolor: str = LN,
linestyle: str = "-",
) -> None:
"""Draw box."""
if rounded:
rect = FancyBboxPatch(
(x, y),
w,
h,
boxstyle="round,pad=0.05",
lw=lw,
edgecolor=edgecolor,
facecolor=fill,
linestyle=linestyle,
)
else:
rect = mpatches.Rectangle(
(x, y),
w,
h,
lw=lw,
edgecolor=edgecolor,
facecolor=fill,
linestyle=linestyle,
)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h / 2,
text,
ha=ha,
va=va,
fontsize=fontsize,
fontweight=fontweight,
wrap=True,
)
def draw_arrow(
ax: Axes,
x1: float,
y1: float,
x2: float,
y2: float,
lw: float = 1.2,
style: str = "->",
color: str = LN,
) -> None:
"""Draw arrow."""
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
)
def draw_double_arrow(
ax: Axes,
x1: float,
y1: float,
x2: float,
y2: float,
lw: float = 1.2,
color: str = LN,
) -> None:
"""Draw double arrow."""
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": "<->", "color": color, "lw": lw},
)
def save_fig(fig: Figure, name: str) -> None:
"""Save fig."""
path = str(Path(OUTPUT_DIR) / name)
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15)
plt.close(fig)
_logger.info(" Saved: %s", path)
def draw_table(
ax: Axes,
headers: list[str],
rows: list[list[str]],
x0: float,
y0: float,
col_widths: list[float],
row_h: float = 0.4,
header_fill: str = GRAY2,
row_fills: list[str] | None = None,
fontsize: float = FS,
header_fontsize: float | None = None,
) -> None:
"""Draw a clean table on axes."""
if header_fontsize is None:
header_fontsize = fontsize
len(headers)
len(rows)
sum(col_widths)
# Header
cx = x0
for j, hdr in enumerate(headers):
draw_box(
ax,
cx,
y0,
col_widths[j],
row_h,
hdr,
fill=header_fill,
fontsize=header_fontsize,
fontweight="bold",
rounded=False,
)
cx += col_widths[j]
# Rows
for i, row in enumerate(rows):
cy = y0 - (i + 1) * row_h
cx = x0
fill = GRAY4 if (i % 2 == 0) else "white"
if row_fills and i < len(row_fills):
fill = row_fills[i]
for j, cell in enumerate(row):
fw = "bold" if j == 0 else "normal"
draw_box(
ax,
cx,
cy,
col_widths[j],
row_h,
cell,
fill=fill,
fontsize=fontsize,
fontweight=fw,
rounded=False,
)
cx += col_widths[j]

View File

@ -0,0 +1,212 @@
"""Q9 diagrams 7-9: IPC mechanisms and scenario tables."""
from __future__ import annotations
from _q9_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
draw_arrow,
draw_box,
draw_double_arrow,
draw_table,
save_fig,
)
import matplotlib.pyplot as plt
# ============================================================
# 7. Scenario table (when to use process vs thread)
# ============================================================
def gen_scenario_table() -> None:
"""Gen scenario table."""
fig, ax = plt.subplots(figsize=(8.5, 4.5))
ax.set_xlim(0, 11)
ax.set_ylim(-5.5, 1)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Kiedy proces, kiedy wątek? — typowe scenariusze",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
headers = ["Scenariusz", "Wybór", "Dlaczego?"]
col_w = [3.5, 2.5, 4.5]
rows = [
["Serwer WWW (Apache)", "Proces", "izolacja klientów"],
["Serwer WWW (nginx)", "Wątek / async", "szybkość, cooperacja"],
["Przeglądarka (karty)", "Proces", "crash isolation"],
["Przeglądarka (JS+render)", "Wątek", "współdzielony DOM"],
["Gra (fizyka+rendering)", "Wątek", "współdzielony świat gry"],
["Kompilacja (make -j8)", "Proces", "izolacja, prostota"],
["Baza danych (zapytania)", "Wątek", "współdzielony cache"],
["Microservices", "Proces (kontener)", "izolacja, deployment"],
]
draw_table(
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7
)
save_fig(fig, "q9_scenario_table.png")
# ============================================================
# 8. IPC details: pipe, shared memory, socket (3-panel)
# ============================================================
def gen_ipc_details() -> None:
"""Gen ipc details."""
fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
fig.suptitle("Mechanizmy IPC — szczegóły", fontsize=FS_TITLE, fontweight="bold")
# Panel 1: Pipe
ax = axes[0]
ax.set_xlim(0, 8)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Pipe (potok)", fontsize=FS_LABEL, fontweight="bold")
draw_box(ax, 0.2, 2.0, 1.8, 1.2, "Proces A\n(ls)\nstdout", fill=GRAY1, fontsize=FS)
draw_box(
ax,
3.0,
2.0,
1.8,
1.2,
"Bufor\njądra\n(4 KB)",
fill=GRAY2,
fontsize=FS,
fontweight="bold",
)
draw_box(ax, 5.8, 2.0, 1.8, 1.2, "Proces B\n(grep)\nstdin", fill=GRAY1, fontsize=FS)
draw_arrow(ax, 2.0, 2.6, 3.0, 2.6, lw=1.5)
ax.text(2.5, 3.0, "write()\nfd[1]", fontsize=FS_SMALL, ha="center")
draw_arrow(ax, 4.8, 2.6, 5.8, 2.6, lw=1.5)
ax.text(5.3, 3.0, "read()\nfd[0]", fontsize=FS_SMALL, ha="center")
ax.text(
4.0,
0.8,
"Jednokierunkowy\nBufor pełny → write() blokuje",
fontsize=FS_SMALL,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# Panel 2: Shared Memory
ax = axes[1]
ax.set_xlim(0, 8)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Shared Memory", fontsize=FS_LABEL, fontweight="bold")
draw_box(ax, 0.3, 3.0, 2.2, 1.2, "Proces A\nstrona 7", fill=GRAY1, fontsize=FS)
draw_box(ax, 5.5, 3.0, 2.2, 1.2, "Proces B\nstrona 3", fill=GRAY1, fontsize=FS)
draw_box(
ax,
2.8,
1.0,
2.4,
1.2,
"RAM\nramka 42",
fill=GRAY3,
fontsize=FS,
fontweight="bold",
)
draw_arrow(ax, 2.0, 3.0, 3.5, 2.2, lw=1.5)
draw_arrow(ax, 6.0, 3.0, 4.5, 2.2, lw=1.5)
ax.text(
4.0,
0.3,
"Zero kopiowania!\nA pisze → B widzi od razu\nWymaga synchronizacji (semafor)",
fontsize=FS_SMALL,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
# Panel 3: Socket
ax = axes[2]
ax.set_xlim(0, 8)
ax.set_ylim(0, 5)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Socket", fontsize=FS_LABEL, fontweight="bold")
# Network socket
draw_box(
ax, 0.3, 3.2, 1.8, 0.9, "Klient", fill=GRAY1, fontsize=FS, fontweight="bold"
)
draw_box(
ax, 5.5, 3.2, 1.8, 0.9, "Serwer", fill=GRAY1, fontsize=FS, fontweight="bold"
)
draw_double_arrow(ax, 2.1, 3.65, 5.5, 3.65, lw=1.5)
ax.text(3.8, 4.3, "TCP/IP (sieciowy)", fontsize=FS, ha="center", fontweight="bold")
# Unix socket
draw_box(
ax, 0.3, 1.3, 1.8, 0.9, "Proces A", fill=GRAY4, fontsize=FS, fontweight="bold"
)
draw_box(
ax, 5.5, 1.3, 1.8, 0.9, "Proces B", fill=GRAY4, fontsize=FS, fontweight="bold"
)
draw_double_arrow(ax, 2.1, 1.75, 5.5, 1.75, lw=1.5)
ax.text(
3.8,
2.4,
"Unix domain socket\n(/tmp/app.sock)",
fontsize=FS,
ha="center",
fontweight="bold",
)
ax.text(
3.8,
0.5,
"Dwukierunkowy\nNajbardziej uniwersalny IPC",
fontsize=FS_SMALL,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
fig.tight_layout(rect=[0, 0, 1, 0.9])
save_fig(fig, "q9_ipc_details.png")
# ============================================================
# 9. IPC comparison table
# ============================================================
def gen_ipc_table() -> None:
"""Gen ipc table."""
fig, ax = plt.subplots(figsize=(8.5, 3.5))
ax.set_xlim(0, 11)
ax.set_ylim(-4.5, 1)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Porównanie mechanizmów IPC", fontsize=FS_TITLE, fontweight="bold", pad=10
)
headers = ["Mechanizm", "Kierunek", "Szybkość", "Zastosowanie"]
col_w = [2.5, 2.0, 2.5, 3.5]
rows = [
["Pipe", "jednokierunkowy", "średnia", "ls | grep"],
["Named Pipe", "jednokierunkowy", "średnia", "demon → klient"],
["Shared Memory", "dwukierunkowy", "NAJSZYBSZA", "video, bazy danych"],
["Message Queue", "dwukierunkowy", "średnia", "wieloproducentowe"],
["Socket", "dwukierunkowy", "wolna (sieć)", "klient-serwer"],
["Signal", "jednokierunkowy", "natychmiastowa", "powiadomienia (nr)"],
]
draw_table(
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7.5
)
save_fig(fig, "q9_ipc_table.png")

View File

@ -0,0 +1,404 @@
"""Q9 diagrams 10-13: race conditions, deadlock, Coffman, starvation."""
from __future__ import annotations
from _q9_common import (
FS,
FS_LABEL,
FS_SMALL,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
draw_arrow,
draw_box,
draw_table,
save_fig,
)
import matplotlib.pyplot as plt
# ============================================================
# 10. Race condition (simple x + bank timeline)
# ============================================================
def gen_race_condition() -> None:
"""Gen race condition."""
fig, axes = plt.subplots(1, 2, figsize=(11, 5))
fig.suptitle(
"Wyścig (Race Condition) — przykłady", fontsize=FS_TITLE, fontweight="bold"
)
# Panel 1: simple x increment
ax = axes[0]
ax.set_xlim(0, 8)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Prosty wyścig: x = x + 1", fontsize=FS_LABEL, fontweight="bold")
# Timeline
steps_a = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"]
steps_b = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"]
ax.text(2.0, 6.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold")
ax.text(6.0, 6.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold")
ax.plot([2, 2], [0.8, 6.0], color=LN, lw=1)
ax.plot([6, 6], [0.8, 6.0], color=LN, lw=1)
for i, (sa, sb) in enumerate(zip(steps_a, steps_b, strict=False)):
y = 5.3 - i * 1.2
draw_box(ax, 0.5, y, 3.0, 0.6, sa, fill=GRAY4, fontsize=FS)
draw_box(ax, 4.5, y - 0.3, 3.0, 0.6, sb, fill=GRAY1, fontsize=FS)
ax.text(
4.0,
0.4,
"Wynik: x = 1 (powinno 2!)",
fontsize=FS,
ha="center",
fontweight="bold",
color="#C62828",
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
)
# Panel 2: bank account
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 7)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Konto bankowe: saldo = 1000 zł", fontsize=FS_LABEL, fontweight="bold")
ax.text(2.5, 6.3, "Wątek A (+500)", fontsize=FS, ha="center", fontweight="bold")
ax.text(7.5, 6.3, "Wątek B (-200)", fontsize=FS, ha="center", fontweight="bold")
ax.plot([2.5, 2.5], [0.8, 6.0], color=LN, lw=1)
ax.plot([7.5, 7.5], [0.8, 6.0], color=LN, lw=1)
events = [
("t1", "czytaj → 1000", "", 5.3),
("t2", "", "czytaj → 1000", 4.6),
("t3", "1000+500=1500", "", 3.9),
("t4", "", "1000-200=800", 3.2),
("t5", "zapisz 1500", "", 2.5),
("t6", "", "zapisz 800 ✗", 1.8),
]
for t, a, b, y in events:
ax.text(0.3, y + 0.15, t, fontsize=FS_SMALL, fontweight="bold", va="center")
if a:
draw_box(ax, 1.0, y, 3.0, 0.45, a, fill=GRAY4, fontsize=FS_SMALL)
if b:
fill = "#F8D7DA" if "" in b else GRAY1
draw_box(ax, 6.0, y, 3.0, 0.45, b, fill=fill, fontsize=FS_SMALL)
ax.text(
5.0,
0.4,
"Wynik: 800 zł (powinno 1300!)",
fontsize=FS,
ha="center",
fontweight="bold",
color="#C62828",
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
)
fig.tight_layout(rect=[0, 0, 1, 0.9])
save_fig(fig, "q9_race_condition.png")
# ============================================================
# 11. Deadlock scenario + cycle
# ============================================================
def gen_deadlock_scenario() -> None:
"""Gen deadlock scenario."""
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
fig.suptitle("Zakleszczenie (Deadlock)", fontsize=FS_TITLE, fontweight="bold")
# Panel 1: timeline
ax = axes[0]
ax.set_xlim(0, 8)
ax.set_ylim(0, 6)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Scenariusz z 2 mutexami", fontsize=FS_LABEL, fontweight="bold")
ax.text(2.5, 5.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold")
ax.text(6.0, 5.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold")
steps = [
("lock(mutex1) OK", "", "trzyma", False, 4.5),
("", "lock(mutex2) OK", "trzyma", False, 3.7),
("lock(mutex2) ...WAIT", "", "CZEKA!", True, 2.9),
("", "lock(mutex1) ...WAIT", "CZEKA!", True, 2.1),
]
for a_text, b_text, _note, is_wait, y in steps:
if a_text:
fill = "#F8D7DA" if is_wait else GRAY4
draw_box(ax, 0.5, y, 3.3, 0.55, a_text, fill=fill, fontsize=FS_SMALL)
if b_text:
fill = "#F8D7DA" if is_wait else GRAY4
draw_box(ax, 4.3, y, 3.3, 0.55, b_text, fill=fill, fontsize=FS_SMALL)
ax.text(
4.0,
1.2,
"DEADLOCK!\nŻaden nie odpuści",
fontsize=FS,
ha="center",
fontweight="bold",
color="#C62828",
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
)
# Panel 2: cycle diagram
ax = axes[1]
ax.set_xlim(0, 8)
ax.set_ylim(0, 6)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Cykl oczekiwania", fontsize=FS_LABEL, fontweight="bold")
# Thread boxes
draw_box(
ax,
0.5,
3.5,
2.2,
1.2,
"Wątek A\ntrzyma Mutex 1",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
draw_box(
ax,
5.3,
3.5,
2.2,
1.2,
"Wątek B\ntrzyma Mutex 2",
fill=GRAY1,
fontsize=FS,
fontweight="bold",
)
# Mutex boxes
draw_box(
ax, 0.5, 1.0, 2.2, 1.0, "Mutex 1", fill=GRAY3, fontsize=FS, fontweight="bold"
)
draw_box(
ax, 5.3, 1.0, 2.2, 1.0, "Mutex 2", fill=GRAY3, fontsize=FS, fontweight="bold"
)
# holds arrows (down)
draw_arrow(ax, 1.6, 3.5, 1.6, 2.0, lw=2)
ax.text(0.9, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center")
draw_arrow(ax, 6.4, 3.5, 6.4, 2.0, lw=2)
ax.text(7.0, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center")
# waits-for arrows (across, red)
draw_arrow(ax, 2.7, 4.3, 5.3, 4.3, lw=2.5, color="#C62828")
ax.text(
4.0,
4.7,
"czeka na Mutex 2",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
color="#C62828",
)
draw_arrow(ax, 5.3, 3.7, 2.7, 3.7, lw=2.5, color="#C62828")
ax.text(
4.0,
3.1,
"czeka na Mutex 1",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
color="#C62828",
)
fig.tight_layout(rect=[0, 0, 1, 0.9])
save_fig(fig, "q9_deadlock_scenario.png")
# ============================================================
# 12. Coffman conditions + prevention strategies
# ============================================================
def gen_coffman_strategies() -> None:
"""Gen coffman strategies."""
fig, ax = plt.subplots(figsize=(9, 4))
ax.set_xlim(0, 11.5)
ax.set_ylim(-3.5, 1)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title(
"Warunki Coffmana — zapobieganie deadlockowi",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
headers = ["Warunek", "Opis", "Jak złamać", "Przykład"]
col_w = [2.5, 2.5, 3.0, 3.0]
rows = [
[
"1. Mutual Exclusion",
"zasób wyłączny",
"współdzielony zasób",
"Read-write lock",
],
[
"2. Hold and Wait",
"trzymaj + czekaj",
"bierz WSZYSTKIE naraz",
"lock(m1,m2) atomowo",
],
[
"3. No Preemption",
"nie zabierzesz siłą",
"timeout / trylock",
"pthread_mutex_trylock()",
],
[
"4. Circular Wait",
"cykliczne oczekiw.",
"porządek liniowy",
"zawsze m1 przed m2",
],
]
draw_table(
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.6, fontsize=7
)
ax.text(
5.75,
-3.1,
"▸ Najczęstsza strategia: PORZĄDEK LINIOWY — "
"numeruj mutexy, zawsze blokuj rosnąco",
fontsize=FS,
ha="center",
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
)
save_fig(fig, "q9_coffman_strategies.png")
# ============================================================
# 13. Starvation + Priority Inversion (2-panel)
# ============================================================
def gen_starvation_priority() -> None:
"""Gen starvation priority."""
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
fig.suptitle(
"Zagłodzenie i Inwersja priorytetów", fontsize=FS_TITLE, fontweight="bold"
)
# Panel 1: Starvation + aging
ax = axes[0]
ax.set_xlim(0, 8)
ax.set_ylim(0, 6)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Zagłodzenie (Starvation)", fontsize=FS_LABEL, fontweight="bold")
threads = [
("Wątek HIGH", "prio=10", GRAY5, 3.0),
("Wątek HIGH", "prio=9", GRAY3, 2.2),
("Wątek MED", "prio=5", GRAY2, 1.4),
("Wątek LOW", "prio=1 → głoduje!", "#F8D7DA", 0.6),
]
for name, prio, color, y in threads:
draw_box(
ax, 0.5, y, 2.0, 0.6, name, fill=color, fontsize=FS_SMALL, fontweight="bold"
)
ax.text(2.8, y + 0.3, prio, fontsize=FS_SMALL, va="center")
ax.text(
1.5,
4.2,
"CPU zawsze\ndostaje HIGH!",
fontsize=FS,
ha="center",
fontweight="bold",
)
draw_arrow(ax, 1.5, 3.9, 1.5, 3.65, lw=1.5)
# Aging solution
draw_box(ax, 4.5, 1.5, 3.2, 2.5, "", fill=GRAY4, rounded=True)
ax.text(6.1, 3.7, "Rozwiązanie: AGING", fontsize=FS, fontweight="bold", ha="center")
aging = [
"t=0: prio=1",
"t=100ms: prio=2",
"t=200ms: prio=3",
"...",
"w końcu → CPU!",
]
for i, line in enumerate(aging):
ax.text(
6.1, 3.2 - i * 0.4, line, fontsize=FS_SMALL, ha="center", family="monospace"
)
# Panel 2: Priority Inversion
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 6)
ax.set_aspect("auto")
ax.axis("off")
ax.set_title("Inwersja priorytetów", fontsize=FS_LABEL, fontweight="bold")
# Timeline
labels = ["H (wysoki)", "M (średni)", "L (niski)"]
ys = [4.2, 2.8, 1.4]
for label, y in zip(labels, ys, strict=False):
ax.text(0.3, y + 0.2, label, fontsize=FS, fontweight="bold", va="center")
# L runs and locks mutex
draw_box(ax, 2.0, ys[2], 1.2, 0.5, "lock(m)", fill=GRAY1, fontsize=FS_SMALL)
# M preempts L
draw_box(ax, 3.5, ys[1], 3.0, 0.5, "M pracuje...", fill=GRAY3, fontsize=FS_SMALL)
# H waits for mutex
draw_box(
ax,
3.5,
ys[0],
3.0,
0.5,
"CZEKA na mutex!",
fill="#F8D7DA",
fontsize=FS_SMALL,
fontweight="bold",
)
# M finishes, L continues, unlocks
draw_box(ax, 6.8, ys[2], 1.5, 0.5, "unlock(m)", fill=GRAY1, fontsize=FS_SMALL)
draw_box(ax, 8.5, ys[0], 1.2, 0.5, "H runs", fill=GRAY4, fontsize=FS_SMALL)
# Explanation
ax.text(
5.0,
0.5,
"H czeka na M (mimo H > M)!\n"
"Rozwiązanie: Priority Inheritance\n"
"L dziedziczy priorytet H → M nie wypycha L",
fontsize=FS_SMALL,
ha="center",
style="italic",
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.text(
5.0,
0.0,
"Mars Pathfinder (1997) — klasyczny bug!",
fontsize=FS_SMALL,
ha="center",
fontweight="bold",
)
fig.tight_layout(rect=[0, 0, 1, 0.9])
save_fig(fig, "q9_starvation_priority.png")

View File

@ -0,0 +1,87 @@
"""Common constants and utilities for scheduling diagrams."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
if TYPE_CHECKING:
from matplotlib.axes import Axes
_logger = logging.getLogger(__name__)
DPI = 300
BG = "white"
LN = "black"
FS = 8
FS_TITLE = 11
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
GRAY1 = "#E8E8E8"
GRAY2 = "#D0D0D0"
GRAY3 = "#B8B8B8"
GRAY4 = "#F5F5F5"
GRAY5 = "#C0C0C0"
MIN_COLUMN_INDEX = 3
FONTWEIGHT_THRESHOLD = 3
def draw_box(
ax: Axes,
x: float,
y: float,
w: float,
h: float,
text: str,
fill: str = "white",
lw: float = 1.2,
fontsize: float = FS,
fontweight: str = "normal",
ha: str = "center",
va: str = "center",
*,
rounded: bool = True,
) -> None:
"""Draw box."""
if rounded:
rect = FancyBboxPatch(
(x, y), w, h, boxstyle="round,pad=0.05", lw=lw, edgecolor=LN, facecolor=fill
)
else:
rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill)
ax.add_patch(rect)
ax.text(
x + w / 2,
y + h / 2,
text,
ha=ha,
va=va,
fontsize=fontsize,
fontweight=fontweight,
wrap=True,
)
def draw_arrow(
ax: Axes,
x1: float,
y1: float,
x2: float,
y2: float,
lw: float = 1.2,
style: str = "->",
color: str = LN,
) -> None:
"""Draw arrow."""
ax.annotate(
"",
xy=(x2, y2),
xytext=(x1, y1),
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
)

View File

@ -0,0 +1,309 @@
"""Scheduling complexity landscape and EDD example diagrams."""
from __future__ import annotations
import logging
from pathlib import Path
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
BG,
DPI,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
OUTPUT_DIR,
draw_arrow,
)
_logger = logging.getLogger(__name__)
# ============================================================
# SCHEDULING COMPLEXITY LANDSCAPE
# ============================================================
def draw_complexity_map() -> None:
"""Draw complexity map."""
_fig, ax = plt.subplots(1, 1, figsize=(8.27, 5))
ax.set_xlim(0, 10)
ax.set_ylim(0, 7)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Złożoność problemów szeregowania — od łatwych do NP-trudnych",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Gradient arrow at the top
ax.annotate(
"",
xy=(9.5, 6.2),
xytext=(0.5, 6.2),
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
)
ax.text(5, 6.5, "Rosnąca złożoność", ha="center", fontsize=9, fontweight="bold")
# Easy (polynomial) region
easy_rect = FancyBboxPatch(
(0.3, 2.8),
4.0,
3.0,
boxstyle="round,pad=0.15",
lw=1.5,
edgecolor="#666666",
facecolor=GRAY4,
linestyle="-",
)
ax.add_patch(easy_rect)
ax.text(
2.3,
5.5,
"WIELOMIANOWE O(n log n)",
ha="center",
fontsize=9,
fontweight="bold",
color="#444444",
)
easy_problems = [
("1 || ΣCⱼ", "SPT", GRAY1, 4.8),
("1 || Lmax", "EDD", GRAY2, 4.0),
("F2 || Cmax", "Johnson", GRAY1, 3.2),
]
for prob, method, fill, y in easy_problems:
rect = FancyBboxPatch(
(0.6, y),
3.5,
0.6,
boxstyle="round,pad=0.05",
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
1.2,
y + 0.3,
prob,
ha="center",
va="center",
fontsize=8,
fontweight="bold",
fontfamily="monospace",
)
ax.text(3.0, y + 0.3, f"{method}", ha="center", va="center", fontsize=8)
# Hard (NP) region
hard_rect = FancyBboxPatch(
(5.3, 2.8),
4.3,
3.0,
boxstyle="round,pad=0.15",
lw=1.5,
edgecolor="#444444",
facecolor=GRAY3,
linestyle="-",
)
ax.add_patch(hard_rect)
ax.text(
7.45,
5.5,
"NP-TRUDNE",
ha="center",
fontsize=9,
fontweight="bold",
color="#333333",
)
hard_problems = [
("Pm || Cmax\n(m≥2)", "LPT heuryst.", GRAY2, 4.5),
("1 || ΣTⱼ", "branch&bound", GRAY4, 3.7),
("Jm || Cmax\n(m≥3)", "metaheuryst.", GRAY5, 2.9),
]
for prob, method, fill, y in hard_problems:
rect = FancyBboxPatch(
(5.6, y),
3.7,
0.7,
boxstyle="round,pad=0.05",
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
6.5,
y + 0.35,
prob,
ha="center",
va="center",
fontsize=7,
fontweight="bold",
fontfamily="monospace",
)
ax.text(8.2, y + 0.35, f"{method}", ha="center", va="center", fontsize=7)
# Arrow connecting
draw_arrow(ax, 4.4, 4.0, 5.2, 4.0, lw=2, color="#888888")
ax.text(4.8, 4.25, "+1\nmaszyna", ha="center", fontsize=6, color="#888888")
# Bottom: key insight
ax.text(
5.0,
1.8,
"„Dodanie jednej maszyny lub jednego ograniczenia\n"
'może zmienić problem z łatwego na NP-trudny!"',
ha="center",
fontsize=8,
fontweight="bold",
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
"lw": 1,
},
)
# Bottom examples
ax.text(
5.0,
0.8,
"1 maszyna → łatwe (sortuj) | ≥2 maszyny równoległe → NP-trudne\n"
"Flow shop 2 maszyny → Johnson O(n log n) | Flow shop 3 maszyny → NP-trudne",
ha="center",
fontsize=7,
color="#555555",
)
plt.tight_layout()
plt.savefig(
str(Path(OUTPUT_DIR) / "scheduling_complexity_map.png"),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close()
_logger.info(" ✓ scheduling_complexity_map.png")
# ============================================================
# EDD EXAMPLE (1 || Lmax)
# ============================================================
def draw_edd_example() -> None:
"""Draw edd example."""
_fig, ax = plt.subplots(1, 1, figsize=(8.27, 4))
ax.set_xlim(-2, 28)
ax.set_ylim(-2, 4)
ax.axis("off")
ax.set_title(
"EDD (Earliest Due Date) — 1 || Lmax — Przykład",
fontsize=FS_TITLE,
fontweight="bold",
pad=8,
)
# Tasks: name, processing time, due date
tasks = [("J1", 4, 10), ("J2", 2, 6), ("J3", 6, 15), ("J4", 3, 8), ("J5", 5, 18)]
# EDD: sort by due date
edd_order = sorted(tasks, key=lambda x: x[2])
bar_y = 1.5
bar_h = 0.8
t = 0
fills_edd = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5]
lateness_vals = []
for i, (name, p, d) in enumerate(edd_order):
rect = mpatches.Rectangle(
(t, bar_y), p, bar_h, lw=1.2, edgecolor=LN, facecolor=fills_edd[i]
)
ax.add_patch(rect)
ax.text(
t + p / 2,
bar_y + bar_h / 2,
f"{name}\np={p}, d={d}",
ha="center",
va="center",
fontsize=6.5,
fontweight="bold",
)
t += p
lateness = t - d
lateness_vals.append(lateness)
# Due date marker
ax.plot(
[d, d], [bar_y - 0.4, bar_y - 0.1], color="#888888", lw=0.8, linestyle="--"
)
ax.text(
d,
bar_y - 0.5,
f"d={d}",
ha="center",
va="top",
fontsize=5.5,
color="#888888",
)
# Completion + lateness
ax.plot([t, t], [bar_y + bar_h, bar_y + bar_h + 0.15], color=LN, lw=0.8)
ax.text(
t,
bar_y + bar_h + 0.2,
f"C={t}\nL={lateness}",
ha="center",
va="bottom",
fontsize=5.5,
)
# Time axis
ax.plot([0, 22], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5)
lmax = max(lateness_vals)
ax.text(
22,
bar_y + bar_h / 2,
f"Lmax = {lmax}",
ha="left",
va="center",
fontsize=10,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY1, "edgecolor": LN},
)
# Bottom mnemonic
ax.text(
10,
-1.3,
'„Early Due Date Does it first" — najpilniejszy deadline idzie pierwszy',
ha="center",
fontsize=8,
fontweight="bold",
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
"lw": 0.8,
},
)
plt.tight_layout()
plt.savefig(
str(Path(OUTPUT_DIR) / "scheduling_edd_example.png"),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close()
_logger.info(" ✓ scheduling_edd_example.png")

View File

@ -0,0 +1,484 @@
"""Graham notation α|β|γ visual mnemonic map diagram."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from matplotlib.patches import FancyBboxPatch
import matplotlib.pyplot as plt
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
BG,
DPI,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
OUTPUT_DIR,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
_logger = logging.getLogger(__name__)
def draw_graham_notation() -> None:
"""Draw graham notation."""
_fig, ax = plt.subplots(1, 1, figsize=(8.27, 10))
ax.set_xlim(0, 10)
ax.set_ylim(0, 14)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Notacja Grahama \u03b1 | β | \u03b3 — Mapa mnemoniczna",
fontsize=FS_TITLE + 1,
fontweight="bold",
pad=12,
)
_draw_graham_formula_bar(ax)
_draw_graham_alpha_beta(ax)
_draw_graham_lower(ax)
plt.tight_layout()
plt.savefig(
str(Path(OUTPUT_DIR) / "scheduling_graham_notation.png"),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close()
_logger.info(" ✓ scheduling_graham_notation.png")
def _draw_graham_formula_bar(ax: Axes) -> None:
"""Draw the top alpha|beta|gamma formula bar."""
bar_y = 12.5
bar_h = 1.0
# alpha box
rect = FancyBboxPatch(
(0.5, bar_y),
2.5,
bar_h,
boxstyle="round,pad=0.08",
lw=2,
edgecolor=LN,
facecolor=GRAY1,
)
ax.add_patch(rect)
ax.text(
1.75,
bar_y + bar_h / 2,
"\u03b1",
fontsize=20,
fontweight="bold",
ha="center",
va="center",
)
ax.text(
1.75,
bar_y - 0.25,
"MASZYNY",
fontsize=8,
fontweight="bold",
ha="center",
va="top",
color="#444444",
)
# separator |
ax.text(
3.3,
bar_y + bar_h / 2,
"|",
fontsize=24,
fontweight="bold",
ha="center",
va="center",
)
# β box
rect = FancyBboxPatch(
(3.7, bar_y),
2.5,
bar_h,
boxstyle="round,pad=0.08",
lw=2,
edgecolor=LN,
facecolor=GRAY2,
)
ax.add_patch(rect)
ax.text(
4.95,
bar_y + bar_h / 2,
"β",
fontsize=20,
fontweight="bold",
ha="center",
va="center",
)
ax.text(
4.95,
bar_y - 0.25,
"OGRANICZENIA",
fontsize=8,
fontweight="bold",
ha="center",
va="top",
color="#444444",
)
# separator |
ax.text(
6.5,
bar_y + bar_h / 2,
"|",
fontsize=24,
fontweight="bold",
ha="center",
va="center",
)
# gamma box
rect = FancyBboxPatch(
(6.9, bar_y),
2.5,
bar_h,
boxstyle="round,pad=0.08",
lw=2,
edgecolor=LN,
facecolor=GRAY3,
)
ax.add_patch(rect)
ax.text(
8.15,
bar_y + bar_h / 2,
"\u03b3",
fontsize=20,
fontweight="bold",
ha="center",
va="center",
)
ax.text(
8.15,
bar_y - 0.25,
"CEL",
fontsize=8,
fontweight="bold",
ha="center",
va="top",
color="#444444",
)
def _draw_graham_alpha_beta(ax: Axes) -> None:
"""Draw alpha (machines) and beta (constraints) sections."""
start_x = 0.3
col_w = 1.28
# === SECTION alpha: MACHINES ===
sec_y = 11.5
ax.text(
0.3,
sec_y,
'\u03b1 — „1 Prawdziwy Quasi-Rycerz Forsuje Jaskinię Orków"',
fontsize=8,
fontweight="bold",
va="top",
style="italic",
color="#333333",
)
alpha_items = [
("1", "jedna maszyna", "", GRAY4),
("Pm", "identyczne Parallel", "●●●", GRAY1),
("Qm", "Quasi-uniform\n(różne prędkości)", "●●◐", GRAY4),
("Rm", "Random unrelated\n(czasy per para)", "●◆▲", GRAY1),
("Fm", "Flow shop\n(ta sama kolejność)", "→→→", GRAY2),
("Jm", "Job shop\n(indyw. trasy)", "↗↙↘", GRAY4),
("Om", "Open shop\n(dowolna kolej.)", "?→?", GRAY1),
]
col_w = 1.28
box_h_a = 1.1
start_x = 0.3
start_y = 9.6
for i, (symbol, desc, icon, fill) in enumerate(alpha_items):
x = start_x + i * col_w
y = start_y
rect = FancyBboxPatch(
(x, y),
col_w - 0.1,
box_h_a,
boxstyle="round,pad=0.04",
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + (col_w - 0.1) / 2,
y + box_h_a - 0.15,
symbol,
ha="center",
va="top",
fontsize=9,
fontweight="bold",
)
ax.text(
x + (col_w - 0.1) / 2,
y + box_h_a / 2 - 0.1,
desc,
ha="center",
va="center",
fontsize=5.5,
)
ax.text(
x + (col_w - 0.1) / 2, y + 0.12, icon, ha="center", va="bottom", fontsize=7
)
# Complexity arrow under alpha
arr_y = 9.35
ax.annotate(
"",
xy=(9.0, arr_y),
xytext=(0.5, arr_y),
arrowprops={"arrowstyle": "->", "color": "#666666", "lw": 1.5},
)
ax.text(
4.8,
arr_y - 0.18,
"rosnąca złożoność →",
ha="center",
fontsize=6,
color="#666666",
)
# === SECTION β: CONSTRAINTS ===
sec_y2 = 8.9
ax.text(
0.3,
sec_y2,
"β — „Robak Daje Deadline: Przerwy Poprzedzają Pojedyncze Setup'y\"",
fontsize=8,
fontweight="bold",
va="top",
style="italic",
color="#333333",
)
beta_items = [
("rⱼ", "release\ndates", "Robak\ndostępne\nod czasu rⱼ", GRAY1),
("dⱼ", "due\ndates", "Daje\ntermin soft\n(kara za spóźn.)", GRAY4),
("d̄ⱼ", "dead-\nlines", "Deadline\ntermin hard\n(musi dotrzymać)", GRAY1),
("pmtn", "preemp-\ntion", "Przerwy\nmożna\nprzerwać", GRAY2),
("prec", "prece-\ndencje", "Poprzedzają\nA->B (DAG)", GRAY4),
("pⱼ=1", "unit\ntime", "Pojedyncze\nwszystkie = 1", GRAY1),
("sⱼₖ", "setup\ntimes", "Setup'y\nprzezbrojenie\nmiędzy j->k", GRAY4),
]
start_y2 = 7.0
box_h_b = 1.4
for i, (symbol, _label, desc, fill) in enumerate(beta_items):
x = start_x + i * col_w
y = start_y2
rect = FancyBboxPatch(
(x, y),
col_w - 0.1,
box_h_b,
boxstyle="round,pad=0.04",
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + (col_w - 0.1) / 2,
y + box_h_b - 0.12,
symbol,
ha="center",
va="top",
fontsize=9,
fontweight="bold",
)
ax.text(
x + (col_w - 0.1) / 2,
y + box_h_b / 2 - 0.05,
desc,
ha="center",
va="center",
fontsize=5,
)
def _draw_graham_lower(ax: Axes) -> None:
"""Draw gamma criteria, examples, and footer sections."""
start_x = 0.3
# === SECTION gamma: CRITERIA ===
sec_y3 = 6.5
ax.text(
0.3,
sec_y3,
'\u03b3 — „Ciężki Sum Ważony Lata, Tardiness Uderza"',
fontsize=8,
fontweight="bold",
va="top",
style="italic",
color="#333333",
)
gamma_items = [
("Cmax", "makespan\nmax(Cⱼ)", "Jak długo\ntrwa WSZYSTKO?", GRAY2),
("ΣCⱼ", "suma\nukończeń", "Średni czas\noczekiwania?", GRAY4),
("ΣwⱼCⱼ", "ważona\nsuma", "Priorytety\nzadań?", GRAY1),
("Lmax", "max\nopóźnienie", "Najgorsze\nspóźnienie?", GRAY2),
("ΣTⱼ", "suma\nspóźnień", "Łączne\nspóźnienia?", GRAY4),
("ΣUⱼ", "liczba\nspóźnionych", "Ile spóźnionych\nzadań?", GRAY1),
]
start_y3 = 4.5
box_h_g = 1.4
col_w_g = 1.5
for i, (symbol, label, question, fill) in enumerate(gamma_items):
x = start_x + i * col_w_g
y = start_y3
rect = FancyBboxPatch(
(x, y),
col_w_g - 0.1,
box_h_g,
boxstyle="round,pad=0.04",
lw=1,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + (col_w_g - 0.1) / 2,
y + box_h_g - 0.1,
symbol,
ha="center",
va="top",
fontsize=9,
fontweight="bold",
)
ax.text(
x + (col_w_g - 0.1) / 2,
y + box_h_g / 2 - 0.05,
label,
ha="center",
va="center",
fontsize=6,
)
ax.text(
x + (col_w_g - 0.1) / 2,
y + 0.15,
f'{question}"',
ha="center",
va="bottom",
fontsize=5,
style="italic",
)
# === BOTTOM: Example + Optimal methods ===
ex_y = 3.5
ax.text(
0.3,
ex_y,
"Przykłady zapisu i optymalne metody:",
fontsize=8,
fontweight="bold",
va="top",
)
examples = [
("1 || ΣCⱼ", "SPT (najkrótsze\nnajpierw)", "O(n log n)", GRAY1),
("1 || Lmax", "EDD (najwcześniejszy\ntermin)", "O(n log n)", GRAY4),
("F2 || Cmax", "Algorytm\nJohnsona", "O(n log n)", GRAY2),
("Pm || Cmax", "LPT heurystyka\n(NP-trudny!)", "NP-hard", GRAY3),
("Jm || Cmax", "Branch & Bound\n(NP-trudny!)", "NP-hard", GRAY5),
]
ex_start_y = 1.8
ex_box_w = 1.72
ex_box_h = 1.4
for i, (notation, method, complexity, fill) in enumerate(examples):
x = start_x + i * (ex_box_w + 0.1)
y = ex_start_y
rect = FancyBboxPatch(
(x, y),
ex_box_w,
ex_box_h,
boxstyle="round,pad=0.04",
lw=1.2,
edgecolor=LN,
facecolor=fill,
)
ax.add_patch(rect)
ax.text(
x + ex_box_w / 2,
y + ex_box_h - 0.12,
notation,
ha="center",
va="top",
fontsize=8,
fontweight="bold",
fontfamily="monospace",
)
ax.text(
x + ex_box_w / 2,
y + ex_box_h / 2 - 0.05,
method,
ha="center",
va="center",
fontsize=6,
)
ax.text(
x + ex_box_w / 2,
y + 0.12,
complexity,
ha="center",
va="bottom",
fontsize=6.5,
fontweight="bold",
color="#555555",
)
# Footer mnemonic summary
ax.text(
5.0,
0.8,
'\u03b1|β|\u03b3 = Maszyny | Ograniczenia | Cel"',
ha="center",
fontsize=9,
fontweight="bold",
style="italic",
color="#333333",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
"lw": 1,
},
)
ax.text(
5.0,
0.2,
"\u03b1: ILE maszyn i JAKIE? "
"β: JAKIE ograniczenia zadań? "
"\u03b3: CO minimalizujemy?",
ha="center",
fontsize=7,
color="#555555",
)

View File

@ -0,0 +1,318 @@
"""Johnson's algorithm Gantt chart diagram (F2||Cmax)."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
BG,
DPI,
FONTWEIGHT_THRESHOLD,
FS_TITLE,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
MIN_COLUMN_INDEX,
OUTPUT_DIR,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
_logger = logging.getLogger(__name__)
def draw_johnson_gantt() -> None:
"""Draw johnson gantt."""
_fig, axes = plt.subplots(
2, 1, figsize=(8.27, 7), gridspec_kw={"height_ratios": [1, 1.8]}
)
_draw_johnson_decision_table(axes[0])
_draw_johnson_gantt_chart(axes[1])
plt.tight_layout()
plt.savefig(
str(Path(OUTPUT_DIR) / "scheduling_johnson_gantt.png"),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close()
_logger.info(" ✓ scheduling_johnson_gantt.png")
def _draw_johnson_decision_table(ax: Axes) -> None:
"""Draw the Johnson algorithm decision table."""
ax.set_xlim(0, 10)
ax.set_ylim(0, 5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
"Algorytm Johnsona (F2 || Cmax) — Decyzja + Diagram Gantta",
fontsize=FS_TITLE,
fontweight="bold",
pad=10,
)
# Task table
tasks = ["J1", "J2", "J3", "J4", "J5"]
a_times = [4, 2, 6, 1, 3]
b_times = [5, 3, 2, 7, 4]
min_vals = [min(a, b) for a, b in zip(a_times, b_times, strict=False)]
min_on = ["M1" if a <= b else "M2" for a, b in zip(a_times, b_times, strict=False)]
assign = ["POCZątek" if m == "M1" else "KONIEC" for m in min_on]
# Draw table
col_w_t = 1.3
row_h = 0.55
headers = ["Zadanie", "aⱼ (M1)", "bⱼ (M2)", "min", "min na", "Przydziel"]
table_x = 0.8
table_y = 3.8
for j, hdr in enumerate(headers):
x = table_x + j * col_w_t
rect = mpatches.Rectangle(
(x, table_y), col_w_t, row_h, lw=1, edgecolor=LN, facecolor=GRAY2
)
ax.add_patch(rect)
ax.text(
x + col_w_t / 2,
table_y + row_h / 2,
hdr,
ha="center",
va="center",
fontsize=6.5,
fontweight="bold",
)
for i in range(5):
row_data = [
tasks[i],
str(a_times[i]),
str(b_times[i]),
str(min_vals[i]),
min_on[i],
assign[i],
]
for j, val in enumerate(row_data):
x = table_x + j * col_w_t
y = table_y - (i + 1) * row_h
fill_c = GRAY1 if min_on[i] == "M1" else GRAY4
if j == MIN_COLUMN_INDEX: # min column - highlight
fill_c = GRAY3
rect = mpatches.Rectangle(
(x, y), col_w_t, row_h, lw=0.8, edgecolor=LN, facecolor=fill_c
)
ax.add_patch(rect)
fw = "bold" if j >= FONTWEIGHT_THRESHOLD else "normal"
ax.text(
x + col_w_t / 2,
y + row_h / 2,
val,
ha="center",
va="center",
fontsize=6.5,
fontweight=fw,
)
# Sorting result
result_y = 0.7
ax.text(
5.0,
result_y + 0.4,
"Sortuj → POCZĄTEK ↑aⱼ: J4(1), J2(2), J5(3), J1(4) | KONIEC ↓bⱼ: J3(2)",
ha="center",
fontsize=7,
color="#333333",
)
ax.text(
5.0,
result_y,
"Optymalna kolejność: J4 → J2 → J5 → J1 → J3",
ha="center",
fontsize=9,
fontweight="bold",
bbox={
"boxstyle": "round,pad=0.2",
"facecolor": GRAY1,
"edgecolor": LN,
"lw": 1.2,
},
)
def _draw_johnson_gantt_chart(ax2: Axes) -> None:
"""Draw the Johnson algorithm Gantt chart."""
ax2.set_xlim(-1, 24)
ax2.set_ylim(-1, 4)
ax2.axis("off")
# Machines labels
m1_y = 2.5
m2_y = 0.8
bar_h = 0.9
ax2.text(
-0.8,
m1_y + bar_h / 2,
"M1",
ha="center",
va="center",
fontsize=11,
fontweight="bold",
)
ax2.text(
-0.8,
m2_y + bar_h / 2,
"M2",
ha="center",
va="center",
fontsize=11,
fontweight="bold",
)
# Schedule: J4 → J2 → J5 → J1 → J3
order = ["J4", "J2", "J5", "J1", "J3"]
a_ord = [1, 2, 3, 4, 6] # M1 times in order
b_ord = [7, 3, 4, 5, 2] # M2 times in order
fills = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5]
hatches = ["", "///", "", "\\\\\\", "xxx"]
# M1 schedule
m1_starts = []
t = 0
for a in a_ord:
m1_starts.append(t)
t += a
m1_ends = [s + a for s, a in zip(m1_starts, a_ord, strict=False)]
# M2 schedule (must wait for M1 finish AND previous M2 finish)
m2_starts = []
m2_ends = []
prev_m2_end = 0
for i, b in enumerate(b_ord):
start = max(m1_ends[i], prev_m2_end)
m2_starts.append(start)
m2_ends.append(start + b)
prev_m2_end = start + b
# Draw M1 bars
for i in range(5):
rect = mpatches.Rectangle(
(m1_starts[i], m1_y),
a_ord[i],
bar_h,
lw=1.2,
edgecolor=LN,
facecolor=fills[i],
hatch=hatches[i],
)
ax2.add_patch(rect)
ax2.text(
m1_starts[i] + a_ord[i] / 2,
m1_y + bar_h / 2,
f"{order[i]}\n({a_ord[i]})",
ha="center",
va="center",
fontsize=7,
fontweight="bold",
)
# Draw M2 bars
for i in range(5):
rect = mpatches.Rectangle(
(m2_starts[i], m2_y),
b_ord[i],
bar_h,
lw=1.2,
edgecolor=LN,
facecolor=fills[i],
hatch=hatches[i],
)
ax2.add_patch(rect)
ax2.text(
m2_starts[i] + b_ord[i] / 2,
m2_y + bar_h / 2,
f"{order[i]}\n({b_ord[i]})",
ha="center",
va="center",
fontsize=7,
fontweight="bold",
)
# Draw idle regions on M2
idle_starts = [0]
idle_ends = [m2_starts[0]]
for i in range(1, 5):
if m2_starts[i] > m2_ends[i - 1]:
idle_starts.append(m2_ends[i - 1])
idle_ends.append(m2_starts[i])
for s, e in zip(idle_starts, idle_ends, strict=False):
if e > s:
rect = mpatches.Rectangle(
(s, m2_y),
e - s,
bar_h,
lw=0.5,
edgecolor="#AAAAAA",
facecolor="white",
linestyle="--",
)
ax2.add_patch(rect)
ax2.text(
s + (e - s) / 2,
m2_y + bar_h / 2,
"idle",
ha="center",
va="center",
fontsize=5,
color="#999999",
)
# Time axis
ax_y = m2_y - 0.15
ax2.plot([0, 23], [ax_y, ax_y], color=LN, lw=0.8)
for t in range(0, 24, 2):
ax2.plot([t, t], [ax_y - 0.08, ax_y + 0.08], color=LN, lw=0.8)
ax2.text(t, ax_y - 0.25, str(t), ha="center", va="top", fontsize=6)
ax2.text(11.5, ax_y - 0.55, "czas", ha="center", fontsize=7)
# Cmax annotation
ax2.annotate(
f"Cmax = {m2_ends[-1]}",
xy=(m2_ends[-1], m2_y + bar_h),
xytext=(m2_ends[-1] + 0.5, m2_y + bar_h + 0.6),
fontsize=10,
fontweight="bold",
color="#333333",
arrowprops={"arrowstyle": "->", "color": "#333333", "lw": 1.5},
)
# Mnemonic at bottom
ax2.text(
11,
-0.7,
"„Krótki na M1 → START (szybko karmi M2)"
" Krótki na M2 → KONIEC"
' (szybko kończy)"',
ha="center",
fontsize=7.5,
fontweight="bold",
style="italic",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": GRAY4,
"edgecolor": GRAY3,
"lw": 0.8,
},
)

View File

@ -0,0 +1,352 @@
"""SPT vs LPT comparison and Flow Shop vs Job Shop diagrams."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
BG,
DPI,
GRAY1,
GRAY2,
GRAY3,
GRAY4,
GRAY5,
LN,
OUTPUT_DIR,
draw_arrow,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
_logger = logging.getLogger(__name__)
# ============================================================
# SPT vs LPT COMPARISON (1 || ΣCⱼ)
# ============================================================
def draw_spt_comparison() -> None:
"""Draw spt comparison."""
fig, axes = plt.subplots(2, 1, figsize=(8.27, 5.5))
tasks_orig = [("J1", 5), ("J2", 3), ("J3", 8), ("J4", 2), ("J5", 6)]
spt_order = sorted(tasks_orig, key=lambda x: x[1])
lpt_order = sorted(tasks_orig, key=lambda x: -x[1])
fills_map = {"J1": GRAY1, "J2": GRAY2, "J3": GRAY3, "J4": GRAY4, "J5": GRAY5}
hatch_map = {"J1": "", "J2": "///", "J3": "xxx", "J4": "", "J5": "\\\\\\"}
for _idx, (ax, order_list, title, is_optimal) in enumerate(
[
(axes[0], spt_order, "SPT (Shortest Processing Time) — OPTYMALNE", True),
(axes[1], lpt_order, "LPT (Longest Processing Time) — gorsze!", False),
]
):
ax.set_xlim(-2, 26)
ax.set_ylim(-0.5, 2.5)
ax.axis("off")
color = "#222222" if is_optimal else "#666666"
marker = "" if is_optimal else ""
ax.set_title(
f"{marker} {title}",
fontsize=9,
fontweight="bold",
loc="left",
color=color,
pad=5,
)
bar_y = 1.0
bar_h = 0.8
t = 0
completions = []
for name, duration in order_list:
rect = mpatches.Rectangle(
(t, bar_y),
duration,
bar_h,
lw=1.2,
edgecolor=LN,
facecolor=fills_map[name],
hatch=hatch_map[name],
)
ax.add_patch(rect)
ax.text(
t + duration / 2,
bar_y + bar_h / 2,
f"{name}\n({duration})",
ha="center",
va="center",
fontsize=7,
fontweight="bold",
)
t += duration
completions.append(t)
# Completion time marker
ax.plot([t, t], [bar_y - 0.15, bar_y], color=LN, lw=0.8)
ax.text(
t,
bar_y - 0.25,
f"C={t}",
ha="center",
va="top",
fontsize=6,
color="#555555",
)
total = sum(completions)
# Time axis
ax.plot([0, 25], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5)
# Sum annotation
comp_str = " + ".join(str(c) for c in completions)
ax.text(
25,
bar_y + bar_h / 2,
f"ΣCⱼ = {comp_str}\n = {total}",
ha="left",
va="center",
fontsize=7,
fontweight="bold" if is_optimal else "normal",
color=color,
bbox={
"boxstyle": "round,pad=0.2",
"facecolor": GRAY1 if is_optimal else "white",
"edgecolor": color,
"lw": 1,
},
)
# Bottom annotation
fig.text(
0.5,
0.02,
'„Short People To the front"'
" — krótkie najpierw,"
" jak niskie osoby w zdjęciu klasowym",
ha="center",
fontsize=8,
fontweight="bold",
style="italic",
color="#444444",
)
plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig(
str(Path(OUTPUT_DIR) / "scheduling_spt_comparison.png"),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close()
_logger.info(" ✓ scheduling_spt_comparison.png")
# ============================================================
# FLOW SHOP vs JOB SHOP
# ============================================================
def draw_flow_vs_job() -> None:
"""Draw flow vs job."""
_fig, axes = plt.subplots(1, 2, figsize=(8.27, 4.5))
_draw_flow_shop(axes[0])
_draw_job_shop(axes[1])
plt.tight_layout()
plt.savefig(
str(Path(OUTPUT_DIR) / "scheduling_flow_vs_job.png"),
dpi=DPI,
bbox_inches="tight",
facecolor=BG,
)
plt.close()
_logger.info(" ✓ scheduling_flow_vs_job.png")
def _draw_flow_shop(ax: Axes) -> None:
"""Draw the Flow Shop diagram."""
ax.set_xlim(0, 6)
ax.set_ylim(0, 6)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Flow Shop (Fm)", fontsize=10, fontweight="bold", pad=8)
# Machines in a row
machines_x = [1, 3, 5]
machines_y = 3
mach_r = 0.4
for i, mx in enumerate(machines_x):
circle = plt.Circle(
(mx, machines_y), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5
)
ax.add_patch(circle)
ax.text(
mx,
machines_y,
f"M{i + 1}",
ha="center",
va="center",
fontsize=9,
fontweight="bold",
)
# Arrows between machines
for i in range(len(machines_x) - 1):
draw_arrow(
ax,
machines_x[i] + mach_r + 0.05,
machines_y,
machines_x[i + 1] - mach_r - 0.05,
machines_y,
lw=2,
)
# Jobs all flowing the same way
jobs_flow = ["J1", "J2", "J3"]
for _j, (job, y_off) in enumerate(zip(jobs_flow, [0.8, 0, -0.8], strict=False)):
ax.text(
0.2,
machines_y + y_off,
job,
ha="center",
va="center",
fontsize=7,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.15", "facecolor": GRAY1, "edgecolor": LN},
)
# Dashed flow line
ax.annotate(
"",
xy=(5.5, machines_y + y_off * 0.3),
xytext=(0.5, machines_y + y_off),
arrowprops={
"arrowstyle": "->",
"color": "#888888",
"lw": 0.8,
"linestyle": "dashed",
},
)
ax.text(
3,
1.2,
"Wszystkie zadania:\nM1 → M2 → M3",
ha="center",
va="center",
fontsize=8,
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.text(
3,
0.4,
"Jak taśma montażowa",
ha="center",
fontsize=7,
style="italic",
color="#666666",
)
def _draw_job_shop(ax: Axes) -> None:
"""Draw the Job Shop diagram."""
mach_r = 0.4
ax.set_xlim(0, 6)
ax.set_ylim(0, 6)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Job Shop (Jm)", fontsize=10, fontweight="bold", pad=8)
# Machines scattered
m_positions = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)]
for i, (mx, my) in enumerate(m_positions):
circle = plt.Circle((mx, my), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5)
ax.add_patch(circle)
ax.text(
mx, my, f"M{i + 1}", ha="center", va="center", fontsize=9, fontweight="bold"
)
# J1: M1 → M2 → M3 (solid)
route1 = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)]
for i in range(len(route1) - 1):
x1, y1 = route1[i]
x2, y2 = route1[i + 1]
dx = x2 - x1
dy = y2 - y1
d = (dx**2 + dy**2) ** 0.5
draw_arrow(
ax,
x1 + mach_r * dx / d + 0.05,
y1 + mach_r * dy / d,
x2 - mach_r * dx / d - 0.05,
y2 - mach_r * dy / d,
lw=1.5,
)
ax.text(
0.3,
4.8,
"J1: M1→M2→M3",
fontsize=7,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY1, "edgecolor": LN},
)
# J2: M2 → M3 → M1 (dashed)
route2 = [(4.5, 4.2), (3, 2.5), (1.5, 4.2)]
for i in range(len(route2) - 1):
x1, y1 = route2[i]
x2, y2 = route2[i + 1]
dx = x2 - x1
dy = y2 - y1
d = (dx**2 + dy**2) ** 0.5
off = 0.15 # offset to avoid overlap
ax.annotate(
"",
xy=(x2 - mach_r * dx / d - 0.05, y2 - mach_r * dy / d + off),
xytext=(x1 + mach_r * dx / d + 0.05, y1 + mach_r * dy / d + off),
arrowprops={
"arrowstyle": "->",
"color": "#555555",
"lw": 1.5,
"linestyle": "dashed",
},
)
ax.text(
3.8,
5.2,
"J2: M2→M3→M1",
fontsize=7,
fontweight="bold",
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY4, "edgecolor": LN},
)
ax.text(
3,
1.2,
"Każde zadanie:\nwłasna trasa!",
ha="center",
va="center",
fontsize=8,
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
)
ax.text(
3,
0.4,
"NP-trudny już dla 3 maszyn",
ha="center",
fontsize=7,
style="italic",
color="#666666",
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
"""Constants for the screen locker module."""
from __future__ import annotations
from pathlib import Path
# Validation limits for workout data
MAX_DISTANCE_KM = 100
MAX_TIME_MINUTES = 600
MAX_PACE_MIN_PER_KM = 20
MIN_EXERCISE_NAME_LEN = 3
MAX_SETS = 20
MAX_REPS = 100
MAX_WEIGHT_KG = 500
SICK_LOCKOUT_SECONDS = 120 # 2 minutes wait when sick
SUBMIT_DELAY_DEMO = 30
SUBMIT_DELAY_PRODUCTION = 180
PHONE_PENALTY_DELAY_DEMO = 10
PHONE_PENALTY_DELAY_PRODUCTION = 600
ADB_TIMEOUT = 15
STRONGLIFTS_DB_REMOTE = (
"/data/data/com.stronglifts.app/databases/StrongLifts-Database-3"
)
SHUTDOWN_CONFIG_FILE = Path("/etc/shutdown-schedule.conf")
# Helper script path (relative to this file)
ADJUST_SHUTDOWN_SCRIPT = Path(__file__).resolve().parent / "adjust_shutdown_schedule.sh"
# State file to track sick day usage and original config values
SICK_DAY_STATE_FILE = Path(__file__).resolve().parent / "sick_day_state.json"
STRENGTH_FIELDS: list[tuple[str, int]] = [
("Exercises (comma-separated):", 50),
("Sets per exercise (comma-separated):", 20),
("Reps (comma-sep, + for variable: 12+11+12):", 30),
("Weight per exercise kg (comma-separated):", 20),
("Total weight lifted (kg):", 15),
]

View File

@ -0,0 +1,203 @@
"""Phone workout verification mixin using ADB and StrongLifts."""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextlib
import logging
from pathlib import Path
import shutil
import socket
import sqlite3
import subprocess
import tempfile
from python_pkg.screen_locker._constants import ADB_TIMEOUT, STRONGLIFTS_DB_REMOTE
_logger = logging.getLogger(__name__)
class PhoneVerificationMixin:
"""Mixin providing phone-based workout verification via ADB."""
def _run_adb(self, args: list[str]) -> tuple[bool, str]:
"""Run an ADB command and return success flag and stdout."""
adb = shutil.which("adb") or "adb"
# When multiple devices are connected (e.g. USB + wireless), pin to
# the wireless device's serial to avoid "more than one device" errors.
_discovery_cmds = {"devices", "connect", "disconnect", "kill-server"}
serial = (
self._get_wireless_serial()
if args and args[0] not in _discovery_cmds
else None
)
serial_args = ["-s", serial] if serial else []
try:
result = subprocess.run(
[adb, *serial_args, *args],
capture_output=True,
text=True,
timeout=ADB_TIMEOUT,
check=False,
)
except (FileNotFoundError, OSError) as exc:
_logger.warning("ADB not available: %s", exc)
return False, ""
except subprocess.TimeoutExpired:
_logger.warning("ADB command timed out: %s", args)
return False, ""
return result.returncode == 0, result.stdout
def _adb_shell(
self,
command: str,
*,
root: bool = False,
) -> tuple[bool, str]:
"""Run a shell command on the connected Android device."""
if root:
return self._run_adb(["shell", "su", "-c", command])
return self._run_adb(["shell", command])
def _get_wireless_serial(self) -> str | None:
"""Return the serial (ip:port) of the first connected wireless ADB device.
Used to pin ADB commands to the wireless device when multiple devices
(e.g. USB cable + wireless debugging) are simultaneously connected.
"""
success, output = self._run_adb(["devices"])
if not success:
return None
for line in output.strip().split("\n")[1:]:
parts = line.split()
if parts and ":" in parts[0] and "device" in line and "offline" not in line:
return parts[0]
return None
def _has_adb_device(self) -> bool:
"""Return True if adb devices shows at least one connected device."""
success, output = self._run_adb(["devices"])
if not success:
return False
lines = output.strip().split("\n")[1:]
return any("device" in line and "offline" not in line for line in lines)
def _try_adb_connect(self, address: str) -> bool:
"""Run adb connect to address. Returns True on success."""
_, output = self._run_adb(["connect", address])
lower = output.lower()
return "connected" in lower and "unable" not in lower and "failed" not in lower
def _get_local_subnet_prefix(self) -> str | None:
"""Detect the local /24 network prefix (e.g. '192.168.1')."""
with (
contextlib.suppress(OSError),
socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock,
):
sock.connect(("8.8.8.8", 80))
return ".".join(sock.getsockname()[0].split(".")[:3])
return None
def _try_wireless_reconnect(self) -> bool:
"""Scan local /24 subnet on port 5555 and attempt ADB connect to phone."""
prefix = self._get_local_subnet_prefix()
if prefix is None:
_logger.info("Could not determine local subnet for wireless scan")
return False
def probe(i: int) -> bool:
ip = f"{prefix}.{i}"
with (
contextlib.suppress(OSError),
socket.create_connection((ip, 5555), timeout=0.5),
):
if self._try_adb_connect(f"{ip}:5555"):
return self._has_adb_device()
return False
_logger.info("Scanning %s.1-254:5555 for phone...", prefix)
with ThreadPoolExecutor(max_workers=64) as executor:
for future in as_completed(
executor.submit(probe, i) for i in range(1, 255)
):
if future.result():
return True
return False
def _is_phone_connected(self) -> bool:
"""Check if an Android device is connected via ADB.
If no device is visible, attempts wireless reconnection using the
stored phone IP/port config. USB-connected devices are detected
automatically by adb devices without any extra steps.
"""
if self._has_adb_device():
return True
_logger.info("No ADB device detected — attempting wireless reconnect...")
return self._try_wireless_reconnect()
def _pull_stronglifts_db(self) -> Path | None:
"""Pull StrongLifts database from phone to a local temp file.
Returns:
Path to the local copy, or None on failure.
"""
tmp = Path(tempfile.gettempdir()) / "stronglifts_check.db"
success, _ = self._adb_shell(
f"cat '{STRONGLIFTS_DB_REMOTE}' > /sdcard/_sl_tmp.db",
root=True,
)
if not success:
return None
ok, _ = self._run_adb(["pull", "/sdcard/_sl_tmp.db", str(tmp)])
if not ok:
return None
return tmp
def _count_today_workouts(self, db_path: Path) -> int:
"""Count today's workouts in a local copy of StrongLifts DB.
Args:
db_path: Path to the locally-pulled StrongLifts database.
Returns:
Number of workouts started today (local time).
"""
try:
conn = sqlite3.connect(str(db_path))
try:
cursor = conn.execute(
"SELECT COUNT(*) FROM workouts "
"WHERE date(start / 1000, 'unixepoch', 'localtime') "
"= date('now', 'localtime')",
)
row = cursor.fetchone()
return int(row[0]) if row else 0
finally:
conn.close()
except (sqlite3.Error, ValueError, TypeError):
_logger.warning("Failed to query StrongLifts database")
return 0
def _verify_phone_workout(self) -> tuple[str, str]:
"""Verify workout was recorded in StrongLifts on the phone.
Returns:
Tuple of (status, message) where status is one of:
- "verified": Workout confirmed on phone.
- "not_verified": Phone connected but no workout found.
- "no_phone": No phone connected via ADB.
- "error": Could not access StrongLifts database.
"""
if not self._is_phone_connected():
return "no_phone", "No phone connected via ADB"
local_db = self._pull_stronglifts_db()
if local_db is None:
return "error", "StrongLifts database not found on phone"
count = self._count_today_workouts(local_db)
if count > 0:
return (
"verified",
f"Workout verified! ({count} session(s) found on phone)",
)
return "not_verified", "No workout found on phone today"

View File

@ -0,0 +1,262 @@
"""Shutdown schedule adjustment mixin for the screen locker."""
from __future__ import annotations
from datetime import datetime, timezone
import json
import logging
import subprocess
from python_pkg.screen_locker._constants import (
ADJUST_SHUTDOWN_SCRIPT,
SHUTDOWN_CONFIG_FILE,
SICK_DAY_STATE_FILE,
)
_logger = logging.getLogger(__name__)
class ShutdownMixin:
"""Mixin providing shutdown schedule adjustment functionality."""
def _apply_earlier_shutdown(self, today: str) -> bool:
"""Read config, save state, and write earlier shutdown hours."""
config_values = self._read_shutdown_config()
if config_values is None:
return False
mon_wed_hour, thu_sun_hour, morning_end_hour = config_values
if not self._save_sick_day_state(today, mon_wed_hour, thu_sun_hour):
_logger.error("Failed to save state - aborting adjustment")
return False
new_mon_wed = max(18, mon_wed_hour - 1)
new_thu_sun = max(18, thu_sun_hour - 1)
return self._write_shutdown_config(
new_mon_wed,
new_thu_sun,
morning_end_hour,
)
def _adjust_shutdown_time_earlier(self) -> bool:
"""Adjust shutdown schedule 1.5 hours earlier (stricter).
This can only be used once per day. Original values are saved and
automatically restored when checked the next day.
Returns True if successful, False otherwise.
"""
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
self._restore_original_config_if_needed()
if self._sick_mode_used_today():
_logger.warning("Sick mode already used today")
return False
try:
return self._apply_earlier_shutdown(today)
except (OSError, ValueError) as e:
_logger.warning("Failed to adjust shutdown time: %s", e)
return False
def _adjust_shutdown_time_later(self) -> bool:
"""Adjust shutdown schedule 2 hours later as workout reward.
Returns True if successful, False otherwise.
"""
try:
config_values = self._read_shutdown_config()
if config_values is None:
return False
mon_wed_hour, thu_sun_hour, morning_end_hour = config_values
new_mon_wed = min(23, mon_wed_hour + 2)
new_thu_sun = min(23, thu_sun_hour + 2)
return self._write_shutdown_config(
new_mon_wed,
new_thu_sun,
morning_end_hour,
restore=True,
)
except (OSError, ValueError) as e:
_logger.warning("Failed to adjust shutdown time for workout: %s", e)
return False
def _sick_mode_used_today(self) -> bool:
"""Check if sick mode was already used today."""
if not SICK_DAY_STATE_FILE.exists():
return False
try:
with SICK_DAY_STATE_FILE.open() as f:
state = json.load(f)
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
return state.get("date") == today
except (OSError, json.JSONDecodeError):
return False
def _save_sick_day_state(
self,
date: str,
orig_mon_wed: int,
orig_thu_sun: int,
) -> bool:
"""Save sick day state with original config values.
Returns True if saved successfully, False otherwise.
"""
state = {
"date": date,
"original_mon_wed_hour": orig_mon_wed,
"original_thu_sun_hour": orig_thu_sun,
}
try:
with SICK_DAY_STATE_FILE.open("w") as f:
json.dump(state, f, indent=2)
except OSError as e:
_logger.warning("Failed to save sick day state: %s", e)
return False
_logger.info("Saved sick day state for %s", date)
return True
def _load_sick_day_state(self) -> tuple[str, int, int] | None:
"""Load sick day state file.
Returns (date, orig_mon_wed_hour, orig_thu_sun_hour) or None.
"""
with SICK_DAY_STATE_FILE.open() as f:
state = json.load(f)
date = state.get("date")
orig_mw = state.get("original_mon_wed_hour")
orig_ts = state.get("original_thu_sun_hour")
if date is None or orig_mw is None or orig_ts is None:
return None
return (str(date), int(orig_mw), int(orig_ts))
def _write_restored_config(
self,
orig_mw: int,
orig_ts: int,
state_date: str,
) -> None:
"""Write restored config values and clean up state file."""
config_values = self._read_shutdown_config()
if config_values:
_, _, morning_end = config_values
_logger.info(
"Restoring original shutdown config from %s",
state_date,
)
self._write_shutdown_config(
orig_mw,
orig_ts,
morning_end,
restore=True,
)
SICK_DAY_STATE_FILE.unlink()
_logger.info("Removed stale sick day state from %s", state_date)
def _restore_original_config_if_needed(self) -> None:
"""Restore original config if sick day state is from a previous day."""
if not SICK_DAY_STATE_FILE.exists():
return
try:
loaded = self._load_sick_day_state()
if loaded is None:
return
state_date, orig_mw, orig_ts = loaded
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
if state_date != today:
self._write_restored_config(orig_mw, orig_ts, state_date)
except (OSError, json.JSONDecodeError) as e:
_logger.warning("Error checking sick day state: %s", e)
def _read_shutdown_config(self) -> tuple[int, int, int] | None:
"""Read shutdown config. Returns (mw_hour, ts_hour, me_hour) or None."""
if not SHUTDOWN_CONFIG_FILE.exists():
_logger.warning("Config not found: %s", SHUTDOWN_CONFIG_FILE)
return None
parsed: dict[str, int] = {}
keys = ("MON_WED_HOUR", "THU_SUN_HOUR", "MORNING_END_HOUR")
with SHUTDOWN_CONFIG_FILE.open() as f:
for line in f:
stripped = line.strip()
for key in keys:
if stripped.startswith(f"{key}="):
parsed[key] = int(stripped.split("=")[1])
if len(parsed) < len(keys):
_logger.warning("Shutdown config missing required values")
return None
return (
parsed["MON_WED_HOUR"],
parsed["THU_SUN_HOUR"],
parsed["MORNING_END_HOUR"],
)
def _build_shutdown_cmd(
self,
mon_wed: int,
thu_sun: int,
morning: int,
*,
restore: bool,
) -> list[str]:
"""Build the shutdown adjustment command."""
cmd = ["/usr/bin/sudo", str(ADJUST_SHUTDOWN_SCRIPT)]
if restore:
cmd.append("--restore")
cmd.extend([str(mon_wed), str(thu_sun), str(morning)])
return cmd
def _write_shutdown_config(
self,
mon_wed_hour: int,
thu_sun_hour: int,
morning_end_hour: int,
*,
restore: bool = False,
) -> bool:
"""Write new shutdown config values using helper script.
Args:
mon_wed_hour: Shutdown hour for Monday-Wednesday.
thu_sun_hour: Shutdown hour for Thursday-Sunday.
morning_end_hour: Morning end hour.
restore: If True, allows restoring to later times.
Returns True if successful, False otherwise.
"""
if not ADJUST_SHUTDOWN_SCRIPT.exists():
_logger.warning(
"Script not found: %s",
ADJUST_SHUTDOWN_SCRIPT,
)
return False
cmd = self._build_shutdown_cmd(
mon_wed_hour,
thu_sun_hour,
morning_end_hour,
restore=restore,
)
return self._run_shutdown_cmd(cmd, mon_wed_hour, thu_sun_hour)
def _run_shutdown_cmd(
self,
cmd: list[str],
mon_wed_hour: int,
thu_sun_hour: int,
) -> bool:
"""Execute the shutdown adjustment command."""
try:
result = subprocess.run(
cmd,
check=True,
capture_output=True,
text=True,
)
except subprocess.SubprocessError as e:
_logger.warning("Failed to adjust shutdown config: %s", e)
return False
_logger.info(
"Adjusted shutdown: Mon-Wed=%d, Thu-Sun=%d. %s",
mon_wed_hour,
thu_sun_hour,
result.stdout.strip(),
)
return True

View File

@ -0,0 +1,294 @@
"""UI flow methods mixin for the screen locker."""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
import contextlib
import tkinter as tk
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from python_pkg.screen_locker._constants import (
PHONE_PENALTY_DELAY_DEMO,
PHONE_PENALTY_DELAY_PRODUCTION,
SICK_LOCKOUT_SECONDS,
)
class UIFlowsMixin:
"""Mixin providing UI flow logic for the screen locker."""
def ask_workout_done(self) -> None:
"""Display the initial workout question dialog."""
self.clear_container()
self._label("Did you work out today?", pady=30)
frame = self._button_row()
self._button(
frame,
"YES",
bg="#00aa00",
command=self.ask_workout_type,
).pack(side="left", padx=20)
self._button(
frame,
"NO",
bg="#aa0000",
command=self.ask_if_sick,
).pack(side="left", padx=20)
def _start_phone_check(self) -> None:
"""Check phone for today's workout immediately at startup."""
self.clear_container()
self._label("Checking phone...", font_size=36, color="#ffaa00", pady=30)
self._text("Looking for today's workout in StrongLifts...", font_size=18)
executor = ThreadPoolExecutor(max_workers=1)
self._phone_future = executor.submit(self._verify_phone_workout)
executor.shutdown(wait=False)
self._poll_phone_check()
def _poll_phone_check(self) -> None:
"""Poll background phone check and route to result handler when done."""
if self._phone_future is not None and self._phone_future.done():
status, message = self._phone_future.result()
self._handle_startup_phone_result(status, message)
else:
self.root.after(500, self._poll_phone_check)
def _handle_startup_phone_result(self, status: str, message: str) -> None:
"""Route to appropriate screen based on startup phone check result."""
if status == "verified":
self.workout_data["type"] = "phone_verified"
self.workout_data["source"] = message
self.clear_container()
self._label(
"\u2713 Workout Verified!", font_size=42, color="#00cc44", pady=30
)
self._text(message, font_size=20, color="#aaffaa")
self._text("Unlocking...", font_size=18, color="#888888")
unlock_delay = 1500 if self.demo_mode else 2000
self.root.after(unlock_delay, self.unlock_screen)
elif status == "not_verified":
self.clear_container()
self._label("No Workout Found", font_size=36, color="#ff4444", pady=20)
self._text(
f"\u274c {message}\n\n"
"StrongLifts shows no workout today.\n"
"Go do your workout first!",
color="#ffaa00",
)
frame = self._button_row()
self._button(
frame,
"TRY AGAIN",
bg="#0066cc",
command=self._start_phone_check,
width=12,
).pack(side="left", padx=10)
self._button(
frame,
"I'm sick",
bg="#cc6600",
command=self.ask_if_sick,
width=12,
).pack(side="left", padx=10)
else:
# no_phone or error — penalty timer, then proceed to logging form
self._show_phone_penalty(message, on_done=self.ask_workout_done)
def ask_if_sick(self) -> None:
"""Display sick day question dialog."""
self.clear_container()
self._label("Are you sick?", pady=30)
self._text(
"If yes, shutdown time will be moved 1.5 hours earlier",
color="#ffaa00",
)
self._sick_question_buttons()
def _sick_question_buttons(self) -> None:
"""Create the sick day yes/no buttons."""
frame = self._button_row()
self._button(
frame,
"YES (sick)",
bg="#cc6600",
command=self.handle_sick_day,
width=12,
).pack(side="left", padx=20)
self._button(
frame,
"NO",
bg="#aa0000",
command=self.lockout,
width=12,
).pack(side="left", padx=20)
def _get_sick_day_status(self) -> tuple[str, str]:
"""Determine sick day status text and color."""
if self._sick_mode_used_today():
return "Shutdown time already adjusted today", "#ffaa00"
if self._adjust_shutdown_time_earlier():
return (
"Shutdown time moved 1.5 hours earlier \u2713\n(Will revert tomorrow)"
), "#00aa00"
return "Could not adjust shutdown time (check permissions)", "#ff4444"
def handle_sick_day(self) -> None:
"""Handle sick day: adjust shutdown time and start 2-minute wait."""
self.clear_container()
status_text, status_color = self._get_sick_day_status()
self._show_sick_day_ui(status_text, status_color)
self.sick_remaining_time = SICK_LOCKOUT_SECONDS
self._update_sick_countdown()
def _show_sick_day_ui(self, status_text: str, status_color: str) -> None:
"""Display sick day UI labels and countdown."""
self._label("Sick Day Mode", color="#cc6600", pady=20)
self._text(status_text, color=status_color)
self._text(
"Please wait 2 minutes before unlocking...",
font_size=24,
pady=20,
)
self.sick_countdown_label = self._label(
str(SICK_LOCKOUT_SECONDS),
font_size=80,
pady=30,
)
def _update_sick_countdown(self) -> None:
"""Update the sick day countdown timer."""
if self.sick_remaining_time > 0:
self.sick_countdown_label.config(text=str(self.sick_remaining_time))
self.sick_remaining_time -= 1
self.root.after(1000, self._update_sick_countdown)
else:
# Record sick day and unlock
self.workout_data["type"] = "sick_day"
self.workout_data["note"] = "Sick day - shutdown moved earlier"
self.unlock_screen()
# ------------------------------------------------------------------
# Lockout flow
# ------------------------------------------------------------------
def lockout(self) -> None:
"""Display lockout screen with countdown timer."""
self.clear_container()
self.lockout_label = self._label(
f"Go work out!\nLocked for {self.lockout_time} seconds",
font_size=48,
color="#ff4444",
pady=30,
)
self.countdown_label = self._label(
str(self.lockout_time),
font_size=120,
pady=30,
)
self.remaining_time = self.lockout_time
self.update_lockout_countdown()
def update_lockout_countdown(self) -> None:
"""Update the lockout countdown timer display."""
if self.remaining_time > 0:
self.countdown_label.config(text=str(self.remaining_time))
self.remaining_time -= 1
self.root.after(1000, self.update_lockout_countdown)
else:
self.ask_workout_done()
# ------------------------------------------------------------------
# Phone penalty
# ------------------------------------------------------------------
def _attempt_unlock(self) -> None:
"""Unlock screen after workout form submission."""
self.unlock_screen()
def _show_phone_penalty(
self, message: str, *, on_done: Callable[[], None] | None = None
) -> None:
"""Show penalty countdown when phone verification is unavailable."""
self.clear_container()
self._phone_penalty_done_fn: Callable[[], None] = (
on_done if on_done is not None else self.unlock_screen
)
delay = (
PHONE_PENALTY_DELAY_DEMO
if self.demo_mode
else PHONE_PENALTY_DELAY_PRODUCTION
)
self._label(
"Cannot Verify Workout",
font_size=36,
color="#ff8800",
pady=20,
)
self._text(message, color="#ffaa00")
self._text(
"Connect phone via ADB to skip this wait,\n"
"or wait for the penalty timer.\n\n"
"Note: Phone must be rooted and StrongLifts installed.",
font_size=18,
)
self.phone_penalty_remaining = delay
self.phone_penalty_label = self._label(
str(delay),
font_size=80,
pady=20,
)
self._update_phone_penalty()
def _update_phone_penalty(self) -> None:
"""Update phone penalty countdown."""
if self.phone_penalty_remaining > 0:
self.phone_penalty_label.config(
text=str(self.phone_penalty_remaining),
)
self.phone_penalty_remaining -= 1
self.root.after(1000, self._update_phone_penalty)
else:
self._phone_penalty_done_fn()
# ------------------------------------------------------------------
# Submit timer and entry checking
# ------------------------------------------------------------------
def _tick_submit_timer(self) -> None:
"""Decrement submit timer and schedule next tick."""
self.timer_label.config(
text=f"Submit available in {self.submit_unlock_time} seconds...",
)
self.submit_unlock_time -= 1
self.root.after(1000, self.update_submit_timer)
def _try_enable_submit(self) -> None:
"""Enable submit button if all entries are filled."""
all_filled = all(entry.get().strip() for entry in self.entries_to_check)
if all_filled:
self.submit_btn.config(
text="SUBMIT",
state="normal",
bg="#00aa00",
command=self.submit_command,
)
self.timer_label.config(text="You can now submit!")
else:
self.timer_label.config(text="Fill all fields to enable submit")
self.root.after(1000, self.check_entries_filled)
def update_submit_timer(self) -> None:
"""Update countdown timer and check if submit can be enabled."""
with contextlib.suppress(tk.TclError):
if self.submit_unlock_time > 0:
self._tick_submit_timer()
else:
self._try_enable_submit()
def check_entries_filled(self) -> None:
"""Continuously check if entries are filled after timer expires."""
with contextlib.suppress(tk.TclError):
self._try_enable_submit()

View File

@ -0,0 +1,269 @@
"""Workout form methods mixin for the screen locker."""
from __future__ import annotations
from typing import TYPE_CHECKING
from python_pkg.screen_locker._constants import (
MAX_DISTANCE_KM,
MAX_PACE_MIN_PER_KM,
MAX_REPS,
MAX_SETS,
MAX_TIME_MINUTES,
MAX_WEIGHT_KG,
MIN_EXERCISE_NAME_LEN,
STRENGTH_FIELDS,
)
if TYPE_CHECKING:
import tkinter as tk
class WorkoutFormsMixin:
"""Mixin providing workout form creation and validation."""
# ------------------------------------------------------------------
# Workout type selection
# ------------------------------------------------------------------
def ask_workout_type(self) -> None:
"""Display workout type selection dialog."""
self.clear_container()
self._label("What type of workout?", pady=30)
frame = self._button_row()
self._button(
frame,
"STRENGTH",
bg="#cc6600",
command=self.ask_strength_details,
width=12,
).pack(side="left", padx=20)
# ------------------------------------------------------------------
# Running workout
# ------------------------------------------------------------------
def _create_running_entries(self) -> list[tk.Entry]:
"""Create running workout entry fields."""
self.distance_entry = self._entry_row("Distance (km):")
self.time_entry = self._entry_row("Time (minutes):")
self.pace_entry = self._entry_row("Pace (min/km):")
return [self.distance_entry, self.time_entry, self.pace_entry]
def ask_running_details(self) -> None:
"""Display running workout input form."""
self.clear_container()
self.workout_data["type"] = "running"
self._label("Running Details", pady=20)
entries = self._create_running_entries()
self._setup_form_controls(
entries,
self.verify_running_data,
self.ask_workout_type,
)
def _check_running_ranges(
self,
distance: float,
time_mins: float,
pace: float,
) -> str | None:
"""Check if running values are in valid ranges."""
if distance <= 0 or distance > MAX_DISTANCE_KM:
return f"Distance seems unrealistic (0-{MAX_DISTANCE_KM} km)"
if time_mins <= 0 or time_mins > MAX_TIME_MINUTES:
return f"Time seems unrealistic (0-{MAX_TIME_MINUTES} minutes)"
if pace <= 0 or pace > MAX_PACE_MIN_PER_KM:
return f"Pace seems unrealistic (0-{MAX_PACE_MIN_PER_KM} min/km)"
expected_pace = time_mins / distance
tolerance = expected_pace * 0.15 # 15% tolerance
if abs(pace - expected_pace) > tolerance:
return (
f"Pace doesn't match! "
f"Expected ~{expected_pace:.2f} min/km, got {pace:.2f}"
)
return None
def _validate_running_input(self) -> tuple[float, float, float] | None:
"""Parse and validate running input fields."""
try:
distance = float(self.distance_entry.get())
time_mins = float(self.time_entry.get())
pace = float(self.pace_entry.get())
except ValueError:
self.show_error("Please enter valid numbers")
return None
error = self._check_running_ranges(distance, time_mins, pace)
if error:
self.show_error(error)
return None
return distance, time_mins, pace
def verify_running_data(self) -> None:
"""Validate running workout data and unlock if valid."""
result = self._validate_running_input()
if result is None:
return
distance, time_mins, pace = result
self.workout_data["distance_km"] = str(distance)
self.workout_data["time_minutes"] = str(time_mins)
self.workout_data["pace_min_per_km"] = str(pace)
self._attempt_unlock()
# ------------------------------------------------------------------
# Strength workout
# ------------------------------------------------------------------
def _create_strength_entries(self) -> list[tk.Entry]:
"""Create strength training entry fields."""
entries = [
self._entry_row(lbl, width=w, font_size=18) for lbl, w in STRENGTH_FIELDS
]
(
self.exercises_entry,
self.sets_entry,
self.reps_entry,
self.weights_entry,
self.total_weight_entry,
) = entries
return entries
def ask_strength_details(self) -> None:
"""Display strength training input form."""
self.clear_container()
self.workout_data["type"] = "strength"
self._label("Strength Training Details", pady=20)
entries = self._create_strength_entries()
self._setup_form_controls(
entries,
self.verify_strength_data,
self.ask_workout_type,
)
def _parse_reps(self, reps_raw: list[str]) -> list[list[int]]:
"""Parse reps input - single number or variable reps like '12+11+12'."""
reps: list[list[int]] = []
for r in reps_raw:
if "+" in r:
reps.append([int(x.strip()) for x in r.split("+")])
else:
reps.append([int(r)])
return reps
def _validate_strength_inputs(
self,
exercises: list[str],
sets: list[int],
reps: list[list[int]],
weights: list[float],
) -> str | None:
"""Validate strength workout inputs. Returns error message or None."""
if not (len(exercises) == len(sets) == len(reps) == len(weights)):
return "Number of exercises, sets, reps, and weights must match"
if any(len(ex) < MIN_EXERCISE_NAME_LEN for ex in exercises):
return "Exercise names too short - be specific"
if any(s < 1 or s > MAX_SETS for s in sets):
return f"Sets should be between 1-{MAX_SETS}"
if any(w < 0 or w > MAX_WEIGHT_KG for w in weights):
return f"Weights should be between 0-{MAX_WEIGHT_KG} kg"
return self._validate_reps(exercises, sets, reps)
def _validate_reps(
self,
exercises: list[str],
sets: list[int],
reps: list[list[int]],
) -> str | None:
"""Validate reps data. Returns error message or None if valid."""
for i, rep_list in enumerate(reps):
if any(r < 1 or r > MAX_REPS for r in rep_list):
return f"Reps should be between 1-{MAX_REPS}"
if len(rep_list) > 1 and len(rep_list) != sets[i]:
return (
f"For {exercises[i]!r}: variable reps count "
f"({len(rep_list)}) doesn't match sets ({sets[i]})"
)
return None
def _calculate_expected_total(
self,
sets: list[int],
reps: list[list[int]],
weights: list[float],
) -> float:
"""Calculate expected total weight lifted."""
expected_total = 0.0
for i, rep_list in enumerate(reps):
if len(rep_list) == 1:
expected_total += sets[i] * rep_list[0] * weights[i]
else:
expected_total += sum(rep_list) * weights[i]
return expected_total
def _parse_strength_entries(
self,
) -> tuple[list[str], list[int], list[list[int]], list[float], float]:
"""Parse raw strength training input from entry widgets."""
exercises = [e.strip() for e in self.exercises_entry.get().split(",")]
sets = [int(s.strip()) for s in self.sets_entry.get().split(",")]
reps_raw = [r.strip() for r in self.reps_entry.get().split(",")]
reps = self._parse_reps(reps_raw)
weights = [float(w.strip()) for w in self.weights_entry.get().split(",")]
total_weight = float(self.total_weight_entry.get())
return exercises, sets, reps, weights, total_weight
def _check_total_weight(
self,
sets: list[int],
reps: list[list[int]],
weights: list[float],
total_weight: float,
) -> str | None:
"""Verify total weight matches individual exercise calculations."""
expected = self._calculate_expected_total(sets, reps, weights)
tolerance = expected * 0.15 # 15% tolerance
if abs(total_weight - expected) > tolerance:
return (
f"Total weight doesn't match! "
f"Expected ~{expected:.1f} kg, got {total_weight:.1f}"
)
return None
def _store_strength_data(
self,
exercises: list[str],
sets: list[int],
reps: list[list[int]],
weights: list[float],
total_weight: float,
) -> None:
"""Store validated strength workout data."""
self.workout_data["exercises"] = exercises
self.workout_data["sets"] = [str(s) for s in sets]
self.workout_data["reps"] = [
"+".join(str(r) for r in rep_list) for rep_list in reps
]
self.workout_data["weights_kg"] = [str(w) for w in weights]
self.workout_data["total_weight_kg"] = str(total_weight)
def verify_strength_data(self) -> None:
"""Validate strength workout data and unlock if valid."""
try:
self._verify_strength_data_inner()
except ValueError:
self.show_error("Please enter valid data in correct format")
def _verify_strength_data_inner(self) -> None:
"""Parse, validate, and store strength data."""
data = self._parse_strength_entries()
exercises, sets, reps, weights, total_weight = data
error = self._validate_strength_inputs(exercises, sets, reps, weights)
if error:
self.show_error(error)
return
total_err = self._check_total_weight(sets, reps, weights, total_weight)
if total_err:
self.show_error(total_err)
return
self._store_strength_data(exercises, sets, reps, weights, total_weight)
self._attempt_unlock()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,113 @@
"""Shared fixtures and helpers for screen_locker tests."""
from __future__ import annotations
from pathlib import Path
import tkinter as tk
from typing import TYPE_CHECKING, NamedTuple
from unittest.mock import MagicMock, patch
import pytest
from python_pkg.screen_locker.screen_lock import ScreenLocker
if TYPE_CHECKING:
from collections.abc import Generator
class RunningData(NamedTuple):
"""Running workout data for tests."""
distance: str
time_mins: str
pace: str
class StrengthData(NamedTuple):
"""Strength workout data for tests."""
exercises: str
sets: str
reps: str
weights: str
total_weight: str
@pytest.fixture
def mock_tk() -> Generator[MagicMock]:
"""Mock tkinter module for testing without display."""
with patch("python_pkg.screen_locker.screen_lock.tk") as mock:
# Set up Tk root mock
mock_root = MagicMock()
mock_root.winfo_screenwidth.return_value = 1920
mock_root.winfo_screenheight.return_value = 1080
mock.Tk.return_value = mock_root
# Set up Frame mock
mock_frame = MagicMock()
mock_frame.winfo_children.return_value = []
mock.Frame.return_value = mock_frame
# Set up TclError as actual exception class
mock.TclError = tk.TclError
yield mock
@pytest.fixture
def mock_sys_exit() -> Generator[MagicMock]:
"""Mock sys.exit to prevent test termination."""
with patch("python_pkg.screen_locker.screen_lock.sys.exit") as mock:
yield mock
@pytest.fixture
def _mock_sys_exit(mock_sys_exit: MagicMock) -> MagicMock:
"""Alias for mock_sys_exit when the return value is unused."""
return mock_sys_exit
@pytest.fixture
def temp_log_file(tmp_path: Path) -> Path:
"""Create a temporary log file path."""
return tmp_path / "workout_log.json"
def create_locker(
_mock_tk: MagicMock,
tmp_path: Path,
*,
demo_mode: bool = True,
has_logged: bool = False,
) -> ScreenLocker:
"""Create a ScreenLocker instance for testing."""
with (
patch.object(Path, "resolve", return_value=tmp_path),
patch.object(ScreenLocker, "has_logged_today", return_value=has_logged),
patch.object(ScreenLocker, "_start_phone_check"),
):
return ScreenLocker(demo_mode=demo_mode)
def setup_running_entries(locker: ScreenLocker, data: RunningData) -> None:
"""Set up mock running entry widgets."""
locker.distance_entry = MagicMock()
locker.distance_entry.get.return_value = data.distance
locker.time_entry = MagicMock()
locker.time_entry.get.return_value = data.time_mins
locker.pace_entry = MagicMock()
locker.pace_entry.get.return_value = data.pace
def setup_strength_entries(locker: ScreenLocker, data: StrengthData) -> None:
"""Set up mock strength entry widgets."""
locker.exercises_entry = MagicMock()
locker.exercises_entry.get.return_value = data.exercises
locker.sets_entry = MagicMock()
locker.sets_entry.get.return_value = data.sets
locker.reps_entry = MagicMock()
locker.reps_entry.get.return_value = data.reps
locker.weights_entry = MagicMock()
locker.weights_entry.get.return_value = data.weights
locker.total_weight_entry = MagicMock()
locker.total_weight_entry.get.return_value = data.total_weight

View File

@ -0,0 +1,411 @@
"""Tests for ADB commands, phone connection, and database operations."""
from __future__ import annotations
import sqlite3
import subprocess
import time
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
from python_pkg.screen_locker.screen_lock import STRONGLIFTS_DB_REMOTE
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
class TestRunAdb:
"""Tests for _run_adb ADB command execution."""
def test_run_adb_success(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test successful ADB command."""
locker = create_locker(mock_tk, tmp_path)
mock_result = MagicMock(returncode=0, stdout="ok\n")
with patch(
"python_pkg.screen_locker._phone_verification.subprocess.run",
return_value=mock_result,
) as mock_run:
success, output = locker._run_adb(["devices"])
assert success is True
assert output == "ok\n"
mock_run.assert_called_once()
def test_run_adb_failure(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test failed ADB command."""
locker = create_locker(mock_tk, tmp_path)
mock_result = MagicMock(returncode=1, stdout="")
with patch(
"python_pkg.screen_locker._phone_verification.subprocess.run",
return_value=mock_result,
):
success, _output = locker._run_adb(["devices"])
assert success is False
def test_run_adb_not_found(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ADB binary not found."""
locker = create_locker(mock_tk, tmp_path)
with patch(
"python_pkg.screen_locker._phone_verification.subprocess.run",
side_effect=FileNotFoundError("adb not found"),
):
success, output = locker._run_adb(["devices"])
assert success is False
assert output == ""
def test_run_adb_oserror(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ADB OSError."""
locker = create_locker(mock_tk, tmp_path)
with patch(
"python_pkg.screen_locker._phone_verification.subprocess.run",
side_effect=OSError("permission denied"),
):
success, output = locker._run_adb(["devices"])
assert success is False
assert output == ""
def test_run_adb_timeout(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ADB command timeout."""
locker = create_locker(mock_tk, tmp_path)
with patch(
"python_pkg.screen_locker._phone_verification.subprocess.run",
side_effect=subprocess.TimeoutExpired("adb", 15),
):
success, output = locker._run_adb(["devices"])
assert success is False
assert output == ""
class TestAdbShell:
"""Tests for _adb_shell method."""
def test_adb_shell_no_root(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ADB shell without root."""
locker = create_locker(mock_tk, tmp_path)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(True, "output"),
)
success, output = locker._adb_shell("ls /sdcard")
locker._run_adb.assert_called_once_with(["shell", "ls /sdcard"])
assert success is True
assert output == "output"
def test_adb_shell_with_root(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ADB shell with root."""
locker = create_locker(mock_tk, tmp_path)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(True, "output"),
)
success, _output = locker._adb_shell("ls /data", root=True)
locker._run_adb.assert_called_once_with(
["shell", "su", "-c", "ls /data"],
)
assert success is True
class TestIsPhoneConnected:
"""Tests for _is_phone_connected method."""
def test_phone_connected(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone detected as connected."""
locker = create_locker(mock_tk, tmp_path)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(
True,
"List of devices attached\nABC123\tdevice\n\n",
),
)
assert locker._is_phone_connected() is True
def test_phone_not_connected(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test no phone connected."""
locker = create_locker(mock_tk, tmp_path)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(True, "List of devices attached\n\n"),
)
locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign]
return_value=False,
)
assert locker._is_phone_connected() is False
def test_phone_offline(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone connected but offline."""
locker = create_locker(mock_tk, tmp_path)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(
True,
"List of devices attached\nABC123\toffline\n\n",
),
)
locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign]
return_value=False,
)
assert locker._is_phone_connected() is False
def test_adb_command_fails(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ADB command failure."""
locker = create_locker(mock_tk, tmp_path)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(False, ""),
)
locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign]
return_value=False,
)
assert locker._is_phone_connected() is False
class TestFindHealthConnectDb:
"""Tests for _pull_stronglifts_db method."""
def test_db_pulled_successfully(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test StrongLifts DB pulled from device."""
locker = create_locker(mock_tk, tmp_path)
locker._adb_shell = MagicMock( # type: ignore[method-assign]
return_value=(True, ""),
)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(True, ""),
)
result = locker._pull_stronglifts_db()
assert result is not None
locker._adb_shell.assert_called_once()
locker._run_adb.assert_called_once()
call_args = locker._run_adb.call_args[0][0]
assert call_args[0] == "pull"
def test_db_cat_fails(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns None when cat command fails."""
locker = create_locker(mock_tk, tmp_path)
locker._adb_shell = MagicMock( # type: ignore[method-assign]
return_value=(False, ""),
)
assert locker._pull_stronglifts_db() is None
def test_db_pull_fails(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns None when adb pull fails."""
locker = create_locker(mock_tk, tmp_path)
locker._adb_shell = MagicMock( # type: ignore[method-assign]
return_value=(True, ""),
)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(False, ""),
)
assert locker._pull_stronglifts_db() is None
def test_db_uses_correct_remote_path(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test uses the correct StrongLifts DB remote path."""
locker = create_locker(mock_tk, tmp_path)
locker._adb_shell = MagicMock( # type: ignore[method-assign]
return_value=(True, ""),
)
locker._run_adb = MagicMock( # type: ignore[method-assign]
return_value=(True, ""),
)
locker._pull_stronglifts_db()
shell_cmd = locker._adb_shell.call_args[0][0]
assert STRONGLIFTS_DB_REMOTE in shell_cmd
class TestCountTodayWorkouts:
"""Tests for _count_today_workouts method."""
def test_workouts_found_today(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test workouts found today."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
# Insert a workout with today's timestamp (ms)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + 3600000),
)
conn.commit()
conn.close()
assert locker._count_today_workouts(db_file) == 1
def test_no_workouts_today(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test no workouts today."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
# Insert a workout from yesterday (24h+ ago)
yesterday_ms = int((time.time() - 200000) * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", yesterday_ms, yesterday_ms + 3600000),
)
conn.commit()
conn.close()
assert locker._count_today_workouts(db_file) == 0
def test_invalid_db_returns_zero(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 for invalid database file."""
locker = create_locker(mock_tk, tmp_path)
bad_file = tmp_path / "not_a_db.db"
bad_file.write_text("not a database")
assert locker._count_today_workouts(bad_file) == 0
def test_missing_table_returns_zero(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test returns 0 when workouts table doesn't exist."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "empty.db"
conn = sqlite3.connect(str(db_file))
conn.execute("CREATE TABLE other (id TEXT)")
conn.commit()
conn.close()
assert locker._count_today_workouts(db_file) == 0
def test_multiple_workouts_today(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test counts multiple workouts today correctly."""
locker = create_locker(mock_tk, tmp_path)
db_file = tmp_path / "sl_test.db"
conn = sqlite3.connect(str(db_file))
conn.execute(
"CREATE TABLE workouts "
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
)
now_ms = int(time.time() * 1000)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w1", now_ms, now_ms + 3600000),
)
conn.execute(
"INSERT INTO workouts VALUES (?, ?, ?)",
("w2", now_ms + 100000, now_ms + 3700000),
)
conn.commit()
conn.close()
assert locker._count_today_workouts(db_file) == 2

View File

@ -0,0 +1,390 @@
"""Tests for screen_locker initialization, logging, and basic operations."""
from __future__ import annotations
from datetime import datetime, timezone
import json
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock
import pytest
from python_pkg.screen_locker.screen_lock import (
MAX_DISTANCE_KM,
MAX_PACE_MIN_PER_KM,
MAX_REPS,
MAX_SETS,
MAX_TIME_MINUTES,
MAX_WEIGHT_KG,
MIN_EXERCISE_NAME_LEN,
)
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
class TestConstants:
"""Tests for module constants."""
def test_max_distance_km(self) -> None:
"""Test MAX_DISTANCE_KM is reasonable."""
assert MAX_DISTANCE_KM == 100
assert MAX_DISTANCE_KM > 0
def test_max_time_minutes(self) -> None:
"""Test MAX_TIME_MINUTES is reasonable."""
assert MAX_TIME_MINUTES == 600
assert MAX_TIME_MINUTES > 0
def test_max_pace_min_per_km(self) -> None:
"""Test MAX_PACE_MIN_PER_KM is reasonable."""
assert MAX_PACE_MIN_PER_KM == 20
assert MAX_PACE_MIN_PER_KM > 0
def test_min_exercise_name_len(self) -> None:
"""Test MIN_EXERCISE_NAME_LEN is reasonable."""
assert MIN_EXERCISE_NAME_LEN == 3
assert MIN_EXERCISE_NAME_LEN > 0
def test_max_sets(self) -> None:
"""Test MAX_SETS is reasonable."""
assert MAX_SETS == 20
assert MAX_SETS > 0
def test_max_reps(self) -> None:
"""Test MAX_REPS is reasonable."""
assert MAX_REPS == 100
assert MAX_REPS > 0
def test_max_weight_kg(self) -> None:
"""Test MAX_WEIGHT_KG is reasonable."""
assert MAX_WEIGHT_KG == 500
assert MAX_WEIGHT_KG > 0
class TestScreenLockerInit:
"""Tests for ScreenLocker initialization."""
def test_init_demo_mode(
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
) -> None:
"""Test initialization in demo mode."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
assert locker.demo_mode is True
assert locker.lockout_time == 10
mock_sys_exit.assert_not_called()
def test_init_production_mode(
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
) -> None:
"""Test initialization in production mode."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
assert locker.demo_mode is False
assert locker.lockout_time == 1800
mock_sys_exit.assert_not_called()
def test_init_exits_if_logged_today(
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
) -> None:
"""Test that init exits early if workout logged today."""
mock_sys_exit.side_effect = SystemExit(0)
with pytest.raises(SystemExit):
create_locker(mock_tk, tmp_path, has_logged=True)
mock_sys_exit.assert_called_once_with(0)
class TestHasLoggedToday:
"""Tests for has_logged_today method."""
def test_no_log_file(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test when log file doesn't exist."""
log_file = tmp_path / "workout_log.json"
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
assert locker.has_logged_today() is False
def test_empty_log_file(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test when log file is empty/invalid JSON."""
log_file = tmp_path / "workout_log.json"
log_file.write_text("")
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
assert locker.has_logged_today() is False
def test_invalid_json(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test when log file contains invalid JSON."""
log_file = tmp_path / "workout_log.json"
log_file.write_text("{invalid json}")
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
assert locker.has_logged_today() is False
def test_today_logged(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test when today's workout is logged."""
log_file = tmp_path / "workout_log.json"
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
log_file.write_text(json.dumps({today: {"workout": "data"}}))
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
assert locker.has_logged_today() is True
def test_other_day_logged(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test when only other days are logged."""
log_file = tmp_path / "workout_log.json"
log_file.write_text(json.dumps({"2020-01-01": {"workout": "data"}}))
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
assert locker.has_logged_today() is False
class TestSaveWorkoutLog:
"""Tests for save_workout_log method."""
def test_save_to_new_file(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test saving to a new log file."""
log_file = tmp_path / "workout_log.json"
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
locker.workout_data = {"type": "running"}
locker.save_workout_log()
assert log_file.exists()
with log_file.open() as f:
data: dict[str, Any] = json.load(f)
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
assert today in data
assert data[today]["workout_data"]["type"] == "running"
def test_save_to_existing_file(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test saving appends to existing log file."""
log_file = tmp_path / "workout_log.json"
log_file.write_text(json.dumps({"2020-01-01": {"old": "data"}}))
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
locker.workout_data = {"type": "strength"}
locker.save_workout_log()
with log_file.open() as f:
data: dict[str, Any] = json.load(f)
assert "2020-01-01" in data
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
assert today in data
def test_save_with_corrupted_existing_file(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test saving when existing file is corrupted."""
log_file = tmp_path / "workout_log.json"
log_file.write_text("not valid json")
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
locker.workout_data = {"type": "running"}
locker.save_workout_log()
with log_file.open() as f:
data: dict[str, Any] = json.load(f)
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
assert today in data
def test_save_with_write_error(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test saving handles write errors gracefully."""
log_file = tmp_path / "nonexistent_dir" / "workout_log.json"
locker = create_locker(mock_tk, tmp_path)
locker.log_file = log_file
locker.workout_data = {"type": "running"}
# Should not raise, just log warning
locker.save_workout_log()
class TestShowError:
"""Tests for show_error method."""
def test_show_error_displays_message(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test show_error clears container and displays error."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.show_error("Test error message")
locker.clear_container.assert_called_once()
class TestRun:
"""Tests for run method."""
def test_run_starts_mainloop(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test run starts the tkinter mainloop."""
locker = create_locker(mock_tk, tmp_path)
locker.run()
locker.root.mainloop.assert_called_once() # type: ignore[attr-defined]
class TestMainEntry:
"""Tests for main entry point."""
def test_main_demo_mode_default(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test main defaults to demo mode."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
assert locker.demo_mode is True
def test_main_production_mode_flag(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test main with --production flag."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
assert locker.demo_mode is False
class TestAdjustShutdownTimeLater:
"""Tests for _adjust_shutdown_time_later method."""
def test_adjust_shutdown_time_later_success(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later adds hours successfully."""
locker = create_locker(mock_tk, tmp_path)
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
return_value=(21, 22, 8)
)
locker._write_shutdown_config = MagicMock( # type: ignore[method-assign]
return_value=True
)
result = locker._adjust_shutdown_time_later()
assert result is True
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
def test_adjust_shutdown_time_later_caps_at_23(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later caps hours at 23."""
locker = create_locker(mock_tk, tmp_path)
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
return_value=(22, 23, 8)
)
locker._write_shutdown_config = MagicMock( # type: ignore[method-assign]
return_value=True
)
result = locker._adjust_shutdown_time_later()
assert result is True
# 22+2=24 capped to 23, 23+2=25 capped to 23
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
def test_adjust_shutdown_time_later_no_config(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later returns False if config missing."""
locker = create_locker(mock_tk, tmp_path)
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
return_value=None
)
result = locker._adjust_shutdown_time_later()
assert result is False
def test_adjust_shutdown_time_later_oserror(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _adjust_shutdown_time_later handles OSError."""
locker = create_locker(mock_tk, tmp_path)
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
side_effect=OSError("permission denied")
)
result = locker._adjust_shutdown_time_later()
assert result is False

View File

@ -0,0 +1,430 @@
"""Tests for phone workout verification, phone check, and unlock operations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
from python_pkg.screen_locker.screen_lock import (
PHONE_PENALTY_DELAY_DEMO,
PHONE_PENALTY_DELAY_PRODUCTION,
)
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
class TestVerifyPhoneWorkout:
"""Tests for _verify_phone_workout method."""
def test_verified(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test workout verified on phone."""
locker = create_locker(mock_tk, tmp_path)
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
return_value=True,
)
locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign]
return_value=tmp_path / "sl.db",
)
locker._count_today_workouts = MagicMock( # type: ignore[method-assign]
return_value=2,
)
status, message = locker._verify_phone_workout()
assert status == "verified"
assert "2 session" in message
def test_not_verified(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test no workout found on phone."""
locker = create_locker(mock_tk, tmp_path)
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
return_value=True,
)
locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign]
return_value=tmp_path / "sl.db",
)
locker._count_today_workouts = MagicMock( # type: ignore[method-assign]
return_value=0,
)
status, message = locker._verify_phone_workout()
assert status == "not_verified"
assert "No workout" in message
def test_no_phone(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test no phone connected."""
locker = create_locker(mock_tk, tmp_path)
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
return_value=False,
)
status, _ = locker._verify_phone_workout()
assert status == "no_phone"
def test_error_no_db(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test error when StrongLifts DB cannot be pulled."""
locker = create_locker(mock_tk, tmp_path)
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
return_value=True,
)
locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign]
return_value=None,
)
status, message = locker._verify_phone_workout()
assert status == "error"
assert "database" in message.lower()
class TestStartPhoneCheck:
"""Tests for _start_phone_check and _handle_startup_phone_result."""
def test_start_phone_check_shows_checking_screen(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _start_phone_check shows checking message and starts check."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker._verify_phone_workout = MagicMock( # type: ignore[method-assign]
return_value=("no_phone", "No phone"),
)
locker._poll_phone_check = MagicMock() # type: ignore[method-assign]
locker._start_phone_check()
locker.clear_container.assert_called()
locker._poll_phone_check.assert_called_once()
assert locker._phone_future is not None
def test_handle_startup_verified_unlocks_directly(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test verified result shows success screen then unlocks via after()."""
locker = create_locker(mock_tk, tmp_path)
locker.unlock_screen = MagicMock() # type: ignore[method-assign]
locker.root.after = MagicMock() # type: ignore[method-assign]
locker._handle_startup_phone_result("verified", "Workout verified! (1 session)")
# unlock_screen is deferred via root.after, not called directly
locker.unlock_screen.assert_not_called()
assert locker.workout_data["type"] == "phone_verified"
locker.root.after.assert_called_once_with(1500, locker.unlock_screen)
def test_handle_startup_not_verified_shows_block(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test not_verified result shows blocking screen with buttons."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker._handle_startup_phone_result(
"not_verified", "No workout found on phone today"
)
locker.clear_container.assert_called()
def test_handle_startup_no_phone_shows_penalty(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test no_phone result triggers penalty with ask_workout_done as callback."""
locker = create_locker(mock_tk, tmp_path)
locker._show_phone_penalty = MagicMock() # type: ignore[method-assign]
locker._handle_startup_phone_result("no_phone", "No phone")
locker._show_phone_penalty.assert_called_once()
_, kwargs = locker._show_phone_penalty.call_args
assert kwargs["on_done"] == locker.ask_workout_done
def test_handle_startup_error_shows_penalty(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test error result triggers penalty with ask_workout_done as callback."""
locker = create_locker(mock_tk, tmp_path)
locker._show_phone_penalty = MagicMock() # type: ignore[method-assign]
locker._handle_startup_phone_result("error", "DB not found")
locker._show_phone_penalty.assert_called_once()
_, kwargs = locker._show_phone_penalty.call_args
assert kwargs["on_done"] == locker.ask_workout_done
def test_poll_phone_check_schedules_retry_when_pending(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _poll_phone_check reschedules itself when future is not done."""
locker = create_locker(mock_tk, tmp_path)
mock_future: MagicMock = MagicMock()
mock_future.done.return_value = False
locker._phone_future = mock_future # type: ignore[assignment]
locker.root.after = MagicMock() # type: ignore[method-assign]
locker._poll_phone_check()
locker.root.after.assert_called_once_with(500, locker._poll_phone_check)
def test_poll_phone_check_routes_when_done(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _poll_phone_check calls result handler when future is done."""
locker = create_locker(mock_tk, tmp_path)
mock_future: MagicMock = MagicMock()
mock_future.done.return_value = True
mock_future.result.return_value = ("no_phone", "No phone")
locker._phone_future = mock_future # type: ignore[assignment]
locker._handle_startup_phone_result = MagicMock() # type: ignore[method-assign]
locker._poll_phone_check()
locker._handle_startup_phone_result.assert_called_once_with(
"no_phone", "No phone"
)
class TestAttemptUnlock:
"""Tests for _attempt_unlock method."""
def test_attempt_unlock_calls_unlock_screen(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test _attempt_unlock calls unlock_screen directly."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "strength"}
locker.unlock_screen = MagicMock() # type: ignore[method-assign]
locker._attempt_unlock()
locker.unlock_screen.assert_called_once()
class TestShowPhonePenalty:
"""Tests for _show_phone_penalty and _update_phone_penalty methods."""
def test_show_phone_penalty_demo_delay(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test demo mode uses short penalty delay."""
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker._show_phone_penalty("test message")
# _update_phone_penalty is called once, decrementing by 1
assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_DEMO - 1
def test_show_phone_penalty_production_delay(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test production mode uses long penalty delay."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker._show_phone_penalty("test message")
assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_PRODUCTION - 1
def test_update_phone_penalty_countdown(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone penalty countdown decrements."""
locker = create_locker(mock_tk, tmp_path)
locker.phone_penalty_remaining = 5
locker.phone_penalty_label = MagicMock()
locker._update_phone_penalty()
assert locker.phone_penalty_remaining == 4
locker.phone_penalty_label.config.assert_called_once_with(text="5")
locker.root.after.assert_called() # type: ignore[attr-defined]
def test_update_phone_penalty_at_zero(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test phone penalty unlocks when timer reaches zero."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "strength"}
locker.phone_penalty_remaining = 0
locker.phone_penalty_label = MagicMock()
locker.unlock_screen = MagicMock() # type: ignore[method-assign]
locker._phone_penalty_done_fn = locker.unlock_screen # type: ignore[attr-defined]
locker._update_phone_penalty()
locker.unlock_screen.assert_called_once()
class TestUnlockScreenShutdownAdjustment:
"""Tests for unlock_screen shutdown time adjustment."""
def test_unlock_screen_adjusts_for_running(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen adjusts shutdown for running workout."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "running"}
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
return_value=True
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
def test_unlock_screen_adjusts_for_strength(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen adjusts shutdown for strength workout."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "strength"}
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
return_value=True
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
def test_unlock_screen_adjusts_for_phone_verified(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen adjusts shutdown for phone-verified workout."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "phone_verified"}
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
return_value=True
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
def test_unlock_screen_skips_adjustment_for_sick_day(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen does not adjust for sick day."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "sick_day"}
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
return_value=True
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_not_called()
def test_unlock_screen_skips_adjustment_no_type(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen does not adjust when no workout type."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {}
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
return_value=True
)
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_not_called()
def test_unlock_screen_handles_adjustment_failure(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen continues when adjustment fails."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "running"}
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
return_value=False
)
# Should not raise, should continue with unlock
locker.unlock_screen()
locker._adjust_shutdown_time_later.assert_called_once()
locker.root.after.assert_called() # type: ignore[attr-defined]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,424 @@
"""Tests for UI transitions, timer logic, and workout detail screens."""
from __future__ import annotations
import tkinter as tk
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
from python_pkg.screen_locker.screen_lock import (
SUBMIT_DELAY_DEMO,
SUBMIT_DELAY_PRODUCTION,
)
from python_pkg.screen_locker.tests.conftest import create_locker
if TYPE_CHECKING:
from pathlib import Path
_TK_TCLERROR = tk.TclError
class TestUITransitions:
"""Tests for UI state transitions."""
def test_clear_container(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test clear_container destroys all child widgets."""
locker = create_locker(mock_tk, tmp_path)
# Set up mock children
mock_child1 = MagicMock()
mock_child2 = MagicMock()
locker.container.winfo_children.return_value = [ # type: ignore[attr-defined]
mock_child1,
mock_child2,
]
locker.clear_container()
mock_child1.destroy.assert_called_once()
mock_child2.destroy.assert_called_once()
def test_unlock_screen_saves_and_schedules_close(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test unlock_screen saves log and schedules close."""
locker = create_locker(mock_tk, tmp_path)
locker.log_file = tmp_path / "workout_log.json"
locker.workout_data = {"type": "running"}
locker.unlock_screen()
# Check that after() was called to schedule close
locker.root.after.assert_called() # type: ignore[attr-defined]
def test_lockout_starts_countdown(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test lockout initializes countdown timer."""
locker = create_locker(mock_tk, tmp_path)
locker.lockout()
# lockout() sets remaining_time to lockout_time (10 in demo mode)
# then calls update_lockout_countdown() which decrements it by 1
assert locker.remaining_time == 9 # 10 - 1 after first update
def test_close_destroys_root_and_exits(
self,
mock_tk: MagicMock,
mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test close destroys root window and exits."""
locker = create_locker(mock_tk, tmp_path)
locker.close()
locker.root.destroy.assert_called_once() # type: ignore[attr-defined]
mock_sys_exit.assert_called_with(0)
class TestTimerLogic:
"""Tests for timer countdown logic."""
def test_update_lockout_countdown_decrements(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test countdown decrements remaining time."""
locker = create_locker(mock_tk, tmp_path)
locker.remaining_time = 5
locker.countdown_label = MagicMock()
locker.update_lockout_countdown()
assert locker.remaining_time == 4
locker.root.after.assert_called_with( # type: ignore[attr-defined]
1000, locker.update_lockout_countdown
)
def test_update_lockout_countdown_at_zero(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test countdown at zero returns to workout question."""
locker = create_locker(mock_tk, tmp_path)
locker.remaining_time = 0
locker.countdown_label = MagicMock()
locker.ask_workout_done = MagicMock() # type: ignore[method-assign]
locker.update_lockout_countdown()
locker.ask_workout_done.assert_called_once()
def test_update_submit_timer_countdown(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test submit timer counts down."""
locker = create_locker(mock_tk, tmp_path)
locker.submit_unlock_time = 5
locker.timer_label = MagicMock()
locker.submit_btn = MagicMock()
locker.entries_to_check = []
locker.update_submit_timer()
assert locker.submit_unlock_time == 4
locker.root.after.assert_called() # type: ignore[attr-defined]
def test_update_submit_timer_enables_when_filled(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test submit enabled when timer done and entries filled."""
locker = create_locker(mock_tk, tmp_path)
locker.submit_unlock_time = 0
locker.timer_label = MagicMock()
locker.submit_btn = MagicMock()
mock_entry = MagicMock()
mock_entry.get.return_value = "some value"
locker.entries_to_check = [mock_entry]
locker.submit_command = MagicMock()
locker.update_submit_timer()
locker.submit_btn.config.assert_called()
def test_update_submit_timer_waits_for_entries(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test submit waits when entries not filled."""
locker = create_locker(mock_tk, tmp_path)
locker.submit_unlock_time = 0
locker.timer_label = MagicMock()
locker.submit_btn = MagicMock()
mock_entry = MagicMock()
mock_entry.get.return_value = "" # Empty entry
locker.entries_to_check = [mock_entry]
locker.update_submit_timer()
locker.root.after.assert_called_with( # type: ignore[attr-defined]
1000, locker.check_entries_filled
)
def test_update_submit_timer_handles_tcl_error(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test timer handles TclError when widgets destroyed."""
locker = create_locker(mock_tk, tmp_path)
locker.submit_unlock_time = 5
locker.timer_label = MagicMock()
locker.timer_label.config.side_effect = _TK_TCLERROR("widget destroyed")
# Should not raise
locker.update_submit_timer()
def test_check_entries_filled_enables_submit(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test check_entries_filled enables submit when all filled."""
locker = create_locker(mock_tk, tmp_path)
locker.timer_label = MagicMock()
locker.submit_btn = MagicMock()
mock_entry = MagicMock()
mock_entry.get.return_value = "value"
locker.entries_to_check = [mock_entry]
locker.submit_command = MagicMock()
locker.check_entries_filled()
locker.submit_btn.config.assert_called()
def test_check_entries_filled_continues_waiting(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test check_entries_filled continues waiting when not filled."""
locker = create_locker(mock_tk, tmp_path)
locker.timer_label = MagicMock()
locker.submit_btn = MagicMock()
mock_entry = MagicMock()
mock_entry.get.return_value = ""
locker.entries_to_check = [mock_entry]
locker.check_entries_filled()
locker.root.after.assert_called_with( # type: ignore[attr-defined]
1000, locker.check_entries_filled
)
def test_check_entries_filled_handles_tcl_error(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test check_entries_filled handles TclError."""
locker = create_locker(mock_tk, tmp_path)
locker.timer_label = MagicMock()
mock_entry = MagicMock()
mock_entry.get.side_effect = _TK_TCLERROR("widget destroyed")
locker.entries_to_check = [mock_entry]
# Should not raise
locker.check_entries_filled()
class TestAskWorkoutType:
"""Tests for ask_workout_type method."""
def test_ask_workout_type_creates_buttons(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_workout_type creates running and strength buttons."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.ask_workout_type()
locker.clear_container.assert_called_once()
# Verify Label and Button were called
mock_tk.Label.assert_called()
mock_tk.Button.assert_called()
class TestAskRunningDetails:
"""Tests for ask_running_details method."""
def test_ask_running_details_sets_workout_type(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_running_details sets workout type to running."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_running_details()
assert locker.workout_data["type"] == "running"
locker.clear_container.assert_called_once()
def test_ask_running_details_creates_entry_fields(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_running_details creates entry fields."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_running_details()
# Verify Entry fields were created
mock_tk.Entry.assert_called()
assert hasattr(locker, "distance_entry")
assert hasattr(locker, "time_entry")
assert hasattr(locker, "pace_entry")
def test_ask_running_details_sets_timer(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_running_details initializes submit timer."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_running_details()
assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO
locker.update_submit_timer.assert_called_once()
class TestAskStrengthDetails:
"""Tests for ask_strength_details method."""
def test_ask_strength_details_sets_workout_type(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_strength_details sets workout type to strength."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_strength_details()
assert locker.workout_data["type"] == "strength"
locker.clear_container.assert_called_once()
def test_ask_strength_details_creates_entry_fields(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_strength_details creates entry fields."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_strength_details()
# Verify Entry fields were created
mock_tk.Entry.assert_called()
assert hasattr(locker, "exercises_entry")
assert hasattr(locker, "sets_entry")
assert hasattr(locker, "reps_entry")
assert hasattr(locker, "weights_entry")
assert hasattr(locker, "total_weight_entry")
def test_ask_strength_details_sets_timer(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_strength_details initializes submit timer."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_strength_details()
assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO
locker.update_submit_timer.assert_called_once()
def test_ask_strength_details_production_timer(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test production mode uses longer submit delay."""
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
locker.ask_strength_details()
assert locker.submit_unlock_time == SUBMIT_DELAY_PRODUCTION
class TestAskWorkoutDone:
"""Tests for ask_workout_done method."""
def test_ask_workout_done_creates_buttons(
self,
mock_tk: MagicMock,
_mock_sys_exit: MagicMock,
tmp_path: Path,
) -> None:
"""Test ask_workout_done creates yes/no buttons."""
locker = create_locker(mock_tk, tmp_path)
locker.clear_container = MagicMock() # type: ignore[method-assign]
locker.ask_workout_done()
locker.clear_container.assert_called_once()
mock_tk.Label.assert_called()
mock_tk.Button.assert_called()

Some files were not shown because too many files have changed in this diff Show More