Spaces:
Sleeping
Sleeping
| import torch | |
| from tqdm import tqdm | |
| import utils | |
| from PIL import Image | |
| import gc | |
| import numpy as np | |
| from .attention import GatedSelfAttentionDense | |
| from .models import torch_device | |
| def encode(model_dict, image, generator): | |
| """ | |
| image should be a PIL object or numpy array with range 0 to 255 | |
| """ | |
| vae, dtype = model_dict.vae, model_dict.dtype | |
| if isinstance(image, Image.Image): | |
| w, h = image.size | |
| assert w % 8 == 0 and h % 8 == 0, f"h ({h}) and w ({w}) should be a multiple of 8" | |
| # w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 | |
| # image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :] | |
| image = np.array(image) | |
| if isinstance(image, np.ndarray): | |
| assert image.dtype == np.uint8, f"Should have dtype uint8 (dtype: {image.dtype})" | |
| image = image.astype(np.float32) / 255.0 | |
| image = image[None, ...] | |
| image = image.transpose(0, 3, 1, 2) | |
| image = 2.0 * image - 1.0 | |
| image = torch.from_numpy(image) | |
| assert isinstance(image, torch.Tensor), f"type of image: {type(image)}" | |
| image = image.to(device=torch_device, dtype=dtype) | |
| latents = vae.encode(image).latent_dist.sample(generator) | |
| latents = vae.config.scaling_factor * latents | |
| return latents | |
| def decode(vae, latents): | |
| # scale and decode the image latents with vae | |
| scaled_latents = 1 / 0.18215 * latents | |
| with torch.no_grad(): | |
| image = vae.decode(scaled_latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| images = (image * 255).round().astype("uint8") | |
| return images | |
| def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False): | |
| vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype | |
| text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings | |
| if not no_set_timesteps: | |
| scheduler.set_timesteps(num_inference_steps) | |
| for t in tqdm(scheduler.timesteps): | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) | |
| # predict the noise residual | |
| with torch.no_grad(): | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| images = decode(vae, latents) | |
| ret = [latents, images] | |
| return tuple(ret) | |
| def gligen_enable_fuser(unet, enabled=True): | |
| for module in unet.modules(): | |
| if isinstance(module, GatedSelfAttentionDense): | |
| module.enabled = enabled | |
| def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5, | |
| frozen_steps=20, frozen_mask=None, | |
| return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None, | |
| offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True, | |
| semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None, | |
| return_box_vis=False, show_progress=True, save_all_latents=False): | |
| """ | |
| The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases). | |
| """ | |
| vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype | |
| text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings | |
| if latents.dim() == 5: | |
| # latents_all from the input side, different from the latents_all to be saved | |
| latents_all_input = latents | |
| latents = latents[0] | |
| else: | |
| latents_all_input = None | |
| # Just in case that we have in-place ops | |
| latents = latents.clone() | |
| if save_all_latents: | |
| # offload to cpu to save space | |
| if offload_latents_to_cpu: | |
| latents_all = [latents.cpu()] | |
| else: | |
| latents_all = [latents] | |
| scheduler.set_timesteps(num_inference_steps) | |
| if frozen_mask is not None: | |
| frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.) | |
| batch_size = 1 | |
| # 5.1 Prepare GLIGEN variables | |
| assert len(phrases) == len(bboxes) | |
| # assert batch_size == 1 | |
| max_objs = 30 | |
| _boxes = bboxes | |
| n_objs = min(len(_boxes), max_objs) | |
| boxes = torch.zeros(max_objs, 4, device=torch_device, dtype=dtype) | |
| phrase_embeddings = torch.zeros(max_objs, 768, device=torch_device, dtype=dtype) | |
| masks = torch.zeros(max_objs, device=torch_device, dtype=dtype) | |
| if n_objs > 0: | |
| boxes[:n_objs] = torch.tensor(_boxes[:n_objs]) | |
| tokenizer_inputs = tokenizer(phrases, padding=True, return_tensors="pt").to(torch_device) | |
| _phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output | |
| phrase_embeddings[:n_objs] = _phrase_embeddings[:n_objs] | |
| masks[:n_objs] = 1 | |
| # Classifier-free guidance | |
| repeat_batch = batch_size * num_images_per_prompt * 2 | |
| boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() | |
| phrase_embeddings = phrase_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() | |
| masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() | |
| masks[:repeat_batch // 2] = 0 | |
| if semantic_guidance_bboxes and semantic_guidance: | |
| loss = torch.tensor(10000.) | |
| # TODO: we can also save necessary tokens only to save memory. | |
| # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep. | |
| guidance_cross_attention_kwargs = { | |
| 'offload_cross_attn_to_cpu': False, | |
| 'enable_flash_attn': False, | |
| 'gligen': { | |
| 'boxes': boxes[:repeat_batch // 2], | |
| 'positive_embeddings': phrase_embeddings[:repeat_batch // 2], | |
| 'masks': masks[:repeat_batch // 2], | |
| 'fuser_attn_kwargs': { | |
| 'enable_flash_attn': False, | |
| } | |
| } | |
| } | |
| if return_saved_cross_attn: | |
| saved_attns = [] | |
| main_cross_attention_kwargs = { | |
| 'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu, | |
| 'return_cond_ca_only': return_cond_ca_only, | |
| 'return_token_ca_only': return_token_ca_only, | |
| 'save_keys': saved_cross_attn_keys, | |
| 'gligen': { | |
| 'boxes': boxes, | |
| 'positive_embeddings': phrase_embeddings, | |
| 'masks': masks | |
| } | |
| } | |
| timesteps = scheduler.timesteps | |
| num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) | |
| gligen_enable_fuser(unet, True) | |
| for index, t in enumerate(tqdm(timesteps, disable=not show_progress)): | |
| # Scheduled sampling | |
| if index == num_grounding_steps: | |
| gligen_enable_fuser(unet, False) | |
| if semantic_guidance_bboxes and semantic_guidance: | |
| with torch.enable_grad(): | |
| latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs) | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) | |
| main_cross_attention_kwargs['save_attn_to_dict'] = {} | |
| # predict the noise residual | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, | |
| cross_attention_kwargs=main_cross_attention_kwargs).sample | |
| if return_saved_cross_attn: | |
| saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict']) | |
| del main_cross_attention_kwargs['save_attn_to_dict'] | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| if frozen_mask is not None and index < frozen_steps: | |
| latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask) | |
| if save_all_latents: | |
| if offload_latents_to_cpu: | |
| latents_all.append(latents.cpu()) | |
| else: | |
| latents_all.append(latents) | |
| # Turn off fuser for typical SD | |
| gligen_enable_fuser(unet, False) | |
| images = decode(vae, latents) | |
| ret = [latents, images] | |
| if return_saved_cross_attn: | |
| ret.append(saved_attns) | |
| if return_box_vis: | |
| pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images] | |
| ret.append(pil_images) | |
| if save_all_latents: | |
| latents_all = torch.stack(latents_all, dim=0) | |
| ret.append(latents_all) | |
| return tuple(ret) | |