Spaces:
Runtime error
Runtime error
| from typing import List, Optional | |
| import math, random, os | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from tqdm.auto import tqdm | |
| from sklearn.decomposition import PCA | |
| def extract_clip_features(clip, image, encoder): | |
| """ | |
| Extracts feature embeddings from an image using either CLIP or DINOv2 models. | |
| Args: | |
| clip (torch.nn.Module): The feature extraction model (either CLIP or DINOv2) | |
| image (torch.Tensor): Input image tensor normalized according to model requirements | |
| encoder (str): Type of encoder to use ('dinov2-small' or 'clip') | |
| Returns: | |
| torch.Tensor: Feature embeddings extracted from the image | |
| Note: | |
| - For DINOv2 models, uses the pooled output features | |
| - For CLIP models, uses the image features from the vision encoder | |
| - The input image should already be properly resized and normalized | |
| """ | |
| # Handle DINOv2 models | |
| if 'dino' in encoder: | |
| denoised = clip(image) | |
| denoised = denoised.pooler_output | |
| # Handle CLIP models | |
| else: | |
| denoised = clip.get_image_features(image) | |
| return denoised | |
| def compute_clip_pca( | |
| diverse_prompts: List[str], | |
| pipe, | |
| clip_model, | |
| clip_processor, | |
| device, | |
| guidance_scale, | |
| params, | |
| total_samples = 5000, | |
| num_pca_components = 100, | |
| batch_size = 10 | |
| ) -> torch.Tensor: | |
| """ | |
| Extract CLIP features from generated images based on prompts. | |
| Args: | |
| diverse_prompts: List of prompts to generate images from | |
| model_components: Various model components needed for generation | |
| args: Training arguments | |
| Returns: | |
| Tensor of CLIP principle components | |
| """ | |
| # Calculate how many total batches we need | |
| num_batches = math.ceil(total_samples / batch_size) | |
| # Randomly sample prompts (with replacement if needed) | |
| sampled_prompts_clip = random.choices(diverse_prompts, k=num_batches) | |
| clip_features_path = f"{params['savepath_training_images']}/clip_principle_directions.pt" | |
| if os.path.exists(clip_features_path): | |
| df = pd.read_csv(f"{params['savepath_training_images']}/training_data.csv") | |
| prompts_training = list(df.prompt) | |
| image_paths = list(df.image_path) | |
| return torch.load(clip_features_path).to(device), prompts_training, image_paths | |
| os.makedirs(params['savepath_training_images'], exist_ok=True) | |
| # Generate images and extract features | |
| img_idx = 0 | |
| clip_features = [] | |
| image_paths = [] | |
| prompts_training = [] | |
| print('Calculating Semantic PCA') | |
| for prompt in tqdm(sampled_prompts_clip): | |
| if 'max_sequence_length' in params: | |
| images = pipe(prompt, | |
| num_images_per_prompt = batch_size, | |
| num_inference_steps = params['max_denoising_steps'], | |
| guidance_scale=guidance_scale, | |
| max_sequence_length = params['max_sequence_length'], | |
| height = params['height'], | |
| width = params['width'], | |
| ).images | |
| else: | |
| images = pipe(prompt, | |
| num_images_per_prompt = batch_size, | |
| num_inference_steps = params['max_denoising_steps'], | |
| guidance_scale=guidance_scale, | |
| height = params['height'], | |
| width = params['width'], | |
| ).images | |
| # Process images | |
| clip_inputs = clip_processor(images=images, return_tensors="pt", padding=True) | |
| pixel_values = clip_inputs['pixel_values'].to(device) | |
| # Get image embeddings | |
| with torch.no_grad(): | |
| image_features = clip_model.get_image_features(pixel_values) | |
| # Normalize embeddings | |
| clip_feats = image_features / image_features.norm(dim=1, keepdim=True) | |
| clip_features.append(clip_feats) | |
| for im in images: | |
| image_path = f"{params['savepath_training_images']}/{img_idx}.png" | |
| im.save(image_path) | |
| image_paths.append(image_path) | |
| prompts_training.append(prompt) | |
| img_idx += 1 | |
| clip_features = torch.cat(clip_features) | |
| # Calculate principle components | |
| pca = PCA(n_components=num_pca_components) | |
| clip_embeds_np = clip_features.float().cpu().numpy() | |
| pca.fit(clip_embeds_np) | |
| clip_principles = torch.from_numpy(pca.components_).to(device, dtype=pipe.vae.dtype) | |
| # Save results | |
| torch.save(clip_principles, clip_features_path) | |
| pd.DataFrame({ | |
| 'prompt': prompts_training, | |
| 'image_path': image_paths | |
| }).to_csv(f"{params['savepath_training_images']}/training_data.csv", index=False) | |
| return clip_principles, prompts_training, image_paths |