testsAndMisc-archive/python_pkg/praca_magisterska_video/_q23_transformer.py

436 lines
12 KiB
Python
Raw Normal View History

"""Transformer segmentation and methods comparison for Q23 video."""
from __future__ import annotations
from moviepy import (
ColorClip,
CompositeVideoClip,
VideoClip,
)
from moviepy.video.fx import FadeIn, FadeOut
import numpy as np
from python_pkg.praca_magisterska_video._q23_helpers import (
BG_COLOR,
FONT_B,
FONT_R,
FPS,
STEP_DUR,
H,
W,
_compose_slide,
_tc,
_text_slide,
)
# ── Transformer Segmentation ────────────────────────────────────
def _draw_base_grid(
frame: np.ndarray,
gx: int,
gy: int,
grid_n: int,
cell: int,
) -> None:
"""Draw an empty grid of cells."""
for r in range(grid_n):
for c in range(grid_n):
x = gx + c * cell
y = gy + r * cell
frame[y : y + cell - 2, x : x + cell - 2] = (35, 40, 55)
def _draw_cnn_kernel(
frame: np.ndarray,
lx: int,
ly: int,
cell: int,
progress: float,
) -> None:
"""Highlight a 3x3 CNN kernel on the grid."""
cnn_phase = 0.2
if progress <= cnn_phase:
return
cx, cy = 2, 2
for dr in range(-1, 2):
for dc in range(-1, 2):
r, c = cy + dr, cx + dc
x = lx + c * cell
y = ly + r * cell
frame[y : y + cell - 2, x : x + cell - 2] = (70, 130, 200)
x = lx + cx * cell
y = ly + cy * cell
frame[y : y + cell - 2, x : x + cell - 2] = (120, 180, 255)
def _draw_conn_line(
frame: np.ndarray,
x0: int,
y0: int,
x1: int,
y1: int,
) -> None:
"""Draw a dashed connection line between two points."""
steps = max(abs(x1 - x0), abs(y1 - y0))
if steps <= 0:
return
for s in range(0, steps, 3):
px = x0 + int((x1 - x0) * s / steps)
py = y0 + int((y1 - y0) * s / steps)
if 0 <= px < W - 1 and 0 <= py < H - 1:
frame[py : py + 1, px : px + 1] = (200, 180, 50)
def _draw_attention_connections(
frame: np.ndarray,
origin: tuple[int, int],
grid_n: int,
cell: int,
progress: float,
) -> None:
"""Draw transformer self-attention connections on the grid."""
rx, ry = origin
transformer_phase = 0.4
if progress <= transformer_phase:
return
cx_t, cy_t = 2, 2
x0 = rx + cx_t * cell + cell // 2
y0 = ry + cy_t * cell + cell // 2
n_connections = int(progress * 36)
conn_idx = 0
for r in range(grid_n):
for c in range(grid_n):
conn_idx += 1
if conn_idx > n_connections:
break
x = rx + c * cell
y = ry + r * cell
dist = abs(r - cy_t) + abs(c - cx_t)
strength = max(30, 200 - dist * 30)
frame[y : y + cell - 2, x : x + cell - 2] = (
strength // 3,
strength // 2,
strength,
)
_draw_conn_line(frame, x0, y0, x + cell // 2, y + cell // 2)
else:
continue
break
x = rx + cx_t * cell
y = ry + cy_t * cell
frame[y : y + cell - 2, x : x + cell - 2] = (255, 200, 50)
def _make_attention_frame(t: float) -> np.ndarray:
"""Render a CNN-vs-Transformer attention comparison frame."""
frame = np.zeros((H, W, 3), dtype=np.uint8)
frame[:] = BG_COLOR
progress = min(t / (STEP_DUR * 0.7), 1.0)
cell = 40
grid_n = 6
lx, ly = 60, 200
_draw_base_grid(frame, lx, ly, grid_n, cell)
_draw_cnn_kernel(frame, lx, ly, cell, progress)
rx, ry = 680, 200
_draw_base_grid(frame, rx, ry, grid_n, cell)
_draw_attention_connections(frame, (rx, ry), grid_n, cell, progress)
return frame
def _transformer_seg_demo() -> list[CompositeVideoClip]:
"""Animate transformer-based segmentation: self-attention concept."""
dur = STEP_DUR + 1
# Slide 1: CNN local vs Transformer global
att_clip = VideoClip(_make_attention_frame, duration=dur).with_fps(FPS)
labels = [
("Transformer: Self-Attention w segmentacji", 26, "#FFE082", FONT_B, (80, 20)),
("CNN = LOKALNY kontekst", 18, "#64B5F6", FONT_B, (60, 160)),
("Transformer = GLOBALNY kontekst", 18, "#FFE082", FONT_B, (680, 160)),
("Filtr 3x3 widzi", 14, "#64B5F6", FONT_R, (60, 460)),
("TYLKO 9 sąsiadów", 14, "#64B5F6", FONT_R, (60, 485)),
("Self-attention: każdy", 14, "#FFE082", FONT_R, (680, 460)),
("piksel widzi WSZYSTKIE!", 14, "#FFE082", FONT_R, (680, 485)),
("vs", 28, "#B0BEC5", FONT_B, (450, 300)),
]
slides = [_compose_slide(att_clip, labels, dur)]
# Slide 2: Self-attention Q/K/V step by step
qkv_lines = [
("Self-Attention: Q / K / V krok po kroku", 26, "#FFE082", FONT_B, (80, 30)),
("Każdy piksel (token) tworzy 3 wektory:", 18, "#B0BEC5", FONT_R, (100, 100)),
(
" Q (Query) = 'czego szukam?' - pytanie piksela",
17,
"#64B5F6",
FONT_R,
(120, 145),
),
(
" K (Key) = 'co oferuj\u0119?' - odpowied\u017a piksela",
17,
"#A5D6A7",
FONT_R,
(120, 185),
),
(
" V (Value) = 'moja warto\u015b\u0107' - informacja do przekazania",
17,
"#FFE082",
FONT_R,
(120, 225),
),
("Algorytm attention:", 18, "#B0BEC5", FONT_R, (100, 285)),
(
" 1. Mnożenie Q x K\u1d40 → macierz NxN (kto ważny dla kogo)",
16,
"white",
FONT_R,
(120, 320),
),
(
" 2. Skalowanie: / \u221ad (stabilno\u015b\u0107 gradient\u00f3w)",
16,
"white",
FONT_R,
(120, 355),
),
(
" 3. Softmax \u2192 wagi attention (sumuj\u0105 si\u0119 do 1)",
16,
"white",
FONT_R,
(120, 390),
),
(
" 4. Mno\u017cenie wag x V \u2192 wa\u017cona suma warto\u015bci",
16,
"white",
FONT_R,
(120, 425),
),
(
"Attention(Q,K,V) = softmax(Q \u00b7 K\u1d40 / \u221ad) \u00b7 V",
20,
"#FFE082",
FONT_B,
(100, 480),
),
(
"Z\u0142o\u017cono\u015b\u0107: O(n\u00b2) pami\u0119ci \u2014 n = liczba "
"pikseli/token\u00f3w",
16,
"#EF9A9A",
FONT_R,
(100, 535),
),
(
"Dlatego SegFormer u\u017cywa efficient attention (liniowa "
"z\u0142o\u017cono\u015b\u0107)",
15,
"#78909C",
FONT_R,
(100, 570),
),
(
"SegFormer (2021): lightweight + hierarchiczny encoder",
16,
"#A5D6A7",
FONT_R,
(100, 610),
),
(
"Mask2Former (2022): masked attention + "
"unified (semantic+instance+panoptic)",
16,
"#CE93D8",
FONT_R,
(100, 645),
),
]
slides.append(_text_slide(qkv_lines, duration=STEP_DUR + 1))
# Slide 3: Encoder-Decoder in DL summary
summary_lines = [
(
"Podsumowanie: Encoder-Decoder w segmentacji DL",
24,
"#FFE082",
FONT_B,
(80, 30),
),
(
"Wsp\u00f3lna idea WSZYSTKICH sieci segmentacji:",
18,
"#B0BEC5",
FONT_R,
(80, 90),
),
(
"Encoder: obraz \u2192 cechy (zmniejsza rozdzielczo\u015b\u0107, "
"wyci\u0105ga CO)",
16,
"#64B5F6",
FONT_R,
(100, 140),
),
(
"Decoder: cechy \u2192 mapa (zwi\u0119ksza rozdzielczo\u015b\u0107, "
"odtwarza GDZIE)",
16,
"#A5D6A7",
FONT_R,
(100, 175),
),
(
"Skip: przenosi detale z encodera do decodera",
16,
"#FFE082",
FONT_R,
(100, 210),
),
("", 10, "white", FONT_R, (100, 240)),
(
"FCN (2015): Conv1x1 + skip \u2192 pierwsza end-to-end",
16,
"#64B5F6",
FONT_R,
(100, 275),
),
(
"U-Net (2015): U-shape + skip concat \u2192 segmentacja medyczna",
16,
"#A5D6A7",
FONT_R,
(100, 310),
),
(
"DeepLab (2018): dilated conv + ASPP \u2192 multi-scale kontekst",
16,
"#FFE082",
FONT_R,
(100, 345),
),
(
"SegFormer: transformer encoder (globalny kontekst)",
16,
"#CE93D8",
FONT_R,
(100, 380),
),
(
"Mask2Former: masked attention (unified, SOTA)",
16,
"#CE93D8",
FONT_R,
(100, 415),
),
("", 10, "white", FONT_R, (100, 440)),
(
"Ewolucja: wi\u0119cej kontekstu + lepsze skip connections:",
17,
"white",
FONT_R,
(80, 465),
),
(
" CNN lokal. \u2192 dilated (szersze RF) \u2192 transformer (global) "
"\u2192 masked att.",
16,
"#B0BEC5",
FONT_R,
(80, 505),
),
(
" addition skip \u2192 concat skip \u2192 cross-attention skip",
16,
"#B0BEC5",
FONT_R,
(80, 540),
),
(
"Metryki: mIoU (standard), Dice (medycyna), Focal Loss (imbalance)",
16,
"#90CAF9",
FONT_R,
(80, 590),
),
(
"Loss: Cross-Entropy per piksel + opcjonalnie Dice/Focal",
15,
"#78909C",
FONT_R,
(80, 625),
),
]
slides.append(_text_slide(summary_lines, duration=STEP_DUR + 1))
return slides
# ── Methods comparison ────────────────────────────────────────────
def _methods_comparison() -> CompositeVideoClip:
"""Create a comparison table of all segmentation methods."""
bg = ColorClip(size=(W, H), color=BG_COLOR).with_duration(10.0)
title = (
_tc(
text="Por\u00f3wnanie metod segmentacji",
font_size=36,
color="white",
font=FONT_B,
)
.with_duration(10.0)
.with_position(("center", 20))
)
rows = [
("Metoda", "Typ", "Idea", "Mnemonik"),
(
"Thresholding",
"Klasyczna",
"piksel > T \u2192 klasa 1",
"PR\u00d3G na bramce",
),
("Otsu", "Klasyczna", "auto-pr\u00f3g, min \u03c3\u00b2", "AUTO-bramkarz"),
("Region Growing", "Klasyczna", "BFS od seeda", "PLAMA atramentu"),
("Watershed", "Klasyczna", "zalewanie minim\u00f3w", "ZALEWANIE terenu"),
(
"Mean Shift",
"Klasyczna",
"j\u0105dro \u2192 max g\u0119sto\u015bci",
"KULKI do do\u0142k\u00f3w",
),
("U-Net", "Deep Learning", "encoder-decoder + skip", "Litera U + mosty"),
("DeepLab", "Deep Learning", "dilated conv + ASPP", "DZIURY w filtrze"),
]
clips: list[VideoClip] = [bg, title]
mnemonic_col = 3
for i, row in enumerate(rows):
y_pos = 75 + i * 72
col_x = [40, 210, 340, 660]
for j, cell in enumerate(row):
fs = 16 if i > 0 else 18
color = (
"#64B5F6" if i == 0 else ("#E0E0E0" if j < mnemonic_col else "#FFE082")
)
tc = (
_tc(
text=cell,
font_size=fs,
color=color,
font=FONT_B if i == 0 else FONT_R,
)
.with_duration(10.0)
.with_position((col_x[j], y_pos))
)
clips.append(tc)
return CompositeVideoClip(clips, size=(W, H)).with_effects(
[FadeIn(0.5), FadeOut(0.5)]
)