duyongkun
update app
5de2f8f
raw
history blame
19.1 kB
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
# transformers == 4.45.1
from .configuration_unirec import UniRecConfig
from transformers import M2M100PreTrainedModel
from transformers.models.m2m_100.modeling_m2m_100 import M2M100ScaledWordEmbedding, M2M100Decoder
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput
from transformers.generation import GenerationMixin
from openrec.modeling.encoders.focalsvtr import FocalSVTR
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError('self.model.config.pad_token_id has to be defined.')
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class UniRecEncoder(M2M100PreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`M2M100EncoderLayer`].
Args:
config: UniRecConfig
"""
def __init__(self, config: UniRecConfig):
super().__init__(config)
self.config = config
self.vision_encoder = FocalSVTR(img_size=[1408, 960],
depths=[2, 2, 9, 2],
embed_dim=96,
sub_k=[[2, 2], [2, 2], [2, 2],
[-1, -1]],
focal_levels=[3, 3, 3, 3],
max_khs=[7, 3, 3, 3],
focal_windows=[3, 3, 3, 3],
last_stage=False,
feat2d=False)
self.vision_fc = nn.Linear(config.dims[-1], config.d_model)
def forward(
self,
pixel_values: torch.Tensor = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print('visionencoder pixel_values', pixel_values)
# retrieve input_ids and inputs_embeds
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = self.vision_encoder(pixel_values)
hidden_states = self.vision_fc(hidden_states)
# hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = (hidden_states, )
if not return_dict:
return tuple(
v for v in [hidden_states, encoder_states, all_attentions]
if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions)
class UniRecModel(M2M100PreTrainedModel):
_tied_weights_keys = [
'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'
]
def __init__(self, config: UniRecConfig):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.shared = M2M100ScaledWordEmbedding(vocab_size,
config.d_model,
padding_idx,
embed_scale=embed_scale)
self.encoder = UniRecEncoder(config)
self.decoder = M2M100Decoder(config, self.shared)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def forward(
self,
pixel_values: torch.Tensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(pixel_values)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1]
if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2]
if len(encoder_outputs) > 2 else None,
)
attention_mask = torch.ones(encoder_outputs[0].shape[:2],
dtype=torch.long,
device=encoder_outputs[0].device)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class UniRecForConditionalGenerationNew(M2M100PreTrainedModel,
GenerationMixin):
base_model_prefix = 'model'
_tied_weights_keys = [
'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight',
'lm_head.weight'
]
def __init__(self, config: UniRecConfig):
super().__init__(config)
self.model = UniRecModel(config)
self.lm_head = nn.Linear(config.d_model,
self.model.shared.num_embeddings,
bias=False)
self.loss_fct = CrossEntropyLoss(
ignore_index=config.pad_token_id,
label_smoothing=config.label_smoothing)
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
pixel_values: torch.Tensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
length: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None:
# decoder_input_ids = shift_tokens_right(
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
# )
if length is not None:
max_len = length.max()
decoder_input_ids = labels[:, :1 + max_len]
labels = labels[:, 1:2 + max_len]
else:
decoder_input_ids = labels[:, :-1]
labels = labels[:, 1:]
masked_lm_loss = None
if self.training and labels is not None:
outputs = self.model(
pixel_values=pixel_values,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0])
masked_lm_loss = self.loss_fct(
lm_logits.reshape(-1, self.config.vocab_size),
labels.reshape(-1))
else:
# print('pixel_values', pixel_values.shape)
# print('decoder_input_ids', decoder_input_ids)
outputs = self.model(
pixel_values=pixel_values,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0])
if not return_dict:
output = (lm_logits, ) + outputs[1:]
return ((masked_lm_loss, ) +
output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
pixel_values=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
'input_ids':
None, # encoder_outputs is defined. input_ids not needed
'encoder_outputs': encoder_outputs,
'past_key_values': past_key_values,
'decoder_input_ids': decoder_input_ids,
'attention_mask': attention_mask,
'head_mask': head_mask,
'decoder_head_mask': decoder_head_mask,
'cross_attn_head_mask': cross_attn_head_mask,
'use_cache':
use_cache, # change this to avoid caching (presumably for debugging)
'pixel_values': pixel_values,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past), )
return reordered_past