Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import trimesh | |
| import tempfile | |
| from typing import Union, Optional, Dict, Any | |
| from pathlib import Path | |
| import os | |
| class Hunyuan3DGenerator: | |
| """3D model generation using Hunyuan3D-2.1""" | |
| def __init__(self, device: str = "cuda"): | |
| self.device = device if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.preprocessor = None | |
| # Model configuration | |
| self.model_id = "tencent/Hunyuan3D-2.1" | |
| self.lite_model_id = "tencent/Hunyuan3D-2.1-Lite" # For low VRAM | |
| # Generation parameters | |
| self.num_inference_steps = 50 | |
| self.guidance_scale = 7.5 | |
| self.resolution = 256 # 3D resolution | |
| # Use lite model for low VRAM | |
| self.use_lite = self.device == "cpu" or not self._check_vram() | |
| def _check_vram(self) -> bool: | |
| """Check if we have enough VRAM for full model""" | |
| if not torch.cuda.is_available(): | |
| return False | |
| try: | |
| vram = torch.cuda.get_device_properties(0).total_memory | |
| # Need at least 12GB for full model | |
| return vram > 12 * 1024 * 1024 * 1024 | |
| except: | |
| return False | |
| def load_model(self): | |
| """Lazy load the 3D generation model""" | |
| if self.model is None: | |
| try: | |
| # Import Hunyuan3D components | |
| from transformers import AutoModel, AutoProcessor | |
| model_id = self.lite_model_id if self.use_lite else self.model_id | |
| # Load preprocessor | |
| self.preprocessor = AutoProcessor.from_pretrained(model_id) | |
| # Load model with optimizations | |
| torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" if self.device == "cuda" else None, | |
| trust_remote_code=True | |
| ) | |
| if self.device == "cpu": | |
| self.model = self.model.to(self.device) | |
| # Enable optimizations | |
| if hasattr(self.model, 'enable_attention_slicing'): | |
| self.model.enable_attention_slicing() | |
| except Exception as e: | |
| print(f"Failed to load Hunyuan3D model: {e}") | |
| # Model loading failed, will use fallback | |
| self.model = "fallback" | |
| def image_to_3d(self, | |
| image: Union[str, Image.Image, np.ndarray], | |
| remove_background: bool = True, | |
| texture_resolution: int = 1024) -> Union[str, trimesh.Trimesh]: | |
| """Convert 2D image to 3D model""" | |
| try: | |
| # Load model if needed | |
| if self.model is None: | |
| self.load_model() | |
| # If model loading failed, use fallback | |
| if self.model == "fallback": | |
| return self._generate_fallback_3d(image) | |
| # Prepare image | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Ensure RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize for processing | |
| image = image.resize((512, 512), Image.Resampling.LANCZOS) | |
| # Remove background if requested | |
| if remove_background: | |
| image = self._remove_background(image) | |
| # Process with model | |
| with torch.no_grad(): | |
| # Preprocess image | |
| inputs = self.preprocessor(images=image, return_tensors="pt").to(self.device) | |
| # Generate 3D | |
| outputs = self.model.generate( | |
| **inputs, | |
| num_inference_steps=self.num_inference_steps, | |
| guidance_scale=self.guidance_scale, | |
| texture_resolution=texture_resolution | |
| ) | |
| # Extract mesh | |
| mesh = self._extract_mesh(outputs) | |
| # Save mesh | |
| mesh_path = self._save_mesh(mesh) | |
| return mesh_path | |
| except Exception as e: | |
| print(f"3D generation error: {e}") | |
| return self._generate_fallback_3d(image) | |
| def _remove_background(self, image: Image.Image) -> Image.Image: | |
| """Remove background from image""" | |
| try: | |
| # Try using rembg if available | |
| from rembg import remove | |
| return remove(image) | |
| except: | |
| # Fallback: simple background removal | |
| # Convert to RGBA | |
| image = image.convert("RGBA") | |
| # Simple white background removal | |
| datas = image.getdata() | |
| new_data = [] | |
| for item in datas: | |
| # Remove white-ish backgrounds | |
| if item[0] > 230 and item[1] > 230 and item[2] > 230: | |
| new_data.append((255, 255, 255, 0)) | |
| else: | |
| new_data.append(item) | |
| image.putdata(new_data) | |
| return image | |
| def _extract_mesh(self, model_outputs: Dict[str, Any]) -> trimesh.Trimesh: | |
| """Extract mesh from model outputs""" | |
| # This would depend on actual Hunyuan3D output format | |
| # Placeholder implementation | |
| if 'vertices' in model_outputs and 'faces' in model_outputs: | |
| vertices = model_outputs['vertices'].cpu().numpy() | |
| faces = model_outputs['faces'].cpu().numpy() | |
| # Create trimesh object | |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| # Add texture if available | |
| if 'texture' in model_outputs: | |
| # Apply texture to mesh | |
| pass | |
| return mesh | |
| else: | |
| # Create a simple mesh if outputs are different | |
| return self._create_simple_mesh() | |
| def _create_simple_mesh(self) -> trimesh.Trimesh: | |
| """Create a simple placeholder mesh""" | |
| # Create a simple sphere as placeholder | |
| mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0) | |
| # Add some variation | |
| mesh.vertices += np.random.normal(0, 0.05, mesh.vertices.shape) | |
| # Smooth the mesh | |
| mesh = mesh.smoothed() | |
| return mesh | |
| def _generate_fallback_3d(self, image: Union[Image.Image, np.ndarray]) -> str: | |
| """Generate fallback 3D model when main model fails""" | |
| # Create a simple 3D representation based on image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| elif isinstance(image, str): | |
| image = Image.open(image) | |
| # Analyze image for basic shape | |
| image_array = np.array(image.resize((64, 64))) | |
| # Create height map from image brightness | |
| gray = np.mean(image_array, axis=2) | |
| height_map = gray / 255.0 | |
| # Create mesh from height map | |
| mesh = self._heightmap_to_mesh(height_map) | |
| # Save and return path | |
| return self._save_mesh(mesh) | |
| def _heightmap_to_mesh(self, heightmap: np.ndarray) -> trimesh.Trimesh: | |
| """Convert heightmap to 3D mesh""" | |
| h, w = heightmap.shape | |
| # Create vertices | |
| vertices = [] | |
| faces = [] | |
| # Create vertex grid | |
| for i in range(h): | |
| for j in range(w): | |
| x = (j - w/2) / w * 2 | |
| y = (i - h/2) / h * 2 | |
| z = heightmap[i, j] * 0.5 | |
| vertices.append([x, y, z]) | |
| # Create faces | |
| for i in range(h-1): | |
| for j in range(w-1): | |
| # Two triangles per grid square | |
| v1 = i * w + j | |
| v2 = v1 + 1 | |
| v3 = v1 + w | |
| v4 = v3 + 1 | |
| faces.append([v1, v2, v3]) | |
| faces.append([v2, v4, v3]) | |
| vertices = np.array(vertices) | |
| faces = np.array(faces) | |
| # Create mesh | |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| # Apply smoothing | |
| mesh = mesh.smoothed() | |
| return mesh | |
| def _save_mesh(self, mesh: trimesh.Trimesh) -> str: | |
| """Save mesh to file""" | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as tmp: | |
| mesh_path = tmp.name | |
| # Export mesh | |
| mesh.export(mesh_path) | |
| return mesh_path | |
| def text_to_3d(self, text_prompt: str) -> str: | |
| """Generate 3D model from text description""" | |
| # First generate image, then convert to 3D | |
| # This would require image generator integration | |
| raise NotImplementedError("Text to 3D requires image generation first") | |
| def to(self, device: str): | |
| """Move model to specified device""" | |
| self.device = device | |
| if self.model and self.model != "fallback": | |
| self.model.to(device) | |
| def __del__(self): | |
| """Cleanup when object is destroyed""" | |
| if self.model and self.model != "fallback": | |
| del self.model | |
| if self.preprocessor: | |
| del self.preprocessor | |
| torch.cuda.empty_cache() |