mirror of
https://github.com/kuhyx/testsAndMisc.git
synced 2026-07-04 16:23:04 +02:00
music_gen: force CUDA for NVIDIA GPUs, auto-select model by VRAM
- Fail fast if NVIDIA GPU detected but CUDA unavailable (no CPU fallback) - Auto-select largest model based on VRAM (large=12GB+, medium=8GB+) - Remove torchaudio dependency (scipy handles audio I/O) - Use safetensors format to avoid torch.load security issues
This commit is contained in:
parent
bfef1a532b
commit
e4b3d8cbdc
@ -23,6 +23,10 @@ import warnings
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
# VRAM thresholds for model selection (in GB)
|
||||
VRAM_THRESHOLD_LARGE = 12 # Use large model with 12GB+ VRAM
|
||||
VRAM_THRESHOLD_MEDIUM = 8 # Use medium model with 8GB+ VRAM
|
||||
|
||||
|
||||
def check_dependencies() -> bool:
|
||||
"""Check if required packages are installed."""
|
||||
@ -33,30 +37,61 @@ def check_dependencies() -> bool:
|
||||
except ImportError:
|
||||
missing.append("torch")
|
||||
|
||||
try:
|
||||
import torchaudio # noqa: F401
|
||||
except ImportError:
|
||||
missing.append("torchaudio")
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
except ImportError:
|
||||
missing.append("transformers")
|
||||
|
||||
try:
|
||||
import scipy # noqa: F401
|
||||
except ImportError:
|
||||
missing.append("scipy")
|
||||
|
||||
if missing:
|
||||
print("Missing dependencies. Install with:")
|
||||
print(f" pip install {' '.join(missing)}")
|
||||
print("\nOr run the full setup:")
|
||||
print(" pip install torch torchaudio transformers scipy")
|
||||
print("\nFor CUDA support:")
|
||||
print(" pip install torch --index-url https://download.pytorch.org/whl/cu121")
|
||||
print(" pip install transformers scipy")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Get the best available device (CUDA, MPS, or CPU)."""
|
||||
"""Get the best available device (CUDA or MPS). No CPU fallback for NVIDIA.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If NVIDIA GPU is detected but CUDA is not available.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Check for NVIDIA GPU first
|
||||
nvidia_gpu_present = False
|
||||
try:
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
nvidia_smi_path = shutil.which("nvidia-smi")
|
||||
if nvidia_smi_path:
|
||||
result = subprocess.run(
|
||||
[nvidia_smi_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
nvidia_gpu_present = result.returncode == 0
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
if nvidia_gpu_present:
|
||||
if not torch.cuda.is_available():
|
||||
msg = (
|
||||
"NVIDIA GPU detected but CUDA is not available!\n"
|
||||
"Please install PyTorch with CUDA support:\n"
|
||||
" pip install torch torchaudio --index-url "
|
||||
"https://download.pytorch.org/whl/cu121"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
device = "cuda"
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
@ -70,6 +105,49 @@ def get_device() -> str:
|
||||
return device
|
||||
|
||||
|
||||
def get_vram_gb() -> float | None:
|
||||
"""Get available VRAM in GB. Returns None if no CUDA GPU."""
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
return None
|
||||
|
||||
|
||||
def select_model_size(user_choice: str | None = None) -> str:
|
||||
"""Select model size based on user choice or available VRAM.
|
||||
|
||||
Args:
|
||||
user_choice: User's explicit model choice, or None for auto-selection.
|
||||
|
||||
Returns:
|
||||
Model size: 'small', 'medium', or 'large'
|
||||
"""
|
||||
if user_choice is not None:
|
||||
return user_choice
|
||||
|
||||
vram = get_vram_gb()
|
||||
|
||||
if vram is None:
|
||||
# No GPU, use medium as a safe default
|
||||
print("No CUDA GPU detected, defaulting to medium model")
|
||||
return "medium"
|
||||
|
||||
# Select based on VRAM:
|
||||
# - large: needs ~10GB VRAM (safe with 12GB+)
|
||||
# - medium: needs ~6GB VRAM (safe with 8GB+)
|
||||
# - small: needs ~3GB VRAM
|
||||
if vram >= VRAM_THRESHOLD_LARGE:
|
||||
selected = "large"
|
||||
elif vram >= VRAM_THRESHOLD_MEDIUM:
|
||||
selected = "medium"
|
||||
else:
|
||||
selected = "small"
|
||||
|
||||
print(f"Auto-selected '{selected}' model based on {vram:.1f}GB VRAM")
|
||||
return selected
|
||||
|
||||
|
||||
def load_model(
|
||||
model_size: str = "medium",
|
||||
) -> tuple: # type: ignore[type-arg]
|
||||
@ -93,7 +171,11 @@ def load_model(
|
||||
device = get_device()
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
# Use safetensors format to avoid torch.load security issues with older PyTorch
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
use_safetensors=True,
|
||||
)
|
||||
model = model.to(device)
|
||||
|
||||
print(f"Model loaded successfully on {device}!")
|
||||
@ -255,10 +337,10 @@ Examples:
|
||||
%(prog)s --model small "jazz guitar solo"
|
||||
%(prog)s --interactive
|
||||
|
||||
Model sizes:
|
||||
small - ~500MB, fastest, lower quality
|
||||
medium - ~3.3GB, good balance (default)
|
||||
large - ~6.5GB, best quality, needs 16GB+ VRAM
|
||||
Model sizes (auto-selected based on VRAM if not specified):
|
||||
small - ~500MB, fastest, lower quality (3GB+ VRAM)
|
||||
medium - ~3.3GB, good balance (8GB+ VRAM)
|
||||
large - ~6.5GB, best quality (12GB+ VRAM)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -278,8 +360,8 @@ Model sizes:
|
||||
"-m",
|
||||
"--model",
|
||||
choices=["small", "medium", "large"],
|
||||
default="medium",
|
||||
help="Model size (default: medium)",
|
||||
default=None,
|
||||
help="Model size (default: auto-select based on VRAM, largest possible)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
@ -305,8 +387,11 @@ Model sizes:
|
||||
if not check_dependencies():
|
||||
sys.exit(1)
|
||||
|
||||
# Select model size based on VRAM if not specified
|
||||
model_size = select_model_size(args.model)
|
||||
|
||||
# Load model
|
||||
model, processor = load_model(args.model)
|
||||
model, processor = load_model(model_size)
|
||||
|
||||
if args.interactive:
|
||||
interactive_mode(model, processor)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user