music_gen: force CUDA for NVIDIA GPUs, auto-select model by VRAM

- Fail fast if NVIDIA GPU detected but CUDA unavailable (no CPU fallback)
- Auto-select largest model based on VRAM (large=12GB+, medium=8GB+)
- Remove torchaudio dependency (scipy handles audio I/O)
- Use safetensors format to avoid torch.load security issues
This commit is contained in:
Krzysztof kuhy Rudnicki 2025-12-04 20:57:50 +01:00
parent 150230caf8
commit 96564e121b

View File

@ -23,6 +23,10 @@ import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# VRAM thresholds for model selection (in GB)
VRAM_THRESHOLD_LARGE = 12 # Use large model with 12GB+ VRAM
VRAM_THRESHOLD_MEDIUM = 8 # Use medium model with 8GB+ VRAM
def check_dependencies() -> bool:
"""Check if required packages are installed."""
@ -33,30 +37,61 @@ def check_dependencies() -> bool:
except ImportError:
missing.append("torch")
try:
import torchaudio # noqa: F401
except ImportError:
missing.append("torchaudio")
try:
import transformers # noqa: F401
except ImportError:
missing.append("transformers")
try:
import scipy # noqa: F401
except ImportError:
missing.append("scipy")
if missing:
print("Missing dependencies. Install with:")
print(f" pip install {' '.join(missing)}")
print("\nOr run the full setup:")
print(" pip install torch torchaudio transformers scipy")
print("\nFor CUDA support:")
print(" pip install torch --index-url https://download.pytorch.org/whl/cu121")
print(" pip install transformers scipy")
return False
return True
def get_device() -> str:
"""Get the best available device (CUDA, MPS, or CPU)."""
"""Get the best available device (CUDA or MPS). No CPU fallback for NVIDIA.
Raises:
RuntimeError: If NVIDIA GPU is detected but CUDA is not available.
"""
import torch
if torch.cuda.is_available():
# Check for NVIDIA GPU first
nvidia_gpu_present = False
try:
import shutil
import subprocess
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path:
result = subprocess.run(
[nvidia_smi_path],
capture_output=True,
text=True,
check=False,
)
nvidia_gpu_present = result.returncode == 0
except FileNotFoundError:
pass
if nvidia_gpu_present:
if not torch.cuda.is_available():
msg = (
"NVIDIA GPU detected but CUDA is not available!\n"
"Please install PyTorch with CUDA support:\n"
" pip install torch torchaudio --index-url "
"https://download.pytorch.org/whl/cu121"
)
raise RuntimeError(msg)
device = "cuda"
gpu_name = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
@ -70,6 +105,49 @@ def get_device() -> str:
return device
def get_vram_gb() -> float | None:
"""Get available VRAM in GB. Returns None if no CUDA GPU."""
import torch
if torch.cuda.is_available():
return torch.cuda.get_device_properties(0).total_memory / 1024**3
return None
def select_model_size(user_choice: str | None = None) -> str:
"""Select model size based on user choice or available VRAM.
Args:
user_choice: User's explicit model choice, or None for auto-selection.
Returns:
Model size: 'small', 'medium', or 'large'
"""
if user_choice is not None:
return user_choice
vram = get_vram_gb()
if vram is None:
# No GPU, use medium as a safe default
print("No CUDA GPU detected, defaulting to medium model")
return "medium"
# Select based on VRAM:
# - large: needs ~10GB VRAM (safe with 12GB+)
# - medium: needs ~6GB VRAM (safe with 8GB+)
# - small: needs ~3GB VRAM
if vram >= VRAM_THRESHOLD_LARGE:
selected = "large"
elif vram >= VRAM_THRESHOLD_MEDIUM:
selected = "medium"
else:
selected = "small"
print(f"Auto-selected '{selected}' model based on {vram:.1f}GB VRAM")
return selected
def load_model(
model_size: str = "medium",
) -> tuple: # type: ignore[type-arg]
@ -93,7 +171,11 @@ def load_model(
device = get_device()
processor = AutoProcessor.from_pretrained(model_name)
model = MusicgenForConditionalGeneration.from_pretrained(model_name)
# Use safetensors format to avoid torch.load security issues with older PyTorch
model = MusicgenForConditionalGeneration.from_pretrained(
model_name,
use_safetensors=True,
)
model = model.to(device)
print(f"Model loaded successfully on {device}!")
@ -255,10 +337,10 @@ Examples:
%(prog)s --model small "jazz guitar solo"
%(prog)s --interactive
Model sizes:
small - ~500MB, fastest, lower quality
medium - ~3.3GB, good balance (default)
large - ~6.5GB, best quality, needs 16GB+ VRAM
Model sizes (auto-selected based on VRAM if not specified):
small - ~500MB, fastest, lower quality (3GB+ VRAM)
medium - ~3.3GB, good balance (8GB+ VRAM)
large - ~6.5GB, best quality (12GB+ VRAM)
""",
)
@ -278,8 +360,8 @@ Model sizes:
"-m",
"--model",
choices=["small", "medium", "large"],
default="medium",
help="Model size (default: medium)",
default=None,
help="Model size (default: auto-select based on VRAM, largest possible)",
)
parser.add_argument(
"-i",
@ -305,8 +387,11 @@ Model sizes:
if not check_dependencies():
sys.exit(1)
# Select model size based on VRAM if not specified
model_size = select_model_size(args.model)
# Load model
model, processor = load_model(args.model)
model, processor = load_model(model_size)
if args.interactive:
interactive_mode(model, processor)