mirror of
https://github.com/kuhyx/testsAndMisc.git
synced 2026-07-04 16:23:04 +02:00
200 lines
5.4 KiB
Python
200 lines
5.4 KiB
Python
|
|
"""Model download and caching for faster-whisper."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import contextlib
|
||
|
|
import importlib
|
||
|
|
import logging
|
||
|
|
from pathlib import Path
|
||
|
|
import time
|
||
|
|
from typing import TYPE_CHECKING
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
import types
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
_BYTES_PER_KB = 1024
|
||
|
|
|
||
|
|
# Model name to HF repo mapping
|
||
|
|
_MODEL_MAP: dict[str, str] = {
|
||
|
|
"tiny": "Systran/faster-whisper-tiny",
|
||
|
|
"tiny.en": "Systran/faster-whisper-tiny.en",
|
||
|
|
"base": "Systran/faster-whisper-base",
|
||
|
|
"base.en": "Systran/faster-whisper-base.en",
|
||
|
|
"small": "Systran/faster-whisper-small",
|
||
|
|
"small.en": "Systran/faster-whisper-small.en",
|
||
|
|
"medium": "Systran/faster-whisper-medium",
|
||
|
|
"medium.en": "Systran/faster-whisper-medium.en",
|
||
|
|
"large-v1": "Systran/faster-whisper-large-v1",
|
||
|
|
"large-v2": "Systran/faster-whisper-large-v2",
|
||
|
|
"large-v3": "Systran/faster-whisper-large-v3",
|
||
|
|
"large": "Systran/faster-whisper-large-v3",
|
||
|
|
"distil-large-v2": "Systran/faster-distil-whisper-large-v2",
|
||
|
|
"distil-large-v3": "Systran/faster-distil-whisper-large-v3",
|
||
|
|
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
|
||
|
|
"distil-small.en": "Systran/faster-distil-whisper-small.en",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _try_import(name: str) -> types.ModuleType | None:
|
||
|
|
"""Attempt to import a module, returning None on failure."""
|
||
|
|
try:
|
||
|
|
return importlib.import_module(name)
|
||
|
|
except ImportError:
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _format_bytes(size: int) -> str:
|
||
|
|
"""Format bytes as human-readable string."""
|
||
|
|
fsize = float(size)
|
||
|
|
for unit in ["B", "KB", "MB", "GB"]:
|
||
|
|
if fsize < _BYTES_PER_KB:
|
||
|
|
return f"{fsize:.1f}{unit}"
|
||
|
|
fsize /= _BYTES_PER_KB
|
||
|
|
return f"{fsize:.1f}TB"
|
||
|
|
|
||
|
|
|
||
|
|
def _check_cache(
|
||
|
|
repo_id: str,
|
||
|
|
) -> str | None:
|
||
|
|
"""Check HF cache for an already-downloaded model."""
|
||
|
|
hh = _try_import("huggingface_hub")
|
||
|
|
if hh is None:
|
||
|
|
return None
|
||
|
|
cache_path = hh.try_to_load_from_cache(
|
||
|
|
repo_id, "model.bin"
|
||
|
|
)
|
||
|
|
if cache_path is not None:
|
||
|
|
parent = str(Path(cache_path).parent)
|
||
|
|
logger.info(
|
||
|
|
"Model already cached, loading from: %s",
|
||
|
|
parent,
|
||
|
|
)
|
||
|
|
return parent
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _download_files(
|
||
|
|
repo_id: str,
|
||
|
|
required_files: list[str],
|
||
|
|
) -> str:
|
||
|
|
"""Download required model files from HuggingFace."""
|
||
|
|
hh = _try_import("huggingface_hub")
|
||
|
|
if hh is None:
|
||
|
|
msg = "huggingface_hub not available"
|
||
|
|
raise RuntimeError(msg)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"Downloading model files from %s...",
|
||
|
|
repo_id,
|
||
|
|
)
|
||
|
|
logger.info(
|
||
|
|
"This may take several minutes for large "
|
||
|
|
"models (~3GB for large-v3)",
|
||
|
|
)
|
||
|
|
|
||
|
|
_log_total_download_size(repo_id, required_files)
|
||
|
|
|
||
|
|
downloaded = 0
|
||
|
|
model_dir = ""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
for filename in required_files:
|
||
|
|
file_start = time.time()
|
||
|
|
logger.info("DOWNLOAD %s...", filename)
|
||
|
|
try:
|
||
|
|
local_path = hh.hf_hub_download(
|
||
|
|
repo_id=repo_id,
|
||
|
|
filename=filename,
|
||
|
|
resume_download=True,
|
||
|
|
)
|
||
|
|
elapsed = time.time() - file_start
|
||
|
|
lp = Path(local_path)
|
||
|
|
file_size = (
|
||
|
|
lp.stat().st_size
|
||
|
|
if lp.exists()
|
||
|
|
else 0
|
||
|
|
)
|
||
|
|
logger.info(
|
||
|
|
"done (%s, %.1fs)",
|
||
|
|
_format_bytes(file_size),
|
||
|
|
elapsed,
|
||
|
|
)
|
||
|
|
downloaded += 1
|
||
|
|
if downloaded == 1:
|
||
|
|
model_dir = str(lp.parent)
|
||
|
|
except OSError:
|
||
|
|
logger.info("not found (optional)")
|
||
|
|
except RuntimeError as exc:
|
||
|
|
logger.info("error: %s", exc)
|
||
|
|
|
||
|
|
total_time = time.time() - start_time
|
||
|
|
logger.info("Download complete in %.1fs", total_time)
|
||
|
|
return model_dir
|
||
|
|
|
||
|
|
|
||
|
|
def _log_total_download_size(
|
||
|
|
repo_id: str, required_files: list[str]
|
||
|
|
) -> None:
|
||
|
|
"""Log total download size if available."""
|
||
|
|
hh = _try_import("huggingface_hub")
|
||
|
|
if hh is None:
|
||
|
|
return
|
||
|
|
with contextlib.suppress(OSError, RuntimeError):
|
||
|
|
fs = hh.HfFileSystem()
|
||
|
|
files_info = fs.ls(repo_id, detail=True)
|
||
|
|
total_size = sum(
|
||
|
|
f.get("size", 0)
|
||
|
|
for f in files_info
|
||
|
|
if f.get("name", "").split("/")[-1]
|
||
|
|
in required_files
|
||
|
|
)
|
||
|
|
logger.info(
|
||
|
|
"Total download size: ~%s",
|
||
|
|
_format_bytes(total_size),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def download_model_with_progress(
|
||
|
|
model_name: str,
|
||
|
|
) -> str:
|
||
|
|
"""Download model files from HuggingFace with progress.
|
||
|
|
|
||
|
|
Returns the local path to the downloaded model.
|
||
|
|
"""
|
||
|
|
hh = _try_import("huggingface_hub")
|
||
|
|
if hh is None:
|
||
|
|
logger.warning(
|
||
|
|
"huggingface_hub not available, "
|
||
|
|
"falling back to default download",
|
||
|
|
)
|
||
|
|
return model_name
|
||
|
|
|
||
|
|
repo_id = _MODEL_MAP.get(model_name, model_name)
|
||
|
|
|
||
|
|
if "/" not in repo_id and model_name not in _MODEL_MAP:
|
||
|
|
repo_id = f"Systran/faster-whisper-{model_name}"
|
||
|
|
|
||
|
|
logger.info("Checking model: %s", repo_id)
|
||
|
|
|
||
|
|
required_files = [
|
||
|
|
"config.json",
|
||
|
|
"model.bin",
|
||
|
|
"tokenizer.json",
|
||
|
|
"vocabulary.txt",
|
||
|
|
]
|
||
|
|
|
||
|
|
try:
|
||
|
|
cached = _check_cache(repo_id)
|
||
|
|
if cached is not None:
|
||
|
|
return cached
|
||
|
|
return _download_files(repo_id, required_files)
|
||
|
|
except (OSError, RuntimeError) as exc:
|
||
|
|
logger.warning(
|
||
|
|
"Custom download failed (%s), "
|
||
|
|
"falling back to default",
|
||
|
|
exc,
|
||
|
|
)
|
||
|
|
return model_name
|