Spaces:
Running
Running
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| # transformers == 4.45.1 | |
| from .configuration_unirec import UniRecConfig | |
| from transformers import M2M100PreTrainedModel | |
| from transformers.models.m2m_100.modeling_m2m_100 import M2M100ScaledWordEmbedding, M2M100Decoder | |
| from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput | |
| from transformers.generation import GenerationMixin | |
| from openrec.modeling.encoders.focalsvtr import FocalSVTR | |
| from transformers.utils import is_flash_attn_2_available | |
| if is_flash_attn_2_available(): | |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward | |
| def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, | |
| decoder_start_token_id: int): | |
| """ | |
| Shift input ids one token to the right. | |
| """ | |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
| shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
| shifted_input_ids[:, 0] = decoder_start_token_id | |
| if pad_token_id is None: | |
| raise ValueError('self.model.config.pad_token_id has to be defined.') | |
| # replace possible -100 values in labels by `pad_token_id` | |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | |
| return shifted_input_ids | |
| class UniRecEncoder(M2M100PreTrainedModel): | |
| """ | |
| Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a | |
| [`M2M100EncoderLayer`]. | |
| Args: | |
| config: UniRecConfig | |
| """ | |
| def __init__(self, config: UniRecConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.vision_encoder = FocalSVTR(img_size=[1408, 960], | |
| depths=[2, 2, 9, 2], | |
| embed_dim=96, | |
| sub_k=[[2, 2], [2, 2], [2, 2], | |
| [-1, -1]], | |
| focal_levels=[3, 3, 3, 3], | |
| max_khs=[7, 3, 3, 3], | |
| focal_windows=[3, 3, 3, 3], | |
| last_stage=False, | |
| feat2d=False) | |
| self.vision_fc = nn.Linear(config.dims[-1], config.d_model) | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor = None, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ): | |
| r""" | |
| Args: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you | |
| provide it. | |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
| [`PreTrainedTokenizer.__call__`] for details. | |
| [What are input IDs?](../glossary#input-ids) | |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
| - 1 for tokens that are **not masked**, | |
| - 0 for tokens that are **masked**. | |
| [What are attention masks?](../glossary#attention-mask) | |
| head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): | |
| Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: | |
| - 1 indicates the head is **not masked**, | |
| - 0 indicates the head is **masked**. | |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
| than the model's internal embedding lookup matrix. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
| for more detail. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = (output_hidden_states | |
| if output_hidden_states is not None else | |
| self.config.output_hidden_states) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # print('visionencoder pixel_values', pixel_values) | |
| # retrieve input_ids and inputs_embeds | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| hidden_states = self.vision_encoder(pixel_values) | |
| hidden_states = self.vision_fc(hidden_states) | |
| # hidden_states = self.layer_norm(hidden_states) | |
| if output_hidden_states: | |
| encoder_states = (hidden_states, ) | |
| if not return_dict: | |
| return tuple( | |
| v for v in [hidden_states, encoder_states, all_attentions] | |
| if v is not None) | |
| return BaseModelOutput(last_hidden_state=hidden_states, | |
| hidden_states=encoder_states, | |
| attentions=all_attentions) | |
| class UniRecModel(M2M100PreTrainedModel): | |
| _tied_weights_keys = [ | |
| 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight' | |
| ] | |
| def __init__(self, config: UniRecConfig): | |
| super().__init__(config) | |
| padding_idx, vocab_size = config.pad_token_id, config.vocab_size | |
| embed_scale = math.sqrt( | |
| config.d_model) if config.scale_embedding else 1.0 | |
| self.shared = M2M100ScaledWordEmbedding(vocab_size, | |
| config.d_model, | |
| padding_idx, | |
| embed_scale=embed_scale) | |
| self.encoder = UniRecEncoder(config) | |
| self.decoder = M2M100Decoder(config, self.shared) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.shared | |
| def set_input_embeddings(self, value): | |
| self.shared = value | |
| self.decoder.embed_tokens = self.shared | |
| def _tie_weights(self): | |
| if self.config.tie_word_embeddings: | |
| self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) | |
| def get_encoder(self): | |
| return self.encoder | |
| def get_decoder(self): | |
| return self.decoder | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor = None, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = (output_hidden_states | |
| if output_hidden_states is not None else | |
| self.config.output_hidden_states) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if encoder_outputs is None: | |
| encoder_outputs = self.encoder(pixel_values) | |
| # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True | |
| elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |
| encoder_outputs = BaseModelOutput( | |
| last_hidden_state=encoder_outputs[0], | |
| hidden_states=encoder_outputs[1] | |
| if len(encoder_outputs) > 1 else None, | |
| attentions=encoder_outputs[2] | |
| if len(encoder_outputs) > 2 else None, | |
| ) | |
| attention_mask = torch.ones(encoder_outputs[0].shape[:2], | |
| dtype=torch.long, | |
| device=encoder_outputs[0].device) | |
| # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| attention_mask=decoder_attention_mask, | |
| encoder_hidden_states=encoder_outputs[0], | |
| encoder_attention_mask=attention_mask, | |
| head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if not return_dict: | |
| return decoder_outputs + encoder_outputs | |
| return Seq2SeqModelOutput( | |
| last_hidden_state=decoder_outputs.last_hidden_state, | |
| past_key_values=decoder_outputs.past_key_values, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| decoder_attentions=decoder_outputs.attentions, | |
| cross_attentions=decoder_outputs.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
| encoder_hidden_states=encoder_outputs.hidden_states, | |
| encoder_attentions=encoder_outputs.attentions, | |
| ) | |
| class UniRecForConditionalGenerationNew(M2M100PreTrainedModel, | |
| GenerationMixin): | |
| base_model_prefix = 'model' | |
| _tied_weights_keys = [ | |
| 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', | |
| 'lm_head.weight' | |
| ] | |
| def __init__(self, config: UniRecConfig): | |
| super().__init__(config) | |
| self.model = UniRecModel(config) | |
| self.lm_head = nn.Linear(config.d_model, | |
| self.model.shared.num_embeddings, | |
| bias=False) | |
| self.loss_fct = CrossEntropyLoss( | |
| ignore_index=config.pad_token_id, | |
| label_smoothing=config.label_smoothing) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_encoder(self): | |
| return self.model.get_encoder() | |
| def get_decoder(self): | |
| return self.model.get_decoder() | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| length: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| Returns: | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None: | |
| if decoder_input_ids is None: | |
| # decoder_input_ids = shift_tokens_right( | |
| # labels, self.config.pad_token_id, self.config.decoder_start_token_id | |
| # ) | |
| if length is not None: | |
| max_len = length.max() | |
| decoder_input_ids = labels[:, :1 + max_len] | |
| labels = labels[:, 1:2 + max_len] | |
| else: | |
| decoder_input_ids = labels[:, :-1] | |
| labels = labels[:, 1:] | |
| masked_lm_loss = None | |
| if self.training and labels is not None: | |
| outputs = self.model( | |
| pixel_values=pixel_values, | |
| input_ids=None, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| encoder_outputs=encoder_outputs, | |
| decoder_attention_mask=decoder_attention_mask, | |
| head_mask=head_mask, | |
| decoder_head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| lm_logits = self.lm_head(outputs[0]) | |
| masked_lm_loss = self.loss_fct( | |
| lm_logits.reshape(-1, self.config.vocab_size), | |
| labels.reshape(-1)) | |
| else: | |
| # print('pixel_values', pixel_values.shape) | |
| # print('decoder_input_ids', decoder_input_ids) | |
| outputs = self.model( | |
| pixel_values=pixel_values, | |
| input_ids=None, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| encoder_outputs=encoder_outputs, | |
| decoder_attention_mask=decoder_attention_mask, | |
| head_mask=head_mask, | |
| decoder_head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| lm_logits = self.lm_head(outputs[0]) | |
| if not return_dict: | |
| output = (lm_logits, ) + outputs[1:] | |
| return ((masked_lm_loss, ) + | |
| output) if masked_lm_loss is not None else output | |
| return Seq2SeqLMOutput( | |
| loss=masked_lm_loss, | |
| logits=lm_logits, | |
| past_key_values=outputs.past_key_values, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| cross_attentions=outputs.cross_attentions, | |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
| encoder_hidden_states=outputs.encoder_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| decoder_input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| head_mask=None, | |
| decoder_head_mask=None, | |
| cross_attn_head_mask=None, | |
| use_cache=None, | |
| encoder_outputs=None, | |
| pixel_values=None, | |
| **kwargs, | |
| ): | |
| # cut decoder_input_ids if past is used | |
| if past_key_values is not None: | |
| past_length = past_key_values[0][0].shape[2] | |
| # Some generation methods already pass only the last input ID | |
| if decoder_input_ids.shape[1] > past_length: | |
| remove_prefix_length = past_length | |
| else: | |
| # Default to old behavior: keep only final ID | |
| remove_prefix_length = decoder_input_ids.shape[1] - 1 | |
| decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] | |
| return { | |
| 'input_ids': | |
| None, # encoder_outputs is defined. input_ids not needed | |
| 'encoder_outputs': encoder_outputs, | |
| 'past_key_values': past_key_values, | |
| 'decoder_input_ids': decoder_input_ids, | |
| 'attention_mask': attention_mask, | |
| 'head_mask': head_mask, | |
| 'decoder_head_mask': decoder_head_mask, | |
| 'cross_attn_head_mask': cross_attn_head_mask, | |
| 'use_cache': | |
| use_cache, # change this to avoid caching (presumably for debugging) | |
| 'pixel_values': pixel_values, | |
| } | |
| def _reorder_cache(past_key_values, beam_idx): | |
| reordered_past = () | |
| for layer_past in past_key_values: | |
| reordered_past += (tuple( | |
| past_state.index_select(0, beam_idx.to(past_state.device)) | |
| for past_state in layer_past), ) | |
| return reordered_past | |