import torch import numpy as np from PIL import Image import trimesh import tempfile from typing import Union, Optional, Dict, Any from pathlib import Path import os class Hunyuan3DGenerator: """3D model generation using Hunyuan3D-2.1""" def __init__(self, device: str = "cuda"): self.device = device if torch.cuda.is_available() else "cpu" self.model = None self.preprocessor = None # Model configuration self.model_id = "tencent/Hunyuan3D-2.1" self.lite_model_id = "tencent/Hunyuan3D-2.1-Lite" # For low VRAM # Generation parameters self.num_inference_steps = 50 self.guidance_scale = 7.5 self.resolution = 256 # 3D resolution # Use lite model for low VRAM self.use_lite = self.device == "cpu" or not self._check_vram() def _check_vram(self) -> bool: """Check if we have enough VRAM for full model""" if not torch.cuda.is_available(): return False try: vram = torch.cuda.get_device_properties(0).total_memory # Need at least 12GB for full model return vram > 12 * 1024 * 1024 * 1024 except: return False def load_model(self): """Lazy load the 3D generation model""" if self.model is None: try: # Import Hunyuan3D components from transformers import AutoModel, AutoProcessor model_id = self.lite_model_id if self.use_lite else self.model_id # Load preprocessor self.preprocessor = AutoProcessor.from_pretrained(model_id) # Load model with optimizations torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.model = AutoModel.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, device_map="auto" if self.device == "cuda" else None, trust_remote_code=True ) if self.device == "cpu": self.model = self.model.to(self.device) # Enable optimizations if hasattr(self.model, 'enable_attention_slicing'): self.model.enable_attention_slicing() except Exception as e: print(f"Failed to load Hunyuan3D model: {e}") # Model loading failed, will use fallback self.model = "fallback" def image_to_3d(self, image: Union[str, Image.Image, np.ndarray], remove_background: bool = True, texture_resolution: int = 1024) -> Union[str, trimesh.Trimesh]: """Convert 2D image to 3D model""" try: # Load model if needed if self.model is None: self.load_model() # If model loading failed, use fallback if self.model == "fallback": return self._generate_fallback_3d(image) # Prepare image if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) # Ensure RGB if image.mode != 'RGB': image = image.convert('RGB') # Resize for processing image = image.resize((512, 512), Image.Resampling.LANCZOS) # Remove background if requested if remove_background: image = self._remove_background(image) # Process with model with torch.no_grad(): # Preprocess image inputs = self.preprocessor(images=image, return_tensors="pt").to(self.device) # Generate 3D outputs = self.model.generate( **inputs, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale, texture_resolution=texture_resolution ) # Extract mesh mesh = self._extract_mesh(outputs) # Save mesh mesh_path = self._save_mesh(mesh) return mesh_path except Exception as e: print(f"3D generation error: {e}") return self._generate_fallback_3d(image) def _remove_background(self, image: Image.Image) -> Image.Image: """Remove background from image""" try: # Try using rembg if available from rembg import remove return remove(image) except: # Fallback: simple background removal # Convert to RGBA image = image.convert("RGBA") # Simple white background removal datas = image.getdata() new_data = [] for item in datas: # Remove white-ish backgrounds if item[0] > 230 and item[1] > 230 and item[2] > 230: new_data.append((255, 255, 255, 0)) else: new_data.append(item) image.putdata(new_data) return image def _extract_mesh(self, model_outputs: Dict[str, Any]) -> trimesh.Trimesh: """Extract mesh from model outputs""" # This would depend on actual Hunyuan3D output format # Placeholder implementation if 'vertices' in model_outputs and 'faces' in model_outputs: vertices = model_outputs['vertices'].cpu().numpy() faces = model_outputs['faces'].cpu().numpy() # Create trimesh object mesh = trimesh.Trimesh(vertices=vertices, faces=faces) # Add texture if available if 'texture' in model_outputs: # Apply texture to mesh pass return mesh else: # Create a simple mesh if outputs are different return self._create_simple_mesh() def _create_simple_mesh(self) -> trimesh.Trimesh: """Create a simple placeholder mesh""" # Create a simple sphere as placeholder mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0) # Add some variation mesh.vertices += np.random.normal(0, 0.05, mesh.vertices.shape) # Smooth the mesh mesh = mesh.smoothed() return mesh def _generate_fallback_3d(self, image: Union[Image.Image, np.ndarray]) -> str: """Generate fallback 3D model when main model fails""" # Create a simple 3D representation based on image if isinstance(image, np.ndarray): image = Image.fromarray(image) elif isinstance(image, str): image = Image.open(image) # Analyze image for basic shape image_array = np.array(image.resize((64, 64))) # Create height map from image brightness gray = np.mean(image_array, axis=2) height_map = gray / 255.0 # Create mesh from height map mesh = self._heightmap_to_mesh(height_map) # Save and return path return self._save_mesh(mesh) def _heightmap_to_mesh(self, heightmap: np.ndarray) -> trimesh.Trimesh: """Convert heightmap to 3D mesh""" h, w = heightmap.shape # Create vertices vertices = [] faces = [] # Create vertex grid for i in range(h): for j in range(w): x = (j - w/2) / w * 2 y = (i - h/2) / h * 2 z = heightmap[i, j] * 0.5 vertices.append([x, y, z]) # Create faces for i in range(h-1): for j in range(w-1): # Two triangles per grid square v1 = i * w + j v2 = v1 + 1 v3 = v1 + w v4 = v3 + 1 faces.append([v1, v2, v3]) faces.append([v2, v4, v3]) vertices = np.array(vertices) faces = np.array(faces) # Create mesh mesh = trimesh.Trimesh(vertices=vertices, faces=faces) # Apply smoothing mesh = mesh.smoothed() return mesh def _save_mesh(self, mesh: trimesh.Trimesh) -> str: """Save mesh to file""" # Create temporary file with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as tmp: mesh_path = tmp.name # Export mesh mesh.export(mesh_path) return mesh_path def text_to_3d(self, text_prompt: str) -> str: """Generate 3D model from text description""" # First generate image, then convert to 3D # This would require image generator integration raise NotImplementedError("Text to 3D requires image generation first") def to(self, device: str): """Move model to specified device""" self.device = device if self.model and self.model != "fallback": self.model.to(device) def __del__(self): """Cleanup when object is destroyed""" if self.model and self.model != "fallback": del self.model if self.preprocessor: del self.preprocessor torch.cuda.empty_cache()