diff --git a/python_pkg/music_gen/music_generator.py b/python_pkg/music_gen/music_generator.py index a296f0d..212898b 100755 --- a/python_pkg/music_gen/music_generator.py +++ b/python_pkg/music_gen/music_generator.py @@ -23,6 +23,10 @@ import warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +# VRAM thresholds for model selection (in GB) +VRAM_THRESHOLD_LARGE = 12 # Use large model with 12GB+ VRAM +VRAM_THRESHOLD_MEDIUM = 8 # Use medium model with 8GB+ VRAM + def check_dependencies() -> bool: """Check if required packages are installed.""" @@ -33,30 +37,61 @@ def check_dependencies() -> bool: except ImportError: missing.append("torch") - try: - import torchaudio # noqa: F401 - except ImportError: - missing.append("torchaudio") - try: import transformers # noqa: F401 except ImportError: missing.append("transformers") + try: + import scipy # noqa: F401 + except ImportError: + missing.append("scipy") + if missing: print("Missing dependencies. Install with:") print(f" pip install {' '.join(missing)}") - print("\nOr run the full setup:") - print(" pip install torch torchaudio transformers scipy") + print("\nFor CUDA support:") + print(" pip install torch --index-url https://download.pytorch.org/whl/cu121") + print(" pip install transformers scipy") return False return True def get_device() -> str: - """Get the best available device (CUDA, MPS, or CPU).""" + """Get the best available device (CUDA or MPS). No CPU fallback for NVIDIA. + + Raises: + RuntimeError: If NVIDIA GPU is detected but CUDA is not available. + """ import torch - if torch.cuda.is_available(): + # Check for NVIDIA GPU first + nvidia_gpu_present = False + try: + import shutil + import subprocess + + nvidia_smi_path = shutil.which("nvidia-smi") + if nvidia_smi_path: + result = subprocess.run( + [nvidia_smi_path], + capture_output=True, + text=True, + check=False, + ) + nvidia_gpu_present = result.returncode == 0 + except FileNotFoundError: + pass + + if nvidia_gpu_present: + if not torch.cuda.is_available(): + msg = ( + "NVIDIA GPU detected but CUDA is not available!\n" + "Please install PyTorch with CUDA support:\n" + " pip install torch torchaudio --index-url " + "https://download.pytorch.org/whl/cu121" + ) + raise RuntimeError(msg) device = "cuda" gpu_name = torch.cuda.get_device_name(0) vram = torch.cuda.get_device_properties(0).total_memory / 1024**3 @@ -70,6 +105,49 @@ def get_device() -> str: return device +def get_vram_gb() -> float | None: + """Get available VRAM in GB. Returns None if no CUDA GPU.""" + import torch + + if torch.cuda.is_available(): + return torch.cuda.get_device_properties(0).total_memory / 1024**3 + return None + + +def select_model_size(user_choice: str | None = None) -> str: + """Select model size based on user choice or available VRAM. + + Args: + user_choice: User's explicit model choice, or None for auto-selection. + + Returns: + Model size: 'small', 'medium', or 'large' + """ + if user_choice is not None: + return user_choice + + vram = get_vram_gb() + + if vram is None: + # No GPU, use medium as a safe default + print("No CUDA GPU detected, defaulting to medium model") + return "medium" + + # Select based on VRAM: + # - large: needs ~10GB VRAM (safe with 12GB+) + # - medium: needs ~6GB VRAM (safe with 8GB+) + # - small: needs ~3GB VRAM + if vram >= VRAM_THRESHOLD_LARGE: + selected = "large" + elif vram >= VRAM_THRESHOLD_MEDIUM: + selected = "medium" + else: + selected = "small" + + print(f"Auto-selected '{selected}' model based on {vram:.1f}GB VRAM") + return selected + + def load_model( model_size: str = "medium", ) -> tuple: # type: ignore[type-arg] @@ -93,7 +171,11 @@ def load_model( device = get_device() processor = AutoProcessor.from_pretrained(model_name) - model = MusicgenForConditionalGeneration.from_pretrained(model_name) + # Use safetensors format to avoid torch.load security issues with older PyTorch + model = MusicgenForConditionalGeneration.from_pretrained( + model_name, + use_safetensors=True, + ) model = model.to(device) print(f"Model loaded successfully on {device}!") @@ -255,10 +337,10 @@ Examples: %(prog)s --model small "jazz guitar solo" %(prog)s --interactive -Model sizes: - small - ~500MB, fastest, lower quality - medium - ~3.3GB, good balance (default) - large - ~6.5GB, best quality, needs 16GB+ VRAM +Model sizes (auto-selected based on VRAM if not specified): + small - ~500MB, fastest, lower quality (3GB+ VRAM) + medium - ~3.3GB, good balance (8GB+ VRAM) + large - ~6.5GB, best quality (12GB+ VRAM) """, ) @@ -278,8 +360,8 @@ Model sizes: "-m", "--model", choices=["small", "medium", "large"], - default="medium", - help="Model size (default: medium)", + default=None, + help="Model size (default: auto-select based on VRAM, largest possible)", ) parser.add_argument( "-i", @@ -305,8 +387,11 @@ Model sizes: if not check_dependencies(): sys.exit(1) + # Select model size based on VRAM if not specified + model_size = select_model_size(args.model) + # Load model - model, processor = load_model(args.model) + model, processor = load_model(model_size) if args.interactive: interactive_mode(model, processor)