Spaces:
Paused
Paused
| import logging | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Dict, Sequence, Union, List, Tuple, Any | |
| from transformers import ( | |
| LlamaForCausalLM, | |
| Blip2PreTrainedModel, | |
| Blip2VisionModel, | |
| Blip2Config, | |
| Blip2QFormerModel, | |
| GenerationConfig, | |
| ) | |
| from transformers.utils import ModelOutput | |
| warnings.filterwarnings('ignore') | |
| logger = logging.getLogger(__name__) | |
| class Blip2ForConditionalGenerationModelOutput(ModelOutput): | |
| """ | |
| Class defining the outputs of [`Blip2ForConditionalGeneration`]. | |
| Args: | |
| loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): | |
| Language modeling loss from the language model. | |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
| Prediction scores of the language modeling head of the language model. | |
| vision_outputs (`BaseModelOutputWithPooling`): | |
| Outputs of the vision encoder. | |
| qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): | |
| Outputs of the Q-Former (Querying Transformer). | |
| language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): | |
| Outputs of the language model. | |
| """ | |
| loss: Optional[Tuple[torch.FloatTensor]] = None | |
| logits: Optional[Tuple[torch.FloatTensor]] = None | |
| vision_outputs: Optional[torch.FloatTensor] = None | |
| qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None | |
| language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None | |
| def to_tuple(self) -> Tuple[Any]: | |
| return tuple( | |
| self[k] | |
| if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] | |
| else getattr(self, k).to_tuple() | |
| for k in self.keys() | |
| ) | |
| class Blip2LlaMAForConditionalGeneration(Blip2PreTrainedModel): | |
| config_class = Blip2Config | |
| main_input_name = "pixel_values" | |
| def __init__(self, config: Blip2Config): | |
| super().__init__(config) | |
| self.vision_model = Blip2VisionModel(config.vision_config) | |
| self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) | |
| self.qformer = Blip2QFormerModel(config.qformer_config) | |
| language_model = LlamaForCausalLM(config.text_config) | |
| self.language_model = language_model | |
| self.language_projection = nn.Linear(config.qformer_config.hidden_size, language_model.config.hidden_size) | |
| self.config.hidden_size = config.text_config.hidden_size | |
| self.num_queries = config.num_query_tokens | |
| self.offset = 5 | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| def set_output_embeddings(self, new_embeddings): | |
| self.language_model.set_output_embeddings(new_embeddings) | |
| def get_output_embeddings(self) -> nn.Module: | |
| return self.language_model.get_output_embeddings() | |
| def get_encoder(self): | |
| return self.language_model.get_encoder() | |
| def get_decoder(self): | |
| return self.language_model.get_decoder() | |
| def extract_feature( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| ): | |
| image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state | |
| image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
| query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
| query_outputs = self.qformer( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_attention_mask, | |
| return_dict=True, | |
| ) | |
| query_output = query_outputs.last_hidden_state | |
| language_model_inputs = self.language_projection(query_output) | |
| return language_model_inputs | |
| def _tie_weights(self): | |
| if not self.config.use_decoder_only_language_model: | |
| self.language_model.encoder.embed_tokens = self.language_model.shared | |
| self.language_model.decoder.embed_tokens = self.language_model.shared | |
| def _preprocess_accelerate(self): | |
| r""" | |
| Some pre-processing hacks to make the model `accelerate` compatible. Check | |
| https://github.com/huggingface/transformers/pull/21707 for more details. | |
| """ | |
| hf_device_map = self.hf_device_map | |
| if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: | |
| # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. | |
| logger.warning( | |
| "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" | |
| " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." | |
| " Please pass a `device_map` that contains `language_model` to remove this warning." | |
| " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for", | |
| " more details on creating a `device_map` for large models.", | |
| ) | |
| if hasattr(self.language_model, "_hf_hook"): | |
| self.language_model._hf_hook.io_same_device = True # For `generate` compatibility | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| input_ids: torch.FloatTensor, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # step 1: forward the images through the vision encoder, | |
| # to get image embeddings of shape (batch_size, seq_len, hidden_size) | |
| vision_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| image_embeds = vision_outputs[0] | |
| # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention | |
| image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
| query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
| query_outputs = self.qformer( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| query_output = query_outputs[0] | |
| # step 3: use the language model, conditioned on the query outputs and the prompt | |
| language_model_inputs = self.language_projection(query_output) | |
| assert language_model_inputs.shape[1] == self.num_queries | |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| # Human: <img><IMAGE></img>. Give the describe Assistant: | |
| # position of <image>: [offset: offset+num_queries] | |
| inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| outputs = self.language_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| logits = outputs.logits if return_dict else outputs[0] | |
| loss = None | |
| # we compute the loss here since we need to take into account the sequence length of the query embeds | |
| if labels is not None: | |
| logits = logits[:, -labels.size(1):, :] | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous().to(logits.device).to(torch.long) | |
| # Flatten the tokens | |
| loss_fct = nn.CrossEntropyLoss(reduction="mean") | |
| loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) | |
| if not return_dict: | |
| output = (logits, vision_outputs, query_outputs, outputs) | |
| return ((loss,) + output) if loss is not None else output | |
| return Blip2ForConditionalGenerationModelOutput( | |
| loss=loss, | |
| logits=logits, | |
| vision_outputs=vision_outputs, | |
| qformer_outputs=query_outputs, | |
| language_model_outputs=outputs, | |
| ) | |
| def generate( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| language_model_inputs: Optional[torch.FloatTensor] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| **generate_kwargs, | |
| ) -> torch.LongTensor: | |
| """ | |
| Overrides `generate` function to be able to use the model as a conditional generator. | |
| Args: | |
| pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): | |
| Input images to be processed. | |
| input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
| The sequence used as a prompt for the generation. | |
| attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
| Mask to avoid performing attention on padding token indices | |
| generation_config (`~generation.GenerationConfig`, *optional*): | |
| The generation configuration to be used as base parametrization for the generation call. `**kwargs` | |
| passed to generate matching the attributes of `generation_config` will override them. If | |
| `generation_config` is not provided, the default will be used, which had the following loading | |
| priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model | |
| configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s | |
| default values, whose documentation should be checked to parameterize generation. | |
| Returns: | |
| captions (list): A list of strings of length batch_size * num_captions. | |
| """ | |
| if hasattr(self, "hf_device_map"): | |
| # preprocess for `accelerate` | |
| self._preprocess_accelerate() | |
| if language_model_inputs is None: | |
| batch_size = pixel_values.shape[0] | |
| image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state | |
| image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
| query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
| query_outputs = self.qformer( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_attention_mask, | |
| return_dict=True, | |
| ) | |
| query_output = query_outputs.last_hidden_state | |
| language_model_inputs = self.language_projection(query_output) | |
| assert language_model_inputs.shape[1] == self.num_queries | |
| if input_ids is None: | |
| input_ids = ( | |
| torch.LongTensor([[self.config.text_config.bos_token_id]]) | |
| .repeat(batch_size, 1) | |
| .to(image_embeds.device) | |
| ) | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| # position of <image>: [offset: offset+num_queries] | |
| inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs | |
| outputs = self.language_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| generation_config=generation_config, | |
| **generate_kwargs, | |
| ) | |
| return outputs | |