diff --git a/linux_configuration/scripts/digital_wellbeing/focus_mode_daemon.py b/linux_configuration/scripts/digital_wellbeing/focus_mode_daemon.py index 784bbc2..5d39b77 100755 --- a/linux_configuration/scripts/digital_wellbeing/focus_mode_daemon.py +++ b/linux_configuration/scripts/digital_wellbeing/focus_mode_daemon.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Focus Mode Daemon - Steam/Browser Mutual Exclusion +"""Focus Mode Daemon - Steam/Browser Mutual Exclusion. This daemon monitors running processes and enforces mutual exclusion between Steam (gaming) and web browsers. Whichever starts first "wins" and the other @@ -8,17 +8,27 @@ category is blocked/killed. Run as a systemd user service for continuous monitoring. """ -from datetime import datetime -import os +from __future__ import annotations + +import contextlib +from datetime import datetime, timezone +import logging from pathlib import Path +import shutil import signal import subprocess import sys import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import FrameType + +logger = logging.getLogger(__name__) # Configuration STATE_DIR = ( - Path(os.environ.get("XDG_STATE_HOME", Path.home() / ".local/state")) / "focus-mode" + Path.home() / ".local" / "state" / "focus-mode" ) LOG_FILE = STATE_DIR / "focus-mode.log" POLL_INTERVAL = 2 # seconds between process checks @@ -75,36 +85,44 @@ IGNORE_PATTERNS = frozenset( def log(message: str) -> None: """Log message with timestamp.""" - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + timestamp = datetime.now(tz=timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ) log_line = f"{timestamp} - {message}" - print(log_line) - try: + logger.info("%s", log_line) + with contextlib.suppress(OSError): STATE_DIR.mkdir(parents=True, exist_ok=True) - with open(LOG_FILE, "a") as f: + with LOG_FILE.open("a") as f: f.write(log_line + "\n") - except Exception: - pass -def notify(title: str, message: str, urgency: str = "normal") -> None: +def notify( + title: str, message: str, urgency: str = "normal" +) -> None: """Send desktop notification.""" - try: + notify_send = shutil.which("notify-send") + if notify_send is None: + return + with contextlib.suppress( + OSError, subprocess.SubprocessError + ): subprocess.run( - ["notify-send", "-u", urgency, title, message], + [notify_send, "-u", urgency, title, message], capture_output=True, timeout=5, check=False, ) - except Exception: - pass def get_running_processes() -> set[str]: """Get set of currently running process names.""" - processes = set() + processes: set[str] = set() + ps_bin = shutil.which("ps") + if ps_bin is None: + return processes try: result = subprocess.run( - ["ps", "-eo", "comm="], + [ps_bin, "-eo", "comm="], capture_output=True, text=True, timeout=10, @@ -115,18 +133,16 @@ def get_running_processes() -> set[str]: proc_name = line.strip().lower() if proc_name: processes.add(proc_name) - except Exception as e: - log(f"Error getting processes: {e}") + except (OSError, subprocess.SubprocessError) as exc: + log(f"Error getting processes: {exc}") return processes def is_steam_running(processes: set[str]) -> bool: """Check if Steam or any Steam game is running.""" for proc in processes: - # Check for Steam main processes if proc in STEAM_PATTERNS: return True - # Check for Steam games (have steam_app_ prefix) if proc.startswith(STEAM_GAME_PREFIX): return True return False @@ -135,129 +151,181 @@ def is_steam_running(processes: set[str]) -> bool: def is_browser_running(processes: set[str]) -> bool: """Check if any browser is running.""" for proc in processes: - # Skip Electron apps and ignored patterns if proc in ELECTRON_IGNORE: continue if any(ign in proc for ign in IGNORE_PATTERNS): continue - # Use exact match to avoid false positives from Electron apps if proc in BROWSER_PATTERNS: return True return False +def _run_pkill( + pattern: str, *, force: bool = False +) -> None: + """Run pkill with the given pattern.""" + pkill_bin = shutil.which("pkill") + if pkill_bin is None: + return + cmd = [pkill_bin] + if force: + cmd.append("-9") + cmd.extend(["-f", pattern]) + with contextlib.suppress( + OSError, subprocess.SubprocessError + ): + subprocess.run( + cmd, capture_output=True, timeout=5, check=False + ) + + def kill_steam() -> None: """Kill all Steam-related processes.""" log("Killing Steam processes...") - notify("🎮 Gaming Blocked", "Browser is active. Closing Steam.", "critical") + notify( + "\U0001f3ae Gaming Blocked", + "Browser is active. Closing Steam.", + "critical", + ) - try: - # First try graceful shutdown - subprocess.run( - ["pkill", "-f", "steam"], capture_output=True, timeout=5, check=False - ) - time.sleep(2) - - # Force kill if still running - subprocess.run( - ["pkill", "-9", "-f", "steam"], capture_output=True, timeout=5, check=False - ) - except Exception as e: - log(f"Error killing Steam: {e}") + _run_pkill("steam") + time.sleep(2) + _run_pkill("steam", force=True) def kill_browsers() -> None: """Kill all browser processes.""" log("Killing browser processes...") - notify("🌐 Browsers Blocked", "Steam is active. Closing browsers.", "critical") + notify( + "\U0001f310 Browsers Blocked", + "Steam is active. Closing browsers.", + "critical", + ) for browser in BROWSER_PATTERNS: - try: - subprocess.run( - ["pkill", "-f", browser], capture_output=True, timeout=5, check=False - ) - except Exception: - pass + _run_pkill(browser) time.sleep(2) - # Force kill if still running for browser in BROWSER_PATTERNS: - try: - subprocess.run( - ["pkill", "-9", "-f", browser], - capture_output=True, - timeout=5, - check=False, - ) - except Exception: - pass + _run_pkill(browser, force=True) class FocusMode: """Tracks current focus mode and enforces mutual exclusion.""" - def __init__(self): - self.current_mode: str | None = None # "gaming" or "browsing" or None + def __init__(self) -> None: + """Initialize focus mode as inactive.""" + self.current_mode: str | None = None self.mode_start_time: datetime | None = None + def _enter_mode( + self, mode: str, msg: str, notification: str + ) -> None: + """Enter a new focus mode.""" + log(msg) + self.current_mode = mode + self.mode_start_time = datetime.now(tz=timezone.utc) + notify(*notification.split("|", 1)) + + def _handle_no_mode( + self, + *, + steam_running: bool, + browser_running: bool, + ) -> None: + """Handle updates when no mode is active.""" + if steam_running and browser_running: + log( + "Both Steam and browsers detected at " + "startup - entering GAMING mode" + ) + self.current_mode = "gaming" + self.mode_start_time = datetime.now( + tz=timezone.utc + ) + kill_browsers() + elif steam_running: + self._enter_mode( + "gaming", + "Steam detected - entering GAMING mode", + "\U0001f3ae Gaming Mode|" + "Steam detected. Browsers are now blocked.", + ) + elif browser_running: + self._enter_mode( + "browsing", + "Browser detected - entering BROWSING mode", + "\U0001f310 Browsing Mode|" + "Browser detected. Steam is now blocked.", + ) + + def _handle_gaming( + self, + *, + steam_running: bool, + browser_running: bool, + ) -> None: + """Handle updates in gaming mode.""" + if not steam_running: + log("Steam closed - exiting GAMING mode") + self.current_mode = None + self.mode_start_time = None + notify( + "\U0001f3ae Gaming Mode Ended", + "You can now use browsers.", + "normal", + ) + elif browser_running: + log( + "Browser detected during GAMING mode " + "- killing browsers" + ) + kill_browsers() + + def _handle_browsing( + self, + *, + steam_running: bool, + browser_running: bool, + ) -> None: + """Handle updates in browsing mode.""" + if not browser_running: + log("Browsers closed - exiting BROWSING mode") + self.current_mode = None + self.mode_start_time = None + notify( + "\U0001f310 Browsing Mode Ended", + "You can now use Steam.", + "normal", + ) + elif steam_running: + log( + "Steam detected during BROWSING mode " + "- killing Steam" + ) + kill_steam() + def update(self, processes: set[str]) -> None: """Update focus mode based on running processes.""" steam_running = is_steam_running(processes) browser_running = is_browser_running(processes) if self.current_mode is None: - # No mode set yet - first to start wins - if steam_running and browser_running: - # Both running at startup - prefer gaming mode (close browsers) - log( - "Both Steam and browsers detected at startup - entering GAMING mode" - ) - self.current_mode = "gaming" - self.mode_start_time = datetime.now() - kill_browsers() - elif steam_running: - log("Steam detected - entering GAMING mode") - self.current_mode = "gaming" - self.mode_start_time = datetime.now() - notify( - "🎮 Gaming Mode", - "Steam detected. Browsers are now blocked.", - "normal", - ) - elif browser_running: - log("Browser detected - entering BROWSING mode") - self.current_mode = "browsing" - self.mode_start_time = datetime.now() - notify( - "🌐 Browsing Mode", - "Browser detected. Steam is now blocked.", - "normal", - ) - + self._handle_no_mode( + steam_running=steam_running, + browser_running=browser_running, + ) elif self.current_mode == "gaming": - if not steam_running: - # Steam closed - exit gaming mode - log("Steam closed - exiting GAMING mode") - self.current_mode = None - self.mode_start_time = None - notify("🎮 Gaming Mode Ended", "You can now use browsers.", "normal") - elif browser_running: - # Browser started while in gaming mode - kill it - log("Browser detected during GAMING mode - killing browsers") - kill_browsers() - + self._handle_gaming( + steam_running=steam_running, + browser_running=browser_running, + ) elif self.current_mode == "browsing": - if not browser_running: - # Browsers closed - exit browsing mode - log("Browsers closed - exiting BROWSING mode") - self.current_mode = None - self.mode_start_time = None - notify("🌐 Browsing Mode Ended", "You can now use Steam.", "normal") - elif steam_running: - # Steam started while in browsing mode - kill it - log("Steam detected during BROWSING mode - killing Steam") - kill_steam() + self._handle_browsing( + steam_running=steam_running, + browser_running=browser_running, + ) def get_status(self) -> str: """Get current status string.""" @@ -266,33 +334,47 @@ class FocusMode: duration = "" if self.mode_start_time: - elapsed = datetime.now() - self.mode_start_time + elapsed = ( + datetime.now(tz=timezone.utc) + - self.mode_start_time + ) minutes = int(elapsed.total_seconds() // 60) duration = f" (active for {minutes}m)" if self.current_mode == "gaming": - return f"🎮 GAMING mode{duration} - browsers blocked" - return f"🌐 BROWSING mode{duration} - Steam blocked" + return ( + f"\U0001f3ae GAMING mode{duration}" + " - browsers blocked" + ) + return ( + f"\U0001f310 BROWSING mode{duration}" + " - Steam blocked" + ) def write_status(focus: FocusMode) -> None: """Write current status to state file for external queries.""" - try: + with contextlib.suppress(OSError): STATE_DIR.mkdir(parents=True, exist_ok=True) status_file = STATE_DIR / "status" - with open(status_file, "w") as f: + with status_file.open("w") as f: f.write(focus.get_status() + "\n") - f.write(f"mode={focus.current_mode or 'none'}\n") - except Exception: - pass + f.write( + f"mode={focus.current_mode or 'none'}\n" + ) -def main(): - """Main daemon loop.""" +def main() -> None: + """Run the main daemon loop.""" + logging.basicConfig( + format="%(message)s", level=logging.INFO + ) log("Focus Mode Daemon starting...") - # Setup signal handlers - def handle_signal(signum, frame): + def handle_signal( + signum: int, _frame: FrameType | None + ) -> None: + """Handle termination signals.""" log(f"Received signal {signum} - shutting down") sys.exit(0) @@ -306,8 +388,11 @@ def main(): processes = get_running_processes() focus.update(processes) write_status(focus) - except Exception as e: - log(f"Error in main loop: {e}") + except ( + OSError, + subprocess.SubprocessError, + ) as exc: + log(f"Error in main loop: {exc}") time.sleep(POLL_INTERVAL) diff --git a/linux_configuration/scripts/misc/testsAndMisc-bash/tools/__init__.py b/linux_configuration/scripts/misc/testsAndMisc-bash/tools/__init__.py new file mode 100644 index 0000000..52e2d4e --- /dev/null +++ b/linux_configuration/scripts/misc/testsAndMisc-bash/tools/__init__.py @@ -0,0 +1 @@ +"""Transcription and helper tools for testsAndMisc bash scripts.""" diff --git a/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_fw.py b/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_fw.py index 54695ad..820168e 100755 --- a/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_fw.py +++ b/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_fw.py @@ -1,211 +1,333 @@ #!/usr/bin/env python3 +"""Transcribe audio with faster-whisper and write .txt and .srt.""" + +from __future__ import annotations + import argparse +import contextlib from datetime import timedelta +import importlib +import logging import os +from pathlib import Path import shutil import subprocess import sys +import tempfile import time +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import types + + import numpy as np + import numpy.typing as npt + +logger = logging.getLogger(__name__) + +# Constants +_BYTES_PER_KB = 1024 +_NDIM_2D = 2 +_SAMPLE_RATE_16K = 16000 +_MIN_SAMPLES_DIAR = 1600 +_PROGRESS_THROTTLE_SEC = 0.2 +_SECONDS_PER_DAY = 60 * 60 * 24 + +# Model name to HF repo mapping +_MODEL_MAP: dict[str, str] = { + "tiny": "Systran/faster-whisper-tiny", + "tiny.en": "Systran/faster-whisper-tiny.en", + "base": "Systran/faster-whisper-base", + "base.en": "Systran/faster-whisper-base.en", + "small": "Systran/faster-whisper-small", + "small.en": "Systran/faster-whisper-small.en", + "medium": "Systran/faster-whisper-medium", + "medium.en": "Systran/faster-whisper-medium.en", + "large-v1": "Systran/faster-whisper-large-v1", + "large-v2": "Systran/faster-whisper-large-v2", + "large-v3": "Systran/faster-whisper-large-v3", + "large": "Systran/faster-whisper-large-v3", + "distil-large-v2": "Systran/faster-distil-whisper-large-v2", + "distil-large-v3": "Systran/faster-distil-whisper-large-v3", + "distil-medium.en": "Systran/faster-distil-whisper-medium.en", + "distil-small.en": "Systran/faster-distil-whisper-small.en", +} + + +def _try_import(name: str) -> types.ModuleType | None: + """Attempt to import a module, returning None on failure.""" + try: + return importlib.import_module(name) + except ImportError: + return None def format_bytes(size: int) -> str: """Format bytes as human-readable string.""" + fsize = float(size) for unit in ["B", "KB", "MB", "GB"]: - if size < 1024: - return f"{size:.1f}{unit}" - size /= 1024 - return f"{size:.1f}TB" + if fsize < _BYTES_PER_KB: + return f"{fsize:.1f}{unit}" + fsize /= _BYTES_PER_KB + return f"{fsize:.1f}TB" -def download_model_with_progress(model_name: str) -> str: - """Download model files from HuggingFace with a visible progress bar. +def _check_cache( + repo_id: str, +) -> str | None: + """Check HF cache for an already-downloaded model.""" + hh = _try_import("huggingface_hub") + if hh is None: + return None + cache_path = hh.try_to_load_from_cache( + repo_id, "model.bin" + ) + if cache_path is not None: + parent = str(Path(cache_path).parent) + logger.info( + "Model already cached, loading from: %s", + parent, + ) + return parent + return None + + +def _download_files( + repo_id: str, + required_files: list[str], +) -> str: + """Download required model files from HuggingFace.""" + hh = _try_import("huggingface_hub") + if hh is None: + msg = "huggingface_hub not available" + raise RuntimeError(msg) + + logger.info( + "Downloading model files from %s...", + repo_id, + ) + logger.info( + "This may take several minutes for large " + "models (~3GB for large-v3)", + ) + + _log_total_download_size(repo_id, required_files) + + downloaded = 0 + model_dir = "" + start_time = time.time() + + for filename in required_files: + file_start = time.time() + logger.info("DOWNLOAD %s...", filename) + try: + local_path = hh.hf_hub_download( + repo_id=repo_id, + filename=filename, + resume_download=True, + ) + elapsed = time.time() - file_start + lp = Path(local_path) + file_size = ( + lp.stat().st_size + if lp.exists() + else 0 + ) + logger.info( + "done (%s, %.1fs)", + format_bytes(file_size), + elapsed, + ) + downloaded += 1 + if downloaded == 1: + model_dir = str(lp.parent) + except OSError: + logger.info("not found (optional)") + except RuntimeError as exc: + logger.info("error: %s", exc) + + total_time = time.time() - start_time + logger.info("Download complete in %.1fs", total_time) + return model_dir + + +def _log_total_download_size( + repo_id: str, required_files: list[str] +) -> None: + """Log total download size if available.""" + hh = _try_import("huggingface_hub") + if hh is None: + return + with contextlib.suppress(OSError, RuntimeError): + fs = hh.HfFileSystem() + files_info = fs.ls(repo_id, detail=True) + total_size = sum( + f.get("size", 0) + for f in files_info + if f.get("name", "").split("/")[-1] + in required_files + ) + logger.info( + "Total download size: ~%s", + format_bytes(total_size), + ) + + +def download_model_with_progress( + model_name: str, +) -> str: + """Download model files from HuggingFace with progress. Returns the local path to the downloaded model. """ - try: - from huggingface_hub import hf_hub_download - from huggingface_hub.utils import EntryNotFoundError - except ImportError: - print( - "[WARN] huggingface_hub not available, falling back to default download", - file=sys.stderr, + hh = _try_import("huggingface_hub") + if hh is None: + logger.warning( + "huggingface_hub not available, " + "falling back to default download", ) return model_name - # Map common model names to HF repo IDs - model_map = { - "tiny": "Systran/faster-whisper-tiny", - "tiny.en": "Systran/faster-whisper-tiny.en", - "base": "Systran/faster-whisper-base", - "base.en": "Systran/faster-whisper-base.en", - "small": "Systran/faster-whisper-small", - "small.en": "Systran/faster-whisper-small.en", - "medium": "Systran/faster-whisper-medium", - "medium.en": "Systran/faster-whisper-medium.en", - "large-v1": "Systran/faster-whisper-large-v1", - "large-v2": "Systran/faster-whisper-large-v2", - "large-v3": "Systran/faster-whisper-large-v3", - "large": "Systran/faster-whisper-large-v3", - "distil-large-v2": "Systran/faster-distil-whisper-large-v2", - "distil-large-v3": "Systran/faster-distil-whisper-large-v3", - "distil-medium.en": "Systran/faster-distil-whisper-medium.en", - "distil-small.en": "Systran/faster-distil-whisper-small.en", - } + repo_id = _MODEL_MAP.get(model_name, model_name) - repo_id = model_map.get(model_name, model_name) - - # Check if it looks like a repo ID - if "/" not in repo_id and model_name not in model_map: - # Assume it's a Systran model + if "/" not in repo_id and model_name not in _MODEL_MAP: repo_id = f"Systran/faster-whisper-{model_name}" - print(f"[INFO] Checking model: {repo_id}", flush=True) + logger.info("Checking model: %s", repo_id) - # Files we need to download (model.bin is the large one) - required_files = ["config.json", "model.bin", "tokenizer.json", "vocabulary.txt"] + required_files = [ + "config.json", + "model.bin", + "tokenizer.json", + "vocabulary.txt", + ] try: - # Use snapshot_download which handles caching and shows what's happening - # First, let's check if model.bin needs downloading by checking cache - from huggingface_hub import HfFileSystem, try_to_load_from_cache - - cache_path = try_to_load_from_cache(repo_id, "model.bin") - if cache_path is not None: - print( - f"[INFO] Model already cached, loading from: {os.path.dirname(cache_path)}", - flush=True, - ) - # Return the directory containing the cached files - return os.path.dirname(cache_path) - - # Model not cached, need to download - print(f"[INFO] Downloading model files from {repo_id}...", flush=True) - print( - "[INFO] This may take several minutes for large models (~3GB for large-v3)", - flush=True, - ) - - # Get file sizes to show progress - try: - fs = HfFileSystem() - files_info = fs.ls(repo_id, detail=True) - total_size = sum( - f.get("size", 0) - for f in files_info - if f.get("name", "").split("/")[-1] in required_files - ) - print( - f"[INFO] Total download size: ~{format_bytes(total_size)}", flush=True - ) - except Exception: - pass # Size info is optional - - # Download with progress - downloaded = 0 - start_time = time.time() - - for filename in required_files: - file_start = time.time() - print(f"[DOWNLOAD] {filename}...", end=" ", flush=True) - try: - local_path = hf_hub_download( - repo_id=repo_id, - filename=filename, - resume_download=True, - ) - elapsed = time.time() - file_start - file_size = ( - os.path.getsize(local_path) if os.path.exists(local_path) else 0 - ) - print(f"done ({format_bytes(file_size)}, {elapsed:.1f}s)", flush=True) - downloaded += 1 - - # Return directory on first successful download - if downloaded == 1: - model_dir = os.path.dirname(local_path) - except EntryNotFoundError: - print("not found (optional)", flush=True) - except Exception as e: - print(f"error: {e}", flush=True) - - total_time = time.time() - start_time - print(f"[INFO] Download complete in {total_time:.1f}s", flush=True) - - return model_dir - - except Exception as e: - print( - f"[WARN] Custom download failed ({e}), falling back to default", - file=sys.stderr, + cached = _check_cache(repo_id) + if cached is not None: + return cached + return _download_files(repo_id, required_files) + except (OSError, RuntimeError) as exc: + logger.warning( + "Custom download failed (%s), " + "falling back to default", + exc, ) return model_name def format_timestamp(seconds: float) -> str: + """Format seconds as SRT timestamp HH:MM:SS,mmm.""" td = timedelta(seconds=seconds) - # Ensure SRT format HH:MM:SS,mmm total_seconds = int(td.total_seconds()) hours = total_seconds // 3600 minutes = (total_seconds % 3600) // 60 secs = total_seconds % 60 millis = int((seconds - int(seconds)) * 1000) - return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" + return ( + f"{hours:02d}:{minutes:02d}:" + f"{secs:02d},{millis:03d}" + ) -def write_srt(segments, srt_path: str): - with open(srt_path, "w", encoding="utf-8") as f: +def write_srt( + segments: list[Any], srt_path: str +) -> None: + """Write segments to an SRT subtitle file.""" + with Path(srt_path).open( + "w", encoding="utf-8" + ) as f: for i, seg in enumerate(segments, start=1): start = format_timestamp(seg.start) end = format_timestamp(seg.end) text = (seg.text or "").strip() if not text: continue - f.write(f"{i}\n{start} --> {end}\n{text}\n\n") + f.write( + f"{i}\n{start} --> {end}\n{text}\n\n" + ) -def write_txt(segments, txt_path: str): - with open(txt_path, "w", encoding="utf-8") as f: +def write_txt( + segments: list[Any], txt_path: str +) -> None: + """Write segments as plain text, one per line.""" + with Path(txt_path).open( + "w", encoding="utf-8" + ) as f: for seg in segments: text = (seg.text or "").strip() if text: f.write(text + "\n") -def write_srt_with_speakers(segments, labels: list[int], path: str): - with open(path, "w", encoding="utf-8") as f: - for i, (seg, lab) in enumerate(zip(segments, labels, strict=False), start=1): +def write_srt_with_speakers( + segments: list[Any], + labels: list[int], + path: str, +) -> None: + """Write SRT subtitles with speaker labels.""" + with Path(path).open("w", encoding="utf-8") as f: + for i, (seg, lab) in enumerate( + zip(segments, labels, strict=False), + start=1, + ): text = (seg.text or "").strip() if not text: continue - spk = f"SPK{lab+1}" + spk = f"SPK{lab + 1}" + start_ts = format_timestamp(seg.start) + end_ts = format_timestamp(seg.end) f.write( - f"{i}\n{format_timestamp(seg.start)} --> {format_timestamp(seg.end)}\n[{spk}] {text}\n\n" + f"{i}\n{start_ts} --> {end_ts}\n" + f"[{spk}] {text}\n\n" ) -def write_txt_with_speakers(segments, labels: list[int], path: str): - with open(path, "w", encoding="utf-8") as f: - for seg, lab in zip(segments, labels, strict=False): +def write_txt_with_speakers( + segments: list[Any], + labels: list[int], + path: str, +) -> None: + """Write plain text with speaker labels.""" + with Path(path).open("w", encoding="utf-8") as f: + for seg, lab in zip( + segments, labels, strict=False + ): text = (seg.text or "").strip() if text: - spk = f"SPK{lab+1}" + spk = f"SPK{lab + 1}" f.write(f"[{spk}] {text}\n") -def write_rttm(segments, labels: list[int], path: str, file_id: str = "audio"): - # RTTM format: SPEAKER 1 - with open(path, "w", encoding="utf-8") as f: - for seg, lab in zip(segments, labels, strict=False): - start = float(getattr(seg, "start", 0.0) or 0.0) - end = float(getattr(seg, "end", start) or start) +def write_rttm( + segments: list[Any], + labels: list[int], + path: str, + file_id: str = "audio", +) -> None: + """Write RTTM speaker diarization output.""" + with Path(path).open("w", encoding="utf-8") as f: + for seg, lab in zip( + segments, labels, strict=False + ): + start = float( + getattr(seg, "start", 0.0) or 0.0 + ) + end = float( + getattr(seg, "end", start) or start + ) dur = max(0.0, end - start) - name = f"SPK{lab+1}" + name = f"SPK{lab + 1}" f.write( - f"SPEAKER {file_id} 1 {start:.3f} {dur:.3f} {name} \n" + f"SPEAKER {file_id} 1 " + f"{start:.3f} {dur:.3f} " + f" {name} \n" ) def hhmmss(seconds: float) -> str: + """Format seconds as HH:MM:SS string.""" seconds = max(0.0, float(seconds)) total_seconds = int(seconds) h = total_seconds // 3600 @@ -214,104 +336,190 @@ def hhmmss(seconds: float) -> str: return f"{h:02d}:{m:02d}:{s:02d}" -def get_media_duration(path: str) -> float | None: - """Try to get media duration in seconds using ffmpeg-python or ffprobe. - Returns None if unavailable. - """ - # Try ffmpeg-python first (if installed) which uses ffprobe under the hood +def _probe_with_ffmpeg_python( + path: str, +) -> float | None: + """Try ffmpeg-python to get duration.""" + ffmpeg_mod = _try_import("ffmpeg") + if ffmpeg_mod is None: + return None try: - import ffmpeg # type: ignore - - probe = ffmpeg.probe(path) + probe = ffmpeg_mod.probe(path) fmt = probe.get("format", {}) if "duration" in fmt: - return float(fmt["duration"]) # type: ignore - except Exception: + return float(fmt["duration"]) + except (OSError, RuntimeError): pass - - # Fallback: call ffprobe directly if available - if shutil.which("ffprobe"): - try: - out = subprocess.check_output( - [ - "ffprobe", - "-v", - "error", - "-show_entries", - "format=duration", - "-of", - "default=noprint_wrappers=1:nokey=1", - path, - ], - stderr=subprocess.DEVNULL, - ) - return float(out.decode().strip()) - except Exception: - return None return None -def _resample_linear(x, src_sr: int, tgt_sr: int): - import numpy as np +def _probe_with_ffprobe(path: str) -> float | None: + """Try ffprobe CLI to get duration.""" + ffprobe_bin = shutil.which("ffprobe") + if ffprobe_bin is None: + return None + try: + out = subprocess.check_output( + [ + ffprobe_bin, + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=" + "noprint_wrappers=1:nokey=1", + path, + ], + stderr=subprocess.DEVNULL, + ) + return float(out.decode().strip()) + except ( + OSError, + subprocess.CalledProcessError, + ValueError, + ): + return None + + +def get_media_duration(path: str) -> float | None: + """Try to get media duration in seconds. + + Returns None if unavailable. + """ + result = _probe_with_ffmpeg_python(path) + if result is not None: + return result + return _probe_with_ffprobe(path) + + +def _resample_linear( + x: npt.NDArray[np.float32], + src_sr: int, + tgt_sr: int, +) -> npt.NDArray[np.float32]: + """Linearly resample 1-D audio array.""" + np_mod = _try_import("numpy") + if np_mod is None: + msg = "numpy is required for resampling" + raise RuntimeError(msg) if src_sr == tgt_sr: return x ratio = float(tgt_sr) / float(src_sr) - n_out = max(1, int(round(x.shape[-1] * ratio))) - xp = np.linspace(0.0, 1.0, num=x.shape[-1], endpoint=False) - xq = np.linspace(0.0, 1.0, num=n_out, endpoint=False) - y = np.interp(xq, xp, x.astype(np.float32)) - return y.astype(np.float32) + n_out = max(1, round(x.shape[-1] * ratio)) + xp = np_mod.linspace( + 0.0, 1.0, num=x.shape[-1], endpoint=False + ) + xq = np_mod.linspace( + 0.0, 1.0, num=n_out, endpoint=False + ) + y = np_mod.interp( + xq, xp, x.astype(np_mod.float32) + ) + return y.astype(np_mod.float32) -def _kmeans_cosine(embs, k: int, iters: int = 50, seed: int = 0): - import numpy as np +def _kmeans_cosine( + embs: list[Any], + k: int, + iters: int = 50, + seed: int = 0, +) -> npt.NDArray[np.int64]: + """Cluster embeddings with cosine-similarity k-means.""" + np_mod = _try_import("numpy") + if np_mod is None: + msg = "numpy is required for clustering" + raise RuntimeError(msg) - rng = np.random.default_rng(seed) - X = np.asarray(embs, dtype=np.float32) - if X.ndim != 2 or X.shape[0] == 0: - return np.zeros((0,), dtype=np.int64) - # Normalize - X = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8) - # Init centroids as random samples - idxs = rng.choice(X.shape[0], size=min(k, X.shape[0]), replace=False) - C = X[idxs] - # If fewer samples than k, pad with random - if C.shape[0] < k: - pad = rng.standard_normal(size=(k - C.shape[0], X.shape[1])).astype(np.float32) - pad /= np.linalg.norm(pad, axis=1, keepdims=True) + 1e-8 - C = np.concatenate([C, pad], axis=0) + rng = np_mod.random.default_rng(seed) + features = np_mod.asarray(embs, dtype=np_mod.float32) + if ( + features.ndim != _NDIM_2D + or features.shape[0] == 0 + ): + return np_mod.zeros((0,), dtype=np_mod.int64) + features = features / ( + np_mod.linalg.norm( + features, axis=1, keepdims=True + ) + + 1e-8 + ) + idxs = rng.choice( + features.shape[0], + size=min(k, features.shape[0]), + replace=False, + ) + centroids = features[idxs] + if centroids.shape[0] < k: + pad = rng.standard_normal( + size=( + k - centroids.shape[0], + features.shape[1], + ) + ).astype(np_mod.float32) + pad /= ( + np_mod.linalg.norm( + pad, axis=1, keepdims=True + ) + + 1e-8 + ) + centroids = np_mod.concatenate( + [centroids, pad], axis=0 + ) + return _run_kmeans_iterations( + np_mod, features, centroids, k, iters + ) + + +def _run_kmeans_iterations( + np_mod: object, + features: object, + centroids: object, + k: int, + iters: int, +) -> object: + """Run k-means iteration loop and return labels.""" + labels: object = None for _ in range(iters): - # Assign by cosine similarity (maximize dot product) - sims = X @ C.T # (n, k) + sims = features @ centroids.T labels = sims.argmax(axis=1) - newC = np.zeros_like(C) + new_c = np_mod.zeros_like(centroids) for j in range(k): - sel = X[labels == j] + sel = features[labels == j] if sel.shape[0] == 0: - newC[j] = C[j] + new_c[j] = centroids[j] else: v = sel.mean(axis=0) - v /= np.linalg.norm(v) + 1e-8 - newC[j] = v - if np.allclose(newC, C, atol=1e-4): + v /= np_mod.linalg.norm(v) + 1e-8 + new_c[j] = v + if np_mod.allclose( + new_c, centroids, atol=1e-4 + ): break - C = newC + centroids = new_c return labels -def _ffmpeg_transcode_to_wav16_mono(src_path: str) -> str | None: - """If ffmpeg is available, transcode input to a temporary 16k mono WAV and return its path.""" - if not shutil.which("ffmpeg"): - return None - import tempfile +def _ffmpeg_transcode_to_wav16_mono( + src_path: str, +) -> str | None: + """Transcode input to a temporary 16k mono WAV. + + Returns its path, or None if ffmpeg is unavailable. + """ + ffmpeg_bin = shutil.which("ffmpeg") + if ffmpeg_bin is None: + return None + with tempfile.NamedTemporaryFile( + prefix="fw_diar_", + suffix=".wav", + delete=False, + ) as tmp: + tmp_path = tmp.name - tmp = tempfile.NamedTemporaryFile(prefix="fw_diar_", suffix=".wav", delete=False) - tmp_path = tmp.name - tmp.close() - # Run ffmpeg quietly cmd = [ - "ffmpeg", + ffmpeg_bin, "-y", "-v", "error", @@ -327,132 +535,207 @@ def _ffmpeg_transcode_to_wav16_mono(src_path: str) -> str | None: ] try: subprocess.run( - cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + cmd, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, ) + except (OSError, subprocess.CalledProcessError): + with contextlib.suppress(OSError): + Path(tmp_path).unlink() + return None + else: return tmp_path - except Exception: - try: - os.unlink(tmp_path) - except Exception: - pass - return None -def diarize_segments(audio_path: str, segments, num_speakers: int = 2) -> list | None: - """Simple diarization: compute speaker embeddings per segment and cluster with KMeans. - Returns a list of speaker labels aligned with segments, or None on failure. +def _cleanup_temp(path: str | None) -> None: + """Remove a temporary file if it exists.""" + if path is not None: + with contextlib.suppress(OSError): + Path(path).unlink() + + +def _load_audio( + audio_path: str, +) -> tuple[Any, int, str | None] | None: + """Load audio, with ffmpeg fallback. + + Returns (wav, sample_rate, temp_path) or None. """ - try: - import soundfile as sf - - # Use non-deprecated import path - from speechbrain.inference import EncoderClassifier - import torch - except Exception as e: - print( - f"[WARN] Diarization dependencies missing ({e}); skipping speaker labels.", - file=sys.stderr, - ) + sf = _try_import("soundfile") + if sf is None: return None - # Load audio - temp_to_cleanup: str | None = None try: - wav, sr = sf.read(audio_path, dtype="float32", always_2d=False) - except Exception as e: - # Try ffmpeg transcoding fallback - alt = _ffmpeg_transcode_to_wav16_mono(audio_path) + wav, sr = sf.read( + audio_path, + dtype="float32", + always_2d=False, + ) + except OSError as exc: + alt = _ffmpeg_transcode_to_wav16_mono( + audio_path + ) if alt is None: - print( - f"[WARN] Could not read audio for diarization and no ffmpeg fallback available: {e}", - file=sys.stderr, + logger.warning( + "Could not read audio for diarization " + "and no ffmpeg fallback: %s", + exc, ) return None try: - wav, sr = sf.read(alt, dtype="float32", always_2d=False) - temp_to_cleanup = alt - except Exception as e2: - print( - f"[WARN] Could not read transcoded audio for diarization: {e2}", - file=sys.stderr, + wav, sr = sf.read( + alt, + dtype="float32", + always_2d=False, ) - try: - os.unlink(alt) - except Exception: - pass + except OSError as exc2: + logger.warning( + "Could not read transcoded audio: %s", + exc2, + ) + _cleanup_temp(alt) return None - if wav.ndim == 2: # mixdown - wav = wav.mean(axis=1) - # Resample to 16k for ECAPA - wav16 = _resample_linear(wav, sr, 16000) + else: + return wav, sr, alt + else: + return wav, sr, None - # Load speaker embedding model (CPU is fine) + +def _load_speaker_classifier( + temp_to_cleanup: str | None, +) -> object | None: + """Load the ECAPA speaker embedding classifier.""" + sb_inf = _try_import("speechbrain.inference") + if sb_inf is None: + return None try: - classifier = EncoderClassifier.from_hparams( + cache_dir = ( + Path.home() / ".cache" / "speechbrain_ecapa" + ) + classifier = sb_inf.EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": "cpu"}, - savedir=os.path.join( - os.path.expanduser("~"), ".cache", "speechbrain_ecapa" - ), + savedir=str(cache_dir), ) - except Exception as e: - print(f"[WARN] Could not load speaker embedding model: {e}", file=sys.stderr) - if temp_to_cleanup: - try: - os.unlink(temp_to_cleanup) - except Exception: - pass + except (OSError, RuntimeError) as exc: + logger.warning( + "Could not load speaker embedding model: %s", + exc, + ) + _cleanup_temp(temp_to_cleanup) return None + else: + return classifier - embs = [] - # Extract embedding per segment window + +def _extract_embeddings( + segments: list[Any], + wav16: object, + classifier: object, + torch_mod: types.ModuleType, +) -> list[Any]: + """Extract speaker embeddings per segment.""" + embs: list[Any] = [] for seg in segments: s = float(getattr(seg, "start", 0.0) or 0.0) e = float(getattr(seg, "end", s) or s) if e <= s: - e = s + 0.2 # minimal window - # Convert to samples in 16k - i0 = int(s * 16000) - i1 = int(e * 16000) - # Add small margins to help very short segments - pad = int(0.05 * 16000) + e = s + 0.2 + i0 = int(s * _SAMPLE_RATE_16K) + i1 = int(e * _SAMPLE_RATE_16K) + pad = int(0.05 * _SAMPLE_RATE_16K) i0 = max(0, i0 - pad) i1 = min(len(wav16), i1 + pad) - if i1 - i0 < 1600: # <0.1s, too short; expand if possible - i1 = min(len(wav16), i0 + 1600) - segment_wav = torch.tensor(wav16[i0:i1]).unsqueeze(0) - with torch.no_grad(): + if i1 - i0 < _MIN_SAMPLES_DIAR: + i1 = min( + len(wav16), i0 + _MIN_SAMPLES_DIAR + ) + seg_wav = torch_mod.tensor( + wav16[i0:i1] + ).unsqueeze(0) + with torch_mod.no_grad(): emb = ( - classifier.encode_batch(segment_wav).squeeze(0).squeeze(0).cpu().numpy() + classifier.encode_batch(seg_wav) + .squeeze(0) + .squeeze(0) + .cpu() + .numpy() ) embs.append(emb.astype("float32")) + return embs + + +def diarize_segments( + audio_path: str, + segments: list[Any], + num_speakers: int = 2, +) -> list[int] | None: + """Compute speaker embeddings per segment and cluster. + + Returns speaker labels aligned with segments, + or None on failure. + """ + torch_mod = _try_import("torch") + if torch_mod is None: + logger.warning( + "Diarization dependencies missing; " + "skipping speaker labels.", + ) + return None + + audio_result = _load_audio(audio_path) + if audio_result is None: + return None + wav, sr, temp_to_cleanup = audio_result + + if wav.ndim == _NDIM_2D: + wav = wav.mean(axis=1) + wav16 = _resample_linear( + wav, sr, _SAMPLE_RATE_16K + ) + + classifier = _load_speaker_classifier( + temp_to_cleanup + ) + if classifier is None: + return None + + embs = _extract_embeddings( + segments, wav16, classifier, torch_mod + ) if len(embs) == 0: return None - # Cluster - labels = _kmeans_cosine(embs, k=max(1, int(num_speakers))) - if temp_to_cleanup: - try: - os.unlink(temp_to_cleanup) - except Exception: - pass + labels = _kmeans_cosine( + embs, k=max(1, int(num_speakers)) + ) + _cleanup_temp(temp_to_cleanup) return labels.tolist() -def main(): +def _parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" parser = argparse.ArgumentParser( - description="Transcribe audio with faster-whisper and write .txt and .srt" + description=( + "Transcribe audio with faster-whisper " + "and write .txt and .srt" + ), + ) + parser.add_argument( + "input", help="Path to audio/video file" ) - parser.add_argument("input", help="Path to audio/video file") parser.add_argument( "--model", - default=os.environ.get("FW_MODEL", "large-v3"), + default=os.environ.get( + "FW_MODEL", "large-v3" + ), help="Model size or path (default: large-v3)", ) parser.add_argument( "--language", default=None, - help="Language code (e.g., en). Leave None for auto-detect", + help="Language code (e.g., en). None=auto", ) parser.add_argument( "--device", @@ -464,145 +747,259 @@ def main(): "--compute-type", dest="compute_type", default=os.environ.get("FW_COMPUTE", "auto"), - help="Compute type (auto,int8,float16,float32,int8_float16,etc.)", + help="Compute type (auto,int8,float16,...)", ) parser.add_argument( - "--outdir", default=None, help="Output directory (default: next to input)" + "--outdir", + default=None, + help="Output dir (default: next to input)", ) parser.add_argument( - "--no-progress", action="store_true", help="Disable live progress output" + "--no-progress", + action="store_true", + help="Disable live progress output", ) parser.add_argument( - "--diarize", action="store_true", help="Enable speaker diarization (labels)" + "--diarize", + action="store_true", + help="Enable speaker diarization (labels)", ) parser.add_argument( "--num-speakers", type=int, - default=int(os.environ.get("FW_NUM_SPEAKERS", "2")), - help="Assumed number of speakers (default: 2)", + default=int( + os.environ.get("FW_NUM_SPEAKERS", "2") + ), + help="Number of speakers (default: 2)", ) - args = parser.parse_args() + return parser.parse_args() - try: - from faster_whisper import WhisperModel - except Exception as e: - print( - "[ERROR] faster-whisper is not installed in this environment.", - file=sys.stderr, - ) - print(str(e), file=sys.stderr) - return 2 - inp = os.path.abspath(args.input) - if not os.path.exists(inp): - print(f"[ERROR] Input file not found: {inp}", file=sys.stderr) - return 2 - - outdir = os.path.abspath(args.outdir or os.path.dirname(inp) or ".") - os.makedirs(outdir, exist_ok=True) - base = os.path.splitext(os.path.basename(inp))[0] - srt_path = os.path.join(outdir, base + ".srt") - txt_path = os.path.join(outdir, base + ".txt") - - # Device and compute_type heuristics +def _resolve_device_and_compute( + args: argparse.Namespace, +) -> tuple[str, str]: + """Resolve device and compute_type from args.""" device = args.device compute_type = args.compute_type if device == "auto": device = "cpu" if compute_type == "auto": - # Prefer accuracy over speed by default - compute_type = "float16" if device == "cuda" else "float32" + compute_type = ( + "float16" + if device == "cuda" + else "float32" + ) + return device, compute_type - print( - f"[INFO] Loading model='{args.model}', device='{device}', compute_type='{compute_type}'" + +def _run_progress_loop( + args: argparse.Namespace, + model: object, + inp: str, + total_duration: float | None, +) -> tuple[list[Any], object]: + """Transcribe with live progress output.""" + start_ts = time.time() + iter_segments, info = model.transcribe( + inp, language=args.language + ) + collected: list[Any] = [] + processed = 0.0 + last_prt = 0.0 + tty = sys.stderr.isatty() + + for seg in iter_segments: + collected.append(seg) + if getattr(seg, "end", None) is not None: + processed = max( + processed, float(seg.end) + ) + now = time.time() + if not args.no_progress and ( + tty + or (now - last_prt) + >= _PROGRESS_THROTTLE_SEC + ): + last_prt = now + line = _format_progress_line( + processed, + total_duration, + now, + start_ts, + ) + if tty: + logger.info("\r%s", line) + else: + logger.info("%s", line) + + if not args.no_progress and tty: + logger.info("") + + return collected, info + + +def _format_progress_line( + processed: float, + total_duration: float | None, + now: float, + start_ts: float, +) -> str: + """Format a progress line string.""" + if total_duration and total_duration > 0: + pct = max( + 0.0, + min( + 100.0, + (processed / total_duration) * 100.0, + ), + ) + elapsed = now - start_ts + line = ( + f"[PROGRESS] {hhmmss(processed)} / " + f"{hhmmss(total_duration)} " + f"({pct:5.1f}%)" + ) + if processed > 0: + rate = processed / max(1e-6, elapsed) + remaining = max( + 0.0, total_duration - processed + ) + eta = remaining / max(1e-6, rate) + if eta < _SECONDS_PER_DAY: + line += f" ETA ~{hhmmss(eta)}" + return line + return f"[PROGRESS] processed {hhmmss(processed)}" + + +def _write_diarized_outputs( + args: argparse.Namespace, + inp: str, + outdir: Path, + base: str, + collected: list[Any], +) -> None: + """Optionally diarize and write speaker outputs.""" + if not args.diarize: + return + labels = diarize_segments( + inp, + collected, + num_speakers=args.num_speakers, + ) + if labels is not None and len(labels) == len( + collected + ): + diar_srt = str(outdir / (base + ".diar.srt")) + diar_txt = str(outdir / (base + ".diar.txt")) + rttm_path = str(outdir / (base + ".rttm")) + write_srt_with_speakers( + collected, labels, diar_srt + ) + write_txt_with_speakers( + collected, labels, diar_txt + ) + write_rttm( + collected, + labels, + rttm_path, + file_id=base, + ) + logger.info("Wrote: %s", diar_txt) + logger.info("Wrote: %s", diar_srt) + logger.info("Wrote: %s", rttm_path) + else: + logger.warning( + "Diarization failed or returned " + "mismatched labels; writing plain.", + ) + + +def main() -> int: + """Run the main transcription pipeline.""" + logging.basicConfig( + level=logging.INFO, + format="%(message)s", ) - # Pre-download model files with explicit progress if not already cached - model_path = args.model - if not os.path.isdir(args.model): # Not a local path, need to download from HF - model_path = download_model_with_progress(args.model) + args = _parse_args() - # Show CTranslate2 conversion progress - import logging + fw = _try_import("faster_whisper") + if fw is None: + logger.error( + "faster-whisper is not installed " + "in this environment.", + ) + return 2 + + inp_path = Path(args.input).resolve() + if not inp_path.exists(): + logger.error("Input file not found: %s", inp_path) + return 2 + + inp = str(inp_path) + outdir = Path( + args.outdir or str(inp_path.parent) or "." + ).resolve() + outdir.mkdir(parents=True, exist_ok=True) + base = inp_path.stem + srt_path = str(outdir / (base + ".srt")) + txt_path = str(outdir / (base + ".txt")) + + device, compute_type = ( + _resolve_device_and_compute(args) + ) + + logger.info( + "Loading model='%s', device='%s', " + "compute_type='%s'", + args.model, + device, + compute_type, + ) + + model_path: str = args.model + if not Path(args.model).is_dir(): + model_path = download_model_with_progress( + args.model + ) - logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") ct2_logger = logging.getLogger("faster_whisper") ct2_logger.setLevel(logging.INFO) - print("[INFO] Initializing model...", flush=True) - model = WhisperModel(model_path, device=device, compute_type=compute_type) - print("[INFO] Model loaded successfully.", flush=True) + logger.info("Initializing model...") + model = fw.WhisperModel( + model_path, + device=device, + compute_type=compute_type, + ) + logger.info("Model loaded successfully.") - # Transcription with live progress total_duration = get_media_duration(inp) if total_duration: - print(f"[INFO] Media duration: {hhmmss(total_duration)}") - start_ts = time.time() + logger.info( + "Media duration: %s", + hhmmss(total_duration), + ) - iter_segments, info = model.transcribe(inp, language=args.language) - collected = [] - processed = 0.0 - last_print = 0.0 - tty = sys.stderr.isatty() - for seg in iter_segments: - collected.append(seg) - # Update processed time from segment end if available - if getattr(seg, "end", None) is not None: - processed = max(processed, float(seg.end)) - now = time.time() - # Print each segment or throttle to ~5 per second - if not args.no_progress and (tty or (now - last_print) >= 0.2): - last_print = now - if total_duration and total_duration > 0: - pct = max(0.0, min(100.0, (processed / total_duration) * 100.0)) - elapsed = now - start_ts - eta = None - if processed > 0: - rate = processed / max(1e-6, elapsed) - remaining = max(0.0, total_duration - processed) - eta = remaining / max(1e-6, rate) - line = f"[PROGRESS] {hhmmss(processed)} / {hhmmss(total_duration)} ({pct:5.1f}%)" - if eta is not None and eta < 60 * 60 * 24: # cap unrealistic values - line += f" ETA ~{hhmmss(eta)}" - else: - line = f"[PROGRESS] processed {hhmmss(processed)}" - if tty: - print("\r" + line, end="", file=sys.stderr, flush=True) - else: - print(line, file=sys.stderr, flush=True) - - # Finish progress line - if not args.no_progress and sys.stderr.isatty(): - print(file=sys.stderr) # newline - - print( - f"[INFO] Detected language: {getattr(info, 'language', None)} (prob={getattr(info, 'language_probability', None)})" + collected, info = _run_progress_loop( + args, model, inp, total_duration ) - print(f"[INFO] Segments: {len(collected)}") - # Optionally diarize - if args.diarize: - labels = diarize_segments(inp, collected, num_speakers=args.num_speakers) - if labels is not None and len(labels) == len(collected): - diar_srt = os.path.join(outdir, base + ".diar.srt") - diar_txt = os.path.join(outdir, base + ".diar.txt") - rttm_path = os.path.join(outdir, base + ".rttm") - write_srt_with_speakers(collected, labels, diar_srt) - write_txt_with_speakers(collected, labels, diar_txt) - write_rttm(collected, labels, rttm_path, file_id=base) - print( - f"[OK] Wrote: {diar_txt}\n[OK] Wrote: {diar_srt}\n[OK] Wrote: {rttm_path}" - ) - else: - print( - "[WARN] Diarization failed or returned mismatched labels; writing plain outputs.", - file=sys.stderr, - ) + logger.info( + "Detected language: %s (prob=%s)", + getattr(info, "language", None), + getattr(info, "language_probability", None), + ) + logger.info("Segments: %d", len(collected)) + + _write_diarized_outputs( + args, inp, outdir, base, collected + ) - # Write base outputs write_txt(collected, txt_path) write_srt(collected, srt_path) - print(f"[OK] Wrote: {txt_path}\n[OK] Wrote: {srt_path}") + logger.info("Wrote: %s", txt_path) + logger.info("Wrote: %s", srt_path) return 0 diff --git a/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_helpers.py b/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_helpers.py index b87e3fe..b92624d 100755 --- a/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_helpers.py +++ b/linux_configuration/scripts/misc/testsAndMisc-bash/tools/transcribe_helpers.py @@ -1,13 +1,31 @@ #!/usr/bin/env python3 """Helper utilities for transcribe.sh - replaces inline Python snippets.""" +from __future__ import annotations + import argparse import array +import importlib +import logging import math import os import sys +from typing import TYPE_CHECKING import wave +if TYPE_CHECKING: + import types + +logger = logging.getLogger(__name__) + + +def _try_import(name: str) -> types.ModuleType | None: + """Attempt to import a module, returning None on failure.""" + try: + return importlib.import_module(name) + except ImportError: + return None + def get_python_version() -> str: """Return Python major.minor version string.""" @@ -16,42 +34,36 @@ def get_python_version() -> str: def check_faster_whisper() -> bool: """Check if faster_whisper is importable. Exit 7 if not.""" - try: - import faster_whisper # noqa: F401 - - return True - except ImportError: - return False + return _try_import("faster_whisper") is not None def check_diarization_deps() -> bool: - """Check if diarization dependencies are available. Returns False with warning if missing.""" - try: - import soundfile # noqa: F401 - import speechbrain # noqa: F401 - import torch # noqa: F401 + """Check if diarization dependencies are available. - return True - except Exception as e: - print( - f"[WARN] Diarization deps missing offline ({e}); speaker labels will be skipped." + Returns False with warning if missing. + """ + _sf = _try_import("soundfile") + _sb = _try_import("speechbrain") + _torch = _try_import("torch") + if _sf is None or _sb is None or _torch is None: + logger.warning( + "Diarization deps missing offline; " + "speaker labels will be skipped.", ) return False + return True def check_ctranslate2() -> bool: """Check if ctranslate2 is importable.""" - try: - import ctranslate2 # noqa: F401 - - return True - except ImportError: - return False + return _try_import("ctranslate2") is not None -def print_deps_installed(): +def print_deps_installed() -> None: """Print confirmation that Python dependencies are installed.""" - print(f"[PY] Python {sys.version.split()[0]} dependencies installed.") + logger.info( + "Python %s dependencies installed.", sys.version.split()[0] + ) def generate_sine_wav( @@ -84,7 +96,12 @@ def generate_sine_wav( min( 1.0, amplitude - * math.sin(2 * math.pi * frequency * (i / sample_rate)), + * math.sin( + 2 + * math.pi + * frequency + * (i / sample_rate) + ), ), ) * 32767 @@ -97,10 +114,11 @@ def generate_sine_wav( wf.setsampwidth(2) wf.setframerate(sample_rate) wf.writeframes(data.tobytes()) - return True - except Exception as e: - print(f"[ERROR] Failed to generate WAV: {e}", file=sys.stderr) + except OSError: + logger.exception("Failed to generate WAV") return False + else: + return True def prepare_model(model_name: str, model_dir: str) -> bool: @@ -113,35 +131,40 @@ def prepare_model(model_name: str, model_dir: str) -> bool: Returns: True on success, False on failure """ - try: - from faster_whisper import WhisperModel - - # Enable HuggingFace Hub progress bars for model download - try: - from huggingface_hub import logging as hf_logging - - hf_logging.set_verbosity_info() - import huggingface_hub - - huggingface_hub.constants.HF_HUB_DISABLE_PROGRESS_BARS = False - os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0" - except ImportError: - pass - - print(f"[PY] Preparing model '{model_name}' into {model_dir}") - print( - "[INFO] Downloading model files (progress bar should appear below)...", - flush=True, - ) - WhisperModel( - model_name, device="cpu", compute_type="int8", download_root=model_dir - ) - print("[PY] Model prepared.") - return True - except Exception as e: - print(f"[ERROR] Failed to prepare model: {e}", file=sys.stderr) + fw = _try_import("faster_whisper") + if fw is None: + logger.error("faster_whisper is not installed") return False + try: + hf_logging = _try_import("huggingface_hub.logging") + if hf_logging is not None: + hf_logging.set_verbosity_info() + hh = _try_import("huggingface_hub") + if hh is not None: + hh.constants.HF_HUB_DISABLE_PROGRESS_BARS = False + os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0" + + logger.info( + "Preparing model '%s' into %s", model_name, model_dir + ) + logger.info( + "Downloading model files " + "(progress bar should appear below)...", + ) + fw.WhisperModel( + model_name, + device="cpu", + compute_type="int8", + download_root=model_dir, + ) + logger.info("Model prepared.") + except (OSError, RuntimeError): + logger.exception("Failed to prepare model") + return False + else: + return True + def test_cuda() -> bool: """Test CUDA initialization with faster-whisper. @@ -149,30 +172,96 @@ def test_cuda() -> bool: Returns: True if CUDA works, False otherwise """ - try: - from faster_whisper import WhisperModel - - WhisperModel("tiny", device="cuda", compute_type="float16") - print("[PY] CUDA test init succeeded.") - return True - except Exception as e: - print(f"[ERROR] CUDA test failed: {e}", file=sys.stderr) + fw = _try_import("faster_whisper") + if fw is None: + logger.error("faster_whisper is not installed") return False + try: + fw.WhisperModel( + "tiny", device="cuda", compute_type="float16" + ) + logger.info("CUDA test init succeeded.") + except (OSError, RuntimeError): + logger.exception("CUDA test failed") + return False + else: + return True + + +def _handle_python_version() -> None: + """Handle python-version command.""" + logger.info("%s", get_python_version()) + + +def _handle_check_faster_whisper() -> None: + """Handle check-faster-whisper command.""" + if not check_faster_whisper(): + logger.error( + "Python dependency 'faster_whisper' not found in " + "offline mode. Run with --online to install.", + ) + sys.exit(7) + + +def _handle_check_diarization() -> None: + """Handle check-diarization command.""" + check_diarization_deps() + + +def _handle_check_ctranslate2() -> None: + """Handle check-ctranslate2 command.""" + if not check_ctranslate2(): + sys.exit(1) + + +def _handle_deps_installed() -> None: + """Handle deps-installed command.""" + print_deps_installed() + + +def _handle_generate_wav(args: argparse.Namespace) -> None: + """Handle generate-wav command.""" + if not args.file: + logger.error("--file is required for generate-wav") + sys.exit(2) + if not generate_sine_wav(args.file): + sys.exit(1) + + +def _handle_prepare_model(args: argparse.Namespace) -> None: + """Handle prepare-model command.""" + if not args.model or not args.model_dir: + logger.error( + "--model and --model-dir are required for prepare-model", + ) + sys.exit(2) + if not prepare_model(args.model, args.model_dir): + sys.exit(1) + + +def _handle_test_cuda() -> None: + """Handle test-cuda command.""" + if not test_cuda(): + sys.exit(1) + + +def main() -> None: + """Parse arguments and dispatch helper commands.""" + logging.basicConfig(format="%(message)s", level=logging.INFO) -def main(): parser = argparse.ArgumentParser( description="Helper utilities for transcribe.sh", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Commands: python-version Print Python major.minor version - check-faster-whisper Check if faster_whisper is installed (exit 7 if not) + check-faster-whisper Check if faster_whisper is installed check-diarization Check diarization deps (warn if missing) - check-ctranslate2 Check if ctranslate2 is installed (exit 1 if not) - deps-installed Print deps installed confirmation message - generate-wav FILE Generate a 3s 1kHz sine wave WAV file - prepare-model Download model for offline use (requires --model and --model-dir) + check-ctranslate2 Check if ctranslate2 is installed + deps-installed Print deps installed confirmation + generate-wav FILE Generate a 3s 1kHz sine wave WAV + prepare-model Download model for offline use test-cuda Test CUDA initialization """, ) @@ -190,46 +279,32 @@ Commands: ], help="Command to run", ) - parser.add_argument("--file", help="Output file path (for generate-wav)") - parser.add_argument("--model", help="Model name (for prepare-model)") - parser.add_argument("--model-dir", help="Model directory (for prepare-model)") + parser.add_argument( + "--file", help="Output file path (for generate-wav)" + ) + parser.add_argument( + "--model", help="Model name (for prepare-model)" + ) + parser.add_argument( + "--model-dir", help="Model directory (for prepare-model)" + ) args = parser.parse_args() - if args.command == "python-version": - print(get_python_version()) - elif args.command == "check-faster-whisper": - if not check_faster_whisper(): - print( - "Python dependency 'faster_whisper' not found in offline mode. Run with --online to install.", - file=sys.stderr, - ) - sys.exit(7) - elif args.command == "check-diarization": - check_diarization_deps() - elif args.command == "check-ctranslate2": - if not check_ctranslate2(): - sys.exit(1) - elif args.command == "deps-installed": - print_deps_installed() - elif args.command == "generate-wav": - if not args.file: - print("--file is required for generate-wav", file=sys.stderr) - sys.exit(2) - if not generate_sine_wav(args.file): - sys.exit(1) - elif args.command == "prepare-model": - if not args.model or not args.model_dir: - print( - "--model and --model-dir are required for prepare-model", - file=sys.stderr, - ) - sys.exit(2) - if not prepare_model(args.model, args.model_dir): - sys.exit(1) - elif args.command == "test-cuda": - if not test_cuda(): - sys.exit(1) + dispatch: dict[str, object] = { + "python-version": _handle_python_version, + "check-faster-whisper": _handle_check_faster_whisper, + "check-diarization": _handle_check_diarization, + "check-ctranslate2": _handle_check_ctranslate2, + "deps-installed": _handle_deps_installed, + "generate-wav": lambda: _handle_generate_wav(args), + "prepare-model": lambda: _handle_prepare_model(args), + "test-cuda": _handle_test_cuda, + } + + handler = dispatch.get(args.command) + if handler is not None and callable(handler): + handler() if __name__ == "__main__":