Spaces:
Paused
Paused
feat: use Hunyuan3D-2.1 model directly for local 3D generation, optimize for high VRAM, update pipeline config and docs
e4aa154
| import spaces | |
| import torch | |
| import gc | |
| import os | |
| from typing import Optional, List, Dict, Any | |
| from datetime import datetime | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import tempfile | |
| import threading | |
| import time | |
| # Model imports (to be implemented) | |
| from models.stt_processor import KyutaiSTTProcessor | |
| from models.text_generator import QwenTextGenerator | |
| from models.image_generator import OmniGenImageGenerator | |
| from models.model_3d_generator import Hunyuan3DGenerator | |
| from models.rigging_processor import UniRigProcessor | |
| from utils.fallbacks import FallbackManager | |
| from utils.caching import ModelCache | |
| class MonsterGenerationPipeline: | |
| """Main AI pipeline for monster generation""" | |
| def __init__(self, device: str = "cuda"): | |
| self.device = device if torch.cuda.is_available() else "cpu" | |
| self.cache = ModelCache() | |
| self.fallback_manager = FallbackManager() | |
| self.models = {} | |
| self.model_loaded = { | |
| 'stt': False, | |
| 'text_gen': False, | |
| 'image_gen': False, | |
| '3d_gen': False, | |
| 'rigging': False | |
| } | |
| # Pipeline configuration | |
| self.config = { | |
| 'max_retries': 3, | |
| 'timeout': 180, | |
| 'enable_caching': True, | |
| 'low_vram_mode': False, # We have enough VRAM | |
| 'enable_rigging': False # Disable rigging by default for faster generation | |
| } | |
| def _cleanup_memory(self): | |
| """Clear GPU memory""" | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| def _lazy_load_model(self, model_type: str): | |
| """Lazy loading with memory optimization""" | |
| if self.model_loaded[model_type]: | |
| return self.models[model_type] | |
| # Clear memory before loading new model | |
| self._cleanup_memory() | |
| try: | |
| if model_type == 'stt': | |
| self.models['stt'] = KyutaiSTTProcessor(device=self.device) | |
| elif model_type == 'text_gen': | |
| self.models['text_gen'] = QwenTextGenerator(device=self.device) | |
| elif model_type == 'image_gen': | |
| self.models['image_gen'] = OmniGenImageGenerator(device=self.device) | |
| elif model_type == '3d_gen': | |
| self.models['3d_gen'] = Hunyuan3DGenerator(device=self.device) | |
| elif model_type == 'rigging': | |
| self.models['rigging'] = UniRigProcessor(device=self.device) | |
| self.model_loaded[model_type] = True | |
| return self.models[model_type] | |
| except Exception as e: | |
| print(f"Failed to load {model_type}: {e}") | |
| return None | |
| def _unload_model(self, model_type: str): | |
| """Unload model to free memory""" | |
| if model_type in self.models and self.model_loaded[model_type]: | |
| # Don't try to move models to CPU - just delete them | |
| # Moving to CPU can fail with meta tensors or when using CPU offloading | |
| try: | |
| # If the model has a pipeline attribute, delete it first | |
| if hasattr(self.models[model_type], 'pipeline'): | |
| if self.models[model_type].pipeline is not None: | |
| del self.models[model_type].pipeline | |
| # Delete the model wrapper | |
| del self.models[model_type] | |
| except Exception as e: | |
| print(f"Warning during model unload: {e}") | |
| self.model_loaded[model_type] = False | |
| self._cleanup_memory() | |
| def generate_monster(self, | |
| audio_input: Optional[str] = None, | |
| text_input: Optional[str] = None, | |
| reference_images: Optional[List] = None, | |
| user_id: Optional[str] = None) -> Dict[str, Any]: | |
| """Main monster generation pipeline""" | |
| generation_log = { | |
| 'user_id': user_id, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'stages_completed': [], | |
| 'fallbacks_used': [], | |
| 'success': False, | |
| 'errors': [] | |
| } | |
| try: | |
| print("🚀 Starting monster generation pipeline...") | |
| # Stage 1: Speech to Text (if audio provided) | |
| description = "" | |
| if audio_input and os.path.exists(audio_input): | |
| try: | |
| print("🎤 Processing audio input...") | |
| stt_model = self._lazy_load_model('stt') | |
| if stt_model: | |
| description = stt_model.transcribe(audio_input) | |
| generation_log['stages_completed'].append('stt') | |
| print(f"✅ STT completed: {description[:100]}...") | |
| else: | |
| raise Exception("STT model failed to load") | |
| except Exception as e: | |
| print(f"❌ STT failed: {e}") | |
| description = text_input or "Create a friendly digital monster" | |
| generation_log['fallbacks_used'].append('stt') | |
| generation_log['errors'].append(f"STT error: {str(e)}") | |
| finally: | |
| # Unload STT to free memory | |
| self._unload_model('stt') | |
| else: | |
| description = text_input or "Create a friendly digital monster" | |
| print(f"📝 Using text input: {description}") | |
| # Stage 2: Generate monster characteristics | |
| monster_traits = {} | |
| monster_dialogue = "" | |
| try: | |
| print("🧠 Generating monster traits and dialogue...") | |
| text_gen = self._lazy_load_model('text_gen') | |
| if text_gen: | |
| monster_traits = text_gen.generate_traits(description) | |
| monster_dialogue = text_gen.generate_dialogue(monster_traits) | |
| generation_log['stages_completed'].append('text_gen') | |
| print(f"✅ Text generation completed: {monster_traits.get('name', 'Unknown')}") | |
| else: | |
| raise Exception("Text generation model failed to load") | |
| except Exception as e: | |
| print(f"❌ Text generation failed: {e}") | |
| monster_traits, monster_dialogue = self.fallback_manager.handle_text_gen_failure(description) | |
| generation_log['fallbacks_used'].append('text_gen') | |
| generation_log['errors'].append(f"Text generation error: {str(e)}") | |
| finally: | |
| self._unload_model('text_gen') | |
| # Stage 3: Generate monster image | |
| monster_image = None | |
| try: | |
| print("🎨 Generating monster image...") | |
| image_gen = self._lazy_load_model('image_gen') | |
| if image_gen: | |
| # Create enhanced prompt from traits | |
| image_prompt = self._create_image_prompt(description, monster_traits) | |
| monster_image = image_gen.generate( | |
| prompt=image_prompt, | |
| reference_images=reference_images, | |
| width=512, | |
| height=512 | |
| ) | |
| generation_log['stages_completed'].append('image_gen') | |
| print("✅ Image generation completed") | |
| else: | |
| raise Exception("Image generation model failed to load") | |
| except Exception as e: | |
| print(f"❌ Image generation failed: {e}") | |
| monster_image = self.fallback_manager.handle_image_gen_failure(description) | |
| generation_log['fallbacks_used'].append('image_gen') | |
| generation_log['errors'].append(f"Image generation error: {str(e)}") | |
| finally: | |
| self._unload_model('image_gen') | |
| # Stage 4: Convert to 3D model | |
| model_3d = None | |
| model_3d_path = None | |
| try: | |
| print("🔲 Converting to 3D model...") | |
| model_3d_gen = self._lazy_load_model('3d_gen') | |
| if model_3d_gen and monster_image: | |
| # Set a timeout for 3D generation (5 minutes) | |
| result = None | |
| error = None | |
| def generate_3d(): | |
| nonlocal result, error | |
| try: | |
| result = model_3d_gen.image_to_3d(monster_image) | |
| except Exception as e: | |
| error = e | |
| # Start 3D generation in a separate thread | |
| thread = threading.Thread(target=generate_3d) | |
| thread.daemon = True | |
| thread.start() | |
| # Wait for completion with timeout | |
| timeout = 300 # 5 minutes | |
| thread.join(timeout) | |
| if thread.is_alive(): | |
| print(f"⏰ 3D generation timed out after {timeout} seconds") | |
| raise Exception(f"3D generation timeout after {timeout} seconds") | |
| if error: | |
| raise error | |
| if result: | |
| model_3d = result | |
| # Save 3D model | |
| model_3d_path = self._save_3d_model(model_3d, user_id) | |
| generation_log['stages_completed'].append('3d_gen') | |
| print("✅ 3D generation completed") | |
| else: | |
| raise Exception("3D generation returned no result") | |
| else: | |
| raise Exception("3D generation failed - no model or image") | |
| except Exception as e: | |
| print(f"❌ 3D generation failed: {e}") | |
| model_3d = self.fallback_manager.handle_3d_gen_failure(monster_image) | |
| generation_log['fallbacks_used'].append('3d_gen') | |
| generation_log['errors'].append(f"3D generation error: {str(e)}") | |
| finally: | |
| self._unload_model('3d_gen') | |
| # Stage 5: Add rigging (optional, can be skipped for performance) | |
| rigged_model = model_3d | |
| if model_3d and self.config.get('enable_rigging', False): | |
| try: | |
| print("🦴 Adding rigging...") | |
| rigging_proc = self._lazy_load_model('rigging') | |
| if rigging_proc: | |
| rigged_model = rigging_proc.rig_mesh(model_3d) | |
| generation_log['stages_completed'].append('rigging') | |
| print("✅ Rigging completed") | |
| except Exception as e: | |
| print(f"❌ Rigging failed: {e}") | |
| generation_log['fallbacks_used'].append('rigging') | |
| generation_log['errors'].append(f"Rigging error: {str(e)}") | |
| finally: | |
| self._unload_model('rigging') | |
| # Prepare download files | |
| download_files = self._prepare_download_files( | |
| rigged_model or model_3d, | |
| monster_image, | |
| user_id | |
| ) | |
| generation_log['success'] = True | |
| print("🎉 Monster generation pipeline completed successfully!") | |
| return { | |
| 'description': description, | |
| 'traits': monster_traits, | |
| 'dialogue': monster_dialogue, | |
| 'image': monster_image, | |
| 'model_3d': model_3d_path, | |
| 'download_files': download_files, | |
| 'generation_log': generation_log, | |
| 'status': 'success' | |
| } | |
| except Exception as e: | |
| print(f"💥 Pipeline error: {e}") | |
| generation_log['error'] = str(e) | |
| generation_log['errors'].append(f"Pipeline error: {str(e)}") | |
| # return self.fallback_generation(description or "digital monster", generation_log) | |
| return { | |
| 'description': description, | |
| 'traits': monster_traits, | |
| 'dialogue': monster_dialogue, | |
| 'image': monster_image, | |
| 'model_3d': model_3d_path, | |
| 'download_files': download_files, | |
| 'generation_log': generation_log, | |
| 'status': 'error' | |
| } | |
| def _create_image_prompt(self, base_description: str, traits: Dict) -> str: | |
| """Create enhanced prompt for image generation""" | |
| prompt_parts = [base_description] | |
| if traits: | |
| if 'appearance' in traits: | |
| prompt_parts.append(traits['appearance']) | |
| if 'personality' in traits: | |
| prompt_parts.append(f"with {traits['personality']} personality") | |
| if 'color_scheme' in traits: | |
| prompt_parts.append(f"featuring {traits['color_scheme']} colors") | |
| prompt_parts.extend([ | |
| "digital monster", | |
| "creature design", | |
| "game character", | |
| "high quality", | |
| "detailed" | |
| ]) | |
| return ", ".join(prompt_parts) | |
| def _save_3d_model(self, model_3d, user_id: Optional[str]) -> Optional[str]: | |
| """Save 3D model to persistent storage""" | |
| if not model_3d: | |
| return None | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| user_id_str = user_id or "anonymous" | |
| filename = f"monster_{user_id_str}_{timestamp}.glb" | |
| # Use HuggingFace Spaces persistent storage | |
| if os.path.exists("/data"): | |
| filepath = f"/data/models/{filename}" | |
| else: | |
| filepath = f"./data/models/{filename}" | |
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
| # Save model (implementation depends on model format) | |
| # This is a placeholder - actual implementation would depend on model format | |
| with open(filepath, 'wb') as f: | |
| if hasattr(model_3d, 'export'): | |
| model_3d.export(f) | |
| else: | |
| # Fallback: save as binary data | |
| f.write(str(model_3d).encode()) | |
| return filepath | |
| def _prepare_download_files(self, model_3d, image, user_id: Optional[str]) -> List[str]: | |
| """Prepare downloadable files for user""" | |
| files = [] | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| user_id_str = user_id or "anonymous" | |
| # Save image | |
| if image: | |
| if isinstance(image, Image.Image): | |
| image_path = f"/tmp/monster_{user_id_str}_{timestamp}.png" | |
| image.save(image_path) | |
| files.append(image_path) | |
| elif isinstance(image, np.ndarray): | |
| image_path = f"/tmp/monster_{user_id_str}_{timestamp}.png" | |
| Image.fromarray(image).save(image_path) | |
| files.append(image_path) | |
| # Save 3D model in multiple formats if available | |
| if model_3d: | |
| # GLB format | |
| glb_path = f"/tmp/monster_{user_id_str}_{timestamp}.glb" | |
| files.append(glb_path) | |
| # OBJ format (optional) | |
| obj_path = f"/tmp/monster_{user_id_str}_{timestamp}.obj" | |
| files.append(obj_path) | |
| return files | |
| def fallback_generation(self, description: str, generation_log: Dict) -> Dict[str, Any]: | |
| """Complete fallback generation when pipeline fails""" | |
| return self.fallback_manager.complete_fallback_generation(description, generation_log) | |
| def cleanup(self): | |
| """Clean up all loaded models""" | |
| for model_type in list(self.models.keys()): | |
| self._unload_model(model_type) | |
| self._cleanup_memory() |