testsAndMisc/linux_configuration/scripts/misc/testsAndMisc-bash/tools/_transcribe_diarize.py
Krzysztof kuhy Rudnicki 2545d72710 test: achieve 100% branch coverage across all python_pkg packages
- Add comprehensive tests for all packages (3572 tests, 100% branch coverage)
- Split oversized test files to stay under 500-line limit
- Add per-file ruff ignores for test-appropriate suppressions
- Fix _cache_decks.py to properly convert JSON lists to tuples
- Add session-scoped conftest fixture for logging handler cleanup (Python 3.14)
- Update ruff pre-commit hook to v0.15.2
- Add codespell ignore words for test data
- Add generated output files to .gitignore
2026-03-21 17:51:36 +01:00

361 lines
9.5 KiB
Python

"""Speaker diarization and audio processing utilities."""
from __future__ import annotations
import contextlib
import importlib
import logging
from pathlib import Path
import shutil
import subprocess
import tempfile
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import types
import numpy as np
import numpy.typing as npt
logger = logging.getLogger(__name__)
_NDIM_2D = 2
_SAMPLE_RATE_16K = 16000
_MIN_SAMPLES_DIAR = 1600
def _try_import(name: str) -> types.ModuleType | None:
"""Attempt to import a module, returning None on failure."""
try:
return importlib.import_module(name)
except ImportError:
return None
def _probe_with_ffmpeg_python(
path: str,
) -> float | None:
"""Try ffmpeg-python to get duration."""
ffmpeg_mod = _try_import("ffmpeg")
if ffmpeg_mod is None:
return None
try:
probe = ffmpeg_mod.probe(path)
fmt = probe.get("format", {})
if "duration" in fmt:
return float(fmt["duration"])
except (OSError, RuntimeError):
pass
return None
def _probe_with_ffprobe(path: str) -> float | None:
"""Try ffprobe CLI to get duration."""
ffprobe_bin = shutil.which("ffprobe")
if ffprobe_bin is None:
return None
try:
out = subprocess.check_output(
[
ffprobe_bin,
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
path,
],
stderr=subprocess.DEVNULL,
)
return float(out.decode().strip())
except (
OSError,
subprocess.CalledProcessError,
ValueError,
):
return None
def get_media_duration(path: str) -> float | None:
"""Try to get media duration in seconds.
Returns None if unavailable.
"""
result = _probe_with_ffmpeg_python(path)
if result is not None:
return result
return _probe_with_ffprobe(path)
def _resample_linear(
x: npt.NDArray[np.float32],
src_sr: int,
tgt_sr: int,
) -> npt.NDArray[np.float32]:
"""Linearly resample 1-D audio array."""
np_mod = _try_import("numpy")
if np_mod is None:
msg = "numpy is required for resampling"
raise RuntimeError(msg)
if src_sr == tgt_sr:
return x
ratio = float(tgt_sr) / float(src_sr)
n_out = max(1, round(x.shape[-1] * ratio))
xp = np_mod.linspace(0.0, 1.0, num=x.shape[-1], endpoint=False)
xq = np_mod.linspace(0.0, 1.0, num=n_out, endpoint=False)
y = np_mod.interp(xq, xp, x.astype(np_mod.float32))
return y.astype(np_mod.float32)
def _kmeans_cosine(
embs: list[Any],
k: int,
iters: int = 50,
seed: int = 0,
) -> npt.NDArray[np.int64]:
"""Cluster embeddings with cosine-similarity k-means."""
np_mod = _try_import("numpy")
if np_mod is None:
msg = "numpy is required for clustering"
raise RuntimeError(msg)
rng = np_mod.random.default_rng(seed)
features = np_mod.asarray(embs, dtype=np_mod.float32)
if features.ndim != _NDIM_2D or features.shape[0] == 0:
return np_mod.zeros((0,), dtype=np_mod.int64)
features = features / (np_mod.linalg.norm(features, axis=1, keepdims=True) + 1e-8)
idxs = rng.choice(
features.shape[0],
size=min(k, features.shape[0]),
replace=False,
)
centroids = features[idxs]
if centroids.shape[0] < k:
pad = rng.standard_normal(
size=(
k - centroids.shape[0],
features.shape[1],
)
).astype(np_mod.float32)
pad /= np_mod.linalg.norm(pad, axis=1, keepdims=True) + 1e-8
centroids = np_mod.concatenate([centroids, pad], axis=0)
return _run_kmeans_iterations(np_mod, features, centroids, k, iters)
def _run_kmeans_iterations(
np_mod: object,
features: object,
centroids: object,
k: int,
iters: int,
) -> object:
"""Run k-means iteration loop and return labels."""
labels: object = None
for _ in range(iters):
sims = features @ centroids.T
labels = sims.argmax(axis=1)
new_c = np_mod.zeros_like(centroids)
for j in range(k):
sel = features[labels == j]
if sel.shape[0] == 0:
new_c[j] = centroids[j]
else:
v = sel.mean(axis=0)
v /= np_mod.linalg.norm(v) + 1e-8
new_c[j] = v
if np_mod.allclose(new_c, centroids, atol=1e-4):
break
centroids = new_c
return labels
def _ffmpeg_transcode_to_wav16_mono(
src_path: str,
) -> str | None:
"""Transcode input to a temporary 16k mono WAV.
Returns its path, or None if ffmpeg is unavailable.
"""
ffmpeg_bin = shutil.which("ffmpeg")
if ffmpeg_bin is None:
return None
with tempfile.NamedTemporaryFile(
prefix="fw_diar_",
suffix=".wav",
delete=False,
) as tmp:
tmp_path = tmp.name
cmd = [
ffmpeg_bin,
"-y",
"-v",
"error",
"-i",
src_path,
"-ac",
"1",
"-ar",
"16000",
"-f",
"wav",
tmp_path,
]
try:
subprocess.run(
cmd,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except (OSError, subprocess.CalledProcessError):
with contextlib.suppress(OSError):
Path(tmp_path).unlink()
return None
else:
return tmp_path
def _cleanup_temp(path: str | None) -> None:
"""Remove a temporary file if it exists."""
if path is not None:
with contextlib.suppress(OSError):
Path(path).unlink()
def _load_audio(
audio_path: str,
) -> tuple[Any, int, str | None] | None:
"""Load audio, with ffmpeg fallback.
Returns (wav, sample_rate, temp_path) or None.
"""
sf = _try_import("soundfile")
if sf is None:
return None
try:
wav, sr = sf.read(
audio_path,
dtype="float32",
always_2d=False,
)
except OSError as exc:
alt = _ffmpeg_transcode_to_wav16_mono(audio_path)
if alt is None:
logger.warning(
"Could not read audio for diarization and no ffmpeg fallback: %s",
exc,
)
return None
try:
wav, sr = sf.read(
alt,
dtype="float32",
always_2d=False,
)
except OSError as exc2:
logger.warning(
"Could not read transcoded audio: %s",
exc2,
)
_cleanup_temp(alt)
return None
else:
return wav, sr, alt
else:
return wav, sr, None
def _load_speaker_classifier(
temp_to_cleanup: str | None,
) -> object | None:
"""Load the ECAPA speaker embedding classifier."""
sb_inf = _try_import("speechbrain.inference")
if sb_inf is None:
return None
try:
cache_dir = Path.home() / ".cache" / "speechbrain_ecapa"
classifier = sb_inf.EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": "cpu"},
savedir=str(cache_dir),
)
except (OSError, RuntimeError) as exc:
logger.warning(
"Could not load speaker embedding model: %s",
exc,
)
_cleanup_temp(temp_to_cleanup)
return None
else:
return classifier
def _extract_embeddings(
segments: list[Any],
wav16: object,
classifier: object,
torch_mod: types.ModuleType,
) -> list[Any]:
"""Extract speaker embeddings per segment."""
embs: list[Any] = []
for seg in segments:
s = float(getattr(seg, "start", 0.0) or 0.0)
e = float(getattr(seg, "end", s) or s)
if e <= s:
e = s + 0.2
i0 = int(s * _SAMPLE_RATE_16K)
i1 = int(e * _SAMPLE_RATE_16K)
pad = int(0.05 * _SAMPLE_RATE_16K)
i0 = max(0, i0 - pad)
i1 = min(len(wav16), i1 + pad)
if i1 - i0 < _MIN_SAMPLES_DIAR:
i1 = min(len(wav16), i0 + _MIN_SAMPLES_DIAR)
seg_wav = torch_mod.tensor(wav16[i0:i1]).unsqueeze(0)
with torch_mod.no_grad():
emb = classifier.encode_batch(seg_wav).squeeze(0).squeeze(0).cpu().numpy()
embs.append(emb.astype("float32"))
return embs
def diarize_segments(
audio_path: str,
segments: list[Any],
num_speakers: int = 2,
) -> list[int] | None:
"""Compute speaker embeddings per segment and cluster.
Returns speaker labels aligned with segments,
or None on failure.
"""
torch_mod = _try_import("torch")
if torch_mod is None:
logger.warning(
"Diarization dependencies missing; skipping speaker labels.",
)
return None
audio_result = _load_audio(audio_path)
if audio_result is None:
return None
wav, sr, temp_to_cleanup = audio_result
if wav.ndim == _NDIM_2D:
wav = wav.mean(axis=1)
wav16 = _resample_linear(wav, sr, _SAMPLE_RATE_16K)
classifier = _load_speaker_classifier(temp_to_cleanup)
if classifier is None:
return None
embs = _extract_embeddings(segments, wav16, classifier, torch_mod)
if len(embs) == 0:
return None
labels = _kmeans_cosine(embs, k=max(1, int(num_speakers)))
_cleanup_temp(temp_to_cleanup)
return labels.tolist()