2025-10-12 14:46:55 +02:00
#!/usr/bin/env python3
import argparse
import os
import shutil
import subprocess
import sys
import time
from datetime import timedelta
from typing import List , Optional
2025-12-20 21:49:52 +01:00
def format_bytes ( size : int ) - > str :
""" Format bytes as human-readable string. """
for unit in [ ' B ' , ' KB ' , ' MB ' , ' GB ' ] :
if size < 1024 :
return f " { size : .1f } { unit } "
size / = 1024
return f " { size : .1f } TB "
def download_model_with_progress ( model_name : str ) - > str :
""" Download model files from HuggingFace with a visible progress bar.
Returns the local path to the downloaded model .
"""
try :
from huggingface_hub import snapshot_download , hf_hub_download
from huggingface_hub . utils import EntryNotFoundError
except ImportError :
print ( " [WARN] huggingface_hub not available, falling back to default download " , file = sys . stderr )
return model_name
# Map common model names to HF repo IDs
model_map = {
" tiny " : " Systran/faster-whisper-tiny " ,
" tiny.en " : " Systran/faster-whisper-tiny.en " ,
" base " : " Systran/faster-whisper-base " ,
" base.en " : " Systran/faster-whisper-base.en " ,
" small " : " Systran/faster-whisper-small " ,
" small.en " : " Systran/faster-whisper-small.en " ,
" medium " : " Systran/faster-whisper-medium " ,
" medium.en " : " Systran/faster-whisper-medium.en " ,
" large-v1 " : " Systran/faster-whisper-large-v1 " ,
" large-v2 " : " Systran/faster-whisper-large-v2 " ,
" large-v3 " : " Systran/faster-whisper-large-v3 " ,
" large " : " Systran/faster-whisper-large-v3 " ,
" distil-large-v2 " : " Systran/faster-distil-whisper-large-v2 " ,
" distil-large-v3 " : " Systran/faster-distil-whisper-large-v3 " ,
" distil-medium.en " : " Systran/faster-distil-whisper-medium.en " ,
" distil-small.en " : " Systran/faster-distil-whisper-small.en " ,
}
repo_id = model_map . get ( model_name , model_name )
# Check if it looks like a repo ID
if " / " not in repo_id and model_name not in model_map :
# Assume it's a Systran model
repo_id = f " Systran/faster-whisper- { model_name } "
print ( f " [INFO] Checking model: { repo_id } " , flush = True )
# Files we need to download (model.bin is the large one)
required_files = [ " config.json " , " model.bin " , " tokenizer.json " , " vocabulary.txt " ]
try :
# Use snapshot_download which handles caching and shows what's happening
# First, let's check if model.bin needs downloading by checking cache
from huggingface_hub import try_to_load_from_cache , HfFileSystem
cache_path = try_to_load_from_cache ( repo_id , " model.bin " )
if cache_path is not None :
print ( f " [INFO] Model already cached, loading from: { os . path . dirname ( cache_path ) } " , flush = True )
# Return the directory containing the cached files
return os . path . dirname ( cache_path )
# Model not cached, need to download
print ( f " [INFO] Downloading model files from { repo_id } ... " , flush = True )
print ( " [INFO] This may take several minutes for large models (~3GB for large-v3) " , flush = True )
# Get file sizes to show progress
try :
fs = HfFileSystem ( )
files_info = fs . ls ( repo_id , detail = True )
total_size = sum ( f . get ( ' size ' , 0 ) for f in files_info if f . get ( ' name ' , ' ' ) . split ( ' / ' ) [ - 1 ] in required_files )
print ( f " [INFO] Total download size: ~ { format_bytes ( total_size ) } " , flush = True )
except Exception :
pass # Size info is optional
# Download with progress
downloaded = 0
start_time = time . time ( )
for filename in required_files :
file_start = time . time ( )
print ( f " [DOWNLOAD] { filename } ... " , end = " " , flush = True )
try :
local_path = hf_hub_download (
repo_id = repo_id ,
filename = filename ,
resume_download = True ,
)
elapsed = time . time ( ) - file_start
file_size = os . path . getsize ( local_path ) if os . path . exists ( local_path ) else 0
print ( f " done ( { format_bytes ( file_size ) } , { elapsed : .1f } s) " , flush = True )
downloaded + = 1
# Return directory on first successful download
if downloaded == 1 :
model_dir = os . path . dirname ( local_path )
except EntryNotFoundError :
print ( " not found (optional) " , flush = True )
except Exception as e :
print ( f " error: { e } " , flush = True )
total_time = time . time ( ) - start_time
print ( f " [INFO] Download complete in { total_time : .1f } s " , flush = True )
return model_dir
except Exception as e :
print ( f " [WARN] Custom download failed ( { e } ), falling back to default " , file = sys . stderr )
return model_name
2025-10-12 14:46:55 +02:00
def format_timestamp ( seconds : float ) - > str :
td = timedelta ( seconds = seconds )
# Ensure SRT format HH:MM:SS,mmm
total_seconds = int ( td . total_seconds ( ) )
hours = total_seconds / / 3600
minutes = ( total_seconds % 3600 ) / / 60
secs = total_seconds % 60
millis = int ( ( seconds - int ( seconds ) ) * 1000 )
return f " { hours : 02d } : { minutes : 02d } : { secs : 02d } , { millis : 03d } "
def write_srt ( segments , srt_path : str ) :
with open ( srt_path , " w " , encoding = " utf-8 " ) as f :
for i , seg in enumerate ( segments , start = 1 ) :
start = format_timestamp ( seg . start )
end = format_timestamp ( seg . end )
text = ( seg . text or " " ) . strip ( )
if not text :
continue
f . write ( f " { i } \n { start } --> { end } \n { text } \n \n " )
def write_txt ( segments , txt_path : str ) :
with open ( txt_path , " w " , encoding = " utf-8 " ) as f :
for seg in segments :
text = ( seg . text or " " ) . strip ( )
if text :
f . write ( text + " \n " )
def write_srt_with_speakers ( segments , labels : List [ int ] , path : str ) :
with open ( path , " w " , encoding = " utf-8 " ) as f :
for i , ( seg , lab ) in enumerate ( zip ( segments , labels ) , start = 1 ) :
text = ( seg . text or " " ) . strip ( )
if not text :
continue
spk = f " SPK { lab + 1 } "
f . write ( f " { i } \n { format_timestamp ( seg . start ) } --> { format_timestamp ( seg . end ) } \n [ { spk } ] { text } \n \n " )
def write_txt_with_speakers ( segments , labels : List [ int ] , path : str ) :
with open ( path , " w " , encoding = " utf-8 " ) as f :
for seg , lab in zip ( segments , labels ) :
text = ( seg . text or " " ) . strip ( )
if text :
spk = f " SPK { lab + 1 } "
f . write ( f " [ { spk } ] { text } \n " )
def write_rttm ( segments , labels : List [ int ] , path : str , file_id : str = " audio " ) :
# RTTM format: SPEAKER <file-id> 1 <start> <duration> <ortho> <stype> <name> <conf>
with open ( path , " w " , encoding = " utf-8 " ) as f :
for seg , lab in zip ( segments , labels ) :
start = float ( getattr ( seg , " start " , 0.0 ) or 0.0 )
end = float ( getattr ( seg , " end " , start ) or start )
dur = max ( 0.0 , end - start )
name = f " SPK { lab + 1 } "
f . write ( f " SPEAKER { file_id } 1 { start : .3f } { dur : .3f } <NA> <NA> { name } <NA> \n " )
def hhmmss ( seconds : float ) - > str :
seconds = max ( 0.0 , float ( seconds ) )
total_seconds = int ( seconds )
h = total_seconds / / 3600
m = ( total_seconds % 3600 ) / / 60
s = total_seconds % 60
return f " { h : 02d } : { m : 02d } : { s : 02d } "
def get_media_duration ( path : str ) - > float | None :
""" Try to get media duration in seconds using ffmpeg-python or ffprobe.
Returns None if unavailable .
"""
# Try ffmpeg-python first (if installed) which uses ffprobe under the hood
try :
import ffmpeg # type: ignore
probe = ffmpeg . probe ( path )
fmt = probe . get ( " format " , { } )
if " duration " in fmt :
return float ( fmt [ " duration " ] ) # type: ignore
except Exception :
pass
# Fallback: call ffprobe directly if available
if shutil . which ( " ffprobe " ) :
try :
out = subprocess . check_output (
[
" ffprobe " ,
" -v " ,
" error " ,
" -show_entries " ,
" format=duration " ,
" -of " ,
" default=noprint_wrappers=1:nokey=1 " ,
path ,
] ,
stderr = subprocess . DEVNULL ,
)
return float ( out . decode ( ) . strip ( ) )
except Exception :
return None
return None
def _resample_linear ( x , src_sr : int , tgt_sr : int ) :
import numpy as np
if src_sr == tgt_sr :
return x
ratio = float ( tgt_sr ) / float ( src_sr )
n_out = max ( 1 , int ( round ( x . shape [ - 1 ] * ratio ) ) )
xp = np . linspace ( 0.0 , 1.0 , num = x . shape [ - 1 ] , endpoint = False )
xq = np . linspace ( 0.0 , 1.0 , num = n_out , endpoint = False )
y = np . interp ( xq , xp , x . astype ( np . float32 ) )
return y . astype ( np . float32 )
def _kmeans_cosine ( embs , k : int , iters : int = 50 , seed : int = 0 ) :
import numpy as np
rng = np . random . default_rng ( seed )
X = np . asarray ( embs , dtype = np . float32 )
if X . ndim != 2 or X . shape [ 0 ] == 0 :
return np . zeros ( ( 0 , ) , dtype = np . int64 )
# Normalize
X = X / ( np . linalg . norm ( X , axis = 1 , keepdims = True ) + 1e-8 )
# Init centroids as random samples
idxs = rng . choice ( X . shape [ 0 ] , size = min ( k , X . shape [ 0 ] ) , replace = False )
C = X [ idxs ]
# If fewer samples than k, pad with random
if C . shape [ 0 ] < k :
pad = rng . standard_normal ( size = ( k - C . shape [ 0 ] , X . shape [ 1 ] ) ) . astype ( np . float32 )
pad / = ( np . linalg . norm ( pad , axis = 1 , keepdims = True ) + 1e-8 )
C = np . concatenate ( [ C , pad ] , axis = 0 )
for _ in range ( iters ) :
# Assign by cosine similarity (maximize dot product)
sims = X @ C . T # (n, k)
labels = sims . argmax ( axis = 1 )
newC = np . zeros_like ( C )
for j in range ( k ) :
sel = X [ labels == j ]
if sel . shape [ 0 ] == 0 :
newC [ j ] = C [ j ]
else :
v = sel . mean ( axis = 0 )
v / = ( np . linalg . norm ( v ) + 1e-8 )
newC [ j ] = v
if np . allclose ( newC , C , atol = 1e-4 ) :
break
C = newC
return labels
2025-10-12 14:51:41 +02:00
def _ffmpeg_transcode_to_wav16_mono ( src_path : str ) - > Optional [ str ] :
""" If ffmpeg is available, transcode input to a temporary 16k mono WAV and return its path. """
if not shutil . which ( " ffmpeg " ) :
return None
import tempfile
tmp = tempfile . NamedTemporaryFile ( prefix = " fw_diar_ " , suffix = " .wav " , delete = False )
tmp_path = tmp . name
tmp . close ( )
# Run ffmpeg quietly
cmd = [
" ffmpeg " ,
" -y " ,
" -v " ,
" error " ,
" -i " ,
src_path ,
" -ac " ,
" 1 " ,
" -ar " ,
" 16000 " ,
" -f " ,
" wav " ,
tmp_path ,
]
try :
subprocess . run ( cmd , check = True , stdout = subprocess . DEVNULL , stderr = subprocess . DEVNULL )
return tmp_path
except Exception :
try :
os . unlink ( tmp_path )
except Exception :
pass
return None
2025-10-12 14:46:55 +02:00
def diarize_segments ( audio_path : str , segments , num_speakers : int = 2 ) - > Optional [ list ] :
""" Simple diarization: compute speaker embeddings per segment and cluster with KMeans.
Returns a list of speaker labels aligned with segments , or None on failure .
"""
try :
import numpy as np
import soundfile as sf
2025-10-12 14:51:41 +02:00
# Use non-deprecated import path
from speechbrain . inference import EncoderClassifier
2025-10-12 14:46:55 +02:00
import torch
except Exception as e :
print ( f " [WARN] Diarization dependencies missing ( { e } ); skipping speaker labels. " , file = sys . stderr )
return None
# Load audio
2025-10-12 14:51:41 +02:00
temp_to_cleanup : Optional [ str ] = None
2025-10-12 14:46:55 +02:00
try :
wav , sr = sf . read ( audio_path , dtype = " float32 " , always_2d = False )
except Exception as e :
2025-10-12 14:51:41 +02:00
# Try ffmpeg transcoding fallback
alt = _ffmpeg_transcode_to_wav16_mono ( audio_path )
if alt is None :
print ( f " [WARN] Could not read audio for diarization and no ffmpeg fallback available: { e } " , file = sys . stderr )
return None
try :
wav , sr = sf . read ( alt , dtype = " float32 " , always_2d = False )
temp_to_cleanup = alt
except Exception as e2 :
print ( f " [WARN] Could not read transcoded audio for diarization: { e2 } " , file = sys . stderr )
try :
os . unlink ( alt )
except Exception :
pass
return None
2025-10-12 14:46:55 +02:00
if wav . ndim == 2 : # mixdown
wav = wav . mean ( axis = 1 )
# Resample to 16k for ECAPA
wav16 = _resample_linear ( wav , sr , 16000 )
# Load speaker embedding model (CPU is fine)
try :
classifier = EncoderClassifier . from_hparams (
source = " speechbrain/spkrec-ecapa-voxceleb " ,
run_opts = { " device " : " cpu " } ,
2025-10-12 14:51:41 +02:00
savedir = os . path . join ( os . path . expanduser ( " ~ " ) , " .cache " , " speechbrain_ecapa " ) ,
2025-10-12 14:46:55 +02:00
)
except Exception as e :
print ( f " [WARN] Could not load speaker embedding model: { e } " , file = sys . stderr )
2025-10-12 14:51:41 +02:00
if temp_to_cleanup :
try :
os . unlink ( temp_to_cleanup )
except Exception :
pass
2025-10-12 14:46:55 +02:00
return None
embs = [ ]
# Extract embedding per segment window
for seg in segments :
s = float ( getattr ( seg , " start " , 0.0 ) or 0.0 )
e = float ( getattr ( seg , " end " , s ) or s )
if e < = s :
e = s + 0.2 # minimal window
# Convert to samples in 16k
i0 = int ( s * 16000 )
i1 = int ( e * 16000 )
# Add small margins to help very short segments
pad = int ( 0.05 * 16000 )
i0 = max ( 0 , i0 - pad )
i1 = min ( len ( wav16 ) , i1 + pad )
if i1 - i0 < 1600 : # <0.1s, too short; expand if possible
i1 = min ( len ( wav16 ) , i0 + 1600 )
segment_wav = torch . tensor ( wav16 [ i0 : i1 ] ) . unsqueeze ( 0 )
with torch . no_grad ( ) :
emb = classifier . encode_batch ( segment_wav ) . squeeze ( 0 ) . squeeze ( 0 ) . cpu ( ) . numpy ( )
embs . append ( emb . astype ( " float32 " ) )
if len ( embs ) == 0 :
return None
# Cluster
labels = _kmeans_cosine ( embs , k = max ( 1 , int ( num_speakers ) ) )
2025-10-12 14:51:41 +02:00
if temp_to_cleanup :
try :
os . unlink ( temp_to_cleanup )
except Exception :
pass
2025-10-12 14:46:55 +02:00
return labels . tolist ( )
def main ( ) :
parser = argparse . ArgumentParser ( description = " Transcribe audio with faster-whisper and write .txt and .srt " )
parser . add_argument ( " input " , help = " Path to audio/video file " )
parser . add_argument ( " --model " , default = os . environ . get ( " FW_MODEL " , " large-v3 " ) , help = " Model size or path (default: large-v3) " )
parser . add_argument ( " --language " , default = None , help = " Language code (e.g., en). Leave None for auto-detect " )
parser . add_argument ( " --device " , default = os . environ . get ( " FW_DEVICE " , " auto " ) , choices = [ " auto " , " cpu " , " cuda " ] , help = " Device to run on " )
parser . add_argument ( " --compute-type " , dest = " compute_type " , default = os . environ . get ( " FW_COMPUTE " , " auto " ) , help = " Compute type (auto,int8,float16,float32,int8_float16,etc.) " )
parser . add_argument ( " --outdir " , default = None , help = " Output directory (default: next to input) " )
parser . add_argument ( " --no-progress " , action = " store_true " , help = " Disable live progress output " )
parser . add_argument ( " --diarize " , action = " store_true " , help = " Enable speaker diarization (labels) " )
parser . add_argument ( " --num-speakers " , type = int , default = int ( os . environ . get ( " FW_NUM_SPEAKERS " , " 2 " ) ) , help = " Assumed number of speakers (default: 2) " )
args = parser . parse_args ( )
try :
from faster_whisper import WhisperModel
except Exception as e :
print ( " [ERROR] faster-whisper is not installed in this environment. " , file = sys . stderr )
print ( str ( e ) , file = sys . stderr )
return 2
inp = os . path . abspath ( args . input )
if not os . path . exists ( inp ) :
print ( f " [ERROR] Input file not found: { inp } " , file = sys . stderr )
return 2
outdir = os . path . abspath ( args . outdir or os . path . dirname ( inp ) or " . " )
os . makedirs ( outdir , exist_ok = True )
base = os . path . splitext ( os . path . basename ( inp ) ) [ 0 ]
srt_path = os . path . join ( outdir , base + " .srt " )
txt_path = os . path . join ( outdir , base + " .txt " )
# Device and compute_type heuristics
device = args . device
compute_type = args . compute_type
if device == " auto " :
device = " cpu "
if compute_type == " auto " :
# Prefer accuracy over speed by default
compute_type = " float16 " if device == " cuda " else " float32 "
print ( f " [INFO] Loading model= ' { args . model } ' , device= ' { device } ' , compute_type= ' { compute_type } ' " )
2025-12-20 21:49:52 +01:00
# Pre-download model files with explicit progress if not already cached
model_path = args . model
if not os . path . isdir ( args . model ) : # Not a local path, need to download from HF
model_path = download_model_with_progress ( args . model )
# Show CTranslate2 conversion progress
import logging
logging . basicConfig ( level = logging . INFO , format = ' [ %(levelname)s ] %(message)s ' )
ct2_logger = logging . getLogger ( " faster_whisper " )
ct2_logger . setLevel ( logging . INFO )
print ( " [INFO] Initializing model... " , flush = True )
model = WhisperModel ( model_path , device = device , compute_type = compute_type )
print ( " [INFO] Model loaded successfully. " , flush = True )
2025-10-12 14:46:55 +02:00
# Transcription with live progress
total_duration = get_media_duration ( inp )
if total_duration :
print ( f " [INFO] Media duration: { hhmmss ( total_duration ) } " )
start_ts = time . time ( )
iter_segments , info = model . transcribe ( inp , language = args . language )
collected = [ ]
processed = 0.0
last_print = 0.0
tty = sys . stderr . isatty ( )
for seg in iter_segments :
collected . append ( seg )
# Update processed time from segment end if available
if getattr ( seg , " end " , None ) is not None :
processed = max ( processed , float ( seg . end ) )
now = time . time ( )
# Print each segment or throttle to ~5 per second
if not args . no_progress and ( tty or ( now - last_print ) > = 0.2 ) :
last_print = now
if total_duration and total_duration > 0 :
pct = max ( 0.0 , min ( 100.0 , ( processed / total_duration ) * 100.0 ) )
elapsed = now - start_ts
eta = None
if processed > 0 :
rate = processed / max ( 1e-6 , elapsed )
remaining = max ( 0.0 , total_duration - processed )
eta = remaining / max ( 1e-6 , rate )
line = f " [PROGRESS] { hhmmss ( processed ) } / { hhmmss ( total_duration ) } ( { pct : 5.1f } %) "
if eta is not None and eta < 60 * 60 * 24 : # cap unrealistic values
line + = f " ETA ~ { hhmmss ( eta ) } "
else :
line = f " [PROGRESS] processed { hhmmss ( processed ) } "
if tty :
print ( " \r " + line , end = " " , file = sys . stderr , flush = True )
else :
print ( line , file = sys . stderr , flush = True )
# Finish progress line
if not args . no_progress and sys . stderr . isatty ( ) :
print ( " " , file = sys . stderr ) # newline
print ( f " [INFO] Detected language: { getattr ( info , ' language ' , None ) } (prob= { getattr ( info , ' language_probability ' , None ) } ) " )
print ( f " [INFO] Segments: { len ( collected ) } " )
# Optionally diarize
if args . diarize :
labels = diarize_segments ( inp , collected , num_speakers = args . num_speakers )
if labels is not None and len ( labels ) == len ( collected ) :
diar_srt = os . path . join ( outdir , base + " .diar.srt " )
diar_txt = os . path . join ( outdir , base + " .diar.txt " )
rttm_path = os . path . join ( outdir , base + " .rttm " )
write_srt_with_speakers ( collected , labels , diar_srt )
write_txt_with_speakers ( collected , labels , diar_txt )
write_rttm ( collected , labels , rttm_path , file_id = base )
print ( f " [OK] Wrote: { diar_txt } \n [OK] Wrote: { diar_srt } \n [OK] Wrote: { rttm_path } " )
else :
print ( " [WARN] Diarization failed or returned mismatched labels; writing plain outputs. " , file = sys . stderr )
# Write base outputs
write_txt ( collected , txt_path )
write_srt ( collected , srt_path )
print ( f " [OK] Wrote: { txt_path } \n [OK] Wrote: { srt_path } " )
return 0
if __name__ == " __main__ " :
sys . exit ( main ( ) )