mirror of
https://github.com/kuhyx/testsAndMisc-archive.git
synced 2026-07-04 14:23:04 +02:00
music_gen: force CUDA for NVIDIA GPUs, auto-select model by VRAM
- Fail fast if NVIDIA GPU detected but CUDA unavailable (no CPU fallback) - Auto-select largest model based on VRAM (large=12GB+, medium=8GB+) - Remove torchaudio dependency (scipy handles audio I/O) - Use safetensors format to avoid torch.load security issues
This commit is contained in:
parent
150230caf8
commit
96564e121b
@ -23,6 +23,10 @@ import warnings
|
|||||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
# VRAM thresholds for model selection (in GB)
|
||||||
|
VRAM_THRESHOLD_LARGE = 12 # Use large model with 12GB+ VRAM
|
||||||
|
VRAM_THRESHOLD_MEDIUM = 8 # Use medium model with 8GB+ VRAM
|
||||||
|
|
||||||
|
|
||||||
def check_dependencies() -> bool:
|
def check_dependencies() -> bool:
|
||||||
"""Check if required packages are installed."""
|
"""Check if required packages are installed."""
|
||||||
@ -33,30 +37,61 @@ def check_dependencies() -> bool:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
missing.append("torch")
|
missing.append("torch")
|
||||||
|
|
||||||
try:
|
|
||||||
import torchaudio # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
missing.append("torchaudio")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import transformers # noqa: F401
|
import transformers # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
missing.append("transformers")
|
missing.append("transformers")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import scipy # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
missing.append("scipy")
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
print("Missing dependencies. Install with:")
|
print("Missing dependencies. Install with:")
|
||||||
print(f" pip install {' '.join(missing)}")
|
print(f" pip install {' '.join(missing)}")
|
||||||
print("\nOr run the full setup:")
|
print("\nFor CUDA support:")
|
||||||
print(" pip install torch torchaudio transformers scipy")
|
print(" pip install torch --index-url https://download.pytorch.org/whl/cu121")
|
||||||
|
print(" pip install transformers scipy")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_device() -> str:
|
def get_device() -> str:
|
||||||
"""Get the best available device (CUDA, MPS, or CPU)."""
|
"""Get the best available device (CUDA or MPS). No CPU fallback for NVIDIA.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If NVIDIA GPU is detected but CUDA is not available.
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
# Check for NVIDIA GPU first
|
||||||
|
nvidia_gpu_present = False
|
||||||
|
try:
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
nvidia_smi_path = shutil.which("nvidia-smi")
|
||||||
|
if nvidia_smi_path:
|
||||||
|
result = subprocess.run(
|
||||||
|
[nvidia_smi_path],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
nvidia_gpu_present = result.returncode == 0
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if nvidia_gpu_present:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
msg = (
|
||||||
|
"NVIDIA GPU detected but CUDA is not available!\n"
|
||||||
|
"Please install PyTorch with CUDA support:\n"
|
||||||
|
" pip install torch torchaudio --index-url "
|
||||||
|
"https://download.pytorch.org/whl/cu121"
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
gpu_name = torch.cuda.get_device_name(0)
|
gpu_name = torch.cuda.get_device_name(0)
|
||||||
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||||
@ -70,6 +105,49 @@ def get_device() -> str:
|
|||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_vram_gb() -> float | None:
|
||||||
|
"""Get available VRAM in GB. Returns None if no CUDA GPU."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def select_model_size(user_choice: str | None = None) -> str:
|
||||||
|
"""Select model size based on user choice or available VRAM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_choice: User's explicit model choice, or None for auto-selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model size: 'small', 'medium', or 'large'
|
||||||
|
"""
|
||||||
|
if user_choice is not None:
|
||||||
|
return user_choice
|
||||||
|
|
||||||
|
vram = get_vram_gb()
|
||||||
|
|
||||||
|
if vram is None:
|
||||||
|
# No GPU, use medium as a safe default
|
||||||
|
print("No CUDA GPU detected, defaulting to medium model")
|
||||||
|
return "medium"
|
||||||
|
|
||||||
|
# Select based on VRAM:
|
||||||
|
# - large: needs ~10GB VRAM (safe with 12GB+)
|
||||||
|
# - medium: needs ~6GB VRAM (safe with 8GB+)
|
||||||
|
# - small: needs ~3GB VRAM
|
||||||
|
if vram >= VRAM_THRESHOLD_LARGE:
|
||||||
|
selected = "large"
|
||||||
|
elif vram >= VRAM_THRESHOLD_MEDIUM:
|
||||||
|
selected = "medium"
|
||||||
|
else:
|
||||||
|
selected = "small"
|
||||||
|
|
||||||
|
print(f"Auto-selected '{selected}' model based on {vram:.1f}GB VRAM")
|
||||||
|
return selected
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
model_size: str = "medium",
|
model_size: str = "medium",
|
||||||
) -> tuple: # type: ignore[type-arg]
|
) -> tuple: # type: ignore[type-arg]
|
||||||
@ -93,7 +171,11 @@ def load_model(
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(model_name)
|
processor = AutoProcessor.from_pretrained(model_name)
|
||||||
model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
# Use safetensors format to avoid torch.load security issues with older PyTorch
|
||||||
|
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
use_safetensors=True,
|
||||||
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print(f"Model loaded successfully on {device}!")
|
print(f"Model loaded successfully on {device}!")
|
||||||
@ -255,10 +337,10 @@ Examples:
|
|||||||
%(prog)s --model small "jazz guitar solo"
|
%(prog)s --model small "jazz guitar solo"
|
||||||
%(prog)s --interactive
|
%(prog)s --interactive
|
||||||
|
|
||||||
Model sizes:
|
Model sizes (auto-selected based on VRAM if not specified):
|
||||||
small - ~500MB, fastest, lower quality
|
small - ~500MB, fastest, lower quality (3GB+ VRAM)
|
||||||
medium - ~3.3GB, good balance (default)
|
medium - ~3.3GB, good balance (8GB+ VRAM)
|
||||||
large - ~6.5GB, best quality, needs 16GB+ VRAM
|
large - ~6.5GB, best quality (12GB+ VRAM)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -278,8 +360,8 @@ Model sizes:
|
|||||||
"-m",
|
"-m",
|
||||||
"--model",
|
"--model",
|
||||||
choices=["small", "medium", "large"],
|
choices=["small", "medium", "large"],
|
||||||
default="medium",
|
default=None,
|
||||||
help="Model size (default: medium)",
|
help="Model size (default: auto-select based on VRAM, largest possible)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-i",
|
"-i",
|
||||||
@ -305,8 +387,11 @@ Model sizes:
|
|||||||
if not check_dependencies():
|
if not check_dependencies():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Select model size based on VRAM if not specified
|
||||||
|
model_size = select_model_size(args.model)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, processor = load_model(args.model)
|
model, processor = load_model(model_size)
|
||||||
|
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
interactive_mode(model, processor)
|
interactive_mode(model, processor)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user