mirror of
https://github.com/kuhyx/testsAndMisc.git
synced 2026-07-04 13:03:13 +02:00
WIP: Enforce 500-line limit - split batch 1
Split 16+ files. 27 files still need splitting. See session notes.
This commit is contained in:
parent
e51c12dd8e
commit
c985160d17
@ -64,7 +64,7 @@ unfixable = []
|
||||
"SLF001", # Allow private member access in tests
|
||||
]
|
||||
# Files using urlopen with validated URL schemes
|
||||
"python_pkg/geo_data.py" = ["S310"]
|
||||
"python_pkg/geo_data/_common.py" = ["S310"]
|
||||
"python_pkg/steam_backlog_enforcer/library_hider.py" = ["S310"]
|
||||
"poker_modifier_app/poker_modifier_app.py" = [
|
||||
"FBT003", # Boolean positional values in tkinter API calls
|
||||
@ -76,9 +76,12 @@ unfixable = []
|
||||
"FBT003", # Boolean positional values in tkinter API calls
|
||||
]
|
||||
# Brother printer - optional usb.core/usb.util imports
|
||||
"python_pkg/brother_printer/check_brother_printer.py" = [
|
||||
"python_pkg/brother_printer/cups_service.py" = [
|
||||
"PLC0415", # Late imports for optional pyusb dependency
|
||||
]
|
||||
"python_pkg/brother_printer/usb_query.py" = [
|
||||
"PLC0415", # Late import of cups_service fallback
|
||||
]
|
||||
# Music generator - CLI script with intentional patterns
|
||||
"python_pkg/music_gen/music_generator.py" = [
|
||||
"T201", # print() is intentional for CLI feedback
|
||||
@ -182,6 +185,8 @@ disable = []
|
||||
[tool.pylint.design]
|
||||
# A class with just run() as public API is valid for games/apps
|
||||
min-public-methods = 1
|
||||
# Enforce maximum file length of 500 lines
|
||||
max-module-lines = 500
|
||||
|
||||
[tool.pylint.spelling]
|
||||
# No spelling dictionary to avoid false positives
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
173
python_pkg/brother_printer/constants.py
Normal file
173
python_pkg/brother_printer/constants.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Constants, status codes, and lookup tables for Brother printer checking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
# ── Colors ───────────────────────────────────────────────────────────
|
||||
|
||||
RED = "\033[0;31m"
|
||||
YELLOW = "\033[1;33m"
|
||||
GREEN = "\033[0;32m"
|
||||
CYAN = "\033[0;36m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# ── SNMP supply level sentinel values ────────────────────────────────────────
|
||||
|
||||
SNMP_LEVEL_OK = -3
|
||||
SNMP_LEVEL_LOW = -2
|
||||
SUPPLY_LOW_PCT = 10
|
||||
SUPPLY_WARN_PCT = 25
|
||||
PROGRESS_BAR_WIDTH = 25
|
||||
|
||||
# Brother HL-1110 consumable page ratings
|
||||
TONER_RATED_PAGES = 1000
|
||||
DRUM_RATED_PAGES = 10000
|
||||
CUPS_PAGE_LOG_PATH = "/var/log/cups/page_log"
|
||||
CONSUMABLE_STATE_DIR = ".config/brother_printer"
|
||||
MIN_LPSTAT_JOB_PARTS = 4
|
||||
|
||||
BROTHER_USB_VENDOR_ID = 0x04F9
|
||||
|
||||
|
||||
def _out(text: str = "") -> None:
|
||||
"""Write a line to stdout."""
|
||||
sys.stdout.write(text + "\n")
|
||||
|
||||
|
||||
def _prompt(text: str) -> str:
|
||||
"""Read user input with a prompt."""
|
||||
sys.stdout.write(text)
|
||||
sys.stdout.flush()
|
||||
return sys.stdin.readline().strip()
|
||||
|
||||
|
||||
# ── Brother PJL status codes ────────────────────────────────────────
|
||||
# Documented in Brother PJL Technical Reference.
|
||||
# Format: code -> (severity, short_text, action)
|
||||
# Severities: ok, info, warn, critical
|
||||
|
||||
BROTHER_STATUS_CODES: dict[int, tuple[str, str, str]] = {
|
||||
10001: ("ok", "Ready", ""),
|
||||
10002: ("ok", "Sleep", ""),
|
||||
10003: ("info", "Self-test / Calibrating", ""),
|
||||
10004: ("ok", "Warming up", ""),
|
||||
10005: ("ok", "Cooling down", ""),
|
||||
10006: ("info", "Processing", ""),
|
||||
10007: ("info", "Printing", ""),
|
||||
10014: ("ok", "Cancelling", ""),
|
||||
10023: ("info", "Waiting", ""),
|
||||
# Toner
|
||||
30010: (
|
||||
"warn",
|
||||
"Toner Low",
|
||||
"Order replacement toner cartridge (TN-1050/TN-1030 compatible).",
|
||||
),
|
||||
30038: (
|
||||
"warn",
|
||||
"Toner Low",
|
||||
"Order replacement toner cartridge (TN-1050/TN-1030 compatible).",
|
||||
),
|
||||
40038: (
|
||||
"warn",
|
||||
"Toner Low",
|
||||
"Order replacement toner cartridge (TN-1050/TN-1030 compatible).",
|
||||
),
|
||||
40309: (
|
||||
"critical",
|
||||
"Replace Toner",
|
||||
"The toner cartridge needs immediate replacement"
|
||||
" (TN-1050/TN-1030 compatible).",
|
||||
),
|
||||
40310: (
|
||||
"critical",
|
||||
"Toner End",
|
||||
"The toner cartridge is empty." " Replace now (TN-1050/TN-1030 compatible).",
|
||||
),
|
||||
# Drum
|
||||
30201: (
|
||||
"warn",
|
||||
"Drum End Soon",
|
||||
"The drum unit is nearing end of life."
|
||||
" Order replacement (DR-1050 compatible).",
|
||||
),
|
||||
40201: (
|
||||
"warn",
|
||||
"Drum End Soon",
|
||||
"The drum unit is nearing end of life."
|
||||
" Order replacement (DR-1050 compatible).",
|
||||
),
|
||||
40019: (
|
||||
"critical",
|
||||
"Replace Drum",
|
||||
"The drum unit must be replaced (DR-1050 compatible).",
|
||||
),
|
||||
40020: (
|
||||
"critical",
|
||||
"Drum Stop",
|
||||
"The drum unit must be replaced immediately (DR-1050 compatible).",
|
||||
),
|
||||
# Paper / feed
|
||||
40000: ("critical", "Paper Jam", "Clear the paper jam and close all covers."),
|
||||
40300: (
|
||||
"critical",
|
||||
"No Paper / Tray Open",
|
||||
"Load paper or close the paper tray.",
|
||||
),
|
||||
40302: ("critical", "No Paper", "Load paper into the paper tray."),
|
||||
40016: ("warn", "Paper Feed Error", "Check paper tray and re-seat paper."),
|
||||
# Cover
|
||||
41000: ("critical", "Cover Open", "Close the top cover of the printer."),
|
||||
41001: ("critical", "Cover Open", "Close the front cover of the printer."),
|
||||
# Others
|
||||
35078: ("info", "Manual Feed", "Load paper in the manual feed slot."),
|
||||
42000: (
|
||||
"critical",
|
||||
"Machine Error",
|
||||
"Power-cycle the printer. If error persists, contact service.",
|
||||
),
|
||||
}
|
||||
|
||||
# ── CUPS status code mappings ────────────────────────────────────────
|
||||
|
||||
_CUPS_REASONS_TO_STATUS: dict[str, int] = {
|
||||
"paused": 10023,
|
||||
"moving-to-paused": 10023,
|
||||
"toner-low": 30010,
|
||||
"toner-empty": 40310,
|
||||
"marker-supply-low": 30010,
|
||||
"marker-supply-empty": 40310,
|
||||
"media-empty": 40302,
|
||||
"media-needed": 40302,
|
||||
"media-jam": 40000,
|
||||
"cover-open": 41000,
|
||||
"door-open": 41000,
|
||||
"input-tray-missing": 40300,
|
||||
}
|
||||
|
||||
_CUPS_STATE_TO_STATUS: dict[str, int] = {
|
||||
"idle": 10001,
|
||||
"processing": 10007,
|
||||
"stopped": 10023,
|
||||
}
|
||||
|
||||
_ERROR_REASON_MAP: tuple[tuple[tuple[str, ...], str, str], ...] = (
|
||||
(("media-jam",), "40000", "Paper Jam"),
|
||||
(("cover-open", "door-open"), "41000", "Cover Open"),
|
||||
(("toner-empty",), "40310", "Toner End"),
|
||||
(("toner-low",), "30010", "Toner Low"),
|
||||
)
|
||||
|
||||
|
||||
def get_status_info(code: str) -> tuple[str, str, str]:
|
||||
"""Look up a PJL status code. Returns (severity, text, action)."""
|
||||
try:
|
||||
return BROTHER_STATUS_CODES[int(code)]
|
||||
except (KeyError, ValueError):
|
||||
return (
|
||||
"info",
|
||||
f"Unknown status (code {code})",
|
||||
"Check printer display for details.",
|
||||
)
|
||||
459
python_pkg/brother_printer/cups_queue.py
Normal file
459
python_pkg/brother_printer/cups_queue.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""CUPS queue inspection, display, and interactive fix functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from python_pkg.brother_printer.constants import (
|
||||
BOLD,
|
||||
CYAN,
|
||||
DIM,
|
||||
GREEN,
|
||||
MIN_LPSTAT_JOB_PARTS,
|
||||
RED,
|
||||
RESET,
|
||||
YELLOW,
|
||||
_out,
|
||||
_prompt,
|
||||
)
|
||||
from python_pkg.brother_printer.cups_service import find_cups_printer_name
|
||||
from python_pkg.brother_printer.data_classes import CUPSJob, CUPSQueueStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
# ── Queue inspection ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_lpstat_printer_line(line: str) -> tuple[bool, str]:
|
||||
"""Parse an lpstat -p line. Returns (enabled, reason)."""
|
||||
enabled = "disabled" not in line.lower()
|
||||
reason = ""
|
||||
match = re.search(r"\d{4}\s+-\s*(.+)", line)
|
||||
if match:
|
||||
reason = match.group(1).strip()
|
||||
return enabled, reason
|
||||
|
||||
|
||||
def _parse_lpstat_jobs(output: str, printer_name: str) -> list[CUPSJob]:
|
||||
"""Parse lpstat -o output into CUPSJob list."""
|
||||
jobs: list[CUPSJob] = []
|
||||
for line in output.splitlines():
|
||||
if not line.startswith(printer_name):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) >= MIN_LPSTAT_JOB_PARTS:
|
||||
job_id = parts[0]
|
||||
user = parts[1]
|
||||
size = parts[2]
|
||||
date = " ".join(parts[3:])
|
||||
jobs.append(CUPSJob(job_id=job_id, user=user, size=size, date=date))
|
||||
return jobs
|
||||
|
||||
|
||||
def get_cups_queue_status() -> CUPSQueueStatus:
|
||||
"""Check if the CUPS queue is disabled and list pending jobs."""
|
||||
printer_name = find_cups_printer_name()
|
||||
if not printer_name:
|
||||
return CUPSQueueStatus()
|
||||
|
||||
result = CUPSQueueStatus(printer_name=printer_name)
|
||||
lpstat_path = shutil.which("lpstat")
|
||||
if not lpstat_path:
|
||||
return result
|
||||
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[lpstat_path, "-p", printer_name],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if "printer" in line.lower() and printer_name in line:
|
||||
result.enabled, result.reason = _parse_lpstat_printer_line(line)
|
||||
break
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[lpstat_path, "-o", printer_name],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
result.jobs = _parse_lpstat_jobs(r.stdout, printer_name)
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
|
||||
has_errors, last_error = _check_cups_backend_errors(printer_name)
|
||||
result.has_backend_errors = has_errors
|
||||
result.last_backend_error = last_error
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── CUPS fix actions ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _cups_enable_printer(printer_name: str) -> bool:
|
||||
"""Re-enable a disabled CUPS printer. Returns True on success."""
|
||||
cupsenable_path = shutil.which("cupsenable")
|
||||
if not cupsenable_path:
|
||||
_out(f" {RED}cupsenable not found.{RESET}")
|
||||
return False
|
||||
try:
|
||||
subprocess.run(
|
||||
[cupsenable_path, printer_name],
|
||||
timeout=5,
|
||||
check=True,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e:
|
||||
_out(f" {RED}Failed to enable printer: {e}{RESET}")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _cups_cancel_all_jobs(printer_name: str) -> bool:
|
||||
"""Cancel all pending jobs. Returns True on success."""
|
||||
cancel_path = shutil.which("cancel")
|
||||
if not cancel_path:
|
||||
_out(f" {RED}cancel command not found.{RESET}")
|
||||
return False
|
||||
try:
|
||||
subprocess.run(
|
||||
[cancel_path, "-a", printer_name],
|
||||
timeout=5,
|
||||
check=True,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError) as e:
|
||||
_out(f" {RED}Failed to cancel jobs: {e}{RESET}")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _cups_cancel_job(job_id: str) -> bool:
|
||||
"""Cancel a specific job. Returns True on success."""
|
||||
cancel_path = shutil.which("cancel")
|
||||
if not cancel_path:
|
||||
return False
|
||||
try:
|
||||
subprocess.run(
|
||||
[cancel_path, job_id],
|
||||
timeout=5,
|
||||
check=True,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _cups_restart_service() -> bool:
|
||||
"""Restart the CUPS service. Returns True on success."""
|
||||
systemctl_path = shutil.which("systemctl")
|
||||
if not systemctl_path:
|
||||
_out(f" {RED}systemctl not found.{RESET}")
|
||||
return False
|
||||
sys.stdout.write(f" {DIM}Restarting CUPS...{RESET}")
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[systemctl_path, "restart", "cups"],
|
||||
)
|
||||
deadline = time.time() + 30
|
||||
while proc.poll() is None:
|
||||
if time.time() > deadline:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
sys.stdout.write("\n")
|
||||
_out(
|
||||
f" {RED}CUPS restart timed out"
|
||||
f" (stuck backend process?).{RESET}"
|
||||
)
|
||||
_out(
|
||||
f" {DIM}Try: sudo kill -9 $(pgrep -f 'cups/backend/usb')"
|
||||
f" && sudo systemctl restart cups{RESET}"
|
||||
)
|
||||
return False
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
sys.stdout.write("\n")
|
||||
if proc.returncode != 0:
|
||||
_out(
|
||||
f" {RED}CUPS restart failed" f" (exit code {proc.returncode}).{RESET}"
|
||||
)
|
||||
return False
|
||||
except OSError as e:
|
||||
sys.stdout.write("\n")
|
||||
_out(f" {RED}Failed to restart CUPS: {e}{RESET}")
|
||||
return False
|
||||
time.sleep(2)
|
||||
return True
|
||||
|
||||
|
||||
# ── Backend error detection ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_cups_printer_healthy(printer_name: str) -> bool:
|
||||
"""Check live CUPS state via lpstat. Returns True if enabled with no issues."""
|
||||
lpstat_path = shutil.which("lpstat")
|
||||
if not lpstat_path:
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[lpstat_path, "-p", printer_name],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if (
|
||||
printer_name in line
|
||||
and "idle" in line.lower()
|
||||
and "enabled" in line.lower()
|
||||
):
|
||||
return True
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _find_backend_error_in_log(
|
||||
lines: list[str],
|
||||
) -> tuple[str, str, str]:
|
||||
"""Scan CUPS log lines (reversed) for backend errors.
|
||||
|
||||
Returns:
|
||||
(backend_error, error_timestamp, last_success_timestamp)
|
||||
"""
|
||||
backend_error = ""
|
||||
error_timestamp = ""
|
||||
last_success_timestamp = ""
|
||||
|
||||
for line in reversed(lines):
|
||||
if (
|
||||
"backend errors" in line or "stopped with status" in line
|
||||
) and not backend_error:
|
||||
backend_error = line.strip()
|
||||
ts_match = re.search(r"\[([^\]]+)\]", line)
|
||||
if ts_match:
|
||||
error_timestamp = ts_match.group(1)
|
||||
if ("Completed" in line or "total" in line) and error_timestamp:
|
||||
ts_match = re.search(r"\[([^\]]+)\]", line)
|
||||
if ts_match:
|
||||
last_success_timestamp = ts_match.group(1)
|
||||
break
|
||||
|
||||
return backend_error, error_timestamp, last_success_timestamp
|
||||
|
||||
|
||||
def _check_cups_backend_errors(
|
||||
printer_name: str,
|
||||
) -> tuple[bool, str]:
|
||||
"""Check CUPS error log for backend errors. Returns (has_errors, last_error)."""
|
||||
if _is_cups_printer_healthy(printer_name):
|
||||
return False, ""
|
||||
|
||||
log_path = Path("/var/log/cups/error_log")
|
||||
if not log_path.exists():
|
||||
return False, ""
|
||||
try:
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
except OSError:
|
||||
return False, ""
|
||||
|
||||
backend_error, error_timestamp, last_success_timestamp = _find_backend_error_in_log(
|
||||
lines
|
||||
)
|
||||
|
||||
if not backend_error:
|
||||
return False, ""
|
||||
|
||||
if last_success_timestamp and last_success_timestamp > error_timestamp:
|
||||
return False, ""
|
||||
|
||||
return True, backend_error
|
||||
|
||||
|
||||
# ── Queue status display ────────────────────────────────────────────
|
||||
|
||||
|
||||
def display_cups_queue_status(queue: CUPSQueueStatus) -> None:
|
||||
"""Display CUPS queue status and offer interactive fixes."""
|
||||
if not queue.printer_name:
|
||||
return
|
||||
if queue.enabled and not queue.jobs and not queue.has_backend_errors:
|
||||
return
|
||||
|
||||
_out()
|
||||
_out(f"{BOLD}── Print Queue ──{RESET}")
|
||||
_out()
|
||||
|
||||
if queue.has_backend_errors and queue.enabled and not queue.jobs:
|
||||
_out(f" {YELLOW}{BOLD}⚡ CUPS backend has stale errors{RESET}")
|
||||
_out(
|
||||
f" {DIM}New print jobs may silently fail."
|
||||
f" A CUPS restart usually fixes this.{RESET}"
|
||||
)
|
||||
_out()
|
||||
|
||||
if not queue.enabled:
|
||||
_out(f" {RED}{BOLD}⚠ Printer queue is DISABLED{RESET}")
|
||||
if queue.reason:
|
||||
_out(f" {DIM}Reason: {queue.reason}{RESET}")
|
||||
_out()
|
||||
|
||||
if queue.jobs:
|
||||
_out(f" {BOLD}Pending jobs ({len(queue.jobs)}):{RESET}")
|
||||
for job in queue.jobs:
|
||||
_out(f" {job.job_id} {DIM}{job.user} {job.size}B {job.date}{RESET}")
|
||||
_out()
|
||||
|
||||
_offer_queue_fix(queue)
|
||||
|
||||
|
||||
# ── Interactive queue fix ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _offer_queue_fix(queue: CUPSQueueStatus) -> None:
|
||||
"""Prompt the user to fix a disabled queue / pending jobs."""
|
||||
_out(f" {BOLD}Available actions:{RESET}")
|
||||
|
||||
options: list[str] = []
|
||||
if not queue.enabled and queue.jobs:
|
||||
_out(f" {CYAN}1){RESET} Re-enable printer and retry all jobs")
|
||||
_out(f" {CYAN}2){RESET} Re-enable printer and cancel all jobs")
|
||||
_out(f" {CYAN}3){RESET} Cancel all jobs (keep printer disabled)")
|
||||
_out(f" {CYAN}4){RESET} Restart CUPS service (fixes stale backend)")
|
||||
_out(f" {CYAN}5){RESET} Restart CUPS + re-enable + retry all jobs")
|
||||
_out(f" {CYAN}6){RESET} Do nothing")
|
||||
options = ["1", "2", "3", "4", "5", "6"]
|
||||
elif not queue.enabled:
|
||||
_out(f" {CYAN}1){RESET} Re-enable printer")
|
||||
_out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)")
|
||||
_out(f" {CYAN}3){RESET} Do nothing")
|
||||
options = ["1", "2", "3"]
|
||||
elif queue.jobs:
|
||||
_out(f" {CYAN}1){RESET} Cancel all pending jobs")
|
||||
_out(f" {CYAN}2){RESET} Restart CUPS service (fixes stale backend)")
|
||||
_out(f" {CYAN}3){RESET} Do nothing")
|
||||
options = ["1", "2", "3"]
|
||||
else:
|
||||
_out(f" {CYAN}1){RESET} Restart CUPS service (fixes stale backend)")
|
||||
_out(f" {CYAN}2){RESET} Do nothing")
|
||||
options = ["1", "2"]
|
||||
|
||||
_out()
|
||||
choice = _prompt(f" Choose [{'/'.join(options)}]: ")
|
||||
_out()
|
||||
|
||||
if not queue.enabled and queue.jobs:
|
||||
_handle_disabled_with_jobs(queue, choice)
|
||||
elif not queue.enabled:
|
||||
_handle_disabled_no_jobs(queue, choice)
|
||||
elif queue.jobs:
|
||||
_handle_enabled_with_jobs(queue, choice)
|
||||
else:
|
||||
_handle_backend_errors_only(choice)
|
||||
|
||||
|
||||
def _dwj_enable_only(printer_name: str) -> None:
|
||||
"""Choice 1: re-enable printer so queued jobs are retried."""
|
||||
if _cups_enable_printer(printer_name):
|
||||
_out(f" {GREEN}✓ Printer re-enabled. Jobs will be retried.{RESET}")
|
||||
|
||||
|
||||
def _dwj_cancel_and_enable(printer_name: str) -> None:
|
||||
"""Choice 2: cancel all jobs then re-enable."""
|
||||
_cups_cancel_all_jobs(printer_name)
|
||||
if _cups_enable_printer(printer_name):
|
||||
_out(f" {GREEN}✓ All jobs cancelled and printer re-enabled.{RESET}")
|
||||
|
||||
|
||||
def _dwj_cancel_only(printer_name: str) -> None:
|
||||
"""Choice 3: cancel all jobs."""
|
||||
if _cups_cancel_all_jobs(printer_name):
|
||||
_out(f" {GREEN}✓ All jobs cancelled.{RESET}")
|
||||
|
||||
|
||||
def _dwj_restart_only(_printer_name: str) -> None:
|
||||
"""Choice 4: restart CUPS."""
|
||||
if _cups_restart_service():
|
||||
_out(f" {GREEN}✓ CUPS restarted.{RESET}")
|
||||
|
||||
|
||||
def _dwj_restart_and_enable(printer_name: str) -> None:
|
||||
"""Choice 5: restart CUPS and re-enable printer."""
|
||||
if _cups_restart_service():
|
||||
_cups_enable_printer(printer_name)
|
||||
_out(
|
||||
f" {GREEN}✓ CUPS restarted, printer re-enabled."
|
||||
f" Jobs will be retried.{RESET}"
|
||||
)
|
||||
|
||||
|
||||
_DWJ_ACTIONS: dict[str, Callable[[str], None]] = {
|
||||
"1": _dwj_enable_only,
|
||||
"2": _dwj_cancel_and_enable,
|
||||
"3": _dwj_cancel_only,
|
||||
"4": _dwj_restart_only,
|
||||
"5": _dwj_restart_and_enable,
|
||||
}
|
||||
|
||||
|
||||
def _handle_disabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None:
|
||||
"""Handle fix for disabled printer with pending jobs."""
|
||||
action = _DWJ_ACTIONS.get(choice)
|
||||
if action is not None:
|
||||
action(queue.printer_name)
|
||||
else:
|
||||
_out(f" {DIM}No changes made.{RESET}")
|
||||
|
||||
|
||||
def _handle_disabled_no_jobs(queue: CUPSQueueStatus, choice: str) -> None:
|
||||
"""Handle fix for disabled printer with no pending jobs."""
|
||||
if choice == "1":
|
||||
if _cups_enable_printer(queue.printer_name):
|
||||
_out(f" {GREEN}✓ Printer re-enabled.{RESET}")
|
||||
elif choice == "2":
|
||||
if _cups_restart_service():
|
||||
_cups_enable_printer(queue.printer_name)
|
||||
_out(f" {GREEN}✓ CUPS restarted and printer re-enabled.{RESET}")
|
||||
else:
|
||||
_out(f" {DIM}No changes made.{RESET}")
|
||||
|
||||
|
||||
def _handle_enabled_with_jobs(queue: CUPSQueueStatus, choice: str) -> None:
|
||||
"""Handle fix for enabled printer with stuck jobs."""
|
||||
if choice == "1":
|
||||
if _cups_cancel_all_jobs(queue.printer_name):
|
||||
_out(f" {GREEN}✓ All jobs cancelled.{RESET}")
|
||||
elif choice == "2":
|
||||
if _cups_restart_service():
|
||||
_out(f" {GREEN}✓ CUPS restarted.{RESET}")
|
||||
else:
|
||||
_out(f" {DIM}No changes made.{RESET}")
|
||||
|
||||
|
||||
def _handle_backend_errors_only(choice: str) -> None:
|
||||
"""Handle fix when only stale backend errors are detected."""
|
||||
if choice == "1":
|
||||
if _cups_restart_service():
|
||||
_out(f" {GREEN}✓ CUPS restarted. Stale backend errors cleared.{RESET}")
|
||||
else:
|
||||
_out(f" {DIM}No changes made.{RESET}")
|
||||
479
python_pkg/brother_printer/cups_service.py
Normal file
479
python_pkg/brother_printer/cups_service.py
Normal file
@ -0,0 +1,479 @@
|
||||
"""CUPS service management, USB fallback, and consumable state tracking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
from python_pkg.brother_printer.constants import (
|
||||
_CUPS_REASONS_TO_STATUS,
|
||||
_CUPS_STATE_TO_STATUS,
|
||||
_ERROR_REASON_MAP,
|
||||
BROTHER_USB_VENDOR_ID,
|
||||
CONSUMABLE_STATE_DIR,
|
||||
CUPS_PAGE_LOG_PATH,
|
||||
DRUM_RATED_PAGES,
|
||||
GREEN,
|
||||
RESET,
|
||||
TONER_RATED_PAGES,
|
||||
_out,
|
||||
)
|
||||
from python_pkg.brother_printer.data_classes import (
|
||||
PageCountEstimate,
|
||||
USBPortStatus,
|
||||
USBResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CUPS_PAGE_LOG = Path(CUPS_PAGE_LOG_PATH)
|
||||
CONSUMABLE_STATE_FILE = Path.home() / CONSUMABLE_STATE_DIR / "state.json"
|
||||
|
||||
|
||||
# ── pyusb device info ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_pyusb_device_info() -> dict[str, str]:
|
||||
"""Get Brother USB printer info via pyusb (no interface claim needed)."""
|
||||
try:
|
||||
import usb.core
|
||||
|
||||
dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID)
|
||||
if dev is None:
|
||||
return {}
|
||||
except (ImportError, OSError, ValueError):
|
||||
return {}
|
||||
else:
|
||||
return {
|
||||
"product": dev.product or "",
|
||||
"serial": dev.serial_number or "",
|
||||
}
|
||||
|
||||
|
||||
# ── CUPS service control ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _stop_cups() -> bool:
|
||||
"""Stop CUPS service and sockets. Returns True on success."""
|
||||
systemctl = shutil.which("systemctl")
|
||||
if not systemctl:
|
||||
return False
|
||||
try:
|
||||
subprocess.run(
|
||||
[systemctl, "stop", "cups.service", "cups.socket", "cups.path"],
|
||||
timeout=15,
|
||||
check=True,
|
||||
)
|
||||
time.sleep(2)
|
||||
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_cups_scheduler_running() -> bool:
|
||||
"""Check if the CUPS scheduler is currently running."""
|
||||
lpstat = shutil.which("lpstat")
|
||||
if not lpstat:
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[lpstat, "-r"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
check=False,
|
||||
)
|
||||
return (
|
||||
"is running" in r.stdout.lower() and "not running" not in r.stdout.lower()
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def start_cups() -> bool:
|
||||
"""Start CUPS service, socket, and path units. Returns True on success."""
|
||||
systemctl = shutil.which("systemctl")
|
||||
if not systemctl:
|
||||
return False
|
||||
try:
|
||||
subprocess.run(
|
||||
[systemctl, "start", "cups.service", "cups.socket", "cups.path"],
|
||||
timeout=15,
|
||||
check=True,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
|
||||
return False
|
||||
for _ in range(10):
|
||||
if is_cups_scheduler_running():
|
||||
return True
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_cups_running() -> bool:
|
||||
"""Make sure CUPS is running, starting it if necessary."""
|
||||
if is_cups_scheduler_running():
|
||||
return True
|
||||
return start_cups()
|
||||
|
||||
|
||||
# ── USB port status via pyusb ────────────────────────────────────────
|
||||
|
||||
|
||||
def _query_usb_port_status_raw() -> USBPortStatus | None:
|
||||
"""Query USB printer port status via pyusb control transfer.
|
||||
|
||||
Requires root and temporarily stops CUPS to access the USB device.
|
||||
Returns None if the query fails.
|
||||
"""
|
||||
try:
|
||||
import usb.core
|
||||
import usb.util
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID)
|
||||
if dev is None:
|
||||
return None
|
||||
|
||||
if not _stop_cups():
|
||||
return None
|
||||
|
||||
try:
|
||||
dev.reset()
|
||||
time.sleep(2)
|
||||
dev = usb.core.find(idVendor=BROTHER_USB_VENDOR_ID)
|
||||
if dev is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if dev.is_kernel_driver_active(0):
|
||||
dev.detach_kernel_driver(0)
|
||||
except (usb.core.USBError, NotImplementedError):
|
||||
pass
|
||||
|
||||
usb.util.claim_interface(dev, 0)
|
||||
try:
|
||||
# USB Printer Class GET_PORT_STATUS (bRequest=0x01)
|
||||
raw = dev.ctrl_transfer(0xA1, 0x01, 0, 0, 1, timeout=5000)
|
||||
port_byte = raw[0]
|
||||
return USBPortStatus(
|
||||
paper_empty=bool(port_byte & 0x20),
|
||||
online=bool(port_byte & 0x10),
|
||||
error=not bool(port_byte & 0x08),
|
||||
raw_byte=port_byte,
|
||||
)
|
||||
finally:
|
||||
usb.util.release_interface(dev, 0)
|
||||
usb.util.dispose_resources(dev)
|
||||
except (OSError, ValueError):
|
||||
logger.debug("USB port status query failed", exc_info=True)
|
||||
return None
|
||||
finally:
|
||||
start_cups()
|
||||
|
||||
|
||||
# ── Consumable state management ──────────────────────────────────────
|
||||
|
||||
|
||||
def _get_cups_total_pages() -> int:
|
||||
"""Parse CUPS page_log to get total pages printed (deduplicated by job)."""
|
||||
if not CUPS_PAGE_LOG.exists():
|
||||
return 0
|
||||
try:
|
||||
text = CUPS_PAGE_LOG.read_text(encoding="utf-8", errors="replace")
|
||||
except OSError:
|
||||
return 0
|
||||
jobs: dict[str, int] = {}
|
||||
for line in text.splitlines():
|
||||
match = re.search(r"\s(\d+)\s+\[.*?\]\s+total\s+(\d+)", line)
|
||||
if match:
|
||||
job_id = match.group(1)
|
||||
pages = int(match.group(2))
|
||||
jobs[job_id] = max(jobs.get(job_id, 0), pages)
|
||||
return sum(jobs.values())
|
||||
|
||||
|
||||
def _load_consumable_state() -> dict[str, int]:
|
||||
"""Load consumable replacement state from disk."""
|
||||
defaults: dict[str, int] = {"toner_replaced_at": 0, "drum_replaced_at": 0}
|
||||
if not CONSUMABLE_STATE_FILE.exists():
|
||||
return defaults
|
||||
try:
|
||||
data = json.loads(
|
||||
CONSUMABLE_STATE_FILE.read_text(encoding="utf-8"),
|
||||
)
|
||||
return {
|
||||
"toner_replaced_at": int(data.get("toner_replaced_at", 0)),
|
||||
"drum_replaced_at": int(data.get("drum_replaced_at", 0)),
|
||||
}
|
||||
except (OSError, json.JSONDecodeError, ValueError, TypeError):
|
||||
return defaults
|
||||
|
||||
|
||||
def _save_consumable_state(state: dict[str, int]) -> None:
|
||||
"""Persist consumable replacement state to disk."""
|
||||
CONSUMABLE_STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONSUMABLE_STATE_FILE.write_text(
|
||||
json.dumps(state, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def reset_consumable(name: str) -> None:
|
||||
"""Record current page count as replacement point for a consumable."""
|
||||
total = _get_cups_total_pages()
|
||||
state = _load_consumable_state()
|
||||
key = f"{name}_replaced_at"
|
||||
state[key] = total
|
||||
_save_consumable_state(state)
|
||||
_out(
|
||||
f"{GREEN}✓ {name.capitalize()} counter reset at page count" f" {total}.{RESET}"
|
||||
)
|
||||
_out(f" State saved to {CONSUMABLE_STATE_FILE}")
|
||||
|
||||
|
||||
def estimate_consumable_life() -> PageCountEstimate:
|
||||
"""Estimate toner/drum life from CUPS page count since last replacement."""
|
||||
total = _get_cups_total_pages()
|
||||
if total <= 0:
|
||||
return PageCountEstimate()
|
||||
state = _load_consumable_state()
|
||||
toner_pages = max(0, total - state["toner_replaced_at"])
|
||||
drum_pages = max(0, total - state["drum_replaced_at"])
|
||||
toner_pct = max(0, 100 - (toner_pages * 100 // TONER_RATED_PAGES))
|
||||
drum_pct = max(0, 100 - (drum_pages * 100 // DRUM_RATED_PAGES))
|
||||
return PageCountEstimate(
|
||||
total_pages=total,
|
||||
toner_pages=toner_pages,
|
||||
drum_pages=drum_pages,
|
||||
toner_pct_remaining=toner_pct,
|
||||
drum_pct_remaining=drum_pct,
|
||||
toner_exhausted=toner_pages >= TONER_RATED_PAGES,
|
||||
toner_low=toner_pages >= TONER_RATED_PAGES * 80 // 100,
|
||||
drum_near_end=drum_pages >= DRUM_RATED_PAGES * 90 // 100,
|
||||
)
|
||||
|
||||
|
||||
# ── IPP / CUPS attribute queries ────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_ipp_attributes(output: str) -> dict[str, str]:
|
||||
"""Parse ipptool verbose output into an attribute dict."""
|
||||
attrs: dict[str, str] = {}
|
||||
for line in output.splitlines():
|
||||
match = re.match(r"\s+(\S+)\s+\([^)]+\)\s+=\s+(.*)", line)
|
||||
if match:
|
||||
attrs[match.group(1)] = match.group(2).strip()
|
||||
return attrs
|
||||
|
||||
|
||||
def _get_cups_ipp_status(printer_name: str) -> dict[str, str]:
|
||||
"""Query printer attributes via CUPS IPP using ipptool."""
|
||||
ipptool_path = shutil.which("ipptool")
|
||||
if not ipptool_path:
|
||||
return {}
|
||||
uri = f"ipp://localhost/printers/{printer_name}"
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ipptool_path, "-tv", uri, "get-printer-attributes.test"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
return _parse_ipp_attributes(r.stdout)
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
return {}
|
||||
|
||||
|
||||
def _get_cups_economode(printer_name: str) -> str:
|
||||
"""Query toner save mode setting via lpoptions."""
|
||||
lpoptions_path = shutil.which("lpoptions")
|
||||
if not lpoptions_path:
|
||||
return ""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[lpoptions_path, "-p", printer_name, "-l"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if "conomode" in line.lower():
|
||||
match = re.search(r"\*(\w+)", line)
|
||||
if match:
|
||||
return "ON" if match.group(1).lower() == "true" else "OFF"
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
# ── Status code mapping ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _map_cups_to_status_code(state: str, reasons: str) -> str:
|
||||
"""Map CUPS state + reasons to a Brother PJL status code string."""
|
||||
for keyword, code in _CUPS_REASONS_TO_STATUS.items():
|
||||
if keyword in reasons.lower():
|
||||
return str(code)
|
||||
clean_state = re.sub(r"\(.*\)", "", state).strip().lower()
|
||||
return str(_CUPS_STATE_TO_STATUS.get(clean_state, 10001))
|
||||
|
||||
|
||||
def _cups_reasons_to_error(cups_reasons: str) -> tuple[str, str]:
|
||||
"""Map CUPS reason keywords to a (status_code, display) pair."""
|
||||
reasons_lower = cups_reasons.lower()
|
||||
for keywords, code, display in _ERROR_REASON_MAP:
|
||||
if any(kw in reasons_lower for kw in keywords):
|
||||
return code, display
|
||||
return "42000", "Printer Error"
|
||||
|
||||
|
||||
def _port_status_to_status_code(
|
||||
ps: USBPortStatus,
|
||||
cups_reasons: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Map USB port status + CUPS reasons to (status_code, display)."""
|
||||
if ps.error and ps.paper_empty:
|
||||
return "40302", "No Paper"
|
||||
if ps.error and not ps.online:
|
||||
return "41000", "Cover Open"
|
||||
if ps.error:
|
||||
return _cups_reasons_to_error(cups_reasons)
|
||||
if ps.paper_empty:
|
||||
return "40302", "No Paper"
|
||||
if not ps.online:
|
||||
return "10002", "Offline / Sleep"
|
||||
return "", ""
|
||||
|
||||
|
||||
# ── CUPS printer name discovery ──────────────────────────────────────
|
||||
|
||||
|
||||
def find_cups_printer_name() -> str:
|
||||
"""Find the CUPS queue name for a Brother printer."""
|
||||
lpstat_path = shutil.which("lpstat")
|
||||
if not lpstat_path:
|
||||
return ""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[lpstat_path, "-v"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if "brother" in line.lower():
|
||||
match = re.match(r"device for (\S+):", line)
|
||||
if match:
|
||||
return match.group(1)
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
# ── CUPS-based USB fallback query ────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None:
|
||||
"""Extract product and serial from a CUPS usb:// URI."""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
info["product"] = urllib.parse.unquote(parsed.path.lstrip("/"))
|
||||
qs = urllib.parse.parse_qs(parsed.query)
|
||||
if "serial" in qs:
|
||||
info["serial"] = qs["serial"][0]
|
||||
|
||||
|
||||
def _get_printer_info_from_cups() -> dict[str, str]:
|
||||
"""Get printer model/serial from lpstat."""
|
||||
info: dict[str, str] = {"product": "", "serial": ""}
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["/usr/bin/lpstat", "-v"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if "Brother" in line:
|
||||
for part in line.split():
|
||||
if part.startswith("usb://"):
|
||||
_parse_cups_usb_uri(part, info)
|
||||
break
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
logger.debug("Failed to query CUPS for printer info", exc_info=True)
|
||||
return info
|
||||
|
||||
|
||||
def query_usb_via_cups() -> USBResult:
|
||||
"""Query USB printer status through CUPS when /dev/usb/lp* is unavailable."""
|
||||
_ensure_cups_running()
|
||||
printer_name = find_cups_printer_name()
|
||||
if not printer_name:
|
||||
return USBResult(
|
||||
error="No USB printer device at /dev/usb/lp*"
|
||||
" (usblp module not available)"
|
||||
" and no Brother printer found in CUPS.",
|
||||
)
|
||||
|
||||
pyusb_info = _get_pyusb_device_info()
|
||||
cups_info = _get_printer_info_from_cups()
|
||||
|
||||
result = USBResult(
|
||||
device="cups",
|
||||
product=(
|
||||
pyusb_info.get("product")
|
||||
or cups_info.get("product")
|
||||
or "Brother Laser Printer"
|
||||
),
|
||||
serial=pyusb_info.get("serial") or cups_info.get("serial", ""),
|
||||
)
|
||||
|
||||
ipp = _get_cups_ipp_status(printer_name)
|
||||
state = ipp.get("printer-state", "")
|
||||
reasons = ipp.get("printer-state-reasons", "none")
|
||||
result.economode = _get_cups_economode(printer_name)
|
||||
|
||||
port_status = _query_usb_port_status_raw()
|
||||
if port_status is not None:
|
||||
result.port_status = port_status
|
||||
hw_code, hw_display = _port_status_to_status_code(
|
||||
port_status,
|
||||
reasons,
|
||||
)
|
||||
if hw_code:
|
||||
result.status_code = hw_code
|
||||
result.display = hw_display
|
||||
result.online = "TRUE" if port_status.online else "FALSE"
|
||||
return result
|
||||
estimate = estimate_consumable_life()
|
||||
if estimate.toner_exhausted:
|
||||
result.status_code = "40310"
|
||||
result.display = "Toner End (estimated from page count)"
|
||||
result.online = "TRUE"
|
||||
return result
|
||||
if estimate.toner_low:
|
||||
result.status_code = "30010"
|
||||
result.display = "Toner Low (estimated from page count)"
|
||||
result.online = "TRUE"
|
||||
return result
|
||||
result.status_code = _map_cups_to_status_code(state, reasons)
|
||||
result.display = ipp.get("printer-state-message", "")
|
||||
result.online = "TRUE"
|
||||
return result
|
||||
|
||||
result.status_code = _map_cups_to_status_code(state, reasons)
|
||||
result.display = ipp.get("printer-state-message", "")
|
||||
result.online = "TRUE" if state.lower() in {"idle", "processing"} else "FALSE"
|
||||
|
||||
return result
|
||||
96
python_pkg/brother_printer/data_classes.py
Normal file
96
python_pkg/brother_printer/data_classes.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Data classes for Brother printer status information."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUPSJob:
|
||||
"""A single CUPS print job."""
|
||||
|
||||
job_id: str
|
||||
user: str
|
||||
size: str
|
||||
date: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUPSQueueStatus:
|
||||
"""Status of the CUPS print queue for a printer."""
|
||||
|
||||
printer_name: str = ""
|
||||
enabled: bool = True
|
||||
reason: str = ""
|
||||
jobs: list[CUPSJob] = field(default_factory=list)
|
||||
has_backend_errors: bool = False
|
||||
last_backend_error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageCountEstimate:
|
||||
"""Estimated consumable life based on CUPS page count."""
|
||||
|
||||
total_pages: int = 0
|
||||
toner_pages: int = 0
|
||||
drum_pages: int = 0
|
||||
toner_pct_remaining: int = 100
|
||||
drum_pct_remaining: int = 100
|
||||
toner_exhausted: bool = False
|
||||
toner_low: bool = False
|
||||
drum_near_end: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class USBPortStatus:
|
||||
"""IEEE 1284 USB printer port status bits."""
|
||||
|
||||
paper_empty: bool = False
|
||||
online: bool = True
|
||||
error: bool = False
|
||||
raw_byte: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class USBResult:
|
||||
"""Result from a USB PJL query."""
|
||||
|
||||
connection: str = "usb"
|
||||
device: str = ""
|
||||
product: str = "Brother Laser Printer"
|
||||
serial: str = ""
|
||||
status_code: str = ""
|
||||
display: str = ""
|
||||
online: str = ""
|
||||
economode: str = ""
|
||||
error: str = ""
|
||||
port_status: USBPortStatus | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkResult:
|
||||
"""Result from an SNMP network query."""
|
||||
|
||||
connection: str = "network"
|
||||
ip: str = ""
|
||||
product: str = "Unknown"
|
||||
serial: str = ""
|
||||
printer_status: str = ""
|
||||
device_status: str = ""
|
||||
display: str = ""
|
||||
page_count: str = ""
|
||||
supply_descriptions: list[str] = field(default_factory=list)
|
||||
supply_max: list[str] = field(default_factory=list)
|
||||
supply_levels: list[str] = field(default_factory=list)
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupplyStatus:
|
||||
"""Processed supply level info for display."""
|
||||
|
||||
color: str
|
||||
bar: str
|
||||
status_text: str
|
||||
warning: str
|
||||
needs_replacement: bool
|
||||
385
python_pkg/brother_printer/display.py
Normal file
385
python_pkg/brother_printer/display.py
Normal file
@ -0,0 +1,385 @@
|
||||
"""Display and formatting functions for Brother printer status reports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from python_pkg.brother_printer.constants import (
|
||||
BOLD,
|
||||
CYAN,
|
||||
DIM,
|
||||
DRUM_RATED_PAGES,
|
||||
GREEN,
|
||||
PROGRESS_BAR_WIDTH,
|
||||
RED,
|
||||
RESET,
|
||||
SNMP_LEVEL_LOW,
|
||||
SNMP_LEVEL_OK,
|
||||
SUPPLY_LOW_PCT,
|
||||
SUPPLY_WARN_PCT,
|
||||
TONER_RATED_PAGES,
|
||||
YELLOW,
|
||||
_out,
|
||||
get_status_info,
|
||||
)
|
||||
from python_pkg.brother_printer.cups_queue import (
|
||||
display_cups_queue_status,
|
||||
get_cups_queue_status,
|
||||
)
|
||||
from python_pkg.brother_printer.cups_service import estimate_consumable_life
|
||||
from python_pkg.brother_printer.data_classes import (
|
||||
NetworkResult,
|
||||
SupplyStatus,
|
||||
USBResult,
|
||||
)
|
||||
|
||||
# ── Shared display helpers ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _display_report_header() -> None:
|
||||
"""Print the report banner box."""
|
||||
_out()
|
||||
_out(f"{BOLD}╔══════════════════════════════════════════════════╗{RESET}")
|
||||
_out(f"{BOLD}║ Brother Laser Printer Status Report ║{RESET}")
|
||||
_out(f"{BOLD}╚══════════════════════════════════════════════════╝{RESET}")
|
||||
_out()
|
||||
|
||||
|
||||
def _display_page_count_estimate() -> None:
|
||||
"""Show estimated consumable life based on CUPS page count."""
|
||||
estimate = estimate_consumable_life()
|
||||
if estimate.total_pages <= 0:
|
||||
return
|
||||
_out(f"{BOLD}── Page Count Estimate ──{RESET}")
|
||||
_out()
|
||||
_out(
|
||||
f" {BOLD}Total pages printed:{RESET} {estimate.total_pages}"
|
||||
f" (toner: {estimate.toner_pages} since replacement,"
|
||||
f" drum: {estimate.drum_pages} since replacement)"
|
||||
)
|
||||
_out()
|
||||
# Toner bar
|
||||
toner_pct = estimate.toner_pct_remaining
|
||||
toner_filled = toner_pct * PROGRESS_BAR_WIDTH // 100
|
||||
toner_empty = PROGRESS_BAR_WIDTH - toner_filled
|
||||
toner_bar = f"[{'█' * toner_filled}{'░' * toner_empty}]"
|
||||
if estimate.toner_exhausted:
|
||||
toner_color = RED
|
||||
toner_note = " ← REPLACE NOW"
|
||||
elif estimate.toner_low:
|
||||
toner_color = YELLOW
|
||||
toner_note = " ← order soon"
|
||||
else:
|
||||
toner_color = GREEN
|
||||
toner_note = ""
|
||||
_out(
|
||||
f" {BOLD}Toner:{RESET} {toner_color}{toner_bar} ~{toner_pct}%"
|
||||
f"{toner_note}{RESET}"
|
||||
)
|
||||
# Drum bar
|
||||
drum_pct = estimate.drum_pct_remaining
|
||||
drum_filled = drum_pct * PROGRESS_BAR_WIDTH // 100
|
||||
drum_empty = PROGRESS_BAR_WIDTH - drum_filled
|
||||
drum_bar = f"[{'█' * drum_filled}{'░' * drum_empty}]"
|
||||
if estimate.drum_near_end:
|
||||
drum_color = YELLOW
|
||||
drum_note = " ← nearing end"
|
||||
else:
|
||||
drum_color = GREEN
|
||||
drum_note = ""
|
||||
_out(
|
||||
f" {BOLD}Drum:{RESET} {drum_color}{drum_bar} ~{drum_pct}%"
|
||||
f"{drum_note}{RESET}"
|
||||
)
|
||||
_out(
|
||||
f" {DIM}Based on pages since last replacement"
|
||||
f" vs rated capacity (toner ~{TONER_RATED_PAGES},"
|
||||
f" drum ~{DRUM_RATED_PAGES}).{RESET}"
|
||||
)
|
||||
_out(f" {DIM}Reset after replacing: --reset-toner or --reset-drum{RESET}")
|
||||
if estimate.toner_exhausted:
|
||||
_out()
|
||||
_out(
|
||||
f" {RED}{BOLD}⚠ Toner is likely exhausted."
|
||||
f" This is probably why the orange light is flashing.{RESET}"
|
||||
)
|
||||
_out()
|
||||
|
||||
|
||||
def _display_consumables_reference() -> None:
|
||||
"""Print compatible consumables reference."""
|
||||
_out(f"{BOLD}── Compatible Consumables ──{RESET}")
|
||||
_out()
|
||||
_out(f" {BOLD}Toner:{RESET} TN-1050 / TN-1030 (or compatible third-party)")
|
||||
_out(f" {BOLD}Drum:{RESET} DR-1050 / DR-1030 (or compatible third-party)")
|
||||
_out(f" {DIM} Toner rated ~1000 pages; Drum rated ~10000 pages.{RESET}")
|
||||
_out()
|
||||
|
||||
|
||||
# ── USB display helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _display_usb_device_info(result: USBResult) -> None:
|
||||
"""Print device info block for USB results."""
|
||||
_out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}")
|
||||
_out(f"{BOLD}Connection:{RESET} USB")
|
||||
if result.serial:
|
||||
_out(f"{BOLD}Serial:{RESET} {result.serial}")
|
||||
|
||||
if result.online == "TRUE":
|
||||
_out(f"{BOLD}Online:{RESET} {GREEN}Yes{RESET}")
|
||||
elif result.online == "FALSE":
|
||||
_out(f"{BOLD}Online:{RESET} {YELLOW}No (needs attention){RESET}")
|
||||
|
||||
_out()
|
||||
|
||||
if result.economode:
|
||||
if result.economode == "ON":
|
||||
_out(
|
||||
f"{BOLD}Toner Save:{RESET} {GREEN}ON{RESET}"
|
||||
" (extends toner life, lighter prints)"
|
||||
)
|
||||
else:
|
||||
_out(f"{BOLD}Toner Save:{RESET} OFF")
|
||||
|
||||
|
||||
_SEVERITY_ICONS: dict[str, str] = {
|
||||
"ok": "✓",
|
||||
"info": "i",
|
||||
"warn": "⚡",
|
||||
"critical": "⚠",
|
||||
}
|
||||
_SEVERITY_COLORS: dict[str, str] = {
|
||||
"ok": GREEN,
|
||||
"info": CYAN,
|
||||
"warn": YELLOW,
|
||||
"critical": RED,
|
||||
}
|
||||
_SEVERITY_SUMMARIES: dict[str, str] = {
|
||||
"ok": f"{GREEN}{BOLD}✓ Printer is healthy. No replacements needed.{RESET}",
|
||||
"info": (
|
||||
f"{CYAN}{BOLD}i Printer is busy/processing." f" No replacements needed.{RESET}"
|
||||
),
|
||||
"warn": (
|
||||
f"{YELLOW}{BOLD}⚡ WARNING: Maintenance will be needed"
|
||||
f" soon.{RESET}\n{YELLOW} Order replacement parts"
|
||||
f" now to avoid interruption.{RESET}"
|
||||
),
|
||||
"critical": (
|
||||
f"{RED}{BOLD}⚠ ACTION REQUIRED:" f" Replacement or fix needed now!{RESET}"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _format_status_detail(
|
||||
severity: str, short_text: str, action: str, result: USBResult
|
||||
) -> None:
|
||||
"""Print severity icon, display text, and action."""
|
||||
color = _SEVERITY_COLORS.get(severity, GREEN)
|
||||
icon = _SEVERITY_ICONS.get(severity, "✓")
|
||||
|
||||
_out(f" {color}{BOLD}{icon} {short_text}{RESET}")
|
||||
if result.display and result.display != short_text:
|
||||
_out(f" {DIM}Display: {result.display}{RESET}")
|
||||
_out(f" {DIM}Status code: {result.status_code}{RESET}")
|
||||
|
||||
if action:
|
||||
_out()
|
||||
_out(f" {color}{BOLD}Action:{RESET} {color}{action}{RESET}")
|
||||
_out()
|
||||
_out(_SEVERITY_SUMMARIES.get(severity, ""))
|
||||
|
||||
|
||||
def _display_pjl_status(result: USBResult) -> None:
|
||||
"""Display PJL status code interpretation."""
|
||||
_out()
|
||||
_out(f"{BOLD}── Printer Status ──{RESET}")
|
||||
_out()
|
||||
|
||||
if not result.status_code:
|
||||
_out(f" {YELLOW}Could not read status from printer.{RESET}")
|
||||
if result.display:
|
||||
_out(f" Display message: {BOLD}{result.display}{RESET}")
|
||||
return
|
||||
|
||||
severity, short_text, action = get_status_info(result.status_code)
|
||||
_format_status_detail(severity, short_text, action, result)
|
||||
|
||||
|
||||
def _display_cups_fallback_note(result: USBResult) -> None:
|
||||
"""Show a note when running in CUPS fallback mode."""
|
||||
_out()
|
||||
if result.port_status is not None:
|
||||
_out(
|
||||
f" {DIM}Note: Hardware status obtained via USB port query."
|
||||
f" Toner/drum percentages not available.{RESET}"
|
||||
)
|
||||
else:
|
||||
_out(
|
||||
f" {DIM}Note: pyusb not available; status obtained via"
|
||||
f" CUPS only. Detailed toner/drum levels are not"
|
||||
f" available in this mode.{RESET}"
|
||||
)
|
||||
|
||||
|
||||
# ── USB results display ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def display_usb_results(result: USBResult) -> None:
|
||||
"""Print a formatted report for USB PJL query results."""
|
||||
if result.error:
|
||||
_out(f"{RED}Error: {result.error}{RESET}")
|
||||
sys.exit(1)
|
||||
|
||||
_display_report_header()
|
||||
_display_usb_device_info(result)
|
||||
_display_pjl_status(result)
|
||||
|
||||
if result.device == "cups":
|
||||
_display_cups_fallback_note(result)
|
||||
|
||||
_out()
|
||||
_display_page_count_estimate()
|
||||
_display_consumables_reference()
|
||||
|
||||
queue = get_cups_queue_status()
|
||||
display_cups_queue_status(queue)
|
||||
|
||||
|
||||
# ── Network supply level helpers ─────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_percentage_level(desc: str, pct: int) -> tuple[int, str, str, str, bool]:
|
||||
"""Classify a supply by its calculated percentage."""
|
||||
if pct <= SUPPLY_LOW_PCT:
|
||||
return pct, f"{pct}%", RED, f"{desc} at {pct}%.", True
|
||||
if pct <= SUPPLY_WARN_PCT:
|
||||
return pct, f"{pct}%", YELLOW, f"{desc} at {pct}% -- order soon.", False
|
||||
return pct, f"{pct}%", GREEN, "", False
|
||||
|
||||
|
||||
def _classify_supply_level(
|
||||
desc: str, max_val: int, level: int
|
||||
) -> tuple[int, str, str, str, bool]:
|
||||
"""Classify a supply level. Returns (pct, status, color, warning, replace)."""
|
||||
if level == SNMP_LEVEL_OK:
|
||||
return -1, "OK", GREEN, "", False
|
||||
if level == SNMP_LEVEL_LOW:
|
||||
return -1, "LOW", RED, f"{desc} is LOW.", True
|
||||
if level == 0:
|
||||
return 0, "EMPTY", RED, f"{desc} is EMPTY -- replace now!", True
|
||||
if max_val > 0:
|
||||
pct = min(level * 100 // max_val, 100)
|
||||
return _classify_percentage_level(desc, pct)
|
||||
return -1, "", GREEN, "", False
|
||||
|
||||
|
||||
def _format_supply_bar(pct: int) -> str:
|
||||
"""Build a progress bar string for a supply percentage."""
|
||||
if pct < 0:
|
||||
return ""
|
||||
filled = pct * PROGRESS_BAR_WIDTH // 100
|
||||
empty = PROGRESS_BAR_WIDTH - filled
|
||||
return f"[{'█' * filled}{'░' * empty}]"
|
||||
|
||||
|
||||
def _process_supply_item(desc: str, max_val: int, level: int) -> SupplyStatus:
|
||||
"""Process a single supply item into display info."""
|
||||
pct, status_text, color, warning, needs_replacement = _classify_supply_level(
|
||||
desc, max_val, level
|
||||
)
|
||||
bar = _format_supply_bar(pct)
|
||||
return SupplyStatus(color, bar, status_text, warning, needs_replacement)
|
||||
|
||||
|
||||
def _display_supply_warnings(*, needs_replacement: bool, warnings: list[str]) -> None:
|
||||
"""Display supply level warnings summary."""
|
||||
_out()
|
||||
if needs_replacement:
|
||||
_out(f"{RED}{BOLD}⚠ ACTION NEEDED:{RESET}")
|
||||
for w in warnings:
|
||||
_out(f" {RED}• {w}{RESET}")
|
||||
elif warnings:
|
||||
_out(f"{YELLOW}{BOLD}⚡ HEADS UP:{RESET}")
|
||||
for w in warnings:
|
||||
_out(f" {YELLOW}• {w}{RESET}")
|
||||
else:
|
||||
_out(f"{GREEN}{BOLD}✓ All consumables are at healthy levels.{RESET}")
|
||||
|
||||
|
||||
def _parse_supply_value(values: list[str], index: int) -> int:
|
||||
"""Safely parse an integer from a supply value list."""
|
||||
try:
|
||||
return int(values[index])
|
||||
except (IndexError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _collect_supply_items(
|
||||
result: NetworkResult,
|
||||
) -> tuple[list[SupplyStatus], list[str]]:
|
||||
"""Parse and collect supply items with their descriptions."""
|
||||
items: list[SupplyStatus] = []
|
||||
descs: list[str] = []
|
||||
for i, desc in enumerate(result.supply_descriptions):
|
||||
max_val = _parse_supply_value(result.supply_max, i)
|
||||
level = _parse_supply_value(result.supply_levels, i)
|
||||
items.append(_process_supply_item(desc, max_val, level))
|
||||
descs.append(desc)
|
||||
return items, descs
|
||||
|
||||
|
||||
def _display_supply_levels(result: NetworkResult) -> None:
|
||||
"""Display consumable supply levels section."""
|
||||
_out()
|
||||
_out(f"{BOLD}── Consumable Levels ──{RESET}")
|
||||
_out()
|
||||
|
||||
needs_replacement = False
|
||||
warnings: list[str] = []
|
||||
items, descs = _collect_supply_items(result)
|
||||
|
||||
for desc, item in zip(descs, items, strict=True):
|
||||
_out(
|
||||
f" {BOLD}{desc:<25}{RESET}"
|
||||
f" {item.color}{item.bar} {item.status_text}{RESET}"
|
||||
)
|
||||
if item.needs_replacement:
|
||||
needs_replacement = True
|
||||
if item.warning:
|
||||
warnings.append(item.warning)
|
||||
|
||||
_display_supply_warnings(needs_replacement=needs_replacement, warnings=warnings)
|
||||
|
||||
|
||||
def _display_network_device_info(result: NetworkResult) -> None:
|
||||
"""Display device info section for network results."""
|
||||
_out(f"{BOLD}Printer:{RESET} {result.product or 'Unknown'}")
|
||||
_out(f"{BOLD}Connection:{RESET} Network ({result.ip})")
|
||||
if result.serial:
|
||||
_out(f"{BOLD}Serial:{RESET} {result.serial}")
|
||||
if result.display:
|
||||
_out(f"{BOLD}Display:{RESET} {result.display}")
|
||||
if result.page_count and result.page_count.isdigit():
|
||||
_out(f"{BOLD}Pages:{RESET} {result.page_count} total")
|
||||
|
||||
|
||||
# ── Network results display ──────────────────────────────────────────
|
||||
|
||||
|
||||
def display_network_results(result: NetworkResult) -> None:
|
||||
"""Print a formatted report for SNMP network query results."""
|
||||
if result.error:
|
||||
_out(f"{RED}Error: {result.error}{RESET}")
|
||||
sys.exit(1)
|
||||
|
||||
_display_report_header()
|
||||
_display_network_device_info(result)
|
||||
_display_supply_levels(result)
|
||||
|
||||
_out()
|
||||
_out(
|
||||
f"{CYAN}Tip: Visit http://{result.ip} for the full web management"
|
||||
f" interface.{RESET}"
|
||||
)
|
||||
_out()
|
||||
97
python_pkg/brother_printer/network_query.py
Normal file
97
python_pkg/brother_printer/network_query.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""SNMP network query functions for Brother printers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
from python_pkg.brother_printer.data_classes import NetworkResult
|
||||
|
||||
|
||||
def _snmpwalk_cmd(
|
||||
path: str, community: str, timeout: int, ip: str, oid: str
|
||||
) -> list[str]:
|
||||
"""Build the snmpwalk command arguments."""
|
||||
return [path, "-v", "2c", "-c", community, "-t", str(timeout), "-OQvs", ip, oid]
|
||||
|
||||
|
||||
def snmp_walk(ip: str, oid: str, community: str, timeout: int) -> list[str]:
|
||||
"""Run snmpwalk and return cleaned values."""
|
||||
snmpwalk_path = shutil.which("snmpwalk")
|
||||
if not snmpwalk_path:
|
||||
return []
|
||||
try:
|
||||
r = subprocess.run(
|
||||
_snmpwalk_cmd(snmpwalk_path, community, timeout, ip, oid),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
check=False,
|
||||
)
|
||||
return [
|
||||
line.strip().strip('"')
|
||||
for line in r.stdout.strip().splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
return []
|
||||
|
||||
|
||||
def _snmpget_cmd(
|
||||
path: str, community: str, timeout: int, ip: str, oid: str
|
||||
) -> list[str]:
|
||||
"""Build the snmpget command arguments."""
|
||||
return [path, "-v", "2c", "-c", community, "-t", str(timeout), ip, oid]
|
||||
|
||||
|
||||
def _check_snmp_connectivity(ip: str, community: str, timeout: int) -> str | None:
|
||||
"""Verify SNMP connectivity. Returns error message or None on success."""
|
||||
snmpget_path = shutil.which("snmpget")
|
||||
if not snmpget_path:
|
||||
return "snmpget not found. Install: sudo pacman -S net-snmp"
|
||||
try:
|
||||
subprocess.run(
|
||||
_snmpget_cmd(
|
||||
snmpget_path,
|
||||
community,
|
||||
timeout,
|
||||
ip,
|
||||
"1.3.6.1.2.1.43.11.1.1.6.1.1",
|
||||
),
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=True,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, OSError):
|
||||
return f"Cannot reach printer at {ip} via SNMP."
|
||||
return None
|
||||
|
||||
|
||||
def _build_network_result(ip: str, community: str, timeout: int) -> NetworkResult:
|
||||
"""Collect all SNMP data into a NetworkResult."""
|
||||
|
||||
def walk(oid: str) -> list[str]:
|
||||
return snmp_walk(ip, oid, community, timeout)
|
||||
|
||||
return NetworkResult(
|
||||
ip=ip,
|
||||
product=" ".join(walk("1.3.6.1.2.1.25.3.2.1.3")[:1]) or "Unknown",
|
||||
serial=" ".join(walk("1.3.6.1.2.1.43.5.1.1.17")[:1]) or "",
|
||||
printer_status=" ".join(walk("1.3.6.1.2.1.25.3.5.1.1")[:1]) or "",
|
||||
device_status=" ".join(walk("1.3.6.1.2.1.25.3.2.1.5")[:1]) or "",
|
||||
display=" ".join(walk("1.3.6.1.2.1.43.16.5.1.2")[:3]) or "",
|
||||
page_count=" ".join(walk("1.3.6.1.2.1.43.10.2.1.4")[:1]) or "",
|
||||
supply_descriptions=walk("1.3.6.1.2.1.43.11.1.1.6"),
|
||||
supply_max=walk("1.3.6.1.2.1.43.11.1.1.8"),
|
||||
supply_levels=walk("1.3.6.1.2.1.43.11.1.1.9"),
|
||||
)
|
||||
|
||||
|
||||
def query_network_snmp(ip: str) -> NetworkResult:
|
||||
"""Query a Brother printer via SNMP over the network."""
|
||||
community = "public"
|
||||
timeout = 5
|
||||
error = _check_snmp_connectivity(ip, community, timeout)
|
||||
if error:
|
||||
return NetworkResult(ip=ip, error=error)
|
||||
return _build_network_result(ip, community, timeout)
|
||||
233
python_pkg/brother_printer/usb_query.py
Normal file
233
python_pkg/brother_printer/usb_query.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""USB printer discovery and PJL query functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import fcntl
|
||||
import os
|
||||
from pathlib import Path
|
||||
import select
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import urllib.parse
|
||||
|
||||
from python_pkg.brother_printer.data_classes import USBResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── USB printer discovery ────────────────────────────────────────────
|
||||
|
||||
|
||||
def find_brother_usb() -> str:
|
||||
"""Look for any Brother printer on USB via lsusb. Returns the info line."""
|
||||
if not shutil.which("lsusb"):
|
||||
return ""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["/usr/bin/lsusb"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if "04f9:" in line.lower():
|
||||
return line.split(": ", 1)[1] if ": " in line else line
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def find_usb_printer_dev() -> str | None:
|
||||
"""Find /dev/usb/lp* device for the Brother printer."""
|
||||
devices = sorted(Path("/dev/usb").glob("lp*"))
|
||||
return str(devices[0]) if devices else None
|
||||
|
||||
|
||||
def _parse_cups_usb_uri(uri: str, info: dict[str, str]) -> None:
|
||||
"""Extract product and serial from a CUPS usb:// URI."""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
info["product"] = urllib.parse.unquote(parsed.path.lstrip("/"))
|
||||
qs = urllib.parse.parse_qs(parsed.query)
|
||||
if "serial" in qs:
|
||||
info["serial"] = qs["serial"][0]
|
||||
|
||||
|
||||
def get_printer_info_from_cups() -> dict[str, str]:
|
||||
"""Get printer model/serial from lpstat."""
|
||||
info: dict[str, str] = {"product": "", "serial": ""}
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["/usr/bin/lpstat", "-v"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
for line in r.stdout.splitlines():
|
||||
if "Brother" in line:
|
||||
for part in line.split():
|
||||
if part.startswith("usb://"):
|
||||
_parse_cups_usb_uri(part, info)
|
||||
break
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError):
|
||||
logger.debug("Failed to query CUPS for printer info", exc_info=True)
|
||||
return info
|
||||
|
||||
|
||||
# ── PJL over USB ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _drain_buffer(fd: int) -> None:
|
||||
"""Read and discard any stale data from the USB buffer."""
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
with contextlib.suppress(OSError):
|
||||
while os.read(fd, 4096):
|
||||
pass
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
|
||||
|
||||
def _read_nonblocking(fd: int, flags: int) -> bytes:
|
||||
"""Read all available data from fd in non-blocking mode."""
|
||||
data = b""
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
with contextlib.suppress(OSError):
|
||||
while True:
|
||||
chunk = os.read(fd, 4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
return data
|
||||
|
||||
|
||||
def _wait_for_pjl_response(fd: int, flags: int, deadline: float) -> bytes:
|
||||
"""Poll fd until PJL data arrives or deadline expires."""
|
||||
response = b""
|
||||
while time.time() < deadline:
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
readable, _, _ = select.select([fd], [], [], min(remaining, 1.0))
|
||||
if readable:
|
||||
response += _read_nonblocking(fd, flags)
|
||||
if response and (b"=" in response or b"@PJL" in response):
|
||||
break
|
||||
return response
|
||||
|
||||
|
||||
def pjl_query(fd: int, cmd: str, timeout_sec: float = 5.0) -> str:
|
||||
"""Send a PJL command via raw fd and read the response."""
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
|
||||
pjl_cmd = f"\x1b%-12345X@PJL\r\n{cmd}\r\n\x1b%-12345X"
|
||||
os.write(fd, pjl_cmd.encode())
|
||||
|
||||
deadline = time.time() + timeout_sec
|
||||
response = _wait_for_pjl_response(fd, flags, deadline)
|
||||
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
return response.decode("ascii", errors="replace")
|
||||
|
||||
|
||||
def _parse_status(resp: str, result: USBResult) -> bool:
|
||||
"""Parse STATUS response into result. Returns True if code was found."""
|
||||
found = False
|
||||
for raw_line in resp.splitlines():
|
||||
stripped = raw_line.strip()
|
||||
if stripped.startswith("CODE="):
|
||||
result.status_code = stripped.split("=", 1)[1]
|
||||
found = True
|
||||
elif stripped.startswith("DISPLAY="):
|
||||
result.display = stripped.split("=", 1)[1].strip().strip('"').strip()
|
||||
elif stripped.startswith("ONLINE="):
|
||||
result.online = stripped.split("=", 1)[1]
|
||||
return found
|
||||
|
||||
|
||||
def _parse_variables(resp: str, result: USBResult) -> bool:
|
||||
"""Parse VARIABLES response into result. Returns True if data found."""
|
||||
found = False
|
||||
for raw_line in resp.splitlines():
|
||||
stripped = raw_line.strip()
|
||||
if stripped.startswith("ECONOMODE="):
|
||||
result.economode = stripped.split("=", 1)[1].split()[0]
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
def _retry_pjl_query(
|
||||
fd: int,
|
||||
cmd: str,
|
||||
parser: Callable[[str, USBResult], bool],
|
||||
result: USBResult,
|
||||
max_retries: int,
|
||||
) -> None:
|
||||
"""Send a PJL query with retries, draining between attempts."""
|
||||
for attempt in range(max_retries + 1):
|
||||
resp = pjl_query(fd, cmd)
|
||||
if parser(resp, result):
|
||||
break
|
||||
if attempt < max_retries:
|
||||
_drain_buffer(fd)
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def _run_pjl_queries(fd: int, result: USBResult, max_retries: int) -> None:
|
||||
"""Execute PJL query sequence on an open file descriptor."""
|
||||
_drain_buffer(fd)
|
||||
|
||||
os.write(fd, b"\x1b%-12345X@PJL\r\n\x1b%-12345X")
|
||||
time.sleep(0.5)
|
||||
_drain_buffer(fd)
|
||||
|
||||
_retry_pjl_query(fd, "@PJL INFO STATUS", _parse_status, result, max_retries)
|
||||
_drain_buffer(fd)
|
||||
time.sleep(0.5)
|
||||
_retry_pjl_query(fd, "@PJL INFO VARIABLES", _parse_variables, result, max_retries)
|
||||
|
||||
|
||||
def _init_usb_result(dev_path: str) -> USBResult:
|
||||
"""Create a USBResult with device info from CUPS."""
|
||||
cups_info = get_printer_info_from_cups()
|
||||
return USBResult(
|
||||
device=dev_path,
|
||||
product=cups_info.get("product") or "Brother Laser Printer",
|
||||
serial=cups_info.get("serial", ""),
|
||||
)
|
||||
|
||||
|
||||
def query_usb_pjl(max_retries: int = 2) -> USBResult:
|
||||
"""Query a Brother printer via PJL over /dev/usb/lp*."""
|
||||
dev_path = find_usb_printer_dev()
|
||||
if not dev_path:
|
||||
from python_pkg.brother_printer.cups_service import query_usb_via_cups
|
||||
|
||||
return query_usb_via_cups()
|
||||
|
||||
result = _init_usb_result(dev_path)
|
||||
if not os.access(dev_path, os.R_OK | os.W_OK):
|
||||
result.error = f"Permission denied: {dev_path}. Run with sudo."
|
||||
return result
|
||||
|
||||
fd: int | None = None
|
||||
try:
|
||||
fd = os.open(dev_path, os.O_RDWR)
|
||||
fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
_run_pjl_queries(fd, result, max_retries)
|
||||
except OSError as e:
|
||||
result.error = str(e)
|
||||
finally:
|
||||
if fd is not None:
|
||||
os.close(fd)
|
||||
return result
|
||||
File diff suppressed because it is too large
Load Diff
206
python_pkg/geo_data/__init__.py
Normal file
206
python_pkg/geo_data/__init__.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""Shared geographic data module for Warsaw and Poland Anki generators.
|
||||
|
||||
This module handles downloading and caching geographic data from various sources:
|
||||
- OpenStreetMap via Overpass API
|
||||
- Geofabrik OSM extracts
|
||||
- GitHub repositories with pre-processed GeoJSON
|
||||
|
||||
All data is cached locally to avoid repeated downloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from python_pkg.geo_data._common import (
|
||||
CACHE_DIR,
|
||||
MAX_RETRIES,
|
||||
MIN_LAKE_AREA_KM2,
|
||||
MIN_LINE_COORDS,
|
||||
MIN_PEAK_ELEVATION,
|
||||
MIN_RING_COORDS,
|
||||
MIN_RIVER_LENGTH_KM,
|
||||
OVERPASS_ENDPOINTS,
|
||||
POLSKA_GEOJSON_BASE,
|
||||
REQUEST_TIMEOUT,
|
||||
RETRY_DELAY,
|
||||
WIKIDATA_SPARQL,
|
||||
)
|
||||
from python_pkg.geo_data._poland_admin import (
|
||||
get_poland_boundary,
|
||||
get_polish_gminy,
|
||||
get_polish_powiaty,
|
||||
get_polish_wojewodztwa,
|
||||
)
|
||||
from python_pkg.geo_data._poland_nature import (
|
||||
get_polish_forests,
|
||||
get_polish_landscape_parks,
|
||||
get_polish_mountain_peaks,
|
||||
get_polish_mountain_ranges,
|
||||
get_polish_national_parks,
|
||||
get_polish_nature_reserves,
|
||||
)
|
||||
from python_pkg.geo_data._poland_water import (
|
||||
get_polish_coastal_features,
|
||||
get_polish_islands,
|
||||
get_polish_lakes,
|
||||
get_polish_rivers,
|
||||
get_polish_unesco_sites,
|
||||
)
|
||||
from python_pkg.geo_data._warsaw import (
|
||||
get_vistula_river,
|
||||
get_warsaw_boundary,
|
||||
get_warsaw_bridges,
|
||||
get_warsaw_districts,
|
||||
get_warsaw_metro_stations,
|
||||
get_warsaw_osiedla,
|
||||
)
|
||||
from python_pkg.geo_data._warsaw_places import get_warsaw_landmarks, get_warsaw_streets
|
||||
|
||||
__all__ = [
|
||||
"CACHE_DIR",
|
||||
"MAX_RETRIES",
|
||||
"MIN_LAKE_AREA_KM2",
|
||||
"MIN_LINE_COORDS",
|
||||
"MIN_PEAK_ELEVATION",
|
||||
"MIN_RING_COORDS",
|
||||
"MIN_RIVER_LENGTH_KM",
|
||||
"OVERPASS_ENDPOINTS",
|
||||
"POLSKA_GEOJSON_BASE",
|
||||
"REQUEST_TIMEOUT",
|
||||
"RETRY_DELAY",
|
||||
"WIKIDATA_SPARQL",
|
||||
"clear_cache",
|
||||
"download_all_poland_data",
|
||||
"download_all_warsaw_data",
|
||||
"get_poland_boundary",
|
||||
"get_polish_coastal_features",
|
||||
"get_polish_forests",
|
||||
"get_polish_gminy",
|
||||
"get_polish_islands",
|
||||
"get_polish_lakes",
|
||||
"get_polish_landscape_parks",
|
||||
"get_polish_mountain_peaks",
|
||||
"get_polish_mountain_ranges",
|
||||
"get_polish_national_parks",
|
||||
"get_polish_nature_reserves",
|
||||
"get_polish_powiaty",
|
||||
"get_polish_rivers",
|
||||
"get_polish_unesco_sites",
|
||||
"get_polish_wojewodztwa",
|
||||
"get_vistula_river",
|
||||
"get_warsaw_boundary",
|
||||
"get_warsaw_bridges",
|
||||
"get_warsaw_districts",
|
||||
"get_warsaw_landmarks",
|
||||
"get_warsaw_metro_stations",
|
||||
"get_warsaw_osiedla",
|
||||
"get_warsaw_streets",
|
||||
]
|
||||
|
||||
|
||||
def download_all_warsaw_data() -> None:
|
||||
"""Download and cache all Warsaw geographic data.
|
||||
|
||||
Call this once to pre-populate the cache.
|
||||
"""
|
||||
sys.stdout.write("Downloading all Warsaw geographic data...\n")
|
||||
sys.stdout.write("=" * 60 + "\n")
|
||||
|
||||
sys.stdout.write("\n1. Warsaw boundary...\n")
|
||||
get_warsaw_boundary()
|
||||
|
||||
sys.stdout.write("\n2. Vistula river...\n")
|
||||
get_vistula_river()
|
||||
|
||||
sys.stdout.write("\n3. Warsaw bridges...\n")
|
||||
get_warsaw_bridges()
|
||||
|
||||
sys.stdout.write("\n4. Metro stations...\n")
|
||||
get_warsaw_metro_stations()
|
||||
|
||||
sys.stdout.write("\n5. Major streets...\n")
|
||||
get_warsaw_streets()
|
||||
|
||||
sys.stdout.write("\n6. Landmarks...\n")
|
||||
get_warsaw_landmarks()
|
||||
|
||||
sys.stdout.write("\n7. Osiedla...\n")
|
||||
get_warsaw_osiedla()
|
||||
|
||||
sys.stdout.write("\n" + "=" * 60 + "\n")
|
||||
sys.stdout.write("All Warsaw data cached successfully!\n")
|
||||
|
||||
|
||||
def download_all_poland_data() -> None:
|
||||
"""Download and cache all Poland geographic data.
|
||||
|
||||
Call this once to pre-populate the cache.
|
||||
"""
|
||||
sys.stdout.write("Downloading all Poland geographic data...\n")
|
||||
sys.stdout.write("=" * 60 + "\n")
|
||||
|
||||
sys.stdout.write("\n1. Województwa...\n")
|
||||
get_polish_wojewodztwa()
|
||||
|
||||
sys.stdout.write("\n2. Powiaty...\n")
|
||||
get_polish_powiaty()
|
||||
|
||||
sys.stdout.write("\n3. Gminy (this may take a while)...\n")
|
||||
get_polish_gminy()
|
||||
|
||||
sys.stdout.write("\n4. Poland boundary...\n")
|
||||
get_poland_boundary()
|
||||
|
||||
sys.stdout.write("\n" + "=" * 60 + "\n")
|
||||
sys.stdout.write("All Poland data cached successfully!\n")
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear all cached data."""
|
||||
if CACHE_DIR.exists():
|
||||
shutil.rmtree(CACHE_DIR)
|
||||
sys.stdout.write("Cache cleared.\n")
|
||||
else:
|
||||
sys.stdout.write("Cache directory does not exist.\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Manage geographic data cache")
|
||||
parser.add_argument(
|
||||
"--download-warsaw",
|
||||
action="store_true",
|
||||
help="Download all Warsaw data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-poland",
|
||||
action="store_true",
|
||||
help="Download all Poland data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-all",
|
||||
action="store_true",
|
||||
help="Download all data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear-cache",
|
||||
action="store_true",
|
||||
help="Clear cached data",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.clear_cache:
|
||||
clear_cache()
|
||||
elif args.download_warsaw:
|
||||
download_all_warsaw_data()
|
||||
elif args.download_poland:
|
||||
download_all_poland_data()
|
||||
elif args.download_all:
|
||||
download_all_warsaw_data()
|
||||
download_all_poland_data()
|
||||
else:
|
||||
parser.print_help()
|
||||
318
python_pkg/geo_data/_common.py
Normal file
318
python_pkg/geo_data/_common.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""Common utilities for geographic data operations.
|
||||
|
||||
Shared constants, API helpers, and geometry extraction functions used
|
||||
across the geo_data package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.request import urlopen
|
||||
|
||||
import geopandas as gpd
|
||||
import requests
|
||||
from shapely.geometry import (
|
||||
GeometryCollection,
|
||||
MultiPolygon,
|
||||
Polygon,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
# Parent directory of the geo_data package (i.e. python_pkg/)
|
||||
_PKG_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
# Shared cache directory for all geo data
|
||||
CACHE_DIR = _PKG_DIR / "geo_cache"
|
||||
|
||||
# Overpass API endpoints (multiple for redundancy)
|
||||
# Note: kumi.systems is more reliable, so it's first
|
||||
OVERPASS_ENDPOINTS = [
|
||||
"https://overpass.kumi.systems/api/interpreter",
|
||||
"https://overpass-api.de/api/interpreter",
|
||||
"https://maps.mail.ru/osm/tools/overpass/api/interpreter",
|
||||
]
|
||||
|
||||
# GitHub URLs for pre-processed data
|
||||
POLSKA_GEOJSON_BASE = "https://raw.githubusercontent.com/ppatrzyk/polska-geojson/master"
|
||||
|
||||
# Wikidata SPARQL endpoint
|
||||
WIKIDATA_SPARQL = "https://query.wikidata.org/sparql"
|
||||
|
||||
# Request timeout and retry settings
|
||||
REQUEST_TIMEOUT = 180
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 5
|
||||
|
||||
# Data thresholds for filtering
|
||||
MIN_PEAK_ELEVATION = 300 # meters
|
||||
MIN_LAKE_AREA_KM2 = 0.5 # km²
|
||||
MIN_RIVER_LENGTH_KM = 10 # km
|
||||
MIN_LINE_COORDS = 2 # minimum coordinates for a line
|
||||
MIN_RING_COORDS = 4 # minimum coordinates for a polygon ring
|
||||
|
||||
|
||||
def _ensure_cache_dir() -> None:
|
||||
"""Create cache directory if it doesn't exist."""
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _extract_polygonal_geometry(
|
||||
geom: Polygon | MultiPolygon | GeometryCollection,
|
||||
) -> Polygon | MultiPolygon | None:
|
||||
"""Extract only polygonal geometry from a geometry that may be mixed.
|
||||
|
||||
Some OSM data comes as GeometryCollections containing polygons mixed with
|
||||
lines. This function extracts only the polygon/multipolygon parts.
|
||||
|
||||
Args:
|
||||
geom: Input geometry (Polygon, MultiPolygon, or GeometryCollection).
|
||||
|
||||
Returns:
|
||||
Polygon or MultiPolygon with only the polygonal parts, or None if empty.
|
||||
"""
|
||||
if isinstance(geom, Polygon | MultiPolygon):
|
||||
return geom
|
||||
|
||||
if isinstance(geom, GeometryCollection):
|
||||
polygons = [g for g in geom.geoms if isinstance(g, Polygon | MultiPolygon)]
|
||||
if not polygons:
|
||||
return None
|
||||
if len(polygons) == 1:
|
||||
return polygons[0]
|
||||
# Flatten MultiPolygons and combine all polygons
|
||||
all_polys = []
|
||||
for p in polygons:
|
||||
if isinstance(p, Polygon):
|
||||
all_polys.append(p)
|
||||
elif isinstance(p, MultiPolygon):
|
||||
all_polys.extend(p.geoms)
|
||||
return MultiPolygon(all_polys)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _try_single_request(
|
||||
endpoint: str, query: str
|
||||
) -> tuple[dict[str, Any] | None, Exception | None]:
|
||||
"""Try a single request to an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: Overpass API endpoint URL.
|
||||
query: Overpass QL query string.
|
||||
|
||||
Returns:
|
||||
Tuple of (result, error). One will be None.
|
||||
"""
|
||||
try:
|
||||
sys.stdout.write(f" Querying {endpoint}...\n")
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
data={"data": query},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except (requests.RequestException, requests.Timeout, ValueError) as e:
|
||||
return None, e
|
||||
else:
|
||||
# Check for valid response with elements
|
||||
if not isinstance(result, dict) or "elements" not in result:
|
||||
return None, ValueError("Invalid response format")
|
||||
return result, None
|
||||
|
||||
|
||||
def _overpass_query(query: str) -> dict[str, Any]:
|
||||
"""Execute an Overpass API query with retry logic.
|
||||
|
||||
Args:
|
||||
query: Overpass QL query string.
|
||||
|
||||
Returns:
|
||||
JSON response from the API.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all endpoints fail.
|
||||
"""
|
||||
last_error: Exception | None = None
|
||||
|
||||
for endpoint in OVERPASS_ENDPOINTS:
|
||||
for attempt in range(MAX_RETRIES):
|
||||
result, error = _try_single_request(endpoint, query)
|
||||
if result is not None:
|
||||
return result
|
||||
last_error = error
|
||||
sys.stdout.write(f" Attempt {attempt + 1} failed: {error}\n")
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
time.sleep(RETRY_DELAY)
|
||||
|
||||
msg = f"All Overpass API endpoints failed. Last error: {last_error}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def _download_github_geojson(url: str, cache_path: Path) -> gpd.GeoDataFrame:
|
||||
"""Download GeoJSON from GitHub and cache it.
|
||||
|
||||
Args:
|
||||
url: URL to download from.
|
||||
cache_path: Path to cache the data.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with the data.
|
||||
"""
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
sys.stdout.write(f"Downloading from {url}...\n")
|
||||
if not url.startswith(("http://", "https://")):
|
||||
msg = f"Unsupported URL scheme: {url}"
|
||||
raise ValueError(msg)
|
||||
with urlopen(url, timeout=REQUEST_TIMEOUT) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
|
||||
_ensure_cache_dir()
|
||||
cache_path.write_text(json.dumps(data))
|
||||
|
||||
return gpd.GeoDataFrame.from_features(data["features"], crs="EPSG:4326")
|
||||
|
||||
|
||||
def _extract_osiedla_rings(
|
||||
element: dict[str, Any], min_coords: int
|
||||
) -> tuple[list[list[tuple[float, float]]], list[list[tuple[float, float]]]]:
|
||||
"""Extract outer and inner rings from an OSM relation.
|
||||
|
||||
Args:
|
||||
element: OSM relation element.
|
||||
min_coords: Minimum number of coordinates for a valid ring.
|
||||
|
||||
Returns:
|
||||
Tuple of (outer_rings, inner_rings).
|
||||
"""
|
||||
outer_rings: list[list[tuple[float, float]]] = []
|
||||
inner_rings: list[list[tuple[float, float]]] = []
|
||||
|
||||
for member in element.get("members", []):
|
||||
if "geometry" not in member:
|
||||
continue
|
||||
ring = [(p["lon"], p["lat"]) for p in member["geometry"]]
|
||||
if len(ring) < min_coords:
|
||||
continue
|
||||
# Close the ring if not closed
|
||||
if ring[0] != ring[-1]:
|
||||
ring.append(ring[0])
|
||||
if member.get("role") == "outer":
|
||||
outer_rings.append(ring)
|
||||
elif member.get("role") == "inner":
|
||||
inner_rings.append(ring)
|
||||
|
||||
return outer_rings, inner_rings
|
||||
|
||||
|
||||
def _build_osiedla_geometry(
|
||||
outer_rings: list[list[tuple[float, float]]],
|
||||
inner_rings: list[list[tuple[float, float]]],
|
||||
) -> dict[str, Any]:
|
||||
"""Build GeoJSON geometry from outer and inner rings.
|
||||
|
||||
Args:
|
||||
outer_rings: List of outer ring coordinates.
|
||||
inner_rings: List of inner ring coordinates.
|
||||
|
||||
Returns:
|
||||
GeoJSON geometry dict.
|
||||
"""
|
||||
if len(outer_rings) == 1:
|
||||
return {
|
||||
"type": "Polygon",
|
||||
"coordinates": [outer_rings[0], *inner_rings],
|
||||
}
|
||||
# Multiple outer rings - create MultiPolygon
|
||||
# Each polygon in a MultiPolygon is [exterior, hole1, hole2, ...]
|
||||
return {
|
||||
"type": "MultiPolygon",
|
||||
"coordinates": [[ring] for ring in outer_rings],
|
||||
}
|
||||
|
||||
|
||||
def _extract_polygon_from_element(
|
||||
element: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract polygon geometry from an OSM relation or way element.
|
||||
|
||||
Args:
|
||||
element: OSM element (relation or way).
|
||||
|
||||
Returns:
|
||||
GeoJSON geometry dict, or None if extraction fails.
|
||||
"""
|
||||
if element.get("type") == "relation":
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, MIN_RING_COORDS)
|
||||
if not outer_rings:
|
||||
return None
|
||||
return _build_osiedla_geometry(outer_rings, inner_rings)
|
||||
|
||||
if element.get("type") == "way" and "geometry" in element:
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) < MIN_RING_COORDS:
|
||||
return None
|
||||
if coords[0] != coords[-1]:
|
||||
coords.append(coords[0])
|
||||
return {"type": "Polygon", "coordinates": [coords]}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_line_from_way(element: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Extract line geometry from an OSM way element.
|
||||
|
||||
Args:
|
||||
element: OSM way element.
|
||||
|
||||
Returns:
|
||||
GeoJSON LineString geometry dict, or None if extraction fails.
|
||||
"""
|
||||
if element.get("type") != "way" or "geometry" not in element:
|
||||
return None
|
||||
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) < MIN_LINE_COORDS:
|
||||
return None
|
||||
|
||||
return {"type": "LineString", "coordinates": coords}
|
||||
|
||||
|
||||
def _add_area_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
|
||||
"""Add area_km2 column to a GeoDataFrame.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame with polygon geometries.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with area_km2 column added.
|
||||
"""
|
||||
if len(gdf) == 0:
|
||||
return gdf
|
||||
gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system
|
||||
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
|
||||
return gdf
|
||||
|
||||
|
||||
def _add_length_column(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
|
||||
"""Add length_km column to a GeoDataFrame.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame with line geometries.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with length_km column added.
|
||||
"""
|
||||
if len(gdf) == 0:
|
||||
return gdf
|
||||
gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system
|
||||
gdf["length_km"] = gdf_proj.geometry.length / 1000
|
||||
return gdf
|
||||
226
python_pkg/geo_data/_poland_admin.py
Normal file
226
python_pkg/geo_data/_poland_admin.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Polish administrative boundary data.
|
||||
|
||||
Functions for downloading and caching Polish administrative divisions:
|
||||
województwa, powiaty, gminy, and the national boundary.
|
||||
Includes Wikidata integration for population data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import geopandas as gpd
|
||||
import requests
|
||||
|
||||
from python_pkg.geo_data._common import (
|
||||
CACHE_DIR,
|
||||
POLSKA_GEOJSON_BASE,
|
||||
WIKIDATA_SPARQL,
|
||||
_build_osiedla_geometry,
|
||||
_download_github_geojson,
|
||||
_ensure_cache_dir,
|
||||
_extract_osiedla_rings,
|
||||
_overpass_query,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _query_wikidata(query: str) -> list[dict[str, Any]]:
|
||||
"""Query Wikidata SPARQL endpoint.
|
||||
|
||||
Args:
|
||||
query: SPARQL query string.
|
||||
|
||||
Returns:
|
||||
List of result bindings.
|
||||
"""
|
||||
response = requests.get(
|
||||
WIKIDATA_SPARQL,
|
||||
params={"query": query, "format": "json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["results"]["bindings"]
|
||||
|
||||
|
||||
def _get_powiaty_population() -> dict[str, int]:
|
||||
"""Get population data for all Polish powiaty from Wikidata.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping powiat name to population.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "powiaty_population.json"
|
||||
|
||||
if cache_path.exists():
|
||||
return json.loads(cache_path.read_text())
|
||||
|
||||
# Query Wikidata for all powiaty (Q247073) in Poland (Q36) with population
|
||||
# Filter to only current Polish powiaty using country=Poland filter
|
||||
query = """
|
||||
SELECT ?powiat ?powiatLabel ?population WHERE {
|
||||
?powiat wdt:P31 wd:Q247073.
|
||||
?powiat wdt:P17 wd:Q36.
|
||||
?powiat wdt:P1082 ?population.
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "pl,en". }
|
||||
}
|
||||
ORDER BY DESC(?population)
|
||||
"""
|
||||
|
||||
sys.stdout.write("Fetching powiaty population data from Wikidata...\n")
|
||||
results = _query_wikidata(query)
|
||||
|
||||
population_map: dict[str, int] = {}
|
||||
for item in results:
|
||||
label = item.get("powiatLabel", {}).get("value", "")
|
||||
pop = item.get("population", {}).get("value", "0")
|
||||
if label and pop:
|
||||
# Remove "powiat" prefix if present for matching
|
||||
clean_label = label.replace("powiat ", "").strip()
|
||||
with contextlib.suppress(ValueError):
|
||||
population_map[clean_label] = int(pop)
|
||||
|
||||
_ensure_cache_dir()
|
||||
cache_path.write_text(json.dumps(population_map, ensure_ascii=False, indent=2))
|
||||
|
||||
sys.stdout.write(f"Cached population data for {len(population_map)} powiaty.\n")
|
||||
return population_map
|
||||
|
||||
|
||||
def get_polish_wojewodztwa() -> gpd.GeoDataFrame:
|
||||
"""Get Polish województwa (voivodeships).
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with województwa boundaries.
|
||||
"""
|
||||
url = f"{POLSKA_GEOJSON_BASE}/wojewodztwa/wojewodztwa-min.geojson"
|
||||
cache_path = CACHE_DIR / "polish_wojewodztwa.geojson"
|
||||
return _download_github_geojson(url, cache_path)
|
||||
|
||||
|
||||
def get_polish_powiaty() -> gpd.GeoDataFrame:
|
||||
"""Get Polish powiaty (counties), sorted by population descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with powiat boundaries and population.
|
||||
"""
|
||||
url = f"{POLSKA_GEOJSON_BASE}/powiaty/powiaty-min.geojson"
|
||||
cache_path = CACHE_DIR / "polish_powiaty.geojson"
|
||||
gdf = _download_github_geojson(url, cache_path)
|
||||
|
||||
# Get population data from Wikidata
|
||||
population_map = _get_powiaty_population()
|
||||
|
||||
# Add population column
|
||||
def get_population(nazwa: str) -> int:
|
||||
"""Match powiat name to population data."""
|
||||
if not nazwa:
|
||||
return 0
|
||||
# Remove "powiat " prefix for matching
|
||||
clean_name = nazwa.replace("powiat ", "").strip()
|
||||
# Try direct match
|
||||
if clean_name in population_map:
|
||||
return population_map[clean_name]
|
||||
# Try lowercase
|
||||
name_lower = clean_name.lower()
|
||||
for pop_name, pop in population_map.items():
|
||||
if pop_name.lower() == name_lower:
|
||||
return pop
|
||||
return 0
|
||||
|
||||
gdf["population"] = gdf["nazwa"].apply(get_population)
|
||||
|
||||
# Sort by population descending
|
||||
return gdf.sort_values("population", ascending=False).reset_index(drop=True)
|
||||
|
||||
|
||||
def get_polish_gminy() -> gpd.GeoDataFrame:
|
||||
"""Get Polish gminy (municipalities) from OSM, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with gminy boundaries.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_gminy.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching gminy data from OSM (this may take a while)...\n")
|
||||
# Polish gminy are admin_level=7 in OSM
|
||||
query = """
|
||||
[out:json][timeout:300];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
relation["boundary"="administrative"]["admin_level"="7"]["name"](area.pl);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
min_ring_coords = 4
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") != "relation":
|
||||
continue
|
||||
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
|
||||
if not outer_rings:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name},
|
||||
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} gminy.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
# Add area column
|
||||
from python_pkg.geo_data._common import _add_area_column
|
||||
|
||||
gdf = _add_area_column(gdf)
|
||||
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
|
||||
|
||||
def get_poland_boundary() -> gpd.GeoDataFrame:
|
||||
"""Get Poland country boundary.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with Poland boundary.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "poland_boundary.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
# Dissolve from województwa
|
||||
woj = get_polish_wojewodztwa()
|
||||
# Fix invalid geometries with buffer(0)
|
||||
woj["geometry"] = woj["geometry"].buffer(0)
|
||||
poland = gpd.GeoDataFrame(geometry=[woj.union_all()], crs=woj.crs)
|
||||
|
||||
_ensure_cache_dir()
|
||||
poland.to_file(cache_path, driver="GeoJSON")
|
||||
|
||||
return poland
|
||||
446
python_pkg/geo_data/_poland_nature.py
Normal file
446
python_pkg/geo_data/_poland_nature.py
Normal file
@ -0,0 +1,446 @@
|
||||
"""Polish natural land features.
|
||||
|
||||
Functions for downloading and caching data about Polish mountains,
|
||||
national parks, forests, nature reserves, and landscape parks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import geopandas as gpd
|
||||
|
||||
from python_pkg.geo_data._common import (
|
||||
CACHE_DIR,
|
||||
MIN_PEAK_ELEVATION,
|
||||
_add_area_column,
|
||||
_build_osiedla_geometry,
|
||||
_ensure_cache_dir,
|
||||
_extract_osiedla_rings,
|
||||
_extract_polygon_from_element,
|
||||
_extract_polygonal_geometry,
|
||||
_overpass_query,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_polish_mountain_peaks() -> gpd.GeoDataFrame:
|
||||
"""Get Polish mountain peaks, sorted by elevation descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with mountain peak points and elevation.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_mountain_peaks.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
return gdf.sort_values("elevation", ascending=False).reset_index(drop=True)
|
||||
|
||||
sys.stdout.write("Fetching mountain peaks data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:120];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
node["natural"="peak"]["name"]["ele"](area.pl);
|
||||
);
|
||||
out;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") != "node":
|
||||
continue
|
||||
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
ele_str = element.get("tags", {}).get("ele", "")
|
||||
|
||||
if not name or not ele_str or name in seen_names:
|
||||
continue
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
elevation = float(ele_str.replace(",", ".").split()[0])
|
||||
if elevation < MIN_PEAK_ELEVATION:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name, "elevation": elevation},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [element["lon"], element["lat"]],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} mountain peaks.\n")
|
||||
|
||||
if not features:
|
||||
msg = "No mountain peaks found in OSM data"
|
||||
raise ValueError(msg)
|
||||
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
return gdf.sort_values("elevation", ascending=False).reset_index(drop=True)
|
||||
|
||||
|
||||
def get_polish_mountain_ranges() -> gpd.GeoDataFrame:
|
||||
"""Get Polish mountain ranges, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with mountain range polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_mountain_ranges.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
# Fix invalid geometries from OSM data and extract only polygons
|
||||
gdf["geometry"] = gdf.geometry.make_valid()
|
||||
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
|
||||
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching mountain ranges data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:180];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["natural"="mountain_range"]["name"](area.pl);
|
||||
way["natural"="mountain_range"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features: list[dict[str, Any]] = []
|
||||
seen_names: set[str] = set()
|
||||
min_ring_coords = 4
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
if element.get("type") == "relation":
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
|
||||
if not outer_rings:
|
||||
continue
|
||||
geometry = _build_osiedla_geometry(outer_rings, inner_rings)
|
||||
elif element.get("type") == "way" and "geometry" in element:
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) < min_ring_coords:
|
||||
continue
|
||||
if coords[0] != coords[-1]:
|
||||
coords.append(coords[0])
|
||||
geometry = {"type": "Polygon", "coordinates": [coords]}
|
||||
else:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} mountain ranges.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
# Fix invalid geometries from OSM data and extract only polygons
|
||||
gdf["geometry"] = gdf.geometry.make_valid()
|
||||
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
|
||||
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
|
||||
|
||||
# Calculate area in km²
|
||||
gdf_proj = gdf.to_crs("EPSG:2180") # Polish coordinate system
|
||||
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
|
||||
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
|
||||
|
||||
def get_polish_national_parks() -> gpd.GeoDataFrame:
|
||||
"""Get all 23 Polish national parks, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with national park polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_national_parks.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching national parks data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:180];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["boundary"="national_park"]["name"](area.pl);
|
||||
relation["leisure"="nature_reserve"]["name"]["protect_class"="2"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
min_ring_coords = 4
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") != "relation":
|
||||
continue
|
||||
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
# Filter to only include "Park Narodowy" in name
|
||||
if "Narodowy" not in name and "narodowy" not in name.lower():
|
||||
continue
|
||||
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
|
||||
if not outer_rings:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name},
|
||||
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} national parks.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
# Calculate area in km²
|
||||
gdf_proj = gdf.to_crs("EPSG:2180")
|
||||
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
|
||||
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
|
||||
|
||||
def get_polish_forests() -> gpd.GeoDataFrame:
|
||||
"""Get major Polish forests (puszcze), sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with forest polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_forests.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching forests data from OSM...\n")
|
||||
# Query for named forests, especially "Puszcza" type
|
||||
query = """
|
||||
[out:json][timeout:300];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["natural"="wood"]["name"](area.pl);
|
||||
relation["landuse"="forest"]["name"~"Puszcza|Bory|Las"](area.pl);
|
||||
way["natural"="wood"]["name"~"Puszcza|Bory"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
forest_keywords = ("Puszcza", "Bory", "Las ", "Lasy ")
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
if not any(keyword in name for keyword in forest_keywords):
|
||||
continue
|
||||
|
||||
geometry = _extract_polygon_from_element(element)
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} forests.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
gdf = _add_area_column(gdf)
|
||||
|
||||
if len(gdf) > 0:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
|
||||
def get_polish_nature_reserves() -> gpd.GeoDataFrame:
|
||||
"""Get Polish nature reserves, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with nature reserve polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_nature_reserves.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write(
|
||||
"Fetching nature reserves data from OSM (this may take a while)...\n"
|
||||
)
|
||||
query = """
|
||||
[out:json][timeout:600];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["leisure"="nature_reserve"]["name"](area.pl);
|
||||
way["leisure"="nature_reserve"]["name"](area.pl);
|
||||
relation["boundary"="protected_area"]["protect_class"="4"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
geometry = _extract_polygon_from_element(element)
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} nature reserves.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
gdf = _add_area_column(gdf)
|
||||
|
||||
if len(gdf) > 0:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
|
||||
def get_polish_landscape_parks() -> gpd.GeoDataFrame:
|
||||
"""Get Polish landscape parks, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with landscape park polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_landscape_parks.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
# Fix invalid geometries from OSM data and extract only polygons
|
||||
gdf["geometry"] = gdf.geometry.make_valid()
|
||||
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
|
||||
# Remove any rows where geometry extraction failed
|
||||
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching landscape parks data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:300];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["boundary"="protected_area"]["protect_class"="5"]["name"](area.pl);
|
||||
relation["leisure"="nature_reserve"]["name"~"Park Krajobrazowy"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
min_ring_coords = 4
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") != "relation":
|
||||
continue
|
||||
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
|
||||
if not outer_rings:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name},
|
||||
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} landscape parks.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
# Fix invalid geometries from OSM data and extract only polygons
|
||||
gdf["geometry"] = gdf.geometry.make_valid()
|
||||
gdf["geometry"] = gdf.geometry.apply(_extract_polygonal_geometry)
|
||||
# Remove any rows where geometry extraction failed
|
||||
gdf = gdf[gdf.geometry.notna() & ~gdf.geometry.is_empty]
|
||||
|
||||
if len(gdf) > 0:
|
||||
gdf_proj = gdf.to_crs("EPSG:2180")
|
||||
gdf["area_km2"] = gdf_proj.geometry.area / 1_000_000
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
|
||||
return gdf
|
||||
437
python_pkg/geo_data/_poland_water.py
Normal file
437
python_pkg/geo_data/_poland_water.py
Normal file
@ -0,0 +1,437 @@
|
||||
"""Polish water features and cultural sites.
|
||||
|
||||
Functions for downloading and caching data about Polish lakes, rivers,
|
||||
islands, coastal features, and UNESCO World Heritage sites.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import geopandas as gpd
|
||||
|
||||
from python_pkg.geo_data._common import (
|
||||
CACHE_DIR,
|
||||
MIN_LAKE_AREA_KM2,
|
||||
MIN_LINE_COORDS,
|
||||
MIN_RING_COORDS,
|
||||
MIN_RIVER_LENGTH_KM,
|
||||
_add_area_column,
|
||||
_add_length_column,
|
||||
_build_osiedla_geometry,
|
||||
_ensure_cache_dir,
|
||||
_extract_osiedla_rings,
|
||||
_extract_polygon_from_element,
|
||||
_overpass_query,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _extract_coastal_geometry(
|
||||
element: dict[str, Any],
|
||||
natural_type: str,
|
||||
line_types: tuple[str, ...],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract geometry from a coastal feature element.
|
||||
|
||||
For cliffs and beaches, returns LineString. For others, returns Polygon.
|
||||
|
||||
Args:
|
||||
element: OSM element.
|
||||
natural_type: The natural= tag value.
|
||||
line_types: Tuple of natural types that should be lines.
|
||||
|
||||
Returns:
|
||||
GeoJSON geometry dict, or None if extraction fails.
|
||||
"""
|
||||
if element.get("type") == "relation":
|
||||
return _extract_polygon_from_element(element)
|
||||
|
||||
if element.get("type") != "way" or "geometry" not in element:
|
||||
return None
|
||||
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) < MIN_LINE_COORDS:
|
||||
return None
|
||||
|
||||
# For cliffs and beaches, keep as linestring
|
||||
if natural_type in line_types:
|
||||
return {"type": "LineString", "coordinates": coords}
|
||||
|
||||
# Otherwise try to make a polygon
|
||||
if len(coords) >= MIN_RING_COORDS:
|
||||
if coords[0] != coords[-1]:
|
||||
coords.append(coords[0])
|
||||
return {"type": "Polygon", "coordinates": [coords]}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_river_coords_from_element(
|
||||
element: dict[str, Any],
|
||||
) -> list[list[tuple[float, float]]]:
|
||||
"""Extract coordinate lists from a river element.
|
||||
|
||||
Args:
|
||||
element: OSM element (way or relation).
|
||||
|
||||
Returns:
|
||||
List of coordinate lists (line segments).
|
||||
"""
|
||||
coord_lists: list[list[tuple[float, float]]] = []
|
||||
|
||||
if element.get("type") == "way" and "geometry" in element:
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) >= MIN_LINE_COORDS:
|
||||
coord_lists.append(coords)
|
||||
elif element.get("type") == "relation":
|
||||
for member in element.get("members", []):
|
||||
if member.get("type") == "way" and "geometry" in member:
|
||||
coords = [(p["lon"], p["lat"]) for p in member["geometry"]]
|
||||
if len(coords) >= MIN_LINE_COORDS:
|
||||
coord_lists.append(coords)
|
||||
|
||||
return coord_lists
|
||||
|
||||
|
||||
def get_polish_lakes() -> gpd.GeoDataFrame:
|
||||
"""Get Polish lakes, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with lake polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_lakes.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching lakes data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:300];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["natural"="water"]["water"="lake"]["name"](area.pl);
|
||||
way["natural"="water"]["water"="lake"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
geometry = _extract_polygon_from_element(element)
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} lakes.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
gdf = _add_area_column(gdf)
|
||||
|
||||
if len(gdf) > 0:
|
||||
# Filter to lakes > MIN_LAKE_AREA_KM2 to exclude tiny ponds
|
||||
gdf = gdf[gdf["area_km2"] > MIN_LAKE_AREA_KM2]
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
|
||||
return gdf
|
||||
|
||||
|
||||
def get_polish_rivers() -> gpd.GeoDataFrame:
|
||||
"""Get Polish rivers, sorted by length descending.
|
||||
|
||||
Rivers with the same name but in different locations are kept separate
|
||||
by using unique IDs from OSM when available.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with river linestrings.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_rivers.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "length_km" in gdf.columns:
|
||||
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching rivers data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:300];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["waterway"="river"]["name"](area.pl);
|
||||
way["waterway"="river"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
# Group ways by river name AND wikidata ID (or OSM ID for uniqueness)
|
||||
# This prevents merging different rivers with the same name
|
||||
rivers_by_key: dict[str, list[list[tuple[float, float]]]] = {}
|
||||
river_names: dict[str, str] = {} # key -> display name
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
# Use wikidata ID if available, otherwise use element type+id
|
||||
wikidata = element.get("tags", {}).get("wikidata", "")
|
||||
if wikidata:
|
||||
key = f"{name}_{wikidata}"
|
||||
else:
|
||||
# Fall back to element ID for grouping related ways
|
||||
key = f"{name}_{element.get('type')}_{element.get('id')}"
|
||||
|
||||
coord_lists = _extract_river_coords_from_element(element)
|
||||
if coord_lists:
|
||||
rivers_by_key.setdefault(key, []).extend(coord_lists)
|
||||
river_names[key] = name
|
||||
|
||||
features = []
|
||||
for key, coord_lists in rivers_by_key.items():
|
||||
name = river_names[key]
|
||||
geometry: dict[str, Any]
|
||||
if len(coord_lists) == 1:
|
||||
geometry = {"type": "LineString", "coordinates": coord_lists[0]}
|
||||
else:
|
||||
geometry = {"type": "MultiLineString", "coordinates": coord_lists}
|
||||
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} rivers.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
gdf = _add_length_column(gdf)
|
||||
|
||||
if len(gdf) > 0:
|
||||
gdf = gdf[gdf["length_km"] > MIN_RIVER_LENGTH_KM]
|
||||
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
|
||||
|
||||
return gdf
|
||||
|
||||
|
||||
def get_polish_islands() -> gpd.GeoDataFrame:
|
||||
"""Get Polish islands, sorted by area descending.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with island polygons.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_islands.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "area_km2" in gdf.columns:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching islands data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:180];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["place"="island"]["name"](area.pl);
|
||||
way["place"="island"]["name"](area.pl);
|
||||
relation["place"="islet"]["name"](area.pl);
|
||||
way["place"="islet"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
geometry = _extract_polygon_from_element(element)
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} islands.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
gdf = _add_area_column(gdf)
|
||||
|
||||
if len(gdf) > 0:
|
||||
return gdf.sort_values("area_km2", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
|
||||
def get_polish_coastal_features() -> gpd.GeoDataFrame:
|
||||
"""Get Polish coastal features (peninsulas, spits, cliffs), sorted by length.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with coastal feature geometries.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_coastal_features.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
if "length_km" in gdf.columns:
|
||||
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
sys.stdout.write("Fetching coastal features data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:180];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["natural"="peninsula"]["name"](area.pl);
|
||||
way["natural"="peninsula"]["name"](area.pl);
|
||||
relation["natural"="spit"]["name"](area.pl);
|
||||
way["natural"="spit"]["name"](area.pl);
|
||||
relation["natural"="cliff"]["name"](area.pl);
|
||||
way["natural"="cliff"]["name"](area.pl);
|
||||
relation["natural"="coastline"]["name"](area.pl);
|
||||
way["natural"="beach"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
line_types = ("cliff", "beach", "coastline")
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
natural_type = element.get("tags", {}).get("natural", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
geometry = _extract_coastal_geometry(element, natural_type, line_types)
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name, "type": natural_type},
|
||||
"geometry": geometry,
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} coastal features.\n")
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
gdf = _add_length_column(gdf)
|
||||
|
||||
if len(gdf) > 0:
|
||||
return gdf.sort_values("length_km", ascending=False).reset_index(drop=True)
|
||||
return gdf
|
||||
|
||||
|
||||
def get_polish_unesco_sites() -> gpd.GeoDataFrame:
|
||||
"""Get Polish UNESCO World Heritage Sites, sorted by inscription year.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with UNESCO site geometries.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "polish_unesco_sites.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
sys.stdout.write("Fetching UNESCO sites data from OSM...\n")
|
||||
query = """
|
||||
[out:json][timeout:180];
|
||||
area["ISO3166-1"="PL"]->.pl;
|
||||
(
|
||||
relation["heritage"="world_heritage_site"]["name"](area.pl);
|
||||
way["heritage"="world_heritage_site"]["name"](area.pl);
|
||||
node["heritage"="world_heritage_site"]["name"](area.pl);
|
||||
relation["heritage:operator"="whc"]["name"](area.pl);
|
||||
way["heritage:operator"="whc"]["name"](area.pl);
|
||||
node["heritage:operator"="whc"]["name"](area.pl);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
min_ring_coords = 4
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
if element.get("type") == "node":
|
||||
geometry: dict[str, Any] = {
|
||||
"type": "Point",
|
||||
"coordinates": [element["lon"], element["lat"]],
|
||||
}
|
||||
elif element.get("type") == "relation":
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
|
||||
if not outer_rings:
|
||||
continue
|
||||
geometry = _build_osiedla_geometry(outer_rings, inner_rings)
|
||||
elif element.get("type") == "way" and "geometry" in element:
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) < min_ring_coords:
|
||||
continue
|
||||
if coords[0] != coords[-1]:
|
||||
coords.append(coords[0])
|
||||
geometry = {"type": "Polygon", "coordinates": [coords]}
|
||||
else:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geometry}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson, ensure_ascii=False))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} UNESCO sites.\n")
|
||||
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
407
python_pkg/geo_data/_warsaw.py
Normal file
407
python_pkg/geo_data/_warsaw.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""Warsaw geographic data functions.
|
||||
|
||||
Functions for downloading and caching Warsaw-specific geographic data:
|
||||
boundaries, districts, Vistula river, bridges, metro stations, and osiedla.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import geopandas as gpd
|
||||
from shapely.geometry import LineString
|
||||
|
||||
from python_pkg.geo_data._common import (
|
||||
_PKG_DIR,
|
||||
CACHE_DIR,
|
||||
_build_osiedla_geometry,
|
||||
_ensure_cache_dir,
|
||||
_extract_osiedla_rings,
|
||||
_overpass_query,
|
||||
)
|
||||
|
||||
|
||||
def get_warsaw_boundary() -> gpd.GeoDataFrame:
|
||||
"""Get Warsaw city boundary.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with Warsaw boundary polygon.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_boundary.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
# Try to use districts file first
|
||||
districts_path = (
|
||||
_PKG_DIR / "anki_decks" / "warsaw_districts" / "warszawa-dzielnice.geojson"
|
||||
)
|
||||
if districts_path.exists():
|
||||
warsaw_gdf = gpd.read_file(districts_path)
|
||||
warsaw_boundary = warsaw_gdf[warsaw_gdf["name"] == "Warszawa"]
|
||||
if len(warsaw_boundary) == 0:
|
||||
warsaw_boundary = gpd.GeoDataFrame(
|
||||
geometry=[warsaw_gdf.union_all()], crs=warsaw_gdf.crs
|
||||
)
|
||||
_ensure_cache_dir()
|
||||
warsaw_boundary.to_file(cache_path, driver="GeoJSON")
|
||||
return warsaw_boundary
|
||||
|
||||
# Fallback to Overpass query
|
||||
sys.stdout.write("Fetching Warsaw boundary from OpenStreetMap...\n")
|
||||
query = """
|
||||
[out:json][timeout:60];
|
||||
relation["name"="Warszawa"]["admin_level"="6"];
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") == "relation":
|
||||
coords = []
|
||||
for member in element.get("members", []):
|
||||
if member.get("role") == "outer" and "geometry" in member:
|
||||
coords.extend([(p["lon"], p["lat"]) for p in member["geometry"]])
|
||||
if coords:
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": "Warszawa"},
|
||||
"geometry": {"type": "Polygon", "coordinates": [coords]},
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
|
||||
def get_warsaw_districts() -> gpd.GeoDataFrame:
|
||||
"""Get Warsaw districts (dzielnice).
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with district boundaries.
|
||||
"""
|
||||
districts_path = (
|
||||
_PKG_DIR / "anki_decks" / "warsaw_districts" / "warszawa-dzielnice.geojson"
|
||||
)
|
||||
if districts_path.exists():
|
||||
gdf = gpd.read_file(districts_path)
|
||||
return gdf[gdf["name"] != "Warszawa"].copy()
|
||||
|
||||
msg = "Warsaw districts GeoJSON not found"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
|
||||
def get_vistula_river() -> gpd.GeoDataFrame:
|
||||
"""Get Vistula river in Warsaw.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with river geometry.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_vistula.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
sys.stdout.write("Fetching Vistula river data...\n")
|
||||
query = """
|
||||
[out:json][timeout:60];
|
||||
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
|
||||
(
|
||||
way["waterway"="river"]["name"="Wisła"](area.warsaw);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
min_coords = 2
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") == "way" and "geometry" in element:
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) >= min_coords:
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": "Wisła"},
|
||||
"geometry": {"type": "LineString", "coordinates": coords},
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
|
||||
def get_warsaw_bridges() -> gpd.GeoDataFrame:
|
||||
"""Get Warsaw bridges over the Vistula.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with bridge geometries.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_bridges.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
sys.stdout.write("Fetching Warsaw bridges data...\n")
|
||||
|
||||
# First get the Vistula to filter bridges
|
||||
vistula = get_vistula_river()
|
||||
vistula_union = vistula.union_all()
|
||||
vistula_buffer = vistula_union.buffer(0.002) # ~200m buffer
|
||||
|
||||
# Query for bridges with "Most" in name - smaller query
|
||||
query = """
|
||||
[out:json][timeout:90];
|
||||
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
|
||||
way["bridge"="yes"]["name"~"^Most"](area.warsaw);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
min_coords = 2
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") != "way" or "geometry" not in element:
|
||||
continue
|
||||
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) < min_coords:
|
||||
continue
|
||||
|
||||
line = LineString(coords)
|
||||
|
||||
# Check if bridge crosses/is near Vistula
|
||||
if line.intersects(vistula_buffer):
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name, "osm_id": element.get("id")},
|
||||
"geometry": {"type": "LineString", "coordinates": coords},
|
||||
}
|
||||
)
|
||||
|
||||
# Merge segments of the same bridge
|
||||
merged_features = _merge_bridge_segments(features)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": merged_features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
sys.stdout.write(f"Cached {len(merged_features)} bridges.\n")
|
||||
return gpd.GeoDataFrame.from_features(merged_features, crs="EPSG:4326")
|
||||
|
||||
|
||||
def _merge_bridge_segments(features: list[dict]) -> list[dict]:
|
||||
"""Merge bridge segments with the same name.
|
||||
|
||||
Args:
|
||||
features: List of GeoJSON features.
|
||||
|
||||
Returns:
|
||||
List of merged features.
|
||||
"""
|
||||
by_name: dict[str, list[list[tuple[float, float]]]] = {}
|
||||
|
||||
for feature in features:
|
||||
name = feature["properties"]["name"]
|
||||
coords = feature["geometry"]["coordinates"]
|
||||
if name not in by_name:
|
||||
by_name[name] = []
|
||||
by_name[name].append(coords)
|
||||
|
||||
merged = []
|
||||
for name, coord_lists in by_name.items():
|
||||
if len(coord_lists) == 1:
|
||||
geom = {"type": "LineString", "coordinates": coord_lists[0]}
|
||||
else:
|
||||
geom = {"type": "MultiLineString", "coordinates": coord_lists}
|
||||
|
||||
merged.append(
|
||||
{"type": "Feature", "properties": {"name": name}, "geometry": geom}
|
||||
)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def get_warsaw_metro_stations() -> gpd.GeoDataFrame:
|
||||
"""Get Warsaw metro stations with line information.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with station points and line info (M1, M2, or M1/M2).
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_metro.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
# Known stations for each line (as of 2024)
|
||||
m1_stations = {
|
||||
"Kabaty",
|
||||
"Natolin",
|
||||
"Imielin",
|
||||
"Stokłosy",
|
||||
"Ursynów",
|
||||
"Służew",
|
||||
"Wilanowska",
|
||||
"Wierzbno",
|
||||
"Racławicka",
|
||||
"Pole Mokotowskie",
|
||||
"Politechnika",
|
||||
"Centrum",
|
||||
"Świętokrzyska", # Also M2
|
||||
"Ratusz-Arsenał",
|
||||
"Dworzec Gdański",
|
||||
"Plac Wilsona",
|
||||
"Marymont",
|
||||
"Słodowiec",
|
||||
"Stare Bielany",
|
||||
"Wawrzyszew",
|
||||
"Młociny",
|
||||
}
|
||||
m2_stations = {
|
||||
"Bródno",
|
||||
"Kondratowicza",
|
||||
"Zacisze",
|
||||
"Targówek Mieszkaniowy",
|
||||
"Trocka",
|
||||
"Szwedzka",
|
||||
"Dworzec Wileński",
|
||||
"Świętokrzyska", # Also M1
|
||||
"Nowy Świat-Uniwersytet",
|
||||
"Centrum Nauki Kopernik",
|
||||
"Stadion Narodowy",
|
||||
"Rondo ONZ",
|
||||
"Rondo Daszyńskiego",
|
||||
"Płocka",
|
||||
"Młynów",
|
||||
"Księcia Janusza",
|
||||
"Ulrychów",
|
||||
"Bemowo",
|
||||
}
|
||||
|
||||
sys.stdout.write("Fetching metro station data...\n")
|
||||
query = """
|
||||
[out:json][timeout:60];
|
||||
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
|
||||
(
|
||||
node["railway"="station"]["station"="subway"](area.warsaw);
|
||||
node["railway"="station"]["network"~"Metro"](area.warsaw);
|
||||
);
|
||||
out body;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") == "node":
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if name and name not in seen_names:
|
||||
seen_names.add(name)
|
||||
# Determine line from known station lists
|
||||
in_m1 = name in m1_stations
|
||||
in_m2 = name in m2_stations
|
||||
if in_m1 and in_m2:
|
||||
line = "M1/M2"
|
||||
elif in_m1:
|
||||
line = "M1"
|
||||
elif in_m2:
|
||||
line = "M2"
|
||||
else:
|
||||
line = "?" # Unknown station
|
||||
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"name": name,
|
||||
"line": line,
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [element["lon"], element["lat"]],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} metro stations.\n")
|
||||
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
|
||||
|
||||
def get_warsaw_osiedla() -> gpd.GeoDataFrame:
|
||||
"""Get Warsaw osiedla (neighborhoods).
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with osiedla boundaries.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_osiedla.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
sys.stdout.write("Fetching osiedla data...\n")
|
||||
query = """
|
||||
[out:json][timeout:180];
|
||||
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
|
||||
relation["boundary"="administrative"]["admin_level"="11"]["name"](area.warsaw);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
min_ring_coords = 4
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") != "relation":
|
||||
continue
|
||||
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
outer_rings, inner_rings = _extract_osiedla_rings(element, min_ring_coords)
|
||||
if not outer_rings:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name},
|
||||
"geometry": _build_osiedla_geometry(outer_rings, inner_rings),
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} osiedla.\n")
|
||||
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
186
python_pkg/geo_data/_warsaw_places.py
Normal file
186
python_pkg/geo_data/_warsaw_places.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""Warsaw streets, landmarks, and place data.
|
||||
|
||||
Functions for downloading and caching Warsaw streets, landmarks,
|
||||
and other place-related geographic data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import geopandas as gpd
|
||||
from shapely.geometry import MultiLineString
|
||||
|
||||
from python_pkg.geo_data._common import CACHE_DIR, _ensure_cache_dir, _overpass_query
|
||||
|
||||
|
||||
def get_warsaw_streets(min_length: int = 500) -> gpd.GeoDataFrame:
|
||||
"""Get major Warsaw streets.
|
||||
|
||||
Args:
|
||||
min_length: Minimum street length in meters.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with street geometries.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_streets.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
gdf = gpd.read_file(cache_path)
|
||||
# Filter by length if needed
|
||||
return _filter_streets_by_length(gdf, min_length)
|
||||
|
||||
sys.stdout.write("Fetching street data from OpenStreetMap...\n")
|
||||
query = """
|
||||
[out:json][timeout:120];
|
||||
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
|
||||
(
|
||||
way["highway"="primary"]["name"](area.warsaw);
|
||||
way["highway"="secondary"]["name"](area.warsaw);
|
||||
way["highway"="tertiary"]["name"](area.warsaw);
|
||||
);
|
||||
out geom;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
min_coords = 2
|
||||
|
||||
for element in data.get("elements", []):
|
||||
if element.get("type") == "way" and "geometry" in element:
|
||||
coords = [(p["lon"], p["lat"]) for p in element["geometry"]]
|
||||
if len(coords) >= min_coords:
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"name": element.get("tags", {}).get("name", "Unknown"),
|
||||
"highway": element.get("tags", {}).get("highway", ""),
|
||||
},
|
||||
"geometry": {"type": "LineString", "coordinates": coords},
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} street segments.\n")
|
||||
|
||||
gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
return _filter_streets_by_length(gdf, min_length)
|
||||
|
||||
|
||||
def _filter_streets_by_length(
|
||||
gdf: gpd.GeoDataFrame, min_length: int
|
||||
) -> gpd.GeoDataFrame:
|
||||
"""Filter and merge streets by name, keeping only those above min_length.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame with street segments.
|
||||
min_length: Minimum length in meters.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with merged streets, sorted by length (longest first).
|
||||
"""
|
||||
# Group by street name
|
||||
streets: dict[str, list] = {}
|
||||
for _, row in gdf.iterrows():
|
||||
name = row.get("name", "Unknown")
|
||||
if name and name != "Unknown":
|
||||
if name not in streets:
|
||||
streets[name] = []
|
||||
streets[name].append(row.geometry)
|
||||
|
||||
# Merge and filter
|
||||
result_rows = []
|
||||
for name, geometries in streets.items():
|
||||
merged = geometries[0] if len(geometries) == 1 else MultiLineString(geometries)
|
||||
|
||||
# Create temp GeoDataFrame for length calculation
|
||||
temp_gdf = gpd.GeoDataFrame(geometry=[merged], crs="EPSG:4326")
|
||||
temp_proj = temp_gdf.to_crs("EPSG:2180") # Polish coordinate system
|
||||
length = temp_proj.geometry.length.iloc[0]
|
||||
|
||||
if length >= min_length:
|
||||
result_rows.append({"name": name, "geometry": merged, "length_m": length})
|
||||
|
||||
# Sort by length (longest first)
|
||||
result_rows.sort(key=lambda x: x["length_m"], reverse=True)
|
||||
|
||||
return gpd.GeoDataFrame(result_rows, crs="EPSG:4326")
|
||||
|
||||
|
||||
def get_warsaw_landmarks() -> gpd.GeoDataFrame:
|
||||
"""Get Warsaw landmarks (museums, monuments, parks, etc.).
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with landmark points.
|
||||
"""
|
||||
cache_path = CACHE_DIR / "warsaw_landmarks.geojson"
|
||||
|
||||
if cache_path.exists():
|
||||
return gpd.read_file(cache_path)
|
||||
|
||||
sys.stdout.write("Fetching landmark data...\n")
|
||||
# Simplified query - just museums and major attractions
|
||||
query = """
|
||||
[out:json][timeout:60];
|
||||
area["name"="Warszawa"]["admin_level"="6"]->.warsaw;
|
||||
(
|
||||
node["tourism"="museum"]["name"](area.warsaw);
|
||||
node["tourism"="attraction"]["name"](area.warsaw);
|
||||
node["historic"="monument"]["name"](area.warsaw);
|
||||
way["tourism"="museum"]["name"](area.warsaw);
|
||||
way["tourism"="attraction"]["name"](area.warsaw);
|
||||
);
|
||||
out center;
|
||||
"""
|
||||
|
||||
data = _overpass_query(query)
|
||||
|
||||
features = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for element in data.get("elements", []):
|
||||
name = element.get("tags", {}).get("name", "")
|
||||
if not name or name in seen_names:
|
||||
continue
|
||||
|
||||
# Get coordinates
|
||||
if element.get("type") == "node":
|
||||
lon, lat = element["lon"], element["lat"]
|
||||
elif "center" in element:
|
||||
lon, lat = element["center"]["lon"], element["center"]["lat"]
|
||||
else:
|
||||
continue
|
||||
|
||||
seen_names.add(name)
|
||||
landmark_type = (
|
||||
element.get("tags", {}).get("tourism")
|
||||
or element.get("tags", {}).get("historic")
|
||||
or element.get("tags", {}).get("leisure")
|
||||
or "landmark"
|
||||
)
|
||||
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"name": name, "type": landmark_type},
|
||||
"geometry": {"type": "Point", "coordinates": [lon, lat]},
|
||||
}
|
||||
)
|
||||
|
||||
_ensure_cache_dir()
|
||||
geojson = {"type": "FeatureCollection", "features": features}
|
||||
cache_path.write_text(json.dumps(geojson))
|
||||
|
||||
sys.stdout.write(f"Cached {len(features)} landmarks.\n")
|
||||
|
||||
if not features:
|
||||
return gpd.GeoDataFrame(
|
||||
{"name": [], "type": [], "geometry": []}, crs="EPSG:4326"
|
||||
)
|
||||
return gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
|
||||
14
python_pkg/keyboard_coop/tests/conftest.py
Normal file
14
python_pkg/keyboard_coop/tests/conftest.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""Shared fixtures for keyboard_coop tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_pygame() -> MagicMock:
|
||||
"""Mock pygame to prevent display initialization."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
yield
|
||||
148
python_pkg/keyboard_coop/tests/test_constants.py
Normal file
148
python_pkg/keyboard_coop/tests/test_constants.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Tests for keyboard_coop constants and dataclasses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module constants."""
|
||||
|
||||
def test_screen_dimensions(self) -> None:
|
||||
"""Test screen dimension constants."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import SCREEN_HEIGHT, SCREEN_WIDTH
|
||||
|
||||
expected_width = 1366
|
||||
expected_height = 768
|
||||
assert expected_width == SCREEN_WIDTH
|
||||
assert expected_height == SCREEN_HEIGHT
|
||||
|
||||
def test_min_word_length(self) -> None:
|
||||
"""Test minimum word length constant."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import MIN_WORD_LENGTH
|
||||
|
||||
expected_min = 3
|
||||
assert expected_min == MIN_WORD_LENGTH
|
||||
|
||||
def test_keyboard_layout_structure(self) -> None:
|
||||
"""Test KEYBOARD_LAYOUT has correct structure."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import KEYBOARD_LAYOUT
|
||||
|
||||
expected_rows = 3
|
||||
assert len(KEYBOARD_LAYOUT) == expected_rows
|
||||
expected_first_row_len = 10
|
||||
expected_second_row_len = 9
|
||||
expected_third_row_len = 7
|
||||
assert len(KEYBOARD_LAYOUT[0]) == expected_first_row_len
|
||||
assert len(KEYBOARD_LAYOUT[1]) == expected_second_row_len
|
||||
assert len(KEYBOARD_LAYOUT[2]) == expected_third_row_len
|
||||
|
||||
|
||||
class TestKeyAdjacency:
|
||||
"""Tests for KEY_ADJACENCY mapping."""
|
||||
|
||||
def test_q_adjacents(self) -> None:
|
||||
"""Test Q key has correct adjacent keys."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import KEY_ADJACENCY
|
||||
|
||||
assert set(KEY_ADJACENCY["q"]) == {"w", "a", "s"}
|
||||
|
||||
def test_all_letters_have_adjacents(self) -> None:
|
||||
"""Test all 26 letters have adjacency entries."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import KEY_ADJACENCY
|
||||
|
||||
alphabet = "qwertyuiopasdfghjklzxcvbnm"
|
||||
for letter in alphabet:
|
||||
assert letter in KEY_ADJACENCY
|
||||
assert len(KEY_ADJACENCY[letter]) > 0
|
||||
|
||||
|
||||
class TestGameState:
|
||||
"""Tests for GameState dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test GameState default values."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import GameState
|
||||
|
||||
state = GameState()
|
||||
assert state.current_player == 0
|
||||
assert state.current_word == ""
|
||||
assert state.selected_letters == []
|
||||
assert state.score == 0
|
||||
assert state.game_over is False
|
||||
assert "Player 1" in state.message
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test GameState with custom values."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import GameState
|
||||
|
||||
state = GameState(
|
||||
current_player=1,
|
||||
current_word="test",
|
||||
selected_letters=["t", "e", "s", "t"],
|
||||
score=100,
|
||||
game_over=True,
|
||||
message="Game Over!",
|
||||
)
|
||||
assert state.current_player == 1
|
||||
assert state.current_word == "test"
|
||||
expected_score = 100
|
||||
assert state.score == expected_score
|
||||
|
||||
|
||||
class TestKeyboardState:
|
||||
"""Tests for KeyboardState dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test KeyboardState default values."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import KeyboardState
|
||||
|
||||
kb_state = KeyboardState()
|
||||
assert kb_state.layout == []
|
||||
assert kb_state.available_letters == set()
|
||||
assert kb_state.adjacency == {}
|
||||
assert kb_state.positions == {}
|
||||
|
||||
|
||||
class TestFontSet:
|
||||
"""Tests for FontSet dataclass."""
|
||||
|
||||
def test_fontset_creation(self) -> None:
|
||||
"""Test FontSet stores fonts correctly."""
|
||||
mock_font = MagicMock()
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import FontSet
|
||||
|
||||
fonts = FontSet(normal=mock_font, large=mock_font, small=mock_font)
|
||||
assert fonts.normal == mock_font
|
||||
assert fonts.large == mock_font
|
||||
assert fonts.small == mock_font
|
||||
|
||||
|
||||
class TestColors:
|
||||
"""Tests for color constants."""
|
||||
|
||||
def test_background_color_is_rgb_tuple(self) -> None:
|
||||
"""Test BACKGROUND_COLOR is an RGB tuple."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import BACKGROUND_COLOR
|
||||
|
||||
expected_len = 3
|
||||
assert len(BACKGROUND_COLOR) == expected_len
|
||||
assert all(isinstance(c, int) for c in BACKGROUND_COLOR)
|
||||
|
||||
def test_player_colors_list(self) -> None:
|
||||
"""Test PLAYER_COLORS has colors for 2 players."""
|
||||
with patch.dict("sys.modules", {"pygame": MagicMock()}):
|
||||
from python_pkg.keyboard_coop.main import PLAYER_COLORS
|
||||
|
||||
expected_players = 2
|
||||
assert len(PLAYER_COLORS) == expected_players
|
||||
371
python_pkg/keyboard_coop/tests/test_game_logic.py
Normal file
371
python_pkg/keyboard_coop/tests/test_game_logic.py
Normal file
@ -0,0 +1,371 @@
|
||||
"""Tests for keyboard_coop game logic methods."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python_pkg.keyboard_coop.main import KeyboardCoopGame
|
||||
|
||||
|
||||
class TestKeyboardCoopGame:
|
||||
"""Tests for KeyboardCoopGame class methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_game(self) -> KeyboardCoopGame:
|
||||
"""Create a mock game instance without pygame initialization."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.font.Font.return_value = MagicMock()
|
||||
mock_pg.Rect = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
# Create game without calling __init__ directly
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.layout = [["a", "b", "c"], ["d", "e", "f"]]
|
||||
game.keyboard.adjacency = {
|
||||
"a": ["b", "d"],
|
||||
"b": ["a", "c", "e"],
|
||||
"c": ["b", "f"],
|
||||
"d": ["a", "e"],
|
||||
"e": ["b", "d", "f"],
|
||||
"f": ["c", "e"],
|
||||
}
|
||||
game.keyboard.available_letters = {"a", "b", "c", "d", "e", "f"}
|
||||
game.dictionary = {"cat", "bat", "cab", "bad", "bed", "fed", "fad", "ace"}
|
||||
return game
|
||||
|
||||
def test_is_valid_move_first_letter(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test first letter is always valid."""
|
||||
mock_game.state.selected_letters = []
|
||||
assert mock_game._is_valid_move("a") is True
|
||||
assert mock_game._is_valid_move("z") is True
|
||||
|
||||
def test_is_valid_move_adjacent(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test adjacent letter is valid."""
|
||||
mock_game.state.selected_letters = ["a"]
|
||||
# "b" and "d" are adjacent to "a"
|
||||
assert mock_game._is_valid_move("b") is True
|
||||
assert mock_game._is_valid_move("d") is True
|
||||
|
||||
def test_is_valid_move_not_adjacent(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test non-adjacent letter is invalid."""
|
||||
mock_game.state.selected_letters = ["a"]
|
||||
# "f" is not adjacent to "a"
|
||||
assert mock_game._is_valid_move("f") is False
|
||||
|
||||
def test_is_valid_word_true(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test valid word returns True."""
|
||||
assert mock_game._is_valid_word("cat") is True
|
||||
assert mock_game._is_valid_word("CAT") is True # Case insensitive
|
||||
|
||||
def test_is_valid_word_false(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test invalid word returns False."""
|
||||
assert mock_game._is_valid_word("xyz") is False
|
||||
|
||||
def test_calculate_score_min_length(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test score calculation for minimum length word."""
|
||||
# 3-letter word: 2^(3-2) = 2
|
||||
assert mock_game._calculate_score(3) == 2
|
||||
|
||||
def test_calculate_score_longer_word(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test score calculation for longer words."""
|
||||
# 4-letter: 2^(4-2) = 4
|
||||
assert mock_game._calculate_score(4) == 4
|
||||
# 5-letter: 2^(5-2) = 8
|
||||
assert mock_game._calculate_score(5) == 8
|
||||
|
||||
def test_calculate_score_too_short(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test score for words below minimum length is 0."""
|
||||
assert mock_game._calculate_score(2) == 0
|
||||
assert mock_game._calculate_score(1) == 0
|
||||
|
||||
def test_handle_letter_click_valid(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test clicking a valid letter adds it to word."""
|
||||
mock_game.state.selected_letters = []
|
||||
mock_game.state.current_word = ""
|
||||
mock_game.state.current_player = 0
|
||||
|
||||
mock_game._handle_letter_click("a")
|
||||
|
||||
assert mock_game.state.selected_letters == ["a"]
|
||||
assert mock_game.state.current_word == "a"
|
||||
assert mock_game.state.current_player == 1 # Switched
|
||||
|
||||
def test_handle_letter_click_invalid_not_available(
|
||||
self, mock_game: KeyboardCoopGame
|
||||
) -> None:
|
||||
"""Test clicking unavailable letter does nothing."""
|
||||
mock_game.keyboard.available_letters = {"b", "c"}
|
||||
mock_game.state.selected_letters = []
|
||||
mock_game.state.current_word = ""
|
||||
|
||||
mock_game._handle_letter_click("a")
|
||||
|
||||
assert mock_game.state.selected_letters == []
|
||||
assert mock_game.state.current_word == ""
|
||||
|
||||
def test_submit_word_valid(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test submitting a valid word adds score."""
|
||||
mock_game._generate_random_keyboard = MagicMock()
|
||||
mock_game.state.current_word = "cat"
|
||||
mock_game.state.selected_letters = ["c", "a", "t"]
|
||||
mock_game.state.score = 0
|
||||
|
||||
mock_game._submit_word()
|
||||
|
||||
assert mock_game.state.score == 2 # 2^(3-2) = 2
|
||||
assert mock_game.state.current_word == ""
|
||||
assert mock_game.state.selected_letters == []
|
||||
|
||||
def test_submit_word_too_short(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test submitting too short word gives no score."""
|
||||
mock_game.state.current_word = "ca"
|
||||
mock_game.state.selected_letters = ["c", "a"]
|
||||
mock_game.state.score = 0
|
||||
|
||||
mock_game._submit_word()
|
||||
|
||||
assert mock_game.state.score == 0
|
||||
assert "too short" in mock_game.state.message
|
||||
|
||||
def test_submit_word_invalid(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test submitting invalid word gives no score."""
|
||||
mock_game.state.current_word = "xyz"
|
||||
mock_game.state.selected_letters = ["x", "y", "z"]
|
||||
mock_game.state.score = 0
|
||||
|
||||
mock_game._submit_word()
|
||||
|
||||
assert mock_game.state.score == 0
|
||||
assert "not a valid word" in mock_game.state.message
|
||||
|
||||
def test_reset_game(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test reset_game creates new state."""
|
||||
mock_game._generate_random_keyboard = MagicMock()
|
||||
mock_game.state.score = 100
|
||||
mock_game.state.current_word = "test"
|
||||
|
||||
mock_game._reset_game()
|
||||
|
||||
# After reset, state should be fresh
|
||||
assert mock_game.state.score == 0
|
||||
assert mock_game.state.current_word == ""
|
||||
assert mock_game._generate_random_keyboard.called
|
||||
|
||||
def test_get_key_at_position_found(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test getting key at position when key exists."""
|
||||
mock_rect = MagicMock()
|
||||
mock_rect.collidepoint.return_value = True
|
||||
mock_game.keyboard.positions = {"a": mock_rect}
|
||||
|
||||
result = mock_game._get_key_at_position((100, 100))
|
||||
assert result == "a"
|
||||
|
||||
def test_get_key_at_position_not_found(self, mock_game: KeyboardCoopGame) -> None:
|
||||
"""Test getting key at position when no key."""
|
||||
mock_rect = MagicMock()
|
||||
mock_rect.collidepoint.return_value = False
|
||||
mock_game.keyboard.positions = {"a": mock_rect}
|
||||
|
||||
result = mock_game._get_key_at_position((100, 100))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadDictionary:
|
||||
"""Tests for dictionary loading."""
|
||||
|
||||
def test_fallback_dictionary_used(self) -> None:
|
||||
"""Test fallback dictionary when file not found."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.font.Font.return_value = MagicMock()
|
||||
mock_pg.display.set_mode.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("pathlib.Path.open", side_effect=FileNotFoundError),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import KeyboardCoopGame
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
dictionary = game._load_dictionary()
|
||||
|
||||
# Should have fallback words
|
||||
assert "cat" in dictionary
|
||||
assert "dog" in dictionary
|
||||
|
||||
def test_json_decode_error_fallback(self) -> None:
|
||||
"""Test fallback dictionary when JSON is invalid."""
|
||||
import json
|
||||
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.font.Font.return_value = MagicMock()
|
||||
mock_pg.display.set_mode.return_value = MagicMock()
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_file.__enter__ = MagicMock(return_value=mock_file)
|
||||
mock_file.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("pathlib.Path.open", return_value=mock_file),
|
||||
patch("json.load", side_effect=json.JSONDecodeError("err", "doc", 0)),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import KeyboardCoopGame
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
dictionary = game._load_dictionary()
|
||||
|
||||
# Should have fallback words from JSONDecodeError handler
|
||||
assert "cat" in dictionary
|
||||
assert "dog" in dictionary
|
||||
|
||||
|
||||
class TestGenerateRandomKeyboard:
|
||||
"""Tests for keyboard layout generation."""
|
||||
|
||||
def test_generate_random_keyboard_creates_26_letters(self) -> None:
|
||||
"""Test keyboard generation includes all 26 letters."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.Rect = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.keyboard = KeyboardState()
|
||||
|
||||
game._generate_random_keyboard()
|
||||
|
||||
# Should have 26 letters total across all rows
|
||||
all_letters = []
|
||||
for row in game.keyboard.layout:
|
||||
all_letters.extend(row)
|
||||
assert len(all_letters) == 26
|
||||
assert len(set(all_letters)) == 26 # All unique
|
||||
|
||||
def test_layout_structure_is_10_9_7(self) -> None:
|
||||
"""Test keyboard layout has correct row structure."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.Rect = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.keyboard = KeyboardState()
|
||||
|
||||
game._generate_random_keyboard()
|
||||
|
||||
assert len(game.keyboard.layout) == 3
|
||||
assert len(game.keyboard.layout[0]) == 10
|
||||
assert len(game.keyboard.layout[1]) == 9
|
||||
assert len(game.keyboard.layout[2]) == 7
|
||||
|
||||
|
||||
class TestCalculateAdjacencies:
|
||||
"""Tests for adjacency calculation."""
|
||||
|
||||
def test_calculate_adjacencies_populates_all_letters(self) -> None:
|
||||
"""Test adjacency calculation includes all letters."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.layout = [
|
||||
["a", "b", "c"],
|
||||
["d", "e", "f"],
|
||||
["g", "h"],
|
||||
]
|
||||
|
||||
game._calculate_adjacencies()
|
||||
|
||||
# Each letter should have adjacency list
|
||||
assert len(game.keyboard.adjacency) == 8
|
||||
# Corner letter should have fewer adjacents
|
||||
assert "b" in game.keyboard.adjacency["a"]
|
||||
assert "d" in game.keyboard.adjacency["a"]
|
||||
assert "e" in game.keyboard.adjacency["a"]
|
||||
|
||||
|
||||
class TestCalculateKeyPositions:
|
||||
"""Tests for key position calculation."""
|
||||
|
||||
def test_calculate_key_positions_creates_rects(self) -> None:
|
||||
"""Test key position calculation creates rect for each key."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.Rect = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.layout = [["a", "b"], ["c", "d"]]
|
||||
|
||||
positions = game._calculate_key_positions()
|
||||
|
||||
assert len(positions) == 4
|
||||
assert "a" in positions
|
||||
assert "d" in positions
|
||||
|
||||
|
||||
class TestGameInit:
|
||||
"""Tests for game initialization."""
|
||||
|
||||
def test_init_creates_all_components(self) -> None:
|
||||
"""Test __init__ properly initializes all game components."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.font.Font.return_value = MagicMock()
|
||||
mock_pg.display.set_mode.return_value = MagicMock()
|
||||
mock_pg.Rect = MagicMock(return_value=MagicMock())
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_file.__enter__ = MagicMock(return_value=mock_file)
|
||||
mock_file.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("pathlib.Path.open", return_value=mock_file),
|
||||
patch("json.load", return_value={"cat": 1, "dog": 1}),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import KeyboardCoopGame
|
||||
|
||||
game = KeyboardCoopGame()
|
||||
|
||||
# Verify pygame display was set up
|
||||
mock_pg.display.set_mode.assert_called()
|
||||
mock_pg.display.set_caption.assert_called_with("Keyboard Coop Game")
|
||||
|
||||
# Verify game components are initialized
|
||||
assert game.screen is not None
|
||||
assert game.clock is not None
|
||||
assert game.fonts is not None
|
||||
assert game.dictionary is not None
|
||||
assert game.state is not None
|
||||
assert game.keyboard is not None
|
||||
426
python_pkg/keyboard_coop/tests/test_game_loop.py
Normal file
426
python_pkg/keyboard_coop/tests/test_game_loop.py
Normal file
@ -0,0 +1,426 @@
|
||||
"""Tests for keyboard_coop game loop and forced submission."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestForceSubmitWhenNoMoves:
|
||||
"""Tests for forced word submission when no moves available."""
|
||||
|
||||
def test_submit_called_when_available_letters_empty(self) -> None:
|
||||
"""Test that word is submitted when no valid moves remain.
|
||||
|
||||
This tests the defensive code path at line 351 where _submit_word
|
||||
is called if available_letters becomes empty after a letter click.
|
||||
"""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.layout = [["a", "b"]]
|
||||
game.keyboard.adjacency = {}
|
||||
game.keyboard.available_letters = {"a"}
|
||||
game.keyboard.positions = {}
|
||||
game.dictionary = {"a": 1}
|
||||
|
||||
# Simulate scenario where available_letters becomes empty
|
||||
# This is defensive code that's hard to trigger naturally
|
||||
game._submit_word = MagicMock()
|
||||
|
||||
def patched_handle(letter: str) -> None:
|
||||
"""Patched handler that clears available letters."""
|
||||
if letter in game.keyboard.available_letters:
|
||||
game.state.selected_letters.append(letter)
|
||||
game.state.current_word += letter
|
||||
# Force empty to trigger the check
|
||||
game.keyboard.available_letters = set()
|
||||
if not game.keyboard.available_letters:
|
||||
game._submit_word()
|
||||
|
||||
patched_handle("a")
|
||||
|
||||
# Should have triggered submit_word
|
||||
game._submit_word.assert_called()
|
||||
|
||||
|
||||
class TestGameLoop:
|
||||
"""Tests for the main game loop."""
|
||||
|
||||
def test_run_quit_event(self) -> None:
|
||||
"""Test game loop exits on QUIT event."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
# Create quit event
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.event.get.return_value = [quit_event]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit") as mock_exit,
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
|
||||
game.run()
|
||||
|
||||
mock_pg.quit.assert_called()
|
||||
mock_exit.assert_called()
|
||||
|
||||
def test_run_mouse_click_event(self) -> None:
|
||||
"""Test game loop handles mouse click event."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
# Create mouse click event followed by quit
|
||||
click_event = MagicMock()
|
||||
click_event.type = "MOUSEDOWN"
|
||||
click_event.button = 1
|
||||
click_event.pos = (100, 100)
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
# Return click event first, then quit event
|
||||
mock_pg.event.get.side_effect = [[click_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._handle_click = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
game._handle_click.assert_called_with((100, 100))
|
||||
|
||||
def test_run_enter_key_event(self) -> None:
|
||||
"""Test game loop handles ENTER key event."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
key_event = MagicMock()
|
||||
key_event.type = "KEYDOWN"
|
||||
key_event.key = "K_RETURN"
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.K_RETURN = "K_RETURN"
|
||||
mock_pg.K_r = "K_r"
|
||||
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._submit_word = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
game._submit_word.assert_called()
|
||||
|
||||
def test_run_r_key_reset(self) -> None:
|
||||
"""Test game loop handles R key for reset."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
key_event = MagicMock()
|
||||
key_event.type = "KEYDOWN"
|
||||
key_event.key = "K_r"
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.K_RETURN = "K_RETURN"
|
||||
mock_pg.K_r = "K_r"
|
||||
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._reset_game = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
game._reset_game.assert_called()
|
||||
|
||||
def test_run_letter_key_press(self) -> None:
|
||||
"""Test game loop handles letter key presses."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
key_event = MagicMock()
|
||||
key_event.type = "KEYDOWN"
|
||||
key_event.key = "some_key"
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.K_RETURN = "K_RETURN"
|
||||
mock_pg.K_r = "K_r"
|
||||
mock_pg.key.name.return_value = "a" # Single letter key
|
||||
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._handle_letter_click = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
game._handle_letter_click.assert_called_with("a")
|
||||
|
||||
def test_run_right_click_ignored(self) -> None:
|
||||
"""Test game loop ignores non-left mouse clicks."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
click_event = MagicMock()
|
||||
click_event.type = "MOUSEDOWN"
|
||||
click_event.button = 3 # Right click
|
||||
click_event.pos = (100, 100)
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.event.get.side_effect = [[click_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._handle_click = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
# handle_click should NOT be called for right click
|
||||
game._handle_click.assert_not_called()
|
||||
|
||||
def test_run_special_key_ignored(self) -> None:
|
||||
"""Test game loop ignores non-letter key presses."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
key_event = MagicMock()
|
||||
key_event.type = "KEYDOWN"
|
||||
key_event.key = "some_key"
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.K_RETURN = "K_RETURN"
|
||||
mock_pg.K_r = "K_r"
|
||||
mock_pg.key.name.return_value = "escape" # Multi-char, not a letter
|
||||
mock_pg.event.get.side_effect = [[key_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._handle_letter_click = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
# handle_letter_click should NOT be called for special keys
|
||||
game._handle_letter_click.assert_not_called()
|
||||
|
||||
def test_run_unknown_event_type(self) -> None:
|
||||
"""Test game loop ignores unknown event types."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
unknown_event = MagicMock()
|
||||
unknown_event.type = "UNKNOWN"
|
||||
|
||||
quit_event = MagicMock()
|
||||
quit_event.type = "QUIT"
|
||||
|
||||
mock_pg.QUIT = "QUIT"
|
||||
mock_pg.MOUSEBUTTONDOWN = "MOUSEDOWN"
|
||||
mock_pg.KEYDOWN = "KEYDOWN"
|
||||
mock_pg.event.get.side_effect = [[unknown_event], [quit_event]]
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"pygame": mock_pg}),
|
||||
patch("sys.exit"),
|
||||
):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.clock = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
game._draw_keyboard = MagicMock()
|
||||
game._draw_ui = MagicMock(return_value=(MagicMock(), MagicMock()))
|
||||
game._handle_click = MagicMock()
|
||||
game._submit_word = MagicMock()
|
||||
game._reset_game = MagicMock()
|
||||
game._handle_letter_click = MagicMock()
|
||||
|
||||
game.run()
|
||||
|
||||
# None of the handlers should be called for unknown event
|
||||
game._handle_click.assert_not_called()
|
||||
game._submit_word.assert_not_called()
|
||||
game._reset_game.assert_not_called()
|
||||
game._handle_letter_click.assert_not_called()
|
||||
File diff suppressed because it is too large
Load Diff
311
python_pkg/keyboard_coop/tests/test_ui.py
Normal file
311
python_pkg/keyboard_coop/tests/test_ui.py
Normal file
@ -0,0 +1,311 @@
|
||||
"""Tests for keyboard_coop UI drawing and click handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestHandleClick:
|
||||
"""Tests for click handling."""
|
||||
|
||||
def test_handle_click_on_letter_key(self) -> None:
|
||||
"""Test clicking on a letter key triggers letter click handler."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.Rect = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.available_letters = {"a"}
|
||||
game.keyboard.adjacency = {"a": []}
|
||||
|
||||
# Mock _get_key_at_position to return "a"
|
||||
game._get_key_at_position = MagicMock(return_value="a")
|
||||
game._handle_letter_click = MagicMock()
|
||||
game._submit_word = MagicMock()
|
||||
game._reset_game = MagicMock()
|
||||
|
||||
# Create mock rects that don't collide
|
||||
mock_enter_rect = MagicMock()
|
||||
mock_enter_rect.collidepoint.return_value = False
|
||||
mock_reset_rect = MagicMock()
|
||||
mock_reset_rect.collidepoint.return_value = False
|
||||
|
||||
# Patch pygame.Rect to return our mocks
|
||||
mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect]
|
||||
|
||||
game._handle_click((100, 100))
|
||||
|
||||
game._handle_letter_click.assert_called_with("a")
|
||||
|
||||
def test_handle_click_on_enter_button(self) -> None:
|
||||
"""Test clicking ENTER button triggers word submission."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
|
||||
# Mock methods
|
||||
game._get_key_at_position = MagicMock(return_value=None)
|
||||
game._submit_word = MagicMock()
|
||||
game._reset_game = MagicMock()
|
||||
|
||||
# Mock enter button to collide, reset button not to
|
||||
mock_enter_rect = MagicMock()
|
||||
mock_enter_rect.collidepoint.return_value = True
|
||||
mock_reset_rect = MagicMock()
|
||||
mock_reset_rect.collidepoint.return_value = False
|
||||
|
||||
mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect]
|
||||
|
||||
game._handle_click((750, 200))
|
||||
|
||||
game._submit_word.assert_called()
|
||||
game._reset_game.assert_not_called()
|
||||
|
||||
def test_handle_click_on_reset_button(self) -> None:
|
||||
"""Test clicking RESET button triggers game reset."""
|
||||
mock_pg = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.keyboard.positions = {}
|
||||
|
||||
# Mock methods
|
||||
game._get_key_at_position = MagicMock(return_value=None)
|
||||
game._submit_word = MagicMock()
|
||||
game._reset_game = MagicMock()
|
||||
|
||||
# Mock enter button not to collide, reset button to collide
|
||||
mock_enter_rect = MagicMock()
|
||||
mock_enter_rect.collidepoint.return_value = False
|
||||
mock_reset_rect = MagicMock()
|
||||
mock_reset_rect.collidepoint.return_value = True
|
||||
|
||||
mock_pg.Rect.side_effect = [mock_enter_rect, mock_reset_rect]
|
||||
|
||||
game._handle_click((900, 200))
|
||||
|
||||
game._reset_game.assert_called()
|
||||
game._submit_word.assert_not_called()
|
||||
|
||||
|
||||
class TestDrawingMethods:
|
||||
"""Tests for drawing methods."""
|
||||
|
||||
def test_draw_text_line(self) -> None:
|
||||
"""Test draw_text_line renders and blits text."""
|
||||
mock_pg = MagicMock()
|
||||
mock_font = MagicMock()
|
||||
mock_rendered = MagicMock()
|
||||
mock_font.render.return_value = mock_rendered
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KeyboardCoopGame,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
|
||||
game._draw_text_line("Test text", (10, 20), mock_font)
|
||||
|
||||
mock_font.render.assert_called()
|
||||
game.screen.blit.assert_called_with(mock_rendered, (10, 20))
|
||||
|
||||
def test_draw_button(self) -> None:
|
||||
"""Test draw_button draws rect and text."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.draw = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
KeyboardCoopGame,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
|
||||
mock_rect = MagicMock()
|
||||
mock_rect.center = (50, 50)
|
||||
|
||||
game._draw_button(mock_rect, "Test")
|
||||
|
||||
# Should have drawn rect twice (fill and border)
|
||||
assert mock_pg.draw.rect.call_count == 2
|
||||
|
||||
|
||||
class TestDrawKeyboard:
|
||||
"""Tests for keyboard drawing."""
|
||||
|
||||
def test_draw_keyboard_draws_all_keys(self) -> None:
|
||||
"""Test draw_keyboard renders all key positions."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.draw = MagicMock()
|
||||
mock_pg.mouse.get_pos.return_value = (0, 0)
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.state = GameState()
|
||||
game.keyboard = KeyboardState()
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
|
||||
# Set up some positions
|
||||
mock_rect_a = MagicMock()
|
||||
mock_rect_a.collidepoint.return_value = False
|
||||
mock_rect_a.center = (100, 100)
|
||||
mock_rect_b = MagicMock()
|
||||
mock_rect_b.collidepoint.return_value = False
|
||||
mock_rect_b.center = (150, 100)
|
||||
|
||||
game.keyboard.positions = {"a": mock_rect_a, "b": mock_rect_b}
|
||||
game.keyboard.available_letters = {"a", "b"}
|
||||
|
||||
game._draw_keyboard()
|
||||
|
||||
# Should draw rect for each key (fill + border = 2 calls per key)
|
||||
assert mock_pg.draw.rect.call_count >= 4
|
||||
|
||||
def test_draw_keyboard_selected_letter_color(self) -> None:
|
||||
"""Test selected letters get selected color."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.draw = MagicMock()
|
||||
mock_pg.mouse.get_pos.return_value = (0, 0)
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KEY_SELECTED_COLOR,
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.state = GameState()
|
||||
game.state.selected_letters = ["a"] # 'a' is selected
|
||||
game.keyboard = KeyboardState()
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
|
||||
mock_rect_a = MagicMock()
|
||||
mock_rect_a.collidepoint.return_value = False
|
||||
mock_rect_a.center = (100, 100)
|
||||
|
||||
game.keyboard.positions = {"a": mock_rect_a}
|
||||
game.keyboard.available_letters = {"a"}
|
||||
|
||||
game._draw_keyboard()
|
||||
|
||||
# Check that KEY_SELECTED_COLOR was used
|
||||
calls = mock_pg.draw.rect.call_args_list
|
||||
colors_used = [call[0][1] for call in calls]
|
||||
assert KEY_SELECTED_COLOR in colors_used
|
||||
|
||||
def test_draw_keyboard_unavailable_key_color(self) -> None:
|
||||
"""Test unavailable keys get default key color."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.draw = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
KEY_COLOR,
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
KeyboardState,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.state = GameState()
|
||||
game.state.selected_letters = [] # Not selected
|
||||
game.keyboard = KeyboardState()
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
|
||||
mock_rect_a = MagicMock()
|
||||
mock_rect_a.center = (100, 100)
|
||||
|
||||
game.keyboard.positions = {"a": mock_rect_a}
|
||||
# Key is NOT available - should get KEY_COLOR
|
||||
game.keyboard.available_letters = set()
|
||||
|
||||
game._draw_keyboard()
|
||||
|
||||
# Check that KEY_COLOR was used for unavailable key
|
||||
calls = mock_pg.draw.rect.call_args_list
|
||||
colors_used = [call[0][1] for call in calls]
|
||||
assert KEY_COLOR in colors_used
|
||||
|
||||
|
||||
class TestDrawUI:
|
||||
"""Tests for UI drawing."""
|
||||
|
||||
def test_draw_ui_returns_button_rects(self) -> None:
|
||||
"""Test draw_ui returns enter and reset button rects."""
|
||||
mock_pg = MagicMock()
|
||||
mock_pg.draw = MagicMock()
|
||||
mock_rect_instance = MagicMock()
|
||||
mock_pg.Rect.return_value = mock_rect_instance
|
||||
|
||||
with patch.dict("sys.modules", {"pygame": mock_pg}):
|
||||
from python_pkg.keyboard_coop.main import (
|
||||
FontSet,
|
||||
GameState,
|
||||
KeyboardCoopGame,
|
||||
)
|
||||
|
||||
game = object.__new__(KeyboardCoopGame)
|
||||
game.screen = MagicMock()
|
||||
game.state = GameState()
|
||||
game.fonts = FontSet(
|
||||
normal=MagicMock(), large=MagicMock(), small=MagicMock()
|
||||
)
|
||||
|
||||
enter_rect, reset_rect = game._draw_ui()
|
||||
|
||||
# Should return pygame.Rect instances
|
||||
assert enter_rect is not None
|
||||
assert reset_rect is not None
|
||||
File diff suppressed because it is too large
Load Diff
409
python_pkg/lichess_bot/tests/test_main_analysis.py
Normal file
409
python_pkg/lichess_bot/tests/test_main_analysis.py
Normal file
@ -0,0 +1,409 @@
|
||||
"""Tests for lichess_bot main module: game events and analysis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import chess
|
||||
import pytest
|
||||
|
||||
from python_pkg.lichess_bot.main import (
|
||||
BotContext,
|
||||
GameMeta,
|
||||
GameState,
|
||||
_collect_analysis_lines,
|
||||
_finalize_game,
|
||||
_insert_analysis_into_log,
|
||||
_log_analysis_progress,
|
||||
_process_analysis_output,
|
||||
_process_game_event,
|
||||
_run_analysis_subprocess,
|
||||
_write_pgn_to_log,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
# Type alias to make mypy happy with test event dicts
|
||||
Event = dict[str, Any]
|
||||
|
||||
|
||||
class TestProcessGameEvent:
|
||||
"""Tests for _process_game_event."""
|
||||
|
||||
def test_process_game_event_unhandled_type(self) -> None:
|
||||
"""Test processing unhandled event type."""
|
||||
ctx = MagicMock()
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {"type": "chatLine", "text": "hello"}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is True
|
||||
|
||||
def test_process_game_event_game_full(self) -> None:
|
||||
"""Test processing gameFull event."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("e2e4"),
|
||||
"opening",
|
||||
)
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"type": "gameFull",
|
||||
"state": {"moves": "", "status": "started"},
|
||||
"white": {"id": "mybot"},
|
||||
"black": {"id": "opp"},
|
||||
}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is True
|
||||
|
||||
def test_process_game_event_game_end(self) -> None:
|
||||
"""Test processing game end event."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (None, "no moves")
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(color="white", last_handled_len=-1)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"type": "gameState",
|
||||
"moves": "e2e4 e7e5",
|
||||
"status": "mate",
|
||||
}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is False
|
||||
|
||||
def test_process_game_event_game_end_after_move(self) -> None:
|
||||
"""Test game ends with status after handling move.
|
||||
|
||||
This covers the case where _handle_move_if_needed returns True
|
||||
but status indicates game end.
|
||||
"""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("d2d4"),
|
||||
"response",
|
||||
)
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
# Black's turn - it's opponent's move, so we don't need to move
|
||||
state = GameState(color="black", last_handled_len=-1)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"type": "gameState",
|
||||
"moves": "e2e4", # One move - now it's black's turn
|
||||
"status": "resign", # Game ended with resign
|
||||
}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is False # Game should end
|
||||
|
||||
def test_process_game_event_unchanged_position(self) -> None:
|
||||
"""Test processing event with unchanged position."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
state = GameState(last_handled_len=2, color="white")
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {"type": "gameState", "moves": "e2e4 e7e5"}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is True
|
||||
|
||||
def test_process_game_event_color_unknown(self) -> None:
|
||||
"""Test processing event with unknown color."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
state = GameState(last_handled_len=-1)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {"type": "gameState", "moves": "e2e4"}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is True
|
||||
assert state.last_handled_len == 1
|
||||
|
||||
def test_process_game_event_color_unknown_on_gamefull(self) -> None:
|
||||
"""Test processing gameFull event with still unknown color.
|
||||
|
||||
This covers the branch where event_type is gameFull but color
|
||||
is not determined (e.g., spectator watching game).
|
||||
"""
|
||||
api = MagicMock()
|
||||
# Return a user id that doesn't match either player
|
||||
api.get_my_user_id.return_value = "spectator"
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
state = GameState(last_handled_len=-1)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"type": "gameFull",
|
||||
"state": {"moves": "e2e4", "status": "started"},
|
||||
"white": {"id": "player1"},
|
||||
"black": {"id": "player2"},
|
||||
}
|
||||
result = _process_game_event(event, ctx, state, meta)
|
||||
assert result is True
|
||||
# last_handled_len should NOT be updated for gameFull with unknown color
|
||||
assert state.last_handled_len == -1
|
||||
|
||||
|
||||
class TestWritePgnToLog:
|
||||
"""Tests for _write_pgn_to_log."""
|
||||
|
||||
def test_write_pgn_to_log(self, tmp_path: Path) -> None:
|
||||
"""Test writing PGN to log file."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header\n")
|
||||
board = chess.Board()
|
||||
board.push_uci("e2e4")
|
||||
meta = GameMeta(
|
||||
game_id="game1",
|
||||
bot_version=1,
|
||||
site_url="https://lichess.org/game1",
|
||||
date_iso="2021.01.01",
|
||||
white_name="White",
|
||||
black_name="Black",
|
||||
)
|
||||
_write_pgn_to_log(log_path, board, meta)
|
||||
content = log_path.read_text()
|
||||
assert "PGN:" in content
|
||||
assert "e4" in content
|
||||
|
||||
|
||||
class TestRunAnalysisSubprocess:
|
||||
"""Tests for _run_analysis_subprocess."""
|
||||
|
||||
def test_run_analysis_subprocess_script_not_found(self, tmp_path: Path) -> None:
|
||||
"""Test analysis when script not found."""
|
||||
log_path = tmp_path / "game.log"
|
||||
with patch("python_pkg.lichess_bot.main.Path") as mock_path:
|
||||
mock_script = MagicMock()
|
||||
mock_script.is_file.return_value = False
|
||||
resolve = mock_path.return_value.resolve.return_value
|
||||
resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = (
|
||||
mock_script
|
||||
)
|
||||
result = _run_analysis_subprocess("game1", log_path, 10)
|
||||
assert result is None
|
||||
|
||||
def test_run_analysis_subprocess_success(self, tmp_path: Path) -> None:
|
||||
"""Test successful analysis subprocess."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("test")
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"])
|
||||
mock_proc.stderr.read.return_value = ""
|
||||
mock_proc.wait.return_value = 0
|
||||
mock_proc.__enter__ = MagicMock(return_value=mock_proc)
|
||||
mock_proc.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("python_pkg.lichess_bot.main.Path") as mock_path,
|
||||
patch("subprocess.Popen", return_value=mock_proc),
|
||||
):
|
||||
mock_script = MagicMock()
|
||||
mock_script.is_file.return_value = True
|
||||
resolve = mock_path.return_value.resolve.return_value
|
||||
resolve.parent.parent.__truediv__.return_value.__truediv__.return_value = (
|
||||
mock_script
|
||||
)
|
||||
result = _run_analysis_subprocess("game1", log_path, 2)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestProcessAnalysisOutput:
|
||||
"""Tests for _process_analysis_output."""
|
||||
|
||||
def test_process_analysis_output_success(self) -> None:
|
||||
"""Test processing analysis output successfully."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([" 1 e4\n", " 2 e5\n"])
|
||||
mock_proc.stderr.read.return_value = ""
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
result = _process_analysis_output(mock_proc, "game1", 2)
|
||||
assert result is not None
|
||||
assert "e4" in result
|
||||
|
||||
def test_process_analysis_output_error_exit(self) -> None:
|
||||
"""Test processing analysis output with error exit."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["output\n"])
|
||||
mock_proc.stderr.read.return_value = "error message"
|
||||
mock_proc.wait.return_value = 1
|
||||
|
||||
result = _process_analysis_output(mock_proc, "game1", 1)
|
||||
assert result is not None
|
||||
assert "stderr" in result
|
||||
|
||||
def test_process_analysis_output_error_exit_no_stderr(self) -> None:
|
||||
"""Test processing analysis output with error exit but no stderr."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["output\n"])
|
||||
mock_proc.stderr.read.return_value = ""
|
||||
mock_proc.wait.return_value = 1
|
||||
|
||||
result = _process_analysis_output(mock_proc, "game1", 1)
|
||||
assert result is not None
|
||||
assert "stderr" not in result
|
||||
|
||||
def test_process_analysis_output_none_pipes(self) -> None:
|
||||
"""Test processing analysis output with None pipes."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = None
|
||||
mock_proc.stderr = None
|
||||
|
||||
with pytest.raises(RuntimeError, match="pipes unexpectedly None"):
|
||||
_process_analysis_output(mock_proc, "game1", 1)
|
||||
|
||||
|
||||
class TestCollectAnalysisLines:
|
||||
"""Tests for _collect_analysis_lines helper."""
|
||||
|
||||
def test_collect_analysis_lines_empty_iterator(self) -> None:
|
||||
"""Test collecting lines from empty iterator."""
|
||||
empty_iter: list[str] = []
|
||||
analyzed, lines = _collect_analysis_lines(iter(empty_iter), "game1", 10)
|
||||
assert analyzed == 0
|
||||
assert lines == []
|
||||
|
||||
def test_collect_analysis_lines_with_content(self) -> None:
|
||||
"""Test collecting lines from iterator with content."""
|
||||
content = [" 1 e4\n", " 2 e5\n", "not a ply line\n"]
|
||||
analyzed, lines = _collect_analysis_lines(iter(content), "game1", 3)
|
||||
assert analyzed == 2
|
||||
assert lines == content
|
||||
|
||||
def test_collect_analysis_lines_full_iteration(self) -> None:
|
||||
"""Test that all lines are collected."""
|
||||
content = ["line1\n", " 3 Nf3\n", "line3\n"]
|
||||
analyzed, lines = _collect_analysis_lines(iter(content), "game1", 1)
|
||||
assert analyzed == 1
|
||||
assert len(lines) == 3
|
||||
|
||||
|
||||
class TestLogAnalysisProgress:
|
||||
"""Tests for _log_analysis_progress."""
|
||||
|
||||
def test_log_analysis_progress_with_total(self) -> None:
|
||||
"""Test logging progress with known total."""
|
||||
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
|
||||
_log_analysis_progress("game1", 5, 10)
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0]
|
||||
assert "50%" in call_args[0] % call_args[1:]
|
||||
|
||||
def test_log_analysis_progress_zero_total(self) -> None:
|
||||
"""Test logging progress with zero total."""
|
||||
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
|
||||
_log_analysis_progress("game1", 5, 0)
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0]
|
||||
assert "unknown" in call_args[0]
|
||||
|
||||
|
||||
class TestInsertAnalysisIntoLog:
|
||||
"""Tests for _insert_analysis_into_log."""
|
||||
|
||||
def test_insert_analysis_before_pgn(self, tmp_path: Path) -> None:
|
||||
"""Test inserting analysis before PGN section."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header\n\nPGN:\n1. e4\n")
|
||||
meta = GameMeta(
|
||||
game_id="game1",
|
||||
bot_version=1,
|
||||
date_iso="2021.01.01",
|
||||
white_name="White",
|
||||
black_name="Black",
|
||||
)
|
||||
_insert_analysis_into_log(log_path, "Analysis here", meta)
|
||||
content = log_path.read_text()
|
||||
assert "ANALYSIS:" in content
|
||||
assert content.index("ANALYSIS:") < content.index("PGN:")
|
||||
|
||||
def test_insert_analysis_at_start(self, tmp_path: Path) -> None:
|
||||
"""Test inserting analysis when PGN at start."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("PGN:\n1. e4\n")
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
_insert_analysis_into_log(log_path, "Analysis here", meta)
|
||||
content = log_path.read_text()
|
||||
assert "ANALYSIS:" in content
|
||||
|
||||
def test_insert_analysis_no_pgn(self, tmp_path: Path) -> None:
|
||||
"""Test inserting analysis when no PGN section."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header\n")
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
_insert_analysis_into_log(log_path, "Analysis here", meta)
|
||||
content = log_path.read_text()
|
||||
assert "ANALYSIS:" in content
|
||||
|
||||
def test_insert_analysis_oserror(self, tmp_path: Path) -> None:
|
||||
"""Test inserting analysis with OSError."""
|
||||
log_path = tmp_path / "nonexistent" / "game.log"
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
# Should not raise, just log debug
|
||||
_insert_analysis_into_log(log_path, "Analysis", meta)
|
||||
|
||||
|
||||
class TestFinalizeGame:
|
||||
"""Tests for _finalize_game."""
|
||||
|
||||
def test_finalize_game_no_log_path(self) -> None:
|
||||
"""Test finalize game with no log path."""
|
||||
state = GameState(log_path=None)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
_finalize_game(state, meta) # Should not raise
|
||||
|
||||
def test_finalize_game_write_error(self, tmp_path: Path) -> None:
|
||||
"""Test finalize game with write error."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header")
|
||||
state = GameState(log_path=log_path)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
with patch(
|
||||
"python_pkg.lichess_bot.main._write_pgn_to_log",
|
||||
side_effect=OSError("error"),
|
||||
):
|
||||
_finalize_game(state, meta) # Should not raise
|
||||
|
||||
def test_finalize_game_type_error_on_move_stack(self, tmp_path: Path) -> None:
|
||||
"""Test finalize game with TypeError on move_stack."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header\n")
|
||||
state = GameState(log_path=log_path)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
mock_board = MagicMock()
|
||||
# Use PropertyMock to raise TypeError when move_stack is accessed
|
||||
type(mock_board).move_stack = PropertyMock(side_effect=TypeError())
|
||||
state.board = mock_board
|
||||
|
||||
with patch("python_pkg.lichess_bot.main._write_pgn_to_log"):
|
||||
_finalize_game(state, meta) # Should not raise
|
||||
|
||||
def test_finalize_game_analysis_error(self, tmp_path: Path) -> None:
|
||||
"""Test finalize game with analysis error."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header\n")
|
||||
state = GameState(log_path=log_path)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
with (
|
||||
patch("python_pkg.lichess_bot.main._write_pgn_to_log"),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main._run_analysis_subprocess",
|
||||
side_effect=OSError("error"),
|
||||
),
|
||||
):
|
||||
_finalize_game(state, meta) # Should not raise
|
||||
474
python_pkg/lichess_bot/tests/test_main_bot_loop.py
Normal file
474
python_pkg/lichess_bot/tests/test_main_bot_loop.py
Normal file
@ -0,0 +1,474 @@
|
||||
"""Tests for lichess_bot main module: bot event loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import chess
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from python_pkg.lichess_bot.main import (
|
||||
BotContext,
|
||||
GameMeta,
|
||||
GameState,
|
||||
_handle_challenge,
|
||||
_handle_game,
|
||||
_process_bot_event,
|
||||
_process_game_events_loop,
|
||||
_run_event_loop,
|
||||
_run_event_loop_iteration,
|
||||
_safe_event_loop_iteration,
|
||||
_stream_bot_events,
|
||||
main,
|
||||
run_bot,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
# Type aliases to make mypy happy with test event dicts
|
||||
Event = dict[str, Any]
|
||||
GameThreads = dict[str, threading.Thread]
|
||||
|
||||
|
||||
class TestHandleGame:
|
||||
"""Tests for _handle_game."""
|
||||
|
||||
def test_handle_game_success(self, tmp_path: Path) -> None:
|
||||
"""Test handling a game successfully."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
api.stream_game_events.return_value = iter(
|
||||
[
|
||||
{
|
||||
"type": "gameFull",
|
||||
"state": {"moves": "", "status": "started"},
|
||||
"white": {"id": "mybot"},
|
||||
"black": {"id": "opp"},
|
||||
},
|
||||
{"type": "gameState", "moves": "e2e4", "status": "mate"},
|
||||
]
|
||||
)
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (None, "no moves")
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
|
||||
with (
|
||||
patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main._run_analysis_subprocess",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_handle_game("game1", ctx, None)
|
||||
|
||||
def test_handle_game_request_error(self, tmp_path: Path) -> None:
|
||||
"""Test handling a game with request error."""
|
||||
api = MagicMock()
|
||||
api.stream_game_events.side_effect = requests.RequestException("error")
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
|
||||
with (
|
||||
patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main._run_analysis_subprocess",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_handle_game("game1", ctx, None) # Should not raise
|
||||
|
||||
def test_handle_game_skips_chat_events(self, tmp_path: Path) -> None:
|
||||
"""Test handling a game skips chat events."""
|
||||
api = MagicMock()
|
||||
api.stream_game_events.return_value = iter(
|
||||
[
|
||||
{"type": "chatLine", "text": "hello"},
|
||||
{"type": "opponentGone", "gone": True},
|
||||
]
|
||||
)
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
|
||||
with (
|
||||
patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main._run_analysis_subprocess",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_handle_game("game1", ctx, None)
|
||||
|
||||
|
||||
class TestProcessGameEventsLoop:
|
||||
"""Tests for _process_game_events_loop."""
|
||||
|
||||
def test_empty_events_iterator(self) -> None:
|
||||
"""Test processing empty events iterator."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
state = GameState(color="white")
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
empty_iter: list[Event] = []
|
||||
# Should complete without error when iterator is empty
|
||||
_process_game_events_loop(iter(empty_iter), ctx, state, meta)
|
||||
|
||||
def test_processes_all_events(self) -> None:
|
||||
"""Test that all events are processed until break condition."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (None, "no moves")
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(color="white")
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
events: list[Event] = [
|
||||
{"type": "chatLine", "text": "hello"}, # skipped
|
||||
{"type": "gameState", "moves": "e2e4", "status": "resign"}, # game end
|
||||
]
|
||||
_process_game_events_loop(iter(events), ctx, state, meta)
|
||||
|
||||
def test_processes_multiple_game_events(self) -> None:
|
||||
"""Test processing multiple game events that continue the game."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("e2e4"),
|
||||
"e4",
|
||||
)
|
||||
api.make_move.return_value = None
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(color="white")
|
||||
state.board = chess.Board()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
events: list[Event] = [
|
||||
# First event - game state, game continues
|
||||
{"type": "gameState", "moves": "", "status": "started"},
|
||||
# Second event - opponent moves, game continues
|
||||
{"type": "gameState", "moves": "e2e4 e7e5", "status": "started"},
|
||||
# Third event - game ends
|
||||
{"type": "gameState", "moves": "e2e4 e7e5", "status": "mate"},
|
||||
]
|
||||
_process_game_events_loop(iter(events), ctx, state, meta)
|
||||
|
||||
|
||||
class TestRunEventLoop:
|
||||
"""Tests for _run_event_loop."""
|
||||
|
||||
def test_run_event_loop_zero_iterations(self) -> None:
|
||||
"""Test running event loop with zero iterations."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
|
||||
# Should complete immediately with 0 iterations
|
||||
_run_event_loop(ctx, game_threads, 0, 0)
|
||||
|
||||
def test_run_event_loop_limited_iterations(self) -> None:
|
||||
"""Test running event loop with limited iterations."""
|
||||
api = MagicMock()
|
||||
api.stream_bot_events.return_value = iter([])
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
|
||||
with patch(
|
||||
"python_pkg.lichess_bot.main._safe_event_loop_iteration", return_value=0
|
||||
) as mock_iter:
|
||||
_run_event_loop(ctx, game_threads, 0, 3)
|
||||
assert mock_iter.call_count == 3
|
||||
|
||||
def test_run_event_loop_none_iterations_needs_interrupt(self) -> None:
|
||||
"""Test that None iterations runs until interrupted."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
|
||||
call_count = 0
|
||||
|
||||
def stop_after_calls(*_args: object, **_kwargs: object) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 5:
|
||||
raise KeyboardInterrupt
|
||||
return 0
|
||||
|
||||
with (
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main._safe_event_loop_iteration",
|
||||
side_effect=stop_after_calls,
|
||||
),
|
||||
pytest.raises(KeyboardInterrupt),
|
||||
):
|
||||
_run_event_loop(ctx, game_threads, 0, None)
|
||||
|
||||
assert call_count == 5
|
||||
|
||||
|
||||
class TestHandleChallenge:
|
||||
"""Tests for _handle_challenge."""
|
||||
|
||||
def test_accept_standard_blitz(self) -> None:
|
||||
"""Test accepting standard blitz challenge."""
|
||||
api = MagicMock()
|
||||
challenge: Event = {
|
||||
"id": "ch1",
|
||||
"variant": {"key": "standard"},
|
||||
"speed": "blitz",
|
||||
}
|
||||
_handle_challenge(challenge, api, decline_correspondence=False)
|
||||
api.accept_challenge.assert_called_once_with("ch1")
|
||||
|
||||
def test_decline_variant(self) -> None:
|
||||
"""Test declining non-standard variant."""
|
||||
api = MagicMock()
|
||||
challenge: Event = {
|
||||
"id": "ch1",
|
||||
"variant": {"key": "chess960"},
|
||||
"speed": "blitz",
|
||||
}
|
||||
_handle_challenge(challenge, api, decline_correspondence=False)
|
||||
api.decline_challenge.assert_called_once()
|
||||
|
||||
def test_decline_correspondence(self) -> None:
|
||||
"""Test declining correspondence when flag set."""
|
||||
api = MagicMock()
|
||||
challenge: Event = {
|
||||
"id": "ch1",
|
||||
"variant": {"key": "standard"},
|
||||
"speed": "correspondence",
|
||||
}
|
||||
_handle_challenge(challenge, api, decline_correspondence=True)
|
||||
api.decline_challenge.assert_called_once()
|
||||
|
||||
def test_accept_correspondence_when_allowed(self) -> None:
|
||||
"""Test accepting correspondence when flag not set."""
|
||||
api = MagicMock()
|
||||
challenge: Event = {
|
||||
"id": "ch1",
|
||||
"variant": {"key": "standard"},
|
||||
"speed": "correspondence",
|
||||
}
|
||||
_handle_challenge(challenge, api, decline_correspondence=False)
|
||||
api.decline_challenge.assert_called_once() # Still declined due to perf_ok
|
||||
|
||||
def test_invalid_variant_data(self) -> None:
|
||||
"""Test handling invalid variant data."""
|
||||
api = MagicMock()
|
||||
challenge: Event = {
|
||||
"id": "ch1",
|
||||
"variant": "invalid",
|
||||
"speed": "blitz",
|
||||
}
|
||||
_handle_challenge(challenge, api, decline_correspondence=False)
|
||||
api.accept_challenge.assert_called_once()
|
||||
|
||||
|
||||
class TestProcessBotEvent:
|
||||
"""Tests for _process_bot_event."""
|
||||
|
||||
def test_process_challenge_event(self) -> None:
|
||||
"""Test processing challenge event."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {
|
||||
"type": "challenge",
|
||||
"challenge": {
|
||||
"id": "ch1",
|
||||
"variant": {"key": "standard"},
|
||||
"speed": "blitz",
|
||||
},
|
||||
}
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
api.accept_challenge.assert_called_once()
|
||||
|
||||
def test_process_game_start_event(self) -> None:
|
||||
"""Test processing gameStart event."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {"type": "gameStart", "game": {"id": "game1"}}
|
||||
|
||||
with patch("python_pkg.lichess_bot.main.threading.Thread") as mock_thread_class:
|
||||
mock_thread = MagicMock()
|
||||
mock_thread_class.return_value = mock_thread
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
|
||||
assert "game1" in game_threads
|
||||
mock_thread.start.assert_called_once()
|
||||
|
||||
def test_process_game_start_existing_thread(self) -> None:
|
||||
"""Test processing gameStart with existing alive thread."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
mock_thread = MagicMock(spec=threading.Thread)
|
||||
mock_thread.is_alive.return_value = True
|
||||
game_threads: GameThreads = {"game1": mock_thread}
|
||||
event: Event = {"type": "gameStart", "game": {"id": "game1"}}
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
# Should not create new thread
|
||||
assert game_threads["game1"] is mock_thread
|
||||
|
||||
def test_process_game_finish_event(self) -> None:
|
||||
"""Test processing gameFinish event."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {"type": "gameFinish", "game": {"id": "game1"}}
|
||||
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
mock_logger.info.assert_called()
|
||||
|
||||
def test_process_game_finish_invalid_data(self) -> None:
|
||||
"""Test processing gameFinish event with non-dict game data."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {"type": "gameFinish", "game": "not_a_dict"}
|
||||
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
# Should not log info since game data is invalid
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
def test_process_unknown_event(self) -> None:
|
||||
"""Test processing unknown event."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {"type": "unknown", "data": "test"}
|
||||
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
mock_logger.debug.assert_called()
|
||||
|
||||
def test_process_challenge_invalid_data(self) -> None:
|
||||
"""Test processing challenge with invalid data."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {"type": "challenge", "challenge": "invalid"}
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
api.accept_challenge.assert_not_called()
|
||||
|
||||
def test_process_game_start_invalid_data(self) -> None:
|
||||
"""Test processing gameStart with invalid data."""
|
||||
api = MagicMock()
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
event: Event = {"type": "gameStart", "game": "invalid"}
|
||||
_process_bot_event(event, ctx, game_threads)
|
||||
assert len(game_threads) == 0
|
||||
|
||||
|
||||
class TestStreamBotEvents:
|
||||
"""Tests for _stream_bot_events."""
|
||||
|
||||
def test_stream_bot_events(self) -> None:
|
||||
"""Test streaming bot events."""
|
||||
api = MagicMock()
|
||||
api.stream_events.return_value = iter([{"type": "test"}])
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
events = list(_stream_bot_events(ctx))
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
class TestRunEventLoopIteration:
|
||||
"""Tests for _run_event_loop_iteration."""
|
||||
|
||||
def test_run_event_loop_iteration(self) -> None:
|
||||
"""Test running event loop iteration."""
|
||||
api = MagicMock()
|
||||
api.stream_events.return_value = iter([{"type": "unknown"}])
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
result = _run_event_loop_iteration(ctx, game_threads)
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestSafeEventLoopIteration:
|
||||
"""Tests for _safe_event_loop_iteration."""
|
||||
|
||||
def test_safe_event_loop_iteration_success(self) -> None:
|
||||
"""Test safe event loop iteration success."""
|
||||
api = MagicMock()
|
||||
api.stream_events.return_value = iter([])
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
result = _safe_event_loop_iteration(ctx, game_threads, 0)
|
||||
assert result == 0
|
||||
|
||||
def test_safe_event_loop_iteration_error(self) -> None:
|
||||
"""Test safe event loop iteration with error."""
|
||||
api = MagicMock()
|
||||
api.stream_events.side_effect = requests.RequestException("error")
|
||||
ctx = BotContext(api=api, engine=MagicMock(), bot_version=1)
|
||||
game_threads: GameThreads = {}
|
||||
with patch("python_pkg.lichess_bot.main.backoff_sleep", return_value=5):
|
||||
result = _safe_event_loop_iteration(ctx, game_threads, 2)
|
||||
assert result == 5
|
||||
|
||||
|
||||
class TestRunBot:
|
||||
"""Tests for run_bot."""
|
||||
|
||||
def test_run_bot_no_token(self) -> None:
|
||||
"""Test run_bot without token raises error."""
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=True),
|
||||
pytest.raises(RuntimeError, match="LICHESS_TOKEN"),
|
||||
):
|
||||
run_bot()
|
||||
|
||||
def test_run_bot_with_token(self) -> None:
|
||||
"""Test run_bot with token starts event loop."""
|
||||
|
||||
class _StopLoopError(Exception):
|
||||
"""Custom exception to stop the loop."""
|
||||
|
||||
def stop_loop(*_args: object, **_kwargs: object) -> None:
|
||||
raise _StopLoopError
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"LICHESS_TOKEN": "test_token"}),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main.get_and_increment_version",
|
||||
return_value=1,
|
||||
),
|
||||
patch("python_pkg.lichess_bot.main.LichessAPI"),
|
||||
patch("python_pkg.lichess_bot.main.RandomEngine"),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main._safe_event_loop_iteration",
|
||||
side_effect=stop_loop,
|
||||
),
|
||||
pytest.raises(_StopLoopError),
|
||||
):
|
||||
run_bot("DEBUG", decline_correspondence=True)
|
||||
|
||||
|
||||
class TestMain:
|
||||
"""Tests for main function."""
|
||||
|
||||
def test_main_parses_args(self) -> None:
|
||||
"""Test main parses command line arguments."""
|
||||
|
||||
class _StopExecutionError(Exception):
|
||||
"""Custom exception to stop execution."""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"sys.argv",
|
||||
["main.py", "--log-level", "DEBUG", "--decline-correspondence"],
|
||||
),
|
||||
patch(
|
||||
"python_pkg.lichess_bot.main.run_bot", side_effect=_StopExecutionError
|
||||
) as mock_run_bot,
|
||||
pytest.raises(_StopExecutionError),
|
||||
):
|
||||
main()
|
||||
mock_run_bot.assert_called_once_with("DEBUG", decline_correspondence=True)
|
||||
403
python_pkg/lichess_bot/tests/test_main_game_state.py
Normal file
403
python_pkg/lichess_bot/tests/test_main_game_state.py
Normal file
@ -0,0 +1,403 @@
|
||||
"""Tests for lichess_bot main module: game state helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import chess
|
||||
import requests
|
||||
|
||||
from python_pkg.lichess_bot.main import (
|
||||
BotContext,
|
||||
GameMeta,
|
||||
GameState,
|
||||
_apply_move_to_board,
|
||||
_attempt_move,
|
||||
_calculate_time_budget,
|
||||
_extract_game_full_data,
|
||||
_extract_game_state_data,
|
||||
_extract_player_info,
|
||||
_handle_move_if_needed,
|
||||
_init_game_log,
|
||||
_is_my_turn,
|
||||
_log_move_to_file,
|
||||
_rebuild_board_from_moves,
|
||||
_update_clocks_from_state,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
# Type alias to make mypy happy with test event dicts
|
||||
Event = dict[str, Any]
|
||||
|
||||
|
||||
class TestApplyMoveToBoard:
|
||||
"""Tests for _apply_move_to_board."""
|
||||
|
||||
def test_apply_valid_move(self) -> None:
|
||||
"""Test applying a valid move."""
|
||||
board = chess.Board()
|
||||
_apply_move_to_board(board, "e2e4", "game1")
|
||||
assert board.fen() != chess.STARTING_FEN
|
||||
|
||||
def test_apply_invalid_move(self) -> None:
|
||||
"""Test applying an invalid move logs debug."""
|
||||
board = chess.Board()
|
||||
with patch("python_pkg.lichess_bot.main._logger") as mock_logger:
|
||||
_apply_move_to_board(board, "invalid", "game1")
|
||||
mock_logger.debug.assert_called_once()
|
||||
|
||||
|
||||
class TestInitGameLog:
|
||||
"""Tests for _init_game_log."""
|
||||
|
||||
def test_init_game_log_success(self, tmp_path: Path) -> None:
|
||||
"""Test successful log initialization."""
|
||||
with patch("python_pkg.lichess_bot.main.Path.cwd", return_value=tmp_path):
|
||||
result = _init_game_log("game123", 42)
|
||||
assert result is not None
|
||||
assert result.exists()
|
||||
content = result.read_text()
|
||||
assert "game game123 started" in content
|
||||
assert "bot_version v42" in content
|
||||
|
||||
def test_init_game_log_oserror(self) -> None:
|
||||
"""Test log initialization with OSError."""
|
||||
with patch("python_pkg.lichess_bot.main.Path.cwd") as mock_cwd:
|
||||
mock_path = MagicMock()
|
||||
mock_path.__truediv__ = MagicMock(return_value=mock_path)
|
||||
mock_path.open.side_effect = OSError("Permission denied")
|
||||
mock_cwd.return_value = mock_path
|
||||
result = _init_game_log("game123", 42)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpdateClocksFromState:
|
||||
"""Tests for _update_clocks_from_state."""
|
||||
|
||||
def test_update_clocks_white(self) -> None:
|
||||
"""Test clock update when playing as white."""
|
||||
state = GameState(color="white")
|
||||
state_data: Event = {"wtime": 60000, "btime": 55000, "winc": 1000}
|
||||
_update_clocks_from_state(state_data, state)
|
||||
assert state.my_ms == 60000
|
||||
assert state.opp_ms == 55000
|
||||
assert state.inc_ms == 1000
|
||||
|
||||
def test_update_clocks_black(self) -> None:
|
||||
"""Test clock update when playing as black."""
|
||||
state = GameState(color="black")
|
||||
state_data: Event = {"wtime": 60000, "btime": 55000, "binc": 2000}
|
||||
_update_clocks_from_state(state_data, state)
|
||||
assert state.my_ms == 55000
|
||||
assert state.opp_ms == 60000
|
||||
assert state.inc_ms == 2000
|
||||
|
||||
def test_update_clocks_float_values(self) -> None:
|
||||
"""Test clock update with float values."""
|
||||
state = GameState(color="white")
|
||||
state_data: Event = {"wtime": 60000.5, "btime": 55000.5}
|
||||
_update_clocks_from_state(state_data, state)
|
||||
assert state.my_ms == 60000
|
||||
assert state.opp_ms == 55000
|
||||
|
||||
def test_update_clocks_none_values(self) -> None:
|
||||
"""Test clock update with None values."""
|
||||
state = GameState(color="white")
|
||||
state_data: Event = {"wtime": None, "btime": None}
|
||||
_update_clocks_from_state(state_data, state)
|
||||
assert state.my_ms is None
|
||||
assert state.opp_ms is None
|
||||
|
||||
|
||||
class TestExtractPlayerInfo:
|
||||
"""Tests for _extract_player_info."""
|
||||
|
||||
def test_extract_player_info_white(self) -> None:
|
||||
"""Test extracting player info when bot is white."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"white": {"id": "mybot", "name": "MyBot"},
|
||||
"black": {"id": "opp", "name": "Opponent"},
|
||||
}
|
||||
_extract_player_info(event, state, meta, api)
|
||||
assert state.color == "white"
|
||||
assert meta.white_name == "MyBot"
|
||||
assert meta.black_name == "Opponent"
|
||||
|
||||
def test_extract_player_info_black(self) -> None:
|
||||
"""Test extracting player info when bot is black."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"white": {"id": "opp", "name": "Opponent"},
|
||||
"black": {"id": "mybot", "name": "MyBot"},
|
||||
}
|
||||
_extract_player_info(event, state, meta, api)
|
||||
assert state.color == "black"
|
||||
|
||||
def test_extract_player_info_invalid_data(self) -> None:
|
||||
"""Test extracting player info with invalid data."""
|
||||
api = MagicMock()
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {"white": "invalid", "black": "invalid"}
|
||||
_extract_player_info(event, state, meta, api)
|
||||
assert state.color is None
|
||||
|
||||
def test_extract_player_info_missing_name(self) -> None:
|
||||
"""Test extracting player info with missing name uses id."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"white": {"id": "mybot"},
|
||||
"black": {"id": "opponent"},
|
||||
}
|
||||
_extract_player_info(event, state, meta, api)
|
||||
assert meta.white_name == "mybot"
|
||||
assert meta.black_name == "opponent"
|
||||
|
||||
|
||||
class TestExtractGameFullData:
|
||||
"""Tests for _extract_game_full_data."""
|
||||
|
||||
def test_extract_game_full_data(self) -> None:
|
||||
"""Test extracting gameFull data."""
|
||||
api = MagicMock()
|
||||
api.get_my_user_id.return_value = "mybot"
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {
|
||||
"state": {"moves": "e2e4 e7e5", "status": "started", "wtime": 60000},
|
||||
"white": {"id": "mybot"},
|
||||
"black": {"id": "opp"},
|
||||
"createdAt": 1609459200000, # 2021-01-01
|
||||
}
|
||||
moves, status = _extract_game_full_data(event, state, meta, api)
|
||||
assert moves == "e2e4 e7e5"
|
||||
assert status == "started"
|
||||
assert meta.site_url == "https://lichess.org/game1"
|
||||
assert meta.date_iso == "2021.01.01"
|
||||
|
||||
def test_extract_game_full_data_invalid_state(self) -> None:
|
||||
"""Test extracting gameFull data with invalid state."""
|
||||
api = MagicMock()
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
event: Event = {"state": "invalid"}
|
||||
moves, status = _extract_game_full_data(event, state, meta, api)
|
||||
assert moves == ""
|
||||
assert status is None
|
||||
|
||||
|
||||
class TestExtractGameStateData:
|
||||
"""Tests for _extract_game_state_data."""
|
||||
|
||||
def test_extract_game_state_as_white(self) -> None:
|
||||
"""Test extracting gameState data as white."""
|
||||
state = GameState(color="white", my_ms=60000)
|
||||
event: Event = {
|
||||
"moves": "e2e4",
|
||||
"status": "started",
|
||||
"wtime": 59000,
|
||||
"btime": 60000,
|
||||
}
|
||||
moves, status = _extract_game_state_data(event, state)
|
||||
assert moves == "e2e4"
|
||||
assert status == "started"
|
||||
assert state.my_ms == 59000
|
||||
assert state.opp_ms == 60000
|
||||
|
||||
def test_extract_game_state_as_black(self) -> None:
|
||||
"""Test extracting gameState data as black."""
|
||||
state = GameState(color="black")
|
||||
event: Event = {
|
||||
"moves": "e2e4 e7e5",
|
||||
"wtime": 60000,
|
||||
"btime": 59000,
|
||||
"binc": 1000,
|
||||
}
|
||||
moves, __status = _extract_game_state_data(event, state)
|
||||
assert moves == "e2e4 e7e5"
|
||||
assert state.my_ms == 59000
|
||||
assert state.opp_ms == 60000
|
||||
assert state.inc_ms == 1000
|
||||
|
||||
|
||||
class TestCalculateTimeBudget:
|
||||
"""Tests for _calculate_time_budget."""
|
||||
|
||||
def test_calculate_time_budget_normal(self) -> None:
|
||||
"""Test time budget calculation."""
|
||||
state = GameState(my_ms=60000, inc_ms=1000)
|
||||
board = chess.Board()
|
||||
budget = _calculate_time_budget(state, board, 10.0)
|
||||
assert 0.05 <= budget <= 10.0
|
||||
|
||||
def test_calculate_time_budget_low_time(self) -> None:
|
||||
"""Test time budget with low time."""
|
||||
state = GameState(my_ms=1000, inc_ms=0)
|
||||
board = chess.Board()
|
||||
budget = _calculate_time_budget(state, board, 10.0)
|
||||
assert budget >= 0.05
|
||||
|
||||
|
||||
class TestLogMoveToFile:
|
||||
"""Tests for _log_move_to_file."""
|
||||
|
||||
def test_log_move_to_file(self, tmp_path: Path) -> None:
|
||||
"""Test logging a move to file."""
|
||||
log_path = tmp_path / "game.log"
|
||||
log_path.write_text("header\n")
|
||||
move = chess.Move.from_uci("e2e4")
|
||||
_log_move_to_file(log_path, 1, move, "best move")
|
||||
content = log_path.read_text()
|
||||
assert "ply 1: e2e4" in content
|
||||
assert "best move" in content
|
||||
|
||||
def test_log_move_to_file_none_path(self) -> None:
|
||||
"""Test logging with None path does nothing."""
|
||||
move = chess.Move.from_uci("e2e4")
|
||||
_log_move_to_file(None, 1, move, "reason") # Should not raise
|
||||
|
||||
|
||||
class TestAttemptMove:
|
||||
"""Tests for _attempt_move."""
|
||||
|
||||
def test_attempt_move_success(self) -> None:
|
||||
"""Test successful move attempt."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("e2e4"),
|
||||
"opening",
|
||||
)
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(my_ms=60000)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
board = chess.Board()
|
||||
|
||||
result = _attempt_move(ctx, state, meta, board)
|
||||
assert result is True
|
||||
api.make_move.assert_called_once()
|
||||
|
||||
def test_attempt_move_no_moves(self) -> None:
|
||||
"""Test move attempt with no legal moves."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (None, "no moves")
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState()
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
board = chess.Board()
|
||||
|
||||
result = _attempt_move(ctx, state, meta, board)
|
||||
assert result is False
|
||||
|
||||
def test_attempt_move_illegal(self) -> None:
|
||||
"""Test move attempt with illegal move."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
# Return a move that's not legal (e.g., random square move)
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("a1a8"),
|
||||
"bad",
|
||||
)
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(my_ms=60000)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
board = chess.Board()
|
||||
|
||||
result = _attempt_move(ctx, state, meta, board)
|
||||
assert result is True
|
||||
api.make_move.assert_not_called()
|
||||
|
||||
def test_attempt_move_request_error(self) -> None:
|
||||
"""Test move attempt with request error."""
|
||||
api = MagicMock()
|
||||
api.make_move.side_effect = requests.RequestException("Network error")
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("e2e4"),
|
||||
"opening",
|
||||
)
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(my_ms=60000)
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
board = chess.Board()
|
||||
|
||||
result = _attempt_move(ctx, state, meta, board)
|
||||
assert result is True # Still returns True
|
||||
|
||||
|
||||
class TestIsMyTurn:
|
||||
"""Tests for _is_my_turn."""
|
||||
|
||||
def test_is_my_turn_white_to_move(self) -> None:
|
||||
"""Test checking turn when white to move."""
|
||||
board = chess.Board() # White to move
|
||||
assert _is_my_turn(board, "white") is True
|
||||
assert _is_my_turn(board, "black") is False
|
||||
|
||||
def test_is_my_turn_black_to_move(self) -> None:
|
||||
"""Test checking turn when black to move."""
|
||||
board = chess.Board()
|
||||
board.push_uci("e2e4") # Black to move
|
||||
assert _is_my_turn(board, "white") is False
|
||||
assert _is_my_turn(board, "black") is True
|
||||
|
||||
|
||||
class TestRebuildBoardFromMoves:
|
||||
"""Tests for _rebuild_board_from_moves."""
|
||||
|
||||
def test_rebuild_board_from_moves(self) -> None:
|
||||
"""Test rebuilding board from moves list."""
|
||||
moves_list = ["e2e4", "e7e5", "g1f3"]
|
||||
board = _rebuild_board_from_moves(moves_list, "game1")
|
||||
assert len(board.move_stack) == 3
|
||||
|
||||
|
||||
class TestHandleMoveIfNeeded:
|
||||
"""Tests for _handle_move_if_needed."""
|
||||
|
||||
def test_handle_move_game_state_my_turn(self) -> None:
|
||||
"""Test handling move on gameState when my turn."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
engine.max_time_sec = 5.0
|
||||
engine.choose_move_with_explanation.return_value = (
|
||||
chess.Move.from_uci("e2e4"),
|
||||
"opening",
|
||||
)
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(color="white", my_ms=60000, board=chess.Board())
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
result = _handle_move_if_needed(ctx, state, meta, "gameState", 0)
|
||||
assert result is True
|
||||
|
||||
def test_handle_move_game_full_with_moves(self) -> None:
|
||||
"""Test handling move on gameFull with existing moves (opponent's turn)."""
|
||||
api = MagicMock()
|
||||
engine = MagicMock()
|
||||
ctx = BotContext(api=api, engine=engine, bot_version=1)
|
||||
state = GameState(color="white", my_ms=60000, board=chess.Board())
|
||||
meta = GameMeta(game_id="game1", bot_version=1)
|
||||
|
||||
# gameFull with moves - don't move
|
||||
result = _handle_move_if_needed(ctx, state, meta, "gameFull", 1)
|
||||
assert result is True
|
||||
engine.choose_move_with_explanation.assert_not_called()
|
||||
445
python_pkg/praca_magisterska_video/_q23_classical.py
Normal file
445
python_pkg/praca_magisterska_video/_q23_classical.py
Normal file
@ -0,0 +1,445 @@
|
||||
"""Classical segmentation methods: concept, thresholding, region growing, watershed."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from moviepy import (
|
||||
CompositeVideoClip,
|
||||
VideoClip,
|
||||
)
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
from python_pkg.praca_magisterska_video._q23_helpers import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_tc,
|
||||
)
|
||||
|
||||
|
||||
# ── Segmentation concept ─────────────────────────────────────────
|
||||
def _segmentation_concept() -> list[CompositeVideoClip]:
|
||||
"""Show what segmentation is: pixel-level labeling."""
|
||||
slides = []
|
||||
|
||||
# Synthetic image: grid of colored pixels
|
||||
def make_image_frame(_t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
# Draw a small "image" grid
|
||||
grid_x, grid_y = 100, 150
|
||||
cell = 40
|
||||
# Sky (top rows)
|
||||
colors_map = [
|
||||
[(135, 206, 235)] * 8, # sky
|
||||
[(135, 206, 235)] * 5 + [(34, 139, 34)] * 3, # sky + tree
|
||||
[(34, 139, 34)] * 3
|
||||
+ [(128, 128, 128)] * 3
|
||||
+ [(34, 139, 34)] * 2, # tree+road+tree
|
||||
[(128, 128, 128)] * 3
|
||||
+ [(200, 50, 50)] * 2
|
||||
+ [(128, 128, 128)] * 3, # road+car+road
|
||||
]
|
||||
labels_map = [
|
||||
["niebo"] * 8,
|
||||
["niebo"] * 5 + ["drzewo"] * 3,
|
||||
["drzewo"] * 3 + ["droga"] * 3 + ["drzewo"] * 2,
|
||||
["droga"] * 3 + ["samochód"] * 2 + ["droga"] * 3,
|
||||
]
|
||||
label_colors = {
|
||||
"niebo": (100, 180, 255),
|
||||
"drzewo": (50, 200, 50),
|
||||
"droga": (180, 180, 180),
|
||||
"samochód": (255, 80, 80),
|
||||
}
|
||||
|
||||
for r, row in enumerate(colors_map):
|
||||
for c, col in enumerate(row):
|
||||
y = grid_y + r * cell
|
||||
x = grid_x + c * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = col
|
||||
|
||||
# Draw segmentation map on the right
|
||||
seg_x = 600
|
||||
for r, row in enumerate(labels_map):
|
||||
for c, lab in enumerate(row):
|
||||
y = grid_y + r * cell
|
||||
x = seg_x + c * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = label_colors[lab]
|
||||
|
||||
return frame
|
||||
|
||||
image_clip = VideoClip(make_image_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
labels_text = [
|
||||
("Obraz wejściowy", 22, "white", FONT_B, (170, 100)),
|
||||
("Mapa segmentacji", 22, "white", FONT_B, (660, 100)),
|
||||
("→", 50, "#FFE082", FONT_B, (450, 250)),
|
||||
("Każdy piksel → etykieta klasy", 20, "#B0BEC5", FONT_R, (100, 420)),
|
||||
("niebo | drzewo | droga | samochód", 18, "#90CAF9", FONT_R, (600, 420)),
|
||||
("Segmentacja = klasyfikacja per-piksel", 24, "#FFE082", FONT_B, (100, 500)),
|
||||
(
|
||||
"Semantic: klasy bez instancji | Instance: "
|
||||
"rozróżnia obiekty | Panoptic: oba",
|
||||
16,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 560),
|
||||
),
|
||||
]
|
||||
clips: list[VideoClip] = [image_clip]
|
||||
for text, fs, color, font, pos in labels_text:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(clips, size=(W, H)).with_effects([FadeIn(0.3), FadeOut(0.3)])
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── Thresholding / Otsu ───────────────────────────────────────────
|
||||
def _thresholding_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate thresholding and Otsu concept."""
|
||||
slides = []
|
||||
|
||||
# Show histogram & threshold
|
||||
def make_threshold_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
# Draw bimodal histogram bars
|
||||
bar_start_x = 80
|
||||
bar_y = 500
|
||||
bar_w = 4
|
||||
|
||||
for i in range(256):
|
||||
# Bimodal: peaks at 60 and 190
|
||||
h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2))
|
||||
h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2))
|
||||
bar_h = int(h1 + h2)
|
||||
x = bar_start_x + i * bar_w
|
||||
if x + bar_w < W:
|
||||
frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (150, 150, 170)
|
||||
|
||||
# Animated threshold line
|
||||
threshold = int(60 + (190 - 60) * min(t / (STEP_DUR * 0.7), 1.0))
|
||||
tx = bar_start_x + threshold * bar_w
|
||||
if tx < W:
|
||||
frame[bar_y - 250 : bar_y + 10, tx : tx + 3] = (255, 80, 80)
|
||||
|
||||
# Color the two sides
|
||||
for i in range(threshold):
|
||||
x = bar_start_x + i * bar_w
|
||||
h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2))
|
||||
h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2))
|
||||
bar_h = int(h1 + h2)
|
||||
if x + bar_w < W and bar_h > 0:
|
||||
frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (70, 130, 200)
|
||||
|
||||
for i in range(threshold, 256):
|
||||
x = bar_start_x + i * bar_w
|
||||
h1 = 200 * np.exp(-((i - 60) ** 2) / (2 * 20**2))
|
||||
h2 = 150 * np.exp(-((i - 190) ** 2) / (2 * 25**2))
|
||||
bar_h = int(h1 + h2)
|
||||
if x + bar_w < W and bar_h > 0:
|
||||
frame[bar_y - bar_h : bar_y, x : x + bar_w - 1] = (200, 100, 80)
|
||||
|
||||
return frame
|
||||
|
||||
hist_clip = VideoClip(make_threshold_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [hist_clip]
|
||||
labels = [
|
||||
("Progowanie (Thresholding) z metodą Otsu", 28, "#FFE082", FONT_B, (80, 30)),
|
||||
(
|
||||
"Histogram jasności pikseli — dwumodalny (bimodal)",
|
||||
20,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 80),
|
||||
),
|
||||
("Garb 1: piksele obiektu (ciemne ~60)", 16, "#64B5F6", FONT_R, (80, 120)),
|
||||
("Garb 2: piksele tła (jasne ~190)", 16, "#EF9A9A", FONT_R, (80, 150)),
|
||||
(
|
||||
"Próg T (czerwona linia) dzieli piksele na 2 klasy",
|
||||
18,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 540),
|
||||
),
|
||||
(
|
||||
"Otsu: automatycznie testuje T=0..255, minimalizuje σ² wewnątrzklasową",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 580),
|
||||
),
|
||||
(
|
||||
"Piksel ≤ T → klasa 0 (tło) | Piksel > T → klasa 1 (obiekt)",
|
||||
16,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── Region Growing ────────────────────────────────────────────────
|
||||
def _region_growing_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate region growing BFS from a seed pixel."""
|
||||
slides = []
|
||||
|
||||
grid_size = 10
|
||||
cell_size = 40
|
||||
rng = np.random.default_rng(42)
|
||||
# Create a simple grid: dark region (30-80) and bright region (160-220)
|
||||
grid = np.zeros((grid_size, grid_size), dtype=np.uint8)
|
||||
grid[:] = 60 # dark background
|
||||
grid[2:7, 3:8] = 180 # bright rectangle
|
||||
|
||||
# Add some noise
|
||||
noise = rng.integers(-15, 15, (grid_size, grid_size))
|
||||
grid = np.clip(grid.astype(int) + noise, 0, 255).astype(np.uint8)
|
||||
|
||||
# BFS steps from seed (4, 5)
|
||||
seed = (4, 5)
|
||||
threshold_val = 50
|
||||
visited_order: list[tuple[int, int]] = []
|
||||
queue = [seed]
|
||||
visited_set = {seed}
|
||||
while queue:
|
||||
r, c = queue.pop(0)
|
||||
visited_order.append((r, c))
|
||||
for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
|
||||
nr, nc = r + dr, c + dc
|
||||
if (
|
||||
0 <= nr < grid_size
|
||||
and 0 <= nc < grid_size
|
||||
and (nr, nc) not in visited_set
|
||||
) and abs(int(grid[nr, nc]) - int(grid[seed])) < threshold_val:
|
||||
visited_set.add((nr, nc))
|
||||
queue.append((nr, nc))
|
||||
|
||||
def make_region_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
ox, oy = 100, 180
|
||||
|
||||
# How many cells to show as visited
|
||||
progress = min(t / (STEP_DUR * 0.8), 1.0)
|
||||
n_visited = int(progress * len(visited_order))
|
||||
|
||||
for r in range(grid_size):
|
||||
for c in range(grid_size):
|
||||
x = ox + c * cell_size
|
||||
y = oy + r * cell_size
|
||||
val = grid[r, c]
|
||||
color = (val, val, val)
|
||||
|
||||
# Highlight visited
|
||||
if (r, c) in visited_order[:n_visited]:
|
||||
color = (80, 200, 120) # green for region
|
||||
elif (r, c) == seed:
|
||||
color = (255, 200, 50) # yellow seed
|
||||
|
||||
frame[y : y + cell_size - 2, x : x + cell_size - 2] = color
|
||||
|
||||
# Mark the seed with a bright border
|
||||
sx = ox + seed[1] * cell_size
|
||||
sy = ox + seed[0] * cell_size + 80
|
||||
frame[sy : sy + cell_size, sx : sx + 2] = (255, 200, 50)
|
||||
frame[sy : sy + cell_size, sx + cell_size - 2 : sx + cell_size] = (255, 200, 50)
|
||||
frame[sy : sy + 2, sx : sx + cell_size] = (255, 200, 50)
|
||||
frame[sy + cell_size - 2 : sy + cell_size, sx : sx + cell_size] = (255, 200, 50)
|
||||
|
||||
return frame
|
||||
|
||||
region_clip = VideoClip(make_region_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [region_clip]
|
||||
labels = [
|
||||
("Region Growing — rozrastanie regionu", 28, "#FFE082", FONT_B, (100, 30)),
|
||||
("Seed (ziarno) → BFS do podobnych sąsiadów", 20, "#B0BEC5", FONT_R, (100, 80)),
|
||||
(
|
||||
"Żółty = seed | Zielony = region | Szary = nieodwiedzone",
|
||||
16,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 120),
|
||||
),
|
||||
(
|
||||
"Sąsiad PODOBNY (|jasność - jasność_regionu| < próg) → dodaj do regionu",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 600),
|
||||
),
|
||||
(
|
||||
"Algorytm zatrzymuje się gdy brak podobnych sąsiadów",
|
||||
16,
|
||||
"#90CAF9",
|
||||
FONT_R,
|
||||
(100, 640),
|
||||
),
|
||||
(
|
||||
"Mnemonik: PLAMA atramentu — rozlewa się na podobne piksele",
|
||||
18,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(100, 670),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── Watershed ─────────────────────────────────────────────────────
|
||||
def _watershed_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate watershed flooding concept."""
|
||||
slides = []
|
||||
|
||||
def make_watershed_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
# Draw terrain profile (1D cross-section)
|
||||
ox, oy = 100, 450
|
||||
terrain_w = 900
|
||||
terrain_points = 100
|
||||
|
||||
xs = np.linspace(0, 1, terrain_points)
|
||||
# Two valleys with a ridge
|
||||
terrain = (
|
||||
120 * np.exp(-((xs - 0.25) ** 2) / 0.005)
|
||||
+ 80 * np.exp(-((xs - 0.75) ** 2) / 0.008)
|
||||
+ 30
|
||||
)
|
||||
terrain = 250 - terrain # invert for visual (valleys at bottom)
|
||||
|
||||
# Water level rises over time
|
||||
water_level = int(160 + 80 * min(t / (STEP_DUR * 0.7), 1.0))
|
||||
|
||||
for i in range(terrain_points - 1):
|
||||
x1 = ox + int(xs[i] * terrain_w)
|
||||
x2 = ox + int(xs[i + 1] * terrain_w)
|
||||
y1 = oy - int(terrain[i])
|
||||
y2 = oy - int(terrain[i + 1])
|
||||
|
||||
# Fill terrain
|
||||
for x in range(x1, min(x2 + 1, W)):
|
||||
top = min(y1, y2) - 5
|
||||
frame[top:oy, x : x + 1] = (100, 80, 60)
|
||||
|
||||
# Fill water
|
||||
water_y = oy - water_level
|
||||
for x in range(x1, min(x2 + 1, W)):
|
||||
t_y = oy - int(terrain[i])
|
||||
if water_y < t_y:
|
||||
# Water fills below terrain surface
|
||||
fill_top = max(water_y, 0)
|
||||
fill_bot = min(t_y, oy)
|
||||
if fill_top < fill_bot:
|
||||
frame[fill_top:fill_bot, x : x + 1] = (70, 130, 220)
|
||||
|
||||
# Dam marker at ridge
|
||||
ridge_x = ox + int(0.5 * terrain_w)
|
||||
dam_visible_threshold = 160
|
||||
if water_level > dam_visible_threshold:
|
||||
frame[oy - water_level : oy - 140, ridge_x - 2 : ridge_x + 2] = (
|
||||
255,
|
||||
80,
|
||||
80,
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
ws_clip = VideoClip(make_watershed_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [ws_clip]
|
||||
labels = [
|
||||
("Watershed — metoda zlewiska", 28, "#FFE082", FONT_B, (100, 20)),
|
||||
(
|
||||
"Obraz = mapa topograficzna (jasność = wysokość)",
|
||||
20,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(100, 65),
|
||||
),
|
||||
(
|
||||
"Brązowy = teren (ciemne=doliny, jasne=szczyty)",
|
||||
16,
|
||||
"#8D6E63",
|
||||
FONT_R,
|
||||
(100, 100),
|
||||
),
|
||||
("Niebieski = woda zalewająca od minimów", 16, "#64B5F6", FONT_R, (100, 130)),
|
||||
(
|
||||
"Czerwony = TAMA (granica segmentu) — gdy woda z 2 dolin się spotka",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(100, 160),
|
||||
),
|
||||
(
|
||||
"Problem: over-segmentation "
|
||||
"(za dużo regionów). "
|
||||
"Rozwiązanie: marker-controlled.",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 560),
|
||||
),
|
||||
(
|
||||
"Mnemonik: ZALEWANIE terenu — granie gór = granice segmentów",
|
||||
18,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 600),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
248
python_pkg/praca_magisterska_video/_q23_deeplab.py
Normal file
248
python_pkg/praca_magisterska_video/_q23_deeplab.py
Normal file
@ -0,0 +1,248 @@
|
||||
"""DeepLab architecture animations for Q23 segmentation video."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from moviepy import (
|
||||
CompositeVideoClip,
|
||||
VideoClip,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
from python_pkg.praca_magisterska_video._q23_helpers import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_compose_slide,
|
||||
)
|
||||
|
||||
|
||||
# ── DeepLab Architecture ─────────────────────────────────────────
|
||||
def _make_dilated_frame(t: float) -> np.ndarray:
|
||||
"""Render a dilated convolution comparison frame."""
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
cell = 36
|
||||
grids = [
|
||||
(
|
||||
"rate=1",
|
||||
60,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 1),
|
||||
(0, 2),
|
||||
(1, 0),
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(2, 0),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
],
|
||||
),
|
||||
(
|
||||
"rate=2",
|
||||
420,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 2),
|
||||
(0, 4),
|
||||
(2, 0),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 0),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
],
|
||||
),
|
||||
(
|
||||
"rate=3",
|
||||
820,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 3),
|
||||
(0, 6),
|
||||
(3, 0),
|
||||
(3, 3),
|
||||
(3, 6),
|
||||
(6, 0),
|
||||
(6, 3),
|
||||
(6, 6),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for gi, (_label, gx, positions) in enumerate(grids):
|
||||
if progress < gi * 0.3:
|
||||
break
|
||||
gy = 180
|
||||
grid_size = 7
|
||||
for r in range(grid_size):
|
||||
for c in range(grid_size):
|
||||
x = gx + c * cell
|
||||
y = gy + r * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55)
|
||||
for r, c in positions:
|
||||
x = gx + c * cell
|
||||
y = gy + r * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200)
|
||||
frame[y : y + 2, x : x + cell - 2] = (120, 180, 255)
|
||||
frame[y + cell - 4 : y + cell - 2, x : x + cell - 2] = (120, 180, 255)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _make_aspp_frame(t: float) -> np.ndarray:
|
||||
"""Render a single ASPP module animation frame."""
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
frame[250:330, 50:130] = (70, 130, 200)
|
||||
frame[250:252, 50:130] = (120, 180, 255)
|
||||
frame[328:330, 50:130] = (120, 180, 255)
|
||||
|
||||
branches = [
|
||||
("1x1 conv", 250, (200, 170), (100, 40), (80, 200, 120)),
|
||||
("rate=6", 310, (200, 250), (100, 40), (200, 160, 80)),
|
||||
("rate=12", 370, (200, 330), (100, 40), (200, 120, 60)),
|
||||
("rate=18", 430, (200, 410), (100, 40), (180, 100, 80)),
|
||||
("GAP", 490, (200, 490), (100, 40), (160, 80, 160)),
|
||||
]
|
||||
n_branches = min(int(progress * 5) + 1, 5)
|
||||
for i, (_lbl, _h, (bx, by), (bw, bh), color) in enumerate(branches):
|
||||
if i < n_branches:
|
||||
frame[by : by + bh, bx : bx + bw] = color
|
||||
frame[by : by + 2, bx : bx + bw] = tuple(min(c + 50, 255) for c in color)
|
||||
ay = by + bh // 2
|
||||
frame[ay - 1 : ay + 2, 133:197] = (150, 150, 170)
|
||||
|
||||
concat_phase = 0.6
|
||||
if progress > concat_phase:
|
||||
frame[250:530, 380:420] = (50, 60, 80)
|
||||
frame[250:252, 380:420] = (200, 200, 100)
|
||||
frame[528:530, 380:420] = (200, 200, 100)
|
||||
for i, (_lbl, _h, (bx, by), (bw, bh), _c) in enumerate(branches):
|
||||
if i < n_branches:
|
||||
ay = by + bh // 2
|
||||
frame[ay - 1 : ay + 2, bx + bw + 3 : 378] = (150, 150, 170)
|
||||
|
||||
final_conv_phase = 0.8
|
||||
if progress > final_conv_phase:
|
||||
frame[350:420, 450:550] = (100, 200, 100)
|
||||
frame[350:352, 450:550] = (150, 230, 150)
|
||||
frame[418:420, 450:550] = (150, 230, 150)
|
||||
frame[388:391, 423:448] = (150, 150, 170)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _deeplab_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate DeepLab: dilated convolution + ASPP step by step."""
|
||||
dur = STEP_DUR + 1
|
||||
|
||||
# Slide 1: Regular vs Dilated convolution
|
||||
dil_clip = VideoClip(_make_dilated_frame, duration=dur).with_fps(FPS)
|
||||
labels = [
|
||||
("DeepLab: Atrous (Dilated) Convolution", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"KROK 1: Zrozum dilated convolution — filtr z DZIURAMI",
|
||||
18,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 60),
|
||||
),
|
||||
("rate=1 (zwykła)", 14, "#64B5F6", FONT_B, (60, 160)),
|
||||
("RF = 3x3", 14, "#64B5F6", FONT_R, (60, 440)),
|
||||
("9 wag, kontekst 3px", 12, "#78909C", FONT_R, (60, 470)),
|
||||
("rate=2 (dilated)", 14, "#FFE082", FONT_B, (420, 160)),
|
||||
("RF = 5x5", 14, "#FFE082", FONT_R, (420, 440)),
|
||||
("9 wag, kontekst 5px!", 12, "#78909C", FONT_R, (420, 470)),
|
||||
("rate=3 (dilated)", 14, "#A5D6A7", FONT_B, (820, 160)),
|
||||
("RF = 7x7", 14, "#A5D6A7", FONT_R, (820, 440)),
|
||||
("9 wag, kontekst 7px!", 12, "#78909C", FONT_R, (820, 470)),
|
||||
(
|
||||
"Niebieski = pozycja wag filtra 3x3 | Szary = pominięte (dziury)",
|
||||
15,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 510),
|
||||
),
|
||||
(
|
||||
"TE SAME 9 wag → WIĘKSZE pole widzenia "
|
||||
"→ lepszy kontekst BEZ dodatkowych parametrów!",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 550),
|
||||
),
|
||||
(
|
||||
"Mnemonik: DZIURY w filtrze — à trous = z dziurami (fr.)",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 600),
|
||||
),
|
||||
]
|
||||
slides = [_compose_slide(dil_clip, labels, dur)]
|
||||
|
||||
# Slide 2: ASPP module step by step
|
||||
aspp_clip = VideoClip(_make_aspp_frame, duration=dur).with_fps(FPS)
|
||||
labels2 = [
|
||||
(
|
||||
"DeepLab: ASPP (Atrous Spatial Pyramid Pooling)",
|
||||
24,
|
||||
"#FFE082",
|
||||
FONT_B,
|
||||
(80, 20),
|
||||
),
|
||||
(
|
||||
"KROK 2: Multi-scale — analizuj obraz na WIELU skalach naraz",
|
||||
17,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 60),
|
||||
),
|
||||
("Wejście", 13, "#64B5F6", FONT_B, (55, 235)),
|
||||
("Conv 1x1", 12, "white", FONT_R, (210, 178)),
|
||||
("Dilated r=6", 12, "white", FONT_R, (205, 258)),
|
||||
("Dilated r=12", 12, "white", FONT_R, (203, 338)),
|
||||
("Dilated r=18", 12, "white", FONT_R, (203, 418)),
|
||||
("GAP (global)", 12, "white", FONT_R, (205, 498)),
|
||||
("Concat", 13, "#FFE082", FONT_B, (381, 537)),
|
||||
("Conv", 13, "#A5D6A7", FONT_B, (470, 425)),
|
||||
(
|
||||
"5 gałęzi RÓWNOLEGŁYCH → różne skale kontekstu:",
|
||||
16,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(550, 170),
|
||||
),
|
||||
(" 1x1: kontekst punktowy (piksel)", 14, "#A5D6A7", FONT_R, (560, 210)),
|
||||
(" r=6: kontekst lokalny (~13px)", 14, "#FFE082", FONT_R, (560, 245)),
|
||||
(" r=12: kontekst średni (~25px)", 14, "#FFE082", FONT_R, (560, 280)),
|
||||
(" r=18: kontekst szeroki (~37px)", 14, "#FFE082", FONT_R, (560, 315)),
|
||||
(" GAP: kontekst GLOBALNY (cały obraz)", 14, "#CE93D8", FONT_R, (560, 350)),
|
||||
("Concat → 1x1 conv → mapa segmentacji", 16, "#A5D6A7", FONT_R, (550, 400)),
|
||||
(
|
||||
"Efekt: sieć widzi OD piksela DO całego obrazu naraz!",
|
||||
17,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 600),
|
||||
),
|
||||
(
|
||||
"Mnemonik: ASPP = Piramida z DZIURAMI, patrzy na 5 skal jednocześnie",
|
||||
15,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 645),
|
||||
),
|
||||
]
|
||||
slides.append(_compose_slide(aspp_clip, labels2, dur))
|
||||
|
||||
return slides
|
||||
116
python_pkg/praca_magisterska_video/_q23_helpers.py
Normal file
116
python_pkg/praca_magisterska_video/_q23_helpers.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""Shared constants and helper functions for Q23 segmentation video."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg"
|
||||
|
||||
from moviepy import (
|
||||
ColorClip,
|
||||
CompositeVideoClip,
|
||||
TextClip,
|
||||
VideoClip,
|
||||
)
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────
|
||||
W, H = 1280, 720
|
||||
FPS = 24
|
||||
STEP_DUR = 7.0
|
||||
HEADER_DUR = 4.0
|
||||
FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf"
|
||||
FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf"
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "videos"
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
OUTPUT = str(OUTPUT_DIR / "q23_segmentation.mp4")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
BG_COLOR = (15, 20, 35)
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
|
||||
def _tc(**kwargs: object) -> TextClip:
|
||||
"""TextClip wrapper that adds enough bottom margin to prevent clipping."""
|
||||
fs = kwargs.get("font_size", 24)
|
||||
m = int(fs) // 3 + 2
|
||||
kwargs["margin"] = (0, m)
|
||||
return TextClip(**kwargs)
|
||||
|
||||
|
||||
def _make_header(
|
||||
title: str, subtitle: str, duration: float = HEADER_DUR
|
||||
) -> CompositeVideoClip:
|
||||
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
|
||||
t = (
|
||||
_tc(
|
||||
text=title,
|
||||
font_size=48,
|
||||
color="white",
|
||||
font=FONT_B,
|
||||
)
|
||||
.with_duration(duration)
|
||||
.with_position(("center", 260))
|
||||
)
|
||||
s = (
|
||||
_tc(
|
||||
text=subtitle,
|
||||
font_size=24,
|
||||
color="#90CAF9",
|
||||
font=FONT_R,
|
||||
)
|
||||
.with_duration(duration)
|
||||
.with_position(("center", 340))
|
||||
)
|
||||
return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects(
|
||||
[FadeIn(0.5), FadeOut(0.5)]
|
||||
)
|
||||
|
||||
|
||||
def _text_slide(
|
||||
lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]],
|
||||
duration: float = STEP_DUR,
|
||||
) -> CompositeVideoClip:
|
||||
"""Create a slide with multiple text elements."""
|
||||
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
|
||||
clips: list[VideoClip] = [bg]
|
||||
for text, font_size, color, font, pos in lines:
|
||||
tc = (
|
||||
_tc(
|
||||
text=text,
|
||||
font_size=font_size,
|
||||
color=color,
|
||||
font=font,
|
||||
)
|
||||
.with_duration(duration)
|
||||
.with_position(pos)
|
||||
)
|
||||
clips.append(tc)
|
||||
return CompositeVideoClip(clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
|
||||
|
||||
def _compose_slide(
|
||||
base_clip: VideoClip,
|
||||
labels: list[tuple[str, int, str, str, tuple[int, int]]],
|
||||
duration: float,
|
||||
) -> CompositeVideoClip:
|
||||
"""Overlay text labels on an animated base clip."""
|
||||
text_clips: list[VideoClip] = [base_clip]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(duration)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
return CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
430
python_pkg/praca_magisterska_video/_q23_transformer.py
Normal file
430
python_pkg/praca_magisterska_video/_q23_transformer.py
Normal file
@ -0,0 +1,430 @@
|
||||
"""Transformer segmentation and methods comparison for Q23 video."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from moviepy import (
|
||||
ColorClip,
|
||||
CompositeVideoClip,
|
||||
VideoClip,
|
||||
)
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
from python_pkg.praca_magisterska_video._q23_helpers import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_compose_slide,
|
||||
_tc,
|
||||
_text_slide,
|
||||
)
|
||||
|
||||
|
||||
# ── Transformer Segmentation ────────────────────────────────────
|
||||
def _draw_base_grid(
|
||||
frame: np.ndarray,
|
||||
gx: int,
|
||||
gy: int,
|
||||
grid_n: int,
|
||||
cell: int,
|
||||
) -> None:
|
||||
"""Draw an empty grid of cells."""
|
||||
for r in range(grid_n):
|
||||
for c in range(grid_n):
|
||||
x = gx + c * cell
|
||||
y = gy + r * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55)
|
||||
|
||||
|
||||
def _draw_cnn_kernel(
|
||||
frame: np.ndarray,
|
||||
lx: int,
|
||||
ly: int,
|
||||
cell: int,
|
||||
progress: float,
|
||||
) -> None:
|
||||
"""Highlight a 3x3 CNN kernel on the grid."""
|
||||
cnn_phase = 0.2
|
||||
if progress <= cnn_phase:
|
||||
return
|
||||
cx, cy = 2, 2
|
||||
for dr in range(-1, 2):
|
||||
for dc in range(-1, 2):
|
||||
r, c = cy + dr, cx + dc
|
||||
x = lx + c * cell
|
||||
y = ly + r * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200)
|
||||
x = lx + cx * cell
|
||||
y = ly + cy * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (120, 180, 255)
|
||||
|
||||
|
||||
def _draw_conn_line(
|
||||
frame: np.ndarray,
|
||||
x0: int,
|
||||
y0: int,
|
||||
x1: int,
|
||||
y1: int,
|
||||
) -> None:
|
||||
"""Draw a dashed connection line between two points."""
|
||||
steps = max(abs(x1 - x0), abs(y1 - y0))
|
||||
if steps <= 0:
|
||||
return
|
||||
for s in range(0, steps, 3):
|
||||
px = x0 + int((x1 - x0) * s / steps)
|
||||
py = y0 + int((y1 - y0) * s / steps)
|
||||
if 0 <= px < W - 1 and 0 <= py < H - 1:
|
||||
frame[py : py + 1, px : px + 1] = (200, 180, 50)
|
||||
|
||||
|
||||
def _draw_attention_connections(
|
||||
frame: np.ndarray,
|
||||
origin: tuple[int, int],
|
||||
grid_n: int,
|
||||
cell: int,
|
||||
progress: float,
|
||||
) -> None:
|
||||
"""Draw transformer self-attention connections on the grid."""
|
||||
rx, ry = origin
|
||||
transformer_phase = 0.4
|
||||
if progress <= transformer_phase:
|
||||
return
|
||||
cx_t, cy_t = 2, 2
|
||||
x0 = rx + cx_t * cell + cell // 2
|
||||
y0 = ry + cy_t * cell + cell // 2
|
||||
n_connections = int(progress * 36)
|
||||
conn_idx = 0
|
||||
for r in range(grid_n):
|
||||
for c in range(grid_n):
|
||||
conn_idx += 1
|
||||
if conn_idx > n_connections:
|
||||
break
|
||||
x = rx + c * cell
|
||||
y = ry + r * cell
|
||||
dist = abs(r - cy_t) + abs(c - cx_t)
|
||||
strength = max(30, 200 - dist * 30)
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (
|
||||
strength // 3,
|
||||
strength // 2,
|
||||
strength,
|
||||
)
|
||||
_draw_conn_line(frame, x0, y0, x + cell // 2, y + cell // 2)
|
||||
else:
|
||||
continue
|
||||
break
|
||||
x = rx + cx_t * cell
|
||||
y = ry + cy_t * cell
|
||||
frame[y : y + cell - 2, x : x + cell - 2] = (255, 200, 50)
|
||||
|
||||
|
||||
def _make_attention_frame(t: float) -> np.ndarray:
|
||||
"""Render a CNN-vs-Transformer attention comparison frame."""
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
cell = 40
|
||||
grid_n = 6
|
||||
|
||||
lx, ly = 60, 200
|
||||
_draw_base_grid(frame, lx, ly, grid_n, cell)
|
||||
_draw_cnn_kernel(frame, lx, ly, cell, progress)
|
||||
|
||||
rx, ry = 680, 200
|
||||
_draw_base_grid(frame, rx, ry, grid_n, cell)
|
||||
_draw_attention_connections(frame, (rx, ry), grid_n, cell, progress)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _transformer_seg_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate transformer-based segmentation: self-attention concept."""
|
||||
dur = STEP_DUR + 1
|
||||
|
||||
# Slide 1: CNN local vs Transformer global
|
||||
att_clip = VideoClip(_make_attention_frame, duration=dur).with_fps(FPS)
|
||||
labels = [
|
||||
("Transformer: Self-Attention w segmentacji", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
("CNN = LOKALNY kontekst", 18, "#64B5F6", FONT_B, (60, 160)),
|
||||
("Transformer = GLOBALNY kontekst", 18, "#FFE082", FONT_B, (680, 160)),
|
||||
("Filtr 3x3 widzi", 14, "#64B5F6", FONT_R, (60, 460)),
|
||||
("TYLKO 9 sąsiadów", 14, "#64B5F6", FONT_R, (60, 485)),
|
||||
("Self-attention: każdy", 14, "#FFE082", FONT_R, (680, 460)),
|
||||
("piksel widzi WSZYSTKIE!", 14, "#FFE082", FONT_R, (680, 485)),
|
||||
("vs", 28, "#B0BEC5", FONT_B, (450, 300)),
|
||||
]
|
||||
slides = [_compose_slide(att_clip, labels, dur)]
|
||||
|
||||
# Slide 2: Self-attention Q/K/V step by step
|
||||
qkv_lines = [
|
||||
("Self-Attention: Q / K / V krok po kroku", 26, "#FFE082", FONT_B, (80, 30)),
|
||||
("Każdy piksel (token) tworzy 3 wektory:", 18, "#B0BEC5", FONT_R, (100, 100)),
|
||||
(
|
||||
" Q (Query) = 'czego szukam?' - pytanie piksela",
|
||||
17,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(120, 145),
|
||||
),
|
||||
(
|
||||
" K (Key) = 'co oferuj\u0119?' - odpowied\u017a piksela",
|
||||
17,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(120, 185),
|
||||
),
|
||||
(
|
||||
" V (Value) = 'moja warto\u015b\u0107' - informacja do przekazania",
|
||||
17,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(120, 225),
|
||||
),
|
||||
("Algorytm attention:", 18, "#B0BEC5", FONT_R, (100, 285)),
|
||||
(
|
||||
" 1. Mnożenie Q x K\u1d40 → macierz NxN (kto ważny dla kogo)",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(120, 320),
|
||||
),
|
||||
(
|
||||
" 2. Skalowanie: / \u221ad (stabilno\u015b\u0107 gradient\u00f3w)",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(120, 355),
|
||||
),
|
||||
(
|
||||
" 3. Softmax \u2192 wagi attention (sumuj\u0105 si\u0119 do 1)",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(120, 390),
|
||||
),
|
||||
(
|
||||
" 4. Mno\u017cenie wag x V \u2192 wa\u017cona suma warto\u015bci",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(120, 425),
|
||||
),
|
||||
(
|
||||
"Attention(Q,K,V) = softmax(Q \u00b7 K\u1d40 / \u221ad) \u00b7 V",
|
||||
20,
|
||||
"#FFE082",
|
||||
FONT_B,
|
||||
(100, 480),
|
||||
),
|
||||
(
|
||||
"Z\u0142o\u017cono\u015b\u0107: O(n\u00b2) pami\u0119ci \u2014 n = liczba pikseli/token\u00f3w",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(100, 535),
|
||||
),
|
||||
(
|
||||
"Dlatego SegFormer u\u017cywa efficient attention (liniowa z\u0142o\u017cono\u015b\u0107)",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 570),
|
||||
),
|
||||
(
|
||||
"SegFormer (2021): lightweight + hierarchiczny encoder",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 610),
|
||||
),
|
||||
(
|
||||
"Mask2Former (2022): masked attention + "
|
||||
"unified (semantic+instance+panoptic)",
|
||||
16,
|
||||
"#CE93D8",
|
||||
FONT_R,
|
||||
(100, 645),
|
||||
),
|
||||
]
|
||||
slides.append(_text_slide(qkv_lines, duration=STEP_DUR + 1))
|
||||
|
||||
# Slide 3: Encoder-Decoder in DL summary
|
||||
summary_lines = [
|
||||
(
|
||||
"Podsumowanie: Encoder-Decoder w segmentacji DL",
|
||||
24,
|
||||
"#FFE082",
|
||||
FONT_B,
|
||||
(80, 30),
|
||||
),
|
||||
(
|
||||
"Wsp\u00f3lna idea WSZYSTKICH sieci segmentacji:",
|
||||
18,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 90),
|
||||
),
|
||||
(
|
||||
"Encoder: obraz \u2192 cechy (zmniejsza rozdzielczo\u015b\u0107, wyci\u0105ga CO)",
|
||||
16,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 140),
|
||||
),
|
||||
(
|
||||
"Decoder: cechy \u2192 mapa (zwi\u0119ksza rozdzielczo\u015b\u0107, odtwarza GDZIE)",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 175),
|
||||
),
|
||||
(
|
||||
"Skip: przenosi detale z encodera do decodera",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 210),
|
||||
),
|
||||
("", 10, "white", FONT_R, (100, 240)),
|
||||
(
|
||||
"FCN (2015): Conv1x1 + skip \u2192 pierwsza end-to-end",
|
||||
16,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 275),
|
||||
),
|
||||
(
|
||||
"U-Net (2015): U-shape + skip concat \u2192 segmentacja medyczna",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 310),
|
||||
),
|
||||
(
|
||||
"DeepLab (2018): dilated conv + ASPP \u2192 multi-scale kontekst",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 345),
|
||||
),
|
||||
(
|
||||
"SegFormer: transformer encoder (globalny kontekst)",
|
||||
16,
|
||||
"#CE93D8",
|
||||
FONT_R,
|
||||
(100, 380),
|
||||
),
|
||||
(
|
||||
"Mask2Former: masked attention (unified, SOTA)",
|
||||
16,
|
||||
"#CE93D8",
|
||||
FONT_R,
|
||||
(100, 415),
|
||||
),
|
||||
("", 10, "white", FONT_R, (100, 440)),
|
||||
(
|
||||
"Ewolucja: wi\u0119cej kontekstu + lepsze skip connections:",
|
||||
17,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 465),
|
||||
),
|
||||
(
|
||||
" CNN lokal. \u2192 dilated (szersze RF) \u2192 transformer (global) \u2192 masked att.",
|
||||
16,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 505),
|
||||
),
|
||||
(
|
||||
" addition skip \u2192 concat skip \u2192 cross-attention skip",
|
||||
16,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 540),
|
||||
),
|
||||
(
|
||||
"Metryki: mIoU (standard), Dice (medycyna), Focal Loss (imbalance)",
|
||||
16,
|
||||
"#90CAF9",
|
||||
FONT_R,
|
||||
(80, 590),
|
||||
),
|
||||
(
|
||||
"Loss: Cross-Entropy per piksel + opcjonalnie Dice/Focal",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 625),
|
||||
),
|
||||
]
|
||||
slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1))
|
||||
|
||||
return slides
|
||||
|
||||
|
||||
# ── Methods comparison ────────────────────────────────────────────
|
||||
def _methods_comparison() -> CompositeVideoClip:
|
||||
"""Create a comparison table of all segmentation methods."""
|
||||
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0)
|
||||
title = (
|
||||
_tc(
|
||||
text="Por\u00f3wnanie metod segmentacji",
|
||||
font_size=36,
|
||||
color="white",
|
||||
font=FONT_B,
|
||||
)
|
||||
.with_duration(10.0)
|
||||
.with_position(("center", 20))
|
||||
)
|
||||
|
||||
rows = [
|
||||
("Metoda", "Typ", "Idea", "Mnemonik"),
|
||||
(
|
||||
"Thresholding",
|
||||
"Klasyczna",
|
||||
"piksel > T \u2192 klasa 1",
|
||||
"PR\u00d3G na bramce",
|
||||
),
|
||||
("Otsu", "Klasyczna", "auto-pr\u00f3g, min \u03c3\u00b2", "AUTO-bramkarz"),
|
||||
("Region Growing", "Klasyczna", "BFS od seeda", "PLAMA atramentu"),
|
||||
("Watershed", "Klasyczna", "zalewanie minim\u00f3w", "ZALEWANIE terenu"),
|
||||
(
|
||||
"Mean Shift",
|
||||
"Klasyczna",
|
||||
"j\u0105dro \u2192 max g\u0119sto\u015bci",
|
||||
"KULKI do do\u0142k\u00f3w",
|
||||
),
|
||||
("U-Net", "Deep Learning", "encoder-decoder + skip", "Litera U + mosty"),
|
||||
("DeepLab", "Deep Learning", "dilated conv + ASPP", "DZIURY w filtrze"),
|
||||
]
|
||||
|
||||
clips: list[VideoClip] = [bg, title]
|
||||
mnemonic_col = 3
|
||||
for i, row in enumerate(rows):
|
||||
y_pos = 75 + i * 72
|
||||
col_x = [40, 210, 340, 660]
|
||||
for j, cell in enumerate(row):
|
||||
fs = 16 if i > 0 else 18
|
||||
color = (
|
||||
"#64B5F6" if i == 0 else ("#E0E0E0" if j < mnemonic_col else "#FFE082")
|
||||
)
|
||||
tc = (
|
||||
_tc(
|
||||
text=cell,
|
||||
font_size=fs,
|
||||
color=color,
|
||||
font=FONT_B if i == 0 else FONT_R,
|
||||
)
|
||||
.with_duration(10.0)
|
||||
.with_position((col_x[j], y_pos))
|
||||
)
|
||||
clips.append(tc)
|
||||
|
||||
return CompositeVideoClip(clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.5), FadeOut(0.5)]
|
||||
)
|
||||
399
python_pkg/praca_magisterska_video/_q23_unet_fcn.py
Normal file
399
python_pkg/praca_magisterska_video/_q23_unet_fcn.py
Normal file
@ -0,0 +1,399 @@
|
||||
"""U-Net and FCN architecture animations for Q23 segmentation video."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from moviepy import (
|
||||
CompositeVideoClip,
|
||||
VideoClip,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
from python_pkg.praca_magisterska_video._q23_helpers import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_compose_slide,
|
||||
_text_slide,
|
||||
)
|
||||
|
||||
|
||||
# ── U-Net Architecture ───────────────────────────────────────────
|
||||
def _draw_unet_skips(
|
||||
frame: np.ndarray,
|
||||
enc_positions: list[tuple[int, int, int, int]],
|
||||
n_blocks: int,
|
||||
dec_x: int,
|
||||
skip_threshold: int,
|
||||
) -> None:
|
||||
"""Draw horizontal dashed skip-connection lines."""
|
||||
if n_blocks <= skip_threshold:
|
||||
return
|
||||
for i in range(min(n_blocks - 5, 4)):
|
||||
ey = enc_positions[i][1] + enc_positions[i][3] // 2
|
||||
ex_end = enc_positions[i][0] + enc_positions[i][2]
|
||||
for dash_x in range(ex_end + 10, dec_x - 10, 15):
|
||||
frame[ey : ey + 2, dash_x : dash_x + 8] = (255, 200, 50)
|
||||
|
||||
|
||||
def _make_unet_frame(t: float) -> np.ndarray:
|
||||
"""Render a single U-Net animation frame."""
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
enc_sizes = [(80, 120), (60, 100), (45, 80), (30, 60)]
|
||||
dec_sizes = list(reversed(enc_sizes))
|
||||
enc_x = 150
|
||||
dec_x = 850
|
||||
|
||||
progress = min(t / (STEP_DUR * 0.6), 1.0)
|
||||
n_blocks = int(progress * 8) + 1
|
||||
|
||||
enc_positions: list[tuple[int, int, int, int]] = []
|
||||
y_offset = 120
|
||||
for i, (bw, bh) in enumerate(enc_sizes):
|
||||
x = enc_x
|
||||
y = y_offset + i * 130
|
||||
enc_positions.append((x, y, bw, bh))
|
||||
if i < n_blocks:
|
||||
frame[y : y + bh, x : x + bw] = (70, 130, 200)
|
||||
frame[y : y + 2, x : x + bw] = (100, 180, 255)
|
||||
frame[y + bh - 2 : y + bh, x : x + bw] = (100, 180, 255)
|
||||
frame[y : y + bh, x : x + 2] = (100, 180, 255)
|
||||
frame[y : y + bh, x + bw - 2 : x + bw] = (100, 180, 255)
|
||||
if i < len(enc_sizes) - 1:
|
||||
ax = x + bw // 2
|
||||
ay = y + bh + 10
|
||||
frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170)
|
||||
|
||||
bx, by = 500, y_offset + 3 * 130 + 30
|
||||
encoder_count = 4
|
||||
if n_blocks > encoder_count:
|
||||
frame[by : by + 50, bx : bx + 25] = (200, 100, 80)
|
||||
frame[by : by + 2, bx : bx + 25] = (255, 140, 100)
|
||||
frame[by + 48 : by + 50, bx : bx + 25] = (255, 140, 100)
|
||||
|
||||
for i, (bw, bh) in enumerate(dec_sizes):
|
||||
x = dec_x
|
||||
y = y_offset + (3 - i) * 130
|
||||
if n_blocks > 4 + i + 1:
|
||||
frame[y : y + bh, x : x + bw] = (80, 200, 120)
|
||||
frame[y : y + 2, x : x + bw] = (120, 230, 150)
|
||||
frame[y + bh - 2 : y + bh, x : x + bw] = (120, 230, 150)
|
||||
frame[y : y + bh, x : x + 2] = (120, 230, 150)
|
||||
frame[y : y + bh, x + bw - 2 : x + bw] = (120, 230, 150)
|
||||
if i < len(dec_sizes) - 1:
|
||||
ax = x + bw // 2
|
||||
ay = y - 30
|
||||
frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170)
|
||||
|
||||
skip_threshold = 5
|
||||
_draw_unet_skips(frame, enc_positions, n_blocks, dec_x, skip_threshold)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _unet_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate U-Net encoder-decoder architecture."""
|
||||
dur = STEP_DUR + 1
|
||||
unet_clip = VideoClip(_make_unet_frame, duration=dur).with_fps(FPS)
|
||||
labels = [
|
||||
("U-Net: Encoder-Decoder + Skip Connections", 28, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"Niebieski = Encoder (↓ zmniejsza rozdzielczość, wyciąga cechy)",
|
||||
16,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(80, 65),
|
||||
),
|
||||
(
|
||||
"Zielony = Decoder (↑ zwiększa rozdzielczość, odtwarza mapę)",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 90),
|
||||
),
|
||||
(
|
||||
"Żółte przerywane = Skip connections (przenoszą detale z encodera)",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 115),
|
||||
),
|
||||
(
|
||||
"Czerwony = Bottleneck (najgłębsza warstwa, max abstrakcja)",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(450, 570),
|
||||
),
|
||||
(
|
||||
"Kształt U: encoder ↓ decoder ↑, mosty pośrodku",
|
||||
18,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 640),
|
||||
),
|
||||
(
|
||||
"Concatenation: skip łączy kanały (więcej informacji niż dodawanie)",
|
||||
16,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 670),
|
||||
),
|
||||
]
|
||||
return [_compose_slide(unet_clip, labels, dur)]
|
||||
|
||||
|
||||
# ── FCN Architecture ─────────────────────────────────────────────
|
||||
def _draw_pipeline_blocks(
|
||||
frame: np.ndarray,
|
||||
blocks: list[tuple[tuple[int, int], tuple[int, int], tuple[int, int, int]]],
|
||||
n_visible: int,
|
||||
arrow_limit: int,
|
||||
) -> None:
|
||||
"""Draw coloured blocks with connecting arrows."""
|
||||
for i, ((bx, by), (bw, bh), color) in enumerate(blocks):
|
||||
if i < n_visible:
|
||||
frame[by : by + bh, bx : bx + bw] = color
|
||||
frame[by : by + 2, bx : bx + bw] = tuple(min(c + 50, 255) for c in color)
|
||||
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
if i < arrow_limit:
|
||||
ax = bx + bw + 3
|
||||
ay = by + bh // 2
|
||||
frame[ay - 1 : ay + 2, ax : ax + 12] = (150, 150, 170)
|
||||
|
||||
|
||||
def _draw_red_cross(
|
||||
frame: np.ndarray,
|
||||
x_start: int,
|
||||
width: int,
|
||||
top_y: int,
|
||||
height: int,
|
||||
) -> None:
|
||||
"""Draw a red X across the given rectangle."""
|
||||
for d in range(-2, 3):
|
||||
for step in range(height):
|
||||
x1 = x_start + int(step * width / height)
|
||||
y1 = top_y + step + d
|
||||
if 0 <= y1 < H and 0 <= x1 < W:
|
||||
frame[y1, x1] = (255, 80, 80)
|
||||
y2 = top_y + height - step + d
|
||||
if 0 <= y2 < H and 0 <= x1 < W:
|
||||
frame[y2, x1] = (255, 80, 80)
|
||||
|
||||
|
||||
def _make_fcn_frame(t: float) -> np.ndarray:
|
||||
"""Render a single FCN comparison frame."""
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.8), 1.0)
|
||||
|
||||
top_y = 140
|
||||
blocks_classic = [
|
||||
((80, top_y), (70, 50), (70, 130, 200)),
|
||||
((170, top_y), (50, 40), (50, 100, 160)),
|
||||
((240, top_y), (60, 50), (70, 130, 200)),
|
||||
((320, top_y), (40, 35), (50, 100, 160)),
|
||||
((385, top_y), (55, 50), (160, 80, 60)),
|
||||
((465, top_y), (55, 50), (180, 60, 60)),
|
||||
((545, top_y), (80, 50), (200, 80, 80)),
|
||||
]
|
||||
n_top = min(int(progress * 7) + 1, 7)
|
||||
arrow_limit = 6
|
||||
_draw_pipeline_blocks(frame, blocks_classic, n_top, arrow_limit)
|
||||
|
||||
cross_phase = 0.6
|
||||
if progress > cross_phase:
|
||||
_draw_red_cross(frame, 385, 135, top_y, 50)
|
||||
|
||||
bot_y = 380
|
||||
blocks_fcn = [
|
||||
((80, bot_y), (70, 50), (70, 130, 200)),
|
||||
((170, bot_y), (50, 40), (50, 100, 160)),
|
||||
((240, bot_y), (60, 50), (70, 130, 200)),
|
||||
((320, bot_y), (40, 35), (50, 100, 160)),
|
||||
((385, bot_y), (70, 50), (80, 200, 120)),
|
||||
((480, bot_y), (75, 50), (200, 160, 80)),
|
||||
((580, bot_y), (80, 50), (100, 200, 100)),
|
||||
]
|
||||
fcn_phase = 0.4
|
||||
if progress > fcn_phase:
|
||||
n_bot = min(int((progress - fcn_phase) / 0.6 * 7) + 1, 7)
|
||||
_draw_pipeline_blocks(frame, blocks_fcn, n_bot, arrow_limit)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _fcn_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate FCN step-by-step: FC → Conv 1x1 transformation."""
|
||||
dur = STEP_DUR + 1
|
||||
fcn_clip = VideoClip(_make_fcn_frame, duration=dur).with_fps(FPS)
|
||||
labels = [
|
||||
("FCN: Fully Convolutional Network (2015)", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
("KROK 1: Zamień FC → Conv 1x1", 18, "#A5D6A7", FONT_R, (80, 60)),
|
||||
("Klasyczny CNN:", 16, "#EF9A9A", FONT_B, (80, 105)),
|
||||
("Conv", 11, "white", FONT_R, (92, 148)),
|
||||
("Pool", 11, "white", FONT_R, (178, 148)),
|
||||
("Conv", 11, "white", FONT_R, (250, 148)),
|
||||
("Pool", 11, "white", FONT_R, (325, 148)),
|
||||
("Flatten", 11, "#EF9A9A", FONT_R, (390, 148)),
|
||||
("FC", 11, "#EF9A9A", FONT_R, (480, 148)),
|
||||
("1 label", 11, "#EF9A9A", FONT_R, (555, 148)),
|
||||
("FCN:", 16, "#A5D6A7", FONT_B, (80, 350)),
|
||||
("Conv", 11, "white", FONT_R, (92, 388)),
|
||||
("Pool", 11, "white", FONT_R, (178, 388)),
|
||||
("Conv", 11, "white", FONT_R, (250, 388)),
|
||||
("Pool", 11, "white", FONT_R, (325, 388)),
|
||||
("Conv1x1", 11, "#A5D6A7", FONT_R, (390, 388)),
|
||||
("Upsample", 11, "#FFE082", FONT_R, (486, 388)),
|
||||
("Mapa", 11, "#A5D6A7", FONT_R, (595, 388)),
|
||||
(
|
||||
"FC: spłaszcza 3D→1D, wymusza stały rozmiar → 1 etykieta",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(80, 250),
|
||||
),
|
||||
(
|
||||
"Conv1x1: działa per piksel x kanały → DOWOLNY rozmiar → mapa klasy",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 460),
|
||||
),
|
||||
(
|
||||
"KROK 2: Skip connections — łączą wczesne detale z późną abstrakcją",
|
||||
17,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(80, 510),
|
||||
),
|
||||
(
|
||||
"Wczesne warstwy = krawędzie, tekstury | Późne = koncepty obiektów",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 545),
|
||||
),
|
||||
(
|
||||
"FCN = PIERWSZA sieć end-to-end do segmentacji per-piksel!",
|
||||
18,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 590),
|
||||
),
|
||||
(
|
||||
"Mnemonik: FC → Conv 1x1 = otwieramy bramkę dla DOWOLNEGO rozmiaru",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 640),
|
||||
),
|
||||
]
|
||||
slides = [_compose_slide(fcn_clip, labels, dur)]
|
||||
|
||||
# Slide 2: FCN skip connections step by step
|
||||
skip_lines = [
|
||||
("FCN: Skip Connections — krok po kroku", 26, "#FFE082", FONT_B, (80, 30)),
|
||||
(
|
||||
"1. Encoder zmniejsza: 224→112→56→28→14 (pooling)",
|
||||
18,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 100),
|
||||
),
|
||||
(
|
||||
" Każdy pooling traci detale przestrzenne (dokładne krawędzie)",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 135),
|
||||
),
|
||||
(
|
||||
"2. Decoder powiększa: 14→28→56→112→224 (upsample/deconv)",
|
||||
18,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 190),
|
||||
),
|
||||
(
|
||||
" Upsample ODGADUJE piksele — rozmyty wynik!",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 225),
|
||||
),
|
||||
(
|
||||
"3. Skip connections: dodaj cechy z encodera do decodera",
|
||||
18,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 280),
|
||||
),
|
||||
(
|
||||
" Wczesne cechy = GDZIE (precyzyjne krawędzie)",
|
||||
15,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 315),
|
||||
),
|
||||
(
|
||||
" Późne cechy = CO (abstrakcyjne koncepty)",
|
||||
15,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 345),
|
||||
),
|
||||
(
|
||||
" Skip = daje decoderowi OBA → ostry wynik!",
|
||||
15,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 375),
|
||||
),
|
||||
(
|
||||
"Warianty: FCN-32s (brak skip, rozmyty) → FCN-16s → FCN-8s (najlepszy)",
|
||||
16,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 440),
|
||||
),
|
||||
(
|
||||
"FCN-32s: upsample 32x naraz → ROZMYTE granice",
|
||||
15,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(100, 485),
|
||||
),
|
||||
(
|
||||
"FCN-16s: skip z pool4 + upsample 16x → lepiej",
|
||||
15,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 520),
|
||||
),
|
||||
(
|
||||
"FCN-8s: skip z pool3+pool4 + upsample 8x → OSTRE granice!",
|
||||
15,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 555),
|
||||
),
|
||||
(
|
||||
"Im więcej skip connections → tym więcej "
|
||||
"detali z encodera → ostrzejszy wynik",
|
||||
17,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
]
|
||||
slides.append(_text_slide(skip_lines, duration=STEP_DUR + 1))
|
||||
|
||||
return slides
|
||||
332
python_pkg/praca_magisterska_video/_q24_classical.py
Normal file
332
python_pkg/praca_magisterska_video/_q24_classical.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Classical detection methods: detection concept, HOG+SVM, Viola-Jones."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_tc,
|
||||
)
|
||||
from moviepy import CompositeVideoClip, VideoClip
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── Detection concept ────────────────────────────────────────────
|
||||
def _detection_concept() -> list[CompositeVideoClip]:
|
||||
"""Show what detection is: bounding box + class + confidence."""
|
||||
slides = []
|
||||
|
||||
def make_det_frame(_t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
# Draw a "scene" with colored rectangles representing objects
|
||||
# Sky background area
|
||||
frame[140:500, 100:700] = (40, 50, 70)
|
||||
|
||||
# "Car" object
|
||||
frame[350:430, 150:320] = (180, 60, 60)
|
||||
# "Person" object
|
||||
frame[280:440, 450:520] = (60, 120, 180)
|
||||
# "Tree" object
|
||||
frame[200:400, 580:650] = (40, 130, 50)
|
||||
|
||||
# Bounding boxes (with labels drawn as colored borders)
|
||||
# Car bbox
|
||||
for thickness in range(3):
|
||||
t = thickness
|
||||
frame[348 - t : 432 + t, 148 - t : 148 - t + 2] = (255, 80, 80)
|
||||
frame[348 - t : 432 + t, 322 + t - 2 : 322 + t] = (255, 80, 80)
|
||||
frame[348 - t : 348 - t + 2, 148 - t : 322 + t] = (255, 80, 80)
|
||||
frame[432 + t - 2 : 432 + t, 148 - t : 322 + t] = (255, 80, 80)
|
||||
|
||||
# Person bbox
|
||||
for thickness in range(3):
|
||||
t = thickness
|
||||
frame[278 - t : 442 + t, 448 - t : 448 - t + 2] = (80, 180, 255)
|
||||
frame[278 - t : 442 + t, 522 + t - 2 : 522 + t] = (80, 180, 255)
|
||||
frame[278 - t : 278 - t + 2, 448 - t : 522 + t] = (80, 180, 255)
|
||||
frame[442 + t - 2 : 442 + t, 448 - t : 522 + t] = (80, 180, 255)
|
||||
|
||||
# Tree bbox
|
||||
for thickness in range(3):
|
||||
t = thickness
|
||||
frame[198 - t : 402 + t, 578 - t : 578 - t + 2] = (80, 220, 100)
|
||||
frame[198 - t : 402 + t, 652 + t - 2 : 652 + t] = (80, 220, 100)
|
||||
frame[198 - t : 198 - t + 2, 578 - t : 652 + t] = (80, 220, 100)
|
||||
frame[402 + t - 2 : 402 + t, 578 - t : 652 + t] = (80, 220, 100)
|
||||
|
||||
# Comparison boxes on right side
|
||||
# Classification
|
||||
frame[180:260, 800:1150] = (35, 45, 65)
|
||||
# Detection
|
||||
frame[290:370, 800:1150] = (35, 45, 65)
|
||||
# Segmentation
|
||||
frame[400:480, 800:1150] = (35, 45, 65)
|
||||
|
||||
return frame
|
||||
|
||||
det_clip = VideoClip(make_det_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [det_clip]
|
||||
labels = [
|
||||
("Detekcja obiektów — co to jest?", 28, "#FFE082", FONT_B, (100, 20)),
|
||||
("Wynik: (klasa, bounding box, pewność)", 20, "#B0BEC5", FONT_R, (100, 65)),
|
||||
("samochód 95%", 14, "#EF9A9A", FONT_B, (150, 340)),
|
||||
("osoba 88%", 14, "#64B5F6", FONT_B, (450, 268)),
|
||||
("drzewo 72%", 14, "#A5D6A7", FONT_B, (580, 188)),
|
||||
("Klasyfikacja: cały obraz → 1 etykieta", 15, "#78909C", FONT_R, (810, 210)),
|
||||
("Detekcja: bbox + klasa + pewność", 15, "#FFE082", FONT_R, (810, 320)),
|
||||
("Segmentacja: maska per piksel", 15, "#78909C", FONT_R, (810, 430)),
|
||||
("← granulacja rośnie →", 14, "#90CAF9", FONT_R, (810, 520)),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── HOG + SVM pipeline ───────────────────────────────────────────
|
||||
def _hog_svm_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate HOG feature computation and SVM classification."""
|
||||
slides = []
|
||||
|
||||
def make_hog_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
progress = min(t / (STEP_DUR * 0.8), 1.0)
|
||||
|
||||
# Pipeline stages as boxes with arrows
|
||||
stages = [
|
||||
("Gradient", (80, 250), (130, 80), (100, 160, 220)),
|
||||
("Orientacja", (260, 250), (130, 80), (80, 180, 140)),
|
||||
("Komórki 8x8", (440, 250), (130, 80), (200, 160, 80)),
|
||||
("Bloki 2x2", (620, 250), (130, 80), (200, 120, 60)),
|
||||
("Normalizacja", (800, 250), (130, 80), (180, 100, 80)),
|
||||
("SVM", (980, 250), (130, 80), (220, 80, 80)),
|
||||
]
|
||||
|
||||
n_active = int(progress * len(stages)) + 1
|
||||
|
||||
for i, (_label, (sx, sy), (sw, sh), color) in enumerate(stages):
|
||||
if i < n_active:
|
||||
frame[sy : sy + sh, sx : sx + sw] = color
|
||||
# Border
|
||||
frame[sy : sy + 2, sx : sx + sw] = tuple(
|
||||
min(c + 60, 255) for c in color
|
||||
)
|
||||
frame[sy + sh - 2 : sy + sh, sx : sx + sw] = tuple(
|
||||
min(c + 60, 255) for c in color
|
||||
)
|
||||
|
||||
# Arrow to next
|
||||
if i < len(stages) - 1:
|
||||
ax = sx + sw + 5
|
||||
ay = sy + sh // 2
|
||||
frame[ay - 1 : ay + 2, ax : ax + 20] = (150, 150, 170)
|
||||
|
||||
# Show gradient computation example at bottom
|
||||
gradient_phase = 0.2
|
||||
if progress > gradient_phase:
|
||||
# Mini pixel grid showing gradient computation
|
||||
gx, gy = 100, 430
|
||||
pixels = [50, 50, 200]
|
||||
for idx, val in enumerate(pixels):
|
||||
x = gx + idx * 50
|
||||
frame[gy : gy + 40, x : x + 40] = (val, val, val)
|
||||
|
||||
return frame
|
||||
|
||||
hog_clip = VideoClip(make_hog_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [hog_clip]
|
||||
labels = [
|
||||
("HOG + SVM — pipeline detekcji pieszych", 28, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"Mnemonik: GOKBN = Gradienty→Orientacja→Komórki→Bloki→Normalizacja",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 65),
|
||||
),
|
||||
("Gradient: siła i kierunek zmiany jasności", 14, "#64B5F6", FONT_R, (80, 95)),
|
||||
(
|
||||
"Histogram: 9 binów (0°-180°, co 20°) per komórka 8x8",
|
||||
14,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 120),
|
||||
),
|
||||
(
|
||||
"[50][50][200] → Gx = 200-50 = 150 = silna krawędź!",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(80, 490),
|
||||
),
|
||||
(
|
||||
"Wektor HOG (3780 cech) → SVM: pieszy (+1) / tło (-1)",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 540),
|
||||
),
|
||||
(
|
||||
"Sliding window 64x128 przesuwa się po obrazie → NMS → wynik",
|
||||
16,
|
||||
"#90CAF9",
|
||||
FONT_R,
|
||||
(80, 580),
|
||||
),
|
||||
(
|
||||
"SVM = LINIA MAKSYMALNEGO ODDECHU (max margines, support vectors)",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── Viola-Jones ───────────────────────────────────────────────────
|
||||
def _viola_jones_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate Viola-Jones cascade concept."""
|
||||
slides = []
|
||||
|
||||
def make_cascade_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
progress = min(t / (STEP_DUR * 0.8), 1.0)
|
||||
|
||||
# Draw cascade "funnel" — stages filtering out non-faces
|
||||
stages = 5
|
||||
start_width = 1000
|
||||
start_count = 10000
|
||||
x_center = W // 2
|
||||
|
||||
for i in range(stages):
|
||||
stage_progress = min(progress * stages - i, 1.0)
|
||||
if stage_progress <= 0:
|
||||
break
|
||||
|
||||
width = int(start_width * (1 - i * 0.18))
|
||||
int(start_count * (0.3**i))
|
||||
y = 150 + i * 100
|
||||
h_box = 60
|
||||
|
||||
# Stage box
|
||||
x1 = x_center - width // 2
|
||||
frame[y : y + h_box, x1 : x1 + width] = (
|
||||
50 + i * 10,
|
||||
60 + i * 10,
|
||||
80 + i * 10,
|
||||
)
|
||||
# Border
|
||||
frame[y : y + 2, x1 : x1 + width] = (100 + i * 20, 130 + i * 15, 200)
|
||||
frame[y + h_box - 2 : y + h_box, x1 : x1 + width] = (
|
||||
100 + i * 20,
|
||||
130 + i * 15,
|
||||
200,
|
||||
)
|
||||
|
||||
# Arrow down to next
|
||||
if i < stages - 1:
|
||||
frame[y + h_box + 5 : y + h_box + 25, x_center - 1 : x_center + 2] = (
|
||||
150,
|
||||
150,
|
||||
170,
|
||||
)
|
||||
|
||||
# Red "rejected" arrows on sides
|
||||
if i > 0:
|
||||
# Left reject arrow
|
||||
rx = x1 - 30
|
||||
ry = y + h_box // 2
|
||||
frame[ry - 1 : ry + 2, rx : rx + 25] = (200, 80, 80)
|
||||
|
||||
return frame
|
||||
|
||||
cascade_clip = VideoClip(make_cascade_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [cascade_clip]
|
||||
labels = [
|
||||
(
|
||||
"Viola-Jones — kaskada klasyfikatorów (2001)",
|
||||
28,
|
||||
"#FFE082",
|
||||
FONT_B,
|
||||
(80, 20),
|
||||
),
|
||||
(
|
||||
"3 innowacje: HIC = Haar + Integral Image + Cascade",
|
||||
20,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 65),
|
||||
),
|
||||
("Etap 1: 2 cechy Haar", 14, "#64B5F6", FONT_R, (170, 170)),
|
||||
("Etap 2: 10 cech", 14, "#64B5F6", FONT_R, (210, 270)),
|
||||
("Etap 3: 25 cech", 14, "#64B5F6", FONT_R, (240, 370)),
|
||||
("Etap 4: 50 cech", 14, "#64B5F6", FONT_R, (260, 470)),
|
||||
("→ TWARZ!", 16, "#A5D6A7", FONT_B, (590, 560)),
|
||||
(
|
||||
"SITO: 99% okien odpada w pierwszych 3 etapach → REAL-TIME!",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
(
|
||||
"Haar: kontrast jasna/ciemna | Integral Image: "
|
||||
"suma prostokąta O(1) = 4 odczyty",
|
||||
14,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 655),
|
||||
),
|
||||
("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 275)),
|
||||
("odrzucone →", 12, "#EF9A9A", FONT_R, (60, 375)),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
115
python_pkg/praca_magisterska_video/_q24_common.py
Normal file
115
python_pkg/praca_magisterska_video/_q24_common.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Shared constants and helpers for Q24 object detection visualization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg"
|
||||
|
||||
from moviepy import (
|
||||
ColorClip,
|
||||
CompositeVideoClip,
|
||||
TextClip,
|
||||
VideoClip,
|
||||
)
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────
|
||||
W, H = 1280, 720
|
||||
FPS = 24
|
||||
STEP_DUR = 7.0
|
||||
HEADER_DUR = 4.0
|
||||
FONT_B = "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf"
|
||||
FONT_R = "/usr/share/fonts/TTF/DejaVuSans.ttf"
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "videos"
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
OUTPUT = str(OUTPUT_DIR / "q24_object_detection.mp4")
|
||||
|
||||
BG_COLOR = (15, 20, 35)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-export numpy for sub-modules that need it alongside constants.
|
||||
__all__ = [
|
||||
"BG_COLOR",
|
||||
"FONT_B",
|
||||
"FONT_R",
|
||||
"FPS",
|
||||
"HEADER_DUR",
|
||||
"OUTPUT",
|
||||
"OUTPUT_DIR",
|
||||
"STEP_DUR",
|
||||
"H",
|
||||
"W",
|
||||
"_logger",
|
||||
"_make_header",
|
||||
"_tc",
|
||||
"_text_slide",
|
||||
"np",
|
||||
]
|
||||
|
||||
|
||||
def _tc(**kwargs: object) -> TextClip:
|
||||
"""TextClip wrapper that adds enough bottom margin to prevent clipping."""
|
||||
fs = kwargs.get("font_size", 24)
|
||||
m = int(fs) // 3 + 2
|
||||
kwargs["margin"] = (0, m)
|
||||
return TextClip(**kwargs)
|
||||
|
||||
|
||||
def _make_header(
|
||||
title: str, subtitle: str, duration: float = HEADER_DUR
|
||||
) -> CompositeVideoClip:
|
||||
"""Create a title/subtitle header slide."""
|
||||
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
|
||||
t = (
|
||||
_tc(
|
||||
text=title,
|
||||
font_size=48,
|
||||
color="white",
|
||||
font=FONT_B,
|
||||
)
|
||||
.with_duration(duration)
|
||||
.with_position(("center", 260))
|
||||
)
|
||||
s = (
|
||||
_tc(
|
||||
text=subtitle,
|
||||
font_size=24,
|
||||
color="#90CAF9",
|
||||
font=FONT_R,
|
||||
)
|
||||
.with_duration(duration)
|
||||
.with_position(("center", 340))
|
||||
)
|
||||
return CompositeVideoClip([bg, t, s], size=(W, H)).with_effects(
|
||||
[FadeIn(0.5), FadeOut(0.5)]
|
||||
)
|
||||
|
||||
|
||||
def _text_slide(
|
||||
lines: list[tuple[str, int, str, str, tuple[str | int, str | int]]],
|
||||
duration: float = STEP_DUR,
|
||||
) -> CompositeVideoClip:
|
||||
"""Create a text-only slide from a list of (text, size, color, font, pos)."""
|
||||
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(duration)
|
||||
clips: list[VideoClip] = [bg]
|
||||
for text, font_size, color, font, pos in lines:
|
||||
tc = (
|
||||
_tc(
|
||||
text=text,
|
||||
font_size=font_size,
|
||||
color=color,
|
||||
font=font,
|
||||
)
|
||||
.with_duration(duration)
|
||||
.with_position(pos)
|
||||
)
|
||||
clips.append(tc)
|
||||
return CompositeVideoClip(clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
239
python_pkg/praca_magisterska_video/_q24_nms_final.py
Normal file
239
python_pkg/praca_magisterska_video/_q24_nms_final.py
Normal file
@ -0,0 +1,239 @@
|
||||
"""NMS/IoU, detector-from-classifier, and methods comparison."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_tc,
|
||||
_text_slide,
|
||||
)
|
||||
from moviepy import ColorClip, CompositeVideoClip, VideoClip
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── NMS + IoU ─────────────────────────────────────────────────────
|
||||
def _nms_iou_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate NMS and IoU concepts."""
|
||||
slides = []
|
||||
|
||||
def make_nms_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
# Draw overlapping bounding boxes
|
||||
ox, oy = 100, 200
|
||||
obj_w, obj_h = 150, 120
|
||||
|
||||
# Multiple overlapping detections for same object
|
||||
boxes = [
|
||||
(ox, oy, obj_w, obj_h, 0.95, (255, 80, 80)), # best
|
||||
(ox + 15, oy - 10, obj_w + 10, obj_h + 5, 0.90, (200, 60, 60)),
|
||||
(ox - 10, oy + 5, obj_w - 5, obj_h + 10, 0.85, (160, 50, 50)),
|
||||
]
|
||||
# Different object far away
|
||||
boxes.append((ox + 350, oy + 50, 100, 100, 0.40, (80, 180, 255)))
|
||||
|
||||
for i, (bx, by, bw, bh, _conf, color) in enumerate(boxes):
|
||||
dc = color
|
||||
nms_phase = 0.4
|
||||
nms_limit = 3
|
||||
if progress > nms_phase and i > 0 and i < nms_limit:
|
||||
# After NMS, these get removed (shown as faded/crossed)
|
||||
dc = (60, 40, 40)
|
||||
|
||||
for tt in range(2):
|
||||
frame[by - tt : by + bh + tt, bx - tt : bx - tt + 2] = dc
|
||||
frame[by - tt : by + bh + tt, bx + bw + tt - 2 : bx + bw + tt] = dc
|
||||
frame[by - tt : by - tt + 2, bx - tt : bx + bw + tt] = dc
|
||||
frame[by + bh + tt - 2 : by + bh + tt, bx - tt : bx + bw + tt] = dc
|
||||
|
||||
# IoU visualization on right side
|
||||
iou_x, iou_y = 700, 200
|
||||
# Box A
|
||||
frame[iou_y : iou_y + 100, iou_x : iou_x + 100] = (80, 80, 200)
|
||||
# Box B (overlapping)
|
||||
frame[iou_y + 40 : iou_y + 140, iou_x + 40 : iou_x + 140] = (200, 80, 80)
|
||||
# Intersection highlighted
|
||||
frame[iou_y + 40 : iou_y + 100, iou_x + 40 : iou_x + 100] = (200, 150, 200)
|
||||
|
||||
return frame
|
||||
|
||||
nms_clip = VideoClip(make_nms_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [nms_clip]
|
||||
labels = [
|
||||
("NMS (Non-Maximum Suppression) + IoU", 28, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"NMS = Najlepszy Ma Się dobrze — zachowaj najlepszą, usuń duplikaty",
|
||||
18,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 65),
|
||||
),
|
||||
("conf=0.95 ✓", 14, "#A5D6A7", FONT_B, (100, 340)),
|
||||
("0.90 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 365)),
|
||||
("0.85 ✗ IoU>0.5", 13, "#EF9A9A", FONT_R, (100, 390)),
|
||||
("0.40 ✓ INNY obiekt", 13, "#64B5F6", FONT_R, (100, 420)),
|
||||
("IoU = Intersection over Union", 18, "#FFE082", FONT_B, (700, 160)),
|
||||
("IoU = pole(∩) / pole(AUB)", 16, "white", FONT_R, (700, 380)),
|
||||
("Fioletowy = intersection", 14, "#CE93D8", FONT_R, (700, 410)),
|
||||
("IoU > 0.5 → TEN SAM obiekt → usuń", 14, "#EF9A9A", FONT_R, (700, 440)),
|
||||
("IoU < 0.5 → INNY obiekt → zachowaj", 14, "#A5D6A7", FONT_R, (700, 470)),
|
||||
(
|
||||
"DETR: jedyny detektor BEZ NMS (Hungarian matching zamiast tego)",
|
||||
14,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── Detector from Classifier ─────────────────────────────────────
|
||||
def _detector_from_classifier() -> list[CompositeVideoClip]:
|
||||
"""Show 3 approaches to building a detector from a classifier."""
|
||||
slides = []
|
||||
|
||||
approaches = [
|
||||
(
|
||||
"Podejście 1: Sliding Window (NAJWOLNIEJSZE)",
|
||||
[
|
||||
("Okno przesuwa się po obrazie w wielu skalach", "#B0BEC5"),
|
||||
("Każde okno → klasyfikator (np. ResNet) → klasa + pewność", "#B0BEC5"),
|
||||
("~18 000 okien x 10ms = ~3 minuty na obraz!", "#EF9A9A"),
|
||||
("Mnemonik: WYCINAJ i PYTAJ — jak wycinanie ciasteczek", "#FFE082"),
|
||||
],
|
||||
"SRF",
|
||||
),
|
||||
(
|
||||
"Podejście 2: Region Proposals (= R-CNN)",
|
||||
[
|
||||
("Selective Search → ~2000 inteligentnych regionów", "#B0BEC5"),
|
||||
("Każdy region → CNN → wektor cech → SVM klasyfikuje", "#B0BEC5"),
|
||||
("~2000 x 10ms = ~20 sec — 9x szybciej!", "#64B5F6"),
|
||||
(
|
||||
"Mnemonik: INTELIGENTNE CIĘCIE — wytnij tylko tam gdzie wiśnie",
|
||||
"#FFE082",
|
||||
),
|
||||
],
|
||||
"SRF",
|
||||
),
|
||||
(
|
||||
"Podejście 3: Fine-tune backbone (NAJLEPSZE)",
|
||||
[
|
||||
(
|
||||
"Pretrained backbone (ResNet) → odetnij FC → dodaj detection head",
|
||||
"#B0BEC5",
|
||||
),
|
||||
(
|
||||
"Detection head = głowica klasyfikacji + głowica regresji bbox",
|
||||
"#B0BEC5",
|
||||
),
|
||||
("~0.2 sec/obraz, najlepsza jakość (mAP ~42%)", "#A5D6A7"),
|
||||
("Mnemonik: PRZESZCZEP GŁOWY — ten sam silnik, nowa głowa", "#FFE082"),
|
||||
],
|
||||
"SRF",
|
||||
),
|
||||
]
|
||||
|
||||
for title, points, _mnem in approaches:
|
||||
lines = [
|
||||
(title, 24, "#FFE082", FONT_B, (80, 140)),
|
||||
]
|
||||
for i, (text, color) in enumerate(points):
|
||||
lines.append((f"• {text}", 18, color, FONT_R, (100, 220 + i * 50)))
|
||||
|
||||
lines.append(
|
||||
(
|
||||
"Detektor z klasyfikatora: SRF = Sliding → Region → Fine-tune",
|
||||
16,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 520),
|
||||
)
|
||||
)
|
||||
lines.append(
|
||||
(
|
||||
"= Szukaj Ręcznie, Finalnie optymalizuj!",
|
||||
16,
|
||||
"#90CAF9",
|
||||
FONT_R,
|
||||
(80, 550),
|
||||
)
|
||||
)
|
||||
|
||||
slides.append(_text_slide(lines, duration=STEP_DUR))
|
||||
|
||||
return slides
|
||||
|
||||
|
||||
# ── Methods comparison ────────────────────────────────────────────
|
||||
def _methods_comparison() -> CompositeVideoClip:
|
||||
"""Create a comparison table of all detection methods."""
|
||||
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0)
|
||||
title = (
|
||||
_tc(
|
||||
text="Porównanie detektorów",
|
||||
font_size=36,
|
||||
color="white",
|
||||
font=FONT_B,
|
||||
)
|
||||
.with_duration(10.0)
|
||||
.with_position(("center", 20))
|
||||
)
|
||||
|
||||
rows = [
|
||||
("Model", "Rok", "Typ", "Szybkość", "Kluczowe"),
|
||||
("HOG+SVM", "2005", "Klasyczny", "~1 fps", "Gradient histogramy"),
|
||||
("Viola-Jones", "2001", "Klasyczny", "30+ fps", "Haar+Cascade"),
|
||||
("R-CNN", "2014", "Two-stage", "50 sec!", "CNN per region"),
|
||||
("Fast R-CNN", "2015", "Two-stage", "2 sec", "ROI Pooling"),
|
||||
("Faster R-CNN", "2015", "Two-stage", "5 fps", "RPN w sieci"),
|
||||
("YOLO", "2016", "One-stage", "45+ fps", "Siatka SxS"),
|
||||
("DETR", "2020", "Transformer", "~40 fps", "Bez NMS!"),
|
||||
]
|
||||
|
||||
clips: list[VideoClip] = [bg, title]
|
||||
for i, row in enumerate(rows):
|
||||
y_pos = 75 + i * 72
|
||||
col_x = [40, 200, 280, 400, 530]
|
||||
for j, cell in enumerate(row):
|
||||
fs = 16 if i > 0 else 18
|
||||
color = "#64B5F6" if i == 0 else "#E0E0E0"
|
||||
tc = (
|
||||
_tc(
|
||||
text=cell,
|
||||
font_size=fs,
|
||||
color=color,
|
||||
font=FONT_B if i == 0 else FONT_R,
|
||||
)
|
||||
.with_duration(10.0)
|
||||
.with_position((col_x[j], y_pos))
|
||||
)
|
||||
clips.append(tc)
|
||||
|
||||
return CompositeVideoClip(clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.5), FadeOut(0.5)]
|
||||
)
|
||||
405
python_pkg/praca_magisterska_video/_q24_rcnn.py
Normal file
405
python_pkg/praca_magisterska_video/_q24_rcnn.py
Normal file
@ -0,0 +1,405 @@
|
||||
"""R-CNN family: evolution, detailed pipeline, ROI pooling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_tc,
|
||||
)
|
||||
from moviepy import CompositeVideoClip, VideoClip
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── R-CNN Evolution ───────────────────────────────────────────────
|
||||
def _rcnn_evolution() -> list[CompositeVideoClip]:
|
||||
"""Animate R-CNN → Fast R-CNN → Faster R-CNN evolution."""
|
||||
slides = []
|
||||
|
||||
def make_evolution_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
progress = min(t / (STEP_DUR * 0.8), 1.0)
|
||||
|
||||
# Three rows: R-CNN, Fast R-CNN, Faster R-CNN
|
||||
models = [
|
||||
(
|
||||
"R-CNN (2014)",
|
||||
50,
|
||||
[
|
||||
("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)),
|
||||
("2000x\nCNN", (350, 150), (80, 50), (180, 60, 60)),
|
||||
("2000x\nSVM", (480, 150), (80, 50), (180, 60, 60)),
|
||||
("NMS", (610, 150), (60, 50), (100, 140, 100)),
|
||||
],
|
||||
"50 sec/obraz!",
|
||||
),
|
||||
(
|
||||
"Fast R-CNN (2015)",
|
||||
300,
|
||||
[
|
||||
("Selective\nSearch", (200, 150), (100, 50), (120, 100, 60)),
|
||||
("1x CNN\n(cały obraz)", (350, 150), (100, 50), (80, 140, 200)),
|
||||
("ROI Pool\n(2000)", (500, 150), (90, 50), (200, 160, 80)),
|
||||
("FC", (640, 150), (50, 50), (100, 140, 100)),
|
||||
],
|
||||
"2 sec/obraz",
|
||||
),
|
||||
(
|
||||
"Faster R-CNN (2015)",
|
||||
300,
|
||||
[
|
||||
("CNN\nbackbone", (200, 150), (90, 50), (80, 140, 200)),
|
||||
("RPN\n(~300)", (340, 150), (80, 50), (200, 120, 60)),
|
||||
("ROI Pool", (470, 150), (80, 50), (200, 160, 80)),
|
||||
("FC", (600, 150), (50, 50), (100, 140, 100)),
|
||||
],
|
||||
"0.2 sec → 5 fps!",
|
||||
),
|
||||
]
|
||||
|
||||
n_models = int(progress * 3) + 1
|
||||
|
||||
for mi, (_name, base_y, stages, _speed) in enumerate(models):
|
||||
if mi >= n_models:
|
||||
break
|
||||
for _label, (bx, by_off), (bw, bh), color in stages:
|
||||
by = base_y + by_off - 150
|
||||
frame[by : by + bh, bx : bx + bw] = color
|
||||
frame[by : by + 2, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
|
||||
# Arrows between stages
|
||||
for si in range(len(stages) - 1):
|
||||
sx = stages[si][1][0] + stages[si][2][0]
|
||||
ex = stages[si + 1][1][0]
|
||||
ay = base_y + 25
|
||||
frame[ay - 1 : ay + 2, sx + 3 : ex - 3] = (150, 150, 170)
|
||||
|
||||
return frame
|
||||
|
||||
evo_clip = VideoClip(make_evolution_frame, duration=STEP_DUR + 1).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [evo_clip]
|
||||
labels = [
|
||||
("Ewolucja R-CNN — CORAZ MNIEJ MARNOWANIA", 28, "#FFE082", FONT_B, (80, 20)),
|
||||
("R-CNN (2014)", 20, "#EF9A9A", FONT_B, (50, 80)),
|
||||
("50 sec/obraz (2000x forward pass!)", 14, "#EF9A9A", FONT_R, (720, 100)),
|
||||
("Fast R-CNN (2015)", 20, "#64B5F6", FONT_B, (50, 330)),
|
||||
("2 sec/obraz (CNN raz + ROI Pool)", 14, "#64B5F6", FONT_R, (720, 350)),
|
||||
("Faster R-CNN (2015)", 20, "#A5D6A7", FONT_B, (50, 580)),
|
||||
("0.2 sec → 5 fps (RPN w sieci!)", 14, "#A5D6A7", FONT_R, (720, 600)),
|
||||
(
|
||||
"Kluczowe innowacje: ROI Pooling → stały rozmiar "
|
||||
"| RPN → propozycje w sieci",
|
||||
14,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 660),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR + 1)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
|
||||
|
||||
# ── R-CNN Detailed Pipeline ──────────────────────────────────────
|
||||
def _rcnn_detailed() -> list[CompositeVideoClip]:
|
||||
"""Animate R-CNN step-by-step pipeline in detail."""
|
||||
slides = []
|
||||
|
||||
# Slide 1: R-CNN pipeline step by step
|
||||
def make_rcnn_pipeline(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.8), 1.0)
|
||||
|
||||
# Step boxes arranged vertically with arrows
|
||||
steps = [
|
||||
((80, 130), (200, 55), (120, 100, 60), "1. Selective Search"),
|
||||
((80, 230), (200, 55), (180, 60, 60), "2. Wytnij 2000 regionów"),
|
||||
((80, 330), (200, 55), (70, 130, 200), "3. CNN per region"),
|
||||
((80, 430), (200, 55), (200, 100, 80), "4. SVM klasyfikuje"),
|
||||
((80, 530), (200, 55), (100, 180, 100), "5. Bbox regresja + NMS"),
|
||||
]
|
||||
n_steps = min(int(progress * 5) + 1, 5)
|
||||
for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(steps):
|
||||
if i < n_steps:
|
||||
frame[by : by + bh, bx : bx + bw] = color
|
||||
frame[by : by + 2, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
# Arrow down
|
||||
arrow_limit = 4
|
||||
if i < arrow_limit:
|
||||
ax = bx + bw // 2
|
||||
ay = by + bh + 5
|
||||
frame[ay : ay + 20, ax - 1 : ax + 2] = (150, 150, 170)
|
||||
|
||||
# Illustration: many overlapping regions from Selective Search
|
||||
overlay_phase = 0.2
|
||||
if progress > overlay_phase:
|
||||
rng_local = np.random.default_rng(42)
|
||||
n_boxes = min(int((progress - 0.2) * 15), 8)
|
||||
for i in range(n_boxes):
|
||||
rx = 500 + rng_local.integers(-30, 100)
|
||||
ry = 200 + rng_local.integers(-20, 120)
|
||||
rw = 60 + rng_local.integers(0, 80)
|
||||
rh = 50 + rng_local.integers(0, 70)
|
||||
c = (80 + i * 15, 100 + i * 10, 60 + i * 20)
|
||||
for tt in range(2):
|
||||
frame[ry - tt : ry + rh + tt, rx - tt : rx - tt + 2] = c
|
||||
frame[ry - tt : ry + rh + tt, rx + rw + tt - 2 : rx + rw + tt] = c
|
||||
frame[ry - tt : ry - tt + 2, rx - tt : rx + rw + tt] = c
|
||||
frame[ry + rh + tt - 2 : ry + rh + tt, rx - tt : rx + rw + tt] = c
|
||||
|
||||
return frame
|
||||
|
||||
rcnn_clip = VideoClip(make_rcnn_pipeline, duration=STEP_DUR + 1).with_fps(FPS)
|
||||
dur = STEP_DUR + 1
|
||||
labels = [
|
||||
("R-CNN: krok po kroku (2014, Girshick)", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
("Pipeline detekcji two-stage", 16, "#B0BEC5", FONT_R, (80, 60)),
|
||||
("Selective Search", 11, "white", FONT_R, (105, 145)),
|
||||
("2000 regionów", 11, "white", FONT_R, (105, 245)),
|
||||
("CNN per region", 11, "white", FONT_R, (105, 345)),
|
||||
("SVM klasyfikuje", 11, "white", FONT_R, (105, 445)),
|
||||
("Regresja + NMS", 11, "white", FONT_R, (105, 545)),
|
||||
("~2000 propozycji regionów", 14, "#78909C", FONT_R, (500, 155)),
|
||||
("(inteligentne łączenie", 13, "#78909C", FONT_R, (500, 180)),
|
||||
("podobnych fragmentów)", 13, "#78909C", FONT_R, (500, 200)),
|
||||
("Problem: 2000 x CNN forward pass", 16, "#EF9A9A", FONT_R, (400, 400)),
|
||||
("= 50 SEKUND na obraz!", 18, "#EF9A9A", FONT_B, (400, 430)),
|
||||
("CNN liczy cechy per region OSOBNO", 14, "#EF9A9A", FONT_R, (400, 470)),
|
||||
(
|
||||
"→ regiony się nakładają → obliczenia się powtarzają!",
|
||||
14,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(400, 495),
|
||||
),
|
||||
(
|
||||
"Rozwiązanie: CNN raz na cały obraz → Fast R-CNN →",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
]
|
||||
text_clips: list[VideoClip] = [rcnn_clip]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(dur)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
|
||||
return slides
|
||||
|
||||
|
||||
# ── ROI Pooling ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _draw_roi_pool_grid(frame: np.ndarray) -> None:
|
||||
"""Draw the 3x3 ROI pool grid with max-pooled feature values."""
|
||||
out_x, out_y = 400, 220
|
||||
out_cell = 50
|
||||
out_n = 3
|
||||
roi_r1, roi_c1 = 2, 1
|
||||
roi_r2, roi_c2 = 6, 5
|
||||
roi_h = roi_r2 - roi_r1
|
||||
roi_w = roi_c2 - roi_c1
|
||||
for r in range(out_n):
|
||||
for c in range(out_n):
|
||||
x = out_x + c * out_cell
|
||||
y = out_y + r * out_cell
|
||||
|
||||
# Compute the max from corresponding region
|
||||
src_r1 = roi_r1 + r * roi_h // out_n
|
||||
src_r2 = roi_r1 + (r + 1) * roi_h // out_n
|
||||
src_c1 = roi_c1 + c * roi_w // out_n
|
||||
src_c2 = roi_c1 + (c + 1) * roi_w // out_n
|
||||
max_val = 0
|
||||
for sr in range(src_r1, src_r2):
|
||||
for sc in range(src_c1, src_c2):
|
||||
v = 30 + ((sr * 7 + sc * 13 + 42) % 40)
|
||||
max_val = max(max_val, v)
|
||||
|
||||
frame[y : y + out_cell - 2, x : x + out_cell - 2] = (
|
||||
max_val,
|
||||
max_val + 20,
|
||||
max_val + 40,
|
||||
)
|
||||
frame[y : y + 2, x : x + out_cell - 2] = (80, 200, 120)
|
||||
frame[y + out_cell - 4 : y + out_cell - 2, x : x + out_cell - 2] = (
|
||||
80,
|
||||
200,
|
||||
120,
|
||||
)
|
||||
|
||||
|
||||
def _make_roi_frame(t: float) -> np.ndarray:
|
||||
"""Render a single frame for the ROI pooling animation."""
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
# Left: feature map with ROI highlighted
|
||||
fm_x, fm_y = 60, 180
|
||||
fm_cell = 30
|
||||
fm_grid = 8
|
||||
for r in range(fm_grid):
|
||||
for c in range(fm_grid):
|
||||
x = fm_x + c * fm_cell
|
||||
y = fm_y + r * fm_cell
|
||||
# Random-looking feature values
|
||||
val = 30 + ((r * 7 + c * 13 + 42) % 40)
|
||||
frame[y : y + fm_cell - 1, x : x + fm_cell - 1] = (
|
||||
val,
|
||||
val + 10,
|
||||
val + 20,
|
||||
)
|
||||
|
||||
# ROI region highlighted
|
||||
roi_r1, roi_c1 = 2, 1
|
||||
roi_r2, roi_c2 = 6, 5
|
||||
for tt in range(3):
|
||||
ry1 = fm_y + roi_r1 * fm_cell - tt
|
||||
ry2 = fm_y + roi_r2 * fm_cell + tt
|
||||
rx1 = fm_x + roi_c1 * fm_cell - tt
|
||||
rx2 = fm_x + roi_c2 * fm_cell + tt
|
||||
frame[ry1:ry2, rx1 : rx1 + 2] = (255, 200, 50)
|
||||
frame[ry1:ry2, rx2 - 2 : rx2] = (255, 200, 50)
|
||||
frame[ry1 : ry1 + 2, rx1:rx2] = (255, 200, 50)
|
||||
frame[ry2 - 2 : ry2, rx1:rx2] = (255, 200, 50)
|
||||
|
||||
# Arrow
|
||||
arrow_phase = 0.3
|
||||
if progress > arrow_phase:
|
||||
frame[300:303, 310:380] = (150, 150, 170)
|
||||
|
||||
# Middle: ROI divided into 3x3 grid (output_size)
|
||||
grid_phase = 0.3
|
||||
if progress > grid_phase:
|
||||
_draw_roi_pool_grid(frame)
|
||||
|
||||
# Arrow to FC
|
||||
fc_phase = 0.6
|
||||
if progress > fc_phase:
|
||||
frame[300:303, 560:630] = (150, 150, 170)
|
||||
# FC box
|
||||
frame[270:340, 650:730] = (200, 100, 80)
|
||||
frame[270:272, 650:730] = (240, 140, 120)
|
||||
frame[338:340, 650:730] = (240, 140, 120)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _roi_pooling_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate ROI Pooling: key Fast R-CNN innovation."""
|
||||
slides = []
|
||||
|
||||
roi_clip = VideoClip(_make_roi_frame, duration=STEP_DUR + 1).with_fps(FPS)
|
||||
dur = STEP_DUR + 1
|
||||
labels = [
|
||||
("ROI Pooling: kluczowa innowacja Fast R-CNN", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"KROK 1: CNN raz na CAŁY obraz → feature mapa",
|
||||
17,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(80, 60),
|
||||
),
|
||||
(
|
||||
"KROK 2: Wytnij ROI z feature mapy (nie z obrazu!)",
|
||||
17,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 90),
|
||||
),
|
||||
(
|
||||
"KROK 3: Siatkuj ROI na 3x3 → max pool per komórka → stały rozmiar",
|
||||
17,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 120),
|
||||
),
|
||||
("Feature mapa", 14, "#64B5F6", FONT_B, (60, 160)),
|
||||
("ROI (żółta ramka)", 13, "#FFE082", FONT_R, (60, 440)),
|
||||
("ROI Pool 3x3", 14, "#A5D6A7", FONT_B, (400, 195)),
|
||||
("(max z komórki)", 13, "#78909C", FONT_R, (400, 380)),
|
||||
("FC", 14, "white", FONT_B, (670, 280)),
|
||||
(
|
||||
"Problem: ROI mają RÓŻNE rozmiary, FC wymaga STAŁEGO",
|
||||
15,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 500),
|
||||
),
|
||||
(
|
||||
"ROI Pooling: dzieli ROI na siatkę, max pool → STAŁY rozmiar!",
|
||||
16,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 535),
|
||||
),
|
||||
(
|
||||
"Fast R-CNN: CNN raz → 1 feature mapa → "
|
||||
"ROI Pool 2000 regionów → 25x szybciej!",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 580),
|
||||
),
|
||||
(
|
||||
"(R-CNN: 2000x CNN = 50s | Fast R-CNN: 1xCNN + ROI Pool = 2s)",
|
||||
15,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
]
|
||||
text_clips: list[VideoClip] = [roi_clip]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(dur)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
383
python_pkg/praca_magisterska_video/_q24_rpn_yolo.py
Normal file
383
python_pkg/praca_magisterska_video/_q24_rpn_yolo.py
Normal file
@ -0,0 +1,383 @@
|
||||
"""RPN anchor boxes and YOLO grid detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_tc,
|
||||
_text_slide,
|
||||
)
|
||||
from moviepy import CompositeVideoClip, VideoClip
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── RPN + Anchor Boxes ───────────────────────────────────────────
|
||||
def _rpn_anchors_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate RPN and anchor boxes: Faster R-CNN innovation."""
|
||||
slides = []
|
||||
|
||||
# Slide 1: Anchor boxes concept
|
||||
def make_anchors_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
# Draw feature map grid point with multiple anchors
|
||||
cx, cy = 350, 360 # center point on feature map
|
||||
|
||||
# Draw a "feature map" grid background
|
||||
cell = 60
|
||||
for r in range(-3, 4):
|
||||
for c in range(-3, 4):
|
||||
x = cx + c * cell - cell // 2
|
||||
y = cy + r * cell - cell // 2
|
||||
frame[y : y + cell - 1, x : x + cell - 1] = (30, 35, 48)
|
||||
|
||||
# Center point highlighted
|
||||
frame[cy - 5 : cy + 5, cx - 5 : cx + 5] = (255, 200, 50)
|
||||
|
||||
# Draw anchors around center: 3 sizes x 3 ratios = 9
|
||||
anchor_specs = [
|
||||
(30, 30, (200, 80, 80)), # small 1:1
|
||||
(20, 40, (200, 60, 60)), # small 1:2
|
||||
(40, 20, (180, 60, 60)), # small 2:1
|
||||
(60, 60, (80, 200, 80)), # medium 1:1
|
||||
(40, 80, (60, 180, 60)), # medium 1:2
|
||||
(80, 40, (60, 160, 60)), # medium 2:1
|
||||
(90, 90, (80, 80, 200)), # large 1:1
|
||||
(60, 120, (60, 60, 180)), # large 1:2
|
||||
(120, 60, (60, 60, 160)), # large 2:1
|
||||
]
|
||||
n_anchors = min(int(progress * 9) + 1, 9)
|
||||
for i in range(n_anchors):
|
||||
hw, hh, color = anchor_specs[i]
|
||||
x1 = max(0, cx - hw)
|
||||
y1 = max(0, cy - hh)
|
||||
x2 = min(W - 1, cx + hw)
|
||||
y2 = min(H - 1, cy + hh)
|
||||
for tt in range(2):
|
||||
frame[y1 - tt : y2 + tt, x1 - tt : x1 - tt + 2] = color
|
||||
frame[y1 - tt : y2 + tt, x2 + tt - 2 : x2 + tt] = color
|
||||
frame[y1 - tt : y1 - tt + 2, x1 - tt : x2 + tt] = color
|
||||
frame[y2 + tt - 2 : y2 + tt, x1 - tt : x2 + tt] = color
|
||||
|
||||
return frame
|
||||
|
||||
anch_clip = VideoClip(make_anchors_frame, duration=STEP_DUR + 1).with_fps(FPS)
|
||||
dur = STEP_DUR + 1
|
||||
labels = [
|
||||
("Anchor Boxes + RPN (Faster R-CNN)", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"KROK 1: Anchory = predefiniowane kształty w każdej pozycji",
|
||||
17,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 60),
|
||||
),
|
||||
(
|
||||
"3 rozmiary x 3 proporcje = 9 anchorów per punkt",
|
||||
16,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 90),
|
||||
),
|
||||
("Małe (1:1, 1:2, 2:1)", 14, "#EF9A9A", FONT_R, (750, 170)),
|
||||
("Średnie (1:1, 1:2, 2:1)", 14, "#A5D6A7", FONT_R, (750, 210)),
|
||||
("Duże (1:1, 1:2, 2:1)", 14, "#64B5F6", FONT_R, (750, 250)),
|
||||
("Żółty punkt = pozycja", 14, "#FFE082", FONT_R, (750, 310)),
|
||||
("na feature mapie", 14, "#FFE082", FONT_R, (750, 335)),
|
||||
("Sieć NIE predykuje bbox od zera!", 16, "white", FONT_R, (80, 530)),
|
||||
(
|
||||
"Predykuje OFFSET od najbliższego anchora: (Δx, Δy, Δw, Δh)",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 565),
|
||||
),
|
||||
(
|
||||
"+ P(obiekt) = 'czy w tym anchorze jest coś?'",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 600),
|
||||
),
|
||||
(
|
||||
"Mnemonik: Anchor = KOTWICA — sieć dopasowuje bbox do kotwicy",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 645),
|
||||
),
|
||||
]
|
||||
text_clips: list[VideoClip] = [anch_clip]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(dur)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
|
||||
# Slide 2: RPN step by step
|
||||
rpn_lines = [
|
||||
(
|
||||
"RPN: Region Proposal Network — krok po kroku",
|
||||
24,
|
||||
"#FFE082",
|
||||
FONT_B,
|
||||
(80, 30),
|
||||
),
|
||||
(
|
||||
"Zastępuje Selective Search SIECIĄ NEURONOWĄ (end-to-end!)",
|
||||
17,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 85),
|
||||
),
|
||||
("", 10, "white", FONT_R, (80, 110)),
|
||||
(
|
||||
"1. Backbone (ResNet) przetwarza obraz → feature mapa [40x60x256]",
|
||||
16,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 140),
|
||||
),
|
||||
(
|
||||
"2. Filtr 3x3 przesuwa się po feature mapie",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 180),
|
||||
),
|
||||
(
|
||||
"3. W KAŻDEJ pozycji (x,y) rozważ k=9 anchorów:",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 220),
|
||||
),
|
||||
(" → P(obiekt) — 'czy tu jest coś?'", 15, "white", FONT_R, (120, 255)),
|
||||
(" → (Δx, Δy, Δw, Δh) — poprawka pozycji", 15, "white", FONT_R, (120, 285)),
|
||||
(
|
||||
"4. 40x60 pozycji x 9 anchorów = 21 600 kandydatów!",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(100, 325),
|
||||
),
|
||||
(
|
||||
"5. Weź ~300 z najwyższym P(obiekt) → ROI Pool → FC",
|
||||
16,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 365),
|
||||
),
|
||||
("", 10, "white", FONT_R, (100, 395)),
|
||||
("Porównanie generowania propozycji:", 17, "white", FONT_B, (80, 420)),
|
||||
(
|
||||
" Selective Search: ~2000 regionów, osobny algorytm, ~2 sec",
|
||||
15,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(100, 460),
|
||||
),
|
||||
(
|
||||
" RPN: ~300 regionów, W SIECI, ~10 ms → 200x szybciej!",
|
||||
15,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 495),
|
||||
),
|
||||
("", 10, "white", FONT_R, (100, 520)),
|
||||
(
|
||||
"Faster R-CNN = Backbone + RPN + ROI Pool + FC — WSZYSTKO end-to-end",
|
||||
17,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 545),
|
||||
),
|
||||
(
|
||||
"→ 5 fps (0.2 sec/obraz) vs R-CNN 50 sec = 250x szybciej!",
|
||||
17,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 585),
|
||||
),
|
||||
(
|
||||
"Wciąż two-stage: (1) RPN generuje propozycje, (2) FC klasyfikuje",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 630),
|
||||
),
|
||||
]
|
||||
slides.append(_text_slide(rpn_lines, duration=STEP_DUR + 1))
|
||||
|
||||
return slides
|
||||
|
||||
|
||||
# ── YOLO ──────────────────────────────────────────────────────────
|
||||
def _yolo_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate YOLO grid detection concept."""
|
||||
slides = []
|
||||
|
||||
def make_yolo_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
# Draw image with grid overlay
|
||||
img_x, img_y = 100, 140
|
||||
img_size = 420
|
||||
grid_n = 7
|
||||
|
||||
# Background "image"
|
||||
frame[img_y : img_y + img_size, img_x : img_x + img_size] = (50, 55, 70)
|
||||
|
||||
# Objects in the image
|
||||
frame[img_y + 80 : img_y + 200, img_x + 50 : img_x + 180] = (
|
||||
180,
|
||||
60,
|
||||
60,
|
||||
) # "car"
|
||||
frame[img_y + 150 : img_y + 350, img_x + 250 : img_x + 330] = (
|
||||
60,
|
||||
120,
|
||||
180,
|
||||
) # "person"
|
||||
|
||||
# Grid lines
|
||||
cell = img_size // grid_n
|
||||
for i in range(grid_n + 1):
|
||||
# Vertical
|
||||
x = img_x + i * cell
|
||||
frame[img_y : img_y + img_size, x : x + 1] = (100, 100, 120)
|
||||
# Horizontal
|
||||
y = img_y + i * cell
|
||||
frame[y : y + 1, img_x : img_x + img_size] = (100, 100, 120)
|
||||
|
||||
# Highlight cells containing object centers
|
||||
car_phase = 0.3
|
||||
if progress > car_phase:
|
||||
# Car center ~ cell (1, 1)
|
||||
cx, cy = 1, 2
|
||||
hx = img_x + cx * cell
|
||||
hy = img_y + cy * cell
|
||||
frame[hy : hy + cell, hx : hx + cell] = np.clip(
|
||||
frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255
|
||||
).astype(np.uint8)
|
||||
|
||||
person_phase = 0.5
|
||||
if progress > person_phase:
|
||||
# Person center ~ cell (4, 4)
|
||||
cx, cy = 4, 4
|
||||
hx = img_x + cx * cell
|
||||
hy = img_y + cy * cell
|
||||
frame[hy : hy + cell, hx : hx + cell] = np.clip(
|
||||
frame[hy : hy + cell, hx : hx + cell].astype(int) + 40, 0, 255
|
||||
).astype(np.uint8)
|
||||
|
||||
# Bounding boxes predictions from cells
|
||||
bbox_phase = 0.6
|
||||
if progress > bbox_phase:
|
||||
# Car bbox
|
||||
for tt in range(2):
|
||||
frame[
|
||||
img_y + 78 - tt : img_y + 202 + tt,
|
||||
img_x + 48 - tt : img_x + 48 - tt + 2,
|
||||
] = (255, 80, 80)
|
||||
frame[
|
||||
img_y + 78 - tt : img_y + 202 + tt,
|
||||
img_x + 182 + tt - 2 : img_x + 182 + tt,
|
||||
] = (255, 80, 80)
|
||||
frame[
|
||||
img_y + 78 - tt : img_y + 78 - tt + 2,
|
||||
img_x + 48 - tt : img_x + 182 + tt,
|
||||
] = (255, 80, 80)
|
||||
frame[
|
||||
img_y + 202 + tt - 2 : img_y + 202 + tt,
|
||||
img_x + 48 - tt : img_x + 182 + tt,
|
||||
] = (255, 80, 80)
|
||||
|
||||
# Person bbox
|
||||
for tt in range(2):
|
||||
frame[
|
||||
img_y + 148 - tt : img_y + 352 + tt,
|
||||
img_x + 248 - tt : img_x + 248 - tt + 2,
|
||||
] = (80, 180, 255)
|
||||
frame[
|
||||
img_y + 148 - tt : img_y + 352 + tt,
|
||||
img_x + 332 + tt - 2 : img_x + 332 + tt,
|
||||
] = (80, 180, 255)
|
||||
frame[
|
||||
img_y + 148 - tt : img_y + 148 - tt + 2,
|
||||
img_x + 248 - tt : img_x + 332 + tt,
|
||||
] = (80, 180, 255)
|
||||
frame[
|
||||
img_y + 352 + tt - 2 : img_y + 352 + tt,
|
||||
img_x + 248 - tt : img_x + 332 + tt,
|
||||
] = (80, 180, 255)
|
||||
|
||||
return frame
|
||||
|
||||
yolo_clip = VideoClip(make_yolo_frame, duration=STEP_DUR).with_fps(FPS)
|
||||
text_clips: list[VideoClip] = [yolo_clip]
|
||||
labels = [
|
||||
("YOLO — You Only Look Once", 28, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"Jednoetapowy detektor: siatka SxS → wszystkie detekcje naraz!",
|
||||
18,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 65),
|
||||
),
|
||||
("Siatka 7x7 = 49 komórek", 16, "#64B5F6", FONT_R, (600, 180)),
|
||||
("Każda komórka predykuje:", 16, "white", FONT_R, (600, 220)),
|
||||
(" • B bbox (x, y, w, h, conf)", 14, "#B0BEC5", FONT_R, (600, 255)),
|
||||
(" • C klas (prawdopodobieństwa)", 14, "#B0BEC5", FONT_R, (600, 285)),
|
||||
("Komórka odpowiada za obiekt", 14, "#A5D6A7", FONT_R, (600, 325)),
|
||||
("którego ŚRODEK w niej wpada", 14, "#A5D6A7", FONT_R, (600, 350)),
|
||||
("45-155 fps! (vs 5 fps Faster R-CNN)", 18, "#EF9A9A", FONT_B, (600, 400)),
|
||||
(
|
||||
"Jedno przejście przez sieć → WSZYSTKIE detekcje naraz → NMS → wynik",
|
||||
14,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
(
|
||||
"Two-stage (R-CNN): propozycje+klasyfikacja "
|
||||
"| One-stage (YOLO): bez propozycji!",
|
||||
14,
|
||||
"#90CAF9",
|
||||
FONT_R,
|
||||
(80, 655),
|
||||
),
|
||||
]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(STEP_DUR)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
return slides
|
||||
459
python_pkg/praca_magisterska_video/_q24_yolo_arch_detr.py
Normal file
459
python_pkg/praca_magisterska_video/_q24_yolo_arch_detr.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""YOLO architecture detail and DETR transformer detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
BG_COLOR,
|
||||
FONT_B,
|
||||
FONT_R,
|
||||
FPS,
|
||||
STEP_DUR,
|
||||
H,
|
||||
W,
|
||||
_tc,
|
||||
_text_slide,
|
||||
)
|
||||
from moviepy import CompositeVideoClip, VideoClip
|
||||
from moviepy.video.fx import FadeIn, FadeOut
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── YOLO Architecture Detail ──────────────────────────────────────
|
||||
def _yolo_architecture() -> list[CompositeVideoClip]:
|
||||
"""Show YOLO architecture: backbone → head, output tensor."""
|
||||
slides = []
|
||||
|
||||
# Slide 1: YOLO architecture breakdown
|
||||
def make_yolo_arch(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
# Pipeline: Image → Backbone → Neck → Head → SxSx(B*5+C) tensor
|
||||
blocks = [
|
||||
((60, 280), (100, 80), (50, 70, 90), "Obraz"),
|
||||
((200, 280), (100, 80), (70, 130, 200), "Backbone"),
|
||||
((340, 280), (100, 80), (200, 160, 80), "Neck"),
|
||||
((480, 280), (100, 80), (200, 100, 60), "Head"),
|
||||
((620, 280), (160, 80), (80, 200, 120), "SxSx(B*5+C)"),
|
||||
]
|
||||
n_blocks = min(int(progress * 5) + 1, 5)
|
||||
for i, ((bx, by), (bw, bh), color, _lbl) in enumerate(blocks):
|
||||
if i < n_blocks:
|
||||
frame[by : by + bh, bx : bx + bw] = color
|
||||
frame[by : by + 2, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
arrow_limit = 4
|
||||
if i < arrow_limit:
|
||||
ax = bx + bw + 5
|
||||
ay = by + bh // 2
|
||||
frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170)
|
||||
|
||||
# Output tensor breakdown (right side)
|
||||
tensor_phase = 0.6
|
||||
if progress > tensor_phase:
|
||||
# Show SxS grid
|
||||
gx, gy = 850, 180
|
||||
gs = 120
|
||||
gn = 4 # simplified from 7
|
||||
gc = gs // gn
|
||||
for r in range(gn):
|
||||
for c in range(gn):
|
||||
x = gx + c * gc
|
||||
y = gy + r * gc
|
||||
frame[y : y + gc - 1, x : x + gc - 1] = (40, 50, 65)
|
||||
# Highlight one cell
|
||||
frame[gy + gc : gy + 2 * gc - 1, gx + gc : gx + 2 * gc - 1] = (
|
||||
80,
|
||||
200,
|
||||
120,
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
arch_clip = VideoClip(make_yolo_arch, duration=STEP_DUR + 1).with_fps(FPS)
|
||||
dur = STEP_DUR + 1
|
||||
labels = [
|
||||
("YOLO: Architektura — krok po kroku", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"One-stage: JEDEN forward pass → WSZYSTKIE detekcje naraz",
|
||||
17,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 60),
|
||||
),
|
||||
("Obraz", 13, "white", FONT_R, (85, 295)),
|
||||
("Backbone", 13, "white", FONT_R, (215, 295)),
|
||||
("(ResNet/", 11, "#78909C", FONT_R, (210, 370)),
|
||||
("Darknet)", 11, "#78909C", FONT_R, (210, 390)),
|
||||
("Neck", 13, "white", FONT_R, (365, 295)),
|
||||
("(FPN/", 11, "#78909C", FONT_R, (360, 370)),
|
||||
("PANet)", 11, "#78909C", FONT_R, (360, 390)),
|
||||
("Head", 13, "white", FONT_R, (505, 295)),
|
||||
("(conv)", 11, "#78909C", FONT_R, (500, 370)),
|
||||
("Tensor wyjścia", 13, "#A5D6A7", FONT_R, (640, 295)),
|
||||
("Każda komórka SxS predykuje:", 15, "#FFE082", FONT_R, (830, 320)),
|
||||
(" B bbox x (x,y,w,h,conf)", 13, "#B0BEC5", FONT_R, (830, 350)),
|
||||
(" + C klas (prob.)", 13, "#B0BEC5", FONT_R, (830, 375)),
|
||||
("= SxSx(Bx5+C) tensor", 13, "#A5D6A7", FONT_R, (830, 400)),
|
||||
("Np. 7x7x(2x5+20) = 7x7x30", 13, "#78909C", FONT_R, (830, 430)),
|
||||
(
|
||||
"Two-stage (R-CNN): (1) propozycje → (2) klasyfikacja = 2 przejścia",
|
||||
15,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(80, 470),
|
||||
),
|
||||
(
|
||||
"One-stage (YOLO): siatka → predykcja all-in-one = 1 przejście!",
|
||||
15,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 505),
|
||||
),
|
||||
(
|
||||
"Ewolucja YOLO: v1(2016)→v3→v5→v8(2023, anchor-free, SOTA)",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 555),
|
||||
),
|
||||
(
|
||||
"SSD (2016): multi-scale feature maps → lepsza detekcja małych obiektów",
|
||||
15,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(80, 595),
|
||||
),
|
||||
(
|
||||
"FPN: łączy wczesne warstwy (małe obiekty) + późne (duże obiekty)",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 630),
|
||||
),
|
||||
]
|
||||
text_clips: list[VideoClip] = [arch_clip]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(dur)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
|
||||
return slides
|
||||
|
||||
|
||||
# ── DETR ──────────────────────────────────────────────────────────
|
||||
def _detr_demo() -> list[CompositeVideoClip]:
|
||||
"""Animate DETR: transformer detection, object queries, no NMS."""
|
||||
slides = []
|
||||
|
||||
# Slide 1: DETR pipeline
|
||||
def make_detr_frame(t: float) -> np.ndarray:
|
||||
frame = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
frame[:] = BG_COLOR
|
||||
progress = min(t / (STEP_DUR * 0.7), 1.0)
|
||||
|
||||
# DETR pipeline: Image → Backbone → Encoder → Decoder → N predictions
|
||||
blocks = [
|
||||
((50, 260), (80, 60), (50, 70, 90)),
|
||||
((170, 260), (90, 60), (70, 130, 200)),
|
||||
((300, 260), (110, 60), (200, 120, 60)),
|
||||
((450, 260), (110, 60), (200, 80, 160)),
|
||||
((600, 260), (120, 60), (80, 200, 120)),
|
||||
]
|
||||
n_blocks = min(int(progress * 5) + 1, 5)
|
||||
for i, ((bx, by), (bw, bh), color) in enumerate(blocks):
|
||||
if i < n_blocks:
|
||||
frame[by : by + bh, bx : bx + bw] = color
|
||||
frame[by : by + 2, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
frame[by + bh - 2 : by + bh, bx : bx + bw] = tuple(
|
||||
min(c + 50, 255) for c in color
|
||||
)
|
||||
arrow_limit = 4
|
||||
if i < arrow_limit:
|
||||
ax = bx + bw + 5
|
||||
ay = by + bh // 2
|
||||
frame[ay - 1 : ay + 2, ax : ax + 25] = (150, 150, 170)
|
||||
|
||||
# Object queries illustration (right side)
|
||||
query_phase = 0.5
|
||||
if progress > query_phase:
|
||||
qx, qy = 800, 140
|
||||
for i in range(6):
|
||||
y = qy + i * 50
|
||||
w = 130
|
||||
active_limit = 3
|
||||
active = i < active_limit
|
||||
color = (80, 180, 120) if active else (60, 50, 50)
|
||||
frame[y : y + 35, qx : qx + w] = color
|
||||
frame[y : y + 1, qx : qx + w] = tuple(min(c + 40, 255) for c in color)
|
||||
|
||||
# Arrow from decoder to queries
|
||||
frame[285:288, 723:798] = (150, 150, 170)
|
||||
|
||||
return frame
|
||||
|
||||
detr_clip = VideoClip(make_detr_frame, duration=STEP_DUR + 1).with_fps(FPS)
|
||||
dur = STEP_DUR + 1
|
||||
labels = [
|
||||
("DETR: DEtection TRansformer (2020)", 26, "#FFE082", FONT_B, (80, 20)),
|
||||
(
|
||||
"Radykalnie prostszy pipeline: BEZ anchorów, BEZ NMS!",
|
||||
17,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(80, 60),
|
||||
),
|
||||
("Obraz", 12, "white", FONT_R, (65, 275)),
|
||||
("Backbone", 12, "white", FONT_R, (185, 275)),
|
||||
("Transformer", 12, "white", FONT_R, (310, 275)),
|
||||
("Encoder", 12, "white", FONT_R, (325, 295)),
|
||||
("Transformer", 12, "white", FONT_R, (460, 275)),
|
||||
("Decoder", 12, "white", FONT_R, (478, 295)),
|
||||
("N predykcji", 12, "white", FONT_R, (615, 275)),
|
||||
("Object Queries:", 14, "#FFE082", FONT_B, (800, 115)),
|
||||
("samochód 95%", 11, "white", FONT_R, (810, 148)),
|
||||
("pies 88%", 11, "white", FONT_R, (810, 198)),
|
||||
("rower 72%", 11, "white", FONT_R, (810, 248)),
|
||||
("brak", 11, "#78909C", FONT_R, (810, 298)),
|
||||
("brak", 11, "#78909C", FONT_R, (810, 348)),
|
||||
("brak", 11, "#78909C", FONT_R, (810, 398)),
|
||||
("100 wyuczonych queries", 13, "#FFE082", FONT_R, (800, 440)),
|
||||
("→ każdy 'szuka' obiektu", 13, "#FFE082", FONT_R, (800, 465)),
|
||||
]
|
||||
text_clips: list[VideoClip] = [detr_clip]
|
||||
for text, fs, color, font, pos in labels:
|
||||
tc = (
|
||||
_tc(text=text, font_size=fs, color=color, font=font)
|
||||
.with_duration(dur)
|
||||
.with_position(pos)
|
||||
)
|
||||
text_clips.append(tc)
|
||||
slides.append(
|
||||
CompositeVideoClip(text_clips, size=(W, H)).with_effects(
|
||||
[FadeIn(0.3), FadeOut(0.3)]
|
||||
)
|
||||
)
|
||||
|
||||
# Slide 2: Why no NMS + Hungarian matching
|
||||
detr_details = [
|
||||
("DETR: Dlaczego bez NMS? — krok po kroku", 24, "#FFE082", FONT_B, (80, 30)),
|
||||
(
|
||||
"Problem NMS: duplikaty detekcji → ręcznie usuwaj post-hoc",
|
||||
16,
|
||||
"#EF9A9A",
|
||||
FONT_R,
|
||||
(80, 90),
|
||||
),
|
||||
(
|
||||
"DETR rozwiązanie: Hungarian matching (dopasowanie węgierskie)",
|
||||
17,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(80, 130),
|
||||
),
|
||||
("", 10, "white", FONT_R, (80, 155)),
|
||||
("Jak to działa podczas TRENINGU:", 17, "white", FONT_B, (80, 180)),
|
||||
(
|
||||
" 1. Sieć daje N=100 predykcji (queries)",
|
||||
15,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 220),
|
||||
),
|
||||
(
|
||||
" 2. Na obrazie jest np. 5 obiektów (ground truth)",
|
||||
15,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(100, 255),
|
||||
),
|
||||
(
|
||||
" 3. Hungarian matching: optymalne dopasowanie 1:1",
|
||||
15,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(100, 290),
|
||||
),
|
||||
(
|
||||
" → query_1 ↔ gt_samochód (najlepsze dopasowanie)",
|
||||
14,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(120, 325),
|
||||
),
|
||||
(" → query_7 ↔ gt_pies", 14, "#A5D6A7", FONT_R, (120, 355)),
|
||||
(" → query_3 ↔ gt_rower", 14, "#A5D6A7", FONT_R, (120, 385)),
|
||||
(
|
||||
" → pozostałe 97 queries ↔ klasa 'brak obiektu'",
|
||||
14,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(120, 415),
|
||||
),
|
||||
(
|
||||
" 4. Każdy obiekt ma DOKŁADNIE 1 predykcję → BRAK duplikatów!",
|
||||
15,
|
||||
"#A5D6A7",
|
||||
FONT_R,
|
||||
(100, 455),
|
||||
),
|
||||
("", 10, "white", FONT_R, (100, 475)),
|
||||
(
|
||||
"Self-attention w encoderze: cechy obrazu 'rozmawiają' ze sobą",
|
||||
15,
|
||||
"#64B5F6",
|
||||
FONT_R,
|
||||
(80, 500),
|
||||
),
|
||||
(
|
||||
"Cross-attention w decoderze: queries 'pytają' cechy obrazu",
|
||||
15,
|
||||
"#CE93D8",
|
||||
FONT_R,
|
||||
(80, 535),
|
||||
),
|
||||
(
|
||||
"→ query 'rozumie' który fragment obrazu to 'jego' obiekt",
|
||||
15,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 570),
|
||||
),
|
||||
(
|
||||
"DETR = Detekcja Eliminująca Trikowe Redundancje (NMS, anchory)",
|
||||
16,
|
||||
"#FFE082",
|
||||
FONT_R,
|
||||
(80, 620),
|
||||
),
|
||||
(
|
||||
"Wada: wolniejszy trening (O(n²) attention) | Zaleta: prostszy pipeline!",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 660),
|
||||
),
|
||||
]
|
||||
slides.append(_text_slide(detr_details, duration=STEP_DUR + 1))
|
||||
|
||||
# Slide 3: Two-stage vs One-stage vs Transformer summary
|
||||
summary_lines = [
|
||||
(
|
||||
"Podsumowanie: Two-stage vs One-stage vs Transformer",
|
||||
22,
|
||||
"#FFE082",
|
||||
FONT_B,
|
||||
(80, 30),
|
||||
),
|
||||
("", 10, "white", FONT_R, (80, 55)),
|
||||
("TWO-STAGE (R-CNN family):", 18, "#EF9A9A", FONT_B, (80, 90)),
|
||||
(
|
||||
" (1) Generuj propozycje → (2) Klasyfikuj per region",
|
||||
15,
|
||||
"white",
|
||||
FONT_R,
|
||||
(100, 125),
|
||||
),
|
||||
(
|
||||
" + Wysoka precyzja | - Wolniejsze (2 przejścia)",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 155),
|
||||
),
|
||||
(
|
||||
" R-CNN → Fast R-CNN → Faster R-CNN (0.2s)",
|
||||
15,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(100, 185),
|
||||
),
|
||||
("", 10, "white", FONT_R, (80, 210)),
|
||||
("ONE-STAGE (YOLO, SSD):", 18, "#A5D6A7", FONT_B, (80, 240)),
|
||||
(
|
||||
" Siatka → predykcja all-in-one (1 przejście)",
|
||||
15,
|
||||
"white",
|
||||
FONT_R,
|
||||
(100, 275),
|
||||
),
|
||||
(
|
||||
" + Bardzo szybkie (45-155 fps) | - Historycznie mniej precyzyjne",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 305),
|
||||
),
|
||||
(
|
||||
" YOLOv8 (2023): anchor-free, dorównuje two-stage!",
|
||||
15,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(100, 335),
|
||||
),
|
||||
("", 10, "white", FONT_R, (80, 360)),
|
||||
("TRANSFORMER (DETR):", 18, "#CE93D8", FONT_B, (80, 390)),
|
||||
(
|
||||
" Object queries + self-attention (globalny kontekst)",
|
||||
15,
|
||||
"white",
|
||||
FONT_R,
|
||||
(100, 425),
|
||||
),
|
||||
(
|
||||
" + Brak NMS/anchorów | - Wolniejszy trening (O(n²))",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(100, 455),
|
||||
),
|
||||
(
|
||||
" Hungarian matching → 1:1 obiekt↔predykcja → brak duplikatów",
|
||||
15,
|
||||
"#B0BEC5",
|
||||
FONT_R,
|
||||
(100, 485),
|
||||
),
|
||||
("", 10, "white", FONT_R, (80, 510)),
|
||||
(
|
||||
"Trend: coraz prostsze pipeline, mniej ręcznych komponentów",
|
||||
17,
|
||||
"white",
|
||||
FONT_R,
|
||||
(80, 540),
|
||||
),
|
||||
(
|
||||
" R-CNN (SS+CNN+SVM+NMS) → YOLO "
|
||||
"(backbone+head+NMS) → DETR (backbone+transformer)",
|
||||
14,
|
||||
"#90CAF9",
|
||||
FONT_R,
|
||||
(80, 580),
|
||||
),
|
||||
(
|
||||
"Metryki: mAP@0.5 (standard), mAP@0.5:0.95 (surowsza), "
|
||||
"IoU do dopasowania",
|
||||
15,
|
||||
"#78909C",
|
||||
FONT_R,
|
||||
(80, 630),
|
||||
),
|
||||
]
|
||||
slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1))
|
||||
|
||||
return slides
|
||||
@ -0,0 +1,235 @@
|
||||
"""Common drawing primitives and constants for Pub/Sub diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
DPI = 300
|
||||
BG = "white"
|
||||
LN = "black"
|
||||
FS = 9
|
||||
FS_TITLE = 13
|
||||
FIG_W = 8.27 # A4 width in inches
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
GRAY1 = "#E8E8E8"
|
||||
GRAY2 = "#D0D0D0"
|
||||
GRAY3 = "#B8B8B8"
|
||||
GRAY4 = "#F5F5F5"
|
||||
GRAY5 = "#C0C0C0"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BoxStyle:
|
||||
"""Optional styling for boxes."""
|
||||
|
||||
fill: str = "white"
|
||||
lw: float = 1.2
|
||||
fontsize: float = FS
|
||||
fontweight: str = "normal"
|
||||
ha: str = "center"
|
||||
va: str = "center"
|
||||
rounded: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrowCfg:
|
||||
"""Config for arrows."""
|
||||
|
||||
lw: float = 1.2
|
||||
style: str = "->"
|
||||
color: str = LN
|
||||
label: str = ""
|
||||
label_offset: float = 0.15
|
||||
label_fs: float = 8
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DashedCfg:
|
||||
"""Config for dashed arrows."""
|
||||
|
||||
lw: float = 1.0
|
||||
color: str = LN
|
||||
label: str = ""
|
||||
label_offset: float = 0.15
|
||||
label_fs: float = 8
|
||||
|
||||
|
||||
def draw_box(
|
||||
ax: Axes,
|
||||
pos: tuple[float, float],
|
||||
size: tuple[float, float],
|
||||
text: str,
|
||||
style: BoxStyle | None = None,
|
||||
) -> None:
|
||||
"""Draw box."""
|
||||
s = style or BoxStyle()
|
||||
x, y = pos
|
||||
w, h = size
|
||||
if s.rounded:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=s.lw,
|
||||
edgecolor=LN,
|
||||
facecolor=s.fill,
|
||||
)
|
||||
else:
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
lw=s.lw,
|
||||
edgecolor=LN,
|
||||
facecolor=s.fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2,
|
||||
text,
|
||||
ha=s.ha,
|
||||
va=s.va,
|
||||
fontsize=s.fontsize,
|
||||
fontweight=s.fontweight,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
|
||||
def draw_arrow(
|
||||
ax: Axes,
|
||||
start: tuple[float, float],
|
||||
end: tuple[float, float],
|
||||
cfg: ArrowCfg | None = None,
|
||||
) -> None:
|
||||
"""Draw arrow."""
|
||||
c = cfg or ArrowCfg()
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=end,
|
||||
xytext=start,
|
||||
arrowprops={
|
||||
"arrowstyle": c.style,
|
||||
"color": c.color,
|
||||
"lw": c.lw,
|
||||
},
|
||||
)
|
||||
if c.label:
|
||||
mx = (start[0] + end[0]) / 2
|
||||
my = (start[1] + end[1]) / 2 + c.label_offset
|
||||
ax.text(
|
||||
mx,
|
||||
my,
|
||||
c.label,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=c.label_fs,
|
||||
color=c.color,
|
||||
)
|
||||
|
||||
|
||||
def draw_dashed_arrow(
|
||||
ax: Axes,
|
||||
start: tuple[float, float],
|
||||
end: tuple[float, float],
|
||||
cfg: DashedCfg | None = None,
|
||||
) -> None:
|
||||
"""Draw dashed arrow."""
|
||||
c = cfg or DashedCfg()
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=end,
|
||||
xytext=start,
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": c.color,
|
||||
"lw": c.lw,
|
||||
"linestyle": "dashed",
|
||||
},
|
||||
)
|
||||
if c.label:
|
||||
mx = (start[0] + end[0]) / 2
|
||||
my = (start[1] + end[1]) / 2 + c.label_offset
|
||||
ax.text(
|
||||
mx,
|
||||
my,
|
||||
c.label,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=c.label_fs,
|
||||
color=c.color,
|
||||
)
|
||||
|
||||
|
||||
def draw_cross(
|
||||
ax: Axes,
|
||||
pos: tuple[float, float],
|
||||
size: float = 0.15,
|
||||
lw: float = 2.5,
|
||||
color: str = "black",
|
||||
) -> None:
|
||||
"""Draw cross."""
|
||||
x, y = pos
|
||||
ax.plot(
|
||||
[x - size, x + size],
|
||||
[y - size, y + size],
|
||||
color=color,
|
||||
lw=lw,
|
||||
)
|
||||
ax.plot(
|
||||
[x - size, x + size],
|
||||
[y + size, y - size],
|
||||
color=color,
|
||||
lw=lw,
|
||||
)
|
||||
|
||||
|
||||
def draw_check(
|
||||
ax: Axes,
|
||||
pos: tuple[float, float],
|
||||
size: float = 0.15,
|
||||
lw: float = 2.5,
|
||||
color: str = "black",
|
||||
) -> None:
|
||||
"""Draw check."""
|
||||
x, y = pos
|
||||
ax.plot(
|
||||
[x - size, x - size * 0.2],
|
||||
[y, y - size * 0.7],
|
||||
color=color,
|
||||
lw=lw,
|
||||
)
|
||||
ax.plot(
|
||||
[x - size * 0.2, x + size],
|
||||
[y - size * 0.7, y + size * 0.5],
|
||||
color=color,
|
||||
lw=lw,
|
||||
)
|
||||
|
||||
|
||||
def save(fig: plt.Figure, name: str) -> None:
|
||||
"""Save."""
|
||||
plt.tight_layout()
|
||||
fig.savefig(
|
||||
str(Path(OUTPUT_DIR) / name),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close(fig)
|
||||
@ -0,0 +1,430 @@
|
||||
"""QoS delivery guarantee diagrams: at-most-once, at-least-once, exactly-once."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _pubsub_common import (
|
||||
FIG_W,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
LN,
|
||||
ArrowCfg,
|
||||
BoxStyle,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_check,
|
||||
draw_cross,
|
||||
draw_dashed_arrow,
|
||||
save,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 5. At-most-once (QoS 0)
|
||||
# ============================================================
|
||||
def draw_qos_at_most_once() -> None:
|
||||
"""Draw qos at most once."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"QoS: At-most-once"
|
||||
" \u2014 \u201ewy\u015blij i zapomnij\u201d"
|
||||
" (0 lub 1 dostarczenie)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
px, bx, sx = 1.0, 4.8, 8.5
|
||||
pw, bw, sw = 2.0, 2.2, 2.0
|
||||
bh = 0.8
|
||||
bold10_g1 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
|
||||
bold10_g2 = BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold")
|
||||
draw_box(ax, (px, 5.0), (pw, bh), "Publisher", bold10_g1)
|
||||
draw_box(ax, (bx, 5.0), (bw, bh), "Broker", bold10_g2)
|
||||
draw_box(
|
||||
ax,
|
||||
(sx, 5.0),
|
||||
(sw, bh),
|
||||
"Subscriber",
|
||||
bold10_g1,
|
||||
)
|
||||
|
||||
for xc in [px + pw / 2, bx + bw / 2, sx + sw / 2]:
|
||||
ax.plot(
|
||||
[xc, xc],
|
||||
[5.0, 1.2],
|
||||
color=GRAY3,
|
||||
lw=1,
|
||||
linestyle=":",
|
||||
)
|
||||
|
||||
# Scenario A: success
|
||||
y = 4.3
|
||||
ax.text(
|
||||
0.2,
|
||||
y + 0.15,
|
||||
"Scenariusz A:",
|
||||
fontsize=8.5,
|
||||
fontweight="bold",
|
||||
)
|
||||
msg9 = ArrowCfg(label="MSG", label_fs=9)
|
||||
draw_arrow(ax, (px + pw / 2, y), (bx + bw / 2, y), msg9)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(bx + bw / 2, y - 0.6),
|
||||
(sx + sw / 2, y - 0.6),
|
||||
msg9,
|
||||
)
|
||||
draw_check(ax, (sx + sw / 2 + 0.4, y - 0.6), size=0.18)
|
||||
ax.text(
|
||||
sx + sw / 2 + 0.7,
|
||||
y - 0.6,
|
||||
"OK",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Scenario B: lost
|
||||
y = 2.6
|
||||
ax.text(
|
||||
0.2,
|
||||
y + 0.15,
|
||||
"Scenariusz B:",
|
||||
fontsize=8.5,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, (px + pw / 2, y), (bx + bw / 2, y), msg9)
|
||||
draw_dashed_arrow(ax, (bx + bw / 2, y - 0.6), (7.5, y - 0.6))
|
||||
draw_cross(ax, (7.8, y - 0.6), size=0.2)
|
||||
ax.text(
|
||||
8.2,
|
||||
y - 0.55,
|
||||
"UTRACONA",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
8.2,
|
||||
y - 1.0,
|
||||
"(brak retransmisji)",
|
||||
fontsize=8,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.5,
|
||||
"Brak ACK, brak retransmisji."
|
||||
" Najszybszy. Use case:"
|
||||
" logi, metryki, telemetria.",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=9,
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.4",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
save(fig, "pubsub_qos_at_most_once.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 6. At-least-once (QoS 1)
|
||||
# ============================================================
|
||||
def draw_qos_at_least_once() -> None:
|
||||
"""Draw qos at least once."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 6.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"QoS: At-least-once"
|
||||
" \u2014 \u201epowtarzaj a\u017c potwierdz\u0105\u201d"
|
||||
" (\u22651 dostarczenie)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
bx, bw = 3.5, 2.2
|
||||
sx, sw = 8.0, 2.2
|
||||
bh = 0.8
|
||||
draw_box(
|
||||
ax,
|
||||
(bx, 5.5),
|
||||
(bw, bh),
|
||||
"Broker",
|
||||
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(sx, 5.5),
|
||||
(sw, bh),
|
||||
"Subscriber",
|
||||
BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold"),
|
||||
)
|
||||
|
||||
for xc in [bx + bw / 2, sx + sw / 2]:
|
||||
ax.plot(
|
||||
[xc, xc],
|
||||
[5.5, 0.8],
|
||||
color=GRAY3,
|
||||
lw=1,
|
||||
linestyle=":",
|
||||
)
|
||||
|
||||
# Step 1: send MSG
|
||||
y1 = 4.8
|
||||
draw_arrow(
|
||||
ax,
|
||||
(bx + bw / 2, y1),
|
||||
(sx + sw / 2, y1),
|
||||
ArrowCfg(label="MSG #1", label_fs=9),
|
||||
)
|
||||
draw_check(ax, (sx + sw + 0.2, y1), size=0.15)
|
||||
ax.text(sx + sw + 0.5, y1, "odebrano", fontsize=8)
|
||||
|
||||
# Step 2: ACK lost
|
||||
y2 = 3.9
|
||||
draw_dashed_arrow(
|
||||
ax,
|
||||
(sx + sw / 2, y2),
|
||||
(bx + bw + 1.2, y2),
|
||||
)
|
||||
ax.text(
|
||||
(bx + bw / 2 + sx + sw / 2) / 2,
|
||||
y2 + 0.18,
|
||||
"ACK",
|
||||
fontsize=9,
|
||||
)
|
||||
draw_cross(ax, (bx + bw + 0.8, y2), size=0.18)
|
||||
ax.text(
|
||||
bx + 0.3,
|
||||
y2 - 0.35,
|
||||
"ACK utracony!",
|
||||
fontsize=8.5,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Step 3: timeout -> retry
|
||||
y3 = 2.9
|
||||
ax.text(
|
||||
bx + bw / 2,
|
||||
y3 + 0.45,
|
||||
"timeout...",
|
||||
fontsize=8.5,
|
||||
style="italic",
|
||||
ha="center",
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(bx + bw / 2, y3),
|
||||
(sx + sw / 2, y3),
|
||||
ArrowCfg(label="MSG #1 (retry)", label_fs=9),
|
||||
)
|
||||
draw_check(ax, (sx + sw + 0.2, y3), size=0.15)
|
||||
ax.text(
|
||||
sx + sw + 0.5,
|
||||
y3,
|
||||
"odebrano\n(ponownie!)",
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
# Step 4: ACK ok
|
||||
y4 = 2.0
|
||||
draw_arrow(
|
||||
ax,
|
||||
(sx + sw / 2, y4),
|
||||
(bx + bw / 2, y4),
|
||||
ArrowCfg(label="ACK", label_fs=9),
|
||||
)
|
||||
draw_check(ax, (bx + bw / 2 - 0.5, y4), size=0.18)
|
||||
|
||||
# Duplicate bracket
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(sx + sw + 1.3, y1),
|
||||
xytext=(sx + sw + 1.3, y3),
|
||||
arrowprops={
|
||||
"arrowstyle": "<->",
|
||||
"color": "black",
|
||||
"lw": 1.2,
|
||||
},
|
||||
)
|
||||
ax.text(
|
||||
sx + sw + 1.6,
|
||||
(y1 + y3) / 2,
|
||||
"DUPLIKAT!\nSubscriber\notrzyma\u0142 2x",
|
||||
fontsize=9,
|
||||
ha="left",
|
||||
va="center",
|
||||
fontweight="bold",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.25",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.5,
|
||||
"Broker czeka na ACK, retransmituje"
|
||||
" po timeout. Mog\u0105 by\u0107 duplikaty!\n"
|
||||
"Use case: zam\u00f3wienia, p\u0142atno\u015bci"
|
||||
" (subscriber musi by\u0107 idempotentny).",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=9,
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.4",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
save(fig, "pubsub_qos_at_least_once.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 7. Exactly-once (QoS 2)
|
||||
# ============================================================
|
||||
def draw_qos_exactly_once() -> None:
|
||||
"""Draw qos exactly once."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"QoS: Exactly-once \u2014 4-krokowy"
|
||||
" handshake (dok\u0142adnie 1 dostarczenie)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
bx, bw = 2.5, 2.2
|
||||
sx, sw = 7.5, 2.2
|
||||
bh = 0.8
|
||||
draw_box(
|
||||
ax,
|
||||
(bx, 6.0),
|
||||
(bw, bh),
|
||||
"Broker",
|
||||
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(sx, 6.0),
|
||||
(sw, bh),
|
||||
"Subscriber",
|
||||
BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold"),
|
||||
)
|
||||
|
||||
for xc in [bx + bw / 2, sx + sw / 2]:
|
||||
ax.plot(
|
||||
[xc, xc],
|
||||
[6.0, 1.0],
|
||||
color=GRAY3,
|
||||
lw=1,
|
||||
linestyle=":",
|
||||
)
|
||||
|
||||
steps = [
|
||||
(
|
||||
5.2,
|
||||
"right",
|
||||
"PUBLISH (msg_id=42)",
|
||||
"Broker wysy\u0142a wiadomo\u015b\u0107",
|
||||
),
|
||||
(
|
||||
4.2,
|
||||
"left",
|
||||
"PUBREC (otrzyma\u0142em id=42)",
|
||||
"Sub potwierdza odbi\u00f3r," " zapisuje id",
|
||||
),
|
||||
(
|
||||
3.2,
|
||||
"right",
|
||||
"PUBREL (mo\u017cesz przetworzy\u0107)",
|
||||
"Broker zwalnia wiadomo\u015b\u0107",
|
||||
),
|
||||
(
|
||||
2.2,
|
||||
"left",
|
||||
"PUBCOMP (zako\u0144czone)",
|
||||
"Sub potwierdza przetworzenie",
|
||||
),
|
||||
]
|
||||
|
||||
for i, (y, direction, label, desc) in enumerate(steps):
|
||||
ax.text(
|
||||
bx + bw / 2 - 0.7,
|
||||
y,
|
||||
f"{i + 1}",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
bbox={
|
||||
"boxstyle": "circle,pad=0.18",
|
||||
"facecolor": GRAY3,
|
||||
"edgecolor": LN,
|
||||
},
|
||||
)
|
||||
|
||||
if direction == "right":
|
||||
draw_arrow(
|
||||
ax,
|
||||
(bx + bw / 2, y),
|
||||
(sx + sw / 2, y),
|
||||
ArrowCfg(label=label, label_fs=9),
|
||||
)
|
||||
else:
|
||||
draw_arrow(
|
||||
ax,
|
||||
(sx + sw / 2, y),
|
||||
(bx + bw / 2, y),
|
||||
ArrowCfg(label=label, label_fs=9),
|
||||
)
|
||||
|
||||
ax.text(
|
||||
sx + sw + 0.3,
|
||||
y,
|
||||
desc,
|
||||
fontsize=8,
|
||||
ha="left",
|
||||
va="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.6,
|
||||
"Deduplikacja po msg_id."
|
||||
" Sub nie przetwarza przed PUBREL.\n"
|
||||
"Najkosztowniejszy (4 pakiety)."
|
||||
" Use case: transakcje finansowe,"
|
||||
" krytyczne zdarzenia.",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=9,
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.4",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
save(fig, "pubsub_qos_exactly_once.png")
|
||||
@ -0,0 +1,239 @@
|
||||
"""Subscription-type diagrams: topic-based and content-based."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _pubsub_common import (
|
||||
FIG_W,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
ArrowCfg,
|
||||
BoxStyle,
|
||||
DashedCfg,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_dashed_arrow,
|
||||
save,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 1. Topic-based subscription
|
||||
# ============================================================
|
||||
def draw_sub_topic() -> None:
|
||||
"""Draw sub topic."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.0))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 5.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Subskrypcja topic-based" " \u2014 routing po nazwie tematu",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
|
||||
fs85 = BoxStyle(fill=GRAY1, fontsize=8.5)
|
||||
|
||||
draw_box(ax, (0.2, 3.2), (2.4, 1.1), "Publisher", bold10)
|
||||
draw_box(
|
||||
ax,
|
||||
(0.3, 1.8),
|
||||
(2.2, 0.8),
|
||||
'topic: "orders"',
|
||||
BoxStyle(fill=GRAY4, fontsize=8),
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(0.3, 0.7),
|
||||
(2.2, 0.8),
|
||||
'topic: "payments"',
|
||||
BoxStyle(fill=GRAY4, fontsize=8),
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
(4.2, 1.5),
|
||||
(2.8, 2.2),
|
||||
"BROKER\n\ntopic routing",
|
||||
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 3.8),
|
||||
(3.0, 1.0),
|
||||
'Subscriber A\nsubskrybuje: "orders"',
|
||||
fs85,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 2.2),
|
||||
(3.0, 1.0),
|
||||
'Subscriber B\nsubskrybuje: "payments"',
|
||||
fs85,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 0.6),
|
||||
(3.0, 1.0),
|
||||
'Subscriber C\nsubskrybuje: "orders"',
|
||||
fs85,
|
||||
)
|
||||
|
||||
fs8 = ArrowCfg(label_fs=8)
|
||||
draw_arrow(ax, (2.6, 2.2), (4.2, 2.8), fs8)
|
||||
draw_arrow(ax, (2.6, 1.1), (4.2, 2.2), fs8)
|
||||
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 3.4),
|
||||
(8.5, 4.2),
|
||||
ArrowCfg(label='"orders"', label_fs=8),
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 2.6),
|
||||
(8.5, 2.7),
|
||||
ArrowCfg(label='"payments"', label_fs=8),
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 2.2),
|
||||
(8.5, 1.2),
|
||||
ArrowCfg(label='"orders"', label_fs=8),
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.1,
|
||||
"Subscriber deklaruje nazw\u0119 tematu."
|
||||
" Broker kieruje wiadomo\u015bci\n"
|
||||
"do WSZYSTKICH subscriber\u00f3w"
|
||||
" danego tematu. Najprostszy model.",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8.5,
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
save(fig, "pubsub_sub_topic.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 2. Content-based subscription
|
||||
# ============================================================
|
||||
def draw_sub_content() -> None:
|
||||
"""Draw sub content."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 4.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Subskrypcja content-based"
|
||||
" \u2014 filtrowanie po tre\u015bci wiadomo\u015bci",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
|
||||
draw_box(ax, (0.2, 3.5), (2.4, 1.1), "Publisher", bold10)
|
||||
draw_box(
|
||||
ax,
|
||||
(0.2, 1.8),
|
||||
(2.4, 1.2),
|
||||
'price: 150\ntype: "book"\ncategory: "IT"',
|
||||
BoxStyle(fill=GRAY4, fontsize=8.5),
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
(4.0, 2.0),
|
||||
(3.0, 2.5),
|
||||
"BROKER\n\newaluuje filtry\n" "ka\u017cdego subscribera",
|
||||
BoxStyle(fill=GRAY2, fontsize=9, fontweight="bold"),
|
||||
)
|
||||
|
||||
fs9 = BoxStyle(fill=GRAY1, fontsize=9)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 4.2),
|
||||
(3.2, 1.0),
|
||||
"Sub A\nfiltr: price > 100",
|
||||
fs9,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 2.6),
|
||||
(3.2, 1.0),
|
||||
'Sub B\nfiltr: type = "food"',
|
||||
fs9,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 1.0),
|
||||
(3.2, 1.0),
|
||||
"Sub C\nfiltr: price < 50",
|
||||
fs9,
|
||||
)
|
||||
|
||||
draw_arrow(ax, (2.6, 2.4), (4.0, 3.0))
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 4.0),
|
||||
(8.5, 4.6),
|
||||
ArrowCfg(
|
||||
label="150 > 100 \u2713 dostarczono",
|
||||
label_fs=8,
|
||||
),
|
||||
)
|
||||
draw_dashed_arrow(
|
||||
ax,
|
||||
(7.0, 3.2),
|
||||
(8.5, 3.1),
|
||||
DashedCfg(
|
||||
label='"book" \u2260 "food"' " \u2717 odrzucono",
|
||||
label_fs=8,
|
||||
),
|
||||
)
|
||||
draw_dashed_arrow(
|
||||
ax,
|
||||
(7.0, 2.5),
|
||||
(8.5, 1.6),
|
||||
DashedCfg(
|
||||
label="150 < 50 \u2717 odrzucono",
|
||||
label_fs=8,
|
||||
),
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.2,
|
||||
"Broker analizuje TRE\u015a\u0106 wiadomo\u015bci"
|
||||
" i ewaluuje predykaty.\n"
|
||||
"Bardziej elastyczny ni\u017c topic-based,"
|
||||
" ale wolniejszy (koszt ewaluacji).",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8.5,
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
save(fig, "pubsub_sub_content.png")
|
||||
@ -0,0 +1,279 @@
|
||||
"""Subscription-type diagrams: type-based and hierarchical."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _pubsub_common import (
|
||||
FIG_W,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
ArrowCfg,
|
||||
BoxStyle,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
save,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 3. Type-based subscription
|
||||
# ============================================================
|
||||
def draw_sub_type() -> None:
|
||||
"""Draw sub type."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.0))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 6.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Subskrypcja type-based" " \u2014 routing po typie (klasie) obiektu",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
bold10 = BoxStyle(fill=GRAY1, fontsize=10, fontweight="bold")
|
||||
draw_box(ax, (0.2, 4.2), (2.4, 1.1), "Publisher", bold10)
|
||||
|
||||
fs9_g4 = BoxStyle(fill=GRAY4, fontsize=9)
|
||||
draw_box(
|
||||
ax,
|
||||
(0.1, 2.8),
|
||||
(2.6, 0.9),
|
||||
"new OrderEvent()",
|
||||
fs9_g4,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(0.1, 1.5),
|
||||
(2.6, 0.9),
|
||||
"new PaymentEvent()",
|
||||
fs9_g4,
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
(4.0, 2.3),
|
||||
(3.0, 2.4),
|
||||
"BROKER\n\nrouting po\ntypie klasy",
|
||||
BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold"),
|
||||
)
|
||||
|
||||
fs9 = BoxStyle(fill=GRAY1, fontsize=9)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 4.8),
|
||||
(3.2, 1.0),
|
||||
"Sub A\n\u2192 OrderEvent",
|
||||
fs9,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 3.2),
|
||||
(3.2, 1.0),
|
||||
"Sub B\n\u2192 PaymentEvent",
|
||||
fs9,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(8.5, 1.6),
|
||||
(3.2, 1.0),
|
||||
"Sub C\n\u2192 Event (base)",
|
||||
fs9,
|
||||
)
|
||||
|
||||
draw_arrow(ax, (2.7, 3.2), (4.0, 3.8))
|
||||
draw_arrow(ax, (2.7, 2.0), (4.0, 3.0))
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 4.3),
|
||||
(8.5, 5.2),
|
||||
ArrowCfg(label="OrderEvent", label_fs=8),
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 3.5),
|
||||
(8.5, 3.7),
|
||||
ArrowCfg(label="PaymentEvent", label_fs=8),
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(7.0, 3.0),
|
||||
(8.5, 2.2),
|
||||
ArrowCfg(label="oba (dziedziczenie!)", label_fs=8),
|
||||
)
|
||||
|
||||
hx, hy = 0.5, 0.0
|
||||
draw_box(
|
||||
ax,
|
||||
(hx + 2.0, hy + 0.2),
|
||||
(1.8, 0.6),
|
||||
"Event",
|
||||
BoxStyle(fill=GRAY3, fontsize=8, fontweight="bold"),
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(hx, hy + 0.2),
|
||||
(1.8, 0.6),
|
||||
"OrderEvent",
|
||||
BoxStyle(fill=GRAY4, fontsize=7.5),
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(hx + 4.0, hy + 0.2),
|
||||
(2.0, 0.6),
|
||||
"PaymentEvent",
|
||||
BoxStyle(fill=GRAY4, fontsize=7.5),
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(hx + 2.9, hy + 0.2),
|
||||
(hx + 0.9, hy + 0.2),
|
||||
ArrowCfg(
|
||||
lw=1.0,
|
||||
label="extends",
|
||||
label_offset=-0.3,
|
||||
label_fs=7,
|
||||
),
|
||||
)
|
||||
draw_arrow(
|
||||
ax,
|
||||
(hx + 2.9, hy + 0.2),
|
||||
(hx + 5.0, hy + 0.2),
|
||||
ArrowCfg(
|
||||
lw=1.0,
|
||||
label="extends",
|
||||
label_offset=-0.3,
|
||||
label_fs=7,
|
||||
),
|
||||
)
|
||||
|
||||
ax.text(
|
||||
9.5,
|
||||
0.5,
|
||||
"Sub C subskrybuje bazowy Event\n" "\u2192 otrzymuje WSZYSTKIE podtypy",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=8.5,
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
},
|
||||
)
|
||||
|
||||
save(fig, "pubsub_sub_type.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 4. Hierarchical / Wildcards subscription
|
||||
# ============================================================
|
||||
def draw_sub_hierarchical() -> None:
|
||||
"""Draw sub hierarchical."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(FIG_W, 5.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Subskrypcja hierarchiczna (wildcards)" " \u2014 wzorce temat\u00f3w",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
bold10 = BoxStyle(fill=GRAY2, fontsize=10, fontweight="bold")
|
||||
draw_box(ax, (4.5, 5.8), (2.4, 0.8), "sensors/", bold10)
|
||||
|
||||
fs9_g3 = BoxStyle(fill=GRAY3, fontsize=9)
|
||||
draw_box(
|
||||
ax,
|
||||
(1.5, 4.2),
|
||||
(2.4, 0.8),
|
||||
"temperature/",
|
||||
fs9_g3,
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
(7.5, 4.2),
|
||||
(2.4, 0.8),
|
||||
"humidity/",
|
||||
fs9_g3,
|
||||
)
|
||||
|
||||
fs85_g4 = BoxStyle(fill=GRAY4, fontsize=8.5)
|
||||
draw_box(ax, (0.2, 2.8), (1.8, 0.7), "room1", fs85_g4)
|
||||
draw_box(ax, (2.4, 2.8), (1.8, 0.7), "room2", fs85_g4)
|
||||
draw_box(ax, (6.8, 2.8), (1.8, 0.7), "room1", fs85_g4)
|
||||
draw_box(ax, (9.0, 2.8), (1.8, 0.7), "room2", fs85_g4)
|
||||
|
||||
thin = ArrowCfg(lw=1.0)
|
||||
draw_arrow(ax, (5.7, 5.8), (2.7, 5.0), thin)
|
||||
draw_arrow(ax, (5.7, 5.8), (8.7, 5.0), thin)
|
||||
draw_arrow(ax, (2.2, 4.2), (1.1, 3.5), thin)
|
||||
draw_arrow(ax, (3.2, 4.2), (3.3, 3.5), thin)
|
||||
draw_arrow(ax, (8.2, 4.2), (7.7, 3.5), thin)
|
||||
draw_arrow(ax, (9.2, 4.2), (9.9, 3.5), thin)
|
||||
|
||||
ax.text(
|
||||
1.1,
|
||||
2.4,
|
||||
"sensors/temperature/room1",
|
||||
fontsize=7,
|
||||
ha="center",
|
||||
fontfamily="monospace",
|
||||
style="italic",
|
||||
)
|
||||
ax.text(
|
||||
3.3,
|
||||
2.4,
|
||||
"sensors/temperature/room2",
|
||||
fontsize=7,
|
||||
ha="center",
|
||||
fontfamily="monospace",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
0.3,
|
||||
1.5,
|
||||
"Wzorce subskrypcji (MQTT-style):",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
patterns = [
|
||||
(
|
||||
'"sensors/temperature/room1"',
|
||||
"\u2192 TYLKO room1",
|
||||
"(dok\u0142adne dopasowanie)",
|
||||
),
|
||||
(
|
||||
'"sensors/temperature/*"',
|
||||
"\u2192 room1, room2",
|
||||
"( * = jeden poziom)",
|
||||
),
|
||||
(
|
||||
'"sensors/#"',
|
||||
"\u2192 WSZYSTKO",
|
||||
"( # = dowolna g\u0142\u0119boko\u015b\u0107)",
|
||||
),
|
||||
]
|
||||
for i, (pat, result, note) in enumerate(patterns):
|
||||
yy = 0.9 - i * 0.55
|
||||
ax.text(
|
||||
0.5,
|
||||
yy,
|
||||
pat,
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
fontfamily="monospace",
|
||||
)
|
||||
ax.text(7.0, yy, result, fontsize=9, fontweight="bold")
|
||||
ax.text(9.5, yy, note, fontsize=8, style="italic")
|
||||
|
||||
save(fig, "pubsub_sub_hierarchical.png")
|
||||
@ -0,0 +1,421 @@
|
||||
"""Spark Streaming, Lambda/Kappa architecture, and exactly-once diagrams for Q20."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q20_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_table,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 12. Spark Streaming architecture
|
||||
# ============================================================
|
||||
def gen_spark_streaming_arch() -> None:
|
||||
"""Gen spark streaming arch."""
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Spark Streaming — architektura (micro-batch)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Cluster border
|
||||
draw_box(ax, 0.3, 0.5, 11.4, 5.8, "", fill=GRAY4, rounded=True, lw=2.5)
|
||||
ax.text(
|
||||
6.0, 6.0, "SPARK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold"
|
||||
)
|
||||
|
||||
# Driver
|
||||
draw_box(
|
||||
ax,
|
||||
1.0,
|
||||
4.5,
|
||||
3.0,
|
||||
1.2,
|
||||
"Driver\n(planuje mini-batche)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
draw_arrow(ax, 2.5, 4.5, 6.0, 4.0, lw=1.5)
|
||||
|
||||
# Batches
|
||||
batches = ["batch 1\n(e1,e2,e3)", "batch 2\n(e4,e5,e6)", "batch 3\n(e7,e8,e9)"]
|
||||
for i, b in enumerate(batches):
|
||||
y = 2.8 - i * 1.0
|
||||
draw_box(
|
||||
ax, 4.5, y, 2.5, 0.8, b, fill=GRAY1, fontsize=FS_SMALL, fontweight="bold"
|
||||
)
|
||||
# map → reduce
|
||||
draw_arrow(ax, 7.0, y + 0.4, 7.5, y + 0.4, lw=1)
|
||||
draw_box(ax, 7.5, y, 1.3, 0.8, "map→\nreduce", fill=GRAY3, fontsize=5.5)
|
||||
draw_arrow(ax, 8.8, y + 0.4, 9.3, y + 0.4, lw=1)
|
||||
draw_box(
|
||||
ax, 9.3, y, 1.5, 0.8, f"result {i + 1}", fill="white", fontsize=FS_SMALL
|
||||
)
|
||||
|
||||
# Spark ecosystem
|
||||
draw_box(
|
||||
ax,
|
||||
1.0,
|
||||
1.0,
|
||||
3.0,
|
||||
1.0,
|
||||
"Spark SQL / MLlib\n(ten sam ekosystem!)",
|
||||
fill=GRAY5,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.3,
|
||||
"ZALETA: batch API | WADA: latencja ≥ batch interval (~100ms)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN},
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_spark_streaming_arch.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 13. Lambda vs Kappa architecture
|
||||
# ============================================================
|
||||
def gen_lambda_vs_kappa() -> None:
|
||||
"""Gen lambda vs kappa."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(10, 7))
|
||||
fig.suptitle("Architektura Lambda vs Kappa", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# --- Lambda ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"LAMBDA — 2 ścieżki (batch + speed)", fontsize=FS_LABEL, fontweight="bold"
|
||||
)
|
||||
|
||||
# Source
|
||||
draw_box(
|
||||
ax,
|
||||
0.3,
|
||||
1.8,
|
||||
2.0,
|
||||
1.5,
|
||||
"Źródło\ndanych",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Batch layer (top)
|
||||
draw_box(
|
||||
ax,
|
||||
3.5,
|
||||
3.3,
|
||||
3.0,
|
||||
1.2,
|
||||
"Batch Layer\n(Spark)\nprzelicza co godzinę",
|
||||
fill=GRAY1,
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 2.3, 3.0, 3.5, 3.9, lw=1.5)
|
||||
|
||||
# Speed layer (bottom)
|
||||
draw_box(
|
||||
ax,
|
||||
3.5,
|
||||
0.8,
|
||||
3.0,
|
||||
1.2,
|
||||
"Speed Layer\n(Flink)\nreal-time",
|
||||
fill=GRAY3,
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 2.3, 2.2, 3.5, 1.4, lw=1.5)
|
||||
|
||||
# Results
|
||||
draw_box(
|
||||
ax,
|
||||
7.5,
|
||||
3.3,
|
||||
2.0,
|
||||
1.2,
|
||||
"Dokładne\nwyniki\n(wolne)",
|
||||
fill=GRAY4,
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
draw_arrow(ax, 6.5, 3.9, 7.5, 3.9, lw=1.5)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
7.5,
|
||||
0.8,
|
||||
2.0,
|
||||
1.2,
|
||||
"Przybliżone\nwyniki\n(szybkie)",
|
||||
fill=GRAY4,
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
draw_arrow(ax, 6.5, 1.4, 7.5, 1.4, lw=1.5)
|
||||
|
||||
# Merge
|
||||
draw_box(
|
||||
ax,
|
||||
10.0,
|
||||
2.0,
|
||||
1.5,
|
||||
1.5,
|
||||
"MERGE\n→ UI",
|
||||
fill=GRAY5,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 9.5, 3.5, 10.0, 3.0, lw=1.5)
|
||||
draw_arrow(ax, 9.5, 1.8, 10.0, 2.5, lw=1.5)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
0.1,
|
||||
"2 systemy, 2 kody — złożone ale pewne",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
# --- Kappa ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 4)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"KAPPA — 1 ścieżka (streaming only)", fontsize=FS_LABEL, fontweight="bold"
|
||||
)
|
||||
|
||||
# Source
|
||||
draw_box(
|
||||
ax,
|
||||
0.3,
|
||||
1.3,
|
||||
2.0,
|
||||
1.5,
|
||||
"Źródło\ndanych",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Single streaming layer
|
||||
draw_box(
|
||||
ax,
|
||||
3.5,
|
||||
1.3,
|
||||
3.5,
|
||||
1.5,
|
||||
"Streaming Layer\n(Flink)\n+ replay z Kafka log",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 2.3, 2.05, 3.5, 2.05, lw=2)
|
||||
|
||||
# Output
|
||||
draw_box(
|
||||
ax,
|
||||
8.0,
|
||||
1.3,
|
||||
2.5,
|
||||
1.5,
|
||||
"Wyniki\n→ UI",
|
||||
fill=GRAY4,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 7.0, 2.05, 8.0, 2.05, lw=2)
|
||||
|
||||
# Replay arrow
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(3.5, 1.0),
|
||||
xytext=(7.0, 1.0),
|
||||
arrowprops={
|
||||
"arrowstyle": "<-",
|
||||
"lw": 1.5,
|
||||
"color": LN,
|
||||
"connectionstyle": "arc3,rad=0.3",
|
||||
"linestyle": "--",
|
||||
},
|
||||
)
|
||||
ax.text(
|
||||
5.25,
|
||||
0.3,
|
||||
"Replay z Kafka\n(przetwórz historię od nowa)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.0,
|
||||
3.3,
|
||||
"1 system, 1 kod — prostsze, ale replay = dużo I/O",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.92])
|
||||
save_fig(fig, "q20_lambda_vs_kappa.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 14. Lambda vs Kappa comparison table
|
||||
# ============================================================
|
||||
def gen_lambda_kappa_table() -> None:
|
||||
"""Gen lambda kappa table."""
|
||||
fig, ax = plt.subplots(figsize=(8, 3.5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(-4.5, 1)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Lambda vs Kappa — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
headers = ["Cecha", "Lambda", "Kappa"]
|
||||
col_w = [2.5, 3.5, 3.5]
|
||||
rows = [
|
||||
["Ścieżki", "2 (batch + speed)", "1 (streaming)"],
|
||||
["Kod", "2 implementacje", "1 implementacja"],
|
||||
["Złożoność", "wysoka", "niska"],
|
||||
["Replay", "batch przelicza", "Kafka replay"],
|
||||
["Spójność", "merge wymagany", "natywna"],
|
||||
["Przykład", "Netflix, LinkedIn", "Uber, Confluent"],
|
||||
]
|
||||
draw_table(
|
||||
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.55, fontsize=7.5
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_lambda_kappa_table.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 15. Exactly-once comparison
|
||||
# ============================================================
|
||||
def gen_exactly_once() -> None:
|
||||
"""Gen exactly once."""
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Exactly-Once — mechanizmy na 3 platformach",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Flink
|
||||
draw_box(ax, 0.3, 4.3, 11.0, 2.0, "", fill=GRAY4, rounded=True, lw=1.5)
|
||||
ax.text(
|
||||
1.0,
|
||||
5.9,
|
||||
"Flink — Distributed Snapshots (Chandy-Lamport)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
flink_steps = ["source", "|B|", "map()", "|B|", "sink()"]
|
||||
bx = 1.0
|
||||
for s in flink_steps:
|
||||
if s == "|B|":
|
||||
ax.text(
|
||||
bx + 0.25,
|
||||
4.85,
|
||||
s,
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN},
|
||||
)
|
||||
draw_arrow(ax, bx - 0.1, 4.85, bx + 0.05, 4.85, lw=1)
|
||||
bx += 0.7
|
||||
else:
|
||||
draw_box(
|
||||
ax,
|
||||
bx,
|
||||
4.6,
|
||||
1.5,
|
||||
0.55,
|
||||
s,
|
||||
fill=GRAY1,
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
bx += 1.8
|
||||
ax.text(
|
||||
8.5,
|
||||
5.0,
|
||||
"barrier → save state\n→ checkpoint (HDFS/S3)",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Kafka Streams
|
||||
draw_box(ax, 0.3, 2.3, 11.0, 1.5, "", fill=GRAY1, rounded=True, lw=1.5)
|
||||
ax.text(
|
||||
1.0, 3.5, "Kafka Streams — Transakcje Kafka", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
ax.text(
|
||||
1.5,
|
||||
2.85,
|
||||
"idempotent producer + begin TX → produce → commit TX → consumer offsets w TX",
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
|
||||
# Spark
|
||||
draw_box(ax, 0.3, 0.5, 11.0, 1.5, "", fill=GRAY3, rounded=True, lw=1.5)
|
||||
ax.text(
|
||||
1.0,
|
||||
1.7,
|
||||
"Spark Streaming — Write-Ahead Log (WAL)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
1.5,
|
||||
1.05,
|
||||
"WAL + checkpointing micro-batchów + idempotent sinks (np. upsert do DB)",
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_exactly_once.png")
|
||||
@ -0,0 +1,449 @@
|
||||
"""Batch vs streaming concept and window type diagrams for Q20."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q20_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 1. Batch vs Streaming concept
|
||||
# ============================================================
|
||||
def gen_batch_vs_streaming() -> None:
|
||||
"""Gen batch vs streaming."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(9, 5))
|
||||
fig.suptitle(
|
||||
"Batch vs Streaming — dwa modele przetwarzania",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Batch
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 3)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("BATCH (wsadowe)", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Data collected
|
||||
draw_box(
|
||||
ax,
|
||||
0.5,
|
||||
0.8,
|
||||
3.0,
|
||||
1.4,
|
||||
"Zbierz WSZYSTKIE\ndane\n(godziny / dni)",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 3.5, 1.5, 4.5, 1.5, lw=2)
|
||||
draw_box(
|
||||
ax,
|
||||
4.5,
|
||||
0.8,
|
||||
2.5,
|
||||
1.4,
|
||||
"Analiza\n(batch job)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 7.0, 1.5, 8.0, 1.5, lw=2)
|
||||
draw_box(
|
||||
ax,
|
||||
8.0,
|
||||
0.8,
|
||||
2.5,
|
||||
1.4,
|
||||
"Wynik\n(jednorazowy)",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(11.0, 1.5, "min-h", fontsize=FS, va="center", fontweight="bold")
|
||||
|
||||
# Streaming
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 3)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("STREAMING (strumieniowe)", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Events flowing
|
||||
events_x = [0.5, 1.5, 2.5, 3.5]
|
||||
for i, ex in enumerate(events_x):
|
||||
draw_box(
|
||||
ax,
|
||||
ex,
|
||||
1.0,
|
||||
0.8,
|
||||
0.8,
|
||||
f"e{i + 1}",
|
||||
fill=GRAY4,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
if i < len(events_x) - 1:
|
||||
draw_arrow(ax, ex + 0.8, 1.4, ex + 1.0, 1.4, lw=1)
|
||||
|
||||
ax.text(4.8, 1.4, "...", fontsize=FS_LABEL, va="center")
|
||||
draw_arrow(ax, 5.2, 1.4, 5.8, 1.4, lw=2)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
5.8,
|
||||
0.8,
|
||||
2.8,
|
||||
1.4,
|
||||
"Analiza\nCIĄGŁA\n(event-by-event)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 8.6, 1.5, 9.3, 1.5, lw=2)
|
||||
draw_box(
|
||||
ax,
|
||||
9.3,
|
||||
0.8,
|
||||
2.0,
|
||||
1.4,
|
||||
"Wyniki\nciągłe",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(11.5, 0.5, "ms-s", fontsize=FS, va="center", fontweight="bold")
|
||||
|
||||
# Arrow marking infinity
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(0.2, 1.4),
|
||||
xytext=(-0.3, 1.4),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(0.0, 2.3, "∞ zdarzeń", fontsize=FS_SMALL, ha="center", style="italic")
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.92])
|
||||
save_fig(fig, "q20_batch_vs_streaming.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 2. All 4 window types (TSSG)
|
||||
# ============================================================
|
||||
def _draw_tumbling_window(ax: Axes, events: list[int]) -> None:
|
||||
"""Draw tumbling window section."""
|
||||
ax.set_xlim(0, 14)
|
||||
ax.set_ylim(0, 4)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Tumbling Window (okno przerzutne) — rozłączne, stały rozmiar",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Time axis
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(13.5, 1.0),
|
||||
xytext=(0.3, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
# Events
|
||||
for i, e in enumerate(events):
|
||||
x = 1.0 + i * 1.0
|
||||
ax.plot(x, 1.0, "ko", markersize=5)
|
||||
ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
# Windows
|
||||
colors_w = [GRAY1, GRAY3, GRAY1, GRAY3]
|
||||
for w in range(4):
|
||||
x_start = 1.0 + w * 3.0 - 0.3
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(x_start, 1.5),
|
||||
3.0,
|
||||
1.2,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=colors_w[w],
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x_start + 1.5,
|
||||
2.1,
|
||||
f"Okno {w + 1}",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
# Braces down to events
|
||||
for j in range(3):
|
||||
ex = 1.0 + w * 3.0 + j * 1.0
|
||||
ax.plot([ex, ex], [1.0, 1.5], color=LN, lw=0.8, linestyle="--")
|
||||
|
||||
ax.text(
|
||||
7.0,
|
||||
3.2,
|
||||
"Każde zdarzenie → DOKŁADNIE 1 okno. Zero nakładania.",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
|
||||
def _draw_sliding_window(ax: Axes, events: list[int]) -> None:
|
||||
"""Draw sliding window section."""
|
||||
ax.set_xlim(0, 14)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Sliding Window (okno przesuwne) — nakładające, stały rozmiar + krok",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(13.5, 1.0),
|
||||
xytext=(0.3, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
for i, e in enumerate(events[:8]):
|
||||
x = 1.0 + i * 1.0
|
||||
ax.plot(x, 1.0, "ko", markersize=5)
|
||||
ax.text(x, 0.5, f"e{e}", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
# Sliding windows: size=4, slide=2
|
||||
slide_colors = [GRAY1, GRAY2, GRAY3]
|
||||
for w in range(3):
|
||||
x_start = 0.7 + w * 2.0
|
||||
y_base = 1.5 + w * 0.9
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(x_start, y_base),
|
||||
4.0,
|
||||
0.7,
|
||||
boxstyle="round,pad=0.08",
|
||||
facecolor=slide_colors[w],
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
alpha=0.7,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x_start + 2.0,
|
||||
y_base + 0.35,
|
||||
f"Okno {w + 1} (size=4)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
10.5,
|
||||
3.5,
|
||||
"krok=2\nNakładanie!\ne3,e4 → w oknie 1 i 2",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
|
||||
def _draw_session_window(ax: Axes) -> None:
|
||||
"""Draw session window section."""
|
||||
ax.set_xlim(0, 14)
|
||||
ax.set_ylim(0, 4)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Session Window (okno sesji) — dynamiczny rozmiar, gap = przerwa",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(13.5, 1.0),
|
||||
xytext=(0.3, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
# Cluster 1: events close together
|
||||
cluster1 = [1.0, 1.8, 2.3, 3.0]
|
||||
for x in cluster1:
|
||||
ax.plot(x, 1.0, "ko", markersize=5)
|
||||
|
||||
# Gap
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(7.0, 0.7),
|
||||
xytext=(4.0, 0.7),
|
||||
arrowprops={"arrowstyle": "<->", "lw": 1, "color": LN},
|
||||
)
|
||||
ax.text(
|
||||
5.5,
|
||||
0.3,
|
||||
"GAP > timeout",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Cluster 2
|
||||
cluster2 = [8.0, 8.8, 9.5]
|
||||
for x in cluster2:
|
||||
ax.plot(x, 1.0, "ko", markersize=5)
|
||||
|
||||
# Session boxes
|
||||
rect1 = mpatches.FancyBboxPatch(
|
||||
(0.7, 1.4),
|
||||
2.6,
|
||||
1.0,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.add_patch(rect1)
|
||||
ax.text(
|
||||
2.0, 1.9, "Sesja 1\n(4 zdarzenia)", fontsize=FS, ha="center", fontweight="bold"
|
||||
)
|
||||
|
||||
rect2 = mpatches.FancyBboxPatch(
|
||||
(7.7, 1.4),
|
||||
2.1,
|
||||
1.0,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=GRAY3,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.add_patch(rect2)
|
||||
ax.text(
|
||||
8.75, 1.9, "Sesja 2\n(3 zdarzenia)", fontsize=FS, ha="center", fontweight="bold"
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.5,
|
||||
3.0,
|
||||
"Nowa sesja po przerwie > gap",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
|
||||
def _draw_global_window(ax: Axes) -> None:
|
||||
"""Draw global window section."""
|
||||
ax.set_xlim(0, 14)
|
||||
ax.set_ylim(0, 4)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Global Window — jedno okno na cały strumień + trigger",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(13.5, 1.0),
|
||||
xytext=(0.3, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(13.5, 0.6, "czas", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
for i in range(12):
|
||||
x = 1.0 + i * 1.0
|
||||
ax.plot(x, 1.0, "ko", markersize=5)
|
||||
|
||||
# One big window
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(0.5, 1.4),
|
||||
12.5,
|
||||
1.0,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=LN,
|
||||
lw=2,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
6.75,
|
||||
1.9,
|
||||
"GLOBAL WINDOW (cały strumień)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Trigger markers
|
||||
for tx in [4.0, 8.0, 12.0]:
|
||||
ax.plot([tx, tx], [1.4, 2.4], color=LN, lw=2, linestyle="--")
|
||||
ax.text(
|
||||
tx,
|
||||
2.7,
|
||||
"EMIT",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY3, "edgecolor": LN},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
6.75,
|
||||
3.3,
|
||||
"Trigger decyduje kiedy emitować (np. co N zdarzeń)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
|
||||
def gen_window_types() -> None:
|
||||
"""Gen window types."""
|
||||
fig, axes = plt.subplots(4, 1, figsize=(9, 10))
|
||||
fig.suptitle("4 typy okien — TSSG", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
events = list(range(1, 13))
|
||||
|
||||
_draw_tumbling_window(axes[0], events)
|
||||
_draw_sliding_window(axes[1], events)
|
||||
_draw_session_window(axes[2])
|
||||
_draw_global_window(axes[3])
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.94])
|
||||
save_fig(fig, "q20_window_types.png")
|
||||
@ -0,0 +1,180 @@
|
||||
"""Common utilities and constants for Q20 diagram generation.
|
||||
|
||||
Monochrome, A4-printable PNGs (300 DPI).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
DPI = 300
|
||||
BG = "white"
|
||||
LN = "black"
|
||||
FS = 8
|
||||
FS_TITLE = 11
|
||||
FS_SMALL = 6.5
|
||||
FS_LABEL = 9
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
GRAY1 = "#E8E8E8"
|
||||
GRAY2 = "#D0D0D0"
|
||||
GRAY3 = "#B8B8B8"
|
||||
GRAY4 = "#F5F5F5"
|
||||
GRAY5 = "#C0C0C0"
|
||||
|
||||
|
||||
def draw_box(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
text: str,
|
||||
*,
|
||||
fill: str = "white",
|
||||
lw: float = 1.2,
|
||||
fontsize: float = FS,
|
||||
fontweight: str = "normal",
|
||||
ha: str = "center",
|
||||
va: str = "center",
|
||||
rounded: bool = True,
|
||||
edgecolor: str = LN,
|
||||
linestyle: str = "-",
|
||||
) -> None:
|
||||
"""Draw box."""
|
||||
if rounded:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=lw,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=fill,
|
||||
linestyle=linestyle,
|
||||
)
|
||||
else:
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
lw=lw,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=fill,
|
||||
linestyle=linestyle,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2,
|
||||
text,
|
||||
ha=ha,
|
||||
va=va,
|
||||
fontsize=fontsize,
|
||||
fontweight=fontweight,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
|
||||
def draw_arrow(
|
||||
ax: Axes,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
lw: float = 1.2,
|
||||
style: str = "->",
|
||||
color: str = LN,
|
||||
) -> None:
|
||||
"""Draw arrow."""
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
|
||||
)
|
||||
|
||||
|
||||
def save_fig(fig: Figure, name: str) -> None:
|
||||
"""Save fig."""
|
||||
path = str(Path(OUTPUT_DIR) / name)
|
||||
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", path)
|
||||
|
||||
|
||||
def draw_table(
|
||||
ax: Axes,
|
||||
headers: list[str],
|
||||
rows: list[list[str]],
|
||||
x0: float,
|
||||
y0: float,
|
||||
col_widths: list[float],
|
||||
row_h: float = 0.4,
|
||||
header_fill: str = GRAY2,
|
||||
row_fills: list[str] | None = None,
|
||||
fontsize: float = FS,
|
||||
header_fontsize: float | None = None,
|
||||
) -> None:
|
||||
"""Draw table."""
|
||||
if header_fontsize is None:
|
||||
header_fontsize = fontsize
|
||||
len(headers)
|
||||
# Header
|
||||
cx = x0
|
||||
for j, hdr in enumerate(headers):
|
||||
draw_box(
|
||||
ax,
|
||||
cx,
|
||||
y0,
|
||||
col_widths[j],
|
||||
row_h,
|
||||
hdr,
|
||||
fill=header_fill,
|
||||
fontsize=header_fontsize,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
cx += col_widths[j]
|
||||
# Rows
|
||||
for i, row in enumerate(rows):
|
||||
cy = y0 - (i + 1) * row_h
|
||||
cx = x0
|
||||
fill = GRAY4 if (i % 2 == 0) else "white"
|
||||
if row_fills and i < len(row_fills):
|
||||
fill = row_fills[i]
|
||||
for j, cell in enumerate(row):
|
||||
fw = "bold" if j == 0 else "normal"
|
||||
draw_box(
|
||||
ax,
|
||||
cx,
|
||||
cy,
|
||||
col_widths[j],
|
||||
row_h,
|
||||
cell,
|
||||
fill=fill,
|
||||
fontsize=fontsize,
|
||||
fontweight=fw,
|
||||
rounded=False,
|
||||
)
|
||||
cx += col_widths[j]
|
||||
@ -0,0 +1,240 @@
|
||||
"""Late data strategies and decision tree diagrams for Q20."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q20_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 16. Late data strategies (DRAS)
|
||||
# ============================================================
|
||||
def gen_late_data_strategies() -> None:
|
||||
"""Gen late data strategies."""
|
||||
fig, ax = plt.subplots(figsize=(9, 5.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Late Data — 4 strategie (mnemonik DRAS)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Setup: window closed, late event arrives
|
||||
draw_box(
|
||||
ax,
|
||||
0.5,
|
||||
5.5,
|
||||
4.5,
|
||||
1.0,
|
||||
"Okno [14:00-14:05]\nZAMKNIĘTE o 14:05",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
6.0,
|
||||
5.5,
|
||||
4.5,
|
||||
1.0,
|
||||
"Spóźnione zdarzenie\nevent_time=14:00:03\narrives=14:05:30",
|
||||
fill="#F8D7DA",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 10.5, 6.0, 5.0, 6.0, lw=2, color="#C62828", style="->")
|
||||
ax.text(
|
||||
7.5,
|
||||
5.2,
|
||||
"LATE!",
|
||||
fontsize=FS_LABEL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
)
|
||||
|
||||
# 4 strategies
|
||||
strategies = [
|
||||
("D — Drop", "Odrzuć spóźnione", "/dev/null", GRAY4),
|
||||
("R — Recompute", "Przelicz okno ponownie", "poprawne ale kosztowne", GRAY1),
|
||||
(
|
||||
"A — Allowed lateness",
|
||||
"Czekaj dodatkowy czas\n(np. +2 min)",
|
||||
"kompromis pamięci",
|
||||
GRAY2,
|
||||
),
|
||||
(
|
||||
"S — Side output",
|
||||
"Przekieruj do osobnej\nkolejki",
|
||||
"elastyczne, ręczna analiza",
|
||||
GRAY3,
|
||||
),
|
||||
]
|
||||
for i, (name, desc, tradeoff, color) in enumerate(strategies):
|
||||
y = 3.8 - i * 1.1
|
||||
draw_box(ax, 0.5, y, 2.5, 0.9, name, fill=color, fontsize=FS, fontweight="bold")
|
||||
ax.text(3.3, y + 0.45, desc, fontsize=FS_SMALL, va="center")
|
||||
ax.text(
|
||||
8.5,
|
||||
y + 0.45,
|
||||
tradeoff,
|
||||
fontsize=FS_SMALL,
|
||||
va="center",
|
||||
style="italic",
|
||||
color="#555",
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_late_data_strategies.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 17. Decision tree — which platform
|
||||
# ============================================================
|
||||
def gen_decision_tree() -> None:
|
||||
"""Gen decision tree."""
|
||||
fig, ax = plt.subplots(figsize=(10, 5.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Drzewo decyzyjne — wybór platformy",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Root question
|
||||
draw_box(
|
||||
ax,
|
||||
3.5,
|
||||
5.5,
|
||||
4.5,
|
||||
1.0,
|
||||
"Latencja < 10ms\nwymagana?",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# TAK branch
|
||||
draw_arrow(ax, 3.5, 5.7, 2.0, 5.0, lw=1.5)
|
||||
ax.text(2.3, 5.3, "TAK", fontsize=FS, fontweight="bold")
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
0.3,
|
||||
3.5,
|
||||
3.5,
|
||||
1.0,
|
||||
"Dane już w Kafce?\nProste transformacje?",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# TAK → Kafka Streams
|
||||
draw_arrow(ax, 0.3, 3.7, -0.1, 3.0, lw=1.5)
|
||||
ax.text(0.0, 3.3, "TAK", fontsize=FS_SMALL, fontweight="bold")
|
||||
draw_box(
|
||||
ax,
|
||||
-0.3,
|
||||
1.8,
|
||||
2.5,
|
||||
1.0,
|
||||
"Kafka\nStreams",
|
||||
fill=GRAY5,
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# NIE → Flink
|
||||
draw_arrow(ax, 3.8, 3.7, 4.5, 3.0, lw=1.5)
|
||||
ax.text(4.0, 3.3, "NIE\n(złożona logika)", fontsize=FS_SMALL)
|
||||
draw_box(
|
||||
ax,
|
||||
3.0,
|
||||
1.8,
|
||||
2.5,
|
||||
1.0,
|
||||
"Apache\nFlink",
|
||||
fill=GRAY5,
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# NIE branch
|
||||
draw_arrow(ax, 8.0, 5.7, 9.5, 5.0, lw=1.5)
|
||||
ax.text(8.7, 5.3, "NIE", fontsize=FS, fontweight="bold")
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
7.5,
|
||||
3.5,
|
||||
4.2,
|
||||
1.0,
|
||||
"~100ms-1s OK?\nPotrzeba ML / SQL?",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# TAK + ML → Spark
|
||||
draw_arrow(ax, 9.5, 3.5, 9.5, 3.0, lw=1.5)
|
||||
ax.text(10.0, 3.3, "TAK + ML/SQL", fontsize=FS_SMALL)
|
||||
draw_box(
|
||||
ax,
|
||||
8.0,
|
||||
1.8,
|
||||
2.5,
|
||||
1.0,
|
||||
"Spark\nStreaming",
|
||||
fill=GRAY5,
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# TAK + proste → Kafka Streams too
|
||||
draw_arrow(ax, 7.5, 3.7, 6.5, 3.0, lw=1.5)
|
||||
ax.text(6.3, 3.3, "proste + TAK", fontsize=FS_SMALL)
|
||||
draw_box(
|
||||
ax,
|
||||
5.8,
|
||||
1.8,
|
||||
2.0,
|
||||
1.0,
|
||||
"Kafka\nStreams",
|
||||
fill=GRAY5,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Legend
|
||||
ax.text(
|
||||
6.0,
|
||||
0.7,
|
||||
"Reguła: Kafka Streams = najprostsze (library) | "
|
||||
"Flink = najpotężniejszy (true streaming) | Spark = ekosystem ML",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_decision_tree.png")
|
||||
@ -0,0 +1,471 @@
|
||||
"""Streaming ecosystem, micro-batch, platform comparison, and engine diagrams for Q20."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q20_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_table,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 7. Streaming ecosystem overview
|
||||
# ============================================================
|
||||
def gen_streaming_ecosystem() -> None:
|
||||
"""Gen streaming ecosystem."""
|
||||
fig, ax = plt.subplots(figsize=(10, 5.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Ekosystem przetwarzania strumieniowego",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Source
|
||||
draw_box(
|
||||
ax,
|
||||
0.3,
|
||||
2.5,
|
||||
2.0,
|
||||
3.0,
|
||||
"Kafka\nTopics\n(źródło)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Engines
|
||||
engines = [
|
||||
("Kafka Streams\n(library w JVM)", GRAY1, 4.7),
|
||||
("Apache Flink\n(klaster)", GRAY3, 3.2),
|
||||
("Spark Streaming\n(klaster)", GRAY5, 1.7),
|
||||
]
|
||||
for label, color, y in engines:
|
||||
draw_box(
|
||||
ax, 4.0, y, 3.0, 1.2, label, fill=color, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_arrow(ax, 2.3, 4.0, 4.0, y + 0.6, lw=1.5)
|
||||
|
||||
# Sinks
|
||||
sinks = [
|
||||
("Kafka topic\n/ baza danych", GRAY4, 4.7),
|
||||
("DB / Kafka\n/ S3", GRAY4, 3.2),
|
||||
("HDFS / DB\n/ dashboard", GRAY4, 1.7),
|
||||
]
|
||||
for label, color, y in sinks:
|
||||
draw_box(ax, 8.5, y, 2.5, 1.2, label, fill=color, fontsize=FS)
|
||||
draw_arrow(ax, 7.0, y + 0.6, 8.5, y + 0.6, lw=1.5)
|
||||
|
||||
# Labels
|
||||
ax.text(1.3, 6.0, "ŹRÓDŁO", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
ax.text(5.5, 6.2, "SILNIK", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
ax.text(9.75, 6.2, "WYNIK", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
|
||||
# Latency annotations
|
||||
ax.text(5.5, 5.95, "~1-10 ms", fontsize=FS_SMALL, ha="center", style="italic")
|
||||
ax.text(5.5, 4.5, "<10 ms", fontsize=FS_SMALL, ha="center", style="italic")
|
||||
ax.text(5.5, 3.0, "~100 ms", fontsize=FS_SMALL, ha="center", style="italic")
|
||||
|
||||
save_fig(fig, "q20_streaming_ecosystem.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 8. True streaming vs Micro-batch
|
||||
# ============================================================
|
||||
def gen_true_vs_microbatch() -> None:
|
||||
"""Gen true vs microbatch."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(10, 5.5))
|
||||
fig.suptitle("True Streaming vs Micro-Batch", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# True streaming
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 3.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"TRUE STREAMING (Flink, Kafka Streams) — event-by-event",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
for i in range(6):
|
||||
x = 1.0 + i * 1.8
|
||||
# Event
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
2.0,
|
||||
0.8,
|
||||
0.7,
|
||||
f"e{i + 1}",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
# Arrow down
|
||||
draw_arrow(ax, x + 0.4, 2.0, x + 0.4, 1.4, lw=1)
|
||||
# Result
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
0.5,
|
||||
0.8,
|
||||
0.7,
|
||||
f"r{i + 1}",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
# Latency label
|
||||
ax.text(x + 0.4, 1.6, "~ms", fontsize=5, ha="center", color="#555")
|
||||
|
||||
ax.text(
|
||||
11.5,
|
||||
1.3,
|
||||
"Latencja:\n< 10 ms",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
# Micro-batch
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 3.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"MICRO-BATCH (Spark Streaming) — grupami co ~100ms",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
batch_colors = [GRAY1, GRAY2, GRAY3]
|
||||
for b in range(3):
|
||||
bx = 0.8 + b * 3.5
|
||||
# Batch boundary
|
||||
draw_box(ax, bx, 1.8, 3.0, 1.0, "", fill=batch_colors[b], rounded=True, lw=1.5)
|
||||
ax.text(
|
||||
bx + 1.5, 2.6, f"Batch {b + 1}", fontsize=FS, ha="center", fontweight="bold"
|
||||
)
|
||||
for j in range(3):
|
||||
ex = bx + 0.3 + j * 0.9
|
||||
draw_box(
|
||||
ax,
|
||||
ex,
|
||||
2.0,
|
||||
0.7,
|
||||
0.5,
|
||||
f"e{b * 3 + j + 1}",
|
||||
fill="white",
|
||||
fontsize=FS_SMALL,
|
||||
rounded=False,
|
||||
)
|
||||
|
||||
# Arrow down
|
||||
draw_arrow(ax, bx + 1.5, 1.8, bx + 1.5, 1.2, lw=1.5)
|
||||
# Result
|
||||
draw_box(
|
||||
ax,
|
||||
bx + 0.5,
|
||||
0.4,
|
||||
2.0,
|
||||
0.7,
|
||||
f"result {b + 1}",
|
||||
fill=GRAY4,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
11.5,
|
||||
1.3,
|
||||
"Latencja:\n~100ms-s",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.92])
|
||||
save_fig(fig, "q20_true_vs_microbatch.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 9. Platform comparison table
|
||||
# ============================================================
|
||||
def gen_platform_comparison() -> None:
|
||||
"""Gen platform comparison."""
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
ax.set_xlim(0, 11.5)
|
||||
ax.set_ylim(-6, 1)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Porównanie platform strumieniowych",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
headers = ["Cecha", "Kafka Streams", "Apache Flink", "Spark Streaming"]
|
||||
col_w = [2.5, 2.8, 2.8, 2.8]
|
||||
rows = [
|
||||
["Model", "event-by-event", "event-by-event", "micro-batch (~100ms)"],
|
||||
["Deployment", "library (w JVM)", "klaster", "klaster"],
|
||||
["Latencja", "~1-10 ms", "< 10 ms", "100 ms - sekundy"],
|
||||
["Exactly-once", "Kafka TXN", "checkpointing", "WAL"],
|
||||
["State", "RocksDB local", "RocksDB + ckpt", "in-memory / ext"],
|
||||
["Okna", "T, S, Session", "wszystkie + custom", "T, S"],
|
||||
["Use case", "Kafka → Kafka", "złożona analityka", "ETL + ML / SQL"],
|
||||
]
|
||||
draw_table(
|
||||
ax,
|
||||
headers,
|
||||
rows,
|
||||
x0=0.25,
|
||||
y0=0.5,
|
||||
col_widths=col_w,
|
||||
row_h=0.6,
|
||||
fontsize=7,
|
||||
header_fontsize=8,
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_platform_comparison.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 10. Kafka Streams architecture
|
||||
# ============================================================
|
||||
def gen_kafka_streams_arch() -> None:
|
||||
"""Gen kafka streams arch."""
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Kafka Streams — architektura (library w JVM)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Outer box: Your Java application
|
||||
draw_box(ax, 0.5, 0.5, 11.0, 5.5, "", fill=GRAY4, rounded=True, lw=2.5)
|
||||
ax.text(
|
||||
6.0,
|
||||
5.7,
|
||||
"Twoja aplikacja Java (JVM)",
|
||||
fontsize=FS_LABEL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Kafka Consumer
|
||||
draw_box(
|
||||
ax,
|
||||
1.0,
|
||||
3.0,
|
||||
2.5,
|
||||
1.5,
|
||||
"Kafka\nConsumer\n(input topic)",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Processing
|
||||
draw_box(
|
||||
ax,
|
||||
4.5,
|
||||
3.0,
|
||||
2.5,
|
||||
1.5,
|
||||
"Kafka Streams\n(logika\nbiznesowa)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Kafka Producer
|
||||
draw_box(
|
||||
ax,
|
||||
8.0,
|
||||
3.0,
|
||||
2.5,
|
||||
1.5,
|
||||
"Kafka\nProducer\n(output topic)",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Arrows
|
||||
draw_arrow(ax, 3.5, 3.75, 4.5, 3.75, lw=2)
|
||||
draw_arrow(ax, 7.0, 3.75, 8.0, 3.75, lw=2)
|
||||
|
||||
# RocksDB state store
|
||||
draw_box(
|
||||
ax,
|
||||
4.5,
|
||||
1.0,
|
||||
2.5,
|
||||
1.3,
|
||||
"RocksDB\n(stan lokalny)",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.plot([5.75, 5.75], [3.0, 2.3], color=LN, lw=1.5)
|
||||
ax.text(
|
||||
7.3,
|
||||
1.6,
|
||||
"okna, joiny,\nagregacje",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
va="center",
|
||||
)
|
||||
|
||||
# Key message
|
||||
ax.text(
|
||||
6.0,
|
||||
0.2,
|
||||
"NIE potrzebujesz osobnego klastra! Skalujesz = więcej instancji JVM.",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "edgecolor": LN},
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_kafka_streams_arch.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 11. Flink architecture + checkpointing
|
||||
# ============================================================
|
||||
def gen_flink_arch() -> None:
|
||||
"""Gen flink arch."""
|
||||
fig, ax = plt.subplots(figsize=(9, 6))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 8)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Apache Flink — architektura klastra + checkpointing",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Cluster border
|
||||
draw_box(ax, 0.3, 1.0, 11.4, 6.2, "", fill=GRAY4, rounded=True, lw=2.5)
|
||||
ax.text(
|
||||
6.0, 6.95, "FLINK CLUSTER", fontsize=FS_LABEL, ha="center", fontweight="bold"
|
||||
)
|
||||
|
||||
# Job Manager
|
||||
draw_box(
|
||||
ax,
|
||||
1.0,
|
||||
5.5,
|
||||
3.0,
|
||||
1.2,
|
||||
"Job Manager\n(koordynacja,\ncheckpointy)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Task Managers
|
||||
draw_box(ax, 1.0, 3.0, 10.0, 2.0, "", fill="white", rounded=True, lw=1.5)
|
||||
ax.text(
|
||||
6.0, 4.7, "Task Managers (workery)", fontsize=FS, ha="center", fontweight="bold"
|
||||
)
|
||||
|
||||
slots = ["source\n& map()", "map()", "window()\n& reduce", "sink()"]
|
||||
for i, s in enumerate(slots):
|
||||
x = 1.5 + i * 2.4
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
3.3,
|
||||
2.0,
|
||||
1.2,
|
||||
f"Slot {i + 1}\n{s}",
|
||||
fill=GRAY1,
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
draw_arrow(ax, 2.5, 5.5, 6.0, 5.0, lw=1.5, style="->")
|
||||
ax.text(5.0, 5.5, "przydziela\npodzadania", fontsize=FS_SMALL, style="italic")
|
||||
|
||||
# Checkpoint storage
|
||||
draw_box(
|
||||
ax,
|
||||
5.5,
|
||||
1.2,
|
||||
3.5,
|
||||
1.2,
|
||||
"Checkpoint Storage\n(HDFS / S3)",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.plot([7.25, 7.25], [2.4, 3.3], color=LN, lw=1.5, linestyle="--")
|
||||
ax.text(8.0, 2.7, "snapshoty\nstanu", fontsize=FS_SMALL, style="italic")
|
||||
|
||||
# Barrier concept at bottom
|
||||
ax.text(3.0, 1.6, "Barrier:", fontsize=FS, fontweight="bold")
|
||||
barrier_boxes = ["source", "|B|", "map", "|B|", "sink"]
|
||||
bx = 0.8
|
||||
for _i, b in enumerate(barrier_boxes):
|
||||
if b == "|B|":
|
||||
ax.text(
|
||||
bx + 0.3,
|
||||
1.5,
|
||||
b,
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY5, "edgecolor": LN},
|
||||
)
|
||||
draw_arrow(ax, bx, 1.5, bx + 0.1, 1.5, lw=1)
|
||||
bx += 0.7
|
||||
else:
|
||||
draw_box(
|
||||
ax,
|
||||
bx,
|
||||
1.3,
|
||||
1.0,
|
||||
0.45,
|
||||
b,
|
||||
fill=GRAY1,
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
bx += 1.2
|
||||
|
||||
save_fig(fig, "q20_flink_arch.png")
|
||||
@ -0,0 +1,464 @@
|
||||
"""Event time, fraud detection, SLA monitoring, and session diagrams for Q20."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q20_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_box,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 3. Event Time vs Processing Time scatter + watermark
|
||||
# ============================================================
|
||||
def gen_event_vs_processing_time() -> None:
|
||||
"""Gen event vs processing time."""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(11, 5))
|
||||
fig.suptitle(
|
||||
"Event Time vs Processing Time + Watermark",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# --- Panel 1: Ideal vs Real ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_xlabel("Event Time", fontsize=FS_LABEL)
|
||||
ax.set_ylabel("Processing Time", fontsize=FS_LABEL)
|
||||
ax.set_title("Idealny vs Realny świat", fontsize=FS_LABEL, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Ideal line
|
||||
ax.plot([0, 9], [0, 9], "k--", lw=1.5, label="ideał (brak opóźnień)")
|
||||
|
||||
# Real scattered points (processing >= event, some out of order)
|
||||
event_times = np.sort(rng.uniform(1, 8, 15))
|
||||
proc_times = event_times + rng.exponential(0.5, 15)
|
||||
# Make some out of order
|
||||
idx = [3, 7, 11]
|
||||
for i in idx:
|
||||
proc_times[i] += 1.5
|
||||
|
||||
ax.scatter(
|
||||
event_times, proc_times, c="black", s=30, zorder=5, label="zdarzenia (realne)"
|
||||
)
|
||||
|
||||
# Highlight out-of-order
|
||||
for i in idx:
|
||||
ax.annotate(
|
||||
"out-of-order",
|
||||
xy=(event_times[i], proc_times[i]),
|
||||
xytext=(event_times[i] + 0.8, proc_times[i] + 0.5),
|
||||
fontsize=FS_SMALL,
|
||||
ha="left",
|
||||
arrowprops={"arrowstyle": "->", "lw": 0.8, "color": "#555"},
|
||||
)
|
||||
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper left")
|
||||
ax.text(
|
||||
7,
|
||||
2,
|
||||
"Opóźnienie\nsieciowe ↑",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
style="italic",
|
||||
color="#555",
|
||||
)
|
||||
|
||||
# --- Panel 2: Watermark concept ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_xlabel("Event Time", fontsize=FS_LABEL)
|
||||
ax.set_ylabel("Processing Time", fontsize=FS_LABEL)
|
||||
ax.set_title("Watermark — granica postępu", fontsize=FS_LABEL, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Events
|
||||
ax.scatter(event_times, proc_times, c="black", s=30, zorder=5)
|
||||
|
||||
# Watermark line (below most points, tracks progress)
|
||||
wm_x = np.linspace(0, 9, 50)
|
||||
wm_y = wm_x + 0.3 # watermark slightly above ideal
|
||||
ax.plot(wm_x, wm_y, "k-", lw=2.5, label="Watermark")
|
||||
ax.fill_between(wm_x, 0, wm_y, alpha=0.15, color="gray")
|
||||
|
||||
ax.text(
|
||||
2.0,
|
||||
1.0,
|
||||
'PONIŻEJ watermark:\n„na pewno dotarło"',
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
# Late event
|
||||
late_x, late_y = event_times[7], proc_times[7]
|
||||
ax.scatter(
|
||||
[late_x], [late_y], c="white", s=80, zorder=6, edgecolors="black", linewidths=2
|
||||
)
|
||||
ax.annotate(
|
||||
"LATE DATA!\n(po watermarku)",
|
||||
xy=(late_x, late_y),
|
||||
xytext=(late_x + 1.2, late_y + 0.8),
|
||||
fontsize=FS_SMALL,
|
||||
ha="left",
|
||||
fontweight="bold",
|
||||
arrowprops={"arrowstyle": "->", "lw": 1, "color": LN},
|
||||
)
|
||||
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper left")
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.92])
|
||||
save_fig(fig, "q20_event_vs_processing_time.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 4. Tumbling window example — fraud detection
|
||||
# ============================================================
|
||||
def gen_tumbling_fraud() -> None:
|
||||
"""Gen tumbling fraud."""
|
||||
fig, ax = plt.subplots(figsize=(9, 4))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 5.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Tumbling Window — fraud detection (okno = 1 min)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Time axis
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(11.5, 1.0),
|
||||
xytext=(0.5, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(6.0, 0.4, "czas", fontsize=FS, ha="center")
|
||||
|
||||
# Window 1: normal
|
||||
draw_box(ax, 1.0, 1.5, 4.5, 3.0, "", fill=GRAY4, rounded=True, lw=2)
|
||||
ax.text(
|
||||
3.25, 4.2, "[14:00 — 14:01]", fontsize=FS_LABEL, ha="center", fontweight="bold"
|
||||
)
|
||||
# Transactions
|
||||
txns1 = ["Sklep A: 50 zł", "Sklep B: 30 zł", "Stacja: 80 zł"]
|
||||
for i, t in enumerate(txns1):
|
||||
draw_box(
|
||||
ax,
|
||||
1.3,
|
||||
3.3 - i * 0.55,
|
||||
4.0,
|
||||
0.45,
|
||||
t,
|
||||
fill=GRAY1,
|
||||
fontsize=FS_SMALL,
|
||||
rounded=False,
|
||||
)
|
||||
ax.text(
|
||||
3.25,
|
||||
1.7,
|
||||
"count = 3 → OK",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#2E7D32",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.15",
|
||||
"facecolor": "#E8F5E9",
|
||||
"edgecolor": "#2E7D32",
|
||||
},
|
||||
)
|
||||
|
||||
# Window 2: fraud!
|
||||
draw_box(ax, 6.0, 1.5, 4.5, 3.0, "", fill=GRAY1, rounded=True, lw=2)
|
||||
ax.text(
|
||||
8.25, 4.2, "[14:01 — 14:02]", fontsize=FS_LABEL, ha="center", fontweight="bold"
|
||||
)
|
||||
txns2 = ["ATM Warszawa: 500 zł", "ATM Kraków: 500 zł", "... +45 transakcji"]
|
||||
for i, t in enumerate(txns2):
|
||||
draw_box(
|
||||
ax,
|
||||
6.3,
|
||||
3.3 - i * 0.55,
|
||||
4.0,
|
||||
0.45,
|
||||
t,
|
||||
fill=GRAY3,
|
||||
fontsize=FS_SMALL,
|
||||
rounded=False,
|
||||
)
|
||||
ax.text(
|
||||
8.25,
|
||||
1.7,
|
||||
"count = 47 → ALERT!",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.15",
|
||||
"facecolor": "#F8D7DA",
|
||||
"edgecolor": "#C62828",
|
||||
},
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_tumbling_fraud.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 5. Sliding window — SLA monitoring
|
||||
# ============================================================
|
||||
def gen_sliding_sla() -> None:
|
||||
"""Gen sliding sla."""
|
||||
fig, ax = plt.subplots(figsize=(9, 4.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Sliding Window — monitoring SLA (okno=5min, krok=1min)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Time axis
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(11.5, 0.5),
|
||||
xytext=(0.5, 0.5),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
times = ["14:05", "14:06", "14:07", "14:08", "14:09"]
|
||||
latencies = [120, 180, 340, 290, 150]
|
||||
sla = 200
|
||||
|
||||
for i, (t, lat) in enumerate(zip(times, latencies, strict=False)):
|
||||
x = 1.5 + i * 2.0
|
||||
ax.text(x, 0.1, t, fontsize=FS, ha="center")
|
||||
|
||||
# Bar proportional to latency
|
||||
bar_h = lat / 100.0
|
||||
is_breach = lat > sla
|
||||
fill = "#F8D7DA" if is_breach else GRAY1
|
||||
edge = "#C62828" if is_breach else LN
|
||||
draw_box(
|
||||
ax,
|
||||
x - 0.5,
|
||||
1.0,
|
||||
1.0,
|
||||
bar_h,
|
||||
"",
|
||||
fill=fill,
|
||||
rounded=False,
|
||||
edgecolor=edge,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.text(
|
||||
x,
|
||||
1.0 + bar_h + 0.15,
|
||||
f"{lat}ms",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828" if is_breach else LN,
|
||||
)
|
||||
|
||||
# Status
|
||||
status = "ALERT!" if is_breach else "OK"
|
||||
ax.text(
|
||||
x,
|
||||
1.0 + bar_h + 0.55,
|
||||
status,
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828" if is_breach else "#2E7D32",
|
||||
)
|
||||
|
||||
# SLA line
|
||||
sla_y = 1.0 + sla / 100.0
|
||||
ax.plot([0.8, 11.2], [sla_y, sla_y], "k--", lw=1.5)
|
||||
ax.text(11.3, sla_y, f"SLA={sla}ms", fontsize=FS, va="center", fontweight="bold")
|
||||
|
||||
# Sliding window bracket
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(1.0, 5.3),
|
||||
xytext=(5.0, 5.3),
|
||||
arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(3.0, 5.6, "okno = 5 min", fontsize=FS, ha="center", fontweight="bold")
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(3.0, 4.8),
|
||||
xytext=(5.0, 4.8),
|
||||
arrowprops={"arrowstyle": "<->", "lw": 1, "color": "#555"},
|
||||
)
|
||||
ax.text(
|
||||
4.0,
|
||||
4.4,
|
||||
"krok = 1 min\n(nakładanie!)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
save_fig(fig, "q20_sliding_sla.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 6. Session window — user sessions
|
||||
# ============================================================
|
||||
def gen_session_users() -> None:
|
||||
"""Gen session users."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(10, 5))
|
||||
fig.suptitle(
|
||||
"Session Window — sesje użytkowników (gap = 30 min)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Anna: 2 sessions
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 14)
|
||||
ax.set_ylim(0, 3.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Użytkownik Anna", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(13.5, 1.0),
|
||||
xytext=(0.3, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
|
||||
# Clicks cluster 1
|
||||
for x in [1.0, 1.8, 2.5, 3.2]:
|
||||
ax.plot(x, 1.0, "ko", markersize=6)
|
||||
# Clicks cluster 2
|
||||
for x in [9.0, 9.8, 10.5]:
|
||||
ax.plot(x, 1.0, "ko", markersize=6)
|
||||
|
||||
# Sessions
|
||||
rect1 = mpatches.FancyBboxPatch(
|
||||
(0.7, 1.5),
|
||||
2.8,
|
||||
1.2,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.add_patch(rect1)
|
||||
ax.text(
|
||||
2.1,
|
||||
2.1,
|
||||
"Sesja 1\n4 kliknięcia, 12 min",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
rect2 = mpatches.FancyBboxPatch(
|
||||
(8.7, 1.5),
|
||||
2.1,
|
||||
1.2,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=GRAY3,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.add_patch(rect2)
|
||||
ax.text(
|
||||
9.75,
|
||||
2.1,
|
||||
"Sesja 2\n3 kliknięcia, 8 min",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Gap
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(8.5, 0.5),
|
||||
xytext=(3.8, 0.5),
|
||||
arrowprops={"arrowstyle": "<->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(
|
||||
6.15,
|
||||
0.1,
|
||||
"cisza 45 min > gap(30)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Bob: 1 session
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 14)
|
||||
ax.set_ylim(0, 3.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Użytkownik Bob", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(13.5, 1.0),
|
||||
xytext=(0.3, 1.0),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
|
||||
# Clicks spread evenly
|
||||
bobs = [1.0, 2.5, 4.0, 5.5, 7.0, 8.5, 10.0]
|
||||
for x in bobs:
|
||||
ax.plot(x, 1.0, "ko", markersize=6)
|
||||
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(0.7, 1.5),
|
||||
9.6,
|
||||
1.2,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=LN,
|
||||
lw=2,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
5.5,
|
||||
2.1,
|
||||
"Sesja 1 (ciągła) — 7 kliknięć, każde < 30 min od poprzedniego",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.92])
|
||||
save_fig(fig, "q20_session_users.png")
|
||||
@ -0,0 +1,467 @@
|
||||
"""FCN and U-Net architecture diagram generators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q23_common import (
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TINY,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY5,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
_save_figure,
|
||||
plt,
|
||||
)
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def generate_fcn() -> None:
|
||||
"""Generate fcn."""
|
||||
_fig, axes = plt.subplots(2, 1, figsize=(10, 7))
|
||||
|
||||
# --- Panel 1: FC vs Conv 1x1 ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 20)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"FC (Fully Connected) vs Conv 1x1", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
# Classic CNN with FC
|
||||
layer_info_fc = [
|
||||
(1.5, "Obraz\n224x224x3", 2.2, GRAY2),
|
||||
(4.5, "Conv+Pool\n112x112x64", 1.8, GRAY2),
|
||||
(7.5, "Conv+Pool\n7x7x512", 1.0, GRAY2),
|
||||
(10, "Flatten\n25088", 0.5, ACCENT_LIGHT),
|
||||
(12, "FC\n4096", 0.5, ACCENT_LIGHT),
|
||||
(14, "FC\n1000", 0.3, ACCENT_LIGHT),
|
||||
(16, '"Kot"', 0.3, "#FFCDD2"),
|
||||
]
|
||||
|
||||
y_fc = 4.5
|
||||
for i, (x, label, w, color) in enumerate(layer_info_fc):
|
||||
rect = FancyBboxPatch(
|
||||
(x - w / 2, y_fc - 0.6),
|
||||
w,
|
||||
1.2,
|
||||
boxstyle="round,pad=0.05",
|
||||
facecolor=color,
|
||||
edgecolor=BLACK,
|
||||
linewidth=0.8,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(x, y_fc, label, ha="center", va="center", fontsize=FS_TINY)
|
||||
if i < len(layer_info_fc) - 1:
|
||||
next_x = layer_info_fc[i + 1][0]
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(next_x - layer_info_fc[i + 1][2] / 2, y_fc),
|
||||
xytext=(x + w / 2, y_fc),
|
||||
arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
0.3, y_fc, "CNN:", fontsize=FS, fontweight="bold", color=RED_ACCENT, va="center"
|
||||
)
|
||||
ax.text(
|
||||
12,
|
||||
y_fc + 1,
|
||||
"PROBLEM: FC wymaga\nSTAŁEGO rozmiaru\n(np. 224x224)",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
color=RED_ACCENT,
|
||||
fontweight="bold",
|
||||
bbox={
|
||||
"boxstyle": "round",
|
||||
"facecolor": "#FFCDD2",
|
||||
"edgecolor": RED_ACCENT,
|
||||
"alpha": 0.3,
|
||||
},
|
||||
)
|
||||
|
||||
# FCN with Conv 1x1
|
||||
layer_info_fcn = [
|
||||
(1.5, "Obraz\nHxWx3", 2.2, GRAY2),
|
||||
(4.5, "Conv+Pool\nH/2 x W/2\nx64", 1.8, GRAY2),
|
||||
(7.5, "Conv+Pool\nH/32 x W/32\nx512", 1.0, GRAY2),
|
||||
(10.5, "Conv 1x1\nH/32 x W/32\nxC", 0.8, "#C8E6C9"),
|
||||
(13.5, "Upsample\nHxWxC", 1.8, "#C8E6C9"),
|
||||
(16.5, "Mapa\nsegmentacji", 1.5, "#C8E6C9"),
|
||||
]
|
||||
|
||||
y_fcn = 1.5
|
||||
for i, (x, label, w, color) in enumerate(layer_info_fcn):
|
||||
rect = FancyBboxPatch(
|
||||
(x - w / 2, y_fcn - 0.7),
|
||||
w,
|
||||
1.4,
|
||||
boxstyle="round,pad=0.05",
|
||||
facecolor=color,
|
||||
edgecolor=BLACK,
|
||||
linewidth=0.8,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(x, y_fcn, label, ha="center", va="center", fontsize=FS_TINY)
|
||||
if i < len(layer_info_fcn) - 1:
|
||||
next_x = layer_info_fcn[i + 1][0]
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(next_x - layer_info_fcn[i + 1][2] / 2, y_fcn),
|
||||
xytext=(x + w / 2, y_fcn),
|
||||
arrowprops={"arrowstyle": "->", "color": GRAY5, "lw": 1},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
0.3,
|
||||
y_fcn,
|
||||
"FCN:",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=GREEN_ACCENT,
|
||||
va="center",
|
||||
)
|
||||
ax.text(
|
||||
10.5,
|
||||
y_fcn + 1.2,
|
||||
"Conv 1x1:\nkażdy piksel\nosobno x wagi\n(jak FC ale\nzachowuje HxW)",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GREEN_ACCENT,
|
||||
bbox={
|
||||
"boxstyle": "round",
|
||||
"facecolor": "#C8E6C9",
|
||||
"edgecolor": GREEN_ACCENT,
|
||||
"alpha": 0.3,
|
||||
},
|
||||
)
|
||||
|
||||
# --- Panel 2: What FC and Conv do ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 20)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Co robi warstwa FC? Co robi konwolucja?", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
# FC explanation
|
||||
rect = FancyBboxPatch(
|
||||
(0.3, 3.2),
|
||||
9,
|
||||
2.5,
|
||||
boxstyle="round,pad=0.15",
|
||||
facecolor=ACCENT_LIGHT,
|
||||
edgecolor=ACCENT,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
4.8, 5.2, "Fully Connected (FC)", fontsize=FS, fontweight="bold", ha="center"
|
||||
)
|
||||
ax.text(
|
||||
4.8,
|
||||
4.5,
|
||||
"KAŻDY neuron połączony z KAŻDYM wejściem\n"
|
||||
"25 088 wejść x 4 096 neuronów = ~103 MLN wag!\n"
|
||||
"Traci informację GDZIE (przestrzenną)\n"
|
||||
"Wymaga STAŁEGO rozmiaru wejścia",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
va="top",
|
||||
)
|
||||
|
||||
# Conv explanation
|
||||
rect = FancyBboxPatch(
|
||||
(10.3, 3.2),
|
||||
9,
|
||||
2.5,
|
||||
boxstyle="round,pad=0.15",
|
||||
facecolor="#C8E6C9",
|
||||
edgecolor=GREEN_ACCENT,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(14.8, 5.2, "Konwolucja (Conv)", fontsize=FS, fontweight="bold", ha="center")
|
||||
ax.text(
|
||||
14.8,
|
||||
4.5,
|
||||
'Filtr (np. 3x3) „jedzie" po obrazie\n'
|
||||
"Te same wagi dla KAŻDEJ pozycji\n"
|
||||
"Zachowuje informację GDZIE\n"
|
||||
"Akceptuje DOWOLNY rozmiar wejścia",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
va="top",
|
||||
)
|
||||
|
||||
# Conv 1x1 explanation
|
||||
rect = FancyBboxPatch(
|
||||
(3, 0.3),
|
||||
14,
|
||||
2.2,
|
||||
boxstyle="round,pad=0.15",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=BLACK,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
10,
|
||||
2.1,
|
||||
'Conv 1x1 = „FC per piksel"',
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
10,
|
||||
1.5,
|
||||
"Filtr 1x1: patrzy na JEDEN piksel, ale WSZYSTKIE kanały (512→C klas)\n"
|
||||
"Działa jak FC ale zachowuje mapę HxW → każdy piksel osobno klasyfikowany\n"
|
||||
"FCN: zamień FC na Conv1x1 → koniec z wymogiem stałego rozmiaru!",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
va="top",
|
||||
)
|
||||
|
||||
_save_figure("q23_fc_vs_conv1x1.png")
|
||||
|
||||
|
||||
def generate_unet() -> None:
|
||||
"""Generate unet."""
|
||||
_fig, ax = plt.subplots(1, 1, figsize=(10, 6))
|
||||
ax.set_xlim(-1, 21)
|
||||
ax.set_ylim(-1, 12)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"U-Net: architektura w kształcie litery U",
|
||||
fontsize=FS_TITLE + 1,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Encoder layers (going DOWN-LEFT)
|
||||
encoder_layers = [
|
||||
(2, 10, 2.5, 1.5, "572x572x1\n(wejście)", 64),
|
||||
(2, 7.5, 2.2, 1.3, "284x284\nx64", 64),
|
||||
(2, 5, 1.8, 1.1, "140x140\nx128", 128),
|
||||
(2, 2.5, 1.5, 1.0, "68x68\nx256", 256),
|
||||
]
|
||||
|
||||
# Bottleneck
|
||||
bottleneck = (8, 0.5, 2.5, 1.2, "32x32x512\n(bottleneck)", 512)
|
||||
|
||||
# Decoder layers (going UP-RIGHT)
|
||||
decoder_layers = [
|
||||
(14, 2.5, 1.5, 1.0, "68x68\nx256", 256),
|
||||
(14, 5, 1.8, 1.1, "140x140\nx128", 128),
|
||||
(14, 7.5, 2.2, 1.3, "284x284\nx64", 64),
|
||||
(14, 10, 2.5, 1.5, "572x572xC\n(mapa seg.)", "C"),
|
||||
]
|
||||
|
||||
def draw_block(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
label: str,
|
||||
color: str,
|
||||
) -> None:
|
||||
"""Draw block."""
|
||||
rect = FancyBboxPatch(
|
||||
(x - w / 2, y - h / 2),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
facecolor=color,
|
||||
edgecolor=BLACK,
|
||||
linewidth=1.2,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(x, y, label, ha="center", va="center", fontsize=FS_TINY)
|
||||
|
||||
# Draw encoder
|
||||
for x, y, w, h, label, _channels in encoder_layers:
|
||||
draw_block(ax, x, y, w, h, label, ACCENT_LIGHT)
|
||||
|
||||
# Draw arrows down (encoder)
|
||||
for i in range(len(encoder_layers) - 1):
|
||||
x1, y1 = encoder_layers[i][0], encoder_layers[i][1] - encoder_layers[i][3] / 2
|
||||
x2, y2 = (
|
||||
encoder_layers[i + 1][0],
|
||||
encoder_layers[i + 1][1] + encoder_layers[i + 1][3] / 2,
|
||||
)
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2},
|
||||
)
|
||||
ax.text(
|
||||
x1 - 1.7,
|
||||
(y1 + y2) / 2,
|
||||
"MaxPool\n2x2\n↓ zmniejsz",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Encoder to bottleneck
|
||||
x1, y1 = encoder_layers[-1][0], encoder_layers[-1][1] - encoder_layers[-1][3] / 2
|
||||
draw_block(
|
||||
ax,
|
||||
bottleneck[0],
|
||||
bottleneck[1],
|
||||
bottleneck[2],
|
||||
bottleneck[3],
|
||||
bottleneck[4],
|
||||
GRAY2,
|
||||
)
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(bottleneck[0] - bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": "->", "color": ACCENT, "lw": 2},
|
||||
)
|
||||
|
||||
# Bottleneck to decoder
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(
|
||||
decoder_layers[0][0] - decoder_layers[0][2] / 2,
|
||||
decoder_layers[0][1] - decoder_layers[0][3] / 2,
|
||||
),
|
||||
xytext=(bottleneck[0] + bottleneck[2] / 2, bottleneck[1] + bottleneck[3] / 2),
|
||||
arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 2},
|
||||
)
|
||||
|
||||
# Draw decoder
|
||||
for x, y, w, h, label, channels in decoder_layers:
|
||||
color = "#C8E6C9" if channels != "C" else "#A5D6A7"
|
||||
draw_block(ax, x, y, w, h, label, color)
|
||||
|
||||
# Draw arrows up (decoder)
|
||||
for i in range(len(decoder_layers) - 1):
|
||||
x1, y1 = decoder_layers[i][0], decoder_layers[i][1] + decoder_layers[i][3] / 2
|
||||
x2, y2 = (
|
||||
decoder_layers[i + 1][0],
|
||||
decoder_layers[i + 1][1] - decoder_layers[i + 1][3] / 2,
|
||||
)
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT, "lw": 2},
|
||||
)
|
||||
ax.text(
|
||||
x1 + 2,
|
||||
(y1 + y2) / 2,
|
||||
"UpConv\n2x2\n↑ zwiększ",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=GREEN_ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Skip connections (horizontal arrows)
|
||||
for i in range(len(encoder_layers)):
|
||||
enc = encoder_layers[i]
|
||||
dec = decoder_layers[len(decoder_layers) - 1 - i]
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(dec[0] - dec[2] / 2, dec[1]),
|
||||
xytext=(enc[0] + enc[2] / 2, enc[1]),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": GRAY5,
|
||||
"lw": 1.5,
|
||||
"linestyle": "dashed",
|
||||
},
|
||||
)
|
||||
mid_x = (enc[0] + enc[2] / 2 + dec[0] - dec[2] / 2) / 2
|
||||
ax.text(
|
||||
mid_x,
|
||||
enc[1] + 0.6,
|
||||
"skip\n(concat)",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=GRAY5,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Labels
|
||||
ax.text(
|
||||
0,
|
||||
11.5,
|
||||
"ENCODER\n(↓ zmniejsza)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=ACCENT,
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
17,
|
||||
11.5,
|
||||
"DECODER\n(↑ zwiększa)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=GREEN_ACCENT,
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
8,
|
||||
-0.8,
|
||||
'Kształt litery „U": encoder schodzi ↓ → bottleneck na dnie → decoder wraca ↑',
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=GRAY5,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Concatenation explanation
|
||||
rect = FancyBboxPatch(
|
||||
(17.5, 3),
|
||||
3,
|
||||
5,
|
||||
boxstyle="round,pad=0.15",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=GRAY5,
|
||||
linewidth=1,
|
||||
linestyle="--",
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
19, 7.5, "Concatenation:", fontsize=FS_SMALL, ha="center", fontweight="bold"
|
||||
)
|
||||
ax.text(
|
||||
19,
|
||||
6.5,
|
||||
"Encoder: 64 kanały\nDecoder: 64 kanały\n→ concat → 128 kanałów\n\n"
|
||||
"Jak sklejenie\ndwóch stosów\nkart:",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
19,
|
||||
3.7,
|
||||
"[enc₁|enc₂|...|dec₁|dec₂|...]",
|
||||
fontsize=FS_TINY - 1,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color=ACCENT,
|
||||
)
|
||||
|
||||
_save_figure("q23_unet_arch.png")
|
||||
@ -0,0 +1,96 @@
|
||||
"""Common utilities and constants for Q23 diagram generation.
|
||||
|
||||
A4-compatible, monochrome-friendly (grays + one accent), 300 DPI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
DPI = 300
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Color palette — monochrome-friendly
|
||||
BLACK = "#000000"
|
||||
WHITE = "#FFFFFF"
|
||||
GRAY1 = "#F5F5F5"
|
||||
GRAY2 = "#E0E0E0"
|
||||
GRAY3 = "#BDBDBD"
|
||||
GRAY4 = "#9E9E9E"
|
||||
GRAY5 = "#757575"
|
||||
GRAY6 = "#424242"
|
||||
ACCENT = "#4A90D9" # single blue accent for highlights
|
||||
ACCENT_LIGHT = "#B3D4FC"
|
||||
RED_ACCENT = "#D32F2F"
|
||||
GREEN_ACCENT = "#388E3C"
|
||||
|
||||
FS = 9
|
||||
FS_TITLE = 11
|
||||
FS_SMALL = 7
|
||||
FS_TINY = 6
|
||||
|
||||
_RIDGE_X = 5
|
||||
_VALLEY2_END = 9
|
||||
_DARK_PIXEL_THRESHOLD = 100
|
||||
_GRID_LAST_IDX = 3
|
||||
_HIGHLIGHT_START = 3
|
||||
_HIGHLIGHT_END = 5
|
||||
_BRIGHT_THRESHOLD = 170
|
||||
_OTSU_THRESHOLD = 128
|
||||
|
||||
|
||||
def _save_figure(name: str) -> None:
|
||||
"""Save current figure and log."""
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / name),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor="white",
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ %s", name)
|
||||
|
||||
|
||||
def _render_text_lines(
|
||||
ax: Axes,
|
||||
lines: list[tuple[str, int, str, str]],
|
||||
*,
|
||||
x_pos: float = 0.5,
|
||||
start_y: float,
|
||||
y_step: float = 0.5,
|
||||
y_empty_step: float = 0.2,
|
||||
) -> None:
|
||||
"""Render a list of styled text lines on an axis."""
|
||||
y = start_y
|
||||
for txt, size, color, weight in lines:
|
||||
if txt == "":
|
||||
y -= y_empty_step
|
||||
continue
|
||||
ax.text(
|
||||
x_pos,
|
||||
y,
|
||||
txt,
|
||||
fontsize=size,
|
||||
color=color,
|
||||
fontweight=weight,
|
||||
va="top",
|
||||
)
|
||||
y -= y_step
|
||||
@ -0,0 +1,251 @@
|
||||
"""DIY U-Net step-by-step diagram generator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q23_common import (
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TINY,
|
||||
GRAY1,
|
||||
GRAY3,
|
||||
GRAY5,
|
||||
GREEN_ACCENT,
|
||||
WHITE,
|
||||
_save_figure,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
)
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def _draw_unet_layer_stack(
|
||||
ax: Axes,
|
||||
layer_sizes: list[tuple[int, int]],
|
||||
*,
|
||||
face_color: str,
|
||||
edge_color: str,
|
||||
arrow_color: str,
|
||||
arrow_label: str,
|
||||
add_skip: bool = False,
|
||||
) -> None:
|
||||
"""Draw encoder or decoder layer stack for DIY U-Net."""
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
|
||||
y_pos = 8.5
|
||||
for i, (s, c) in enumerate(layer_sizes):
|
||||
w = s / 64 * 4
|
||||
h = 0.8
|
||||
rect = FancyBboxPatch(
|
||||
(5 - w / 2, y_pos),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
facecolor=face_color,
|
||||
edgecolor=edge_color,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
label = f"{s}x{s}x{c}"
|
||||
if add_skip and i < len(layer_sizes) - 1:
|
||||
label += " + skip!"
|
||||
ax.text(
|
||||
5,
|
||||
y_pos + h / 2,
|
||||
label,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
if i < len(layer_sizes) - 1:
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(5, y_pos - 0.3),
|
||||
xytext=(5, y_pos),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": arrow_color,
|
||||
"lw": 1.5,
|
||||
},
|
||||
)
|
||||
ax.text(
|
||||
7,
|
||||
y_pos - 0.15,
|
||||
arrow_label,
|
||||
fontsize=FS_TINY,
|
||||
color=arrow_color,
|
||||
)
|
||||
y_pos -= 2.2
|
||||
|
||||
|
||||
def _draw_unet_pseudocode(ax: Axes) -> None:
|
||||
"""Draw panel 6: U-Net pseudocode."""
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Pseudokod U-Net", fontsize=FS, fontweight="bold")
|
||||
|
||||
code_lines = [
|
||||
"# ENCODER",
|
||||
"e1 = conv_block(input, 64) # 64x64",
|
||||
"e2 = conv_block(pool(e1), 128) # 32x32",
|
||||
"e3 = conv_block(pool(e2), 256) # 16x16",
|
||||
"",
|
||||
"# BOTTLENECK",
|
||||
"b = conv_block(pool(e3), 512) # 8x8",
|
||||
"",
|
||||
"# DECODER + SKIP",
|
||||
"d3 = conv_block(concat(",
|
||||
" upconv(b), e3), 256) # 16x16",
|
||||
"d2 = conv_block(concat(",
|
||||
" upconv(d3), e2), 128) # 32x32",
|
||||
"d1 = conv_block(concat(",
|
||||
" upconv(d2), e1), 64) # 64x64",
|
||||
"",
|
||||
"output = conv_1x1(d1, n_classes)",
|
||||
]
|
||||
for i, line in enumerate(code_lines):
|
||||
txt_color = (
|
||||
ACCENT
|
||||
if "concat" in line
|
||||
else (GREEN_ACCENT if "output" in line else BLACK)
|
||||
)
|
||||
ax.text(
|
||||
0.3,
|
||||
9.5 - i * 0.55,
|
||||
line,
|
||||
fontsize=FS_TINY,
|
||||
fontfamily="monospace",
|
||||
color=txt_color,
|
||||
)
|
||||
|
||||
|
||||
def generate_diy_unet() -> None:
|
||||
"""Generate diy unet."""
|
||||
fig, axes = plt.subplots(2, 3, figsize=(11, 7))
|
||||
|
||||
size = 64
|
||||
|
||||
# Create synthetic image with two regions
|
||||
img = np.ones((size, size, 3), dtype=np.uint8) * 200 # bright bg
|
||||
# Dark region (object 1)
|
||||
yy, xx = np.mgrid[:size, :size]
|
||||
mask1 = ((xx - 20) ** 2 + (yy - 30) ** 2) < 12**2
|
||||
img[mask1] = [60, 60, 60]
|
||||
# Medium region (object 2)
|
||||
mask2 = ((xx - 45) ** 2 + (yy - 25) ** 2) < 8**2
|
||||
img[mask2] = [120, 120, 120]
|
||||
|
||||
gt = np.zeros((size, size), dtype=np.uint8)
|
||||
gt[mask1] = 1 # class 1
|
||||
gt[mask2] = 2 # class 2
|
||||
|
||||
# --- Panel 1: Input image ---
|
||||
ax = axes[0, 0]
|
||||
ax.imshow(img)
|
||||
ax.set_title("Krok 1: obraz RGB\n64x64x3", fontsize=FS, fontweight="bold")
|
||||
ax.axis("off")
|
||||
|
||||
# --- Panel 2: Encoder shrinks ---
|
||||
ax = axes[0, 1]
|
||||
ax.set_title("Krok 2: Encoder ZMNIEJSZA", fontsize=FS, fontweight="bold")
|
||||
_draw_unet_layer_stack(
|
||||
ax,
|
||||
[(64, 3), (32, 64), (16, 128), (8, 256)],
|
||||
face_color=ACCENT_LIGHT,
|
||||
edge_color=ACCENT,
|
||||
arrow_color=ACCENT,
|
||||
arrow_label="Conv+Pool",
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
0.3,
|
||||
"Wyciąga cechy:\nkrawędzie → tekstury → obiekty",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
# --- Panel 3: Bottleneck ---
|
||||
ax = axes[0, 2]
|
||||
# Show feature maps at bottleneck (abstract)
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Krok 3: Bottleneck\n(najbardziej abstrakcyjne cechy)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Show small abstract feature maps
|
||||
for k in range(4):
|
||||
small = rng.random((4, 4))
|
||||
ax_inset = fig.add_axes(
|
||||
[0.68 + (k % 2) * 0.08, 0.72 - (k // 2) * 0.1, 0.06, 0.06]
|
||||
)
|
||||
ax_inset.imshow(small, cmap="gray")
|
||||
ax_inset.axis("off")
|
||||
|
||||
ax.text(
|
||||
5,
|
||||
5,
|
||||
'8x8x256\n\nMałe mapy, ale DUŻO kanałów\nKażdy kanał = jedna „cecha"\n'
|
||||
'(np. kanał 42 = „wykrył koło"\n kanał 78 = „wykrył krawędź")\n\n'
|
||||
"Wie CO jest na obrazie\nale nie wie GDZIE dokładnie",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# --- Panel 4: Decoder enlarges ---
|
||||
ax = axes[1, 0]
|
||||
ax.set_title(
|
||||
"Krok 4: Decoder ZWIĘKSZA\n(+ skip connections!)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
_draw_unet_layer_stack(
|
||||
ax,
|
||||
[(8, 256), (16, 128), (32, 64), (64, 3)],
|
||||
face_color="#C8E6C9",
|
||||
edge_color=GREEN_ACCENT,
|
||||
arrow_color=GREEN_ACCENT,
|
||||
arrow_label="UpConv+Concat",
|
||||
add_skip=True,
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
0.3,
|
||||
"Odtwarza rozdzielczość:\nskip → przywraca krawędzie",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
# --- Panel 5: Output segmentation map ---
|
||||
ax = axes[1, 1]
|
||||
cmap = plt.cm.colors.ListedColormap([WHITE, ACCENT_LIGHT, "#FFCDD2"])
|
||||
ax.imshow(gt, cmap=cmap, interpolation="nearest")
|
||||
ax.set_title(
|
||||
"Krok 5: mapa segmentacji\n64x64 (3 klasy)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
ax.axis("off")
|
||||
ax.text(20, -3, "Tło=0, obiekt A=1, obiekt B=2", fontsize=FS_TINY, ha="center")
|
||||
|
||||
# --- Panel 6: Summary pseudocode ---
|
||||
_draw_unet_pseudocode(axes[1, 2])
|
||||
|
||||
_save_figure("q23_diy_unet.png")
|
||||
@ -0,0 +1,380 @@
|
||||
"""Mean shift and normalized cuts diagram generators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q23_common import (
|
||||
_DARK_PIXEL_THRESHOLD,
|
||||
_GRID_LAST_IDX,
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TINY,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
GRAY6,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
_render_text_lines,
|
||||
_save_figure,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
)
|
||||
from matplotlib import patches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def generate_mean_shift() -> None:
|
||||
"""Generate mean shift."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
|
||||
|
||||
# --- Panel 1: Feature space concept ---
|
||||
ax = axes[0]
|
||||
# Three clusters in 2D feature space (brightness, x-position)
|
||||
c1x = rng.normal(2, 0.5, 40)
|
||||
c1y = rng.normal(2, 0.5, 40)
|
||||
c2x = rng.normal(6, 0.6, 35)
|
||||
c2y = rng.normal(7, 0.5, 35)
|
||||
c3x = rng.normal(8, 0.4, 25)
|
||||
c3y = rng.normal(3, 0.6, 25)
|
||||
|
||||
ax.scatter(c1x, c1y, c=GRAY4, s=15, alpha=0.7, zorder=3)
|
||||
ax.scatter(c2x, c2y, c=GRAY4, s=15, alpha=0.7, zorder=3)
|
||||
ax.scatter(c3x, c3y, c=GRAY4, s=15, alpha=0.7, zorder=3)
|
||||
|
||||
# Label peaks
|
||||
ax.scatter([2], [2], c=RED_ACCENT, s=80, marker="*", zorder=5, label="Max gęstości")
|
||||
ax.scatter([6], [7], c=RED_ACCENT, s=80, marker="*", zorder=5)
|
||||
ax.scatter([8], [3], c=RED_ACCENT, s=80, marker="*", zorder=5)
|
||||
|
||||
ax.set_xlabel("Cecha 1: jasność", fontsize=FS)
|
||||
ax.set_ylabel("Cecha 2: pozycja x", fontsize=FS)
|
||||
ax.set_title("Przestrzeń cech", fontsize=FS_TITLE, fontweight="bold")
|
||||
for lx, ly, ltxt in [
|
||||
(2, 0.3, "Klaster 1\n(ciemne, lewo)"),
|
||||
(6, 5.3, "Klaster 2\n(jasne, prawo)"),
|
||||
(8, 1.3, "Klaster 3\n(jasne, dół)"),
|
||||
]:
|
||||
ax.text(lx, ly, ltxt, ha="center", fontsize=FS_TINY, color=GRAY6)
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper left")
|
||||
|
||||
# --- Panel 2: Kernel/window moving ---
|
||||
ax = axes[1]
|
||||
ax.scatter(c1x, c1y, c=ACCENT_LIGHT, s=15, alpha=0.7, zorder=3)
|
||||
ax.scatter(c2x, c2y, c=GRAY3, s=15, alpha=0.7, zorder=3)
|
||||
ax.scatter(c3x, c3y, c=GRAY3, s=15, alpha=0.7, zorder=3)
|
||||
|
||||
# Show kernel movement
|
||||
path_x = [4.5, 3.8, 3.0, 2.3, 2.05]
|
||||
path_y = [4.0, 3.3, 2.7, 2.2, 2.03]
|
||||
|
||||
for i, (px, py) in enumerate(zip(path_x, path_y, strict=False)):
|
||||
alpha = 0.3 + 0.15 * i
|
||||
circle = plt.Circle(
|
||||
(px, py),
|
||||
1.2,
|
||||
fill=False,
|
||||
edgecolor=ACCENT,
|
||||
linewidth=1.5,
|
||||
linestyle="--" if i < len(path_x) - 1 else "-",
|
||||
alpha=alpha,
|
||||
)
|
||||
ax.add_patch(circle)
|
||||
if i < len(path_x) - 1:
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(path_x[i + 1], path_y[i + 1]),
|
||||
xytext=(px, py),
|
||||
arrowprops={"arrowstyle": "->", "color": RED_ACCENT, "lw": 1.5},
|
||||
)
|
||||
|
||||
ax.scatter([path_x[0]], [path_y[0]], c=ACCENT, s=50, marker="o", zorder=5)
|
||||
ax.scatter([path_x[-1]], [path_y[-1]], c=RED_ACCENT, s=80, marker="*", zorder=5)
|
||||
|
||||
ax.text(
|
||||
4.5, 5.2, "Start: losowy\npiksel", fontsize=FS_SMALL, ha="center", color=ACCENT
|
||||
)
|
||||
ax.text(
|
||||
2.05,
|
||||
0.5,
|
||||
"Koniec: max\ngęstości",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=RED_ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
7,
|
||||
8,
|
||||
"Okno (jądro)\nprzesuwa się\ndo skupiska",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=GRAY6,
|
||||
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.set_xlabel("Cecha 1", fontsize=FS)
|
||||
ax.set_ylabel("Cecha 2", fontsize=FS)
|
||||
ax.set_title("Jądro → max gęstości", fontsize=FS_TITLE, fontweight="bold")
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 9)
|
||||
|
||||
# --- Panel 3: Why no K parameter ---
|
||||
ax = axes[2]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Dlaczego bez K?", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
lines = [
|
||||
("K-means wymaga:", FS, RED_ACCENT, "bold"),
|
||||
(' „Podaj K=3 klastry"', FS_SMALL, "black", "normal"),
|
||||
(" Problem: skąd wiesz ile klastrów?", FS_SMALL, GRAY5, "normal"),
|
||||
("", 0, "", ""),
|
||||
("Mean Shift NIE wymaga K:", FS, GREEN_ACCENT, "bold"),
|
||||
(" Każdy piksel startuje → toczy się", FS_SMALL, "black", "normal"),
|
||||
(" → trafia do najbliższego szczytu", FS_SMALL, "black", "normal"),
|
||||
(" → ile szczytów = tyle segmentów", FS_SMALL, "black", "normal"),
|
||||
(" → automatycznie!", FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Parametr: bandwidth (szerokość okna)", FS, "black", "bold"),
|
||||
(" Duże okno → mało segmentów", FS_SMALL, "black", "normal"),
|
||||
(" Małe okno → dużo segmentów", FS_SMALL, "black", "normal"),
|
||||
("", 0, "", ""),
|
||||
("Okno = jądro (kernel):", FS, "black", "bold"),
|
||||
(" Koło o promieniu h wokół punktu.", FS_SMALL, "black", "normal"),
|
||||
(" Oblicz średnią pikseli W oknie.", FS_SMALL, "black", "normal"),
|
||||
(" Przesuń okno na tę średnią.", FS_SMALL, "black", "normal"),
|
||||
(" Powtórz aż się zatrzyma.", FS_SMALL, "black", "normal"),
|
||||
]
|
||||
_render_text_lines(ax, lines, start_y=9.0)
|
||||
|
||||
_save_figure("q23_mean_shift.png")
|
||||
|
||||
|
||||
def _draw_ncuts_pixel_grid(
|
||||
ax: Axes,
|
||||
pixel_vals: np.ndarray,
|
||||
) -> None:
|
||||
"""Draw 4x4 pixel grid with value labels and edge weights."""
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
v = pixel_vals[i, j]
|
||||
gray_val = v / 255.0
|
||||
str(gray_val)
|
||||
rect = patches.Rectangle(
|
||||
(j - 0.4, 3 - i - 0.4),
|
||||
0.8,
|
||||
0.8,
|
||||
facecolor=(gray_val, gray_val, gray_val),
|
||||
edgecolor=BLACK,
|
||||
linewidth=0.8,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
text_color = "white" if v < _DARK_PIXEL_THRESHOLD else "black"
|
||||
ax.text(
|
||||
j,
|
||||
3 - i,
|
||||
str(v),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
color=text_color,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
|
||||
def _draw_ncuts_edges(
|
||||
ax: Axes,
|
||||
pixel_vals: np.ndarray,
|
||||
) -> None:
|
||||
"""Draw weighted edges between adjacent pixels."""
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
if j < _GRID_LAST_IDX:
|
||||
similarity = max(
|
||||
0,
|
||||
1 - abs(pixel_vals[i, j] - pixel_vals[i, j + 1]) / 255,
|
||||
)
|
||||
lw = similarity * 2.5 + 0.3
|
||||
alpha = similarity * 0.8 + 0.2
|
||||
ax.plot(
|
||||
[j + 0.4, j + 0.6],
|
||||
[3 - i, 3 - i],
|
||||
color=GRAY5,
|
||||
linewidth=lw,
|
||||
alpha=alpha,
|
||||
)
|
||||
if i < _GRID_LAST_IDX:
|
||||
similarity = max(
|
||||
0,
|
||||
1 - abs(pixel_vals[i, j] - pixel_vals[i + 1, j]) / 255,
|
||||
)
|
||||
lw = similarity * 2.5 + 0.3
|
||||
alpha = similarity * 0.8 + 0.2
|
||||
ax.plot(
|
||||
[j, j],
|
||||
[3 - i - 0.4, 3 - i - 0.6],
|
||||
color=GRAY5,
|
||||
linewidth=lw,
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
|
||||
def generate_normalized_cuts() -> None:
|
||||
"""Generate normalized cuts."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
|
||||
|
||||
# --- Panel 1: Image as graph ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(-0.5, 4.5)
|
||||
ax.set_ylim(-0.5, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_title("Obraz → graf", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
pixel_vals = np.array(
|
||||
[
|
||||
[30, 35, 180, 190],
|
||||
[40, 30, 185, 200],
|
||||
[170, 180, 40, 35],
|
||||
[190, 175, 30, 45],
|
||||
]
|
||||
)
|
||||
_draw_ncuts_pixel_grid(ax, pixel_vals)
|
||||
_draw_ncuts_edges(ax, pixel_vals)
|
||||
|
||||
ax.text(
|
||||
2,
|
||||
-0.8,
|
||||
"Grube linie = duże podobieństwo\n(silna krawędź grafu)",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
ax.axis("off")
|
||||
|
||||
# --- Panel 2: Cut concept ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Cięcie grafu (graph cut)", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# Draw two groups of nodes
|
||||
# Group A (dark pixels)
|
||||
positions_a = [(2, 7), (3, 8), (2, 5), (3, 6)]
|
||||
positions_b = [(7, 7), (8, 8), (7, 5), (8, 6)]
|
||||
|
||||
# Intra-group edges (thick = similar)
|
||||
for i, (x1, y1) in enumerate(positions_a):
|
||||
for x2, y2 in positions_a[i + 1 :]:
|
||||
ax.plot([x1, x2], [y1, y2], color=ACCENT, linewidth=2, alpha=0.5)
|
||||
for i, (x1, y1) in enumerate(positions_b):
|
||||
for x2, y2 in positions_b[i + 1 :]:
|
||||
ax.plot([x1, x2], [y1, y2], color=RED_ACCENT, linewidth=2, alpha=0.5)
|
||||
|
||||
# Inter-group edges (thin = dissimilar) — these get cut
|
||||
cut_edges = [((3, 8), (7, 7)), ((3, 6), (7, 5)), ((2, 5), (7, 5))]
|
||||
for (x1, y1), (x2, y2) in cut_edges:
|
||||
ax.plot([x1, x2], [y1, y2], color=GRAY4, linewidth=0.8, linestyle="--")
|
||||
|
||||
# Draw nodes
|
||||
for x, y in positions_a:
|
||||
ax.scatter(x, y, c=ACCENT, s=120, zorder=5, edgecolors=BLACK, linewidth=0.8)
|
||||
for x, y in positions_b:
|
||||
ax.scatter(x, y, c="#FFCDD2", s=120, zorder=5, edgecolors=BLACK, linewidth=0.8)
|
||||
|
||||
# Cut line
|
||||
ax.plot(
|
||||
[5, 5], [3.5, 9.5], color=RED_ACCENT, linewidth=2.5, linestyle="-", zorder=4
|
||||
)
|
||||
ax.text(
|
||||
5, 9.8, "CIĘCIE", ha="center", fontsize=FS, fontweight="bold", color=RED_ACCENT
|
||||
)
|
||||
|
||||
ax.text(
|
||||
2.5,
|
||||
3.8,
|
||||
"Segment A\n(ciemne piksele)",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
color=ACCENT,
|
||||
)
|
||||
ax.text(
|
||||
7.5,
|
||||
3.8,
|
||||
"Segment B\n(jasne piksele)",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
|
||||
# Formula
|
||||
ax.text(
|
||||
5,
|
||||
1.8,
|
||||
"Ncut(A,B) = cut(A,B)/assoc(A,V)\n + cut(A,B)/assoc(B,V)",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
0.5,
|
||||
"Minimalizuj Ncut → tnij SŁABE krawędzie\nzachowuj SILNE (wewnątrz grupy)",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
# --- Panel 3: Algorithm summary ---
|
||||
ax = axes[2]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Algorytm Normalized Cuts", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
steps = [
|
||||
(
|
||||
"1. Zbuduj graf",
|
||||
"Piksele = węzły\nKrawędzie = podobieństwo"
|
||||
" sąsiadów\n(kolor, jasność, odległość)",
|
||||
),
|
||||
(
|
||||
"2. Macierz podobieństwa W",
|
||||
"W[i,j] = exp(-|kolori - kolorj|² / σ²)"
|
||||
"\n→ im podobniejsze, tym wyższa waga",
|
||||
),
|
||||
("3. Macierz stopni D", "D[i,i] = Σ W[i,j]\n(suma wszystkich wag z węzła i)"),
|
||||
("4. Rozwiąż problem własny", "(D-W)·y = λ·D·y\n→ drugi najm. wektor własny y"),
|
||||
("5. Podziel wg y", "y[i] > 0 → segment A\ny[i] ≤ 0 → segment B"),
|
||||
]
|
||||
|
||||
y = 9.5
|
||||
for title, desc in steps:
|
||||
ax.text(0.5, y, title, fontsize=FS, fontweight="bold", va="top")
|
||||
y -= 0.4
|
||||
ax.text(0.8, y, desc, fontsize=FS_TINY, va="top", color=GRAY6)
|
||||
y -= 1.2
|
||||
|
||||
ax.text(
|
||||
5,
|
||||
0.3,
|
||||
"Złożoność: O(n³) — wymaga eigen decomposition!",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
|
||||
_save_figure("q23_normalized_cuts.png")
|
||||
@ -0,0 +1,327 @@
|
||||
"""Mnemonic summary diagram generator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q23_common import (
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TINY,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY5,
|
||||
GRAY6,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
_save_figure,
|
||||
plt,
|
||||
)
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def generate_mnemonics() -> None:
|
||||
"""Generate mnemonics."""
|
||||
_fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
||||
ax.set_xlim(0, 20)
|
||||
ax.set_ylim(0, 16)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Mnemoniki — segmentacja obrazu", fontsize=FS_TITLE + 2, fontweight="bold"
|
||||
)
|
||||
|
||||
def draw_card(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
title: str,
|
||||
mnemonic: str,
|
||||
color: str,
|
||||
detail: str = "",
|
||||
) -> None:
|
||||
"""Draw card."""
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.15",
|
||||
facecolor=color,
|
||||
edgecolor=BLACK,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h - 0.3,
|
||||
title,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2 - 0.1,
|
||||
mnemonic,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontstyle="italic",
|
||||
color=GRAY6,
|
||||
)
|
||||
if detail:
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + 0.4,
|
||||
detail,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
# Title: STRATEGIE KLASYCZNE
|
||||
ax.text(
|
||||
5,
|
||||
15.5,
|
||||
"STRATEGIE KLASYCZNE",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
color=ACCENT,
|
||||
ha="center",
|
||||
)
|
||||
|
||||
cards_classic = [
|
||||
(
|
||||
0.2,
|
||||
12.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Thresholding",
|
||||
'„PRÓG na bramce"\nPrzepuszcza > T,\nblokuje ≤ T',
|
||||
ACCENT_LIGHT,
|
||||
"jasne=1, ciemne=0",
|
||||
),
|
||||
(
|
||||
5,
|
||||
12.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Otsu",
|
||||
'„AUTO-bramkarz"\nSam dobiera próg\nmin σ² wewnątrz',
|
||||
ACCENT_LIGHT,
|
||||
"histogram bimodalny",
|
||||
),
|
||||
(
|
||||
0.2,
|
||||
9.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Region Growing",
|
||||
'„PLAMA rozlana"\nSeed → BFS po\npodobnych sąsiadach',
|
||||
ACCENT_LIGHT,
|
||||
"jak atrament na papierze",
|
||||
),
|
||||
(
|
||||
5,
|
||||
9.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Watershed",
|
||||
'„ZALEWANIE terenu"\nDoliny=obiekty\nGranie=granice',
|
||||
ACCENT_LIGHT,
|
||||
"woda + geography",
|
||||
),
|
||||
(
|
||||
0.2,
|
||||
6.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Mean Shift",
|
||||
'„KULKI toczą się"\nKażda → max gęstości\nBez K!',
|
||||
ACCENT_LIGHT,
|
||||
"bandwidth = okno",
|
||||
),
|
||||
(
|
||||
5,
|
||||
6.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Normalized Cuts",
|
||||
'„CIĘCIE sznurków"\nGraf: tnij słabe\nkrawędzie (O(n³)!)',
|
||||
ACCENT_LIGHT,
|
||||
"eigenvector problem",
|
||||
),
|
||||
]
|
||||
|
||||
for args in cards_classic:
|
||||
draw_card(ax, *args)
|
||||
|
||||
# Title: SIECI NEURONOWE
|
||||
ax.text(
|
||||
15,
|
||||
15.5,
|
||||
"SIECI NEURONOWE",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
color=GREEN_ACCENT,
|
||||
ha="center",
|
||||
)
|
||||
|
||||
cards_nn = [
|
||||
(
|
||||
10.5,
|
||||
12.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"FCN (2015)",
|
||||
'„FC → Conv 1x1"\nPierwsza end-to-end\nDowolny rozmiar',
|
||||
"#C8E6C9",
|
||||
"skip connections",
|
||||
),
|
||||
(
|
||||
15.3,
|
||||
12.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"U-Net (2015)",
|
||||
'„Litera U"\nEncoder↓ Decoder↑\nSkip = concat',
|
||||
"#C8E6C9",
|
||||
"medycyna, małe dane",
|
||||
),
|
||||
(
|
||||
10.5,
|
||||
9.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"DeepLab v3+",
|
||||
'„DZIURY w filtrze"\nAtrous conv (rate)\nASPP multi-scale',
|
||||
"#C8E6C9",
|
||||
"à trous = z dziurami",
|
||||
),
|
||||
(
|
||||
15.3,
|
||||
9.5,
|
||||
4.5,
|
||||
2.5,
|
||||
"Transformer",
|
||||
'„WSZYSCY ze\nWSZYSTKIMI"\nSelf-attention O(n²)',
|
||||
"#C8E6C9",
|
||||
"SegFormer, Mask2Former",
|
||||
),
|
||||
]
|
||||
|
||||
for args in cards_nn:
|
||||
draw_card(ax, *args)
|
||||
|
||||
# Metryki
|
||||
ax.text(
|
||||
10,
|
||||
8.3,
|
||||
"METRYKI I LOSS",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
color=RED_ACCENT,
|
||||
ha="center",
|
||||
)
|
||||
|
||||
cards_metrics = [
|
||||
(
|
||||
10.5,
|
||||
6.5,
|
||||
4.5,
|
||||
1.6,
|
||||
"mIoU",
|
||||
'„Nakładka / Suma"\nIoU = A∩B / A\u222aB',
|
||||
"#FFCDD2",
|
||||
"",
|
||||
),
|
||||
(
|
||||
15.3,
|
||||
6.5,
|
||||
4.5,
|
||||
1.6,
|
||||
"Dice / Focal",
|
||||
'„Dice=2·nakładka"\nFocal=trudne px',
|
||||
"#FFCDD2",
|
||||
"",
|
||||
),
|
||||
]
|
||||
|
||||
for args in cards_metrics:
|
||||
draw_card(ax, *args)
|
||||
|
||||
# Master mnemonic at bottom
|
||||
rect = FancyBboxPatch(
|
||||
(1, 0.3),
|
||||
18,
|
||||
5.5,
|
||||
boxstyle="round,pad=0.2",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=BLACK,
|
||||
linewidth=1.5,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
10,
|
||||
5.3,
|
||||
"SUPER-MNEMONIK: kolejność algorytmów segmentacji",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
10,
|
||||
4.5,
|
||||
'„TORW-MN FUD-T"',
|
||||
ha="center",
|
||||
fontsize=FS_TITLE + 2,
|
||||
fontweight="bold",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
ax.text(
|
||||
10,
|
||||
3.5,
|
||||
"Klasyczne: Thresholding → Otsu → Region"
|
||||
" growing → Watershed → Mean shift → Norm. cuts",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
ax.text(
|
||||
10,
|
||||
2.8,
|
||||
"Neuronowe: FCN → U-Net → DeepLab → Transformer",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
ax.text(
|
||||
10,
|
||||
1.8,
|
||||
"„Turyści Oglądają Rzekę, Wodospad,"
|
||||
" Morze, Nurt — Fotografują Uroczy"
|
||||
' Dwór Tajemnic"',
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontstyle="italic",
|
||||
color=ACCENT,
|
||||
)
|
||||
ax.text(
|
||||
10,
|
||||
1.0,
|
||||
"Klasyczne: proste→auto→BFS→flood→"
|
||||
"gęstość→graf | Neuronowe:"
|
||||
" FC→U-skip→dilated→attention",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
_save_figure("q23_mnemonics.png")
|
||||
@ -0,0 +1,293 @@
|
||||
"""ReLU and dot product diagram generators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q23_common import (
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TINY,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY5,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
_save_figure,
|
||||
np,
|
||||
plt,
|
||||
)
|
||||
from matplotlib import patches
|
||||
|
||||
|
||||
def generate_relu() -> None:
|
||||
"""Generate relu."""
|
||||
_fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
|
||||
|
||||
# --- Panel 1: ReLU plot ---
|
||||
ax = axes[0]
|
||||
x = np.linspace(-5, 5, 200)
|
||||
relu = np.maximum(0, x)
|
||||
ax.plot(x, relu, color=ACCENT, linewidth=2.5, label="ReLU(x) = max(0, x)")
|
||||
ax.axhline(y=0, color=GRAY3, linewidth=0.5)
|
||||
ax.axvline(x=0, color=GRAY3, linewidth=0.5)
|
||||
ax.fill_between(x[x < 0], 0, 0, color=RED_ACCENT, alpha=0.1)
|
||||
ax.fill_between(x[x >= 0], 0, relu[x >= 0], color=ACCENT, alpha=0.1)
|
||||
|
||||
# Annotations
|
||||
ax.annotate(
|
||||
'x < 0 → output = 0\n(neuron „wyłączony")',
|
||||
xy=(-3, 0),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
color=RED_ACCENT,
|
||||
arrowprops={"arrowstyle": "->", "color": RED_ACCENT},
|
||||
xytext=(-3, 2),
|
||||
)
|
||||
ax.annotate(
|
||||
'x ≥ 0 → output = x\n(neuron „włączony")',
|
||||
xy=(3, 3),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
color=ACCENT,
|
||||
arrowprops={"arrowstyle": "->", "color": ACCENT},
|
||||
xytext=(3, 4.5),
|
||||
)
|
||||
ax.scatter([0], [0], c=BLACK, s=40, zorder=5)
|
||||
ax.text(0.3, -0.5, "(0,0)", fontsize=FS_SMALL, color=GRAY5)
|
||||
ax.set_xlabel("x (wejście neuronu)", fontsize=FS)
|
||||
ax.set_ylabel("ReLU(x)", fontsize=FS)
|
||||
ax.set_title("ReLU — Rectified Linear Unit", fontsize=FS_TITLE, fontweight="bold")
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper left")
|
||||
ax.set_ylim(-1, 6)
|
||||
ax.grid(visible=True, alpha=0.2)
|
||||
|
||||
# --- Panel 2: Why ReLU ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Dlaczego ReLU?", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
y = 9.0
|
||||
lines = [
|
||||
("Neuron oblicza:", FS, BLACK, "bold"),
|
||||
(" z = w₁·x₁ + w₂·x₂ + ... + bias", FS_SMALL, BLACK, "normal"),
|
||||
(" output = ReLU(z) = max(0, z)", FS_SMALL, ACCENT, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Przykład:", FS, BLACK, "bold"),
|
||||
(" wagi: w₁=0.5, w₂=-0.3, bias=0.1", FS_SMALL, BLACK, "normal"),
|
||||
(" wejścia: x₁=2.0, x₂=4.0", FS_SMALL, BLACK, "normal"),
|
||||
(" z = 0.5·2 + (-0.3)·4 + 0.1 = -0.1", FS_SMALL, BLACK, "normal"),
|
||||
(" ReLU(-0.1) = max(0, -0.1) = 0", FS_SMALL, RED_ACCENT, "bold"),
|
||||
(" → neuron milczy (wejście nieistotne)", FS_SMALL, GRAY5, "normal"),
|
||||
("", 0, "", ""),
|
||||
("Gdyby z = 2.3:", FS, BLACK, "bold"),
|
||||
(" ReLU(2.3) = max(0, 2.3) = 2.3", FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
(" → neuron aktywny! Przekazuje sygnał", FS_SMALL, GRAY5, "normal"),
|
||||
("", 0, "", ""),
|
||||
("Szybsza niż sigmoid/tanh", FS_SMALL, GRAY5, "normal"),
|
||||
("(brak exp() → szybkie obliczenia)", FS_SMALL, GRAY5, "normal"),
|
||||
]
|
||||
for txt, size, color, weight in lines:
|
||||
if txt == "":
|
||||
y -= 0.2
|
||||
continue
|
||||
ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
|
||||
y -= 0.5
|
||||
|
||||
_save_figure("q23_relu.png")
|
||||
|
||||
|
||||
def generate_dot_product() -> None:
|
||||
"""Generate dot product."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
|
||||
|
||||
# --- Panel 1: Concept ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Iloczyn skalarny\n(dot product)", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
y = 8.5
|
||||
lines = [
|
||||
("Dwa wektory (listy liczb) → JEDNA liczba", FS, BLACK, "bold"),
|
||||
("", 0, "", ""),
|
||||
("a = [a₁, a₂, a₃] b = [b₁, b₂, b₃]", FS, ACCENT, "normal"),
|
||||
("", 0, "", ""),
|
||||
("a · b = a₁·b₁ + a₂·b₂ + a₃·b₃", FS, BLACK, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Przykład:", FS, BLACK, "bold"),
|
||||
("a = [1, 3, -2] b = [4, -1, 5]", FS_SMALL, BLACK, "normal"),
|
||||
("a·b = 1·4 + 3·(-1) + (-2)·5", FS_SMALL, BLACK, "normal"),
|
||||
(" = 4 + (-3) + (-10) = -9", FS_SMALL, RED_ACCENT, "bold"),
|
||||
("", 0, "", ""),
|
||||
(
|
||||
'Duży wynik → wektory „podobne" (w tym samym kierunku)',
|
||||
FS_SMALL,
|
||||
GREEN_ACCENT,
|
||||
"normal",
|
||||
),
|
||||
('Mały/ujemny → wektory „różne"', FS_SMALL, RED_ACCENT, "normal"),
|
||||
]
|
||||
for txt, size, color, weight in lines:
|
||||
if txt == "":
|
||||
y -= 0.25
|
||||
continue
|
||||
ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
|
||||
y -= 0.55
|
||||
|
||||
# --- Panel 2: Convolution as dot product ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(-0.5, 5.5)
|
||||
ax.set_ylim(-0.5, 5.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_title(
|
||||
"Konwolucja = iloczyn skalarny\nfiltra x fragment obrazu",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Filter 3x3
|
||||
filter_vals = [[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
rect = patches.Rectangle(
|
||||
(j - 0.4, 4 - i - 0.4),
|
||||
0.8,
|
||||
0.8,
|
||||
facecolor=ACCENT_LIGHT,
|
||||
edgecolor=BLACK,
|
||||
linewidth=0.8,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
j,
|
||||
4 - i,
|
||||
str(filter_vals[i][j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(1, 1.5, "Filtr", ha="center", fontsize=FS, fontweight="bold", color=ACCENT)
|
||||
|
||||
# Image patch
|
||||
img_vals = [[50, 50, 200], [50, 50, 200], [50, 50, 200]]
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
rect = patches.Rectangle(
|
||||
(j + 2.6, 4 - i - 0.4),
|
||||
0.8,
|
||||
0.8,
|
||||
facecolor=GRAY2,
|
||||
edgecolor=BLACK,
|
||||
linewidth=0.8,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
j + 3,
|
||||
4 - i,
|
||||
str(img_vals[i][j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
4,
|
||||
1.5,
|
||||
"Fragment\nobrazu",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
ax.text(
|
||||
2.5,
|
||||
0.5,
|
||||
"(-1)·50 + 0·50 + 1·200 +\n"
|
||||
"(-1)·50 + 0·50 + 1·200 +\n"
|
||||
"(-1)·50 + 0·50 + 1·200\n= 450 (krawędź!)",
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GREEN_ACCENT},
|
||||
)
|
||||
|
||||
ax.axis("off")
|
||||
|
||||
# --- Panel 3: Vector visualization ---
|
||||
ax = axes[2]
|
||||
# Draw two vectors
|
||||
ax.quiver(
|
||||
0,
|
||||
0,
|
||||
3,
|
||||
4,
|
||||
angles="xy",
|
||||
scale_units="xy",
|
||||
scale=1,
|
||||
color=ACCENT,
|
||||
width=0.025,
|
||||
label="a = [3, 4]",
|
||||
)
|
||||
ax.quiver(
|
||||
0,
|
||||
0,
|
||||
4,
|
||||
1,
|
||||
angles="xy",
|
||||
scale_units="xy",
|
||||
scale=1,
|
||||
color=RED_ACCENT,
|
||||
width=0.025,
|
||||
label="b = [4, 1]",
|
||||
)
|
||||
|
||||
# Show angle
|
||||
theta = np.linspace(np.arctan2(1, 4), np.arctan2(4, 3), 30)
|
||||
r = 1.5
|
||||
ax.plot(r * np.cos(theta), r * np.sin(theta), color=GREEN_ACCENT, linewidth=1.5)
|
||||
ax.text(1.8, 1.3, "θ", fontsize=FS, color=GREEN_ACCENT, fontweight="bold")
|
||||
|
||||
ax.text(3.2, 4.2, "a", fontsize=FS, color=ACCENT, fontweight="bold")
|
||||
ax.text(4.2, 1.2, "b", fontsize=FS, color=RED_ACCENT, fontweight="bold")
|
||||
|
||||
ax.text(
|
||||
2.5,
|
||||
-1.0,
|
||||
"a · b = |a|·|b|·cos(θ)\n= 3·4 + 4·1 = 16",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round", "facecolor": GRAY1, "edgecolor": GRAY3},
|
||||
)
|
||||
ax.text(
|
||||
2.5,
|
||||
-2.0,
|
||||
'Mały kąt θ → duży dot product\n= wektory „zgadają się"',
|
||||
ha="center",
|
||||
fontsize=FS_TINY,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
ax.set_xlim(-0.5, 5.5)
|
||||
ax.set_ylim(-2.5, 5.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.grid(visible=True, alpha=0.2)
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper left")
|
||||
ax.set_title("Geometrycznie: kąt", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
_save_figure("q23_dot_product.png")
|
||||
@ -0,0 +1,408 @@
|
||||
"""Otsu thresholding and watershed diagram generators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q23_common import (
|
||||
_RIDGE_X,
|
||||
_VALLEY2_END,
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
_render_text_lines,
|
||||
_save_figure,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
)
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def _draw_otsu_variance_panel(ax: Axes) -> None:
|
||||
"""Draw panel 2: within-class variance explanation."""
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Wariancja wewnątrzklasowa", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
texts = [
|
||||
(
|
||||
"Wariancja = jak bardzo wartości\nróżnią się od średniej",
|
||||
FS,
|
||||
"black",
|
||||
"normal",
|
||||
),
|
||||
("", 0, "black", "normal"),
|
||||
("Klasa 0 (piksele ≤ T):", FS, ACCENT, "bold"),
|
||||
(" wartości: 30, 50, 45, 60, 55", FS_SMALL, "black", "normal"),
|
||||
(" średnia μ₀ = 48", FS_SMALL, "black", "normal"),
|
||||
(" σ₀² = ((30-48)²+(50-48)²+...)/5 = 108", FS_SMALL, "black", "normal"),
|
||||
("", 0, "black", "normal"),
|
||||
("Klasa 1 (piksele > T):", FS, RED_ACCENT, "bold"),
|
||||
(" wartości: 180, 200, 190, 210, 195", FS_SMALL, "black", "normal"),
|
||||
(" średnia μ₁ = 195", FS_SMALL, "black", "normal"),
|
||||
(" σ₁² = ((180-195)²+...)/5 = 100", FS_SMALL, "black", "normal"),
|
||||
("", 0, "black", "normal"),
|
||||
("σ²_wewnątrz = w₀·σ₀² + w₁·σ₁²", FS, BLACK, "bold"),
|
||||
("= 0.6·108 + 0.4·100 = 104.8", FS_SMALL, "black", "normal"),
|
||||
("", 0, "black", "normal"),
|
||||
("Otsu próbuje KAŻDE T: 0,1,...,255", FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
("Wybiera T dające MINIMUM σ²_wewnątrz", FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
]
|
||||
_render_text_lines(
|
||||
ax,
|
||||
texts,
|
||||
x_pos=0.3,
|
||||
start_y=9.2,
|
||||
y_step=0.55,
|
||||
y_empty_step=0.25,
|
||||
)
|
||||
|
||||
|
||||
def generate_otsu_bimodal() -> None:
|
||||
"""Generate otsu bimodal."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
|
||||
|
||||
# --- Panel 1: Bimodal histogram ---
|
||||
ax = axes[0]
|
||||
dark = rng.normal(60, 20, 3000).clip(0, 255)
|
||||
bright = rng.normal(190, 25, 2000).clip(0, 255)
|
||||
all_pixels = np.concatenate([dark, bright])
|
||||
|
||||
counts, _bins, _bars = ax.hist(
|
||||
all_pixels, bins=64, color=GRAY3, edgecolor=GRAY5, linewidth=0.5
|
||||
)
|
||||
ax.axvline(
|
||||
x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="Próg Otsu T=128"
|
||||
)
|
||||
ax.fill_betweenx([0, max(counts) * 1.1], 0, 128, alpha=0.12, color=ACCENT)
|
||||
ax.fill_betweenx([0, max(counts) * 1.1], 128, 255, alpha=0.12, color=RED_ACCENT)
|
||||
ax.text(
|
||||
45,
|
||||
max(counts) * 0.85,
|
||||
"Klasa 0\n(tło)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=ACCENT,
|
||||
)
|
||||
ax.text(
|
||||
195,
|
||||
max(counts) * 0.85,
|
||||
"Klasa 1\n(obiekt)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
ax.annotate(
|
||||
"Garb 1",
|
||||
xy=(60, max(counts) * 0.6),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
arrowprops={"arrowstyle": "->", "color": GRAY5},
|
||||
xytext=(30, max(counts) * 0.45),
|
||||
)
|
||||
ax.annotate(
|
||||
"Garb 2",
|
||||
xy=(190, max(counts) * 0.5),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
arrowprops={"arrowstyle": "->", "color": GRAY5},
|
||||
xytext=(220, max(counts) * 0.35),
|
||||
)
|
||||
ax.set_xlabel("Jasność piksela (0-255)", fontsize=FS)
|
||||
ax.set_ylabel("Liczba pikseli", fontsize=FS)
|
||||
ax.set_title("Histogram bimodalny", fontsize=FS_TITLE, fontweight="bold")
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper right")
|
||||
ax.set_xlim(0, 255)
|
||||
|
||||
# --- Panel 2: Within-class variance explanation ---
|
||||
_draw_otsu_variance_panel(axes[1])
|
||||
|
||||
# --- Panel 3: Jednorodność explanation ---
|
||||
ax = axes[2]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title('"Jednorodne" = małe σ²', fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# Draw two clusters
|
||||
# Good separation
|
||||
c0 = rng.normal(2, 0.4, 15)
|
||||
c1 = rng.normal(7, 0.4, 15)
|
||||
y_pos_0 = rng.uniform(6, 8, 15)
|
||||
y_pos_1 = rng.uniform(6, 8, 15)
|
||||
ax.scatter(c0, y_pos_0, c=ACCENT, s=30, zorder=5, label="Klasa 0")
|
||||
ax.scatter(c1, y_pos_1, c=RED_ACCENT, s=30, zorder=5, label="Klasa 1")
|
||||
ax.axvline(x=4.5, color=GREEN_ACCENT, linewidth=2, linestyle="--")
|
||||
ax.text(
|
||||
4.5,
|
||||
8.8,
|
||||
"T optymalny",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
color=GREEN_ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
2, 5.3, "σ₀² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=ACCENT
|
||||
)
|
||||
ax.text(
|
||||
7, 5.3, "σ₁² mała\n(skupione)", ha="center", fontsize=FS_SMALL, color=RED_ACCENT
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
4,
|
||||
"→ σ²_wewnątrz MINIMALNA\n→ klasy JEDNORODNE\n→ dobra segmentacja!",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=GREEN_ACCENT,
|
||||
)
|
||||
|
||||
# Bad separation
|
||||
c0b = rng.normal(3.5, 1.5, 15)
|
||||
c1b = rng.normal(6, 1.5, 15)
|
||||
y_pos_0b = rng.uniform(1, 3, 15)
|
||||
y_pos_1b = rng.uniform(1, 3, 15)
|
||||
ax.scatter(c0b, y_pos_0b, c=ACCENT, s=30, marker="x", zorder=5)
|
||||
ax.scatter(c1b, y_pos_1b, c=RED_ACCENT, s=30, marker="x", zorder=5)
|
||||
ax.axvline(x=4.5, color=GRAY4, linewidth=1, linestyle=":", ymin=0, ymax=0.35)
|
||||
ax.text(
|
||||
5,
|
||||
0.3,
|
||||
"σ²_wewnątrz DUŻA → klasy mieszają się → zły próg",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
ax.legend(fontsize=FS_SMALL, loc="upper left")
|
||||
|
||||
_save_figure("q23_otsu_bimodal.png")
|
||||
|
||||
|
||||
def _draw_watershed_result_panel(ax: Axes) -> None:
|
||||
"""Draw panel 3: watershed result with over-segmentation problem."""
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Krok 3: wynik", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
rect1 = FancyBboxPatch(
|
||||
(0.5, 6),
|
||||
3.5,
|
||||
3.2,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor=ACCENT_LIGHT,
|
||||
edgecolor=BLACK,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect1)
|
||||
ax.text(2.25, 8.8, "Ideał: 2 segmenty", fontsize=FS, ha="center", fontweight="bold")
|
||||
ax.text(2.25, 7.5, "Segment A Segment B", fontsize=FS_SMALL, ha="center")
|
||||
ax.text(
|
||||
2.25,
|
||||
6.7,
|
||||
"(po marker-controlled)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=GREEN_ACCENT,
|
||||
)
|
||||
|
||||
rect2 = FancyBboxPatch(
|
||||
(5.5, 6),
|
||||
4,
|
||||
3.2,
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor="#FFCDD2",
|
||||
edgecolor=BLACK,
|
||||
linewidth=1,
|
||||
)
|
||||
ax.add_patch(rect2)
|
||||
ax.text(
|
||||
7.5,
|
||||
8.8,
|
||||
"Problem: over-segmentation",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
ax.text(
|
||||
7.5,
|
||||
7.8,
|
||||
"47 regionów zamiast 2!",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
ax.text(7.5, 7.1, "Każde mini-minimum", fontsize=FS_SMALL, ha="center")
|
||||
ax.text(7.5, 6.5, '→ osobna „dolina"', fontsize=FS_SMALL, ha="center")
|
||||
|
||||
# Apply marker-controlled solution
|
||||
rect3 = FancyBboxPatch(
|
||||
(1, 0.5),
|
||||
8,
|
||||
4.5,
|
||||
boxstyle="round,pad=0.15",
|
||||
facecolor=GRAY1,
|
||||
edgecolor=GREEN_ACCENT,
|
||||
linewidth=1.5,
|
||||
)
|
||||
ax.add_patch(rect3)
|
||||
ax.text(
|
||||
5,
|
||||
4.3,
|
||||
"Rozwiązanie: Marker-controlled watershed",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color=GREEN_ACCENT,
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
3.4,
|
||||
'1. Zaznacz ręcznie „seeds" (markery) w każdym obiekcie',
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
2.7,
|
||||
"2. Zalewaj TYLKO od tych markerów (nie od wszystkich minimów)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
2.0,
|
||||
"3. Eliminuje fałszywe doliny z szumu",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
)
|
||||
ax.text(
|
||||
5,
|
||||
1.2,
|
||||
"Wynik: tyle segmentów, ile podano markerów",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
|
||||
def generate_watershed() -> None:
|
||||
"""Generate watershed."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 3.8))
|
||||
|
||||
# --- Panel 1: Image as topographic surface ---
|
||||
ax = axes[0]
|
||||
x = np.linspace(0, 10, 200)
|
||||
# Create a surface with two valleys and a ridge
|
||||
surface = (
|
||||
3 * np.exp(-((x - 3) ** 2) / 1.5)
|
||||
+ 4 * np.exp(-((x - 7) ** 2) / 1.2)
|
||||
+ 0.5 * np.sin(x * 2)
|
||||
+ 1
|
||||
)
|
||||
# Invert: valleys at objects (dark), peaks at boundaries (bright)
|
||||
surface_inv = 6 - surface + 1
|
||||
|
||||
ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7)
|
||||
ax.plot(x, surface_inv, color=BLACK, linewidth=1.5)
|
||||
|
||||
# Mark valleys
|
||||
ax.annotate(
|
||||
"Dolina 1\n(obiekt A)",
|
||||
xy=(3, surface_inv[60]),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
arrowprops={"arrowstyle": "->", "color": ACCENT},
|
||||
xytext=(1.5, 5.5),
|
||||
)
|
||||
ax.annotate(
|
||||
"Dolina 2\n(obiekt B)",
|
||||
xy=(7, surface_inv[140]),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
arrowprops={"arrowstyle": "->", "color": RED_ACCENT},
|
||||
xytext=(8.5, 5.5),
|
||||
)
|
||||
# Mark ridge
|
||||
ax.annotate(
|
||||
"Grań\n(granica)",
|
||||
xy=(5, surface_inv[100]),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT},
|
||||
xytext=(5, 6.5),
|
||||
)
|
||||
|
||||
ax.set_xlabel("Pozycja piksela", fontsize=FS)
|
||||
ax.set_ylabel("Jasność (= wysokość)", fontsize=FS)
|
||||
ax.set_title("Krok 1: obraz → teren", fontsize=FS_TITLE, fontweight="bold")
|
||||
ax.set_ylim(0, 7)
|
||||
|
||||
# --- Panel 2: Flooding ---
|
||||
ax = axes[1]
|
||||
ax.fill_between(x, 0, surface_inv, color=GRAY2, alpha=0.7)
|
||||
ax.plot(x, surface_inv, color=BLACK, linewidth=1.5)
|
||||
|
||||
# Water level
|
||||
water_level = 3.2
|
||||
|
||||
# Fill water in valley 1
|
||||
x_v1 = x[(x > 1) & (x < _RIDGE_X)]
|
||||
s_v1 = surface_inv[(x > 1) & (x < _RIDGE_X)]
|
||||
ax.fill_between(
|
||||
x_v1, s_v1, water_level, where=s_v1 < water_level, color=ACCENT_LIGHT, alpha=0.6
|
||||
)
|
||||
# Fill water in valley 2
|
||||
x_v2 = x[(x > _RIDGE_X) & (x < _VALLEY2_END)]
|
||||
s_v2 = surface_inv[(x > _RIDGE_X) & (x < _VALLEY2_END)]
|
||||
ax.fill_between(
|
||||
x_v2, s_v2, water_level, where=s_v2 < water_level, color="#FFCDD2", alpha=0.6
|
||||
)
|
||||
|
||||
ax.axhline(y=water_level, color=ACCENT, linewidth=1, linestyle="--", alpha=0.5)
|
||||
ax.text(3, 2.5, "Woda A", fontsize=FS, ha="center", color=ACCENT, fontweight="bold")
|
||||
ax.text(
|
||||
7, 2.2, "Woda B", fontsize=FS, ha="center", color=RED_ACCENT, fontweight="bold"
|
||||
)
|
||||
ax.annotate(
|
||||
"Tu się spotkają!\n→ GRANICA",
|
||||
xy=(5, surface_inv[100]),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=GREEN_ACCENT,
|
||||
fontweight="bold",
|
||||
arrowprops={"arrowstyle": "->", "color": GREEN_ACCENT},
|
||||
xytext=(5, 6.2),
|
||||
)
|
||||
|
||||
ax.set_xlabel("Pozycja piksela", fontsize=FS)
|
||||
ax.set_title("Krok 2: zalewanie", fontsize=FS_TITLE, fontweight="bold")
|
||||
ax.set_ylim(0, 7)
|
||||
|
||||
# --- Panel 3: Result with problem ---
|
||||
_draw_watershed_result_panel(axes[2])
|
||||
|
||||
_save_figure("q23_watershed.png")
|
||||
@ -0,0 +1,286 @@
|
||||
"""Receptive field and transformer diagram generators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q23_common import (
|
||||
_HIGHLIGHT_END,
|
||||
_HIGHLIGHT_START,
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY3,
|
||||
GRAY5,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
WHITE,
|
||||
_save_figure,
|
||||
plt,
|
||||
)
|
||||
from matplotlib import patches
|
||||
|
||||
|
||||
def generate_receptive_field() -> None:
|
||||
"""Generate receptive field."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
|
||||
|
||||
def draw_grid(
|
||||
ax: patches.Axes,
|
||||
size: int,
|
||||
highlight_cells: list[tuple[int, int]],
|
||||
highlight_color: str,
|
||||
title: str,
|
||||
grid_offset: tuple[int, int] = (0, 0),
|
||||
) -> None:
|
||||
"""Draw grid."""
|
||||
ox, oy = grid_offset
|
||||
for i in range(size):
|
||||
for j in range(size):
|
||||
color = WHITE
|
||||
if (i, j) in highlight_cells:
|
||||
color = highlight_color
|
||||
rect = patches.Rectangle(
|
||||
(ox + j, oy + size - 1 - i),
|
||||
1,
|
||||
1,
|
||||
facecolor=color,
|
||||
edgecolor=GRAY3, # Use GRAY3 instead of GRAY4 since unused
|
||||
linewidth=0.5,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.set_title(title, fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# --- Panel 1: Standard 3x3 conv receptive field ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(-0.5, 7.5)
|
||||
ax.set_ylim(-1, 8)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
|
||||
# 7x7 input grid
|
||||
highlight_3x3 = [
|
||||
(2, 2),
|
||||
(2, 3),
|
||||
(2, 4),
|
||||
(3, 2),
|
||||
(3, 3),
|
||||
(3, 4),
|
||||
(4, 2),
|
||||
(4, 3),
|
||||
(4, 4),
|
||||
]
|
||||
draw_grid(ax, 7, highlight_3x3, ACCENT_LIGHT, "Zwykła conv 3x3")
|
||||
ax.text(
|
||||
3.5,
|
||||
-0.5,
|
||||
"RF = 3x3 pikseli",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color=ACCENT,
|
||||
)
|
||||
|
||||
# --- Panel 2: Dilated conv (rate=2) ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(-0.5, 7.5)
|
||||
ax.set_ylim(-1, 8)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
|
||||
# 7x7 input grid with dilated highlights
|
||||
highlight_dilated = [
|
||||
(1, 1),
|
||||
(1, 3),
|
||||
(1, 5),
|
||||
(3, 1),
|
||||
(3, 3),
|
||||
(3, 5),
|
||||
(5, 1),
|
||||
(5, 3),
|
||||
(5, 5),
|
||||
]
|
||||
draw_grid(ax, 7, highlight_dilated, "#FFCDD2", "Dilated conv 3x3\n(rate=2)")
|
||||
ax.text(
|
||||
3.5,
|
||||
-0.5,
|
||||
"RF = 5x5, ale 9 parametrów!",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
|
||||
# Connect dots to show pattern
|
||||
dots_x = [1.5, 3.5, 5.5, 1.5, 3.5, 5.5, 1.5, 3.5, 5.5]
|
||||
dots_y = [5.5, 5.5, 5.5, 3.5, 3.5, 3.5, 1.5, 1.5, 1.5]
|
||||
ax.scatter(dots_x, dots_y, c=RED_ACCENT, s=30, zorder=5)
|
||||
|
||||
# --- Panel 3: Comparison ---
|
||||
ax = axes[2]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Receptive Field\n(pole widzenia neuronu)", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
y = 8.5
|
||||
lines = [
|
||||
("RF = ile pikseli WEJŚCIOWYCH", FS, BLACK, "bold"),
|
||||
("wpływa na JEDEN piksel wyjścia", FS, BLACK, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Rate (współczynnik dylatacji):", FS, BLACK, "bold"),
|
||||
(' rate=1: filtr „dotyka" sąsiadów', FS_SMALL, BLACK, "normal"),
|
||||
(" rate=2: co drugi piksel → RF = 5x5", FS_SMALL, BLACK, "normal"),
|
||||
(" rate=3: co trzeci → RF = 7x7", FS_SMALL, BLACK, "normal"),
|
||||
(" WIĘCEJ kontekstu, TE SAME wagi!", FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Dlaczego ważne w segmentacji?", FS, BLACK, "bold"),
|
||||
(" Piksel sam nie wie czym jest.", FS_SMALL, BLACK, "normal"),
|
||||
(" Potrzebuje KONTEKSTU (otoczenia).", FS_SMALL, BLACK, "normal"),
|
||||
(" Większe RF → widzi obok budynki", FS_SMALL, BLACK, "normal"),
|
||||
(' → wie, że TEN piksel to „droga"', FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Global Average Pooling:", FS, BLACK, "bold"),
|
||||
(" Mapa HxWxC → 1x1xC", FS_SMALL, BLACK, "normal"),
|
||||
(" Średnia z CAŁEGO feature map", FS_SMALL, BLACK, "normal"),
|
||||
(" RF = nieskończone (cały obraz)", FS_SMALL, GREEN_ACCENT, "bold"),
|
||||
]
|
||||
for txt, size, color, weight in lines:
|
||||
if txt == "":
|
||||
y -= 0.2
|
||||
continue
|
||||
ax.text(0.5, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
|
||||
y -= 0.45
|
||||
|
||||
_save_figure("q23_receptive_field.png")
|
||||
|
||||
|
||||
def generate_transformer() -> None:
|
||||
"""Generate transformer."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 4))
|
||||
|
||||
# --- Panel 1: CNN local vs Transformer global ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(-0.5, 8.5)
|
||||
ax.set_ylim(-1.5, 8.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("CNN: widzi LOKALNIE", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# Draw 8x8 grid
|
||||
for i in range(8):
|
||||
for j in range(8):
|
||||
color = WHITE
|
||||
if (
|
||||
_HIGHLIGHT_START <= i <= _HIGHLIGHT_END
|
||||
and _HIGHLIGHT_START <= j <= _HIGHLIGHT_END
|
||||
):
|
||||
color = ACCENT_LIGHT
|
||||
rect = patches.Rectangle(
|
||||
(j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
# Highlight center
|
||||
rect = patches.Rectangle(
|
||||
(4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.7
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
4.5,
|
||||
4.5,
|
||||
"?",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=WHITE,
|
||||
)
|
||||
ax.text(
|
||||
4.5,
|
||||
-0.8,
|
||||
"Filtr 3x3 widzi tylko\n9 sąsiednich pikseli",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=ACCENT,
|
||||
)
|
||||
|
||||
# --- Panel 2: Transformer global ---
|
||||
ax = axes[1]
|
||||
ax.set_xlim(-0.5, 8.5)
|
||||
ax.set_ylim(-1.5, 8.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Transformer: widzi GLOBALNIE", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# Draw 8x8 grid all highlighted
|
||||
for i in range(8):
|
||||
for j in range(8):
|
||||
color = "#FFCDD2"
|
||||
rect = patches.Rectangle(
|
||||
(j, 7 - i), 1, 1, facecolor=color, edgecolor=GRAY3, linewidth=0.3
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
rect = patches.Rectangle(
|
||||
(4, 4), 1, 1, facecolor=RED_ACCENT, edgecolor=BLACK, linewidth=1.5, alpha=0.9
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
4.5,
|
||||
4.5,
|
||||
"?",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
color=WHITE,
|
||||
)
|
||||
ax.text(
|
||||
4.5,
|
||||
-0.8,
|
||||
'Self-attention „pyta"\nALL 64 piksele naraz',
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=RED_ACCENT,
|
||||
)
|
||||
|
||||
# --- Panel 3: SOTA + Transformer explanation ---
|
||||
ax = axes[2]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Transformer & SOTA", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
y = 9.2
|
||||
lines = [
|
||||
("Transformer:", FS, BLACK, "bold"),
|
||||
(" Architektura z 2017 (Vaswani et al.)", FS_SMALL, BLACK, "normal"),
|
||||
(" Oryginalnie do NLP (tłumaczenie)", FS_SMALL, BLACK, "normal"),
|
||||
(" Kluczowy mechanizm: SELF-ATTENTION", FS_SMALL, ACCENT, "bold"),
|
||||
("", 0, "", ""),
|
||||
("Self-attention w skrócie:", FS, BLACK, "bold"),
|
||||
(" Każdy piksel tworzy trzy wektory:", FS_SMALL, BLACK, "normal"),
|
||||
(' Q (Query — „czego szukam?")', FS_SMALL, ACCENT, "normal"),
|
||||
(' K (Key — „co oferuję innych")', FS_SMALL, RED_ACCENT, "normal"),
|
||||
(' V (Value — „moja wartość")', FS_SMALL, GREEN_ACCENT, "normal"),
|
||||
(" Attention = softmax(Q·Kᵀ/√d)·V", FS_SMALL, BLACK, "bold"),
|
||||
(" Koszt: O(n²) — n=liczba pikseli", FS_SMALL, RED_ACCENT, "normal"),
|
||||
("", 0, "", ""),
|
||||
("SOTA = State Of The Art:", FS, BLACK, "bold"),
|
||||
(" Najlepszy znany wynik na benchmarku", FS_SMALL, BLACK, "normal"),
|
||||
(' Np. „mIoU 85.1% na ADE20K = SOTA"', FS_SMALL, BLACK, "normal"),
|
||||
(" Ciągle się zmienia (nowy paper", FS_SMALL, GRAY5, "normal"),
|
||||
(" → nowy SOTA)", FS_SMALL, GRAY5, "normal"),
|
||||
]
|
||||
for txt, size, color, weight in lines:
|
||||
if txt == "":
|
||||
y -= 0.15
|
||||
continue
|
||||
ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
|
||||
y -= 0.45
|
||||
|
||||
_save_figure("q23_transformer_attention.png")
|
||||
@ -0,0 +1,408 @@
|
||||
"""Region growing and DIY thresholding diagram generators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q23_common import (
|
||||
_BRIGHT_THRESHOLD,
|
||||
_OTSU_THRESHOLD,
|
||||
ACCENT,
|
||||
ACCENT_LIGHT,
|
||||
BLACK,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TINY,
|
||||
FS_TITLE,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
GREEN_ACCENT,
|
||||
RED_ACCENT,
|
||||
WHITE,
|
||||
_save_figure,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
)
|
||||
from matplotlib import patches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def _draw_region_growing_grid(ax: Axes) -> None:
|
||||
"""Draw panel 2: region growing step-by-step grid."""
|
||||
ax.set_xlim(-0.5, 6.5)
|
||||
ax.set_ylim(-1.5, 7.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Region Growing: krok po kroku",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
pixel_grid = np.array(
|
||||
[
|
||||
[150, 153, 148, 200, 210, 205],
|
||||
[147, 155, 152, 195, 208, 200],
|
||||
[145, 148, 160, 190, 195, 210],
|
||||
[200, 195, 190, 155, 148, 150],
|
||||
[210, 205, 200, 150, 152, 145],
|
||||
[215, 208, 195, 148, 147, 155],
|
||||
]
|
||||
)
|
||||
region_mask = np.array(
|
||||
[
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
)
|
||||
|
||||
for i in range(6):
|
||||
for j in range(6):
|
||||
v = pixel_grid[i, j]
|
||||
if region_mask[i, j] == 1 and v < _BRIGHT_THRESHOLD:
|
||||
cell_color = ACCENT_LIGHT
|
||||
elif region_mask[i, j] == 1:
|
||||
cell_color = "#E0E0E0"
|
||||
else:
|
||||
cell_color = WHITE
|
||||
if i == 1 and j == 1:
|
||||
cell_color = "#FFD54F"
|
||||
rect = patches.Rectangle(
|
||||
(j, 5 - i),
|
||||
1,
|
||||
1,
|
||||
facecolor=cell_color,
|
||||
edgecolor=GRAY4,
|
||||
linewidth=0.5,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
j + 0.5,
|
||||
5 - i + 0.5,
|
||||
str(v),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_TINY,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.annotate(
|
||||
"SEED\n(155)",
|
||||
xy=(1.5, 4.5),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
color=RED_ACCENT,
|
||||
fontweight="bold",
|
||||
arrowprops={"arrowstyle": "->", "color": RED_ACCENT},
|
||||
xytext=(-0.5, 7),
|
||||
)
|
||||
ax.text(
|
||||
3,
|
||||
-0.8,
|
||||
"Próg = 20\nNiebieski = region (|val - seed| < 20)",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=ACCENT,
|
||||
)
|
||||
|
||||
|
||||
def _draw_bfs_expansion(ax: Axes) -> None:
|
||||
"""Draw panel 3: BFS expansion visualization."""
|
||||
ax.set_xlim(-0.5, 6.5)
|
||||
ax.set_ylim(-1.5, 7.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Rosnący region (BFS)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
wave_colors = ["#FFD54F", "#FFF176", "#FFF9C4", ACCENT_LIGHT, "#B3D4FC"]
|
||||
wave_labels = ["Seed", "Fala 1", "Fala 2", "Fala 3", "Fala 4"]
|
||||
waves = [
|
||||
[(1, 1)],
|
||||
[(0, 1), (1, 0), (1, 2), (2, 1)],
|
||||
[(0, 0), (0, 2), (2, 0), (2, 2)],
|
||||
]
|
||||
|
||||
for i in range(6):
|
||||
for j in range(6):
|
||||
cell_color = WHITE
|
||||
for w_idx, wave in enumerate(waves):
|
||||
if (i, j) in wave:
|
||||
cell_color = wave_colors[w_idx]
|
||||
rect = patches.Rectangle(
|
||||
(j, 5 - i),
|
||||
1,
|
||||
1,
|
||||
facecolor=cell_color,
|
||||
edgecolor=GRAY4,
|
||||
linewidth=0.5,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
seed_x, seed_y = 1.5, 4.5
|
||||
for dx, dy, _label in [
|
||||
(0, 1, ""),
|
||||
(0, -1, ""),
|
||||
(1, 0, ""),
|
||||
(-1, 0, ""),
|
||||
]:
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(seed_x + dx * 0.7, seed_y + dy * 0.7),
|
||||
xytext=(seed_x, seed_y),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": RED_ACCENT,
|
||||
"lw": 1.2,
|
||||
},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
3,
|
||||
-0.5,
|
||||
"BFS: sprawdzaj sąsiadów,\ndodawaj podobne do kolejki",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=GRAY5,
|
||||
)
|
||||
|
||||
for w_idx, (wave_color, label) in enumerate(
|
||||
zip(wave_colors[:3], wave_labels[:3], strict=False)
|
||||
):
|
||||
rect = patches.Rectangle(
|
||||
(4, 6.5 - w_idx * 0.7),
|
||||
0.5,
|
||||
0.5,
|
||||
facecolor=wave_color,
|
||||
edgecolor=GRAY4,
|
||||
linewidth=0.5,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
4.8,
|
||||
6.75 - w_idx * 0.7,
|
||||
label,
|
||||
fontsize=FS_TINY,
|
||||
va="center",
|
||||
)
|
||||
|
||||
|
||||
def generate_region_growing() -> None:
|
||||
"""Generate region growing."""
|
||||
_fig, axes = plt.subplots(1, 3, figsize=(11, 4.2))
|
||||
|
||||
# --- Panel 1: Manual vs automatic seed ---
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 10)
|
||||
ax.axis("off")
|
||||
ax.set_title("Seed: ręcznie vs automatycznie", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
y = 9.2
|
||||
lines = [
|
||||
("Ręczny seed:", FS, ACCENT, "bold"),
|
||||
(" Użytkownik klika na obraz", FS_SMALL, BLACK, "normal"),
|
||||
(' → „tu jest obiekt, od tego zacznij"', FS_SMALL, BLACK, "normal"),
|
||||
(" Użycie: segmentacja interaktywna", FS_SMALL, GRAY5, "normal"),
|
||||
(" (np. Photoshop — magic wand tool)", FS_SMALL, GRAY5, "normal"),
|
||||
("", 0, "", ""),
|
||||
("Automatyczny seed:", FS, RED_ACCENT, "bold"),
|
||||
(" 1. Histogram → lokalne maxima", FS_SMALL, BLACK, "normal"),
|
||||
(" (najczęstsza jasność → seed)", FS_SMALL, GRAY5, "normal"),
|
||||
(" 2. Grid: siatka co N pikseli", FS_SMALL, BLACK, "normal"),
|
||||
(" (np. seed co 50 px → 100 seedów)", FS_SMALL, GRAY5, "normal"),
|
||||
(" 3. Losowe próbkowanie", FS_SMALL, BLACK, "normal"),
|
||||
(" 4. Ekstrema lokalne gradientu", FS_SMALL, BLACK, "normal"),
|
||||
("", 0, "", ""),
|
||||
("Dlaczego OR?", FS, GREEN_ACCENT, "bold"),
|
||||
(" Ręczny → precyzyjny, ale wolny", FS_SMALL, BLACK, "normal"),
|
||||
(" Auto → szybki, ale over-segmentation", FS_SMALL, BLACK, "normal"),
|
||||
]
|
||||
for txt, size, color, weight in lines:
|
||||
if txt == "":
|
||||
y -= 0.15
|
||||
continue
|
||||
ax.text(0.3, y, txt, fontsize=size, color=color, fontweight=weight, va="top")
|
||||
y -= 0.45
|
||||
|
||||
# --- Panel 2: Region growing step by step ---
|
||||
_draw_region_growing_grid(axes[1])
|
||||
|
||||
# --- Panel 3: BFS expansion ---
|
||||
_draw_bfs_expansion(axes[2])
|
||||
|
||||
_save_figure("q23_region_growing.png")
|
||||
|
||||
|
||||
def _draw_otsu_variance_and_pseudocode(
|
||||
ax_var: Axes,
|
||||
ax_code: Axes,
|
||||
img: np.ndarray,
|
||||
) -> int:
|
||||
"""Draw panels 4 and 5: Otsu variance plot and pseudocode."""
|
||||
thresholds = range(10, 245)
|
||||
variances = []
|
||||
for t in thresholds:
|
||||
c0 = img[img <= t].ravel()
|
||||
c1 = img[img > t].ravel()
|
||||
if len(c0) == 0 or len(c1) == 0:
|
||||
variances.append(np.nan)
|
||||
continue
|
||||
w0 = len(c0) / len(img.ravel())
|
||||
w1 = len(c1) / len(img.ravel())
|
||||
var = w0 * np.var(c0) + w1 * np.var(c1)
|
||||
variances.append(var)
|
||||
|
||||
ax_var.plot(list(thresholds), variances, color=ACCENT, linewidth=1.5)
|
||||
best_t = list(thresholds)[np.nanargmin(variances)]
|
||||
ax_var.axvline(
|
||||
x=best_t,
|
||||
color=RED_ACCENT,
|
||||
linewidth=1.5,
|
||||
linestyle="--",
|
||||
label=f"Otsu T={best_t}",
|
||||
)
|
||||
ax_var.scatter(
|
||||
[best_t],
|
||||
[np.nanmin(variances)],
|
||||
c=RED_ACCENT,
|
||||
s=60,
|
||||
zorder=5,
|
||||
)
|
||||
ax_var.set_xlabel("Próg T", fontsize=FS_SMALL)
|
||||
ax_var.set_ylabel("σ² wewnątrzklasowa", fontsize=FS_SMALL)
|
||||
ax_var.set_title(
|
||||
"Krok 4: Otsu szuka min σ²",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_var.legend(fontsize=FS_TINY)
|
||||
|
||||
ax_code.set_xlim(0, 10)
|
||||
ax_code.set_ylim(0, 10)
|
||||
ax_code.axis("off")
|
||||
ax_code.set_title("Pseudokod Otsu", fontsize=FS, fontweight="bold")
|
||||
|
||||
code_lines = [
|
||||
"best_T = 0",
|
||||
"min_var = ∞",
|
||||
"",
|
||||
"for T in 0..255:",
|
||||
" c0 = piksele z jasność ≤ T",
|
||||
" c1 = piksele z jasność > T",
|
||||
" w0 = len(c0) / len(all)",
|
||||
" w1 = len(c1) / len(all)",
|
||||
" var = w0·var(c0) + w1·var(c1)",
|
||||
" if var < min_var:",
|
||||
" min_var = var",
|
||||
" best_T = T",
|
||||
"",
|
||||
"return best_T # optymalny próg",
|
||||
]
|
||||
for i, line in enumerate(code_lines):
|
||||
txt_color = ACCENT if "best_T = T" in line or "return" in line else BLACK
|
||||
ax_code.text(
|
||||
0.5,
|
||||
9.5 - i * 0.65,
|
||||
line,
|
||||
fontsize=FS_TINY,
|
||||
fontfamily="monospace",
|
||||
color=txt_color,
|
||||
fontweight="bold" if txt_color == ACCENT else "normal",
|
||||
)
|
||||
return int(best_t)
|
||||
|
||||
|
||||
def generate_diy_thresholding() -> None:
|
||||
"""Generate diy thresholding."""
|
||||
_fig, axes = plt.subplots(2, 3, figsize=(11, 7))
|
||||
|
||||
# Create a simple synthetic image: dark circle on bright background
|
||||
size = 64
|
||||
img = np.ones((size, size)) * 200 # bright background
|
||||
yy, xx = np.mgrid[:size, :size]
|
||||
mask = ((xx - 32) ** 2 + (yy - 32) ** 2) < 15**2
|
||||
img[mask] = 60 # dark circle
|
||||
# Add some noise
|
||||
img += rng.normal(0, 10, img.shape)
|
||||
img = np.clip(img, 0, 255)
|
||||
|
||||
# --- Panel 1: Original image ---
|
||||
ax = axes[0, 0]
|
||||
ax.imshow(img, cmap="gray", vmin=0, vmax=255)
|
||||
ax.set_title("Krok 1: obraz wejściowy", fontsize=FS, fontweight="bold")
|
||||
ax.axis("off")
|
||||
ax.text(32, -3, "64x64 pikseli, szare", fontsize=FS_TINY, ha="center")
|
||||
|
||||
# --- Panel 2: Histogram ---
|
||||
ax = axes[0, 1]
|
||||
counts, _bins, _ = ax.hist(
|
||||
img.ravel(), bins=50, color=GRAY3, edgecolor=GRAY5, linewidth=0.5
|
||||
)
|
||||
ax.axvline(
|
||||
x=128, color=RED_ACCENT, linewidth=2, linestyle="--", label="T=128 (Otsu)"
|
||||
)
|
||||
ax.set_xlabel("Jasność", fontsize=FS_SMALL)
|
||||
ax.set_ylabel("Piksele", fontsize=FS_SMALL)
|
||||
ax.set_title("Krok 2: histogram\n(bimodalny!)", fontsize=FS, fontweight="bold")
|
||||
ax.legend(fontsize=FS_TINY)
|
||||
ax.annotate(
|
||||
"Garb 1\n(obiekt)",
|
||||
xy=(60, max(counts) * 0.5),
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.annotate(
|
||||
"Garb 2\n(tło)",
|
||||
xy=(200, max(counts) * 0.5),
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=RED_ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# --- Panel 3: Thresholding result ---
|
||||
ax = axes[0, 2]
|
||||
binary = (img > _OTSU_THRESHOLD).astype(float)
|
||||
ax.imshow(binary, cmap="gray", vmin=0, vmax=1)
|
||||
ax.set_title("Krok 3: progowanie T=128", fontsize=FS, fontweight="bold")
|
||||
ax.axis("off")
|
||||
ax.text(32, -3, "Biały = tło, Czarny = obiekt", fontsize=FS_TINY, ha="center")
|
||||
|
||||
# --- Panels 4+5: Otsu variance plot + pseudocode ---
|
||||
best_t = _draw_otsu_variance_and_pseudocode(
|
||||
axes[1, 0],
|
||||
axes[1, 1],
|
||||
img,
|
||||
)
|
||||
|
||||
# --- Panel 6: Final result with Otsu ---
|
||||
ax = axes[1, 2]
|
||||
binary_otsu = (img > best_t).astype(float)
|
||||
ax.imshow(binary_otsu, cmap="gray", vmin=0, vmax=1)
|
||||
ax.set_title(f"Krok 5: wynik Otsu (T={best_t})", fontsize=FS, fontweight="bold")
|
||||
ax.axis("off")
|
||||
ax.text(
|
||||
32,
|
||||
-3,
|
||||
"Automatyczny próg!",
|
||||
fontsize=FS_TINY,
|
||||
ha="center",
|
||||
color=GREEN_ACCENT,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
_save_figure("q23_diy_thresholding.png")
|
||||
@ -0,0 +1,186 @@
|
||||
"""Common utilities and constants for Q24 diagram generation.
|
||||
|
||||
Monochrome, A4-printable PNGs (300 DPI).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
DPI = 300
|
||||
BG = "white"
|
||||
LN = "black"
|
||||
FS = 8
|
||||
FS_TITLE = 11
|
||||
FS_SMALL = 6.5
|
||||
FS_LABEL = 9
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
GRAY1 = "#E8E8E8"
|
||||
GRAY2 = "#D0D0D0"
|
||||
GRAY3 = "#B8B8B8"
|
||||
GRAY4 = "#F5F5F5"
|
||||
GRAY5 = "#C0C0C0"
|
||||
|
||||
_PIXEL_BRIGHT_THRESH = 127
|
||||
_GRADIENT_BRIGHT_THRESH = 100
|
||||
_DATA_BRIGHT_THRESH = 5
|
||||
_II_BRIGHT_THRESH = 25
|
||||
_DOTS_STAGE_IDX = 2
|
||||
|
||||
|
||||
def draw_box(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
text: str,
|
||||
*,
|
||||
fill: str = "white",
|
||||
lw: float = 1.2,
|
||||
fontsize: float = FS,
|
||||
fontweight: str = "normal",
|
||||
ha: str = "center",
|
||||
va: str = "center",
|
||||
rounded: bool = True,
|
||||
edgecolor: str = LN,
|
||||
linestyle: str = "-",
|
||||
) -> None:
|
||||
"""Draw box."""
|
||||
if rounded:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=lw,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=fill,
|
||||
linestyle=linestyle,
|
||||
)
|
||||
else:
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
lw=lw,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=fill,
|
||||
linestyle=linestyle,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2,
|
||||
text,
|
||||
ha=ha,
|
||||
va=va,
|
||||
fontsize=fontsize,
|
||||
fontweight=fontweight,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
|
||||
def draw_arrow(
|
||||
ax: Axes,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
*,
|
||||
lw: float = 1.2,
|
||||
style: str = "->",
|
||||
color: str = LN,
|
||||
) -> None:
|
||||
"""Draw arrow."""
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
|
||||
)
|
||||
|
||||
|
||||
def save_fig(fig: Figure, name: str) -> None:
|
||||
"""Save fig."""
|
||||
path = str(Path(OUTPUT_DIR) / name)
|
||||
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", path)
|
||||
|
||||
|
||||
def draw_table(
|
||||
ax: Axes,
|
||||
headers: list[str],
|
||||
rows: list[list[str]],
|
||||
x0: float,
|
||||
y0: float,
|
||||
col_widths: list[float],
|
||||
*,
|
||||
row_h: float = 0.4,
|
||||
header_fill: str = GRAY2,
|
||||
row_fills: list[str] | None = None,
|
||||
fontsize: float = FS,
|
||||
header_fontsize: float | None = None,
|
||||
) -> None:
|
||||
"""Draw table."""
|
||||
if header_fontsize is None:
|
||||
header_fontsize = fontsize
|
||||
len(headers)
|
||||
cx = x0
|
||||
for j, hdr in enumerate(headers):
|
||||
draw_box(
|
||||
ax,
|
||||
cx,
|
||||
y0,
|
||||
col_widths[j],
|
||||
row_h,
|
||||
hdr,
|
||||
fill=header_fill,
|
||||
fontsize=header_fontsize,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
cx += col_widths[j]
|
||||
for i, row in enumerate(rows):
|
||||
cy = y0 - (i + 1) * row_h
|
||||
cx = x0
|
||||
fill = GRAY4 if (i % 2 == 0) else "white"
|
||||
if row_fills and i < len(row_fills):
|
||||
fill = row_fills[i]
|
||||
for j, cell in enumerate(row):
|
||||
fw = "bold" if j == 0 else "normal"
|
||||
draw_box(
|
||||
ax,
|
||||
cx,
|
||||
cy,
|
||||
col_widths[j],
|
||||
row_h,
|
||||
cell,
|
||||
fill=fill,
|
||||
fontsize=fontsize,
|
||||
fontweight=fw,
|
||||
rounded=False,
|
||||
)
|
||||
cx += col_widths[j]
|
||||
@ -0,0 +1,412 @@
|
||||
"""FPN, anchor boxes, detection tasks, and CNN architecture diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
np,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 16. FPN (Feature Pyramid Network)
|
||||
# ============================================================
|
||||
def draw_fpn() -> None:
|
||||
"""Draw fpn."""
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
ax.set_xlim(-0.5, 9.5)
|
||||
ax.set_ylim(-0.5, 5.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"FPN (Feature Pyramid Network) — detekcja obiektów wszystkich rozmiarów",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
levels = [
|
||||
(0, 0, 2.0, 2.0, "C2\n56x56", "duże\ndetale"),
|
||||
(0, 2.2, 1.5, 1.5, "C3\n28x28", ""),
|
||||
(0, 3.9, 1.0, 1.0, "C4\n14x14", ""),
|
||||
(0, 5.1, 0.6, 0.6, "C5\n7x7", "kontekst"),
|
||||
]
|
||||
|
||||
for x, y, w, h, label, note in levels:
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((x, y - h), w, h, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y - h / 2,
|
||||
label,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
if note:
|
||||
ax.text(
|
||||
x + w + 0.15,
|
||||
y - h / 2,
|
||||
note,
|
||||
ha="left",
|
||||
va="center",
|
||||
fontsize=5,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
1.0, -0.3, "Bottom-up\n(backbone)", ha="center", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
# Top-down + lateral
|
||||
td_levels = [
|
||||
(4.5, 5.1, 0.6, 0.6, "P5"),
|
||||
(4.5, 3.9, 1.0, 1.0, "P4"),
|
||||
(4.5, 2.2, 1.5, 1.5, "P3"),
|
||||
(4.5, 0, 2.0, 2.0, "P2"),
|
||||
]
|
||||
|
||||
for x, y, w, h, label in td_levels:
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(x, y - h + h), w, h, facecolor=GRAY2, edgecolor=LN, lw=1.5
|
||||
)
|
||||
)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y - h / 2 + h,
|
||||
label,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Lateral connections
|
||||
for (_, y1, w1, h1, _, _), (x2, y2, _w2, h2, _) in zip(
|
||||
levels, td_levels, strict=False
|
||||
):
|
||||
draw_arrow(ax, w1 + 0.2, y1 - h1 / 2, x2 - 0.1, y2 + h2 / 2, lw=1, style="->")
|
||||
|
||||
# Top-down arrows
|
||||
for i in range(len(td_levels) - 1):
|
||||
x2, y2, w2, h2, _ = td_levels[i]
|
||||
x3, y3, w3, h3, _ = td_levels[i + 1]
|
||||
draw_arrow(
|
||||
ax,
|
||||
x2 + w2 / 2,
|
||||
y2,
|
||||
x3 + w3 / 2,
|
||||
y3 + h3 + 0.1,
|
||||
lw=1.2,
|
||||
style="->",
|
||||
color=GRAY3,
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.5,
|
||||
-0.3,
|
||||
"Top-down + lateral\n(FPN)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Detection outputs
|
||||
det_labels = ["małe obj.", "średnie", "duże", "b. duże"]
|
||||
for i, (x, y, w, h, _label) in enumerate(td_levels):
|
||||
draw_arrow(ax, x + w + 0.1, y + h / 2, 7.5, y + h / 2, lw=0.8)
|
||||
ax.text(
|
||||
7.7,
|
||||
y + h / 2,
|
||||
f"detekcja:\n{det_labels[3 - i]}",
|
||||
fontsize=FS_SMALL,
|
||||
va="center",
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_fpn.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 17. Anchor boxes
|
||||
# ============================================================
|
||||
def draw_anchor_boxes() -> None:
|
||||
"""Draw anchor boxes."""
|
||||
fig, ax = plt.subplots(figsize=(7, 5))
|
||||
ax.set_title(
|
||||
"Anchor boxes — predefiniowane kształty",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1))
|
||||
|
||||
# Center point
|
||||
cx, cy = 3, 2.5
|
||||
ax.plot(cx, cy, "ko", markersize=8, zorder=5)
|
||||
ax.text(cx + 0.15, cy + 0.15, "(x, y)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# 9 anchors: 3 sizes x 3 ratios
|
||||
anchors = [
|
||||
(0.8, 0.8, "-", "1:1 small"),
|
||||
(1.6, 1.6, "-", "1:1 medium"),
|
||||
(2.4, 2.4, "-", "1:1 large"),
|
||||
(0.6, 1.2, "--", "1:2 small"),
|
||||
(1.2, 2.4, "--", "1:2 medium"),
|
||||
(1.8, 3.6, "--", "1:2 large"),
|
||||
(1.2, 0.6, ":", "2:1 small"),
|
||||
(2.4, 1.2, ":", "2:1 medium"),
|
||||
(3.6, 1.8, ":", "2:1 large"),
|
||||
]
|
||||
|
||||
for w, h, ls, _label in anchors:
|
||||
rect = mpatches.Rectangle(
|
||||
(cx - w / 2, cy - h / 2),
|
||||
w,
|
||||
h,
|
||||
facecolor="none",
|
||||
edgecolor=LN,
|
||||
lw=1.2,
|
||||
linestyle=ls,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
# Legend-style labels
|
||||
ax.text(
|
||||
3,
|
||||
-0.5,
|
||||
"9 anchorów = 3 rozmiary x 3 proporcje (1:1, 1:2, 2:1)\n"
|
||||
"Sieć predykuje PRZESUNIĘCIE od najbliższego anchora",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.set_xlim(-0.5, 6.5)
|
||||
ax.set_ylim(-1.2, 5.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
|
||||
save_fig(fig, "q24_anchor_boxes.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 18. Detection task comparison
|
||||
# ============================================================
|
||||
def draw_detection_tasks() -> None:
|
||||
"""Draw detection tasks."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
||||
fig.suptitle(
|
||||
"Klasyfikacja vs Detekcja vs Segmentacja",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Classification
|
||||
ax = axes[0]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
# Simple cat silhouette
|
||||
ax.add_patch(mpatches.Ellipse((2, 2), 2, 1.5, facecolor=GRAY3, edgecolor=LN, lw=1))
|
||||
ax.add_patch(mpatches.Ellipse((2, 3), 1, 0.8, facecolor=GRAY3, edgecolor=LN, lw=1))
|
||||
# Ears
|
||||
ax.plot([1.6, 1.5, 1.8], [3.3, 3.8, 3.4], color=LN, lw=1.5)
|
||||
ax.plot([2.2, 2.5, 2.4], [3.3, 3.8, 3.4], color=LN, lw=1.5)
|
||||
ax.text(
|
||||
2, -0.4, '→ "KOT" (jedna etykieta)', ha="center", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
ax.set_xlim(-0.5, 4.5)
|
||||
ax.set_ylim(-0.8, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Klasyfikacja\n(co?)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# Detection
|
||||
ax = axes[1]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
# Cat
|
||||
ax.add_patch(
|
||||
mpatches.Ellipse((1.2, 2), 1.2, 1, facecolor=GRAY3, edgecolor=LN, lw=1)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Ellipse((1.2, 2.8), 0.7, 0.5, facecolor=GRAY3, edgecolor=LN, lw=1)
|
||||
)
|
||||
# Dog
|
||||
ax.add_patch(
|
||||
mpatches.Ellipse((3, 1.5), 1.2, 1, facecolor=GRAY2, edgecolor=LN, lw=1)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Ellipse((3, 2.3), 0.7, 0.5, facecolor=GRAY2, edgecolor=LN, lw=1)
|
||||
)
|
||||
# Bounding boxes
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0.3, 1.2), 1.8, 2.2, facecolor="none", edgecolor=LN, lw=2.5)
|
||||
)
|
||||
ax.text(1.2, 3.5, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((2.1, 0.8), 1.7, 2.0, facecolor="none", edgecolor=LN, lw=2.5)
|
||||
)
|
||||
ax.text(3.0, 2.9, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.text(
|
||||
2,
|
||||
-0.4,
|
||||
"→ bbox + klasa (N obiektów)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xlim(-0.5, 4.5)
|
||||
ax.set_ylim(-0.8, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Detekcja\n(co? + gdzie?)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# Segmentation
|
||||
ax = axes[2]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 4, 4, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
# Cat mask (detailed)
|
||||
theta = np.linspace(0, 2 * np.pi, 30)
|
||||
cat_x = 1.2 + 0.6 * np.cos(theta) + 0.1 * np.sin(3 * theta)
|
||||
cat_y = 2 + 0.5 * np.sin(theta) + 0.1 * np.cos(2 * theta)
|
||||
ax.fill(cat_x, cat_y, facecolor=GRAY3, edgecolor=LN, lw=1.5)
|
||||
# Dog mask
|
||||
dog_x = 3.0 + 0.6 * np.cos(theta) + 0.05 * np.sin(4 * theta)
|
||||
dog_y = 1.5 + 0.5 * np.sin(theta) + 0.08 * np.cos(3 * theta)
|
||||
ax.fill(dog_x, dog_y, facecolor=GRAY2, edgecolor=LN, lw=1.5)
|
||||
ax.text(1.2, 2, "KOT", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.text(3.0, 1.5, "PIES", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.text(
|
||||
2,
|
||||
-0.4,
|
||||
"→ maska pikseli (per piksel)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xlim(-0.5, 4.5)
|
||||
ax.set_ylim(-0.8, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Segmentacja\n(dokładna maska)", fontsize=FS, fontweight="bold")
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_detection_tasks.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 19. CNN Architecture overview
|
||||
# ============================================================
|
||||
def draw_cnn_architecture() -> None:
|
||||
"""Draw cnn architecture."""
|
||||
fig, ax = plt.subplots(figsize=(12, 4))
|
||||
ax.set_xlim(-0.5, 12.5)
|
||||
ax.set_ylim(-1, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"CNN — od obrazu do predykcji (architektura)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
# Input image
|
||||
draw_box(ax, 0, 0.5, 1.5, 3, "Obraz\n224x224x3", fill=GRAY1, fontsize=FS)
|
||||
|
||||
# Conv1
|
||||
draw_arrow(ax, 1.6, 2.0, 2.1, 2.0, lw=1.2)
|
||||
draw_box(
|
||||
ax, 2.2, 0.8, 1.2, 2.4, "Conv1\n+ReLU\n55x55x96", fill=GRAY4, fontsize=FS_SMALL
|
||||
)
|
||||
|
||||
# Pool1
|
||||
draw_arrow(ax, 3.5, 2.0, 3.9, 2.0, lw=1.2)
|
||||
draw_box(ax, 4.0, 1.0, 1.0, 2.0, "Pool\n27x27\nx96", fill=GRAY2, fontsize=FS_SMALL)
|
||||
|
||||
# Conv2
|
||||
draw_arrow(ax, 5.1, 2.0, 5.5, 2.0, lw=1.2)
|
||||
draw_box(
|
||||
ax,
|
||||
5.6,
|
||||
0.8,
|
||||
1.2,
|
||||
2.4,
|
||||
"Conv2\n+ReLU\n27x27\nx256",
|
||||
fill=GRAY4,
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
|
||||
# Pool2
|
||||
draw_arrow(ax, 6.9, 2.0, 7.3, 2.0, lw=1.2)
|
||||
draw_box(ax, 7.4, 1.2, 0.8, 1.6, "Pool\n13x13\nx256", fill=GRAY2, fontsize=FS_SMALL)
|
||||
|
||||
# More conv...
|
||||
draw_arrow(ax, 8.3, 2.0, 8.7, 2.0, lw=1.2)
|
||||
ax.text(9.0, 2.0, "...", fontsize=14, ha="center", va="center")
|
||||
draw_arrow(ax, 9.3, 2.0, 9.7, 2.0, lw=1.2)
|
||||
|
||||
# FC
|
||||
draw_box(ax, 9.8, 1.2, 1.0, 1.6, "FC\n4096", fill=GRAY3, fontsize=FS)
|
||||
|
||||
draw_arrow(ax, 10.9, 2.0, 11.3, 2.0, lw=1.2)
|
||||
|
||||
# Output
|
||||
draw_box(
|
||||
ax, 11.4, 1.5, 1.0, 1.0, "Softmax\n1000 klas", fill=GRAY1, fontsize=FS_SMALL
|
||||
)
|
||||
|
||||
# Annotations below
|
||||
ax.text(
|
||||
3.0,
|
||||
0.0,
|
||||
"rozmiar maleje\n224→55→27→13→6",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
)
|
||||
ax.text(
|
||||
6.0,
|
||||
0.0,
|
||||
"kanały rosną\n3→96→256→384",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
)
|
||||
ax.text(
|
||||
10.0, 0.0, "decyzja\nkońcowa", ha="center", fontsize=FS_SMALL, style="italic"
|
||||
)
|
||||
|
||||
# hierarchy
|
||||
ax.text(
|
||||
6.0,
|
||||
4.0,
|
||||
"Hierarchia: krawędzie → rogi → fragmenty → obiekty (K-R-F-O)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_cnn_architecture.png")
|
||||
@ -0,0 +1,342 @@
|
||||
"""Haar features, integral image, and SVM hyperplane diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q24_common import (
|
||||
_DATA_BRIGHT_THRESH,
|
||||
_II_BRIGHT_THRESH,
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
LN,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 4. Haar Features
|
||||
# ============================================================
|
||||
def draw_haar_features() -> None:
|
||||
"""Draw haar features."""
|
||||
fig, axes = plt.subplots(1, 4, figsize=(11, 3))
|
||||
fig.suptitle(
|
||||
"Cechy Haar — typy i zastosowanie na twarzy",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Feature 1: Vertical edge
|
||||
ax = axes[0]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 1, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((1, 0), 1, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.text(
|
||||
0.5, 1, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
|
||||
)
|
||||
ax.text(
|
||||
1.5, 1, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
|
||||
)
|
||||
ax.set_xlim(-0.2, 2.2)
|
||||
ax.set_ylim(-0.5, 2.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Krawędź pionowa\nwartość = Σ₁ - Σ₂", fontsize=FS)
|
||||
|
||||
# Feature 2: Horizontal edge
|
||||
ax = axes[1]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 1), 2, 1, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 2, 1, facecolor=GRAY3, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.text(
|
||||
1, 1.5, "+Σ₁", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
|
||||
)
|
||||
ax.text(
|
||||
1, 0.5, "-Σ₂", ha="center", va="center", fontsize=FS_LABEL, fontweight="bold"
|
||||
)
|
||||
ax.set_xlim(-0.2, 2.2)
|
||||
ax.set_ylim(-0.5, 2.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Krawędź pozioma\n(oczy vs czoło)", fontsize=FS)
|
||||
|
||||
# Feature 3: Three-rectangle (line)
|
||||
ax = axes[2]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0.7, 0), 0.7, 2, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((1.4, 0), 0.7, 2, facecolor=GRAY3, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.text(
|
||||
0.35, 1, "-Σ₁", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold"
|
||||
)
|
||||
ax.text(
|
||||
1.05, 1, "+Σ₂", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold"
|
||||
)
|
||||
ax.text(
|
||||
1.75, 1, "-Σ₃", ha="center", va="center", fontsize=FS_SMALL, fontweight="bold"
|
||||
)
|
||||
ax.set_xlim(-0.2, 2.3)
|
||||
ax.set_ylim(-0.5, 2.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Linia (3 prostokąty)\n(nos vs policzki)", fontsize=FS)
|
||||
|
||||
_draw_haar_face_panel(axes[3])
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_haar_features.png")
|
||||
|
||||
|
||||
def _draw_haar_face_panel(ax: Axes) -> None:
|
||||
"""Draw Haar feature application on face schematic."""
|
||||
face = mpatches.Ellipse(
|
||||
(1.2, 1.2),
|
||||
2.0,
|
||||
2.4,
|
||||
facecolor=GRAY4,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.add_patch(face)
|
||||
ax.add_patch(
|
||||
mpatches.Ellipse((0.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Ellipse((1.7, 1.6), 0.4, 0.2, facecolor=GRAY3, edgecolor=LN, lw=1)
|
||||
)
|
||||
ax.plot([1.2, 1.1, 1.3], [1.3, 0.9, 0.9], color=LN, lw=1)
|
||||
ax.plot([0.8, 1.0, 1.2, 1.4, 1.6], [0.55, 0.5, 0.55, 0.5, 0.55], color=LN, lw=1)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.3, 1.4),
|
||||
1.8,
|
||||
0.4,
|
||||
facecolor="none",
|
||||
edgecolor=LN,
|
||||
lw=2,
|
||||
linestyle="--",
|
||||
)
|
||||
)
|
||||
ax.annotate(
|
||||
"cechy Haar\n(oczy ciemne\nvs czoło jasne)",
|
||||
xy=(1.2, 1.85),
|
||||
xytext=(2.2, 2.3),
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
arrowprops={"arrowstyle": "->", "color": LN, "lw": 1},
|
||||
)
|
||||
ax.set_xlim(-0.2, 3.0)
|
||||
ax.set_ylim(-0.2, 2.8)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Zastosowanie na twarzy", fontsize=FS)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 5. Integral Image
|
||||
# ============================================================
|
||||
def draw_integral_image() -> None:
|
||||
"""Draw integral image."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
|
||||
fig.suptitle(
|
||||
"Integral Image — suma prostokąta w O(1)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Original image
|
||||
ax = axes[0]
|
||||
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
ax.imshow(data, cmap="gray", vmin=0, vmax=10)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
str(data[i, j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
color="white" if data[i, j] > _DATA_BRIGHT_THRESH else "black",
|
||||
)
|
||||
ax.set_title("① Obraz oryginalny", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Integral image
|
||||
ax = axes[1]
|
||||
ii = np.array([[1, 3, 6], [5, 12, 21], [12, 27, 45]])
|
||||
ax.imshow(ii, cmap="gray", vmin=0, vmax=50)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
str(ii[i, j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
color="white" if ii[i, j] > _II_BRIGHT_THRESH else "black",
|
||||
)
|
||||
ax.set_title("② Integral Image\n(sumy kumulatywne)", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Formula illustration
|
||||
ax = axes[2]
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 4)
|
||||
ax.set_ylim(0, 4)
|
||||
# Draw rectangle
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0.5, 0.5), 3, 3, facecolor="white", edgecolor=LN, lw=1)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((1.5, 0.5), 2, 2, facecolor=GRAY3, edgecolor=LN, lw=2)
|
||||
)
|
||||
# Labels
|
||||
ax.text(0.3, 3.7, "A", fontsize=12, fontweight="bold")
|
||||
ax.text(3.6, 3.7, "B", fontsize=12, fontweight="bold")
|
||||
ax.text(0.3, 0.3, "C", fontsize=12, fontweight="bold")
|
||||
ax.text(3.6, 0.3, "D", fontsize=12, fontweight="bold")
|
||||
ax.text(
|
||||
2.5,
|
||||
1.5,
|
||||
"SZUKANA\nSUMA",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
2.0,
|
||||
-0.3,
|
||||
"Suma = D - B - C + A\n= 4 odczyty → O(1) ZAWSZE!",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
ax.set_title(
|
||||
"③ Formuła: 4 odczyty\n= O(1) niezależnie od rozmiaru",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_integral_image.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 11. SVM Hyperplane
|
||||
# ============================================================
|
||||
def draw_svm_hyperplane() -> None:
|
||||
"""Draw svm hyperplane."""
|
||||
fig, ax = plt.subplots(figsize=(6, 5))
|
||||
ax.set_title(
|
||||
"SVM — hiperpłaszczyzna i margines",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
x_pos = rng.standard_normal(15) * 0.5 + 3
|
||||
y_pos = rng.standard_normal(15) * 0.5 + 3
|
||||
ax.scatter(
|
||||
x_pos,
|
||||
y_pos,
|
||||
marker="o",
|
||||
s=50,
|
||||
facecolors="white",
|
||||
edgecolors=LN,
|
||||
linewidths=1.5,
|
||||
label="klasa +1 (pieszy)",
|
||||
zorder=3,
|
||||
)
|
||||
|
||||
x_neg = rng.standard_normal(15) * 0.5 + 1
|
||||
y_neg = rng.standard_normal(15) * 0.5 + 1
|
||||
ax.scatter(
|
||||
x_neg,
|
||||
y_neg,
|
||||
marker="x",
|
||||
s=50,
|
||||
c=LN,
|
||||
linewidths=1.5,
|
||||
label="klasa -1 (tło)",
|
||||
zorder=3,
|
||||
)
|
||||
|
||||
# Hyperplane (decision boundary)
|
||||
x_line = np.linspace(-0.5, 5, 100)
|
||||
y_line = -x_line + 4.0
|
||||
ax.plot(x_line, y_line, "k-", lw=2, label="hiperpłaszczyzna")
|
||||
|
||||
# Margin lines
|
||||
ax.plot(x_line, y_line + 0.7, "k--", lw=1, alpha=0.5)
|
||||
ax.plot(x_line, y_line - 0.7, "k--", lw=1, alpha=0.5)
|
||||
|
||||
# Margin annotation
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(2.5, 1.5 + 0.7),
|
||||
xytext=(2.5, 1.5 - 0.7),
|
||||
arrowprops={"arrowstyle": "<->", "color": LN, "lw": 1.5},
|
||||
)
|
||||
ax.text(2.8, 1.5, "margines\n(MAX!)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# Support vectors (highlight closest points)
|
||||
ax.scatter(
|
||||
[2.5],
|
||||
[2.2],
|
||||
marker="o",
|
||||
s=120,
|
||||
facecolors="none",
|
||||
edgecolors=LN,
|
||||
linewidths=2.5,
|
||||
zorder=4,
|
||||
)
|
||||
ax.scatter([1.5], [1.8], marker="x", s=120, c=LN, linewidths=2.5, zorder=4)
|
||||
ax.annotate(
|
||||
"support\nvectors",
|
||||
xy=(1.5, 1.8),
|
||||
xytext=(0.2, 3.0),
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
arrowprops={"arrowstyle": "->", "color": LN, "lw": 1},
|
||||
)
|
||||
|
||||
ax.set_xlim(-0.5, 5)
|
||||
ax.set_ylim(-0.5, 5)
|
||||
ax.set_xlabel("cecha 1 (np. gradient pionowy)", fontsize=FS)
|
||||
ax.set_ylabel("cecha 2 (np. gradient poziomy)", fontsize=FS)
|
||||
ax.legend(fontsize=FS_SMALL, loc="lower right")
|
||||
ax.set_aspect("equal")
|
||||
|
||||
save_fig(fig, "q24_svm_hyperplane.png")
|
||||
@ -0,0 +1,380 @@
|
||||
"""HOG + SVM pipeline, HOG gradient steps, Viola-Jones cascade."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
_DOTS_STAGE_IDX,
|
||||
_GRADIENT_BRIGHT_THRESH,
|
||||
_PIXEL_BRIGHT_THRESH,
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
np,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 1. HOG + SVM Pipeline
|
||||
# ============================================================
|
||||
def draw_hog_svm_pipeline() -> None:
|
||||
"""Draw hog svm pipeline."""
|
||||
fig, ax = plt.subplots(figsize=(10, 4.5))
|
||||
ax.set_xlim(-0.5, 10.5)
|
||||
ax.set_ylim(-1, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"HOG + SVM — pipeline detekcji pieszych",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
# Step 1: Image with sliding window
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 1.5), 2, 2, lw=1.5, edgecolor=LN, facecolor=GRAY1)
|
||||
)
|
||||
ax.text(1, 2.5, "Obraz\nwejściowy", ha="center", va="center", fontsize=FS)
|
||||
# sliding window overlay
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.3, 1.8),
|
||||
0.8,
|
||||
1.2,
|
||||
lw=1.5,
|
||||
edgecolor="black",
|
||||
facecolor="none",
|
||||
linestyle="--",
|
||||
)
|
||||
)
|
||||
ax.text(
|
||||
0.7,
|
||||
1.35,
|
||||
"okno 64x128",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
draw_arrow(ax, 2.1, 2.5, 2.8, 2.5, lw=1.5)
|
||||
ax.text(2.45, 2.75, "①", ha="center", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Step 2: Gradient computation
|
||||
draw_box(
|
||||
ax, 2.9, 1.8, 1.6, 1.4, "Oblicz\ngradienty\nGx, Gy", fill=GRAY4, fontsize=FS
|
||||
)
|
||||
ax.text(
|
||||
3.7, 1.55, "kierunek + siła", ha="center", fontsize=FS_SMALL, style="italic"
|
||||
)
|
||||
|
||||
draw_arrow(ax, 4.6, 2.5, 5.2, 2.5, lw=1.5)
|
||||
ax.text(4.9, 2.75, "②", ha="center", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Step 3: HOG histogram
|
||||
draw_box(
|
||||
ax,
|
||||
5.3,
|
||||
1.8,
|
||||
1.6,
|
||||
1.4,
|
||||
"Histogramy\nkierunkowe\n9 binów/cel",
|
||||
fill=GRAY4,
|
||||
fontsize=FS,
|
||||
)
|
||||
ax.text(6.1, 1.55, "komórki 8x8 px", ha="center", fontsize=FS_SMALL, style="italic")
|
||||
|
||||
draw_arrow(ax, 7.0, 2.5, 7.6, 2.5, lw=1.5)
|
||||
ax.text(7.3, 2.75, "③", ha="center", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Step 4: SVM
|
||||
draw_box(
|
||||
ax,
|
||||
7.7,
|
||||
1.8,
|
||||
1.4,
|
||||
1.4,
|
||||
"SVM\nklasyfikator\npieszy/tło",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
draw_arrow(ax, 9.2, 2.5, 9.7, 2.5, lw=1.5)
|
||||
ax.text(9.45, 2.75, "④", ha="center", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Step 5: NMS + output
|
||||
draw_box(ax, 9.3, 2.0, 1.0, 1.0, "NMS\n→ wynik", fill=GRAY1, fontsize=FS)
|
||||
|
||||
# Bottom: HOG feature vector illustration
|
||||
ax.text(
|
||||
5.0,
|
||||
0.7,
|
||||
"Wektor HOG: 3780 cech = 105 bloków x 4 komórki x 9 binów",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# Show small histogram bars
|
||||
bar_x = 3.2
|
||||
bar_y = 0.0
|
||||
angles = [0, 20, 40, 60, 80, 100, 120, 140, 160]
|
||||
values = [0.3, 0.1, 0.5, 0.8, 0.2, 0.6, 0.15, 0.4, 0.25]
|
||||
for i, (_a, v) in enumerate(zip(angles, values, strict=False)):
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(bar_x + i * 0.18, bar_y),
|
||||
0.15,
|
||||
v * 0.6,
|
||||
facecolor=GRAY3,
|
||||
edgecolor=LN,
|
||||
lw=0.5,
|
||||
)
|
||||
)
|
||||
ax.text(bar_x + 0.8, -0.2, "9 binów (0°-160°)", ha="center", fontsize=FS_SMALL)
|
||||
|
||||
save_fig(fig, "q24_hog_svm_pipeline.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 2. HOG Gradient Step-by-Step
|
||||
# ============================================================
|
||||
def draw_hog_gradient_steps() -> None:
|
||||
"""Draw hog gradient steps."""
|
||||
fig, axes = plt.subplots(1, 4, figsize=(12, 3.5))
|
||||
fig.suptitle(
|
||||
"HOG — kroki obliczania cech", fontsize=FS_TITLE, fontweight="bold", y=1.02
|
||||
)
|
||||
|
||||
# Step 1: Original patch
|
||||
ax = axes[0]
|
||||
patch = np.array([[50, 50, 200], [50, 50, 200], [50, 50, 200]])
|
||||
ax.imshow(patch, cmap="gray", vmin=0, vmax=255)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
str(patch[i, j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
color="white" if patch[i, j] > _PIXEL_BRIGHT_THRESH else "black",
|
||||
)
|
||||
ax.set_title("① Fragment obrazu\n(jasność pikseli)", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Step 2: Gradient magnitude
|
||||
ax = axes[1]
|
||||
gx = np.array([[0, 150, 0], [0, 150, 0], [0, 150, 0]])
|
||||
ax.imshow(gx, cmap="gray", vmin=0, vmax=255)
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
str(gx[i, j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
color="white" if gx[i, j] > _GRADIENT_BRIGHT_THRESH else "black",
|
||||
)
|
||||
ax.set_title("② Gradient Gx\n(krawędź pionowa!)", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Step 3: Cell histogram
|
||||
ax = axes[2]
|
||||
angles = ["0°", "20°", "40°", "60°", "80°", "100°", "120°", "140°", "160°"]
|
||||
values = [150, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
bars = ax.bar(range(9), values, color=GRAY3, edgecolor=LN, linewidth=0.5)
|
||||
bars[0].set_facecolor(GRAY5)
|
||||
ax.set_xticks(range(9))
|
||||
ax.set_xticklabels(angles, fontsize=5, rotation=45)
|
||||
ax.set_title(
|
||||
"③ Histogram komórki\n(bin 0° = krawędź pionowa)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_ylabel("siła", fontsize=FS_SMALL)
|
||||
|
||||
# Step 4: Block normalization
|
||||
ax = axes[3]
|
||||
# 2x2 block of cells
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
rect = mpatches.Rectangle(
|
||||
(j * 1.2, (1 - i) * 1.2),
|
||||
1.0,
|
||||
1.0,
|
||||
lw=1.2,
|
||||
edgecolor=LN,
|
||||
facecolor=GRAY4,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
j * 1.2 + 0.5,
|
||||
(1 - i) * 1.2 + 0.5,
|
||||
f"hist\n{i * 2 + j + 1}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(-0.1, -0.1), 2.6, 2.6, lw=2, edgecolor=LN, facecolor="none", linestyle="--"
|
||||
)
|
||||
)
|
||||
ax.text(
|
||||
1.2,
|
||||
-0.4,
|
||||
"blok 2x2 → L2-norm",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xlim(-0.3, 2.8)
|
||||
ax.set_ylim(-0.7, 2.8)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"④ Normalizacja bloków\n(odporność na oświetlenie)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_hog_gradient_steps.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 3. Viola-Jones Cascade
|
||||
# ============================================================
|
||||
def draw_viola_jones_cascade() -> None:
|
||||
"""Draw viola jones cascade."""
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
ax.set_xlim(-0.5, 10.5)
|
||||
ax.set_ylim(-1.5, 5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Viola-Jones — kaskada klasyfikatorów (SITO)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
# Input
|
||||
draw_box(
|
||||
ax,
|
||||
-0.3,
|
||||
2.5,
|
||||
1.5,
|
||||
1.2,
|
||||
"500 000\nokien",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
stages = [
|
||||
("Etap 1\n2 cechy", "50%\nodrzucone", "250 000", GRAY4),
|
||||
("Etap 2\n10 cech", "80%\nodrzucone", "50 000", GRAY4),
|
||||
("Etap 3\n25 cech", "90%\nodrzucone", "5 000", GRAY4),
|
||||
("Etap 25\n200 cech", "99%\nodrzucone", "50", GRAY3),
|
||||
]
|
||||
|
||||
x_pos = 1.6
|
||||
for i, (label, reject, remain, col) in enumerate(stages):
|
||||
# Stage box
|
||||
draw_box(
|
||||
ax, x_pos, 2.5, 1.6, 1.2, label, fill=col, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
# Arrow from previous
|
||||
draw_arrow(ax, x_pos - 0.3, 3.1, x_pos - 0.05, 3.1, lw=1.5)
|
||||
|
||||
# Reject arrow down
|
||||
draw_arrow(ax, x_pos + 0.8, 2.45, x_pos + 0.8, 1.6, lw=1.2)
|
||||
ax.text(
|
||||
x_pos + 0.8,
|
||||
1.3,
|
||||
reject,
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
color="black",
|
||||
style="italic",
|
||||
)
|
||||
ax.text(
|
||||
x_pos + 0.8,
|
||||
0.8,
|
||||
"✗ NIE-TWARZ",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Remaining count above
|
||||
if i < len(stages) - 1:
|
||||
ax.text(
|
||||
x_pos + 2.0,
|
||||
3.9,
|
||||
f"→ {remain}",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Dots between stage 3 and stage 25
|
||||
if i == _DOTS_STAGE_IDX:
|
||||
ax.text(
|
||||
x_pos + 2.0, 3.1, "· · ·", ha="center", fontsize=12, fontweight="bold"
|
||||
)
|
||||
x_pos += 2.5
|
||||
else:
|
||||
x_pos += 2.1
|
||||
|
||||
# Final output
|
||||
draw_arrow(ax, x_pos + 0.3, 3.1, x_pos + 0.9, 3.1, lw=1.5)
|
||||
draw_box(
|
||||
ax,
|
||||
x_pos + 0.5,
|
||||
2.5,
|
||||
1.3,
|
||||
1.2,
|
||||
"~50\nTWARZE\n✓",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Timing info
|
||||
ax.text(
|
||||
5.0,
|
||||
-0.5,
|
||||
"Czas: 99% okien odrzucone w etapach 1-3 (~5 μs każde)\n"
|
||||
"Tylko 0.01% dochodzi do etapu 25 → cały obraz w ~30 ms = 30+ fps",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.4", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_viola_jones_cascade.png")
|
||||
@ -0,0 +1,413 @@
|
||||
"""IoU diagram, NMS steps, and detector-from-classifier diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 8. IoU Diagram
|
||||
# ============================================================
|
||||
def draw_iou_diagram() -> None:
|
||||
"""Draw iou diagram."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
|
||||
fig.suptitle(
|
||||
"IoU (Intersection over Union) — miara nakładania bboxów",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Low IoU
|
||||
ax = axes[0]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5, label="A"
|
||||
)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(2.5, 2.5),
|
||||
3,
|
||||
3,
|
||||
facecolor=GRAY2,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
alpha=0.7,
|
||||
label="B",
|
||||
)
|
||||
)
|
||||
# Intersection
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((2.5, 2.5), 0.5, 0.5, facecolor=GRAY3, edgecolor=LN, lw=2)
|
||||
)
|
||||
ax.text(1.5, 1.5, "A", ha="center", va="center", fontsize=12, fontweight="bold")
|
||||
ax.text(4, 4, "B", ha="center", va="center", fontsize=12, fontweight="bold")
|
||||
ax.set_xlim(-0.5, 6)
|
||||
ax.set_ylim(-0.5, 6)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"IoU ≈ 0.04\n(prawie się nie nakładają)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
# Medium IoU
|
||||
ax = axes[1]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(1.5, 1.5), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7
|
||||
)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((1.5, 1.5), 1.5, 1.5, facecolor=GRAY3, edgecolor=LN, lw=2)
|
||||
)
|
||||
ax.text(0.7, 0.7, "A", ha="center", va="center", fontsize=12, fontweight="bold")
|
||||
ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold")
|
||||
ax.text(2.25, 2.25, "∩", ha="center", va="center", fontsize=14, fontweight="bold")
|
||||
ax.set_xlim(-0.5, 5)
|
||||
ax.set_ylim(-0.5, 5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("IoU ≈ 0.14\n(częściowe nakładanie)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# High IoU
|
||||
ax = axes[2]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 3, 3, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.3, 0.3), 3, 3, facecolor=GRAY2, edgecolor=LN, lw=1.5, alpha=0.7
|
||||
)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0.3, 0.3), 2.7, 2.7, facecolor=GRAY3, edgecolor=LN, lw=2)
|
||||
)
|
||||
ax.text(-0.3, -0.3, "A", ha="center", va="center", fontsize=12, fontweight="bold")
|
||||
ax.text(3.5, 3.5, "B", ha="center", va="center", fontsize=12, fontweight="bold")
|
||||
ax.text(1.65, 1.65, "∩", ha="center", va="center", fontsize=14, fontweight="bold")
|
||||
ax.set_xlim(-0.8, 4)
|
||||
ax.set_ylim(-0.8, 4)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"IoU ≈ 0.74\n(duże nakładanie → duplikat!)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_iou_diagram.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 9. NMS Step-by-Step
|
||||
# ============================================================
|
||||
def draw_nms_steps() -> None:
|
||||
"""Draw nms steps."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
||||
fig.suptitle(
|
||||
"NMS (Non-Maximum Suppression) — usuwanie duplikatów",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Before NMS
|
||||
ax = axes[0]
|
||||
ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1))
|
||||
# Multiple overlapping boxes for same object
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2)
|
||||
)
|
||||
ax.text(2.25, 4.2, "conf=0.95", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(1.2, 1.3), 2.3, 2.8, facecolor="none", edgecolor=LN, lw=1.5, linestyle="--"
|
||||
)
|
||||
)
|
||||
ax.text(2.35, 1.1, "conf=0.90", ha="center", fontsize=FS_SMALL)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.8, 0.8), 2.7, 3.2, facecolor="none", edgecolor=LN, lw=1, linestyle=":"
|
||||
)
|
||||
)
|
||||
ax.text(2.15, 0.6, "conf=0.85", ha="center", fontsize=FS_SMALL)
|
||||
# Different object
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=1.5)
|
||||
)
|
||||
ax.text(4.75, 3.7, "conf=0.80", ha="center", fontsize=FS_SMALL)
|
||||
ax.text(
|
||||
2,
|
||||
0.2,
|
||||
"⚠ 4 detekcje (3 duplikaty!)",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xlim(-0.3, 6.3)
|
||||
ax.set_ylim(-0.3, 5.3)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"① Przed NMS\n(wiele nakładających się)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
# NMS process
|
||||
ax = axes[1]
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 6)
|
||||
ax.set_ylim(0, 5)
|
||||
|
||||
steps = [
|
||||
("1. Sortuj: [0.95, 0.90, 0.85, 0.80]", 4.5),
|
||||
("2. Weź najlepszą (0.95) → ZACHOWAJ", 3.7),
|
||||
("3. IoU(0.95, 0.90)=0.82 > 0.5 → USUŃ", 2.9),
|
||||
("4. IoU(0.95, 0.85)=0.75 > 0.5 → USUŃ", 2.1),
|
||||
("5. IoU(0.95, 0.80)=0.10 < 0.5 → ZACHOWAJ", 1.3),
|
||||
]
|
||||
colors = [GRAY4, GRAY2, GRAY4, GRAY4, GRAY2]
|
||||
for (text, yp), c in zip(steps, colors, strict=False):
|
||||
ax.text(
|
||||
3.0,
|
||||
yp,
|
||||
text,
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": c, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.set_title("② Algorytm NMS\n(próg IoU = 0.5)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# After NMS
|
||||
ax = axes[2]
|
||||
ax.add_patch(mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1))
|
||||
# Only best box for each object
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((1, 1), 2.5, 3, facecolor="none", edgecolor=LN, lw=2.5)
|
||||
)
|
||||
ax.text(2.25, 4.2, "conf=0.95 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((4, 2), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5)
|
||||
)
|
||||
ax.text(4.75, 3.7, "conf=0.80 ✓", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.text(
|
||||
3,
|
||||
0.2,
|
||||
"✓ 2 unikalne obiekty",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xlim(-0.3, 6.3)
|
||||
ax.set_ylim(-0.3, 5.3)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("③ Po NMS\n(1 bbox na obiekt)", fontsize=FS, fontweight="bold")
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_nms_steps.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 10. Detector from Classifier — 3 approaches
|
||||
# ============================================================
|
||||
def draw_detector_from_classifier() -> None:
|
||||
"""Draw detector from classifier."""
|
||||
fig, ax = plt.subplots(figsize=(11, 9))
|
||||
ax.set_xlim(-0.5, 11)
|
||||
ax.set_ylim(-1, 9.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Jak zbudować detektor z klasyfikatora? — 3 podejścia",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
# ---- Approach 1: Sliding Window ----
|
||||
y = 7.0
|
||||
ax.text(
|
||||
0,
|
||||
y + 1.5,
|
||||
"① Sliding Window (NAJWOLNIEJSZE)",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# Image with sliding window
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5
|
||||
)
|
||||
)
|
||||
ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL)
|
||||
# Sliding windows
|
||||
for dx, dy in [(0.1, 0.1), (0.4, 0.1), (0.7, 0.1), (0.1, 0.5), (0.4, 0.5)]:
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(dx, y - 0.5 + dy),
|
||||
0.5,
|
||||
0.5,
|
||||
facecolor="none",
|
||||
edgecolor=LN,
|
||||
lw=0.8,
|
||||
linestyle="--",
|
||||
)
|
||||
)
|
||||
|
||||
draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2)
|
||||
ax.text(2.35, y + 0.6, "xmiliony", fontsize=FS_SMALL, style="italic")
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
2.8,
|
||||
y - 0.3,
|
||||
1.8,
|
||||
1.2,
|
||||
'Klasyfikator\n(ResNet)\n"kot? pies? tło?"',
|
||||
fill=GRAY4,
|
||||
fontsize=FS,
|
||||
)
|
||||
draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2)
|
||||
draw_box(ax, 5.4, y - 0.3, 1.2, 1.2, "NMS", fill=GRAY1, fontsize=FS)
|
||||
draw_arrow(ax, 6.7, y + 0.3, 7.3, y + 0.3, lw=1.2)
|
||||
ax.text(
|
||||
8.5,
|
||||
y + 0.3,
|
||||
"~3.3h / obraz!\n⚠ NIEPRAKTYCZNE",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# ---- Approach 2: Region Proposals ----
|
||||
y = 3.8
|
||||
ax.text(
|
||||
0,
|
||||
y + 1.5,
|
||||
"② Region Proposals + Klasyfikator (= R-CNN)",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5
|
||||
)
|
||||
)
|
||||
ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL)
|
||||
# A few smart regions
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.1, y - 0.4), 0.7, 0.9, facecolor="none", edgecolor=LN, lw=1.5
|
||||
)
|
||||
)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.9, y + 0.0), 0.7, 0.6, facecolor="none", edgecolor=LN, lw=1.5
|
||||
)
|
||||
)
|
||||
|
||||
draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2)
|
||||
draw_box(
|
||||
ax,
|
||||
2.8,
|
||||
y - 0.3,
|
||||
1.6,
|
||||
1.2,
|
||||
"Selective\nSearch\n~2000 regionów",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
)
|
||||
draw_arrow(ax, 4.5, y + 0.3, 5.1, y + 0.3, lw=1.2)
|
||||
ax.text(4.8, y + 0.6, "x2000", fontsize=FS_SMALL, style="italic")
|
||||
draw_box(ax, 5.2, y - 0.3, 1.5, 1.2, "Klasyfikator\n(CNN)", fill=GRAY4, fontsize=FS)
|
||||
draw_arrow(ax, 6.8, y + 0.3, 7.4, y + 0.3, lw=1.2)
|
||||
draw_box(ax, 7.5, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS)
|
||||
draw_arrow(ax, 8.6, y + 0.3, 9.0, y + 0.3, lw=1.2)
|
||||
ax.text(
|
||||
10.0,
|
||||
y + 0.3,
|
||||
"~20-50 s/obraz\n(250x szybciej)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# ---- Approach 3: Fine-tune backbone ----
|
||||
y = 0.5
|
||||
ax.text(
|
||||
0,
|
||||
y + 1.5,
|
||||
"③ Fine-tune backbone + detection head (NAJLEPSZE)",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY2, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0, y - 0.6), 1.8, 1.8, facecolor=GRAY1, edgecolor=LN, lw=1.5
|
||||
)
|
||||
)
|
||||
ax.text(0.9, y + 0.3, "obraz", ha="center", fontsize=FS_SMALL)
|
||||
|
||||
draw_arrow(ax, 2.0, y + 0.3, 2.7, y + 0.3, lw=1.2)
|
||||
draw_box(
|
||||
ax,
|
||||
2.8,
|
||||
y - 0.3,
|
||||
1.8,
|
||||
1.2,
|
||||
"Pretrained\nbackbone\n(ResNet)",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 4.7, y + 0.3, 5.3, y + 0.3, lw=1.2)
|
||||
|
||||
# Two heads from feature map
|
||||
draw_box(ax, 5.4, y + 0.3, 1.6, 0.6, "cls head\nP(klasa)", fill=GRAY4, fontsize=FS)
|
||||
draw_box(
|
||||
ax, 5.4, y - 0.5, 1.6, 0.6, "bbox head\nΔx,Δy,Δw,Δh", fill=GRAY4, fontsize=FS
|
||||
)
|
||||
|
||||
draw_arrow(ax, 7.1, y + 0.6, 7.7, y + 0.6, lw=1.0)
|
||||
draw_arrow(ax, 7.1, y - 0.2, 7.7, y - 0.2, lw=1.0)
|
||||
draw_box(ax, 7.8, y - 0.3, 1.0, 1.2, "NMS", fill=GRAY1, fontsize=FS)
|
||||
|
||||
draw_arrow(ax, 8.9, y + 0.3, 9.3, y + 0.3, lw=1.2)
|
||||
ax.text(
|
||||
10.2,
|
||||
y + 0.3,
|
||||
"5-155 fps!\n✓ NAJLEPSZE",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_detector_from_classifier.png")
|
||||
@ -0,0 +1,365 @@
|
||||
"""Two-stage vs one-stage table, ROI pooling, DETR, and sliding window."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q24_common import (
|
||||
_DATA_BRIGHT_THRESH,
|
||||
FS,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_table,
|
||||
np,
|
||||
plt,
|
||||
rng,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 12. Two-stage vs One-stage comparison table
|
||||
# ============================================================
|
||||
def draw_two_vs_one_stage() -> None:
|
||||
"""Draw two vs one stage."""
|
||||
fig, ax = plt.subplots(figsize=(10, 3.5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(-0.5, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Two-stage vs One-stage — porównanie",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=8,
|
||||
)
|
||||
|
||||
headers = ["Cecha", "Two-stage\n(Faster R-CNN)", "One-stage\n(YOLO)"]
|
||||
rows = [
|
||||
["Szybkość", "~5 fps", "45-155 fps"],
|
||||
["Dokładność (mAP)", "wyższa (historycznie)", "dorównuje (YOLOv8)"],
|
||||
["Małe obiekty", "lepszy", "gorszy (SSD/FPN pomaga)"],
|
||||
["Architektura", "2 etapy + NMS", "1 etap + NMS"],
|
||||
["Real-time?", "NIE", "TAK"],
|
||||
]
|
||||
col_widths = [2.5, 3.5, 3.5]
|
||||
draw_table(
|
||||
ax,
|
||||
headers,
|
||||
rows,
|
||||
0.2,
|
||||
3.8,
|
||||
col_widths,
|
||||
row_h=0.65,
|
||||
fontsize=FS,
|
||||
header_fontsize=FS,
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_two_vs_one_stage.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 13. ROI Pooling illustration
|
||||
# ============================================================
|
||||
def draw_roi_pooling() -> None:
|
||||
"""Draw roi pooling."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
||||
fig.suptitle(
|
||||
"ROI Pooling — dowolny rozmiar → stały rozmiar",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Feature map with ROI
|
||||
ax = axes[0]
|
||||
# Draw feature map grid
|
||||
fm = rng.integers(0, 10, (8, 8))
|
||||
ax.imshow(fm, cmap="gray", vmin=0, vmax=10, alpha=0.3)
|
||||
for i in range(9):
|
||||
ax.axhline(y=i - 0.5, color=LN, lw=0.3)
|
||||
ax.axvline(x=i - 0.5, color=LN, lw=0.3)
|
||||
# ROI rectangle
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(1.5, 1.5), 4, 4, facecolor="none", edgecolor=LN, lw=3, linestyle="-"
|
||||
)
|
||||
)
|
||||
ax.text(3.5, 0.8, "ROI", ha="center", fontsize=FS, fontweight="bold")
|
||||
ax.set_xlim(-0.5, 7.5)
|
||||
ax.set_ylim(7.5, -0.5)
|
||||
ax.set_title("① Feature map\nz zaznaczonym ROI", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# ROI divided into grid
|
||||
ax = axes[1]
|
||||
roi_data = np.array(
|
||||
[
|
||||
[1, 3, 2, 1],
|
||||
[0, 5, 1, 6],
|
||||
[0, 4, 1, 0],
|
||||
[7, 2, 9, 1],
|
||||
]
|
||||
)
|
||||
ax.imshow(roi_data, cmap="gray", vmin=0, vmax=10)
|
||||
for i in range(5):
|
||||
ax.axhline(y=i - 0.5, color=LN, lw=1)
|
||||
ax.axvline(x=i - 0.5, color=LN, lw=1)
|
||||
# Grid lines for 2x2 pooling
|
||||
ax.axhline(y=0.5, color=LN, lw=3, linestyle="--")
|
||||
ax.axvline(x=0.5, color=LN, lw=3, linestyle="--")
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
str(roi_data[i, j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
color="white" if roi_data[i, j] > _DATA_BRIGHT_THRESH else "black",
|
||||
)
|
||||
ax.set_title("② ROI podzielony\nna siatkę 2x2", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Output after pooling
|
||||
ax = axes[2]
|
||||
out = np.array([[5, 6], [7, 9]])
|
||||
ax.imshow(out, cmap="gray", vmin=0, vmax=10)
|
||||
for i in range(3):
|
||||
ax.axhline(y=i - 0.5, color=LN, lw=1.5)
|
||||
ax.axvline(x=i - 0.5, color=LN, lw=1.5)
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
str(out[i, j]),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
color="white" if out[i, j] > _DATA_BRIGHT_THRESH else "black",
|
||||
)
|
||||
ax.set_title(
|
||||
"③ Po ROI Pool 2x2\n(max z każdej komórki)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_roi_pooling.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 14. DETR Pipeline
|
||||
# ============================================================
|
||||
def draw_detr_pipeline() -> None:
|
||||
"""Draw detr pipeline."""
|
||||
fig, ax = plt.subplots(figsize=(11, 4.5))
|
||||
ax.set_xlim(-0.5, 11.5)
|
||||
ax.set_ylim(-1, 4.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"DETR — Transformer do detekcji (bez NMS, bez anchorów)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
# Pipeline
|
||||
draw_box(ax, 0, 1.5, 1.5, 1.5, "Obraz\nwejściowy", fill=GRAY1, fontsize=FS)
|
||||
draw_arrow(ax, 1.6, 2.25, 2.1, 2.25, lw=1.5)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
2.2,
|
||||
1.5,
|
||||
1.5,
|
||||
1.5,
|
||||
"CNN\nBackbone\n(ResNet)",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 3.8, 2.25, 4.3, 2.25, lw=1.5)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
4.4,
|
||||
1.5,
|
||||
1.8,
|
||||
1.5,
|
||||
"Transformer\nEncoder\n(self-attention)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
)
|
||||
draw_arrow(ax, 6.3, 2.25, 6.8, 2.25, lw=1.5)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
6.9,
|
||||
1.5,
|
||||
1.8,
|
||||
1.5,
|
||||
"Transformer\nDecoder\n(N=100 queries)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Output branches
|
||||
draw_arrow(ax, 8.8, 2.5, 9.5, 3.0, lw=1.2)
|
||||
draw_box(ax, 9.6, 2.7, 1.5, 0.7, "klasa₁...klasa₁₀₀", fill=GRAY4, fontsize=FS_SMALL)
|
||||
|
||||
draw_arrow(ax, 8.8, 2.0, 9.5, 1.5, lw=1.2)
|
||||
draw_box(ax, 9.6, 1.2, 1.5, 0.7, "bbox₁...bbox₁₀₀", fill=GRAY4, fontsize=FS_SMALL)
|
||||
|
||||
# Annotations
|
||||
ax.text(
|
||||
7.8,
|
||||
0.5,
|
||||
'100 object queries → 5 obiektów + 95x "brak"',
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.5,
|
||||
0.0,
|
||||
"Hungarian matching (trening): optymalne dopasowanie predykcji do GT",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY5},
|
||||
)
|
||||
|
||||
# Big benefit box
|
||||
ax.text(
|
||||
5.5,
|
||||
4.0,
|
||||
"BEZ anchorów • BEZ NMS • end-to-end • prosty pipeline",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY2, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_detr_pipeline.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 15. Sliding Window illustration
|
||||
# ============================================================
|
||||
def draw_sliding_window() -> None:
|
||||
"""Draw sliding window."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
||||
fig.suptitle(
|
||||
"Sliding Window — najprostsze podejście do detekcji",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Multi-position
|
||||
ax = axes[0]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 8, 6, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
# Grid of sliding windows
|
||||
for i in range(4):
|
||||
for j in range(3):
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(i * 1.8 + 0.2, j * 1.8 + 0.2),
|
||||
1.5,
|
||||
1.5,
|
||||
facecolor="none",
|
||||
edgecolor=LN,
|
||||
lw=0.6,
|
||||
linestyle="--",
|
||||
)
|
||||
)
|
||||
# Highlight current window
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((2.0, 2.0), 1.5, 1.5, facecolor="none", edgecolor=LN, lw=2.5)
|
||||
)
|
||||
ax.set_xlim(-0.5, 8.5)
|
||||
ax.set_ylim(-0.5, 6.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("① Wiele pozycji\n(krok co 8 px)", fontsize=FS, fontweight="bold")
|
||||
|
||||
# Multi-scale
|
||||
ax = axes[1]
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle((0, 0), 6, 5, facecolor=GRAY4, edgecolor=LN, lw=1.5)
|
||||
)
|
||||
sizes = [(0.8, 0.8), (1.5, 1.5), (2.5, 2.5), (3.5, 3.5)]
|
||||
for i, (w, h) in enumerate(sizes):
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.3 + i * 0.3, 0.3 + i * 0.3),
|
||||
w,
|
||||
h,
|
||||
facecolor="none",
|
||||
edgecolor=LN,
|
||||
lw=1 + i * 0.3,
|
||||
linestyle=[":", "--", "-.", "-"][i],
|
||||
)
|
||||
)
|
||||
ax.text(3, 0, "4+ skal", ha="center", fontsize=FS_SMALL, fontweight="bold")
|
||||
ax.set_xlim(-0.5, 6.5)
|
||||
ax.set_ylim(-0.5, 5.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"② Wiele skal\n(obiekty mają różne rozmiary)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
# Count
|
||||
ax = axes[2]
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 6)
|
||||
ax.set_ylim(0, 5)
|
||||
|
||||
lines = [
|
||||
("Obraz: 640 x 480 px", 4.5),
|
||||
("Okno: 64 x 64 px, krok 8 px", 3.8),
|
||||
("Pozycje: ~72 x 52 = 3 744", 3.1),
|
||||
("x 5 skal = 18 720 okien", 2.4),
|
||||
("x klasyfikacja = WOLNE!", 1.7),
|
||||
("→ ~3h na jeden obraz", 0.8),
|
||||
]
|
||||
for text, yp in lines:
|
||||
fw = "bold" if "~3h" in text or "WOLNE" in text else "normal"
|
||||
col = GRAY2 if "WOLNE" in text or "~3h" in text else GRAY4
|
||||
ax.text(
|
||||
3.0,
|
||||
yp,
|
||||
text,
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight=fw,
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": col, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.set_title(
|
||||
"③ Dlaczego wolne?\n(miliony klasyfikacji)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_sliding_window.png")
|
||||
@ -0,0 +1,344 @@
|
||||
"""R-CNN evolution and YOLO grid diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _q24_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
plt,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def _draw_yolo_cell_prediction(ax: Axes) -> None:
|
||||
"""Draw YOLO cell prediction vector panel."""
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 6)
|
||||
ax.set_ylim(-1, 5)
|
||||
|
||||
labels = [
|
||||
"x",
|
||||
"y",
|
||||
"w",
|
||||
"h",
|
||||
"conf",
|
||||
"x",
|
||||
"y",
|
||||
"w",
|
||||
"h",
|
||||
"conf",
|
||||
"P(c₁)",
|
||||
"...",
|
||||
"P(c₂₀)",
|
||||
]
|
||||
colors_vec = [GRAY4] * 5 + [GRAY2] * 5 + [GRAY1] * 3
|
||||
|
||||
bw = 0.42
|
||||
for i, (label, col) in enumerate(
|
||||
zip(labels, colors_vec, strict=False),
|
||||
):
|
||||
x_pos = 0.3 + i * bw
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(x_pos, 2.5),
|
||||
bw - 0.02,
|
||||
0.6,
|
||||
facecolor=col,
|
||||
edgecolor=LN,
|
||||
lw=0.8,
|
||||
)
|
||||
)
|
||||
ax.text(
|
||||
x_pos + bw / 2,
|
||||
2.8,
|
||||
label,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=5,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(0.3, 2.4),
|
||||
xytext=(2.4, 2.4),
|
||||
arrowprops={"arrowstyle": "-", "lw": 1},
|
||||
)
|
||||
ax.text(1.35, 2.15, "bbox 1 (5 wartości)", ha="center", fontsize=FS_SMALL)
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(2.4, 2.4),
|
||||
xytext=(4.5, 2.4),
|
||||
arrowprops={"arrowstyle": "-", "lw": 1},
|
||||
)
|
||||
ax.text(3.45, 2.15, "bbox 2 (5 wartości)", ha="center", fontsize=FS_SMALL)
|
||||
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(4.5, 2.4),
|
||||
xytext=(5.8, 2.4),
|
||||
arrowprops={"arrowstyle": "-", "lw": 1},
|
||||
)
|
||||
ax.text(5.15, 2.15, "20 klas", ha="center", fontsize=FS_SMALL)
|
||||
|
||||
ax.text(
|
||||
3.0,
|
||||
3.5,
|
||||
"Każda komórka → 30 wartości\n= 2x(x,y,w,h,conf) + 20 klas",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.set_title(
|
||||
"② Predykcja jednej komórki\n(S=7, B=2, C=20)",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 6. R-CNN Evolution
|
||||
# ============================================================
|
||||
def draw_rcnn_evolution() -> None:
|
||||
"""Draw rcnn evolution."""
|
||||
fig, ax = plt.subplots(figsize=(11, 7))
|
||||
ax.set_xlim(-0.5, 11)
|
||||
ax.set_ylim(-0.5, 7.5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Ewolucja R-CNN: od 50s do 0.2s na obraz",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
y_positions = [5.5, 3.0, 0.5]
|
||||
labels = [
|
||||
"R-CNN (2014) — 50 s/obraz",
|
||||
"Fast R-CNN (2015) — 2 s/obraz",
|
||||
"Faster R-CNN (2015) — 0.2 s/obraz",
|
||||
]
|
||||
|
||||
# R-CNN
|
||||
y = y_positions[0]
|
||||
ax.text(0, y + 1.3, labels[0], fontsize=FS_LABEL, fontweight="bold")
|
||||
draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS)
|
||||
draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45)
|
||||
ax.text(2.3, y + 0.8, "~2000", ha="center", fontsize=FS_SMALL, style="italic")
|
||||
draw_box(ax, 2.6, y, 1.5, 0.9, "Resize\n224x224", fill=GRAY4, fontsize=FS)
|
||||
draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45)
|
||||
draw_box(
|
||||
ax, 4.7, y, 1.5, 0.9, "CNN\nx2000!", fill=GRAY3, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45)
|
||||
draw_box(ax, 6.8, y, 1.3, 0.9, "SVM\nklasyf.", fill=GRAY4, fontsize=FS)
|
||||
draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45)
|
||||
draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS)
|
||||
# Problem annotation
|
||||
ax.text(
|
||||
5.5,
|
||||
y - 0.4,
|
||||
"⚠ CNN uruchamiane 2000x → 50 sek!",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# Fast R-CNN
|
||||
y = y_positions[1]
|
||||
ax.text(0, y + 1.3, labels[1], fontsize=FS_LABEL, fontweight="bold")
|
||||
draw_box(ax, 0, y, 2, 0.9, "Selective\nSearch", fill=GRAY2, fontsize=FS)
|
||||
draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45)
|
||||
draw_box(
|
||||
ax,
|
||||
2.6,
|
||||
y,
|
||||
1.5,
|
||||
0.9,
|
||||
"CNN\nx1 (RAZ!)",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45)
|
||||
draw_box(
|
||||
ax, 4.7, y, 1.5, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_arrow(ax, 6.3, y + 0.45, 6.7, y + 0.45)
|
||||
draw_box(ax, 6.8, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS)
|
||||
draw_arrow(ax, 8.2, y + 0.45, 8.6, y + 0.45)
|
||||
draw_box(ax, 8.7, y, 1.0, 0.9, "NMS", fill=GRAY1, fontsize=FS)
|
||||
ax.text(
|
||||
3.8,
|
||||
y - 0.4,
|
||||
"✓ CNN RAZ na cały obraz → 25x szybciej",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# Faster R-CNN
|
||||
y = y_positions[2]
|
||||
ax.text(0, y + 1.3, labels[2], fontsize=FS_LABEL, fontweight="bold")
|
||||
draw_box(
|
||||
ax,
|
||||
0.5,
|
||||
y,
|
||||
1.5,
|
||||
0.9,
|
||||
"CNN\nBackbone",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 2.1, y + 0.45, 2.5, y + 0.45)
|
||||
draw_box(ax, 2.6, y, 1.5, 0.9, "Feature\nMap", fill=GRAY1, fontsize=FS)
|
||||
draw_arrow(ax, 4.2, y + 0.45, 4.6, y + 0.45)
|
||||
draw_box(
|
||||
ax,
|
||||
4.7,
|
||||
y,
|
||||
1.3,
|
||||
0.9,
|
||||
"RPN\n(w sieci!)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 6.1, y + 0.45, 6.5, y + 0.45)
|
||||
draw_box(ax, 6.6, y, 1.3, 0.9, "ROI\nPooling", fill=GRAY1, fontsize=FS)
|
||||
draw_arrow(ax, 8.0, y + 0.45, 8.4, y + 0.45)
|
||||
draw_box(ax, 8.5, y, 1.3, 0.9, "FC\nklasa+bbox", fill=GRAY4, fontsize=FS)
|
||||
ax.text(
|
||||
5.0,
|
||||
y - 0.4,
|
||||
"✓ RPN zastępuje Selective Search → end-to-end",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q24_rcnn_evolution.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 7. YOLO Grid
|
||||
# ============================================================
|
||||
def draw_yolo_grid() -> None:
|
||||
"""Draw yolo grid."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
||||
fig.suptitle(
|
||||
"YOLO — detekcja jednoetapowa (siatka SxS)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
# Grid on image
|
||||
ax = axes[0]
|
||||
grid_size = 7
|
||||
ax.set_xlim(0, grid_size)
|
||||
ax.set_ylim(0, grid_size)
|
||||
for i in range(grid_size + 1):
|
||||
ax.axhline(y=i, color=LN, lw=0.5, alpha=0.5)
|
||||
ax.axvline(x=i, color=LN, lw=0.5, alpha=0.5)
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0, 0),
|
||||
grid_size,
|
||||
grid_size,
|
||||
facecolor=GRAY4,
|
||||
edgecolor=LN,
|
||||
lw=1.5,
|
||||
)
|
||||
)
|
||||
# Highlight one cell
|
||||
ax.add_patch(mpatches.Rectangle((3, 3), 1, 1, facecolor=GRAY2, edgecolor=LN, lw=2))
|
||||
# Object center dot
|
||||
ax.plot(3.5, 3.5, "ko", markersize=8)
|
||||
# Bounding box from that cell
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(2.0, 2.2), 3.0, 2.6, facecolor="none", edgecolor=LN, lw=2, linestyle="--"
|
||||
)
|
||||
)
|
||||
ax.text(
|
||||
3.5,
|
||||
1.8,
|
||||
"bbox z komórki (3,3)",
|
||||
ha="center",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_aspect("equal")
|
||||
ax.invert_yaxis()
|
||||
ax.set_title("① Siatka 7x7\nna obrazie", fontsize=FS, fontweight="bold")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
_draw_yolo_cell_prediction(axes[1])
|
||||
|
||||
# Speed comparison
|
||||
ax = axes[2]
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 5)
|
||||
ax.set_ylim(0, 5)
|
||||
|
||||
methods = ["R-CNN", "Fast R-CNN", "Faster R-CNN", "YOLO", "YOLOv8"]
|
||||
fps_vals = [0.02, 0.5, 5, 45, 100]
|
||||
bar_colors = [GRAY3, GRAY3, GRAY3, GRAY2, GRAY1]
|
||||
|
||||
for i, (m, f, c) in enumerate(zip(methods, fps_vals, bar_colors, strict=False)):
|
||||
bar_w = f / 100 * 4.0
|
||||
y_pos = 4.0 - i * 0.8
|
||||
ax.add_patch(
|
||||
mpatches.Rectangle(
|
||||
(0.5, y_pos), max(bar_w, 0.1), 0.5, facecolor=c, edgecolor=LN, lw=0.8
|
||||
)
|
||||
)
|
||||
ax.text(
|
||||
0.4,
|
||||
y_pos + 0.25,
|
||||
m,
|
||||
ha="right",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
max(0.7, 0.5 + bar_w + 0.1),
|
||||
y_pos + 0.25,
|
||||
f"{f} fps",
|
||||
ha="left",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
)
|
||||
|
||||
ax.set_title(
|
||||
"③ Porównanie szybkości\n(fps = klatki/sek)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q24_yolo_grid.png")
|
||||
@ -0,0 +1,102 @@
|
||||
"""Common constants and utilities for Q31 diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
DPI = 300
|
||||
BG = "white"
|
||||
LN = "black"
|
||||
FS = 8
|
||||
FS_TITLE = 11
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
GRAY1 = "#E8E8E8"
|
||||
GRAY2 = "#D0D0D0"
|
||||
GRAY3 = "#B8B8B8"
|
||||
GRAY4 = "#F5F5F5"
|
||||
GRAY5 = "#C0C0C0"
|
||||
|
||||
# Number of regret table header columns before the max-regret column
|
||||
_REGRET_HEADER_COLS = 4
|
||||
# Number of data state columns
|
||||
_DATA_STATE_COLS = 3
|
||||
# Expected-value for the winning alternative
|
||||
_WINNING_EV = 95
|
||||
|
||||
|
||||
def draw_box(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
text: str,
|
||||
*,
|
||||
fill: str = "white",
|
||||
lw: float = 1.2,
|
||||
fontsize: float = FS,
|
||||
fontweight: str = "normal",
|
||||
ha: str = "center",
|
||||
va: str = "center",
|
||||
rounded: bool = True,
|
||||
) -> None:
|
||||
"""Draw a labeled box on the axes."""
|
||||
if rounded:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=lw,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
else:
|
||||
rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2,
|
||||
text,
|
||||
ha=ha,
|
||||
va=va,
|
||||
fontsize=fontsize,
|
||||
fontweight=fontweight,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
|
||||
def draw_arrow(
|
||||
ax: Axes,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
*,
|
||||
lw: float = 1.2,
|
||||
style: str = "->",
|
||||
color: str = LN,
|
||||
) -> None:
|
||||
"""Draw an arrow between two points."""
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={
|
||||
"arrowstyle": style,
|
||||
"color": color,
|
||||
"lw": lw,
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,256 @@
|
||||
"""Q31 Diagram 1: Payoff matrix + all criteria bar chart."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
|
||||
_DATA_STATE_COLS,
|
||||
BG,
|
||||
DPI,
|
||||
FS,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
_logger,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def _draw_payoff_table(ax: Axes) -> None:
|
||||
"""Draw the payoff matrix table on the left panel."""
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 6)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_title(
|
||||
"Macierz wypłat (tys. zł)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=8,
|
||||
)
|
||||
|
||||
headers_col = ["", "S₁\n(dobra)", "S₂\n(średnia)", "S₃\n(zła)"]
|
||||
rows = [
|
||||
["A₁ (fabryka)", "200", "50", "-100"],
|
||||
["A₂ (sklep)", "80", "70", "40"],
|
||||
["A₃ (obligacje)", "30", "30", "30"],
|
||||
]
|
||||
|
||||
col_w = [1.8, 1.2, 1.2, 1.2]
|
||||
row_h = 0.7
|
||||
start_y = 4.5
|
||||
start_x = 0.2
|
||||
|
||||
# Draw header row
|
||||
x = start_x
|
||||
for j, h in enumerate(headers_col):
|
||||
fill = GRAY2 if j > 0 else GRAY3
|
||||
rect = mpatches.Rectangle(
|
||||
(x, start_y),
|
||||
col_w[j],
|
||||
row_h,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + col_w[j] / 2,
|
||||
start_y + row_h / 2,
|
||||
h,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
x += col_w[j]
|
||||
|
||||
# Draw data rows
|
||||
for i, row in enumerate(rows):
|
||||
x = start_x
|
||||
y = start_y - (i + 1) * row_h
|
||||
for j, val in enumerate(row):
|
||||
fill = GRAY4 if j == 0 else ("white" if i % 2 == 0 else GRAY1)
|
||||
if val.startswith("-"):
|
||||
fill = "#D8D8D8"
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
col_w[j],
|
||||
row_h,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
fw = "bold" if j == 0 else "normal"
|
||||
ax.text(
|
||||
x + col_w[j] / 2,
|
||||
y + row_h / 2,
|
||||
val,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight=fw,
|
||||
)
|
||||
x += col_w[j]
|
||||
|
||||
# Probability row for EV
|
||||
x = start_x
|
||||
y = start_y - 4 * row_h
|
||||
probs = ["p (dla E[X]):", "0.5", "0.3", "0.2"]
|
||||
for j, val in enumerate(probs):
|
||||
fill = GRAY5 if j > 0 else GRAY3
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
col_w[j],
|
||||
row_h * 0.7,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + col_w[j] / 2,
|
||||
y + row_h * 0.35,
|
||||
val,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
)
|
||||
x += col_w[j]
|
||||
|
||||
|
||||
def _draw_criteria_bars(ax2: Axes) -> None:
|
||||
"""Draw the criteria comparison bar chart on the right panel."""
|
||||
criteria = [
|
||||
"E[X]",
|
||||
"Laplace",
|
||||
"Maximax",
|
||||
"Maximin",
|
||||
"Hurwicz\n\u03b1=0.6",
|
||||
"Savage",
|
||||
]
|
||||
|
||||
ev = [95, 69, 30]
|
||||
laplace = [50, 63.3, 30]
|
||||
maximax = [200, 80, 30]
|
||||
maximin = [-100, 40, 30]
|
||||
hurwicz = [80, 64, 30]
|
||||
savage_maxregret = [140, 120, 170]
|
||||
|
||||
winners = [0, 1, 0, 1, 0, 1]
|
||||
|
||||
x_pos = np.arange(len(criteria))
|
||||
width = 0.22
|
||||
hatches = ["///", "...", "xxx"]
|
||||
labels = ["A₁ (fabryka)", "A₂ (sklep)", "A₃ (obligacje)"]
|
||||
|
||||
all_vals = [
|
||||
[
|
||||
ev[0],
|
||||
laplace[0],
|
||||
maximax[0],
|
||||
maximin[0],
|
||||
hurwicz[0],
|
||||
savage_maxregret[0],
|
||||
],
|
||||
[
|
||||
ev[1],
|
||||
laplace[1],
|
||||
maximax[1],
|
||||
maximin[1],
|
||||
hurwicz[1],
|
||||
savage_maxregret[1],
|
||||
],
|
||||
[
|
||||
ev[2],
|
||||
laplace[2],
|
||||
maximax[2],
|
||||
maximin[2],
|
||||
hurwicz[2],
|
||||
savage_maxregret[2],
|
||||
],
|
||||
]
|
||||
|
||||
for i in range(_DATA_STATE_COLS):
|
||||
ax2.bar(
|
||||
x_pos + (i - 1) * width,
|
||||
all_vals[i],
|
||||
width,
|
||||
label=labels[i],
|
||||
color="white",
|
||||
edgecolor=LN,
|
||||
hatch=hatches[i],
|
||||
lw=0.8,
|
||||
)
|
||||
|
||||
for c_idx in range(len(criteria)):
|
||||
w = winners[c_idx]
|
||||
val = all_vals[w][c_idx]
|
||||
ax2.text(
|
||||
x_pos[c_idx] + (w - 1) * width,
|
||||
val + 5,
|
||||
"★",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax2.set_xticks(x_pos)
|
||||
ax2.set_xticklabels(criteria, fontsize=7)
|
||||
ax2.set_ylabel("Wartość kryterium", fontsize=8)
|
||||
ax2.set_title(
|
||||
"Porównanie kryteriów",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=8,
|
||||
)
|
||||
ax2.legend(fontsize=7, loc="upper right")
|
||||
ax2.axhline(y=0, color=LN, lw=0.5, ls="-")
|
||||
ax2.spines["top"].set_visible(False)
|
||||
ax2.spines["right"].set_visible(False)
|
||||
ax2.tick_params(labelsize=7)
|
||||
|
||||
ax2.text(
|
||||
5,
|
||||
-30,
|
||||
"(Savage: niżej\n= lepiej)",
|
||||
fontsize=6,
|
||||
ha="center",
|
||||
va="top",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
|
||||
def draw_criteria_comparison() -> None:
|
||||
"""Draw payoff matrix and criteria comparison chart."""
|
||||
fig, axes = plt.subplots(
|
||||
1,
|
||||
2,
|
||||
figsize=(8.27, 4.5),
|
||||
gridspec_kw={"width_ratios": [1.2, 1]},
|
||||
)
|
||||
|
||||
_draw_payoff_table(axes[0])
|
||||
_draw_criteria_bars(axes[1])
|
||||
|
||||
plt.tight_layout()
|
||||
outpath = str(Path(OUTPUT_DIR) / "q31_criteria_comparison.png")
|
||||
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", outpath)
|
||||
@ -0,0 +1,289 @@
|
||||
"""Q31 Diagrams 5 & 6: Expected value + decision conditions spectrum."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
|
||||
_WINNING_EV,
|
||||
BG,
|
||||
DPI,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
_logger,
|
||||
)
|
||||
|
||||
|
||||
def draw_expected_value() -> None:
|
||||
"""Draw expected value criterion with probability-weighted bars."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(8.27, 3.5), sharey=True)
|
||||
fig.suptitle(
|
||||
"Kryterium wartości oczekiwanej E[X]" " \u2014 rozkład wyników per alternatywa",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
|
||||
probs = [0.5, 0.3, 0.2]
|
||||
alts = [
|
||||
("A₁ (fabryka)", [200, 50, -100], 95),
|
||||
("A₂ (sklep)", [80, 70, 40], 69),
|
||||
("A₃ (obligacje)", [30, 30, 30], 30),
|
||||
]
|
||||
|
||||
hatches = ["///", "...", "xxx"]
|
||||
|
||||
for _idx, (ax, (name, vals, ev)) in enumerate(zip(axes, alts, strict=False)):
|
||||
x_positions = [0, 0.6, 1.0]
|
||||
widths = [p * 0.9 for p in probs]
|
||||
|
||||
for i, (v, p, h) in enumerate(zip(vals, probs, hatches, strict=False)):
|
||||
color = "white" if v >= 0 else GRAY2
|
||||
ax.bar(
|
||||
x_positions[i],
|
||||
v,
|
||||
width=widths[i],
|
||||
color=color,
|
||||
edgecolor=LN,
|
||||
hatch=h,
|
||||
lw=0.8,
|
||||
align="edge",
|
||||
)
|
||||
offset = 8 if v >= 0 else -12
|
||||
ax.text(
|
||||
x_positions[i] + widths[i] / 2,
|
||||
v + offset,
|
||||
f"{v}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
contrib = v * p
|
||||
ax.text(
|
||||
x_positions[i] + widths[i] / 2,
|
||||
v / 2,
|
||||
f"{v}x{p}\n={contrib:.0f}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Expected value line
|
||||
ax.axhline(y=ev, color=LN, lw=2, ls="--")
|
||||
ax.text(
|
||||
1.35,
|
||||
ev,
|
||||
f"E[X]={ev}",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
va="center",
|
||||
ha="left",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.15",
|
||||
"facecolor": GRAY1,
|
||||
"edgecolor": LN,
|
||||
},
|
||||
)
|
||||
|
||||
ax.set_title(name, fontsize=9, fontweight="bold")
|
||||
ax.set_xticks([0.225, 0.735, 1.09])
|
||||
ax.set_xticklabels(["S₁", "S₂", "S₃"], fontsize=7)
|
||||
ax.axhline(y=0, color=LN, lw=0.5)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.tick_params(labelsize=7)
|
||||
|
||||
# Star on winner
|
||||
if ev == _WINNING_EV:
|
||||
ax.text(
|
||||
0.7,
|
||||
ev + 20,
|
||||
"★ MAX",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
)
|
||||
|
||||
axes[0].set_ylabel("Wypłata (tys. zł)", fontsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
outpath = str(Path(OUTPUT_DIR) / "q31_expected_value.png")
|
||||
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", outpath)
|
||||
|
||||
|
||||
def draw_conditions_spectrum() -> None:
|
||||
"""Draw decision conditions spectrum diagram."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8.27, 3.5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Warunki decyzyjne" " \u2014 spektrum wiedzy decydenta",
|
||||
fontsize=FS_TITLE + 1,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Three zones
|
||||
zones = [
|
||||
(
|
||||
0.3,
|
||||
1.5,
|
||||
2.8,
|
||||
2.5,
|
||||
"PEWNOŚĆ",
|
||||
"white",
|
||||
[
|
||||
"Znamy dokładny wynik",
|
||||
"Przykład: lokata 5%",
|
||||
"Metoda: po prostu wybierz",
|
||||
"najlepszy wynik",
|
||||
],
|
||||
),
|
||||
(
|
||||
3.5,
|
||||
1.5,
|
||||
2.8,
|
||||
2.5,
|
||||
"RYZYKO",
|
||||
GRAY1,
|
||||
[
|
||||
"Znamy wyniki I prawdop.",
|
||||
"Przykład: gra w kości",
|
||||
"Metoda: wartość",
|
||||
"oczekiwana E[X]",
|
||||
],
|
||||
),
|
||||
(
|
||||
6.7,
|
||||
1.5,
|
||||
2.8,
|
||||
2.5,
|
||||
"NIEPEWNOŚĆ",
|
||||
GRAY3,
|
||||
[
|
||||
"Znamy wyniki, ale",
|
||||
"NIE znamy prawdop.",
|
||||
"Metody: Laplace, maximax,",
|
||||
"maximin, Hurwicz, Savage",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for x, y, w, h, title, fill, lines in zones:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.1",
|
||||
lw=2,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h - 0.3,
|
||||
title,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
for i, line in enumerate(lines):
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h - 0.7 - i * 0.4,
|
||||
line,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
)
|
||||
|
||||
# Arrows between zones
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(3.4, 2.75),
|
||||
xytext=(3.15, 2.75),
|
||||
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
|
||||
)
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(6.6, 2.75),
|
||||
xytext=(6.35, 2.75),
|
||||
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
|
||||
)
|
||||
|
||||
# Bottom: knowledge gradient bar
|
||||
gradient_y = 0.5
|
||||
gradient_h = 0.5
|
||||
n_steps = 50
|
||||
for i in range(n_steps):
|
||||
x = 0.3 + i * (9.2 / n_steps)
|
||||
w = 9.2 / n_steps + 0.01
|
||||
gray_val = 1 - (i / n_steps) * 0.7
|
||||
rect = mpatches.Rectangle(
|
||||
(x, gradient_y),
|
||||
w,
|
||||
gradient_h,
|
||||
lw=0,
|
||||
facecolor=str(gray_val),
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
rect = mpatches.Rectangle(
|
||||
(0.3, gradient_y),
|
||||
9.2,
|
||||
gradient_h,
|
||||
lw=1.5,
|
||||
edgecolor=LN,
|
||||
facecolor="none",
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
ax.text(
|
||||
0.3,
|
||||
gradient_y - 0.15,
|
||||
"Dużo wiedzy",
|
||||
fontsize=7,
|
||||
ha="left",
|
||||
va="top",
|
||||
)
|
||||
ax.text(
|
||||
9.5,
|
||||
gradient_y - 0.15,
|
||||
"Mało wiedzy",
|
||||
fontsize=7,
|
||||
ha="right",
|
||||
va="top",
|
||||
)
|
||||
ax.text(
|
||||
4.95,
|
||||
gradient_y + gradient_h / 2,
|
||||
"POZIOM WIEDZY DECYDENTA",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
color="white",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
outpath = str(Path(OUTPUT_DIR) / "q31_conditions_spectrum.png")
|
||||
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", outpath)
|
||||
@ -0,0 +1,344 @@
|
||||
"""Q31 Diagrams 3 & 4: Hurwicz interpolation + criteria mnemonic map."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
|
||||
BG,
|
||||
DPI,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
_logger,
|
||||
draw_box,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def draw_hurwicz_interpolation() -> None:
|
||||
"""Draw Hurwicz alpha interpolation diagram."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8.27, 4))
|
||||
ax.set_title(
|
||||
"Kryterium Hurwicza" " \u2014 wpływ \u03b1 na wybór alternatywy",
|
||||
fontsize=FS_TITLE + 1,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
alphas = np.linspace(0, 1, 200)
|
||||
|
||||
v1 = alphas * 200 + (1 - alphas) * (-100)
|
||||
v2 = alphas * 80 + (1 - alphas) * 40
|
||||
v3 = alphas * 30 + (1 - alphas) * 30
|
||||
|
||||
ax.plot(
|
||||
alphas,
|
||||
v1,
|
||||
"k-",
|
||||
lw=2,
|
||||
label="A₁ (fabryka): V = 300\u03b1 - 100",
|
||||
)
|
||||
ax.plot(
|
||||
alphas,
|
||||
v2,
|
||||
"k--",
|
||||
lw=2,
|
||||
label="A₂ (sklep): V = 40\u03b1 + 40",
|
||||
)
|
||||
ax.plot(
|
||||
alphas,
|
||||
v3,
|
||||
"k:",
|
||||
lw=2,
|
||||
label="A₃ (obligacje): V = 30",
|
||||
)
|
||||
|
||||
# Crossover A2=A1
|
||||
alpha_cross_12 = 140 / 260
|
||||
v_cross_12 = 40 * alpha_cross_12 + 40
|
||||
|
||||
ax.plot(alpha_cross_12, v_cross_12, "ko", markersize=8, zorder=5)
|
||||
ax.annotate(
|
||||
f"\u03b1 ≈ {alpha_cross_12:.2f}\nA₁ = A₂",
|
||||
xy=(alpha_cross_12, v_cross_12),
|
||||
xytext=(alpha_cross_12 + 0.12, v_cross_12 - 30),
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": LN,
|
||||
"lw": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Shade winning regions
|
||||
ax.axvspan(0, alpha_cross_12, alpha=0.08, color="black")
|
||||
ax.axvspan(alpha_cross_12, 1, alpha=0.15, color="black")
|
||||
|
||||
ax.text(
|
||||
alpha_cross_12 / 2,
|
||||
-60,
|
||||
"A₂ wygrywa\n(pesymistycznie)",
|
||||
fontsize=8,
|
||||
ha="center",
|
||||
va="center",
|
||||
bbox={
|
||||
"boxstyle": "round",
|
||||
"facecolor": "white",
|
||||
"edgecolor": LN,
|
||||
},
|
||||
)
|
||||
ax.text(
|
||||
(alpha_cross_12 + 1) / 2,
|
||||
160,
|
||||
"A₁ wygrywa\n(optymistycznie)",
|
||||
fontsize=8,
|
||||
ha="center",
|
||||
va="center",
|
||||
bbox={
|
||||
"boxstyle": "round",
|
||||
"facecolor": "white",
|
||||
"edgecolor": LN,
|
||||
},
|
||||
)
|
||||
|
||||
# Special alpha values
|
||||
ax.axvline(x=0, color=LN, lw=0.5, ls=":")
|
||||
ax.axvline(x=1, color=LN, lw=0.5, ls=":")
|
||||
ax.text(
|
||||
0,
|
||||
-115,
|
||||
"\u03b1=0\nmaximin",
|
||||
fontsize=7,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
1,
|
||||
-115,
|
||||
"\u03b1=1\nmaximax",
|
||||
fontsize=7,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.set_xlabel("Współczynnik optymizmu \u03b1", fontsize=9)
|
||||
ax.set_ylabel("V(Aᵢ) = \u03b1·max + (1-\u03b1)·min", fontsize=9)
|
||||
ax.legend(fontsize=8, loc="upper left")
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.set_xlim(-0.05, 1.05)
|
||||
ax.axhline(y=0, color=LN, lw=0.3, ls="-")
|
||||
ax.tick_params(labelsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
outpath = str(Path(OUTPUT_DIR) / "q31_hurwicz_alpha.png")
|
||||
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", outpath)
|
||||
|
||||
|
||||
def _draw_mnemonic_criteria_boxes(ax: Axes) -> None:
|
||||
"""Draw the 6 criteria boxes around the center."""
|
||||
criteria = [
|
||||
(
|
||||
0,
|
||||
6.5,
|
||||
3,
|
||||
1.2,
|
||||
"WARTOŚĆ OCZEKIWANA",
|
||||
"\u201eMam prawdopodobieństwa\u201d",
|
||||
"E[Aᵢ] = Σ pⱼ·aᵢⱼ",
|
||||
),
|
||||
(
|
||||
3.5,
|
||||
6.5,
|
||||
3,
|
||||
1.2,
|
||||
"LAPLACE",
|
||||
"\u201eWszystko po równo\u201d",
|
||||
"V = Σaᵢⱼ / n",
|
||||
),
|
||||
(
|
||||
7,
|
||||
6.5,
|
||||
3,
|
||||
1.2,
|
||||
"MAXIMAX",
|
||||
"\u201eOptymista: max z max\u201d",
|
||||
"max maxⱼ aᵢⱼ",
|
||||
),
|
||||
(
|
||||
0,
|
||||
0.5,
|
||||
3,
|
||||
1.2,
|
||||
"MAXIMIN (Wald)",
|
||||
"\u201ePesymista: max z min\u201d",
|
||||
"max minⱼ aᵢⱼ",
|
||||
),
|
||||
(
|
||||
3.5,
|
||||
0.5,
|
||||
3,
|
||||
1.2,
|
||||
"HURWICZ",
|
||||
"\u201e\u03b1 pomiędzy\u201d",
|
||||
"\u03b1·max + (1-\u03b1)·min",
|
||||
),
|
||||
(
|
||||
7,
|
||||
0.5,
|
||||
3,
|
||||
1.2,
|
||||
"SAVAGE",
|
||||
"\u201eMin max żalu\u201d",
|
||||
"min maxⱼ rᵢⱼ",
|
||||
),
|
||||
]
|
||||
|
||||
fills = [GRAY3, GRAY1, "white", "white", GRAY1, GRAY3]
|
||||
|
||||
for i, (x, y, w, h, title, mnem, formula) in enumerate(criteria):
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.08",
|
||||
lw=1.5,
|
||||
edgecolor=LN,
|
||||
facecolor=fills[i],
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h * 0.78,
|
||||
title,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h * 0.45,
|
||||
mnem,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
style="italic",
|
||||
)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h * 0.15,
|
||||
formula,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
family="monospace",
|
||||
)
|
||||
|
||||
# Arrows from center to each box
|
||||
cx, cy = 5, 4
|
||||
bx = x + w / 2
|
||||
by_center = y + h / 2
|
||||
if by_center > cy:
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(bx, y),
|
||||
xytext=(cx, 4.5),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": LN,
|
||||
"lw": 1,
|
||||
"connectionstyle": "arc3,rad=0",
|
||||
},
|
||||
)
|
||||
else:
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(bx, y + h),
|
||||
xytext=(cx, 3.5),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": LN,
|
||||
"lw": 1,
|
||||
"connectionstyle": "arc3,rad=0",
|
||||
},
|
||||
)
|
||||
|
||||
# Labels on arrows
|
||||
arrow_labels = [
|
||||
(1.2, 5.6, "znane p"),
|
||||
(5, 5.6, "p = 1/n"),
|
||||
(8.7, 5.6, "max ↑"),
|
||||
(1.2, 2.5, "min ↑"),
|
||||
(5, 2.5, "podaj \u03b1"),
|
||||
(8.7, 2.5, "macierz\nżalu"),
|
||||
]
|
||||
for lx, ly, ltext in arrow_labels:
|
||||
ax.text(
|
||||
lx,
|
||||
ly,
|
||||
ltext,
|
||||
fontsize=7,
|
||||
ha="center",
|
||||
va="center",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.15",
|
||||
"facecolor": "white",
|
||||
"edgecolor": GRAY3,
|
||||
"lw": 0.5,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def draw_criteria_mnemonic() -> None:
|
||||
"""Draw decision criteria mnemonic map diagram."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8.27, 6))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 8)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Mapa mnemoniczna \u2014 6 kryteriów decyzyjnych",
|
||||
fontsize=FS_TITLE + 2,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Central node
|
||||
draw_box(
|
||||
ax,
|
||||
3.5,
|
||||
3.5,
|
||||
3,
|
||||
1,
|
||||
"MACIERZ\nWYPŁAT",
|
||||
fill=GRAY2,
|
||||
lw=2,
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
_draw_mnemonic_criteria_boxes(ax)
|
||||
|
||||
plt.tight_layout()
|
||||
outpath = str(Path(OUTPUT_DIR) / "q31_criteria_mnemonic.png")
|
||||
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", outpath)
|
||||
@ -0,0 +1,322 @@
|
||||
"""Q31 Diagram 2: Regret matrix construction step-by-step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._q31_common import (
|
||||
_DATA_STATE_COLS,
|
||||
_REGRET_HEADER_COLS,
|
||||
BG,
|
||||
DPI,
|
||||
FS,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
_logger,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
|
||||
def _draw_original_payoff(
|
||||
ax: Axes,
|
||||
start_y: float,
|
||||
row_h: float,
|
||||
) -> None:
|
||||
"""Draw the original payoff matrix (left side of regret fig)."""
|
||||
ax.text(
|
||||
2.2,
|
||||
6.3,
|
||||
"Krok 1: Macierz wypłat",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
|
||||
col_w = 1.0
|
||||
headers = ["", "S₁", "S₂", "S₃"]
|
||||
data = [
|
||||
["A₁", "200", "50", "-100"],
|
||||
["A₂", "80", "70", "40"],
|
||||
["A₃", "30", "30", "30"],
|
||||
]
|
||||
start_x = 0.3
|
||||
|
||||
for j, h in enumerate(headers):
|
||||
w = 0.7 if j == 0 else col_w
|
||||
x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w)
|
||||
rect = mpatches.Rectangle(
|
||||
(x, start_y),
|
||||
w,
|
||||
row_h,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=GRAY2,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
start_y + row_h / 2,
|
||||
h,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
for i, row in enumerate(data):
|
||||
y = start_y - (i + 1) * row_h
|
||||
for j, val in enumerate(row):
|
||||
w = 0.7 if j == 0 else col_w
|
||||
x = start_x + (0 if j == 0 else 0.7 + (j - 1) * col_w)
|
||||
fill = GRAY4 if j == 0 else "white"
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
row_h,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + row_h / 2,
|
||||
val,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
)
|
||||
|
||||
# Max per column annotation
|
||||
max_y = start_y - _DATA_STATE_COLS * row_h - 0.1
|
||||
col_maxes = ["max=200", "max=70", "max=40"]
|
||||
for idx, label in enumerate(col_maxes):
|
||||
ax.text(
|
||||
start_x + 0.7 + (idx + 0.5) * col_w,
|
||||
max_y,
|
||||
label,
|
||||
fontsize=7,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontweight="bold",
|
||||
color="#333",
|
||||
)
|
||||
|
||||
# Arrow
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(5.0, 4.8),
|
||||
xytext=(4.2, 4.8),
|
||||
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
|
||||
)
|
||||
ax.text(
|
||||
4.6,
|
||||
5.0,
|
||||
"rᵢⱼ = max - aᵢⱼ",
|
||||
fontsize=8,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
|
||||
def _draw_regret_table(
|
||||
ax: Axes,
|
||||
start_y: float,
|
||||
row_h: float,
|
||||
) -> None:
|
||||
"""Draw the regret matrix (right side of regret fig)."""
|
||||
ax.text(
|
||||
7.5,
|
||||
6.3,
|
||||
"Krok 2: Macierz żalu",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
|
||||
regret_data = [
|
||||
["A₁", "0", "20", "140"],
|
||||
["A₂", "120", "0", "0"],
|
||||
["A₃", "170", "40", "10"],
|
||||
]
|
||||
headers2 = ["", "S₁", "S₂", "S₃", "max rᵢ"]
|
||||
start_x2 = 5.3
|
||||
|
||||
for j, h in enumerate(headers2):
|
||||
w = 0.7 if j == 0 else (0.9 if j < _REGRET_HEADER_COLS else 1.0)
|
||||
x = start_x2
|
||||
if j == 0:
|
||||
x = start_x2
|
||||
elif j <= _DATA_STATE_COLS:
|
||||
x = start_x2 + 0.7 + (j - 1) * 0.9
|
||||
else:
|
||||
x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9
|
||||
rect = mpatches.Rectangle(
|
||||
(x, start_y),
|
||||
w,
|
||||
row_h,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=(GRAY2 if j < _REGRET_HEADER_COLS else GRAY3),
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
start_y + row_h / 2,
|
||||
h,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
max_regrets = [140, 120, 170]
|
||||
for i, row in enumerate(regret_data):
|
||||
y = start_y - (i + 1) * row_h
|
||||
for j, val in enumerate(row):
|
||||
w = 0.7 if j == 0 else 0.9
|
||||
x = start_x2 + (0 if j == 0 else 0.7 + (j - 1) * 0.9)
|
||||
fill = GRAY4 if j == 0 else "white"
|
||||
if j > 0 and int(val) == max_regrets[i]:
|
||||
fill = GRAY2
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
row_h,
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
fw = "bold" if (j > 0 and int(val) == max_regrets[i]) else "normal"
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + row_h / 2,
|
||||
val,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight=fw,
|
||||
)
|
||||
|
||||
# Max regret column
|
||||
x = start_x2 + 0.7 + _DATA_STATE_COLS * 0.9
|
||||
w = 1.0
|
||||
is_winner = max_regrets[i] == min(max_regrets)
|
||||
fill = "#C8C8C8" if is_winner else GRAY1
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
row_h,
|
||||
lw=1.5 if is_winner else 1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
marker = " ★" if is_winner else ""
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + row_h / 2,
|
||||
f"{max_regrets[i]}{marker}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
|
||||
def draw_regret_matrix() -> None:
|
||||
"""Draw the regret matrix construction diagram."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8.27, 5))
|
||||
ax.axis("off")
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_title(
|
||||
"Kryterium Savage'a \u2014 budowa macierzy żalu",
|
||||
fontsize=FS_TITLE + 1,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
start_y = 5.5
|
||||
row_h = 0.55
|
||||
|
||||
_draw_original_payoff(ax, start_y, row_h)
|
||||
_draw_regret_table(ax, start_y, row_h)
|
||||
|
||||
# Bottom conclusion
|
||||
ax.text(
|
||||
5.0,
|
||||
2.8,
|
||||
"Krok 3: Wybierz min z max żalu" " → A₂ (max żal = 120)",
|
||||
fontsize=10,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontweight="bold",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY1,
|
||||
"edgecolor": LN,
|
||||
"lw": 1.5,
|
||||
},
|
||||
)
|
||||
|
||||
# Interpretation examples
|
||||
ax.text(
|
||||
5.0,
|
||||
2.0,
|
||||
"Interpretacja żalu: r₁₃ = 140 oznacza:\n"
|
||||
"\u201eGdyby nastąpił S₃ (zła koniunktura),"
|
||||
" a wybrałbym A₁,\n"
|
||||
"żałowałbym, bo najlepszą opcją byłoby"
|
||||
" A₂ z wynikiem 40 \u2014 traciłbym 140\u201d",
|
||||
fontsize=7.5,
|
||||
ha="center",
|
||||
va="center",
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
"lw": 0.8,
|
||||
},
|
||||
)
|
||||
|
||||
# Mnemonic
|
||||
ax.text(
|
||||
5.0,
|
||||
0.8,
|
||||
"Mnemonik: Savage = \u201eŻal jak nóż\u201d\n"
|
||||
"Maksymalny żal to nóż "
|
||||
"\u2014 wybierz opcję z NAJMNIEJSZYM nożem",
|
||||
fontsize=8,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontweight="bold",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": "white",
|
||||
"edgecolor": LN,
|
||||
"lw": 1,
|
||||
},
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
outpath = str(Path(OUTPUT_DIR) / "q31_regret_matrix.png")
|
||||
fig.savefig(outpath, dpi=DPI, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", outpath)
|
||||
448
python_pkg/praca_magisterska_video/generate_images/_q9_basics.py
Normal file
448
python_pkg/praca_magisterska_video/generate_images/_q9_basics.py
Normal file
@ -0,0 +1,448 @@
|
||||
"""Q9 diagrams 1-6: process/thread basics, memory, states, PCB, speed."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q9_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_table,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 1. Process vs Thread comparison table
|
||||
# ============================================================
|
||||
def gen_process_vs_thread() -> None:
|
||||
"""Gen process vs thread."""
|
||||
fig, ax = plt.subplots(figsize=(7.5, 4.5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(-4.5, 1.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Proces vs Wątek — porównanie", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
headers = ["Cecha", "Proces", "Wątek"]
|
||||
col_w = [2.5, 3.5, 3.5]
|
||||
rows = [
|
||||
["Pamięć", "Własna, izolowana", "Współdzielona (heap)"],
|
||||
["Tworzenie", "~1-10 ms", "~10-100 μs (100x szybciej)"],
|
||||
["Przełączanie", "~1-5 μs (TLB flush)", "~0.1-0.5 μs (10x)"],
|
||||
["Komunikacja", "IPC (pipe, socket, shm)", "Bezpośrednia (wspólna pam.)"],
|
||||
["Izolacja", "Pełna — awaria izolowana", "Brak — może zabić proces"],
|
||||
["Zastosowanie", "Bezpieczeństwo, izolacja", "Wydajność, współdzielenie"],
|
||||
]
|
||||
draw_table(
|
||||
ax,
|
||||
headers,
|
||||
rows,
|
||||
x0=0.25,
|
||||
y0=0.8,
|
||||
col_widths=col_w,
|
||||
row_h=0.55,
|
||||
fontsize=7.5,
|
||||
header_fontsize=FS_LABEL,
|
||||
)
|
||||
|
||||
# Analogy at bottom
|
||||
ax.text(
|
||||
5.0,
|
||||
-4.2,
|
||||
"Analogia: Proces = mieszkanie (własny adres) "
|
||||
"Wątek = pokój w mieszkaniu (wspólna kuchnia = heap)",
|
||||
ha="center",
|
||||
fontsize=FS,
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_process_vs_thread.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 2. Memory segments layout
|
||||
# ============================================================
|
||||
def gen_memory_layout() -> None:
|
||||
"""Gen memory layout."""
|
||||
fig, ax = plt.subplots(figsize=(6, 5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 8)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Segmenty pamięci procesu", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
segments = [
|
||||
("STACK ↓", "zmienne lokalne, adresy\npowrotu (każdy wątek WŁASNY)", GRAY1),
|
||||
("...", "(wolna przestrzeń)", "white"),
|
||||
("HEAP ↑", "malloc/new — dynamiczna\nalokacja (współdzielony)", GRAY4),
|
||||
("BSS", "zmienne globalne\nniezainicjalizowane (zerowane)", GRAY2),
|
||||
("DATA", "zmienne globalne\nzainicjalizowane", GRAY3),
|
||||
("TEXT", "kod maszynowy\n(read-only, współdzielony)", GRAY5),
|
||||
]
|
||||
bx, bw = 2.0, 2.5
|
||||
seg_h = 0.9
|
||||
gap = 0.05
|
||||
top_y = 7.0
|
||||
|
||||
for i, (name, desc, color) in enumerate(segments):
|
||||
y = top_y - i * (seg_h + gap)
|
||||
draw_box(
|
||||
ax,
|
||||
bx,
|
||||
y,
|
||||
bw,
|
||||
seg_h,
|
||||
name,
|
||||
fill=color,
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
ax.text(bx + bw + 0.3, y + seg_h / 2, desc, fontsize=7.5, va="center")
|
||||
|
||||
# Address labels
|
||||
ax.text(
|
||||
bx - 0.2,
|
||||
top_y + seg_h / 2,
|
||||
"wysoki\nadres",
|
||||
fontsize=FS_SMALL,
|
||||
va="center",
|
||||
ha="right",
|
||||
style="italic",
|
||||
)
|
||||
bottom_y = top_y - 5 * (seg_h + gap)
|
||||
ax.text(
|
||||
bx - 0.2,
|
||||
bottom_y + seg_h / 2,
|
||||
"niski\nadres",
|
||||
fontsize=FS_SMALL,
|
||||
va="center",
|
||||
ha="right",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Arrows for growth
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(bx - 0.5, top_y - 0.1),
|
||||
xytext=(bx - 0.5, top_y + seg_h + 0.1),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(bx - 0.9, top_y + 0.4, "rośnie\nw dół", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
heap_y = top_y - 2 * (seg_h + gap)
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(bx - 0.5, heap_y + seg_h + 0.1),
|
||||
xytext=(bx - 0.5, heap_y - 0.1),
|
||||
arrowprops={"arrowstyle": "->", "lw": 1.5, "color": LN},
|
||||
)
|
||||
ax.text(bx - 0.9, heap_y + 0.5, "rośnie\nw górę", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
save_fig(fig, "q9_memory_layout.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 3. Process states diagram
|
||||
# ============================================================
|
||||
def gen_process_states() -> None:
|
||||
"""Gen process states."""
|
||||
fig, ax = plt.subplots(figsize=(7, 3.5))
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Stany procesu — diagram przejść", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
states = {
|
||||
"NEW": (1.0, 2.5),
|
||||
"READY": (3.5, 2.5),
|
||||
"RUNNING": (6.5, 2.5),
|
||||
"BLOCKED": (6.5, 0.5),
|
||||
"TERMINATED": (10.0, 2.5),
|
||||
}
|
||||
fills = {
|
||||
"NEW": GRAY4,
|
||||
"READY": GRAY1,
|
||||
"RUNNING": GRAY3,
|
||||
"BLOCKED": GRAY2,
|
||||
"TERMINATED": GRAY5,
|
||||
}
|
||||
bw, bh = 1.8, 0.9
|
||||
for name, (x, y) in states.items():
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
y,
|
||||
bw,
|
||||
bh,
|
||||
name,
|
||||
fill=fills[name],
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Transitions
|
||||
transitions = [
|
||||
("NEW", "READY", "admit"),
|
||||
("READY", "RUNNING", "dispatch\n(scheduler)"),
|
||||
("RUNNING", "TERMINATED", "exit"),
|
||||
("RUNNING", "BLOCKED", "I/O wait"),
|
||||
]
|
||||
for src, dst, label in transitions:
|
||||
sx, sy = states[src]
|
||||
dx, dy = states[dst]
|
||||
if sy == dy: # horizontal
|
||||
draw_arrow(ax, sx + bw, sy + bh / 2, dx, dy + bh / 2, lw=1.5)
|
||||
mx = (sx + bw + dx) / 2
|
||||
ax.text(
|
||||
mx,
|
||||
sy + bh / 2 + 0.25,
|
||||
label,
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
)
|
||||
else: # vertical
|
||||
draw_arrow(ax, sx + bw / 2, sy, dx + bw / 2, dy + bh, lw=1.5)
|
||||
ax.text(
|
||||
sx + bw + 0.2,
|
||||
(sy + dy + bh) / 2,
|
||||
label,
|
||||
fontsize=FS_SMALL,
|
||||
ha="left",
|
||||
va="center",
|
||||
)
|
||||
|
||||
# BLOCKED → READY
|
||||
bx, by = states["BLOCKED"]
|
||||
rx, ry = states["READY"]
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(rx + bw / 2, ry),
|
||||
xytext=(bx - 0.3, by + bh / 2),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"lw": 1.5,
|
||||
"color": LN,
|
||||
"connectionstyle": "arc3,rad=0.3",
|
||||
},
|
||||
)
|
||||
ax.text(3.5, 0.7, "I/O done", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
# RUNNING → READY (preemption)
|
||||
rux, ruy = states["RUNNING"]
|
||||
draw_arrow(ax, rux, ruy + bh, rx + bw, ry + bh, lw=1.2)
|
||||
ax.text(5.0, 3.7, "preempt /\ntimeout", fontsize=FS_SMALL, ha="center")
|
||||
|
||||
save_fig(fig, "q9_process_states.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 4. Thread structure within process
|
||||
# ============================================================
|
||||
def gen_thread_structure() -> None:
|
||||
"""Gen thread structure."""
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Wątki wewnątrz procesu (PID=42)", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
# Shared memory region
|
||||
draw_box(ax, 0.5, 3.5, 9.0, 1.8, "", fill=GRAY1, rounded=False, lw=2)
|
||||
ax.text(5.0, 5.0, "WSPÓŁDZIELONE", fontsize=FS, fontweight="bold", ha="center")
|
||||
|
||||
labels_shared = ["TEXT", "DATA", "BSS", "HEAP", "pliki", "PID"]
|
||||
for i, lab in enumerate(labels_shared):
|
||||
x = 1.0 + i * 1.4
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
3.8,
|
||||
1.1,
|
||||
0.6,
|
||||
lab,
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
ax.text(x + 0.55, 4.6, lab, fontsize=FS_SMALL, ha="center", color="#555555")
|
||||
|
||||
# Per-thread regions
|
||||
draw_box(
|
||||
ax, 0.5, 0.5, 9.0, 2.7, "", fill="white", rounded=False, lw=2, linestyle="--"
|
||||
)
|
||||
ax.text(
|
||||
5.0, 2.95, "PRYWATNE (każdy wątek)", fontsize=FS, fontweight="bold", ha="center"
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
x = 1.0 + i * 3.0
|
||||
tid = i + 1
|
||||
draw_box(ax, x, 0.7, 2.3, 2.0, "", fill=GRAY4, rounded=False)
|
||||
ax.text(
|
||||
x + 1.15,
|
||||
2.4,
|
||||
f"Wątek {tid}",
|
||||
fontsize=FS_LABEL,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
)
|
||||
items = [f"stos_{tid}", f"rejestry_{tid}", f"PC_{tid}", f"TID={40 + tid}"]
|
||||
for j, item in enumerate(items):
|
||||
ax.text(
|
||||
x + 1.15,
|
||||
2.0 - j * 0.35,
|
||||
item,
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
family="monospace",
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_thread_structure.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 5. PCB structure
|
||||
# ============================================================
|
||||
def gen_pcb_structure() -> None:
|
||||
"""Gen pcb structure."""
|
||||
fig, ax = plt.subplots(figsize=(5, 3.5))
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 5.5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"PCB (Process Control Block)", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
fields = [
|
||||
("PID", "42"),
|
||||
("Stan", "READY / RUNNING / BLOCKED"),
|
||||
("Rejestry CPU", "EAX, EBX, ESP, EIP ..."),
|
||||
("Tablice stron", "mapowanie wirtualne → fizyczne"),
|
||||
("Otwarte pliki", "fd[0], fd[1], fd[2] ..."),
|
||||
("Priorytety", "nice value, scheduling class"),
|
||||
("Statystyki", "CPU time, I/O count"),
|
||||
]
|
||||
|
||||
top_y = 4.8
|
||||
for i, (field, value) in enumerate(fields):
|
||||
y = top_y - i * 0.55
|
||||
draw_box(
|
||||
ax,
|
||||
0.5,
|
||||
y,
|
||||
2.2,
|
||||
0.45,
|
||||
field,
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
draw_box(ax, 2.7, y, 4.5, 0.45, value, fill=GRAY4, fontsize=FS, rounded=False)
|
||||
|
||||
ax.text(
|
||||
4.0,
|
||||
0.3,
|
||||
"Context switch = zapisz PCB starego → wczytaj PCB nowego",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_pcb_structure.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 6. Speed comparison
|
||||
# ============================================================
|
||||
def gen_speed_comparison() -> None:
|
||||
"""Gen speed comparison."""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(9, 3.5))
|
||||
fig.suptitle(
|
||||
"Szybkość — procesy vs wątki (benchmarki Linux)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Creation time panel
|
||||
ax = axes[0]
|
||||
ops = ["fork()\n(nowy proces)", "pthread_create()\n(nowy wątek)"]
|
||||
times = [3.0, 0.05] # ms
|
||||
colors = [GRAY3, GRAY1]
|
||||
bars = ax.barh(ops, times, color=colors, edgecolor=LN, height=0.5, linewidth=1.2)
|
||||
ax.set_xlabel("Czas [ms]", fontsize=FS)
|
||||
ax.set_title("Tworzenie", fontsize=FS_LABEL, fontweight="bold")
|
||||
ax.set_xlim(0, 4.5)
|
||||
for bar, t in zip(bars, times, strict=False):
|
||||
ax.text(
|
||||
bar.get_width() + 0.1,
|
||||
bar.get_y() + bar.get_height() / 2,
|
||||
f"{t} ms",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
)
|
||||
ax.text(
|
||||
2.5,
|
||||
-0.6,
|
||||
"~100x szybciej",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
transform=ax.get_xaxis_transform(),
|
||||
)
|
||||
ax.tick_params(labelsize=FS)
|
||||
|
||||
# Right: context switch
|
||||
ax = axes[1]
|
||||
ops2 = ["Proces→Proces\n(TLB flush)", "Wątek→Wątek\n(TLB warm)"]
|
||||
times2 = [3000, 300] # ns
|
||||
bars2 = ax.barh(ops2, times2, color=colors, edgecolor=LN, height=0.5, linewidth=1.2)
|
||||
ax.set_xlabel("Czas [ns]", fontsize=FS)
|
||||
ax.set_title("Przełączanie kontekstu", fontsize=FS_LABEL, fontweight="bold")
|
||||
ax.set_xlim(0, 4500)
|
||||
for bar, t in zip(bars2, times2, strict=False):
|
||||
ax.text(
|
||||
bar.get_width() + 50,
|
||||
bar.get_y() + bar.get_height() / 2,
|
||||
f"{t} ns",
|
||||
va="center",
|
||||
fontsize=FS,
|
||||
)
|
||||
ax.text(
|
||||
2500,
|
||||
-0.6,
|
||||
"~10x szybciej",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
transform=ax.get_xaxis_transform(),
|
||||
)
|
||||
ax.tick_params(labelsize=FS)
|
||||
|
||||
fig.tight_layout(rect=[0, 0.05, 1, 0.92])
|
||||
save_fig(fig, "q9_speed_comparison.png")
|
||||
@ -0,0 +1,420 @@
|
||||
"""Q9 diagrams 14-16: classic sync problems, mechanism comparison, semaphore."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q9_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
OCCUPIED_SLOTS,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_table,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 14. Bounded buffer + readers-writers + philosophers
|
||||
# ============================================================
|
||||
def _draw_bounded_buffer_panel(ax: plt.Axes) -> None:
|
||||
"""Draw the bounded-buffer (producer-consumer) panel."""
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Producent-Konsument\n(Bounded Buffer, N=4)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
0.2,
|
||||
4.0,
|
||||
2.0,
|
||||
1.2,
|
||||
"Producent\nP(empty)\nP(mutex)\nwstaw()\nV(mutex)\nV(full)",
|
||||
fill=GRAY1,
|
||||
fontsize=5.5,
|
||||
)
|
||||
items = ["A", "B", "", ""]
|
||||
for i, item in enumerate(items):
|
||||
x = 2.8 + i * 0.9
|
||||
fill = GRAY3 if item else "white"
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
4.3,
|
||||
0.9,
|
||||
0.7,
|
||||
item,
|
||||
fill=fill,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
ax.text(4.6, 5.2, "Bufor (N=4)", fontsize=FS_SMALL, ha="center", fontweight="bold")
|
||||
draw_box(
|
||||
ax,
|
||||
6.0,
|
||||
4.0,
|
||||
2.0,
|
||||
1.2,
|
||||
"Konsument\nP(full)\nP(mutex)\npobierz()\nV(mutex)\nV(empty)",
|
||||
fill=GRAY4,
|
||||
fontsize=5.5,
|
||||
)
|
||||
draw_arrow(ax, 2.2, 4.6, 2.8, 4.65, lw=1.2)
|
||||
draw_arrow(ax, 6.4, 4.65, 6.0, 4.6, lw=1.2)
|
||||
|
||||
sems = [("mutex = 1", GRAY2), ("empty = N", GRAY1), ("full = 0", GRAY3)]
|
||||
for i, (s, c) in enumerate(sems):
|
||||
draw_box(
|
||||
ax,
|
||||
2.0,
|
||||
2.5 - i * 0.6,
|
||||
4.0,
|
||||
0.45,
|
||||
s,
|
||||
fill=c,
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
4.0,
|
||||
0.5,
|
||||
"KOLEJNOŚĆ: P(empty/full)\nPRZED P(mutex)!\nOdwrotnie = DEADLOCK",
|
||||
fontsize=5.5,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
|
||||
)
|
||||
|
||||
|
||||
def _draw_readers_writers_panel(ax: plt.Axes) -> None:
|
||||
"""Draw the readers-writers panel."""
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Czytelnicy-Pisarze\n(Readers-Writers)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax,
|
||||
2.5,
|
||||
3.5,
|
||||
3.0,
|
||||
1.5,
|
||||
"Dane\n(współdzielone)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
x = 0.3 + i * 1.0
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
5.5,
|
||||
0.8,
|
||||
0.7,
|
||||
f"R{i + 1}",
|
||||
fill=GRAY4,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, x + 0.4, 5.5, 3.0 + i * 0.5, 5.0, lw=1)
|
||||
|
||||
ax.text(
|
||||
1.5,
|
||||
6.5,
|
||||
"Czytelnicy (wielu naraz)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
draw_box(
|
||||
ax, 5.5, 5.5, 1.5, 0.7, "Pisarz", fill=GRAY5, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_arrow(ax, 6.25, 5.5, 5.0, 5.0, lw=1.5)
|
||||
ax.text(
|
||||
6.25,
|
||||
6.5,
|
||||
"WYŁĄCZNY",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
)
|
||||
|
||||
rules = [
|
||||
"Wielu czytelników = OK",
|
||||
"Jeden pisarz = wyłączny",
|
||||
"Czytelnik + Pisarz = ✗",
|
||||
"Problem: pisarze głodują",
|
||||
]
|
||||
for i, r in enumerate(rules):
|
||||
ax.text(4.0, 2.5 - i * 0.45, r, fontsize=FS_SMALL, ha="center")
|
||||
|
||||
ax.text(
|
||||
4.0,
|
||||
0.5,
|
||||
"Rozwiązanie:\nrw_mutex + count_mutex\n+ zmienna readers",
|
||||
fontsize=5.5,
|
||||
ha="center",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
|
||||
def _draw_philosophers_panel(ax: plt.Axes) -> None:
|
||||
"""Draw the dining-philosophers panel."""
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Ucztujący filozofowie\n(Dining Philosophers)", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
cx, cy, r = 4.0, 3.8, 1.8
|
||||
table = plt.Circle((cx, cy), 0.8, fill=True, facecolor=GRAY2, edgecolor=LN, lw=1.5)
|
||||
ax.add_patch(table)
|
||||
ax.text(cx, cy, "Stół", fontsize=FS, ha="center", fontweight="bold")
|
||||
|
||||
for i in range(5):
|
||||
angle = np.pi / 2 + i * 2 * np.pi / 5
|
||||
px = cx + r * np.cos(angle)
|
||||
py = cy + r * np.sin(angle)
|
||||
circle = plt.Circle(
|
||||
(px, py), 0.35, fill=True, facecolor=GRAY1, edgecolor=LN, lw=1.2
|
||||
)
|
||||
ax.add_patch(circle)
|
||||
ax.text(
|
||||
px, py, f"F{i}", ha="center", va="center", fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
fork_angle = np.pi / 2 + (i + 0.5) * 2 * np.pi / 5
|
||||
fx = cx + (r * 0.6) * np.cos(fork_angle)
|
||||
fy = cy + (r * 0.6) * np.sin(fork_angle)
|
||||
ax.plot(
|
||||
[fx - 0.1, fx + 0.1],
|
||||
[fy - 0.15, fy + 0.15],
|
||||
color=LN,
|
||||
lw=2.5,
|
||||
solid_capstyle="round",
|
||||
)
|
||||
ax.text(fx + 0.2, fy + 0.15, f"w{i}", fontsize=5, color="#555555")
|
||||
|
||||
rules = [
|
||||
"Jedzenie = 2 widelce",
|
||||
"Naiwne → DEADLOCK",
|
||||
"Fix: F4 bierze odwrotnie",
|
||||
"Alt: semafor(4)",
|
||||
]
|
||||
for i, r in enumerate(rules):
|
||||
ax.text(4.0, 1.2 - i * 0.35, r, fontsize=FS_SMALL, ha="center")
|
||||
|
||||
|
||||
def gen_classic_problems() -> None:
|
||||
"""Gen classic problems."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 5))
|
||||
fig.suptitle(
|
||||
"Klasyczne problemy synchronizacji", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
_draw_bounded_buffer_panel(axes[0])
|
||||
_draw_readers_writers_panel(axes[1])
|
||||
_draw_philosophers_panel(axes[2])
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.88])
|
||||
save_fig(fig, "q9_classic_problems.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 15. Sync mechanisms comparison + mutex/sem/spinlock
|
||||
# ============================================================
|
||||
def gen_sync_comparison() -> None:
|
||||
"""Gen sync comparison."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(9, 7))
|
||||
|
||||
# Top: comparison table
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 11.5)
|
||||
ax.set_ylim(-5, 1)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Mechanizmy synchronizacji — porównanie",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
headers = ["Mechanizm", "Opis", "Kiedy używać"]
|
||||
col_w = [2.5, 4.5, 4.0]
|
||||
rows = [
|
||||
["Mutex", "Zamek: 1 wątek w sekcji", "Sekcja krytyczna"],
|
||||
["Semafor(n)", "Licznik: max n wątków", "Ograniczone zasoby (n miejsc)"],
|
||||
["Monitor", "Obiekt z wbudowanym mutex", "Java synchronized"],
|
||||
["Cond. Variable", "wait()/signal() na warunek", "Producent-konsument"],
|
||||
["Spinlock", "Aktywne czekanie (busy-wait)", "Bardzo krótkie sekcje (<1 μs)"],
|
||||
["RW Lock", "Wielu czytelników LUB 1 pisarz", "Bazy danych, cache"],
|
||||
["Barrier", "Czekaj aż wszyscy dotrą", "Obliczenia równoległe"],
|
||||
]
|
||||
draw_table(
|
||||
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7
|
||||
)
|
||||
|
||||
# Bottom: mutex vs semafor vs spinlock
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 12)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Mutex vs Semafor vs Spinlock", fontsize=FS_TITLE, fontweight="bold", pad=5
|
||||
)
|
||||
|
||||
# Mutex
|
||||
draw_box(ax, 0.3, 2.5, 3.5, 2.0, "", fill=GRAY4)
|
||||
ax.text(2.05, 4.2, "MUTEX", fontsize=FS_LABEL, fontweight="bold", ha="center")
|
||||
ax.text(2.05, 3.6, "= klucz do łazienki\n(1 osoba)", fontsize=FS, ha="center")
|
||||
ax.text(
|
||||
2.05,
|
||||
2.8,
|
||||
"Wątek ZASYPIA gdy czeka\nOS go obudzi (~μs)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Semafor
|
||||
draw_box(ax, 4.3, 2.5, 3.5, 2.0, "", fill=GRAY1)
|
||||
ax.text(6.05, 4.2, "SEMAFOR(n)", fontsize=FS_LABEL, fontweight="bold", ha="center")
|
||||
ax.text(
|
||||
6.05, 3.6, "= parking na n miejsc\n(n wątków naraz)", fontsize=FS, ha="center"
|
||||
)
|
||||
ax.text(
|
||||
6.05,
|
||||
2.8,
|
||||
"Semafor(1) = mutex\nP() = zmniejsz, V() = zwiększ",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Spinlock
|
||||
draw_box(ax, 8.3, 2.5, 3.5, 2.0, "", fill=GRAY2)
|
||||
ax.text(10.05, 4.2, "SPINLOCK", fontsize=FS_LABEL, fontweight="bold", ha="center")
|
||||
ax.text(
|
||||
10.05, 3.6, "= obrotowe drzwi\n(kręcisz się w kółko)", fontsize=FS, ha="center"
|
||||
)
|
||||
ax.text(
|
||||
10.05,
|
||||
2.8,
|
||||
"Wątek KRĘCI się w pętli\nLepszy gdy sekcja < 1 μs",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# Dividing rule
|
||||
ax.text(
|
||||
6.0,
|
||||
1.5,
|
||||
"Reguła kciuka: sekcja > 1 μs → MUTEX | "
|
||||
"sekcja < 1 μs → SPINLOCK | n jednocześnie → SEMAFOR(n)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
save_fig(fig, "q9_sync_comparison.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 16. Semaphore concept diagram
|
||||
# ============================================================
|
||||
def gen_semaphore_concept() -> None:
|
||||
"""Gen semaphore concept."""
|
||||
fig, ax = plt.subplots(figsize=(6, 3))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Semafor — koncepcja (parking na 3 miejsca)",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Parking slots
|
||||
for i in range(3):
|
||||
x = 2.0 + i * 2.0
|
||||
occupied = i < OCCUPIED_SLOTS
|
||||
fill = GRAY3 if occupied else "white"
|
||||
label = f"Wątek {i + 1}" if occupied else "(wolne)"
|
||||
draw_box(
|
||||
ax,
|
||||
x,
|
||||
2.5,
|
||||
1.5,
|
||||
1.2,
|
||||
label,
|
||||
fill=fill,
|
||||
fontsize=FS,
|
||||
fontweight="bold" if occupied else "normal",
|
||||
rounded=False,
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.0,
|
||||
4.2,
|
||||
"semafor(3): counter = 1 (jedno wolne miejsce)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Waiting thread
|
||||
draw_box(
|
||||
ax,
|
||||
0.2,
|
||||
0.5,
|
||||
1.5,
|
||||
0.8,
|
||||
"Wątek 4\nP() → czeka",
|
||||
fill="#F8D7DA",
|
||||
fontsize=FS_SMALL,
|
||||
)
|
||||
draw_arrow(ax, 1.7, 0.9, 2.0, 2.5, lw=1.2, color="#C62828")
|
||||
|
||||
ax.text(
|
||||
5.0,
|
||||
0.6,
|
||||
"P() = counter-- (jeśli 0 → czekaj)\nV() = counter++ (obudź czekającego)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
family="monospace",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_semaphore_concept.png")
|
||||
200
python_pkg/praca_magisterska_video/generate_images/_q9_common.py
Normal file
200
python_pkg/praca_magisterska_video/generate_images/_q9_common.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Common utilities and constants for Q9 diagram generation.
|
||||
|
||||
Monochrome, A4-printable PNGs (300 DPI).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
DPI = 300
|
||||
BG = "white"
|
||||
LN = "black"
|
||||
FS = 8
|
||||
FS_TITLE = 11
|
||||
FS_SMALL = 6.5
|
||||
FS_LABEL = 9
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
GRAY1 = "#E8E8E8"
|
||||
GRAY2 = "#D0D0D0"
|
||||
GRAY3 = "#B8B8B8"
|
||||
GRAY4 = "#F5F5F5"
|
||||
GRAY5 = "#C0C0C0"
|
||||
OCCUPIED_SLOTS = 2
|
||||
|
||||
|
||||
def draw_box(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
text: str,
|
||||
fill: str = "white",
|
||||
lw: float = 1.2,
|
||||
fontsize: float = FS,
|
||||
fontweight: str = "normal",
|
||||
ha: str = "center",
|
||||
va: str = "center",
|
||||
*,
|
||||
rounded: bool = True,
|
||||
edgecolor: str = LN,
|
||||
linestyle: str = "-",
|
||||
) -> None:
|
||||
"""Draw box."""
|
||||
if rounded:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=lw,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=fill,
|
||||
linestyle=linestyle,
|
||||
)
|
||||
else:
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y),
|
||||
w,
|
||||
h,
|
||||
lw=lw,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=fill,
|
||||
linestyle=linestyle,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2,
|
||||
text,
|
||||
ha=ha,
|
||||
va=va,
|
||||
fontsize=fontsize,
|
||||
fontweight=fontweight,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
|
||||
def draw_arrow(
|
||||
ax: Axes,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
lw: float = 1.2,
|
||||
style: str = "->",
|
||||
color: str = LN,
|
||||
) -> None:
|
||||
"""Draw arrow."""
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
|
||||
)
|
||||
|
||||
|
||||
def draw_double_arrow(
|
||||
ax: Axes,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
lw: float = 1.2,
|
||||
color: str = LN,
|
||||
) -> None:
|
||||
"""Draw double arrow."""
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": "<->", "color": color, "lw": lw},
|
||||
)
|
||||
|
||||
|
||||
def save_fig(fig: Figure, name: str) -> None:
|
||||
"""Save fig."""
|
||||
path = str(Path(OUTPUT_DIR) / name)
|
||||
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=BG, pad_inches=0.15)
|
||||
plt.close(fig)
|
||||
_logger.info(" Saved: %s", path)
|
||||
|
||||
|
||||
def draw_table(
|
||||
ax: Axes,
|
||||
headers: list[str],
|
||||
rows: list[list[str]],
|
||||
x0: float,
|
||||
y0: float,
|
||||
col_widths: list[float],
|
||||
row_h: float = 0.4,
|
||||
header_fill: str = GRAY2,
|
||||
row_fills: list[str] | None = None,
|
||||
fontsize: float = FS,
|
||||
header_fontsize: float | None = None,
|
||||
) -> None:
|
||||
"""Draw a clean table on axes."""
|
||||
if header_fontsize is None:
|
||||
header_fontsize = fontsize
|
||||
len(headers)
|
||||
len(rows)
|
||||
sum(col_widths)
|
||||
|
||||
# Header
|
||||
cx = x0
|
||||
for j, hdr in enumerate(headers):
|
||||
draw_box(
|
||||
ax,
|
||||
cx,
|
||||
y0,
|
||||
col_widths[j],
|
||||
row_h,
|
||||
hdr,
|
||||
fill=header_fill,
|
||||
fontsize=header_fontsize,
|
||||
fontweight="bold",
|
||||
rounded=False,
|
||||
)
|
||||
cx += col_widths[j]
|
||||
|
||||
# Rows
|
||||
for i, row in enumerate(rows):
|
||||
cy = y0 - (i + 1) * row_h
|
||||
cx = x0
|
||||
fill = GRAY4 if (i % 2 == 0) else "white"
|
||||
if row_fills and i < len(row_fills):
|
||||
fill = row_fills[i]
|
||||
for j, cell in enumerate(row):
|
||||
fw = "bold" if j == 0 else "normal"
|
||||
draw_box(
|
||||
ax,
|
||||
cx,
|
||||
cy,
|
||||
col_widths[j],
|
||||
row_h,
|
||||
cell,
|
||||
fill=fill,
|
||||
fontsize=fontsize,
|
||||
fontweight=fw,
|
||||
rounded=False,
|
||||
)
|
||||
cx += col_widths[j]
|
||||
212
python_pkg/praca_magisterska_video/generate_images/_q9_ipc.py
Normal file
212
python_pkg/praca_magisterska_video/generate_images/_q9_ipc.py
Normal file
@ -0,0 +1,212 @@
|
||||
"""Q9 diagrams 7-9: IPC mechanisms and scenario tables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q9_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_double_arrow,
|
||||
draw_table,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 7. Scenario table (when to use process vs thread)
|
||||
# ============================================================
|
||||
def gen_scenario_table() -> None:
|
||||
"""Gen scenario table."""
|
||||
fig, ax = plt.subplots(figsize=(8.5, 4.5))
|
||||
ax.set_xlim(0, 11)
|
||||
ax.set_ylim(-5.5, 1)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Kiedy proces, kiedy wątek? — typowe scenariusze",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
headers = ["Scenariusz", "Wybór", "Dlaczego?"]
|
||||
col_w = [3.5, 2.5, 4.5]
|
||||
rows = [
|
||||
["Serwer WWW (Apache)", "Proces", "izolacja klientów"],
|
||||
["Serwer WWW (nginx)", "Wątek / async", "szybkość, cooperacja"],
|
||||
["Przeglądarka (karty)", "Proces", "crash isolation"],
|
||||
["Przeglądarka (JS+render)", "Wątek", "współdzielony DOM"],
|
||||
["Gra (fizyka+rendering)", "Wątek", "współdzielony świat gry"],
|
||||
["Kompilacja (make -j8)", "Proces", "izolacja, prostota"],
|
||||
["Baza danych (zapytania)", "Wątek", "współdzielony cache"],
|
||||
["Microservices", "Proces (kontener)", "izolacja, deployment"],
|
||||
]
|
||||
draw_table(
|
||||
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_scenario_table.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 8. IPC details: pipe, shared memory, socket (3-panel)
|
||||
# ============================================================
|
||||
def gen_ipc_details() -> None:
|
||||
"""Gen ipc details."""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(11, 3.5))
|
||||
fig.suptitle("Mechanizmy IPC — szczegóły", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# Panel 1: Pipe
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Pipe (potok)", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
draw_box(ax, 0.2, 2.0, 1.8, 1.2, "Proces A\n(ls)\nstdout", fill=GRAY1, fontsize=FS)
|
||||
draw_box(
|
||||
ax,
|
||||
3.0,
|
||||
2.0,
|
||||
1.8,
|
||||
1.2,
|
||||
"Bufor\njądra\n(4 KB)",
|
||||
fill=GRAY2,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_box(ax, 5.8, 2.0, 1.8, 1.2, "Proces B\n(grep)\nstdin", fill=GRAY1, fontsize=FS)
|
||||
draw_arrow(ax, 2.0, 2.6, 3.0, 2.6, lw=1.5)
|
||||
ax.text(2.5, 3.0, "write()\nfd[1]", fontsize=FS_SMALL, ha="center")
|
||||
draw_arrow(ax, 4.8, 2.6, 5.8, 2.6, lw=1.5)
|
||||
ax.text(5.3, 3.0, "read()\nfd[0]", fontsize=FS_SMALL, ha="center")
|
||||
ax.text(
|
||||
4.0,
|
||||
0.8,
|
||||
"Jednokierunkowy\nBufor pełny → write() blokuje",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# Panel 2: Shared Memory
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Shared Memory", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
draw_box(ax, 0.3, 3.0, 2.2, 1.2, "Proces A\nstrona 7", fill=GRAY1, fontsize=FS)
|
||||
draw_box(ax, 5.5, 3.0, 2.2, 1.2, "Proces B\nstrona 3", fill=GRAY1, fontsize=FS)
|
||||
draw_box(
|
||||
ax,
|
||||
2.8,
|
||||
1.0,
|
||||
2.4,
|
||||
1.2,
|
||||
"RAM\nramka 42",
|
||||
fill=GRAY3,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 2.0, 3.0, 3.5, 2.2, lw=1.5)
|
||||
draw_arrow(ax, 6.0, 3.0, 4.5, 2.2, lw=1.5)
|
||||
ax.text(
|
||||
4.0,
|
||||
0.3,
|
||||
"Zero kopiowania!\nA pisze → B widzi od razu\nWymaga synchronizacji (semafor)",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
# Panel 3: Socket
|
||||
ax = axes[2]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Socket", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Network socket
|
||||
draw_box(
|
||||
ax, 0.3, 3.2, 1.8, 0.9, "Klient", fill=GRAY1, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_box(
|
||||
ax, 5.5, 3.2, 1.8, 0.9, "Serwer", fill=GRAY1, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_double_arrow(ax, 2.1, 3.65, 5.5, 3.65, lw=1.5)
|
||||
ax.text(3.8, 4.3, "TCP/IP (sieciowy)", fontsize=FS, ha="center", fontweight="bold")
|
||||
|
||||
# Unix socket
|
||||
draw_box(
|
||||
ax, 0.3, 1.3, 1.8, 0.9, "Proces A", fill=GRAY4, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_box(
|
||||
ax, 5.5, 1.3, 1.8, 0.9, "Proces B", fill=GRAY4, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_double_arrow(ax, 2.1, 1.75, 5.5, 1.75, lw=1.5)
|
||||
ax.text(
|
||||
3.8,
|
||||
2.4,
|
||||
"Unix domain socket\n(/tmp/app.sock)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
ax.text(
|
||||
3.8,
|
||||
0.5,
|
||||
"Dwukierunkowy\nNajbardziej uniwersalny IPC",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.9])
|
||||
save_fig(fig, "q9_ipc_details.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 9. IPC comparison table
|
||||
# ============================================================
|
||||
def gen_ipc_table() -> None:
|
||||
"""Gen ipc table."""
|
||||
fig, ax = plt.subplots(figsize=(8.5, 3.5))
|
||||
ax.set_xlim(0, 11)
|
||||
ax.set_ylim(-4.5, 1)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Porównanie mechanizmów IPC", fontsize=FS_TITLE, fontweight="bold", pad=10
|
||||
)
|
||||
|
||||
headers = ["Mechanizm", "Kierunek", "Szybkość", "Zastosowanie"]
|
||||
col_w = [2.5, 2.0, 2.5, 3.5]
|
||||
rows = [
|
||||
["Pipe", "jednokierunkowy", "średnia", "ls | grep"],
|
||||
["Named Pipe", "jednokierunkowy", "średnia", "demon → klient"],
|
||||
["Shared Memory", "dwukierunkowy", "NAJSZYBSZA", "video, bazy danych"],
|
||||
["Message Queue", "dwukierunkowy", "średnia", "wieloproducentowe"],
|
||||
["Socket", "dwukierunkowy", "wolna (sieć)", "klient-serwer"],
|
||||
["Signal", "jednokierunkowy", "natychmiastowa", "powiadomienia (nr)"],
|
||||
]
|
||||
draw_table(
|
||||
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.5, fontsize=7.5
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_ipc_table.png")
|
||||
@ -0,0 +1,404 @@
|
||||
"""Q9 diagrams 10-13: race conditions, deadlock, Coffman, starvation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _q9_common import (
|
||||
FS,
|
||||
FS_LABEL,
|
||||
FS_SMALL,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
draw_arrow,
|
||||
draw_box,
|
||||
draw_table,
|
||||
save_fig,
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 10. Race condition (simple x + bank timeline)
|
||||
# ============================================================
|
||||
def gen_race_condition() -> None:
|
||||
"""Gen race condition."""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(11, 5))
|
||||
fig.suptitle(
|
||||
"Wyścig (Race Condition) — przykłady", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
# Panel 1: simple x increment
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Prosty wyścig: x = x + 1", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Timeline
|
||||
steps_a = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"]
|
||||
steps_b = ["czytaj x (=0)", "dodaj 1", "zapisz x (=1)"]
|
||||
ax.text(2.0, 6.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
ax.text(6.0, 6.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
ax.plot([2, 2], [0.8, 6.0], color=LN, lw=1)
|
||||
ax.plot([6, 6], [0.8, 6.0], color=LN, lw=1)
|
||||
|
||||
for i, (sa, sb) in enumerate(zip(steps_a, steps_b, strict=False)):
|
||||
y = 5.3 - i * 1.2
|
||||
draw_box(ax, 0.5, y, 3.0, 0.6, sa, fill=GRAY4, fontsize=FS)
|
||||
draw_box(ax, 4.5, y - 0.3, 3.0, 0.6, sb, fill=GRAY1, fontsize=FS)
|
||||
|
||||
ax.text(
|
||||
4.0,
|
||||
0.4,
|
||||
"Wynik: x = 1 (powinno 2!)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
|
||||
)
|
||||
|
||||
# Panel 2: bank account
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Konto bankowe: saldo = 1000 zł", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
ax.text(2.5, 6.3, "Wątek A (+500)", fontsize=FS, ha="center", fontweight="bold")
|
||||
ax.text(7.5, 6.3, "Wątek B (-200)", fontsize=FS, ha="center", fontweight="bold")
|
||||
ax.plot([2.5, 2.5], [0.8, 6.0], color=LN, lw=1)
|
||||
ax.plot([7.5, 7.5], [0.8, 6.0], color=LN, lw=1)
|
||||
|
||||
events = [
|
||||
("t1", "czytaj → 1000", "", 5.3),
|
||||
("t2", "", "czytaj → 1000", 4.6),
|
||||
("t3", "1000+500=1500", "", 3.9),
|
||||
("t4", "", "1000-200=800", 3.2),
|
||||
("t5", "zapisz 1500", "", 2.5),
|
||||
("t6", "", "zapisz 800 ✗", 1.8),
|
||||
]
|
||||
for t, a, b, y in events:
|
||||
ax.text(0.3, y + 0.15, t, fontsize=FS_SMALL, fontweight="bold", va="center")
|
||||
if a:
|
||||
draw_box(ax, 1.0, y, 3.0, 0.45, a, fill=GRAY4, fontsize=FS_SMALL)
|
||||
if b:
|
||||
fill = "#F8D7DA" if "✗" in b else GRAY1
|
||||
draw_box(ax, 6.0, y, 3.0, 0.45, b, fill=fill, fontsize=FS_SMALL)
|
||||
|
||||
ax.text(
|
||||
5.0,
|
||||
0.4,
|
||||
"Wynik: 800 zł (powinno 1300!)",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.9])
|
||||
save_fig(fig, "q9_race_condition.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 11. Deadlock scenario + cycle
|
||||
# ============================================================
|
||||
def gen_deadlock_scenario() -> None:
|
||||
"""Gen deadlock scenario."""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
|
||||
fig.suptitle("Zakleszczenie (Deadlock)", fontsize=FS_TITLE, fontweight="bold")
|
||||
|
||||
# Panel 1: timeline
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Scenariusz z 2 mutexami", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
ax.text(2.5, 5.3, "Wątek A", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
ax.text(6.0, 5.3, "Wątek B", fontsize=FS_LABEL, ha="center", fontweight="bold")
|
||||
|
||||
steps = [
|
||||
("lock(mutex1) OK", "", "trzyma", False, 4.5),
|
||||
("", "lock(mutex2) OK", "trzyma", False, 3.7),
|
||||
("lock(mutex2) ...WAIT", "", "CZEKA!", True, 2.9),
|
||||
("", "lock(mutex1) ...WAIT", "CZEKA!", True, 2.1),
|
||||
]
|
||||
for a_text, b_text, _note, is_wait, y in steps:
|
||||
if a_text:
|
||||
fill = "#F8D7DA" if is_wait else GRAY4
|
||||
draw_box(ax, 0.5, y, 3.3, 0.55, a_text, fill=fill, fontsize=FS_SMALL)
|
||||
if b_text:
|
||||
fill = "#F8D7DA" if is_wait else GRAY4
|
||||
draw_box(ax, 4.3, y, 3.3, 0.55, b_text, fill=fill, fontsize=FS_SMALL)
|
||||
|
||||
ax.text(
|
||||
4.0,
|
||||
1.2,
|
||||
"DEADLOCK!\nŻaden nie odpuści",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
bbox={"boxstyle": "round", "facecolor": "#F8D7DA", "edgecolor": "#C62828"},
|
||||
)
|
||||
|
||||
# Panel 2: cycle diagram
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Cykl oczekiwania", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Thread boxes
|
||||
draw_box(
|
||||
ax,
|
||||
0.5,
|
||||
3.5,
|
||||
2.2,
|
||||
1.2,
|
||||
"Wątek A\ntrzyma Mutex 1",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_box(
|
||||
ax,
|
||||
5.3,
|
||||
3.5,
|
||||
2.2,
|
||||
1.2,
|
||||
"Wątek B\ntrzyma Mutex 2",
|
||||
fill=GRAY1,
|
||||
fontsize=FS,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Mutex boxes
|
||||
draw_box(
|
||||
ax, 0.5, 1.0, 2.2, 1.0, "Mutex 1", fill=GRAY3, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
draw_box(
|
||||
ax, 5.3, 1.0, 2.2, 1.0, "Mutex 2", fill=GRAY3, fontsize=FS, fontweight="bold"
|
||||
)
|
||||
|
||||
# holds arrows (down)
|
||||
draw_arrow(ax, 1.6, 3.5, 1.6, 2.0, lw=2)
|
||||
ax.text(0.9, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center")
|
||||
draw_arrow(ax, 6.4, 3.5, 6.4, 2.0, lw=2)
|
||||
ax.text(7.0, 2.7, "trzyma", fontsize=FS_SMALL, rotation=90, va="center")
|
||||
|
||||
# waits-for arrows (across, red)
|
||||
draw_arrow(ax, 2.7, 4.3, 5.3, 4.3, lw=2.5, color="#C62828")
|
||||
ax.text(
|
||||
4.0,
|
||||
4.7,
|
||||
"czeka na Mutex 2",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
)
|
||||
draw_arrow(ax, 5.3, 3.7, 2.7, 3.7, lw=2.5, color="#C62828")
|
||||
ax.text(
|
||||
4.0,
|
||||
3.1,
|
||||
"czeka na Mutex 1",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
color="#C62828",
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.9])
|
||||
save_fig(fig, "q9_deadlock_scenario.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 12. Coffman conditions + prevention strategies
|
||||
# ============================================================
|
||||
def gen_coffman_strategies() -> None:
|
||||
"""Gen coffman strategies."""
|
||||
fig, ax = plt.subplots(figsize=(9, 4))
|
||||
ax.set_xlim(0, 11.5)
|
||||
ax.set_ylim(-3.5, 1)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Warunki Coffmana — zapobieganie deadlockowi",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
headers = ["Warunek", "Opis", "Jak złamać", "Przykład"]
|
||||
col_w = [2.5, 2.5, 3.0, 3.0]
|
||||
rows = [
|
||||
[
|
||||
"1. Mutual Exclusion",
|
||||
"zasób wyłączny",
|
||||
"współdzielony zasób",
|
||||
"Read-write lock",
|
||||
],
|
||||
[
|
||||
"2. Hold and Wait",
|
||||
"trzymaj + czekaj",
|
||||
"bierz WSZYSTKIE naraz",
|
||||
"lock(m1,m2) atomowo",
|
||||
],
|
||||
[
|
||||
"3. No Preemption",
|
||||
"nie zabierzesz siłą",
|
||||
"timeout / trylock",
|
||||
"pthread_mutex_trylock()",
|
||||
],
|
||||
[
|
||||
"4. Circular Wait",
|
||||
"cykliczne oczekiw.",
|
||||
"porządek liniowy",
|
||||
"zawsze m1 przed m2",
|
||||
],
|
||||
]
|
||||
draw_table(
|
||||
ax, headers, rows, x0=0.25, y0=0.5, col_widths=col_w, row_h=0.6, fontsize=7
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.75,
|
||||
-3.1,
|
||||
"▸ Najczęstsza strategia: PORZĄDEK LINIOWY — "
|
||||
"numeruj mutexy, zawsze blokuj rosnąco",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
save_fig(fig, "q9_coffman_strategies.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 13. Starvation + Priority Inversion (2-panel)
|
||||
# ============================================================
|
||||
def gen_starvation_priority() -> None:
|
||||
"""Gen starvation priority."""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
|
||||
fig.suptitle(
|
||||
"Zagłodzenie i Inwersja priorytetów", fontsize=FS_TITLE, fontweight="bold"
|
||||
)
|
||||
|
||||
# Panel 1: Starvation + aging
|
||||
ax = axes[0]
|
||||
ax.set_xlim(0, 8)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Zagłodzenie (Starvation)", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
threads = [
|
||||
("Wątek HIGH", "prio=10", GRAY5, 3.0),
|
||||
("Wątek HIGH", "prio=9", GRAY3, 2.2),
|
||||
("Wątek MED", "prio=5", GRAY2, 1.4),
|
||||
("Wątek LOW", "prio=1 → głoduje!", "#F8D7DA", 0.6),
|
||||
]
|
||||
for name, prio, color, y in threads:
|
||||
draw_box(
|
||||
ax, 0.5, y, 2.0, 0.6, name, fill=color, fontsize=FS_SMALL, fontweight="bold"
|
||||
)
|
||||
ax.text(2.8, y + 0.3, prio, fontsize=FS_SMALL, va="center")
|
||||
|
||||
ax.text(
|
||||
1.5,
|
||||
4.2,
|
||||
"CPU zawsze\ndostaje HIGH!",
|
||||
fontsize=FS,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
draw_arrow(ax, 1.5, 3.9, 1.5, 3.65, lw=1.5)
|
||||
|
||||
# Aging solution
|
||||
draw_box(ax, 4.5, 1.5, 3.2, 2.5, "", fill=GRAY4, rounded=True)
|
||||
ax.text(6.1, 3.7, "Rozwiązanie: AGING", fontsize=FS, fontweight="bold", ha="center")
|
||||
aging = [
|
||||
"t=0: prio=1",
|
||||
"t=100ms: prio=2",
|
||||
"t=200ms: prio=3",
|
||||
"...",
|
||||
"w końcu → CPU!",
|
||||
]
|
||||
for i, line in enumerate(aging):
|
||||
ax.text(
|
||||
6.1, 3.2 - i * 0.4, line, fontsize=FS_SMALL, ha="center", family="monospace"
|
||||
)
|
||||
|
||||
# Panel 2: Priority Inversion
|
||||
ax = axes[1]
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("auto")
|
||||
ax.axis("off")
|
||||
ax.set_title("Inwersja priorytetów", fontsize=FS_LABEL, fontweight="bold")
|
||||
|
||||
# Timeline
|
||||
labels = ["H (wysoki)", "M (średni)", "L (niski)"]
|
||||
ys = [4.2, 2.8, 1.4]
|
||||
for label, y in zip(labels, ys, strict=False):
|
||||
ax.text(0.3, y + 0.2, label, fontsize=FS, fontweight="bold", va="center")
|
||||
|
||||
# L runs and locks mutex
|
||||
draw_box(ax, 2.0, ys[2], 1.2, 0.5, "lock(m)", fill=GRAY1, fontsize=FS_SMALL)
|
||||
|
||||
# M preempts L
|
||||
draw_box(ax, 3.5, ys[1], 3.0, 0.5, "M pracuje...", fill=GRAY3, fontsize=FS_SMALL)
|
||||
|
||||
# H waits for mutex
|
||||
draw_box(
|
||||
ax,
|
||||
3.5,
|
||||
ys[0],
|
||||
3.0,
|
||||
0.5,
|
||||
"CZEKA na mutex!",
|
||||
fill="#F8D7DA",
|
||||
fontsize=FS_SMALL,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# M finishes, L continues, unlocks
|
||||
draw_box(ax, 6.8, ys[2], 1.5, 0.5, "unlock(m)", fill=GRAY1, fontsize=FS_SMALL)
|
||||
draw_box(ax, 8.5, ys[0], 1.2, 0.5, "H runs", fill=GRAY4, fontsize=FS_SMALL)
|
||||
|
||||
# Explanation
|
||||
ax.text(
|
||||
5.0,
|
||||
0.5,
|
||||
"H czeka na M (mimo H > M)!\n"
|
||||
"Rozwiązanie: Priority Inheritance\n"
|
||||
"L dziedziczy priorytet H → M nie wypycha L",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
style="italic",
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.0,
|
||||
0.0,
|
||||
"Mars Pathfinder (1997) — klasyczny bug!",
|
||||
fontsize=FS_SMALL,
|
||||
ha="center",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.9])
|
||||
save_fig(fig, "q9_starvation_priority.png")
|
||||
@ -0,0 +1,87 @@
|
||||
"""Common constants and utilities for scheduling diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
DPI = 300
|
||||
BG = "white"
|
||||
LN = "black"
|
||||
FS = 8
|
||||
FS_TITLE = 11
|
||||
OUTPUT_DIR = str(Path(__file__).resolve().parent / "img")
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
GRAY1 = "#E8E8E8"
|
||||
GRAY2 = "#D0D0D0"
|
||||
GRAY3 = "#B8B8B8"
|
||||
GRAY4 = "#F5F5F5"
|
||||
GRAY5 = "#C0C0C0"
|
||||
|
||||
MIN_COLUMN_INDEX = 3
|
||||
FONTWEIGHT_THRESHOLD = 3
|
||||
|
||||
|
||||
def draw_box(
|
||||
ax: Axes,
|
||||
x: float,
|
||||
y: float,
|
||||
w: float,
|
||||
h: float,
|
||||
text: str,
|
||||
fill: str = "white",
|
||||
lw: float = 1.2,
|
||||
fontsize: float = FS,
|
||||
fontweight: str = "normal",
|
||||
ha: str = "center",
|
||||
va: str = "center",
|
||||
*,
|
||||
rounded: bool = True,
|
||||
) -> None:
|
||||
"""Draw box."""
|
||||
if rounded:
|
||||
rect = FancyBboxPatch(
|
||||
(x, y), w, h, boxstyle="round,pad=0.05", lw=lw, edgecolor=LN, facecolor=fill
|
||||
)
|
||||
else:
|
||||
rect = mpatches.Rectangle((x, y), w, h, lw=lw, edgecolor=LN, facecolor=fill)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + w / 2,
|
||||
y + h / 2,
|
||||
text,
|
||||
ha=ha,
|
||||
va=va,
|
||||
fontsize=fontsize,
|
||||
fontweight=fontweight,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
|
||||
def draw_arrow(
|
||||
ax: Axes,
|
||||
x1: float,
|
||||
y1: float,
|
||||
x2: float,
|
||||
y2: float,
|
||||
lw: float = 1.2,
|
||||
style: str = "->",
|
||||
color: str = LN,
|
||||
) -> None:
|
||||
"""Draw arrow."""
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2, y2),
|
||||
xytext=(x1, y1),
|
||||
arrowprops={"arrowstyle": style, "color": color, "lw": lw},
|
||||
)
|
||||
@ -0,0 +1,309 @@
|
||||
"""Scheduling complexity landscape and EDD example diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
|
||||
BG,
|
||||
DPI,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
draw_arrow,
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# SCHEDULING COMPLEXITY LANDSCAPE
|
||||
# ============================================================
|
||||
def draw_complexity_map() -> None:
|
||||
"""Draw complexity map."""
|
||||
_fig, ax = plt.subplots(1, 1, figsize=(8.27, 5))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 7)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Złożoność problemów szeregowania — od łatwych do NP-trudnych",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Gradient arrow at the top
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(9.5, 6.2),
|
||||
xytext=(0.5, 6.2),
|
||||
arrowprops={"arrowstyle": "->", "color": LN, "lw": 2},
|
||||
)
|
||||
ax.text(5, 6.5, "Rosnąca złożoność", ha="center", fontsize=9, fontweight="bold")
|
||||
|
||||
# Easy (polynomial) region
|
||||
easy_rect = FancyBboxPatch(
|
||||
(0.3, 2.8),
|
||||
4.0,
|
||||
3.0,
|
||||
boxstyle="round,pad=0.15",
|
||||
lw=1.5,
|
||||
edgecolor="#666666",
|
||||
facecolor=GRAY4,
|
||||
linestyle="-",
|
||||
)
|
||||
ax.add_patch(easy_rect)
|
||||
ax.text(
|
||||
2.3,
|
||||
5.5,
|
||||
"WIELOMIANOWE O(n log n)",
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
color="#444444",
|
||||
)
|
||||
|
||||
easy_problems = [
|
||||
("1 || ΣCⱼ", "SPT", GRAY1, 4.8),
|
||||
("1 || Lmax", "EDD", GRAY2, 4.0),
|
||||
("F2 || Cmax", "Johnson", GRAY1, 3.2),
|
||||
]
|
||||
for prob, method, fill, y in easy_problems:
|
||||
rect = FancyBboxPatch(
|
||||
(0.6, y),
|
||||
3.5,
|
||||
0.6,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
1.2,
|
||||
y + 0.3,
|
||||
prob,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
fontfamily="monospace",
|
||||
)
|
||||
ax.text(3.0, y + 0.3, f"→ {method}", ha="center", va="center", fontsize=8)
|
||||
|
||||
# Hard (NP) region
|
||||
hard_rect = FancyBboxPatch(
|
||||
(5.3, 2.8),
|
||||
4.3,
|
||||
3.0,
|
||||
boxstyle="round,pad=0.15",
|
||||
lw=1.5,
|
||||
edgecolor="#444444",
|
||||
facecolor=GRAY3,
|
||||
linestyle="-",
|
||||
)
|
||||
ax.add_patch(hard_rect)
|
||||
ax.text(
|
||||
7.45,
|
||||
5.5,
|
||||
"NP-TRUDNE",
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
color="#333333",
|
||||
)
|
||||
|
||||
hard_problems = [
|
||||
("Pm || Cmax\n(m≥2)", "LPT heuryst.", GRAY2, 4.5),
|
||||
("1 || ΣTⱼ", "branch&bound", GRAY4, 3.7),
|
||||
("Jm || Cmax\n(m≥3)", "metaheuryst.", GRAY5, 2.9),
|
||||
]
|
||||
for prob, method, fill, y in hard_problems:
|
||||
rect = FancyBboxPatch(
|
||||
(5.6, y),
|
||||
3.7,
|
||||
0.7,
|
||||
boxstyle="round,pad=0.05",
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
6.5,
|
||||
y + 0.35,
|
||||
prob,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
fontfamily="monospace",
|
||||
)
|
||||
ax.text(8.2, y + 0.35, f"→ {method}", ha="center", va="center", fontsize=7)
|
||||
|
||||
# Arrow connecting
|
||||
draw_arrow(ax, 4.4, 4.0, 5.2, 4.0, lw=2, color="#888888")
|
||||
ax.text(4.8, 4.25, "+1\nmaszyna", ha="center", fontsize=6, color="#888888")
|
||||
|
||||
# Bottom: key insight
|
||||
ax.text(
|
||||
5.0,
|
||||
1.8,
|
||||
"„Dodanie jednej maszyny lub jednego ograniczenia\n"
|
||||
'może zmienić problem z łatwego na NP-trudny!"',
|
||||
ha="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
"lw": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Bottom examples
|
||||
ax.text(
|
||||
5.0,
|
||||
0.8,
|
||||
"1 maszyna → łatwe (sortuj) | ≥2 maszyny równoległe → NP-trudne\n"
|
||||
"Flow shop 2 maszyny → Johnson O(n log n) | Flow shop 3 maszyny → NP-trudne",
|
||||
ha="center",
|
||||
fontsize=7,
|
||||
color="#555555",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / "scheduling_complexity_map.png"),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ scheduling_complexity_map.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# EDD EXAMPLE (1 || Lmax)
|
||||
# ============================================================
|
||||
def draw_edd_example() -> None:
|
||||
"""Draw edd example."""
|
||||
_fig, ax = plt.subplots(1, 1, figsize=(8.27, 4))
|
||||
ax.set_xlim(-2, 28)
|
||||
ax.set_ylim(-2, 4)
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"EDD (Earliest Due Date) — 1 || Lmax — Przykład",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=8,
|
||||
)
|
||||
|
||||
# Tasks: name, processing time, due date
|
||||
tasks = [("J1", 4, 10), ("J2", 2, 6), ("J3", 6, 15), ("J4", 3, 8), ("J5", 5, 18)]
|
||||
# EDD: sort by due date
|
||||
edd_order = sorted(tasks, key=lambda x: x[2])
|
||||
|
||||
bar_y = 1.5
|
||||
bar_h = 0.8
|
||||
t = 0
|
||||
fills_edd = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5]
|
||||
|
||||
lateness_vals = []
|
||||
for i, (name, p, d) in enumerate(edd_order):
|
||||
rect = mpatches.Rectangle(
|
||||
(t, bar_y), p, bar_h, lw=1.2, edgecolor=LN, facecolor=fills_edd[i]
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
t + p / 2,
|
||||
bar_y + bar_h / 2,
|
||||
f"{name}\np={p}, d={d}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6.5,
|
||||
fontweight="bold",
|
||||
)
|
||||
t += p
|
||||
lateness = t - d
|
||||
lateness_vals.append(lateness)
|
||||
|
||||
# Due date marker
|
||||
ax.plot(
|
||||
[d, d], [bar_y - 0.4, bar_y - 0.1], color="#888888", lw=0.8, linestyle="--"
|
||||
)
|
||||
ax.text(
|
||||
d,
|
||||
bar_y - 0.5,
|
||||
f"d={d}",
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=5.5,
|
||||
color="#888888",
|
||||
)
|
||||
|
||||
# Completion + lateness
|
||||
ax.plot([t, t], [bar_y + bar_h, bar_y + bar_h + 0.15], color=LN, lw=0.8)
|
||||
ax.text(
|
||||
t,
|
||||
bar_y + bar_h + 0.2,
|
||||
f"C={t}\nL={lateness}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=5.5,
|
||||
)
|
||||
|
||||
# Time axis
|
||||
ax.plot([0, 22], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5)
|
||||
|
||||
lmax = max(lateness_vals)
|
||||
ax.text(
|
||||
22,
|
||||
bar_y + bar_h / 2,
|
||||
f"Lmax = {lmax}",
|
||||
ha="left",
|
||||
va="center",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.2", "facecolor": GRAY1, "edgecolor": LN},
|
||||
)
|
||||
|
||||
# Bottom mnemonic
|
||||
ax.text(
|
||||
10,
|
||||
-1.3,
|
||||
'„Early Due Date Does it first" — najpilniejszy deadline idzie pierwszy',
|
||||
ha="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
"lw": 0.8,
|
||||
},
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / "scheduling_edd_example.png"),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ scheduling_edd_example.png")
|
||||
@ -0,0 +1,484 @@
|
||||
"""Graham notation α|β|γ visual mnemonic map diagram."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
|
||||
BG,
|
||||
DPI,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def draw_graham_notation() -> None:
|
||||
"""Draw graham notation."""
|
||||
_fig, ax = plt.subplots(1, 1, figsize=(8.27, 10))
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 14)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Notacja Grahama \u03b1 | β | \u03b3 — Mapa mnemoniczna",
|
||||
fontsize=FS_TITLE + 1,
|
||||
fontweight="bold",
|
||||
pad=12,
|
||||
)
|
||||
|
||||
_draw_graham_formula_bar(ax)
|
||||
_draw_graham_alpha_beta(ax)
|
||||
_draw_graham_lower(ax)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / "scheduling_graham_notation.png"),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ scheduling_graham_notation.png")
|
||||
|
||||
|
||||
def _draw_graham_formula_bar(ax: Axes) -> None:
|
||||
"""Draw the top alpha|beta|gamma formula bar."""
|
||||
bar_y = 12.5
|
||||
bar_h = 1.0
|
||||
# alpha box
|
||||
rect = FancyBboxPatch(
|
||||
(0.5, bar_y),
|
||||
2.5,
|
||||
bar_h,
|
||||
boxstyle="round,pad=0.08",
|
||||
lw=2,
|
||||
edgecolor=LN,
|
||||
facecolor=GRAY1,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
1.75,
|
||||
bar_y + bar_h / 2,
|
||||
"\u03b1",
|
||||
fontsize=20,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
ax.text(
|
||||
1.75,
|
||||
bar_y - 0.25,
|
||||
"MASZYNY",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="top",
|
||||
color="#444444",
|
||||
)
|
||||
|
||||
# separator |
|
||||
ax.text(
|
||||
3.3,
|
||||
bar_y + bar_h / 2,
|
||||
"|",
|
||||
fontsize=24,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
|
||||
# β box
|
||||
rect = FancyBboxPatch(
|
||||
(3.7, bar_y),
|
||||
2.5,
|
||||
bar_h,
|
||||
boxstyle="round,pad=0.08",
|
||||
lw=2,
|
||||
edgecolor=LN,
|
||||
facecolor=GRAY2,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
4.95,
|
||||
bar_y + bar_h / 2,
|
||||
"β",
|
||||
fontsize=20,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
ax.text(
|
||||
4.95,
|
||||
bar_y - 0.25,
|
||||
"OGRANICZENIA",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="top",
|
||||
color="#444444",
|
||||
)
|
||||
|
||||
# separator |
|
||||
ax.text(
|
||||
6.5,
|
||||
bar_y + bar_h / 2,
|
||||
"|",
|
||||
fontsize=24,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
|
||||
# gamma box
|
||||
rect = FancyBboxPatch(
|
||||
(6.9, bar_y),
|
||||
2.5,
|
||||
bar_h,
|
||||
boxstyle="round,pad=0.08",
|
||||
lw=2,
|
||||
edgecolor=LN,
|
||||
facecolor=GRAY3,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
8.15,
|
||||
bar_y + bar_h / 2,
|
||||
"\u03b3",
|
||||
fontsize=20,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="center",
|
||||
)
|
||||
ax.text(
|
||||
8.15,
|
||||
bar_y - 0.25,
|
||||
"CEL",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
ha="center",
|
||||
va="top",
|
||||
color="#444444",
|
||||
)
|
||||
|
||||
|
||||
def _draw_graham_alpha_beta(ax: Axes) -> None:
|
||||
"""Draw alpha (machines) and beta (constraints) sections."""
|
||||
start_x = 0.3
|
||||
col_w = 1.28
|
||||
|
||||
# === SECTION alpha: MACHINES ===
|
||||
sec_y = 11.5
|
||||
ax.text(
|
||||
0.3,
|
||||
sec_y,
|
||||
'\u03b1 — „1 Prawdziwy Quasi-Rycerz Forsuje Jaskinię Orków"',
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
va="top",
|
||||
style="italic",
|
||||
color="#333333",
|
||||
)
|
||||
|
||||
alpha_items = [
|
||||
("1", "jedna maszyna", "●", GRAY4),
|
||||
("Pm", "identyczne Parallel", "●●●", GRAY1),
|
||||
("Qm", "Quasi-uniform\n(różne prędkości)", "●●◐", GRAY4),
|
||||
("Rm", "Random unrelated\n(czasy per para)", "●◆▲", GRAY1),
|
||||
("Fm", "Flow shop\n(ta sama kolejność)", "→→→", GRAY2),
|
||||
("Jm", "Job shop\n(indyw. trasy)", "↗↙↘", GRAY4),
|
||||
("Om", "Open shop\n(dowolna kolej.)", "?→?", GRAY1),
|
||||
]
|
||||
|
||||
col_w = 1.28
|
||||
box_h_a = 1.1
|
||||
start_x = 0.3
|
||||
start_y = 9.6
|
||||
|
||||
for i, (symbol, desc, icon, fill) in enumerate(alpha_items):
|
||||
x = start_x + i * col_w
|
||||
y = start_y
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
col_w - 0.1,
|
||||
box_h_a,
|
||||
boxstyle="round,pad=0.04",
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + (col_w - 0.1) / 2,
|
||||
y + box_h_a - 0.15,
|
||||
symbol,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
x + (col_w - 0.1) / 2,
|
||||
y + box_h_a / 2 - 0.1,
|
||||
desc,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=5.5,
|
||||
)
|
||||
ax.text(
|
||||
x + (col_w - 0.1) / 2, y + 0.12, icon, ha="center", va="bottom", fontsize=7
|
||||
)
|
||||
|
||||
# Complexity arrow under alpha
|
||||
arr_y = 9.35
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(9.0, arr_y),
|
||||
xytext=(0.5, arr_y),
|
||||
arrowprops={"arrowstyle": "->", "color": "#666666", "lw": 1.5},
|
||||
)
|
||||
ax.text(
|
||||
4.8,
|
||||
arr_y - 0.18,
|
||||
"rosnąca złożoność →",
|
||||
ha="center",
|
||||
fontsize=6,
|
||||
color="#666666",
|
||||
)
|
||||
|
||||
# === SECTION β: CONSTRAINTS ===
|
||||
sec_y2 = 8.9
|
||||
ax.text(
|
||||
0.3,
|
||||
sec_y2,
|
||||
"β — „Robak Daje Deadline: Przerwy Poprzedzają Pojedyncze Setup'y\"",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
va="top",
|
||||
style="italic",
|
||||
color="#333333",
|
||||
)
|
||||
|
||||
beta_items = [
|
||||
("rⱼ", "release\ndates", "Robak\ndostępne\nod czasu rⱼ", GRAY1),
|
||||
("dⱼ", "due\ndates", "Daje\ntermin soft\n(kara za spóźn.)", GRAY4),
|
||||
("d̄ⱼ", "dead-\nlines", "Deadline\ntermin hard\n(musi dotrzymać)", GRAY1),
|
||||
("pmtn", "preemp-\ntion", "Przerwy\nmożna\nprzerwać", GRAY2),
|
||||
("prec", "prece-\ndencje", "Poprzedzają\nA->B (DAG)", GRAY4),
|
||||
("pⱼ=1", "unit\ntime", "Pojedyncze\nwszystkie = 1", GRAY1),
|
||||
("sⱼₖ", "setup\ntimes", "Setup'y\nprzezbrojenie\nmiędzy j->k", GRAY4),
|
||||
]
|
||||
|
||||
start_y2 = 7.0
|
||||
box_h_b = 1.4
|
||||
|
||||
for i, (symbol, _label, desc, fill) in enumerate(beta_items):
|
||||
x = start_x + i * col_w
|
||||
y = start_y2
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
col_w - 0.1,
|
||||
box_h_b,
|
||||
boxstyle="round,pad=0.04",
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + (col_w - 0.1) / 2,
|
||||
y + box_h_b - 0.12,
|
||||
symbol,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
x + (col_w - 0.1) / 2,
|
||||
y + box_h_b / 2 - 0.05,
|
||||
desc,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=5,
|
||||
)
|
||||
|
||||
|
||||
def _draw_graham_lower(ax: Axes) -> None:
|
||||
"""Draw gamma criteria, examples, and footer sections."""
|
||||
start_x = 0.3
|
||||
|
||||
# === SECTION gamma: CRITERIA ===
|
||||
sec_y3 = 6.5
|
||||
ax.text(
|
||||
0.3,
|
||||
sec_y3,
|
||||
'\u03b3 — „Ciężki Sum Ważony Lata, Tardiness Uderza"',
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
va="top",
|
||||
style="italic",
|
||||
color="#333333",
|
||||
)
|
||||
|
||||
gamma_items = [
|
||||
("Cmax", "makespan\nmax(Cⱼ)", "Jak długo\ntrwa WSZYSTKO?", GRAY2),
|
||||
("ΣCⱼ", "suma\nukończeń", "Średni czas\noczekiwania?", GRAY4),
|
||||
("ΣwⱼCⱼ", "ważona\nsuma", "Priorytety\nzadań?", GRAY1),
|
||||
("Lmax", "max\nopóźnienie", "Najgorsze\nspóźnienie?", GRAY2),
|
||||
("ΣTⱼ", "suma\nspóźnień", "Łączne\nspóźnienia?", GRAY4),
|
||||
("ΣUⱼ", "liczba\nspóźnionych", "Ile spóźnionych\nzadań?", GRAY1),
|
||||
]
|
||||
|
||||
start_y3 = 4.5
|
||||
box_h_g = 1.4
|
||||
col_w_g = 1.5
|
||||
|
||||
for i, (symbol, label, question, fill) in enumerate(gamma_items):
|
||||
x = start_x + i * col_w_g
|
||||
y = start_y3
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
col_w_g - 0.1,
|
||||
box_h_g,
|
||||
boxstyle="round,pad=0.04",
|
||||
lw=1,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + (col_w_g - 0.1) / 2,
|
||||
y + box_h_g - 0.1,
|
||||
symbol,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.text(
|
||||
x + (col_w_g - 0.1) / 2,
|
||||
y + box_h_g / 2 - 0.05,
|
||||
label,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6,
|
||||
)
|
||||
ax.text(
|
||||
x + (col_w_g - 0.1) / 2,
|
||||
y + 0.15,
|
||||
f'„{question}"',
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=5,
|
||||
style="italic",
|
||||
)
|
||||
|
||||
# === BOTTOM: Example + Optimal methods ===
|
||||
ex_y = 3.5
|
||||
ax.text(
|
||||
0.3,
|
||||
ex_y,
|
||||
"Przykłady zapisu i optymalne metody:",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
va="top",
|
||||
)
|
||||
|
||||
examples = [
|
||||
("1 || ΣCⱼ", "SPT (najkrótsze\nnajpierw)", "O(n log n)", GRAY1),
|
||||
("1 || Lmax", "EDD (najwcześniejszy\ntermin)", "O(n log n)", GRAY4),
|
||||
("F2 || Cmax", "Algorytm\nJohnsona", "O(n log n)", GRAY2),
|
||||
("Pm || Cmax", "LPT heurystyka\n(NP-trudny!)", "NP-hard", GRAY3),
|
||||
("Jm || Cmax", "Branch & Bound\n(NP-trudny!)", "NP-hard", GRAY5),
|
||||
]
|
||||
|
||||
ex_start_y = 1.8
|
||||
ex_box_w = 1.72
|
||||
ex_box_h = 1.4
|
||||
|
||||
for i, (notation, method, complexity, fill) in enumerate(examples):
|
||||
x = start_x + i * (ex_box_w + 0.1)
|
||||
y = ex_start_y
|
||||
rect = FancyBboxPatch(
|
||||
(x, y),
|
||||
ex_box_w,
|
||||
ex_box_h,
|
||||
boxstyle="round,pad=0.04",
|
||||
lw=1.2,
|
||||
edgecolor=LN,
|
||||
facecolor=fill,
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + ex_box_w / 2,
|
||||
y + ex_box_h - 0.12,
|
||||
notation,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
fontfamily="monospace",
|
||||
)
|
||||
ax.text(
|
||||
x + ex_box_w / 2,
|
||||
y + ex_box_h / 2 - 0.05,
|
||||
method,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6,
|
||||
)
|
||||
ax.text(
|
||||
x + ex_box_w / 2,
|
||||
y + 0.12,
|
||||
complexity,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=6.5,
|
||||
fontweight="bold",
|
||||
color="#555555",
|
||||
)
|
||||
|
||||
# Footer mnemonic summary
|
||||
ax.text(
|
||||
5.0,
|
||||
0.8,
|
||||
'„\u03b1|β|\u03b3 = Maszyny | Ograniczenia | Cel"',
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
color="#333333",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
"lw": 1,
|
||||
},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
5.0,
|
||||
0.2,
|
||||
"\u03b1: ILE maszyn i JAKIE? "
|
||||
"β: JAKIE ograniczenia zadań? "
|
||||
"\u03b3: CO minimalizujemy?",
|
||||
ha="center",
|
||||
fontsize=7,
|
||||
color="#555555",
|
||||
)
|
||||
@ -0,0 +1,318 @@
|
||||
"""Johnson's algorithm Gantt chart diagram (F2||Cmax)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
|
||||
BG,
|
||||
DPI,
|
||||
FONTWEIGHT_THRESHOLD,
|
||||
FS_TITLE,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
MIN_COLUMN_INDEX,
|
||||
OUTPUT_DIR,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def draw_johnson_gantt() -> None:
|
||||
"""Draw johnson gantt."""
|
||||
_fig, axes = plt.subplots(
|
||||
2, 1, figsize=(8.27, 7), gridspec_kw={"height_ratios": [1, 1.8]}
|
||||
)
|
||||
|
||||
_draw_johnson_decision_table(axes[0])
|
||||
_draw_johnson_gantt_chart(axes[1])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / "scheduling_johnson_gantt.png"),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ scheduling_johnson_gantt.png")
|
||||
|
||||
|
||||
def _draw_johnson_decision_table(ax: Axes) -> None:
|
||||
"""Draw the Johnson algorithm decision table."""
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 5)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title(
|
||||
"Algorytm Johnsona (F2 || Cmax) — Decyzja + Diagram Gantta",
|
||||
fontsize=FS_TITLE,
|
||||
fontweight="bold",
|
||||
pad=10,
|
||||
)
|
||||
|
||||
# Task table
|
||||
tasks = ["J1", "J2", "J3", "J4", "J5"]
|
||||
a_times = [4, 2, 6, 1, 3]
|
||||
b_times = [5, 3, 2, 7, 4]
|
||||
min_vals = [min(a, b) for a, b in zip(a_times, b_times, strict=False)]
|
||||
min_on = ["M1" if a <= b else "M2" for a, b in zip(a_times, b_times, strict=False)]
|
||||
assign = ["POCZątek" if m == "M1" else "KONIEC" for m in min_on]
|
||||
|
||||
# Draw table
|
||||
col_w_t = 1.3
|
||||
row_h = 0.55
|
||||
headers = ["Zadanie", "aⱼ (M1)", "bⱼ (M2)", "min", "min na", "Przydziel"]
|
||||
table_x = 0.8
|
||||
table_y = 3.8
|
||||
|
||||
for j, hdr in enumerate(headers):
|
||||
x = table_x + j * col_w_t
|
||||
rect = mpatches.Rectangle(
|
||||
(x, table_y), col_w_t, row_h, lw=1, edgecolor=LN, facecolor=GRAY2
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
x + col_w_t / 2,
|
||||
table_y + row_h / 2,
|
||||
hdr,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6.5,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
row_data = [
|
||||
tasks[i],
|
||||
str(a_times[i]),
|
||||
str(b_times[i]),
|
||||
str(min_vals[i]),
|
||||
min_on[i],
|
||||
assign[i],
|
||||
]
|
||||
for j, val in enumerate(row_data):
|
||||
x = table_x + j * col_w_t
|
||||
y = table_y - (i + 1) * row_h
|
||||
fill_c = GRAY1 if min_on[i] == "M1" else GRAY4
|
||||
if j == MIN_COLUMN_INDEX: # min column - highlight
|
||||
fill_c = GRAY3
|
||||
rect = mpatches.Rectangle(
|
||||
(x, y), col_w_t, row_h, lw=0.8, edgecolor=LN, facecolor=fill_c
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
fw = "bold" if j >= FONTWEIGHT_THRESHOLD else "normal"
|
||||
ax.text(
|
||||
x + col_w_t / 2,
|
||||
y + row_h / 2,
|
||||
val,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=6.5,
|
||||
fontweight=fw,
|
||||
)
|
||||
|
||||
# Sorting result
|
||||
result_y = 0.7
|
||||
ax.text(
|
||||
5.0,
|
||||
result_y + 0.4,
|
||||
"Sortuj → POCZĄTEK ↑aⱼ: J4(1), J2(2), J5(3), J1(4) | KONIEC ↓bⱼ: J3(2)",
|
||||
ha="center",
|
||||
fontsize=7,
|
||||
color="#333333",
|
||||
)
|
||||
ax.text(
|
||||
5.0,
|
||||
result_y,
|
||||
"Optymalna kolejność: J4 → J2 → J5 → J1 → J3",
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.2",
|
||||
"facecolor": GRAY1,
|
||||
"edgecolor": LN,
|
||||
"lw": 1.2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _draw_johnson_gantt_chart(ax2: Axes) -> None:
|
||||
"""Draw the Johnson algorithm Gantt chart."""
|
||||
ax2.set_xlim(-1, 24)
|
||||
ax2.set_ylim(-1, 4)
|
||||
ax2.axis("off")
|
||||
|
||||
# Machines labels
|
||||
m1_y = 2.5
|
||||
m2_y = 0.8
|
||||
bar_h = 0.9
|
||||
|
||||
ax2.text(
|
||||
-0.8,
|
||||
m1_y + bar_h / 2,
|
||||
"M1",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax2.text(
|
||||
-0.8,
|
||||
m2_y + bar_h / 2,
|
||||
"M2",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Schedule: J4 → J2 → J5 → J1 → J3
|
||||
order = ["J4", "J2", "J5", "J1", "J3"]
|
||||
a_ord = [1, 2, 3, 4, 6] # M1 times in order
|
||||
b_ord = [7, 3, 4, 5, 2] # M2 times in order
|
||||
fills = [GRAY1, GRAY2, GRAY4, GRAY3, GRAY5]
|
||||
hatches = ["", "///", "", "\\\\\\", "xxx"]
|
||||
|
||||
# M1 schedule
|
||||
m1_starts = []
|
||||
t = 0
|
||||
for a in a_ord:
|
||||
m1_starts.append(t)
|
||||
t += a
|
||||
m1_ends = [s + a for s, a in zip(m1_starts, a_ord, strict=False)]
|
||||
|
||||
# M2 schedule (must wait for M1 finish AND previous M2 finish)
|
||||
m2_starts = []
|
||||
m2_ends = []
|
||||
prev_m2_end = 0
|
||||
for i, b in enumerate(b_ord):
|
||||
start = max(m1_ends[i], prev_m2_end)
|
||||
m2_starts.append(start)
|
||||
m2_ends.append(start + b)
|
||||
prev_m2_end = start + b
|
||||
|
||||
# Draw M1 bars
|
||||
for i in range(5):
|
||||
rect = mpatches.Rectangle(
|
||||
(m1_starts[i], m1_y),
|
||||
a_ord[i],
|
||||
bar_h,
|
||||
lw=1.2,
|
||||
edgecolor=LN,
|
||||
facecolor=fills[i],
|
||||
hatch=hatches[i],
|
||||
)
|
||||
ax2.add_patch(rect)
|
||||
ax2.text(
|
||||
m1_starts[i] + a_ord[i] / 2,
|
||||
m1_y + bar_h / 2,
|
||||
f"{order[i]}\n({a_ord[i]})",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Draw M2 bars
|
||||
for i in range(5):
|
||||
rect = mpatches.Rectangle(
|
||||
(m2_starts[i], m2_y),
|
||||
b_ord[i],
|
||||
bar_h,
|
||||
lw=1.2,
|
||||
edgecolor=LN,
|
||||
facecolor=fills[i],
|
||||
hatch=hatches[i],
|
||||
)
|
||||
ax2.add_patch(rect)
|
||||
ax2.text(
|
||||
m2_starts[i] + b_ord[i] / 2,
|
||||
m2_y + bar_h / 2,
|
||||
f"{order[i]}\n({b_ord[i]})",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Draw idle regions on M2
|
||||
idle_starts = [0]
|
||||
idle_ends = [m2_starts[0]]
|
||||
for i in range(1, 5):
|
||||
if m2_starts[i] > m2_ends[i - 1]:
|
||||
idle_starts.append(m2_ends[i - 1])
|
||||
idle_ends.append(m2_starts[i])
|
||||
|
||||
for s, e in zip(idle_starts, idle_ends, strict=False):
|
||||
if e > s:
|
||||
rect = mpatches.Rectangle(
|
||||
(s, m2_y),
|
||||
e - s,
|
||||
bar_h,
|
||||
lw=0.5,
|
||||
edgecolor="#AAAAAA",
|
||||
facecolor="white",
|
||||
linestyle="--",
|
||||
)
|
||||
ax2.add_patch(rect)
|
||||
ax2.text(
|
||||
s + (e - s) / 2,
|
||||
m2_y + bar_h / 2,
|
||||
"idle",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=5,
|
||||
color="#999999",
|
||||
)
|
||||
|
||||
# Time axis
|
||||
ax_y = m2_y - 0.15
|
||||
ax2.plot([0, 23], [ax_y, ax_y], color=LN, lw=0.8)
|
||||
for t in range(0, 24, 2):
|
||||
ax2.plot([t, t], [ax_y - 0.08, ax_y + 0.08], color=LN, lw=0.8)
|
||||
ax2.text(t, ax_y - 0.25, str(t), ha="center", va="top", fontsize=6)
|
||||
ax2.text(11.5, ax_y - 0.55, "czas", ha="center", fontsize=7)
|
||||
|
||||
# Cmax annotation
|
||||
ax2.annotate(
|
||||
f"Cmax = {m2_ends[-1]}",
|
||||
xy=(m2_ends[-1], m2_y + bar_h),
|
||||
xytext=(m2_ends[-1] + 0.5, m2_y + bar_h + 0.6),
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
color="#333333",
|
||||
arrowprops={"arrowstyle": "->", "color": "#333333", "lw": 1.5},
|
||||
)
|
||||
|
||||
# Mnemonic at bottom
|
||||
ax2.text(
|
||||
11,
|
||||
-0.7,
|
||||
"„Krótki na M1 → START (szybko karmi M2)"
|
||||
" Krótki na M2 → KONIEC"
|
||||
' (szybko kończy)"',
|
||||
ha="center",
|
||||
fontsize=7.5,
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.3",
|
||||
"facecolor": GRAY4,
|
||||
"edgecolor": GRAY3,
|
||||
"lw": 0.8,
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,352 @@
|
||||
"""SPT vs LPT comparison and Flow Shop vs Job Shop diagrams."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from python_pkg.praca_magisterska_video.generate_images._sched_common import (
|
||||
BG,
|
||||
DPI,
|
||||
GRAY1,
|
||||
GRAY2,
|
||||
GRAY3,
|
||||
GRAY4,
|
||||
GRAY5,
|
||||
LN,
|
||||
OUTPUT_DIR,
|
||||
draw_arrow,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# SPT vs LPT COMPARISON (1 || ΣCⱼ)
|
||||
# ============================================================
|
||||
def draw_spt_comparison() -> None:
|
||||
"""Draw spt comparison."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(8.27, 5.5))
|
||||
|
||||
tasks_orig = [("J1", 5), ("J2", 3), ("J3", 8), ("J4", 2), ("J5", 6)]
|
||||
|
||||
spt_order = sorted(tasks_orig, key=lambda x: x[1])
|
||||
lpt_order = sorted(tasks_orig, key=lambda x: -x[1])
|
||||
|
||||
fills_map = {"J1": GRAY1, "J2": GRAY2, "J3": GRAY3, "J4": GRAY4, "J5": GRAY5}
|
||||
hatch_map = {"J1": "", "J2": "///", "J3": "xxx", "J4": "", "J5": "\\\\\\"}
|
||||
|
||||
for _idx, (ax, order_list, title, is_optimal) in enumerate(
|
||||
[
|
||||
(axes[0], spt_order, "SPT (Shortest Processing Time) — OPTYMALNE", True),
|
||||
(axes[1], lpt_order, "LPT (Longest Processing Time) — gorsze!", False),
|
||||
]
|
||||
):
|
||||
ax.set_xlim(-2, 26)
|
||||
ax.set_ylim(-0.5, 2.5)
|
||||
ax.axis("off")
|
||||
color = "#222222" if is_optimal else "#666666"
|
||||
marker = "✓" if is_optimal else "✗"
|
||||
ax.set_title(
|
||||
f"{marker} {title}",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
loc="left",
|
||||
color=color,
|
||||
pad=5,
|
||||
)
|
||||
|
||||
bar_y = 1.0
|
||||
bar_h = 0.8
|
||||
t = 0
|
||||
completions = []
|
||||
|
||||
for name, duration in order_list:
|
||||
rect = mpatches.Rectangle(
|
||||
(t, bar_y),
|
||||
duration,
|
||||
bar_h,
|
||||
lw=1.2,
|
||||
edgecolor=LN,
|
||||
facecolor=fills_map[name],
|
||||
hatch=hatch_map[name],
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
ax.text(
|
||||
t + duration / 2,
|
||||
bar_y + bar_h / 2,
|
||||
f"{name}\n({duration})",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
)
|
||||
t += duration
|
||||
completions.append(t)
|
||||
|
||||
# Completion time marker
|
||||
ax.plot([t, t], [bar_y - 0.15, bar_y], color=LN, lw=0.8)
|
||||
ax.text(
|
||||
t,
|
||||
bar_y - 0.25,
|
||||
f"C={t}",
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=6,
|
||||
color="#555555",
|
||||
)
|
||||
|
||||
total = sum(completions)
|
||||
# Time axis
|
||||
ax.plot([0, 25], [bar_y - 0.05, bar_y - 0.05], color=LN, lw=0.5)
|
||||
|
||||
# Sum annotation
|
||||
comp_str = " + ".join(str(c) for c in completions)
|
||||
ax.text(
|
||||
25,
|
||||
bar_y + bar_h / 2,
|
||||
f"ΣCⱼ = {comp_str}\n = {total}",
|
||||
ha="left",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold" if is_optimal else "normal",
|
||||
color=color,
|
||||
bbox={
|
||||
"boxstyle": "round,pad=0.2",
|
||||
"facecolor": GRAY1 if is_optimal else "white",
|
||||
"edgecolor": color,
|
||||
"lw": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Bottom annotation
|
||||
fig.text(
|
||||
0.5,
|
||||
0.02,
|
||||
'„Short People To the front"'
|
||||
" — krótkie najpierw,"
|
||||
" jak niskie osoby w zdjęciu klasowym",
|
||||
ha="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
style="italic",
|
||||
color="#444444",
|
||||
)
|
||||
|
||||
plt.tight_layout(rect=[0, 0.05, 1, 1])
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / "scheduling_spt_comparison.png"),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ scheduling_spt_comparison.png")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# FLOW SHOP vs JOB SHOP
|
||||
# ============================================================
|
||||
def draw_flow_vs_job() -> None:
|
||||
"""Draw flow vs job."""
|
||||
_fig, axes = plt.subplots(1, 2, figsize=(8.27, 4.5))
|
||||
|
||||
_draw_flow_shop(axes[0])
|
||||
_draw_job_shop(axes[1])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
str(Path(OUTPUT_DIR) / "scheduling_flow_vs_job.png"),
|
||||
dpi=DPI,
|
||||
bbox_inches="tight",
|
||||
facecolor=BG,
|
||||
)
|
||||
plt.close()
|
||||
_logger.info(" ✓ scheduling_flow_vs_job.png")
|
||||
|
||||
|
||||
def _draw_flow_shop(ax: Axes) -> None:
|
||||
"""Draw the Flow Shop diagram."""
|
||||
ax.set_xlim(0, 6)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Flow Shop (Fm)", fontsize=10, fontweight="bold", pad=8)
|
||||
|
||||
# Machines in a row
|
||||
machines_x = [1, 3, 5]
|
||||
machines_y = 3
|
||||
mach_r = 0.4
|
||||
|
||||
for i, mx in enumerate(machines_x):
|
||||
circle = plt.Circle(
|
||||
(mx, machines_y), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5
|
||||
)
|
||||
ax.add_patch(circle)
|
||||
ax.text(
|
||||
mx,
|
||||
machines_y,
|
||||
f"M{i + 1}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Arrows between machines
|
||||
for i in range(len(machines_x) - 1):
|
||||
draw_arrow(
|
||||
ax,
|
||||
machines_x[i] + mach_r + 0.05,
|
||||
machines_y,
|
||||
machines_x[i + 1] - mach_r - 0.05,
|
||||
machines_y,
|
||||
lw=2,
|
||||
)
|
||||
|
||||
# Jobs all flowing the same way
|
||||
jobs_flow = ["J1", "J2", "J3"]
|
||||
for _j, (job, y_off) in enumerate(zip(jobs_flow, [0.8, 0, -0.8], strict=False)):
|
||||
ax.text(
|
||||
0.2,
|
||||
machines_y + y_off,
|
||||
job,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.15", "facecolor": GRAY1, "edgecolor": LN},
|
||||
)
|
||||
# Dashed flow line
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(5.5, machines_y + y_off * 0.3),
|
||||
xytext=(0.5, machines_y + y_off),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": "#888888",
|
||||
"lw": 0.8,
|
||||
"linestyle": "dashed",
|
||||
},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
3,
|
||||
1.2,
|
||||
"Wszystkie zadania:\nM1 → M2 → M3",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=8,
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
3,
|
||||
0.4,
|
||||
"Jak taśma montażowa",
|
||||
ha="center",
|
||||
fontsize=7,
|
||||
style="italic",
|
||||
color="#666666",
|
||||
)
|
||||
|
||||
|
||||
def _draw_job_shop(ax: Axes) -> None:
|
||||
"""Draw the Job Shop diagram."""
|
||||
mach_r = 0.4
|
||||
ax.set_xlim(0, 6)
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
ax.set_title("Job Shop (Jm)", fontsize=10, fontweight="bold", pad=8)
|
||||
|
||||
# Machines scattered
|
||||
m_positions = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)]
|
||||
|
||||
for i, (mx, my) in enumerate(m_positions):
|
||||
circle = plt.Circle((mx, my), mach_r, facecolor=GRAY2, edgecolor=LN, lw=1.5)
|
||||
ax.add_patch(circle)
|
||||
ax.text(
|
||||
mx, my, f"M{i + 1}", ha="center", va="center", fontsize=9, fontweight="bold"
|
||||
)
|
||||
|
||||
# J1: M1 → M2 → M3 (solid)
|
||||
route1 = [(1.5, 4.2), (4.5, 4.2), (3, 2.5)]
|
||||
for i in range(len(route1) - 1):
|
||||
x1, y1 = route1[i]
|
||||
x2, y2 = route1[i + 1]
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
d = (dx**2 + dy**2) ** 0.5
|
||||
draw_arrow(
|
||||
ax,
|
||||
x1 + mach_r * dx / d + 0.05,
|
||||
y1 + mach_r * dy / d,
|
||||
x2 - mach_r * dx / d - 0.05,
|
||||
y2 - mach_r * dy / d,
|
||||
lw=1.5,
|
||||
)
|
||||
ax.text(
|
||||
0.3,
|
||||
4.8,
|
||||
"J1: M1→M2→M3",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY1, "edgecolor": LN},
|
||||
)
|
||||
|
||||
# J2: M2 → M3 → M1 (dashed)
|
||||
route2 = [(4.5, 4.2), (3, 2.5), (1.5, 4.2)]
|
||||
for i in range(len(route2) - 1):
|
||||
x1, y1 = route2[i]
|
||||
x2, y2 = route2[i + 1]
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
d = (dx**2 + dy**2) ** 0.5
|
||||
off = 0.15 # offset to avoid overlap
|
||||
ax.annotate(
|
||||
"",
|
||||
xy=(x2 - mach_r * dx / d - 0.05, y2 - mach_r * dy / d + off),
|
||||
xytext=(x1 + mach_r * dx / d + 0.05, y1 + mach_r * dy / d + off),
|
||||
arrowprops={
|
||||
"arrowstyle": "->",
|
||||
"color": "#555555",
|
||||
"lw": 1.5,
|
||||
"linestyle": "dashed",
|
||||
},
|
||||
)
|
||||
ax.text(
|
||||
3.8,
|
||||
5.2,
|
||||
"J2: M2→M3→M1",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
bbox={"boxstyle": "round,pad=0.1", "facecolor": GRAY4, "edgecolor": LN},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
3,
|
||||
1.2,
|
||||
"Każde zadanie:\nwłasna trasa!",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=8,
|
||||
bbox={"boxstyle": "round,pad=0.3", "facecolor": GRAY4, "edgecolor": GRAY3},
|
||||
)
|
||||
|
||||
ax.text(
|
||||
3,
|
||||
0.4,
|
||||
"NP-trudny już dla 3 maszyn",
|
||||
ha="center",
|
||||
fontsize=7,
|
||||
style="italic",
|
||||
color="#666666",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
2153
python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py
Executable file → Normal file
2153
python_pkg/praca_magisterska_video/generate_images/generate_q20_diagrams.py
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2312
python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py
Executable file → Normal file
2312
python_pkg/praca_magisterska_video/generate_images/generate_q24_diagrams.py
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1473
python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py
Executable file → Normal file
1473
python_pkg/praca_magisterska_video/generate_images/generate_scheduling_diagrams.py
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
36
python_pkg/screen_locker/_constants.py
Normal file
36
python_pkg/screen_locker/_constants.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Constants for the screen locker module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# Validation limits for workout data
|
||||
MAX_DISTANCE_KM = 100
|
||||
MAX_TIME_MINUTES = 600
|
||||
MAX_PACE_MIN_PER_KM = 20
|
||||
MIN_EXERCISE_NAME_LEN = 3
|
||||
MAX_SETS = 20
|
||||
MAX_REPS = 100
|
||||
MAX_WEIGHT_KG = 500
|
||||
SICK_LOCKOUT_SECONDS = 120 # 2 minutes wait when sick
|
||||
SUBMIT_DELAY_DEMO = 30
|
||||
SUBMIT_DELAY_PRODUCTION = 180
|
||||
PHONE_PENALTY_DELAY_DEMO = 10
|
||||
PHONE_PENALTY_DELAY_PRODUCTION = 600
|
||||
ADB_TIMEOUT = 15
|
||||
STRONGLIFTS_DB_REMOTE = (
|
||||
"/data/data/com.stronglifts.app/databases/StrongLifts-Database-3"
|
||||
)
|
||||
SHUTDOWN_CONFIG_FILE = Path("/etc/shutdown-schedule.conf")
|
||||
# Helper script path (relative to this file)
|
||||
ADJUST_SHUTDOWN_SCRIPT = Path(__file__).resolve().parent / "adjust_shutdown_schedule.sh"
|
||||
# State file to track sick day usage and original config values
|
||||
SICK_DAY_STATE_FILE = Path(__file__).resolve().parent / "sick_day_state.json"
|
||||
|
||||
STRENGTH_FIELDS: list[tuple[str, int]] = [
|
||||
("Exercises (comma-separated):", 50),
|
||||
("Sets per exercise (comma-separated):", 20),
|
||||
("Reps (comma-sep, + for variable: 12+11+12):", 30),
|
||||
("Weight per exercise kg (comma-separated):", 20),
|
||||
("Total weight lifted (kg):", 15),
|
||||
]
|
||||
203
python_pkg/screen_locker/_phone_verification.py
Normal file
203
python_pkg/screen_locker/_phone_verification.py
Normal file
@ -0,0 +1,203 @@
|
||||
"""Phone workout verification mixin using ADB and StrongLifts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import socket
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from python_pkg.screen_locker._constants import ADB_TIMEOUT, STRONGLIFTS_DB_REMOTE
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PhoneVerificationMixin:
|
||||
"""Mixin providing phone-based workout verification via ADB."""
|
||||
|
||||
def _run_adb(self, args: list[str]) -> tuple[bool, str]:
|
||||
"""Run an ADB command and return success flag and stdout."""
|
||||
adb = shutil.which("adb") or "adb"
|
||||
# When multiple devices are connected (e.g. USB + wireless), pin to
|
||||
# the wireless device's serial to avoid "more than one device" errors.
|
||||
_discovery_cmds = {"devices", "connect", "disconnect", "kill-server"}
|
||||
serial = (
|
||||
self._get_wireless_serial()
|
||||
if args and args[0] not in _discovery_cmds
|
||||
else None
|
||||
)
|
||||
serial_args = ["-s", serial] if serial else []
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[adb, *serial_args, *args],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=ADB_TIMEOUT,
|
||||
check=False,
|
||||
)
|
||||
except (FileNotFoundError, OSError) as exc:
|
||||
_logger.warning("ADB not available: %s", exc)
|
||||
return False, ""
|
||||
except subprocess.TimeoutExpired:
|
||||
_logger.warning("ADB command timed out: %s", args)
|
||||
return False, ""
|
||||
return result.returncode == 0, result.stdout
|
||||
|
||||
def _adb_shell(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
root: bool = False,
|
||||
) -> tuple[bool, str]:
|
||||
"""Run a shell command on the connected Android device."""
|
||||
if root:
|
||||
return self._run_adb(["shell", "su", "-c", command])
|
||||
return self._run_adb(["shell", command])
|
||||
|
||||
def _get_wireless_serial(self) -> str | None:
|
||||
"""Return the serial (ip:port) of the first connected wireless ADB device.
|
||||
|
||||
Used to pin ADB commands to the wireless device when multiple devices
|
||||
(e.g. USB cable + wireless debugging) are simultaneously connected.
|
||||
"""
|
||||
success, output = self._run_adb(["devices"])
|
||||
if not success:
|
||||
return None
|
||||
for line in output.strip().split("\n")[1:]:
|
||||
parts = line.split()
|
||||
if parts and ":" in parts[0] and "device" in line and "offline" not in line:
|
||||
return parts[0]
|
||||
return None
|
||||
|
||||
def _has_adb_device(self) -> bool:
|
||||
"""Return True if adb devices shows at least one connected device."""
|
||||
success, output = self._run_adb(["devices"])
|
||||
if not success:
|
||||
return False
|
||||
lines = output.strip().split("\n")[1:]
|
||||
return any("device" in line and "offline" not in line for line in lines)
|
||||
|
||||
def _try_adb_connect(self, address: str) -> bool:
|
||||
"""Run adb connect to address. Returns True on success."""
|
||||
_, output = self._run_adb(["connect", address])
|
||||
lower = output.lower()
|
||||
return "connected" in lower and "unable" not in lower and "failed" not in lower
|
||||
|
||||
def _get_local_subnet_prefix(self) -> str | None:
|
||||
"""Detect the local /24 network prefix (e.g. '192.168.1')."""
|
||||
with (
|
||||
contextlib.suppress(OSError),
|
||||
socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock,
|
||||
):
|
||||
sock.connect(("8.8.8.8", 80))
|
||||
return ".".join(sock.getsockname()[0].split(".")[:3])
|
||||
return None
|
||||
|
||||
def _try_wireless_reconnect(self) -> bool:
|
||||
"""Scan local /24 subnet on port 5555 and attempt ADB connect to phone."""
|
||||
prefix = self._get_local_subnet_prefix()
|
||||
if prefix is None:
|
||||
_logger.info("Could not determine local subnet for wireless scan")
|
||||
return False
|
||||
|
||||
def probe(i: int) -> bool:
|
||||
ip = f"{prefix}.{i}"
|
||||
with (
|
||||
contextlib.suppress(OSError),
|
||||
socket.create_connection((ip, 5555), timeout=0.5),
|
||||
):
|
||||
if self._try_adb_connect(f"{ip}:5555"):
|
||||
return self._has_adb_device()
|
||||
return False
|
||||
|
||||
_logger.info("Scanning %s.1-254:5555 for phone...", prefix)
|
||||
with ThreadPoolExecutor(max_workers=64) as executor:
|
||||
for future in as_completed(
|
||||
executor.submit(probe, i) for i in range(1, 255)
|
||||
):
|
||||
if future.result():
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_phone_connected(self) -> bool:
|
||||
"""Check if an Android device is connected via ADB.
|
||||
|
||||
If no device is visible, attempts wireless reconnection using the
|
||||
stored phone IP/port config. USB-connected devices are detected
|
||||
automatically by adb devices without any extra steps.
|
||||
"""
|
||||
if self._has_adb_device():
|
||||
return True
|
||||
_logger.info("No ADB device detected — attempting wireless reconnect...")
|
||||
return self._try_wireless_reconnect()
|
||||
|
||||
def _pull_stronglifts_db(self) -> Path | None:
|
||||
"""Pull StrongLifts database from phone to a local temp file.
|
||||
|
||||
Returns:
|
||||
Path to the local copy, or None on failure.
|
||||
"""
|
||||
tmp = Path(tempfile.gettempdir()) / "stronglifts_check.db"
|
||||
success, _ = self._adb_shell(
|
||||
f"cat '{STRONGLIFTS_DB_REMOTE}' > /sdcard/_sl_tmp.db",
|
||||
root=True,
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
ok, _ = self._run_adb(["pull", "/sdcard/_sl_tmp.db", str(tmp)])
|
||||
if not ok:
|
||||
return None
|
||||
return tmp
|
||||
|
||||
def _count_today_workouts(self, db_path: Path) -> int:
|
||||
"""Count today's workouts in a local copy of StrongLifts DB.
|
||||
|
||||
Args:
|
||||
db_path: Path to the locally-pulled StrongLifts database.
|
||||
|
||||
Returns:
|
||||
Number of workouts started today (local time).
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT COUNT(*) FROM workouts "
|
||||
"WHERE date(start / 1000, 'unixepoch', 'localtime') "
|
||||
"= date('now', 'localtime')",
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
finally:
|
||||
conn.close()
|
||||
except (sqlite3.Error, ValueError, TypeError):
|
||||
_logger.warning("Failed to query StrongLifts database")
|
||||
return 0
|
||||
|
||||
def _verify_phone_workout(self) -> tuple[str, str]:
|
||||
"""Verify workout was recorded in StrongLifts on the phone.
|
||||
|
||||
Returns:
|
||||
Tuple of (status, message) where status is one of:
|
||||
- "verified": Workout confirmed on phone.
|
||||
- "not_verified": Phone connected but no workout found.
|
||||
- "no_phone": No phone connected via ADB.
|
||||
- "error": Could not access StrongLifts database.
|
||||
"""
|
||||
if not self._is_phone_connected():
|
||||
return "no_phone", "No phone connected via ADB"
|
||||
local_db = self._pull_stronglifts_db()
|
||||
if local_db is None:
|
||||
return "error", "StrongLifts database not found on phone"
|
||||
count = self._count_today_workouts(local_db)
|
||||
if count > 0:
|
||||
return (
|
||||
"verified",
|
||||
f"Workout verified! ({count} session(s) found on phone)",
|
||||
)
|
||||
return "not_verified", "No workout found on phone today"
|
||||
262
python_pkg/screen_locker/_shutdown.py
Normal file
262
python_pkg/screen_locker/_shutdown.py
Normal file
@ -0,0 +1,262 @@
|
||||
"""Shutdown schedule adjustment mixin for the screen locker."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
from python_pkg.screen_locker._constants import (
|
||||
ADJUST_SHUTDOWN_SCRIPT,
|
||||
SHUTDOWN_CONFIG_FILE,
|
||||
SICK_DAY_STATE_FILE,
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShutdownMixin:
|
||||
"""Mixin providing shutdown schedule adjustment functionality."""
|
||||
|
||||
def _apply_earlier_shutdown(self, today: str) -> bool:
|
||||
"""Read config, save state, and write earlier shutdown hours."""
|
||||
config_values = self._read_shutdown_config()
|
||||
if config_values is None:
|
||||
return False
|
||||
mon_wed_hour, thu_sun_hour, morning_end_hour = config_values
|
||||
if not self._save_sick_day_state(today, mon_wed_hour, thu_sun_hour):
|
||||
_logger.error("Failed to save state - aborting adjustment")
|
||||
return False
|
||||
new_mon_wed = max(18, mon_wed_hour - 1)
|
||||
new_thu_sun = max(18, thu_sun_hour - 1)
|
||||
return self._write_shutdown_config(
|
||||
new_mon_wed,
|
||||
new_thu_sun,
|
||||
morning_end_hour,
|
||||
)
|
||||
|
||||
def _adjust_shutdown_time_earlier(self) -> bool:
|
||||
"""Adjust shutdown schedule 1.5 hours earlier (stricter).
|
||||
|
||||
This can only be used once per day. Original values are saved and
|
||||
automatically restored when checked the next day.
|
||||
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
self._restore_original_config_if_needed()
|
||||
if self._sick_mode_used_today():
|
||||
_logger.warning("Sick mode already used today")
|
||||
return False
|
||||
try:
|
||||
return self._apply_earlier_shutdown(today)
|
||||
except (OSError, ValueError) as e:
|
||||
_logger.warning("Failed to adjust shutdown time: %s", e)
|
||||
return False
|
||||
|
||||
def _adjust_shutdown_time_later(self) -> bool:
|
||||
"""Adjust shutdown schedule 2 hours later as workout reward.
|
||||
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
config_values = self._read_shutdown_config()
|
||||
if config_values is None:
|
||||
return False
|
||||
mon_wed_hour, thu_sun_hour, morning_end_hour = config_values
|
||||
new_mon_wed = min(23, mon_wed_hour + 2)
|
||||
new_thu_sun = min(23, thu_sun_hour + 2)
|
||||
return self._write_shutdown_config(
|
||||
new_mon_wed,
|
||||
new_thu_sun,
|
||||
morning_end_hour,
|
||||
restore=True,
|
||||
)
|
||||
except (OSError, ValueError) as e:
|
||||
_logger.warning("Failed to adjust shutdown time for workout: %s", e)
|
||||
return False
|
||||
|
||||
def _sick_mode_used_today(self) -> bool:
|
||||
"""Check if sick mode was already used today."""
|
||||
if not SICK_DAY_STATE_FILE.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
with SICK_DAY_STATE_FILE.open() as f:
|
||||
state = json.load(f)
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
return state.get("date") == today
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
def _save_sick_day_state(
|
||||
self,
|
||||
date: str,
|
||||
orig_mon_wed: int,
|
||||
orig_thu_sun: int,
|
||||
) -> bool:
|
||||
"""Save sick day state with original config values.
|
||||
|
||||
Returns True if saved successfully, False otherwise.
|
||||
"""
|
||||
state = {
|
||||
"date": date,
|
||||
"original_mon_wed_hour": orig_mon_wed,
|
||||
"original_thu_sun_hour": orig_thu_sun,
|
||||
}
|
||||
try:
|
||||
with SICK_DAY_STATE_FILE.open("w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
except OSError as e:
|
||||
_logger.warning("Failed to save sick day state: %s", e)
|
||||
return False
|
||||
|
||||
_logger.info("Saved sick day state for %s", date)
|
||||
return True
|
||||
|
||||
def _load_sick_day_state(self) -> tuple[str, int, int] | None:
|
||||
"""Load sick day state file.
|
||||
|
||||
Returns (date, orig_mon_wed_hour, orig_thu_sun_hour) or None.
|
||||
"""
|
||||
with SICK_DAY_STATE_FILE.open() as f:
|
||||
state = json.load(f)
|
||||
date = state.get("date")
|
||||
orig_mw = state.get("original_mon_wed_hour")
|
||||
orig_ts = state.get("original_thu_sun_hour")
|
||||
if date is None or orig_mw is None or orig_ts is None:
|
||||
return None
|
||||
return (str(date), int(orig_mw), int(orig_ts))
|
||||
|
||||
def _write_restored_config(
|
||||
self,
|
||||
orig_mw: int,
|
||||
orig_ts: int,
|
||||
state_date: str,
|
||||
) -> None:
|
||||
"""Write restored config values and clean up state file."""
|
||||
config_values = self._read_shutdown_config()
|
||||
if config_values:
|
||||
_, _, morning_end = config_values
|
||||
_logger.info(
|
||||
"Restoring original shutdown config from %s",
|
||||
state_date,
|
||||
)
|
||||
self._write_shutdown_config(
|
||||
orig_mw,
|
||||
orig_ts,
|
||||
morning_end,
|
||||
restore=True,
|
||||
)
|
||||
SICK_DAY_STATE_FILE.unlink()
|
||||
_logger.info("Removed stale sick day state from %s", state_date)
|
||||
|
||||
def _restore_original_config_if_needed(self) -> None:
|
||||
"""Restore original config if sick day state is from a previous day."""
|
||||
if not SICK_DAY_STATE_FILE.exists():
|
||||
return
|
||||
try:
|
||||
loaded = self._load_sick_day_state()
|
||||
if loaded is None:
|
||||
return
|
||||
state_date, orig_mw, orig_ts = loaded
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
if state_date != today:
|
||||
self._write_restored_config(orig_mw, orig_ts, state_date)
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
_logger.warning("Error checking sick day state: %s", e)
|
||||
|
||||
def _read_shutdown_config(self) -> tuple[int, int, int] | None:
|
||||
"""Read shutdown config. Returns (mw_hour, ts_hour, me_hour) or None."""
|
||||
if not SHUTDOWN_CONFIG_FILE.exists():
|
||||
_logger.warning("Config not found: %s", SHUTDOWN_CONFIG_FILE)
|
||||
return None
|
||||
parsed: dict[str, int] = {}
|
||||
keys = ("MON_WED_HOUR", "THU_SUN_HOUR", "MORNING_END_HOUR")
|
||||
with SHUTDOWN_CONFIG_FILE.open() as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
for key in keys:
|
||||
if stripped.startswith(f"{key}="):
|
||||
parsed[key] = int(stripped.split("=")[1])
|
||||
if len(parsed) < len(keys):
|
||||
_logger.warning("Shutdown config missing required values")
|
||||
return None
|
||||
return (
|
||||
parsed["MON_WED_HOUR"],
|
||||
parsed["THU_SUN_HOUR"],
|
||||
parsed["MORNING_END_HOUR"],
|
||||
)
|
||||
|
||||
def _build_shutdown_cmd(
|
||||
self,
|
||||
mon_wed: int,
|
||||
thu_sun: int,
|
||||
morning: int,
|
||||
*,
|
||||
restore: bool,
|
||||
) -> list[str]:
|
||||
"""Build the shutdown adjustment command."""
|
||||
cmd = ["/usr/bin/sudo", str(ADJUST_SHUTDOWN_SCRIPT)]
|
||||
if restore:
|
||||
cmd.append("--restore")
|
||||
cmd.extend([str(mon_wed), str(thu_sun), str(morning)])
|
||||
return cmd
|
||||
|
||||
def _write_shutdown_config(
|
||||
self,
|
||||
mon_wed_hour: int,
|
||||
thu_sun_hour: int,
|
||||
morning_end_hour: int,
|
||||
*,
|
||||
restore: bool = False,
|
||||
) -> bool:
|
||||
"""Write new shutdown config values using helper script.
|
||||
|
||||
Args:
|
||||
mon_wed_hour: Shutdown hour for Monday-Wednesday.
|
||||
thu_sun_hour: Shutdown hour for Thursday-Sunday.
|
||||
morning_end_hour: Morning end hour.
|
||||
restore: If True, allows restoring to later times.
|
||||
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not ADJUST_SHUTDOWN_SCRIPT.exists():
|
||||
_logger.warning(
|
||||
"Script not found: %s",
|
||||
ADJUST_SHUTDOWN_SCRIPT,
|
||||
)
|
||||
return False
|
||||
cmd = self._build_shutdown_cmd(
|
||||
mon_wed_hour,
|
||||
thu_sun_hour,
|
||||
morning_end_hour,
|
||||
restore=restore,
|
||||
)
|
||||
return self._run_shutdown_cmd(cmd, mon_wed_hour, thu_sun_hour)
|
||||
|
||||
def _run_shutdown_cmd(
|
||||
self,
|
||||
cmd: list[str],
|
||||
mon_wed_hour: int,
|
||||
thu_sun_hour: int,
|
||||
) -> bool:
|
||||
"""Execute the shutdown adjustment command."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except subprocess.SubprocessError as e:
|
||||
_logger.warning("Failed to adjust shutdown config: %s", e)
|
||||
return False
|
||||
_logger.info(
|
||||
"Adjusted shutdown: Mon-Wed=%d, Thu-Sun=%d. %s",
|
||||
mon_wed_hour,
|
||||
thu_sun_hour,
|
||||
result.stdout.strip(),
|
||||
)
|
||||
return True
|
||||
294
python_pkg/screen_locker/_ui_flows.py
Normal file
294
python_pkg/screen_locker/_ui_flows.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""UI flow methods mixin for the screen locker."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
import tkinter as tk
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from python_pkg.screen_locker._constants import (
|
||||
PHONE_PENALTY_DELAY_DEMO,
|
||||
PHONE_PENALTY_DELAY_PRODUCTION,
|
||||
SICK_LOCKOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
class UIFlowsMixin:
|
||||
"""Mixin providing UI flow logic for the screen locker."""
|
||||
|
||||
def ask_workout_done(self) -> None:
|
||||
"""Display the initial workout question dialog."""
|
||||
self.clear_container()
|
||||
self._label("Did you work out today?", pady=30)
|
||||
frame = self._button_row()
|
||||
self._button(
|
||||
frame,
|
||||
"YES",
|
||||
bg="#00aa00",
|
||||
command=self.ask_workout_type,
|
||||
).pack(side="left", padx=20)
|
||||
self._button(
|
||||
frame,
|
||||
"NO",
|
||||
bg="#aa0000",
|
||||
command=self.ask_if_sick,
|
||||
).pack(side="left", padx=20)
|
||||
|
||||
def _start_phone_check(self) -> None:
|
||||
"""Check phone for today's workout immediately at startup."""
|
||||
self.clear_container()
|
||||
self._label("Checking phone...", font_size=36, color="#ffaa00", pady=30)
|
||||
self._text("Looking for today's workout in StrongLifts...", font_size=18)
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._phone_future = executor.submit(self._verify_phone_workout)
|
||||
executor.shutdown(wait=False)
|
||||
self._poll_phone_check()
|
||||
|
||||
def _poll_phone_check(self) -> None:
|
||||
"""Poll background phone check and route to result handler when done."""
|
||||
if self._phone_future is not None and self._phone_future.done():
|
||||
status, message = self._phone_future.result()
|
||||
self._handle_startup_phone_result(status, message)
|
||||
else:
|
||||
self.root.after(500, self._poll_phone_check)
|
||||
|
||||
def _handle_startup_phone_result(self, status: str, message: str) -> None:
|
||||
"""Route to appropriate screen based on startup phone check result."""
|
||||
if status == "verified":
|
||||
self.workout_data["type"] = "phone_verified"
|
||||
self.workout_data["source"] = message
|
||||
self.clear_container()
|
||||
self._label(
|
||||
"\u2713 Workout Verified!", font_size=42, color="#00cc44", pady=30
|
||||
)
|
||||
self._text(message, font_size=20, color="#aaffaa")
|
||||
self._text("Unlocking...", font_size=18, color="#888888")
|
||||
unlock_delay = 1500 if self.demo_mode else 2000
|
||||
self.root.after(unlock_delay, self.unlock_screen)
|
||||
elif status == "not_verified":
|
||||
self.clear_container()
|
||||
self._label("No Workout Found", font_size=36, color="#ff4444", pady=20)
|
||||
self._text(
|
||||
f"\u274c {message}\n\n"
|
||||
"StrongLifts shows no workout today.\n"
|
||||
"Go do your workout first!",
|
||||
color="#ffaa00",
|
||||
)
|
||||
frame = self._button_row()
|
||||
self._button(
|
||||
frame,
|
||||
"TRY AGAIN",
|
||||
bg="#0066cc",
|
||||
command=self._start_phone_check,
|
||||
width=12,
|
||||
).pack(side="left", padx=10)
|
||||
self._button(
|
||||
frame,
|
||||
"I'm sick",
|
||||
bg="#cc6600",
|
||||
command=self.ask_if_sick,
|
||||
width=12,
|
||||
).pack(side="left", padx=10)
|
||||
else:
|
||||
# no_phone or error — penalty timer, then proceed to logging form
|
||||
self._show_phone_penalty(message, on_done=self.ask_workout_done)
|
||||
|
||||
def ask_if_sick(self) -> None:
|
||||
"""Display sick day question dialog."""
|
||||
self.clear_container()
|
||||
self._label("Are you sick?", pady=30)
|
||||
self._text(
|
||||
"If yes, shutdown time will be moved 1.5 hours earlier",
|
||||
color="#ffaa00",
|
||||
)
|
||||
self._sick_question_buttons()
|
||||
|
||||
def _sick_question_buttons(self) -> None:
|
||||
"""Create the sick day yes/no buttons."""
|
||||
frame = self._button_row()
|
||||
self._button(
|
||||
frame,
|
||||
"YES (sick)",
|
||||
bg="#cc6600",
|
||||
command=self.handle_sick_day,
|
||||
width=12,
|
||||
).pack(side="left", padx=20)
|
||||
self._button(
|
||||
frame,
|
||||
"NO",
|
||||
bg="#aa0000",
|
||||
command=self.lockout,
|
||||
width=12,
|
||||
).pack(side="left", padx=20)
|
||||
|
||||
def _get_sick_day_status(self) -> tuple[str, str]:
|
||||
"""Determine sick day status text and color."""
|
||||
if self._sick_mode_used_today():
|
||||
return "Shutdown time already adjusted today", "#ffaa00"
|
||||
if self._adjust_shutdown_time_earlier():
|
||||
return (
|
||||
"Shutdown time moved 1.5 hours earlier \u2713\n(Will revert tomorrow)"
|
||||
), "#00aa00"
|
||||
return "Could not adjust shutdown time (check permissions)", "#ff4444"
|
||||
|
||||
def handle_sick_day(self) -> None:
|
||||
"""Handle sick day: adjust shutdown time and start 2-minute wait."""
|
||||
self.clear_container()
|
||||
status_text, status_color = self._get_sick_day_status()
|
||||
self._show_sick_day_ui(status_text, status_color)
|
||||
self.sick_remaining_time = SICK_LOCKOUT_SECONDS
|
||||
self._update_sick_countdown()
|
||||
|
||||
def _show_sick_day_ui(self, status_text: str, status_color: str) -> None:
|
||||
"""Display sick day UI labels and countdown."""
|
||||
self._label("Sick Day Mode", color="#cc6600", pady=20)
|
||||
self._text(status_text, color=status_color)
|
||||
self._text(
|
||||
"Please wait 2 minutes before unlocking...",
|
||||
font_size=24,
|
||||
pady=20,
|
||||
)
|
||||
self.sick_countdown_label = self._label(
|
||||
str(SICK_LOCKOUT_SECONDS),
|
||||
font_size=80,
|
||||
pady=30,
|
||||
)
|
||||
|
||||
def _update_sick_countdown(self) -> None:
|
||||
"""Update the sick day countdown timer."""
|
||||
if self.sick_remaining_time > 0:
|
||||
self.sick_countdown_label.config(text=str(self.sick_remaining_time))
|
||||
self.sick_remaining_time -= 1
|
||||
self.root.after(1000, self._update_sick_countdown)
|
||||
else:
|
||||
# Record sick day and unlock
|
||||
self.workout_data["type"] = "sick_day"
|
||||
self.workout_data["note"] = "Sick day - shutdown moved earlier"
|
||||
self.unlock_screen()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lockout flow
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def lockout(self) -> None:
|
||||
"""Display lockout screen with countdown timer."""
|
||||
self.clear_container()
|
||||
self.lockout_label = self._label(
|
||||
f"Go work out!\nLocked for {self.lockout_time} seconds",
|
||||
font_size=48,
|
||||
color="#ff4444",
|
||||
pady=30,
|
||||
)
|
||||
self.countdown_label = self._label(
|
||||
str(self.lockout_time),
|
||||
font_size=120,
|
||||
pady=30,
|
||||
)
|
||||
self.remaining_time = self.lockout_time
|
||||
self.update_lockout_countdown()
|
||||
|
||||
def update_lockout_countdown(self) -> None:
|
||||
"""Update the lockout countdown timer display."""
|
||||
if self.remaining_time > 0:
|
||||
self.countdown_label.config(text=str(self.remaining_time))
|
||||
self.remaining_time -= 1
|
||||
self.root.after(1000, self.update_lockout_countdown)
|
||||
else:
|
||||
self.ask_workout_done()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phone penalty
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _attempt_unlock(self) -> None:
|
||||
"""Unlock screen after workout form submission."""
|
||||
self.unlock_screen()
|
||||
|
||||
def _show_phone_penalty(
|
||||
self, message: str, *, on_done: Callable[[], None] | None = None
|
||||
) -> None:
|
||||
"""Show penalty countdown when phone verification is unavailable."""
|
||||
self.clear_container()
|
||||
self._phone_penalty_done_fn: Callable[[], None] = (
|
||||
on_done if on_done is not None else self.unlock_screen
|
||||
)
|
||||
delay = (
|
||||
PHONE_PENALTY_DELAY_DEMO
|
||||
if self.demo_mode
|
||||
else PHONE_PENALTY_DELAY_PRODUCTION
|
||||
)
|
||||
self._label(
|
||||
"Cannot Verify Workout",
|
||||
font_size=36,
|
||||
color="#ff8800",
|
||||
pady=20,
|
||||
)
|
||||
self._text(message, color="#ffaa00")
|
||||
self._text(
|
||||
"Connect phone via ADB to skip this wait,\n"
|
||||
"or wait for the penalty timer.\n\n"
|
||||
"Note: Phone must be rooted and StrongLifts installed.",
|
||||
font_size=18,
|
||||
)
|
||||
self.phone_penalty_remaining = delay
|
||||
self.phone_penalty_label = self._label(
|
||||
str(delay),
|
||||
font_size=80,
|
||||
pady=20,
|
||||
)
|
||||
self._update_phone_penalty()
|
||||
|
||||
def _update_phone_penalty(self) -> None:
|
||||
"""Update phone penalty countdown."""
|
||||
if self.phone_penalty_remaining > 0:
|
||||
self.phone_penalty_label.config(
|
||||
text=str(self.phone_penalty_remaining),
|
||||
)
|
||||
self.phone_penalty_remaining -= 1
|
||||
self.root.after(1000, self._update_phone_penalty)
|
||||
else:
|
||||
self._phone_penalty_done_fn()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Submit timer and entry checking
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _tick_submit_timer(self) -> None:
|
||||
"""Decrement submit timer and schedule next tick."""
|
||||
self.timer_label.config(
|
||||
text=f"Submit available in {self.submit_unlock_time} seconds...",
|
||||
)
|
||||
self.submit_unlock_time -= 1
|
||||
self.root.after(1000, self.update_submit_timer)
|
||||
|
||||
def _try_enable_submit(self) -> None:
|
||||
"""Enable submit button if all entries are filled."""
|
||||
all_filled = all(entry.get().strip() for entry in self.entries_to_check)
|
||||
if all_filled:
|
||||
self.submit_btn.config(
|
||||
text="SUBMIT",
|
||||
state="normal",
|
||||
bg="#00aa00",
|
||||
command=self.submit_command,
|
||||
)
|
||||
self.timer_label.config(text="You can now submit!")
|
||||
else:
|
||||
self.timer_label.config(text="Fill all fields to enable submit")
|
||||
self.root.after(1000, self.check_entries_filled)
|
||||
|
||||
def update_submit_timer(self) -> None:
|
||||
"""Update countdown timer and check if submit can be enabled."""
|
||||
with contextlib.suppress(tk.TclError):
|
||||
if self.submit_unlock_time > 0:
|
||||
self._tick_submit_timer()
|
||||
else:
|
||||
self._try_enable_submit()
|
||||
|
||||
def check_entries_filled(self) -> None:
|
||||
"""Continuously check if entries are filled after timer expires."""
|
||||
with contextlib.suppress(tk.TclError):
|
||||
self._try_enable_submit()
|
||||
269
python_pkg/screen_locker/_workout_forms.py
Normal file
269
python_pkg/screen_locker/_workout_forms.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""Workout form methods mixin for the screen locker."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from python_pkg.screen_locker._constants import (
|
||||
MAX_DISTANCE_KM,
|
||||
MAX_PACE_MIN_PER_KM,
|
||||
MAX_REPS,
|
||||
MAX_SETS,
|
||||
MAX_TIME_MINUTES,
|
||||
MAX_WEIGHT_KG,
|
||||
MIN_EXERCISE_NAME_LEN,
|
||||
STRENGTH_FIELDS,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tkinter as tk
|
||||
|
||||
|
||||
class WorkoutFormsMixin:
|
||||
"""Mixin providing workout form creation and validation."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Workout type selection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def ask_workout_type(self) -> None:
|
||||
"""Display workout type selection dialog."""
|
||||
self.clear_container()
|
||||
self._label("What type of workout?", pady=30)
|
||||
frame = self._button_row()
|
||||
self._button(
|
||||
frame,
|
||||
"STRENGTH",
|
||||
bg="#cc6600",
|
||||
command=self.ask_strength_details,
|
||||
width=12,
|
||||
).pack(side="left", padx=20)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Running workout
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _create_running_entries(self) -> list[tk.Entry]:
|
||||
"""Create running workout entry fields."""
|
||||
self.distance_entry = self._entry_row("Distance (km):")
|
||||
self.time_entry = self._entry_row("Time (minutes):")
|
||||
self.pace_entry = self._entry_row("Pace (min/km):")
|
||||
return [self.distance_entry, self.time_entry, self.pace_entry]
|
||||
|
||||
def ask_running_details(self) -> None:
|
||||
"""Display running workout input form."""
|
||||
self.clear_container()
|
||||
self.workout_data["type"] = "running"
|
||||
self._label("Running Details", pady=20)
|
||||
entries = self._create_running_entries()
|
||||
self._setup_form_controls(
|
||||
entries,
|
||||
self.verify_running_data,
|
||||
self.ask_workout_type,
|
||||
)
|
||||
|
||||
def _check_running_ranges(
|
||||
self,
|
||||
distance: float,
|
||||
time_mins: float,
|
||||
pace: float,
|
||||
) -> str | None:
|
||||
"""Check if running values are in valid ranges."""
|
||||
if distance <= 0 or distance > MAX_DISTANCE_KM:
|
||||
return f"Distance seems unrealistic (0-{MAX_DISTANCE_KM} km)"
|
||||
if time_mins <= 0 or time_mins > MAX_TIME_MINUTES:
|
||||
return f"Time seems unrealistic (0-{MAX_TIME_MINUTES} minutes)"
|
||||
if pace <= 0 or pace > MAX_PACE_MIN_PER_KM:
|
||||
return f"Pace seems unrealistic (0-{MAX_PACE_MIN_PER_KM} min/km)"
|
||||
expected_pace = time_mins / distance
|
||||
tolerance = expected_pace * 0.15 # 15% tolerance
|
||||
if abs(pace - expected_pace) > tolerance:
|
||||
return (
|
||||
f"Pace doesn't match! "
|
||||
f"Expected ~{expected_pace:.2f} min/km, got {pace:.2f}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _validate_running_input(self) -> tuple[float, float, float] | None:
|
||||
"""Parse and validate running input fields."""
|
||||
try:
|
||||
distance = float(self.distance_entry.get())
|
||||
time_mins = float(self.time_entry.get())
|
||||
pace = float(self.pace_entry.get())
|
||||
except ValueError:
|
||||
self.show_error("Please enter valid numbers")
|
||||
return None
|
||||
error = self._check_running_ranges(distance, time_mins, pace)
|
||||
if error:
|
||||
self.show_error(error)
|
||||
return None
|
||||
return distance, time_mins, pace
|
||||
|
||||
def verify_running_data(self) -> None:
|
||||
"""Validate running workout data and unlock if valid."""
|
||||
result = self._validate_running_input()
|
||||
if result is None:
|
||||
return
|
||||
distance, time_mins, pace = result
|
||||
self.workout_data["distance_km"] = str(distance)
|
||||
self.workout_data["time_minutes"] = str(time_mins)
|
||||
self.workout_data["pace_min_per_km"] = str(pace)
|
||||
self._attempt_unlock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Strength workout
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _create_strength_entries(self) -> list[tk.Entry]:
|
||||
"""Create strength training entry fields."""
|
||||
entries = [
|
||||
self._entry_row(lbl, width=w, font_size=18) for lbl, w in STRENGTH_FIELDS
|
||||
]
|
||||
(
|
||||
self.exercises_entry,
|
||||
self.sets_entry,
|
||||
self.reps_entry,
|
||||
self.weights_entry,
|
||||
self.total_weight_entry,
|
||||
) = entries
|
||||
return entries
|
||||
|
||||
def ask_strength_details(self) -> None:
|
||||
"""Display strength training input form."""
|
||||
self.clear_container()
|
||||
self.workout_data["type"] = "strength"
|
||||
self._label("Strength Training Details", pady=20)
|
||||
entries = self._create_strength_entries()
|
||||
self._setup_form_controls(
|
||||
entries,
|
||||
self.verify_strength_data,
|
||||
self.ask_workout_type,
|
||||
)
|
||||
|
||||
def _parse_reps(self, reps_raw: list[str]) -> list[list[int]]:
|
||||
"""Parse reps input - single number or variable reps like '12+11+12'."""
|
||||
reps: list[list[int]] = []
|
||||
for r in reps_raw:
|
||||
if "+" in r:
|
||||
reps.append([int(x.strip()) for x in r.split("+")])
|
||||
else:
|
||||
reps.append([int(r)])
|
||||
return reps
|
||||
|
||||
def _validate_strength_inputs(
|
||||
self,
|
||||
exercises: list[str],
|
||||
sets: list[int],
|
||||
reps: list[list[int]],
|
||||
weights: list[float],
|
||||
) -> str | None:
|
||||
"""Validate strength workout inputs. Returns error message or None."""
|
||||
if not (len(exercises) == len(sets) == len(reps) == len(weights)):
|
||||
return "Number of exercises, sets, reps, and weights must match"
|
||||
if any(len(ex) < MIN_EXERCISE_NAME_LEN for ex in exercises):
|
||||
return "Exercise names too short - be specific"
|
||||
if any(s < 1 or s > MAX_SETS for s in sets):
|
||||
return f"Sets should be between 1-{MAX_SETS}"
|
||||
if any(w < 0 or w > MAX_WEIGHT_KG for w in weights):
|
||||
return f"Weights should be between 0-{MAX_WEIGHT_KG} kg"
|
||||
return self._validate_reps(exercises, sets, reps)
|
||||
|
||||
def _validate_reps(
|
||||
self,
|
||||
exercises: list[str],
|
||||
sets: list[int],
|
||||
reps: list[list[int]],
|
||||
) -> str | None:
|
||||
"""Validate reps data. Returns error message or None if valid."""
|
||||
for i, rep_list in enumerate(reps):
|
||||
if any(r < 1 or r > MAX_REPS for r in rep_list):
|
||||
return f"Reps should be between 1-{MAX_REPS}"
|
||||
if len(rep_list) > 1 and len(rep_list) != sets[i]:
|
||||
return (
|
||||
f"For {exercises[i]!r}: variable reps count "
|
||||
f"({len(rep_list)}) doesn't match sets ({sets[i]})"
|
||||
)
|
||||
return None
|
||||
|
||||
def _calculate_expected_total(
|
||||
self,
|
||||
sets: list[int],
|
||||
reps: list[list[int]],
|
||||
weights: list[float],
|
||||
) -> float:
|
||||
"""Calculate expected total weight lifted."""
|
||||
expected_total = 0.0
|
||||
for i, rep_list in enumerate(reps):
|
||||
if len(rep_list) == 1:
|
||||
expected_total += sets[i] * rep_list[0] * weights[i]
|
||||
else:
|
||||
expected_total += sum(rep_list) * weights[i]
|
||||
return expected_total
|
||||
|
||||
def _parse_strength_entries(
|
||||
self,
|
||||
) -> tuple[list[str], list[int], list[list[int]], list[float], float]:
|
||||
"""Parse raw strength training input from entry widgets."""
|
||||
exercises = [e.strip() for e in self.exercises_entry.get().split(",")]
|
||||
sets = [int(s.strip()) for s in self.sets_entry.get().split(",")]
|
||||
reps_raw = [r.strip() for r in self.reps_entry.get().split(",")]
|
||||
reps = self._parse_reps(reps_raw)
|
||||
weights = [float(w.strip()) for w in self.weights_entry.get().split(",")]
|
||||
total_weight = float(self.total_weight_entry.get())
|
||||
return exercises, sets, reps, weights, total_weight
|
||||
|
||||
def _check_total_weight(
|
||||
self,
|
||||
sets: list[int],
|
||||
reps: list[list[int]],
|
||||
weights: list[float],
|
||||
total_weight: float,
|
||||
) -> str | None:
|
||||
"""Verify total weight matches individual exercise calculations."""
|
||||
expected = self._calculate_expected_total(sets, reps, weights)
|
||||
tolerance = expected * 0.15 # 15% tolerance
|
||||
if abs(total_weight - expected) > tolerance:
|
||||
return (
|
||||
f"Total weight doesn't match! "
|
||||
f"Expected ~{expected:.1f} kg, got {total_weight:.1f}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _store_strength_data(
|
||||
self,
|
||||
exercises: list[str],
|
||||
sets: list[int],
|
||||
reps: list[list[int]],
|
||||
weights: list[float],
|
||||
total_weight: float,
|
||||
) -> None:
|
||||
"""Store validated strength workout data."""
|
||||
self.workout_data["exercises"] = exercises
|
||||
self.workout_data["sets"] = [str(s) for s in sets]
|
||||
self.workout_data["reps"] = [
|
||||
"+".join(str(r) for r in rep_list) for rep_list in reps
|
||||
]
|
||||
self.workout_data["weights_kg"] = [str(w) for w in weights]
|
||||
self.workout_data["total_weight_kg"] = str(total_weight)
|
||||
|
||||
def verify_strength_data(self) -> None:
|
||||
"""Validate strength workout data and unlock if valid."""
|
||||
try:
|
||||
self._verify_strength_data_inner()
|
||||
except ValueError:
|
||||
self.show_error("Please enter valid data in correct format")
|
||||
|
||||
def _verify_strength_data_inner(self) -> None:
|
||||
"""Parse, validate, and store strength data."""
|
||||
data = self._parse_strength_entries()
|
||||
exercises, sets, reps, weights, total_weight = data
|
||||
error = self._validate_strength_inputs(exercises, sets, reps, weights)
|
||||
if error:
|
||||
self.show_error(error)
|
||||
return
|
||||
total_err = self._check_total_weight(sets, reps, weights, total_weight)
|
||||
if total_err:
|
||||
self.show_error(total_err)
|
||||
return
|
||||
self._store_strength_data(exercises, sets, reps, weights, total_weight)
|
||||
self._attempt_unlock()
|
||||
File diff suppressed because it is too large
Load Diff
113
python_pkg/screen_locker/tests/conftest.py
Normal file
113
python_pkg/screen_locker/tests/conftest.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Shared fixtures and helpers for screen_locker tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import tkinter as tk
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from python_pkg.screen_locker.screen_lock import ScreenLocker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
class RunningData(NamedTuple):
|
||||
"""Running workout data for tests."""
|
||||
|
||||
distance: str
|
||||
time_mins: str
|
||||
pace: str
|
||||
|
||||
|
||||
class StrengthData(NamedTuple):
|
||||
"""Strength workout data for tests."""
|
||||
|
||||
exercises: str
|
||||
sets: str
|
||||
reps: str
|
||||
weights: str
|
||||
total_weight: str
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tk() -> Generator[MagicMock]:
|
||||
"""Mock tkinter module for testing without display."""
|
||||
with patch("python_pkg.screen_locker.screen_lock.tk") as mock:
|
||||
# Set up Tk root mock
|
||||
mock_root = MagicMock()
|
||||
mock_root.winfo_screenwidth.return_value = 1920
|
||||
mock_root.winfo_screenheight.return_value = 1080
|
||||
mock.Tk.return_value = mock_root
|
||||
|
||||
# Set up Frame mock
|
||||
mock_frame = MagicMock()
|
||||
mock_frame.winfo_children.return_value = []
|
||||
mock.Frame.return_value = mock_frame
|
||||
|
||||
# Set up TclError as actual exception class
|
||||
mock.TclError = tk.TclError
|
||||
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sys_exit() -> Generator[MagicMock]:
|
||||
"""Mock sys.exit to prevent test termination."""
|
||||
with patch("python_pkg.screen_locker.screen_lock.sys.exit") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_sys_exit(mock_sys_exit: MagicMock) -> MagicMock:
|
||||
"""Alias for mock_sys_exit when the return value is unused."""
|
||||
return mock_sys_exit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_log_file(tmp_path: Path) -> Path:
|
||||
"""Create a temporary log file path."""
|
||||
return tmp_path / "workout_log.json"
|
||||
|
||||
|
||||
def create_locker(
|
||||
_mock_tk: MagicMock,
|
||||
tmp_path: Path,
|
||||
*,
|
||||
demo_mode: bool = True,
|
||||
has_logged: bool = False,
|
||||
) -> ScreenLocker:
|
||||
"""Create a ScreenLocker instance for testing."""
|
||||
with (
|
||||
patch.object(Path, "resolve", return_value=tmp_path),
|
||||
patch.object(ScreenLocker, "has_logged_today", return_value=has_logged),
|
||||
patch.object(ScreenLocker, "_start_phone_check"),
|
||||
):
|
||||
return ScreenLocker(demo_mode=demo_mode)
|
||||
|
||||
|
||||
def setup_running_entries(locker: ScreenLocker, data: RunningData) -> None:
|
||||
"""Set up mock running entry widgets."""
|
||||
locker.distance_entry = MagicMock()
|
||||
locker.distance_entry.get.return_value = data.distance
|
||||
locker.time_entry = MagicMock()
|
||||
locker.time_entry.get.return_value = data.time_mins
|
||||
locker.pace_entry = MagicMock()
|
||||
locker.pace_entry.get.return_value = data.pace
|
||||
|
||||
|
||||
def setup_strength_entries(locker: ScreenLocker, data: StrengthData) -> None:
|
||||
"""Set up mock strength entry widgets."""
|
||||
locker.exercises_entry = MagicMock()
|
||||
locker.exercises_entry.get.return_value = data.exercises
|
||||
locker.sets_entry = MagicMock()
|
||||
locker.sets_entry.get.return_value = data.sets
|
||||
locker.reps_entry = MagicMock()
|
||||
locker.reps_entry.get.return_value = data.reps
|
||||
locker.weights_entry = MagicMock()
|
||||
locker.weights_entry.get.return_value = data.weights
|
||||
locker.total_weight_entry = MagicMock()
|
||||
locker.total_weight_entry.get.return_value = data.total_weight
|
||||
411
python_pkg/screen_locker/tests/test_adb_and_phone.py
Normal file
411
python_pkg/screen_locker/tests/test_adb_and_phone.py
Normal file
@ -0,0 +1,411 @@
|
||||
"""Tests for ADB commands, phone connection, and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from python_pkg.screen_locker.screen_lock import STRONGLIFTS_DB_REMOTE
|
||||
from python_pkg.screen_locker.tests.conftest import create_locker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestRunAdb:
|
||||
"""Tests for _run_adb ADB command execution."""
|
||||
|
||||
def test_run_adb_success(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test successful ADB command."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
mock_result = MagicMock(returncode=0, stdout="ok\n")
|
||||
with patch(
|
||||
"python_pkg.screen_locker._phone_verification.subprocess.run",
|
||||
return_value=mock_result,
|
||||
) as mock_run:
|
||||
success, output = locker._run_adb(["devices"])
|
||||
|
||||
assert success is True
|
||||
assert output == "ok\n"
|
||||
mock_run.assert_called_once()
|
||||
|
||||
def test_run_adb_failure(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test failed ADB command."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
mock_result = MagicMock(returncode=1, stdout="")
|
||||
with patch(
|
||||
"python_pkg.screen_locker._phone_verification.subprocess.run",
|
||||
return_value=mock_result,
|
||||
):
|
||||
success, _output = locker._run_adb(["devices"])
|
||||
|
||||
assert success is False
|
||||
|
||||
def test_run_adb_not_found(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ADB binary not found."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
with patch(
|
||||
"python_pkg.screen_locker._phone_verification.subprocess.run",
|
||||
side_effect=FileNotFoundError("adb not found"),
|
||||
):
|
||||
success, output = locker._run_adb(["devices"])
|
||||
|
||||
assert success is False
|
||||
assert output == ""
|
||||
|
||||
def test_run_adb_oserror(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ADB OSError."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
with patch(
|
||||
"python_pkg.screen_locker._phone_verification.subprocess.run",
|
||||
side_effect=OSError("permission denied"),
|
||||
):
|
||||
success, output = locker._run_adb(["devices"])
|
||||
|
||||
assert success is False
|
||||
assert output == ""
|
||||
|
||||
def test_run_adb_timeout(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ADB command timeout."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
with patch(
|
||||
"python_pkg.screen_locker._phone_verification.subprocess.run",
|
||||
side_effect=subprocess.TimeoutExpired("adb", 15),
|
||||
):
|
||||
success, output = locker._run_adb(["devices"])
|
||||
|
||||
assert success is False
|
||||
assert output == ""
|
||||
|
||||
|
||||
class TestAdbShell:
|
||||
"""Tests for _adb_shell method."""
|
||||
|
||||
def test_adb_shell_no_root(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ADB shell without root."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, "output"),
|
||||
)
|
||||
|
||||
success, output = locker._adb_shell("ls /sdcard")
|
||||
|
||||
locker._run_adb.assert_called_once_with(["shell", "ls /sdcard"])
|
||||
assert success is True
|
||||
assert output == "output"
|
||||
|
||||
def test_adb_shell_with_root(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ADB shell with root."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, "output"),
|
||||
)
|
||||
|
||||
success, _output = locker._adb_shell("ls /data", root=True)
|
||||
|
||||
locker._run_adb.assert_called_once_with(
|
||||
["shell", "su", "-c", "ls /data"],
|
||||
)
|
||||
assert success is True
|
||||
|
||||
|
||||
class TestIsPhoneConnected:
|
||||
"""Tests for _is_phone_connected method."""
|
||||
|
||||
def test_phone_connected(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test phone detected as connected."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(
|
||||
True,
|
||||
"List of devices attached\nABC123\tdevice\n\n",
|
||||
),
|
||||
)
|
||||
|
||||
assert locker._is_phone_connected() is True
|
||||
|
||||
def test_phone_not_connected(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test no phone connected."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, "List of devices attached\n\n"),
|
||||
)
|
||||
locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign]
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
assert locker._is_phone_connected() is False
|
||||
|
||||
def test_phone_offline(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test phone connected but offline."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(
|
||||
True,
|
||||
"List of devices attached\nABC123\toffline\n\n",
|
||||
),
|
||||
)
|
||||
locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign]
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
assert locker._is_phone_connected() is False
|
||||
|
||||
def test_adb_command_fails(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ADB command failure."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(False, ""),
|
||||
)
|
||||
locker._try_wireless_reconnect = MagicMock( # type: ignore[method-assign]
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
assert locker._is_phone_connected() is False
|
||||
|
||||
|
||||
class TestFindHealthConnectDb:
|
||||
"""Tests for _pull_stronglifts_db method."""
|
||||
|
||||
def test_db_pulled_successfully(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test StrongLifts DB pulled from device."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._adb_shell = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, ""),
|
||||
)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, ""),
|
||||
)
|
||||
|
||||
result = locker._pull_stronglifts_db()
|
||||
|
||||
assert result is not None
|
||||
locker._adb_shell.assert_called_once()
|
||||
locker._run_adb.assert_called_once()
|
||||
call_args = locker._run_adb.call_args[0][0]
|
||||
assert call_args[0] == "pull"
|
||||
|
||||
def test_db_cat_fails(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test returns None when cat command fails."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._adb_shell = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(False, ""),
|
||||
)
|
||||
|
||||
assert locker._pull_stronglifts_db() is None
|
||||
|
||||
def test_db_pull_fails(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test returns None when adb pull fails."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._adb_shell = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, ""),
|
||||
)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(False, ""),
|
||||
)
|
||||
|
||||
assert locker._pull_stronglifts_db() is None
|
||||
|
||||
def test_db_uses_correct_remote_path(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test uses the correct StrongLifts DB remote path."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._adb_shell = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, ""),
|
||||
)
|
||||
locker._run_adb = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(True, ""),
|
||||
)
|
||||
|
||||
locker._pull_stronglifts_db()
|
||||
|
||||
shell_cmd = locker._adb_shell.call_args[0][0]
|
||||
assert STRONGLIFTS_DB_REMOTE in shell_cmd
|
||||
|
||||
|
||||
class TestCountTodayWorkouts:
|
||||
"""Tests for _count_today_workouts method."""
|
||||
|
||||
def test_workouts_found_today(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test workouts found today."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
db_file = tmp_path / "sl_test.db"
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
conn.execute(
|
||||
"CREATE TABLE workouts "
|
||||
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
|
||||
)
|
||||
# Insert a workout with today's timestamp (ms)
|
||||
now_ms = int(time.time() * 1000)
|
||||
conn.execute(
|
||||
"INSERT INTO workouts VALUES (?, ?, ?)",
|
||||
("w1", now_ms, now_ms + 3600000),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
assert locker._count_today_workouts(db_file) == 1
|
||||
|
||||
def test_no_workouts_today(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test no workouts today."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
db_file = tmp_path / "sl_test.db"
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
conn.execute(
|
||||
"CREATE TABLE workouts "
|
||||
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
|
||||
)
|
||||
# Insert a workout from yesterday (24h+ ago)
|
||||
yesterday_ms = int((time.time() - 200000) * 1000)
|
||||
conn.execute(
|
||||
"INSERT INTO workouts VALUES (?, ?, ?)",
|
||||
("w1", yesterday_ms, yesterday_ms + 3600000),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
assert locker._count_today_workouts(db_file) == 0
|
||||
|
||||
def test_invalid_db_returns_zero(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test returns 0 for invalid database file."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
bad_file = tmp_path / "not_a_db.db"
|
||||
bad_file.write_text("not a database")
|
||||
|
||||
assert locker._count_today_workouts(bad_file) == 0
|
||||
|
||||
def test_missing_table_returns_zero(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test returns 0 when workouts table doesn't exist."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
db_file = tmp_path / "empty.db"
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
conn.execute("CREATE TABLE other (id TEXT)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
assert locker._count_today_workouts(db_file) == 0
|
||||
|
||||
def test_multiple_workouts_today(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test counts multiple workouts today correctly."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
db_file = tmp_path / "sl_test.db"
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
conn.execute(
|
||||
"CREATE TABLE workouts "
|
||||
"(id TEXT PRIMARY KEY, start INTEGER, finish INTEGER)",
|
||||
)
|
||||
now_ms = int(time.time() * 1000)
|
||||
conn.execute(
|
||||
"INSERT INTO workouts VALUES (?, ?, ?)",
|
||||
("w1", now_ms, now_ms + 3600000),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO workouts VALUES (?, ?, ?)",
|
||||
("w2", now_ms + 100000, now_ms + 3700000),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
assert locker._count_today_workouts(db_file) == 2
|
||||
390
python_pkg/screen_locker/tests/test_init_and_log.py
Normal file
390
python_pkg/screen_locker/tests/test_init_and_log.py
Normal file
@ -0,0 +1,390 @@
|
||||
"""Tests for screen_locker initialization, logging, and basic operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from python_pkg.screen_locker.screen_lock import (
|
||||
MAX_DISTANCE_KM,
|
||||
MAX_PACE_MIN_PER_KM,
|
||||
MAX_REPS,
|
||||
MAX_SETS,
|
||||
MAX_TIME_MINUTES,
|
||||
MAX_WEIGHT_KG,
|
||||
MIN_EXERCISE_NAME_LEN,
|
||||
)
|
||||
from python_pkg.screen_locker.tests.conftest import create_locker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module constants."""
|
||||
|
||||
def test_max_distance_km(self) -> None:
|
||||
"""Test MAX_DISTANCE_KM is reasonable."""
|
||||
assert MAX_DISTANCE_KM == 100
|
||||
assert MAX_DISTANCE_KM > 0
|
||||
|
||||
def test_max_time_minutes(self) -> None:
|
||||
"""Test MAX_TIME_MINUTES is reasonable."""
|
||||
assert MAX_TIME_MINUTES == 600
|
||||
assert MAX_TIME_MINUTES > 0
|
||||
|
||||
def test_max_pace_min_per_km(self) -> None:
|
||||
"""Test MAX_PACE_MIN_PER_KM is reasonable."""
|
||||
assert MAX_PACE_MIN_PER_KM == 20
|
||||
assert MAX_PACE_MIN_PER_KM > 0
|
||||
|
||||
def test_min_exercise_name_len(self) -> None:
|
||||
"""Test MIN_EXERCISE_NAME_LEN is reasonable."""
|
||||
assert MIN_EXERCISE_NAME_LEN == 3
|
||||
assert MIN_EXERCISE_NAME_LEN > 0
|
||||
|
||||
def test_max_sets(self) -> None:
|
||||
"""Test MAX_SETS is reasonable."""
|
||||
assert MAX_SETS == 20
|
||||
assert MAX_SETS > 0
|
||||
|
||||
def test_max_reps(self) -> None:
|
||||
"""Test MAX_REPS is reasonable."""
|
||||
assert MAX_REPS == 100
|
||||
assert MAX_REPS > 0
|
||||
|
||||
def test_max_weight_kg(self) -> None:
|
||||
"""Test MAX_WEIGHT_KG is reasonable."""
|
||||
assert MAX_WEIGHT_KG == 500
|
||||
assert MAX_WEIGHT_KG > 0
|
||||
|
||||
|
||||
class TestScreenLockerInit:
|
||||
"""Tests for ScreenLocker initialization."""
|
||||
|
||||
def test_init_demo_mode(
|
||||
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test initialization in demo mode."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
|
||||
|
||||
assert locker.demo_mode is True
|
||||
assert locker.lockout_time == 10
|
||||
mock_sys_exit.assert_not_called()
|
||||
|
||||
def test_init_production_mode(
|
||||
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test initialization in production mode."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
|
||||
|
||||
assert locker.demo_mode is False
|
||||
assert locker.lockout_time == 1800
|
||||
mock_sys_exit.assert_not_called()
|
||||
|
||||
def test_init_exits_if_logged_today(
|
||||
self, mock_tk: MagicMock, mock_sys_exit: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that init exits early if workout logged today."""
|
||||
mock_sys_exit.side_effect = SystemExit(0)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
create_locker(mock_tk, tmp_path, has_logged=True)
|
||||
|
||||
mock_sys_exit.assert_called_once_with(0)
|
||||
|
||||
|
||||
class TestHasLoggedToday:
|
||||
"""Tests for has_logged_today method."""
|
||||
|
||||
def test_no_log_file(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test when log file doesn't exist."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
|
||||
locker.log_file = log_file
|
||||
assert locker.has_logged_today() is False
|
||||
|
||||
def test_empty_log_file(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test when log file is empty/invalid JSON."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
log_file.write_text("")
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
assert locker.has_logged_today() is False
|
||||
|
||||
def test_invalid_json(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test when log file contains invalid JSON."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
log_file.write_text("{invalid json}")
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
assert locker.has_logged_today() is False
|
||||
|
||||
def test_today_logged(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test when today's workout is logged."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
log_file.write_text(json.dumps({today: {"workout": "data"}}))
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
assert locker.has_logged_today() is True
|
||||
|
||||
def test_other_day_logged(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test when only other days are logged."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
log_file.write_text(json.dumps({"2020-01-01": {"workout": "data"}}))
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
assert locker.has_logged_today() is False
|
||||
|
||||
|
||||
class TestSaveWorkoutLog:
|
||||
"""Tests for save_workout_log method."""
|
||||
|
||||
def test_save_to_new_file(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test saving to a new log file."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
locker.workout_data = {"type": "running"}
|
||||
locker.save_workout_log()
|
||||
|
||||
assert log_file.exists()
|
||||
with log_file.open() as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
assert today in data
|
||||
assert data[today]["workout_data"]["type"] == "running"
|
||||
|
||||
def test_save_to_existing_file(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test saving appends to existing log file."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
log_file.write_text(json.dumps({"2020-01-01": {"old": "data"}}))
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
locker.workout_data = {"type": "strength"}
|
||||
locker.save_workout_log()
|
||||
|
||||
with log_file.open() as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
assert "2020-01-01" in data
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
assert today in data
|
||||
|
||||
def test_save_with_corrupted_existing_file(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test saving when existing file is corrupted."""
|
||||
log_file = tmp_path / "workout_log.json"
|
||||
log_file.write_text("not valid json")
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
locker.workout_data = {"type": "running"}
|
||||
locker.save_workout_log()
|
||||
|
||||
with log_file.open() as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
assert today in data
|
||||
|
||||
def test_save_with_write_error(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test saving handles write errors gracefully."""
|
||||
log_file = tmp_path / "nonexistent_dir" / "workout_log.json"
|
||||
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = log_file
|
||||
locker.workout_data = {"type": "running"}
|
||||
# Should not raise, just log warning
|
||||
locker.save_workout_log()
|
||||
|
||||
|
||||
class TestShowError:
|
||||
"""Tests for show_error method."""
|
||||
|
||||
def test_show_error_displays_message(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test show_error clears container and displays error."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.show_error("Test error message")
|
||||
|
||||
locker.clear_container.assert_called_once()
|
||||
|
||||
|
||||
class TestRun:
|
||||
"""Tests for run method."""
|
||||
|
||||
def test_run_starts_mainloop(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test run starts the tkinter mainloop."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
|
||||
locker.run()
|
||||
|
||||
locker.root.mainloop.assert_called_once() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestMainEntry:
|
||||
"""Tests for main entry point."""
|
||||
|
||||
def test_main_demo_mode_default(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test main defaults to demo mode."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
|
||||
|
||||
assert locker.demo_mode is True
|
||||
|
||||
def test_main_production_mode_flag(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test main with --production flag."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
|
||||
|
||||
assert locker.demo_mode is False
|
||||
|
||||
|
||||
class TestAdjustShutdownTimeLater:
|
||||
"""Tests for _adjust_shutdown_time_later method."""
|
||||
|
||||
def test_adjust_shutdown_time_later_success(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _adjust_shutdown_time_later adds hours successfully."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(21, 22, 8)
|
||||
)
|
||||
locker._write_shutdown_config = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
result = locker._adjust_shutdown_time_later()
|
||||
|
||||
assert result is True
|
||||
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
|
||||
|
||||
def test_adjust_shutdown_time_later_caps_at_23(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _adjust_shutdown_time_later caps hours at 23."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
|
||||
return_value=(22, 23, 8)
|
||||
)
|
||||
locker._write_shutdown_config = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
result = locker._adjust_shutdown_time_later()
|
||||
|
||||
assert result is True
|
||||
# 22+2=24 capped to 23, 23+2=25 capped to 23
|
||||
locker._write_shutdown_config.assert_called_once_with(23, 23, 8, restore=True)
|
||||
|
||||
def test_adjust_shutdown_time_later_no_config(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _adjust_shutdown_time_later returns False if config missing."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
result = locker._adjust_shutdown_time_later()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_adjust_shutdown_time_later_oserror(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _adjust_shutdown_time_later handles OSError."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._read_shutdown_config = MagicMock( # type: ignore[method-assign]
|
||||
side_effect=OSError("permission denied")
|
||||
)
|
||||
|
||||
result = locker._adjust_shutdown_time_later()
|
||||
|
||||
assert result is False
|
||||
430
python_pkg/screen_locker/tests/test_phone_check_unlock.py
Normal file
430
python_pkg/screen_locker/tests/test_phone_check_unlock.py
Normal file
@ -0,0 +1,430 @@
|
||||
"""Tests for phone workout verification, phone check, and unlock operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from python_pkg.screen_locker.screen_lock import (
|
||||
PHONE_PENALTY_DELAY_DEMO,
|
||||
PHONE_PENALTY_DELAY_PRODUCTION,
|
||||
)
|
||||
from python_pkg.screen_locker.tests.conftest import create_locker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestVerifyPhoneWorkout:
|
||||
"""Tests for _verify_phone_workout method."""
|
||||
|
||||
def test_verified(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test workout verified on phone."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True,
|
||||
)
|
||||
locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign]
|
||||
return_value=tmp_path / "sl.db",
|
||||
)
|
||||
locker._count_today_workouts = MagicMock( # type: ignore[method-assign]
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
status, message = locker._verify_phone_workout()
|
||||
|
||||
assert status == "verified"
|
||||
assert "2 session" in message
|
||||
|
||||
def test_not_verified(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test no workout found on phone."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True,
|
||||
)
|
||||
locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign]
|
||||
return_value=tmp_path / "sl.db",
|
||||
)
|
||||
locker._count_today_workouts = MagicMock( # type: ignore[method-assign]
|
||||
return_value=0,
|
||||
)
|
||||
|
||||
status, message = locker._verify_phone_workout()
|
||||
|
||||
assert status == "not_verified"
|
||||
assert "No workout" in message
|
||||
|
||||
def test_no_phone(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test no phone connected."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
status, _ = locker._verify_phone_workout()
|
||||
|
||||
assert status == "no_phone"
|
||||
|
||||
def test_error_no_db(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test error when StrongLifts DB cannot be pulled."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._is_phone_connected = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True,
|
||||
)
|
||||
locker._pull_stronglifts_db = MagicMock( # type: ignore[method-assign]
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
status, message = locker._verify_phone_workout()
|
||||
|
||||
assert status == "error"
|
||||
assert "database" in message.lower()
|
||||
|
||||
|
||||
class TestStartPhoneCheck:
|
||||
"""Tests for _start_phone_check and _handle_startup_phone_result."""
|
||||
|
||||
def test_start_phone_check_shows_checking_screen(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _start_phone_check shows checking message and starts check."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker._verify_phone_workout = MagicMock( # type: ignore[method-assign]
|
||||
return_value=("no_phone", "No phone"),
|
||||
)
|
||||
locker._poll_phone_check = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._start_phone_check()
|
||||
|
||||
locker.clear_container.assert_called()
|
||||
locker._poll_phone_check.assert_called_once()
|
||||
assert locker._phone_future is not None
|
||||
|
||||
def test_handle_startup_verified_unlocks_directly(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test verified result shows success screen then unlocks via after()."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.unlock_screen = MagicMock() # type: ignore[method-assign]
|
||||
locker.root.after = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._handle_startup_phone_result("verified", "Workout verified! (1 session)")
|
||||
|
||||
# unlock_screen is deferred via root.after, not called directly
|
||||
locker.unlock_screen.assert_not_called()
|
||||
assert locker.workout_data["type"] == "phone_verified"
|
||||
locker.root.after.assert_called_once_with(1500, locker.unlock_screen)
|
||||
|
||||
def test_handle_startup_not_verified_shows_block(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test not_verified result shows blocking screen with buttons."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker._handle_startup_phone_result(
|
||||
"not_verified", "No workout found on phone today"
|
||||
)
|
||||
|
||||
locker.clear_container.assert_called()
|
||||
|
||||
def test_handle_startup_no_phone_shows_penalty(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test no_phone result triggers penalty with ask_workout_done as callback."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._show_phone_penalty = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._handle_startup_phone_result("no_phone", "No phone")
|
||||
|
||||
locker._show_phone_penalty.assert_called_once()
|
||||
_, kwargs = locker._show_phone_penalty.call_args
|
||||
assert kwargs["on_done"] == locker.ask_workout_done
|
||||
|
||||
def test_handle_startup_error_shows_penalty(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test error result triggers penalty with ask_workout_done as callback."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker._show_phone_penalty = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._handle_startup_phone_result("error", "DB not found")
|
||||
|
||||
locker._show_phone_penalty.assert_called_once()
|
||||
_, kwargs = locker._show_phone_penalty.call_args
|
||||
assert kwargs["on_done"] == locker.ask_workout_done
|
||||
|
||||
def test_poll_phone_check_schedules_retry_when_pending(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _poll_phone_check reschedules itself when future is not done."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
mock_future: MagicMock = MagicMock()
|
||||
mock_future.done.return_value = False
|
||||
locker._phone_future = mock_future # type: ignore[assignment]
|
||||
locker.root.after = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._poll_phone_check()
|
||||
|
||||
locker.root.after.assert_called_once_with(500, locker._poll_phone_check)
|
||||
|
||||
def test_poll_phone_check_routes_when_done(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _poll_phone_check calls result handler when future is done."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
mock_future: MagicMock = MagicMock()
|
||||
mock_future.done.return_value = True
|
||||
mock_future.result.return_value = ("no_phone", "No phone")
|
||||
locker._phone_future = mock_future # type: ignore[assignment]
|
||||
locker._handle_startup_phone_result = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._poll_phone_check()
|
||||
|
||||
locker._handle_startup_phone_result.assert_called_once_with(
|
||||
"no_phone", "No phone"
|
||||
)
|
||||
|
||||
|
||||
class TestAttemptUnlock:
|
||||
"""Tests for _attempt_unlock method."""
|
||||
|
||||
def test_attempt_unlock_calls_unlock_screen(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test _attempt_unlock calls unlock_screen directly."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "strength"}
|
||||
locker.unlock_screen = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._attempt_unlock()
|
||||
|
||||
locker.unlock_screen.assert_called_once()
|
||||
|
||||
|
||||
class TestShowPhonePenalty:
|
||||
"""Tests for _show_phone_penalty and _update_phone_penalty methods."""
|
||||
|
||||
def test_show_phone_penalty_demo_delay(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test demo mode uses short penalty delay."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=True)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._show_phone_penalty("test message")
|
||||
|
||||
# _update_phone_penalty is called once, decrementing by 1
|
||||
assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_DEMO - 1
|
||||
|
||||
def test_show_phone_penalty_production_delay(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test production mode uses long penalty delay."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker._show_phone_penalty("test message")
|
||||
|
||||
assert locker.phone_penalty_remaining == PHONE_PENALTY_DELAY_PRODUCTION - 1
|
||||
|
||||
def test_update_phone_penalty_countdown(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test phone penalty countdown decrements."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.phone_penalty_remaining = 5
|
||||
locker.phone_penalty_label = MagicMock()
|
||||
|
||||
locker._update_phone_penalty()
|
||||
|
||||
assert locker.phone_penalty_remaining == 4
|
||||
locker.phone_penalty_label.config.assert_called_once_with(text="5")
|
||||
locker.root.after.assert_called() # type: ignore[attr-defined]
|
||||
|
||||
def test_update_phone_penalty_at_zero(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test phone penalty unlocks when timer reaches zero."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "strength"}
|
||||
locker.phone_penalty_remaining = 0
|
||||
locker.phone_penalty_label = MagicMock()
|
||||
locker.unlock_screen = MagicMock() # type: ignore[method-assign]
|
||||
locker._phone_penalty_done_fn = locker.unlock_screen # type: ignore[attr-defined]
|
||||
|
||||
locker._update_phone_penalty()
|
||||
|
||||
locker.unlock_screen.assert_called_once()
|
||||
|
||||
|
||||
class TestUnlockScreenShutdownAdjustment:
|
||||
"""Tests for unlock_screen shutdown time adjustment."""
|
||||
|
||||
def test_unlock_screen_adjusts_for_running(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen adjusts shutdown for running workout."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "running"}
|
||||
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
locker.unlock_screen()
|
||||
|
||||
locker._adjust_shutdown_time_later.assert_called_once()
|
||||
|
||||
def test_unlock_screen_adjusts_for_strength(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen adjusts shutdown for strength workout."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "strength"}
|
||||
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
locker.unlock_screen()
|
||||
|
||||
locker._adjust_shutdown_time_later.assert_called_once()
|
||||
|
||||
def test_unlock_screen_adjusts_for_phone_verified(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen adjusts shutdown for phone-verified workout."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "phone_verified"}
|
||||
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
locker.unlock_screen()
|
||||
|
||||
locker._adjust_shutdown_time_later.assert_called_once()
|
||||
|
||||
def test_unlock_screen_skips_adjustment_for_sick_day(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen does not adjust for sick day."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "sick_day"}
|
||||
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
locker.unlock_screen()
|
||||
|
||||
locker._adjust_shutdown_time_later.assert_not_called()
|
||||
|
||||
def test_unlock_screen_skips_adjustment_no_type(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen does not adjust when no workout type."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {}
|
||||
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
|
||||
return_value=True
|
||||
)
|
||||
|
||||
locker.unlock_screen()
|
||||
|
||||
locker._adjust_shutdown_time_later.assert_not_called()
|
||||
|
||||
def test_unlock_screen_handles_adjustment_failure(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen continues when adjustment fails."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "running"}
|
||||
locker._adjust_shutdown_time_later = MagicMock( # type: ignore[method-assign]
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Should not raise, should continue with unlock
|
||||
locker.unlock_screen()
|
||||
|
||||
locker._adjust_shutdown_time_later.assert_called_once()
|
||||
locker.root.after.assert_called() # type: ignore[attr-defined]
|
||||
File diff suppressed because it is too large
Load Diff
424
python_pkg/screen_locker/tests/test_ui_and_timers.py
Normal file
424
python_pkg/screen_locker/tests/test_ui_and_timers.py
Normal file
@ -0,0 +1,424 @@
|
||||
"""Tests for UI transitions, timer logic, and workout detail screens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tkinter as tk
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from python_pkg.screen_locker.screen_lock import (
|
||||
SUBMIT_DELAY_DEMO,
|
||||
SUBMIT_DELAY_PRODUCTION,
|
||||
)
|
||||
from python_pkg.screen_locker.tests.conftest import create_locker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
_TK_TCLERROR = tk.TclError
|
||||
|
||||
|
||||
class TestUITransitions:
|
||||
"""Tests for UI state transitions."""
|
||||
|
||||
def test_clear_container(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test clear_container destroys all child widgets."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
|
||||
# Set up mock children
|
||||
mock_child1 = MagicMock()
|
||||
mock_child2 = MagicMock()
|
||||
locker.container.winfo_children.return_value = [ # type: ignore[attr-defined]
|
||||
mock_child1,
|
||||
mock_child2,
|
||||
]
|
||||
|
||||
locker.clear_container()
|
||||
|
||||
mock_child1.destroy.assert_called_once()
|
||||
mock_child2.destroy.assert_called_once()
|
||||
|
||||
def test_unlock_screen_saves_and_schedules_close(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unlock_screen saves log and schedules close."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.log_file = tmp_path / "workout_log.json"
|
||||
locker.workout_data = {"type": "running"}
|
||||
|
||||
locker.unlock_screen()
|
||||
|
||||
# Check that after() was called to schedule close
|
||||
locker.root.after.assert_called() # type: ignore[attr-defined]
|
||||
|
||||
def test_lockout_starts_countdown(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test lockout initializes countdown timer."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
|
||||
locker.lockout()
|
||||
|
||||
# lockout() sets remaining_time to lockout_time (10 in demo mode)
|
||||
# then calls update_lockout_countdown() which decrements it by 1
|
||||
assert locker.remaining_time == 9 # 10 - 1 after first update
|
||||
|
||||
def test_close_destroys_root_and_exits(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test close destroys root window and exits."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
|
||||
locker.close()
|
||||
|
||||
locker.root.destroy.assert_called_once() # type: ignore[attr-defined]
|
||||
mock_sys_exit.assert_called_with(0)
|
||||
|
||||
|
||||
class TestTimerLogic:
|
||||
"""Tests for timer countdown logic."""
|
||||
|
||||
def test_update_lockout_countdown_decrements(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test countdown decrements remaining time."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.remaining_time = 5
|
||||
locker.countdown_label = MagicMock()
|
||||
|
||||
locker.update_lockout_countdown()
|
||||
|
||||
assert locker.remaining_time == 4
|
||||
locker.root.after.assert_called_with( # type: ignore[attr-defined]
|
||||
1000, locker.update_lockout_countdown
|
||||
)
|
||||
|
||||
def test_update_lockout_countdown_at_zero(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test countdown at zero returns to workout question."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.remaining_time = 0
|
||||
locker.countdown_label = MagicMock()
|
||||
locker.ask_workout_done = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.update_lockout_countdown()
|
||||
|
||||
locker.ask_workout_done.assert_called_once()
|
||||
|
||||
def test_update_submit_timer_countdown(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test submit timer counts down."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.submit_unlock_time = 5
|
||||
locker.timer_label = MagicMock()
|
||||
locker.submit_btn = MagicMock()
|
||||
locker.entries_to_check = []
|
||||
|
||||
locker.update_submit_timer()
|
||||
|
||||
assert locker.submit_unlock_time == 4
|
||||
locker.root.after.assert_called() # type: ignore[attr-defined]
|
||||
|
||||
def test_update_submit_timer_enables_when_filled(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test submit enabled when timer done and entries filled."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.submit_unlock_time = 0
|
||||
locker.timer_label = MagicMock()
|
||||
locker.submit_btn = MagicMock()
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.get.return_value = "some value"
|
||||
locker.entries_to_check = [mock_entry]
|
||||
locker.submit_command = MagicMock()
|
||||
|
||||
locker.update_submit_timer()
|
||||
|
||||
locker.submit_btn.config.assert_called()
|
||||
|
||||
def test_update_submit_timer_waits_for_entries(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test submit waits when entries not filled."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.submit_unlock_time = 0
|
||||
locker.timer_label = MagicMock()
|
||||
locker.submit_btn = MagicMock()
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.get.return_value = "" # Empty entry
|
||||
locker.entries_to_check = [mock_entry]
|
||||
|
||||
locker.update_submit_timer()
|
||||
|
||||
locker.root.after.assert_called_with( # type: ignore[attr-defined]
|
||||
1000, locker.check_entries_filled
|
||||
)
|
||||
|
||||
def test_update_submit_timer_handles_tcl_error(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test timer handles TclError when widgets destroyed."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.submit_unlock_time = 5
|
||||
locker.timer_label = MagicMock()
|
||||
locker.timer_label.config.side_effect = _TK_TCLERROR("widget destroyed")
|
||||
|
||||
# Should not raise
|
||||
locker.update_submit_timer()
|
||||
|
||||
def test_check_entries_filled_enables_submit(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test check_entries_filled enables submit when all filled."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.timer_label = MagicMock()
|
||||
locker.submit_btn = MagicMock()
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.get.return_value = "value"
|
||||
locker.entries_to_check = [mock_entry]
|
||||
locker.submit_command = MagicMock()
|
||||
|
||||
locker.check_entries_filled()
|
||||
|
||||
locker.submit_btn.config.assert_called()
|
||||
|
||||
def test_check_entries_filled_continues_waiting(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test check_entries_filled continues waiting when not filled."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.timer_label = MagicMock()
|
||||
locker.submit_btn = MagicMock()
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.get.return_value = ""
|
||||
locker.entries_to_check = [mock_entry]
|
||||
|
||||
locker.check_entries_filled()
|
||||
|
||||
locker.root.after.assert_called_with( # type: ignore[attr-defined]
|
||||
1000, locker.check_entries_filled
|
||||
)
|
||||
|
||||
def test_check_entries_filled_handles_tcl_error(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test check_entries_filled handles TclError."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.timer_label = MagicMock()
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.get.side_effect = _TK_TCLERROR("widget destroyed")
|
||||
locker.entries_to_check = [mock_entry]
|
||||
|
||||
# Should not raise
|
||||
locker.check_entries_filled()
|
||||
|
||||
|
||||
class TestAskWorkoutType:
|
||||
"""Tests for ask_workout_type method."""
|
||||
|
||||
def test_ask_workout_type_creates_buttons(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_workout_type creates running and strength buttons."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_workout_type()
|
||||
|
||||
locker.clear_container.assert_called_once()
|
||||
# Verify Label and Button were called
|
||||
mock_tk.Label.assert_called()
|
||||
mock_tk.Button.assert_called()
|
||||
|
||||
|
||||
class TestAskRunningDetails:
|
||||
"""Tests for ask_running_details method."""
|
||||
|
||||
def test_ask_running_details_sets_workout_type(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_running_details sets workout type to running."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_running_details()
|
||||
|
||||
assert locker.workout_data["type"] == "running"
|
||||
locker.clear_container.assert_called_once()
|
||||
|
||||
def test_ask_running_details_creates_entry_fields(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_running_details creates entry fields."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_running_details()
|
||||
|
||||
# Verify Entry fields were created
|
||||
mock_tk.Entry.assert_called()
|
||||
assert hasattr(locker, "distance_entry")
|
||||
assert hasattr(locker, "time_entry")
|
||||
assert hasattr(locker, "pace_entry")
|
||||
|
||||
def test_ask_running_details_sets_timer(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_running_details initializes submit timer."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_running_details()
|
||||
|
||||
assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO
|
||||
locker.update_submit_timer.assert_called_once()
|
||||
|
||||
|
||||
class TestAskStrengthDetails:
|
||||
"""Tests for ask_strength_details method."""
|
||||
|
||||
def test_ask_strength_details_sets_workout_type(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_strength_details sets workout type to strength."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_strength_details()
|
||||
|
||||
assert locker.workout_data["type"] == "strength"
|
||||
locker.clear_container.assert_called_once()
|
||||
|
||||
def test_ask_strength_details_creates_entry_fields(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_strength_details creates entry fields."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_strength_details()
|
||||
|
||||
# Verify Entry fields were created
|
||||
mock_tk.Entry.assert_called()
|
||||
assert hasattr(locker, "exercises_entry")
|
||||
assert hasattr(locker, "sets_entry")
|
||||
assert hasattr(locker, "reps_entry")
|
||||
assert hasattr(locker, "weights_entry")
|
||||
assert hasattr(locker, "total_weight_entry")
|
||||
|
||||
def test_ask_strength_details_sets_timer(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_strength_details initializes submit timer."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_strength_details()
|
||||
|
||||
assert locker.submit_unlock_time == SUBMIT_DELAY_DEMO
|
||||
locker.update_submit_timer.assert_called_once()
|
||||
|
||||
def test_ask_strength_details_production_timer(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test production mode uses longer submit delay."""
|
||||
locker = create_locker(mock_tk, tmp_path, demo_mode=False)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
locker.update_submit_timer = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_strength_details()
|
||||
|
||||
assert locker.submit_unlock_time == SUBMIT_DELAY_PRODUCTION
|
||||
|
||||
|
||||
class TestAskWorkoutDone:
|
||||
"""Tests for ask_workout_done method."""
|
||||
|
||||
def test_ask_workout_done_creates_buttons(
|
||||
self,
|
||||
mock_tk: MagicMock,
|
||||
_mock_sys_exit: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test ask_workout_done creates yes/no buttons."""
|
||||
locker = create_locker(mock_tk, tmp_path)
|
||||
locker.clear_container = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
locker.ask_workout_done()
|
||||
|
||||
locker.clear_container.assert_called_once()
|
||||
mock_tk.Label.assert_called()
|
||||
mock_tk.Button.assert_called()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user