Spaces:
Paused
Paused
| import torch | |
| from diffusers import DiffusionPipeline, StableDiffusionPipeline | |
| from PIL import Image | |
| import numpy as np | |
| from typing import Optional, List, Union | |
| import gc | |
| class OmniGenImageGenerator: | |
| """Image generation using OmniGen2 model""" | |
| def __init__(self, device: str = "cuda"): | |
| self.device = device if torch.cuda.is_available() else "cpu" | |
| self.pipeline = None | |
| self.model_id = "runwayml/stable-diffusion-v1-5" # Using working Stable Diffusion model | |
| # Generation parameters | |
| self.default_width = 512 | |
| self.default_height = 512 | |
| self.num_inference_steps = 30 | |
| self.guidance_scale = 7.5 | |
| # Memory optimization | |
| self.enable_attention_slicing = True | |
| self.enable_vae_slicing = True | |
| self.enable_cpu_offload = self.device == "cuda" | |
| def load_model(self): | |
| """Lazy load the image generation model""" | |
| if self.pipeline is None: | |
| try: | |
| # Determine torch dtype | |
| torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| # Load pipeline with optimizations | |
| self.pipeline = StableDiffusionPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=torch_dtype, | |
| use_safetensors=True, | |
| variant="fp16" if self.device == "cuda" else None | |
| ) | |
| # Apply optimizations and device placement | |
| if self.device == "cuda": | |
| if self.enable_cpu_offload: | |
| self.pipeline.enable_sequential_cpu_offload() | |
| else: | |
| # Safely move pipeline to CUDA | |
| try: | |
| self.pipeline = self.pipeline.to(self.device) | |
| except RuntimeError as e: | |
| if "meta tensor" in str(e): | |
| # Handle meta tensor issue by loading with device_map | |
| print(f"Meta tensor issue detected, using CPU fallback: {e}") | |
| self.device = "cpu" | |
| self.pipeline = self.pipeline.to("cpu") | |
| else: | |
| raise e | |
| if self.enable_attention_slicing and hasattr(self.pipeline, 'enable_attention_slicing'): | |
| self.pipeline.enable_attention_slicing(1) | |
| if self.enable_vae_slicing and hasattr(self.pipeline, 'enable_vae_slicing'): | |
| self.pipeline.enable_vae_slicing() | |
| else: | |
| self.pipeline = self.pipeline.to(self.device) | |
| # Compile for faster inference (if available) | |
| if hasattr(torch, 'compile') and self.device == "cuda": | |
| try: | |
| self.pipeline.unet = torch.compile(self.pipeline.unet, mode="reduce-overhead") | |
| except: | |
| pass # Compilation is optional | |
| except Exception as e: | |
| print(f"Failed to load image generation model: {e}") | |
| # Try fallback to stable diffusion | |
| try: | |
| self.model_id = "runwayml/stable-diffusion-v1-5" | |
| self._load_fallback_model() | |
| except: | |
| raise | |
| def _load_fallback_model(self): | |
| """Load fallback Stable Diffusion model""" | |
| from diffusers import StableDiffusionPipeline | |
| torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.pipeline = StableDiffusionPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=torch_dtype, | |
| use_safetensors=True, | |
| trust_remote_code=True | |
| ) | |
| if self.device == "cuda" and self.enable_cpu_offload: | |
| self.pipeline.enable_sequential_cpu_offload() | |
| else: | |
| self.pipeline = self.pipeline.to(self.device) | |
| def _truncate_prompt(self, prompt: str, max_tokens: int = 75) -> str: | |
| """Truncate prompt to fit CLIP token limit""" | |
| words = prompt.split() | |
| if len(words) <= max_tokens: | |
| return prompt | |
| truncated = ' '.join(words[:max_tokens]) | |
| print(f"Warning: Prompt truncated from {len(words)} to {max_tokens} words") | |
| return truncated | |
| def generate(self, | |
| prompt: str, | |
| reference_images: Optional[List[Union[str, Image.Image]]] = None, | |
| negative_prompt: Optional[str] = None, | |
| width: Optional[int] = None, | |
| height: Optional[int] = None, | |
| num_images: int = 1, | |
| seed: Optional[int] = None) -> Union[Image.Image, List[Image.Image]]: | |
| """Generate monster image from prompt""" | |
| try: | |
| # Load model if needed | |
| self.load_model() | |
| # Truncate prompt to avoid CLIP token limit issues | |
| prompt = self._truncate_prompt(prompt) | |
| if negative_prompt: | |
| negative_prompt = self._truncate_prompt(negative_prompt) | |
| # Set dimensions | |
| width = width or self.default_width | |
| height = height or self.default_height | |
| # Ensure dimensions are multiples of 8 | |
| width = (width // 8) * 8 | |
| height = (height // 8) * 8 | |
| # Enhance prompt for monster generation | |
| enhanced_prompt = self._enhance_prompt(prompt) | |
| # Default negative prompt for quality | |
| if negative_prompt is None: | |
| negative_prompt = ( | |
| "low quality, blurry, distorted, disfigured, " | |
| "bad anatomy, wrong proportions, ugly, duplicate, " | |
| "morbid, mutilated, extra limbs, malformed" | |
| ) | |
| # Set seed for reproducibility | |
| generator = None | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Generate images | |
| with torch.no_grad(): | |
| if hasattr(self.pipeline, '__call__'): | |
| # Standard diffusion pipeline | |
| images = self.pipeline( | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=self.num_inference_steps, | |
| guidance_scale=self.guidance_scale, | |
| num_images_per_prompt=num_images, | |
| generator=generator | |
| ).images | |
| else: | |
| # OmniGen specific generation (if different API) | |
| images = self._omnigen_generate( | |
| enhanced_prompt, | |
| reference_images, | |
| width, | |
| height, | |
| num_images | |
| ) | |
| # Clean up memory | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| # Return single image or list | |
| if num_images == 1: | |
| return images[0] | |
| return images | |
| except Exception as e: | |
| print(f"Image generation error: {e}") | |
| # Return fallback image | |
| return self._generate_fallback_image(width, height) | |
| def _enhance_prompt(self, base_prompt: str) -> str: | |
| """Enhance prompt for better monster generation""" | |
| enhancements = [ | |
| "digital art", | |
| "creature design", | |
| "game character", | |
| "detailed", | |
| "vibrant colors", | |
| "fantasy creature", | |
| "high quality", | |
| "professional artwork" | |
| ] | |
| # Combine base prompt with enhancements | |
| enhanced = f"{base_prompt}, {', '.join(enhancements)}" | |
| return enhanced | |
| def _omnigen_generate(self, prompt: str, reference_images: Optional[List], | |
| width: int, height: int, num_images: int) -> List[Image.Image]: | |
| """OmniGen specific generation with multimodal inputs""" | |
| # This would be implemented based on OmniGen's specific API | |
| # For now, fall back to standard generation | |
| return self.pipeline( | |
| prompt=prompt, | |
| width=width, | |
| height=height, | |
| num_images_per_prompt=num_images | |
| ).images | |
| def _generate_fallback_image(self, width: int, height: int) -> Image.Image: | |
| """Generate a fallback monster image""" | |
| # Create a simple procedural monster image | |
| img_array = np.zeros((height, width, 3), dtype=np.uint8) | |
| # Add some basic shapes and colors | |
| center_x, center_y = width // 2, height // 2 | |
| radius = min(width, height) // 3 | |
| # Create circular body | |
| y, x = np.ogrid[:height, :width] | |
| mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2 | |
| # Random monster color | |
| color = np.random.randint(50, 200, size=3) | |
| img_array[mask] = color | |
| # Add eyes | |
| eye_y = center_y - radius // 3 | |
| eye_left_x = center_x - radius // 3 | |
| eye_right_x = center_x + radius // 3 | |
| eye_radius = radius // 8 | |
| # Left eye | |
| eye_mask = (x - eye_left_x)**2 + (y - eye_y)**2 <= eye_radius**2 | |
| img_array[eye_mask] = [255, 255, 255] | |
| # Right eye | |
| eye_mask = (x - eye_right_x)**2 + (y - eye_y)**2 <= eye_radius**2 | |
| img_array[eye_mask] = [255, 255, 255] | |
| # Convert to PIL Image | |
| return Image.fromarray(img_array) | |
| def edit_image(self, | |
| image: Union[str, Image.Image], | |
| prompt: str, | |
| mask: Optional[Union[str, Image.Image]] = None) -> Image.Image: | |
| """Edit existing image (for future monster customization)""" | |
| # This would implement image editing capabilities | |
| raise NotImplementedError("Image editing not yet implemented") | |
| def to(self, device: str): | |
| """Move pipeline to specified device""" | |
| self.device = device | |
| if self.pipeline: | |
| if device == "cuda" and self.enable_cpu_offload: | |
| self.pipeline.enable_sequential_cpu_offload() | |
| else: | |
| self.pipeline = self.pipeline.to(device) | |
| def __del__(self): | |
| """Cleanup when object is destroyed""" | |
| if self.pipeline: | |
| del self.pipeline | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() |