Spaces:
Runtime error
Runtime error
| from typing import Optional, Union | |
| import torch | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import UNet2DConditionModel, SchedulerMixin, FluxImg2ImgPipeline | |
| from diffusers.image_processor import VaeImageProcessor | |
| # from model_util import SDXL_TEXT_ENCODER_TYPE | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from tqdm import tqdm | |
| UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 | |
| VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 | |
| UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL | |
| TEXT_ENCODER_2_PROJECTION_DIM = 1280 | |
| UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 | |
| def get_random_noise( | |
| batch_size: int, height: int, width: int, generator: torch.Generator = None | |
| ) -> torch.Tensor: | |
| return torch.randn( | |
| ( | |
| batch_size, | |
| UNET_IN_CHANNELS, | |
| height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや | |
| width // VAE_SCALE_FACTOR, | |
| ), | |
| generator=generator, | |
| device="cpu", | |
| ) | |
| # https://www.crosslabs.org/blog/diffusion-with-offset-noise | |
| def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float): | |
| latents = latents + noise_offset * torch.randn( | |
| (latents.shape[0], latents.shape[1], 1, 1), device=latents.device | |
| ) | |
| return latents | |
| def get_initial_latents( | |
| scheduler: SchedulerMixin, | |
| n_imgs: int, | |
| height: int, | |
| width: int, | |
| n_prompts: int, | |
| generator=None, | |
| ) -> torch.Tensor: | |
| noise = get_random_noise(n_imgs, height, width, generator=generator).repeat( | |
| n_prompts, 1, 1, 1 | |
| ) | |
| latents = noise * scheduler.init_noise_sigma | |
| return latents | |
| def text_tokenize( | |
| tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ! | |
| prompts: list[str], | |
| ): | |
| return tokenizer( | |
| prompts, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids | |
| def text_encode(text_encoder: CLIPTextModel, tokens): | |
| return text_encoder(tokens.to(text_encoder.device))[0] | |
| def encode_prompts( | |
| tokenizer: CLIPTokenizer, | |
| text_encoder: CLIPTokenizer, | |
| prompts: list[str], | |
| ): | |
| text_tokens = text_tokenize(tokenizer, prompts) | |
| text_embeddings = text_encode(text_encoder, text_tokens) | |
| return text_embeddings | |
| # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 | |
| def text_encode_xl( | |
| text_encoder, | |
| tokens: torch.FloatTensor, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| prompt_embeds = text_encoder( | |
| tokens.to(text_encoder.device), output_hidden_states=True | |
| ) | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds, pooled_prompt_embeds | |
| def encode_prompts_xl( | |
| tokenizers, | |
| text_encoders, | |
| prompts: list[str], | |
| num_images_per_prompt: int = 1, | |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: | |
| # text_encoder and text_encoder_2's penuultimate layer's output | |
| text_embeds_list = [] | |
| pooled_text_embeds = None # always text_encoder_2's pool | |
| for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
| text_tokens_input_ids = text_tokenize(tokenizer, prompts) | |
| text_embeds, pooled_text_embeds = text_encode_xl( | |
| text_encoder, text_tokens_input_ids, num_images_per_prompt | |
| ) | |
| text_embeds_list.append(text_embeds) | |
| bs_embed = pooled_text_embeds.shape[0] | |
| pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds | |
| def concat_embeddings( | |
| unconditional: torch.FloatTensor, | |
| conditional: torch.FloatTensor, | |
| n_imgs: int, | |
| ): | |
| return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) | |
| # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721 | |
| def predict_noise( | |
| unet: UNet2DConditionModel, | |
| scheduler: SchedulerMixin, | |
| timestep: int, # 現在のタイムステップ | |
| latents: torch.FloatTensor, | |
| text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの | |
| guidance_scale=7.5, | |
| ) -> torch.FloatTensor: | |
| latent_model_input = latents | |
| if guidance_scale!=0: | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) | |
| # predict the noise residual | |
| noise_pred = unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeddings, | |
| ).sample | |
| # perform guidance | |
| if guidance_scale != 1 and guidance_scale!=0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| return noise_pred | |
| # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 | |
| def diffusion( | |
| unet: UNet2DConditionModel, | |
| scheduler: SchedulerMixin, | |
| latents: torch.FloatTensor, # ただのノイズだけのlatents | |
| text_embeddings: torch.FloatTensor, | |
| total_timesteps: int = 1000, | |
| start_timesteps=0, | |
| guidance_scale=1, | |
| composition=False, | |
| **kwargs, | |
| ): | |
| # latents_steps = [] | |
| for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: | |
| if not composition: | |
| noise_pred = predict_noise( | |
| unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale | |
| ) | |
| if guidance_scale==1: | |
| _, noise_pred = noise_pred.chunk(2) | |
| else: | |
| for idx in range(text_embeddings.shape[0]): | |
| pred = predict_noise( | |
| unet, scheduler, timestep, latents, text_embeddings[idx:idx+1], guidance_scale=1 | |
| ) | |
| uncond, pred = noise_pred.chunk(2) | |
| if idx == 0: | |
| noise_pred = guidance_scale * pred | |
| else: | |
| noise_pred += guidance_scale * pred | |
| noise_pred += uncond | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = scheduler.step(noise_pred, timestep, latents).prev_sample | |
| # return latents_steps | |
| return latents | |
| def rescale_noise_cfg( | |
| noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0 | |
| ): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std( | |
| dim=list(range(1, noise_pred_text.ndim)), keepdim=True | |
| ) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = ( | |
| guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| ) | |
| return noise_cfg | |
| def predict_noise_xl( | |
| unet: UNet2DConditionModel, | |
| scheduler: SchedulerMixin, | |
| timestep: int, # 現在のタイムステップ | |
| latents: torch.FloatTensor, | |
| text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの | |
| add_text_embeddings: torch.FloatTensor, # pooled なやつ | |
| add_time_ids: torch.FloatTensor, | |
| guidance_scale=7.5, | |
| guidance_rescale=0.7, | |
| ) -> torch.FloatTensor: | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = latents | |
| if guidance_scale !=0: | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) | |
| added_cond_kwargs = { | |
| "text_embeds": add_text_embeddings, | |
| "time_ids": add_time_ids, | |
| } | |
| # predict the noise residual | |
| noise_pred = unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeddings, | |
| added_cond_kwargs=added_cond_kwargs, | |
| ).sample | |
| # perform guidance | |
| if guidance_scale != 1 and guidance_scale!=0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| return noise_pred | |
| # # perform guidance | |
| # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| # guided_target = noise_pred_uncond + guidance_scale * ( | |
| # noise_pred_text - noise_pred_uncond | |
| # ) | |
| # # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 | |
| # noise_pred = rescale_noise_cfg( | |
| # noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | |
| # ) | |
| # return guided_target | |
| def diffusion_xl( | |
| unet: UNet2DConditionModel, | |
| scheduler: SchedulerMixin, | |
| latents: torch.FloatTensor, # ただのノイズだけのlatents | |
| text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor], | |
| add_text_embeddings: torch.FloatTensor, # pooled なやつ | |
| add_time_ids: torch.FloatTensor, | |
| guidance_scale: float = 1.0, | |
| total_timesteps: int = 1000, | |
| start_timesteps=0, | |
| composition=False, | |
| ): | |
| # latents_steps = [] | |
| for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: | |
| if not composition: | |
| noise_pred = predict_noise_xl( | |
| unet, | |
| scheduler, | |
| timestep, | |
| latents, | |
| text_embeddings, | |
| add_text_embeddings, | |
| add_time_ids, | |
| guidance_scale=guidance_scale, | |
| guidance_rescale=0.7, | |
| ) | |
| if guidance_scale==1: | |
| _, noise_pred = noise_pred.chunk(2) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = scheduler.step(noise_pred, timestep, latents).prev_sample | |
| # return latents_steps | |
| return latents | |
| # for XL | |
| def get_add_time_ids( | |
| height: int, | |
| width: int, | |
| dynamic_crops: bool = False, | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| if dynamic_crops: | |
| # random float scale between 1 and 3 | |
| random_scale = torch.rand(1).item() * 2 + 1 | |
| original_size = (int(height * random_scale), int(width * random_scale)) | |
| # random position | |
| crops_coords_top_left = ( | |
| torch.randint(0, original_size[0] - height, (1,)).item(), | |
| torch.randint(0, original_size[1] - width, (1,)).item(), | |
| ) | |
| target_size = (height, width) | |
| else: | |
| original_size = (height, width) | |
| crops_coords_top_left = (0, 0) | |
| target_size = (height, width) | |
| # this is expected as 6 | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| # this is expected as 2816 | |
| passed_add_embed_dim = ( | |
| UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 | |
| + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 | |
| ) | |
| if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: | |
| raise ValueError( | |
| f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." | |
| ) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) | |
| return add_time_ids | |
| def get_optimizer(name: str): | |
| name = name.lower() | |
| if name.startswith("dadapt"): | |
| import dadaptation | |
| if name == "dadaptadam": | |
| return dadaptation.DAdaptAdam | |
| elif name == "dadaptlion": | |
| return dadaptation.DAdaptLion | |
| else: | |
| raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion") | |
| elif name.endswith("8bit"): # 検証してない | |
| import bitsandbytes as bnb | |
| if name == "adam8bit": | |
| return bnb.optim.Adam8bit | |
| elif name == "lion8bit": | |
| return bnb.optim.Lion8bit | |
| else: | |
| raise ValueError("8bit optimizer must be adam8bit or lion8bit") | |
| else: | |
| if name == "adam": | |
| return torch.optim.Adam | |
| elif name == "adamw": | |
| return torch.optim.AdamW | |
| elif name == "lion": | |
| from lion_pytorch import Lion | |
| return Lion | |
| elif name == "prodigy": | |
| import prodigyopt | |
| return prodigyopt.Prodigy | |
| else: | |
| raise ValueError("Optimizer must be adam, adamw, lion or Prodigy") | |
| def get_noisy_image( | |
| image, | |
| vae, | |
| unet, | |
| scheduler, | |
| timesteps_to = 1000, | |
| generator=None, | |
| **kwargs, | |
| ): | |
| # latents_steps = [] | |
| vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) | |
| device = vae.device | |
| image = image_processor.preprocess(image).to(device).to(vae.dtype) | |
| init_latents = vae.encode(image).latents | |
| init_latents = vae.config.scaling_factor * init_latents | |
| init_latents = torch.cat([init_latents], dim=0) | |
| shape = init_latents.shape | |
| noise = randn_tensor(shape, generator=generator, device=device) | |
| timestep = scheduler.timesteps[timesteps_to:timesteps_to+1] | |
| # get latents | |
| init_latents = scheduler.add_noise(init_latents, noise, timestep) | |
| return init_latents, noise | |
| def get_lr_scheduler( | |
| name: Optional[str], | |
| optimizer: torch.optim.Optimizer, | |
| max_iterations: Optional[int], | |
| lr_min: Optional[float], | |
| **kwargs, | |
| ): | |
| if name == "cosine": | |
| return torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs | |
| ) | |
| elif name == "cosine_with_restarts": | |
| return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs | |
| ) | |
| elif name == "step": | |
| return torch.optim.lr_scheduler.StepLR( | |
| optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs | |
| ) | |
| elif name == "constant": | |
| return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs) | |
| elif name == "linear": | |
| return torch.optim.lr_scheduler.LinearLR( | |
| optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs | |
| ) | |
| else: | |
| raise ValueError( | |
| "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" | |
| ) | |
| def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]: | |
| max_resolution = bucket_resolution | |
| min_resolution = bucket_resolution // 2 | |
| step = 64 | |
| min_step = min_resolution // step | |
| max_step = max_resolution // step | |
| height = torch.randint(min_step, max_step, (1,)).item() * step | |
| width = torch.randint(min_step, max_step, (1,)).item() * step | |
| return height, width | |
| def _get_t5_prompt_embeds( | |
| text_encoder, | |
| tokenizer, | |
| prompt, | |
| max_sequence_length=512, | |
| device=None, | |
| dtype=None | |
| ): | |
| """Helper function to get T5 embeddings in Flux format""" | |
| device = device or text_encoder.device | |
| dtype = dtype or text_encoder.dtype | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| return prompt_embeds | |
| def _get_clip_prompt_embeds( | |
| text_encoder, | |
| tokenizer, | |
| prompt, | |
| device=None, | |
| ): | |
| """Helper function to get CLIP embeddings in Flux format""" | |
| device = device or text_encoder.device | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
| # Use pooled output for Flux | |
| prompt_embeds = prompt_embeds.pooler_output | |
| prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) | |
| return prompt_embeds | |
| def get_noisy_image_flux( | |
| image, | |
| vae, | |
| transformer, | |
| scheduler, | |
| timesteps_to=1000, | |
| generator=None, | |
| params = None | |
| ): | |
| """ | |
| Gets noisy latents for a given image using Flux pipeline approach. | |
| Args: | |
| image (Union[PIL.Image.Image, torch.Tensor]): Input image | |
| vae (AutoencoderKL): Flux VAE model | |
| transformer (FluxTransformer2DModel): Flux transformer model | |
| scheduler (FlowMatchEulerDiscreteScheduler): Flux noise scheduler | |
| timesteps_to (int, optional): Target timestep. Defaults to 1000. | |
| generator (torch.Generator, optional): Random generator for reproducibility. | |
| Returns: | |
| tuple: (noisy_latents, noise) - Both in packed Flux format | |
| """ | |
| vae_scale_factor = params['vae_scale_factor'] | |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) | |
| image = image_processor.preprocess(image, height=params['height'], width=params['width']) | |
| image = image.to(dtype=torch.float32) | |
| # 5. Prepare latent variables | |
| num_channels_latents = transformer.config.in_channels // 4 | |
| latents, latent_image_ids = prepare_latents_flux( | |
| image, | |
| timesteps_to.repeat(params['batchsize']), | |
| params['batchsize'], | |
| num_channels_latents, | |
| params['height'], | |
| params['width'], | |
| transformer.dtype, | |
| transformer.device, | |
| generator, | |
| None, | |
| vae_scale_factor, | |
| vae, | |
| scheduler | |
| ) | |
| return latents, latent_image_ids | |
| def _pack_latents(latents, batch_size, num_channels_latents, height, width): | |
| """ | |
| Pack latents into Flux's 2x2 patch format | |
| """ | |
| latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | |
| latents = latents.permute(0, 2, 4, 1, 3, 5) | |
| latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | |
| return latents | |
| def _unpack_latents(latents, height, width, vae_scale_factor): | |
| """ | |
| Unpack latents from Flux's 2x2 patch format back to image space | |
| """ | |
| batch_size, num_patches, channels = latents.shape | |
| # Account for VAE compression and packing | |
| height = 2 * (int(height) // (vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (vae_scale_factor * 2)) | |
| latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) | |
| latents = latents.permute(0, 3, 1, 4, 2, 5) | |
| latents = latents.reshape(batch_size, channels // (2 * 2), height, width) | |
| return latents | |
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype): | |
| latent_image_ids = torch.zeros(height, width, 3) | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
| latent_image_ids = latent_image_ids.reshape( | |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
| ) | |
| return latent_image_ids.to(device=device, dtype=dtype) | |
| def prepare_latents_flux( | |
| image, | |
| timestep, | |
| batch_size, | |
| num_channels_latents, | |
| height, | |
| width, | |
| dtype, | |
| device, | |
| generator, | |
| latents=None, | |
| vae_scale_factor=None, | |
| vae=None, | |
| scheduler=None | |
| ): | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| # VAE applies 8x compression on images but we must also account for packing which requires | |
| # latent height and width to be divisible by 2. | |
| height = 2 * (int(height) // (vae_scale_factor * 2)) | |
| width = 2 * (int(width) // (vae_scale_factor * 2)) | |
| shape = (batch_size, num_channels_latents, height, width) | |
| latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) | |
| if latents is not None: | |
| return latents.to(device=device, dtype=dtype), latent_image_ids | |
| image = image.to(device=device, dtype=dtype) | |
| image_latents = _encode_vae_image(vae=vae, image=image, generator=generator) | |
| if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: | |
| # expand init_latents for batch_size | |
| additional_image_per_prompt = batch_size // image_latents.shape[0] | |
| image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) | |
| elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: | |
| raise ValueError( | |
| f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." | |
| ) | |
| else: | |
| image_latents = torch.cat([image_latents], dim=0) | |
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| latents = scheduler.scale_noise(image_latents, timestep, noise) | |
| latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) | |
| return latents, latent_image_ids | |
| def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): | |
| if isinstance(generator, list): | |
| image_latents = [ | |
| retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) | |
| for i in range(image.shape[0]) | |
| ] | |
| image_latents = torch.cat(image_latents, dim=0) | |
| else: | |
| image_latents = retrieve_latents(vae.encode(image), generator=generator) | |
| image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor | |
| return image_latents | |
| def retrieve_latents( | |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | |
| ): | |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
| return encoder_output.latent_dist.sample(generator) | |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
| return encoder_output.latent_dist.mode() | |
| elif hasattr(encoder_output, "latents"): | |
| return encoder_output.latents | |
| else: | |
| raise AttributeError("Could not access latents of provided encoder_output") |