music_gen: force CUDA for NVIDIA GPUs, auto-select model by VRAM

- Fail fast if NVIDIA GPU detected but CUDA unavailable (no CPU fallback)
- Auto-select largest model based on VRAM (large=12GB+, medium=8GB+)
- Remove torchaudio dependency (scipy handles audio I/O)
- Use safetensors format to avoid torch.load security issues
This commit is contained in:
Krzysztof kuhy Rudnicki 2025-12-04 20:57:50 +01:00
parent 150230caf8
commit 96564e121b

View File

@ -23,6 +23,10 @@ import warnings
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
# VRAM thresholds for model selection (in GB)
VRAM_THRESHOLD_LARGE = 12 # Use large model with 12GB+ VRAM
VRAM_THRESHOLD_MEDIUM = 8 # Use medium model with 8GB+ VRAM
def check_dependencies() -> bool: def check_dependencies() -> bool:
"""Check if required packages are installed.""" """Check if required packages are installed."""
@ -33,30 +37,61 @@ def check_dependencies() -> bool:
except ImportError: except ImportError:
missing.append("torch") missing.append("torch")
try:
import torchaudio # noqa: F401
except ImportError:
missing.append("torchaudio")
try: try:
import transformers # noqa: F401 import transformers # noqa: F401
except ImportError: except ImportError:
missing.append("transformers") missing.append("transformers")
try:
import scipy # noqa: F401
except ImportError:
missing.append("scipy")
if missing: if missing:
print("Missing dependencies. Install with:") print("Missing dependencies. Install with:")
print(f" pip install {' '.join(missing)}") print(f" pip install {' '.join(missing)}")
print("\nOr run the full setup:") print("\nFor CUDA support:")
print(" pip install torch torchaudio transformers scipy") print(" pip install torch --index-url https://download.pytorch.org/whl/cu121")
print(" pip install transformers scipy")
return False return False
return True return True
def get_device() -> str: def get_device() -> str:
"""Get the best available device (CUDA, MPS, or CPU).""" """Get the best available device (CUDA or MPS). No CPU fallback for NVIDIA.
Raises:
RuntimeError: If NVIDIA GPU is detected but CUDA is not available.
"""
import torch import torch
if torch.cuda.is_available(): # Check for NVIDIA GPU first
nvidia_gpu_present = False
try:
import shutil
import subprocess
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path:
result = subprocess.run(
[nvidia_smi_path],
capture_output=True,
text=True,
check=False,
)
nvidia_gpu_present = result.returncode == 0
except FileNotFoundError:
pass
if nvidia_gpu_present:
if not torch.cuda.is_available():
msg = (
"NVIDIA GPU detected but CUDA is not available!\n"
"Please install PyTorch with CUDA support:\n"
" pip install torch torchaudio --index-url "
"https://download.pytorch.org/whl/cu121"
)
raise RuntimeError(msg)
device = "cuda" device = "cuda"
gpu_name = torch.cuda.get_device_name(0) gpu_name = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3 vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
@ -70,6 +105,49 @@ def get_device() -> str:
return device return device
def get_vram_gb() -> float | None:
"""Get available VRAM in GB. Returns None if no CUDA GPU."""
import torch
if torch.cuda.is_available():
return torch.cuda.get_device_properties(0).total_memory / 1024**3
return None
def select_model_size(user_choice: str | None = None) -> str:
"""Select model size based on user choice or available VRAM.
Args:
user_choice: User's explicit model choice, or None for auto-selection.
Returns:
Model size: 'small', 'medium', or 'large'
"""
if user_choice is not None:
return user_choice
vram = get_vram_gb()
if vram is None:
# No GPU, use medium as a safe default
print("No CUDA GPU detected, defaulting to medium model")
return "medium"
# Select based on VRAM:
# - large: needs ~10GB VRAM (safe with 12GB+)
# - medium: needs ~6GB VRAM (safe with 8GB+)
# - small: needs ~3GB VRAM
if vram >= VRAM_THRESHOLD_LARGE:
selected = "large"
elif vram >= VRAM_THRESHOLD_MEDIUM:
selected = "medium"
else:
selected = "small"
print(f"Auto-selected '{selected}' model based on {vram:.1f}GB VRAM")
return selected
def load_model( def load_model(
model_size: str = "medium", model_size: str = "medium",
) -> tuple: # type: ignore[type-arg] ) -> tuple: # type: ignore[type-arg]
@ -93,7 +171,11 @@ def load_model(
device = get_device() device = get_device()
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
model = MusicgenForConditionalGeneration.from_pretrained(model_name) # Use safetensors format to avoid torch.load security issues with older PyTorch
model = MusicgenForConditionalGeneration.from_pretrained(
model_name,
use_safetensors=True,
)
model = model.to(device) model = model.to(device)
print(f"Model loaded successfully on {device}!") print(f"Model loaded successfully on {device}!")
@ -255,10 +337,10 @@ Examples:
%(prog)s --model small "jazz guitar solo" %(prog)s --model small "jazz guitar solo"
%(prog)s --interactive %(prog)s --interactive
Model sizes: Model sizes (auto-selected based on VRAM if not specified):
small - ~500MB, fastest, lower quality small - ~500MB, fastest, lower quality (3GB+ VRAM)
medium - ~3.3GB, good balance (default) medium - ~3.3GB, good balance (8GB+ VRAM)
large - ~6.5GB, best quality, needs 16GB+ VRAM large - ~6.5GB, best quality (12GB+ VRAM)
""", """,
) )
@ -278,8 +360,8 @@ Model sizes:
"-m", "-m",
"--model", "--model",
choices=["small", "medium", "large"], choices=["small", "medium", "large"],
default="medium", default=None,
help="Model size (default: medium)", help="Model size (default: auto-select based on VRAM, largest possible)",
) )
parser.add_argument( parser.add_argument(
"-i", "-i",
@ -305,8 +387,11 @@ Model sizes:
if not check_dependencies(): if not check_dependencies():
sys.exit(1) sys.exit(1)
# Select model size based on VRAM if not specified
model_size = select_model_size(args.model)
# Load model # Load model
model, processor = load_model(args.model) model, processor = load_model(model_size)
if args.interactive: if args.interactive:
interactive_mode(model, processor) interactive_mode(model, processor)