| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import CrossEntropyLoss |
| | import torch.utils.checkpoint |
| | import torch.utils.checkpoint |
| | from transformers.modeling_outputs import Seq2SeqLMOutput |
| | from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import ( |
| | shift_tokens_right, |
| | ) |
| | from transformers.models.whisper.modeling_whisper import ( |
| | WhisperEncoder, |
| | ) |
| | from transformers.models.whisper.modeling_whisper import ( |
| | WhisperForConditionalGeneration, |
| | shift_tokens_right, |
| | WhisperModel, |
| | ) |
| | from transformers.models.whisper.modeling_whisper import sinusoids |
| | from transformers.utils import logging |
| |
|
| | from .config import Seq2SeqLMOutputLosses, Seq2SeqModelOutputLogit, DiCoWConfig |
| | from .encoder import CustomLinear, CustomDiagonalLinear, FDDT, DiCoWEncoder |
| | from .generation import DiCoWGenerationMixin |
| |
|
| | logging.set_verbosity_debug() |
| | logger = logging.get_logger("transformers") |
| |
|
| |
|
| | class DiCoW(WhisperModel): |
| | def __init__(self, config: DiCoWConfig): |
| | super().__init__(config) |
| | self.encoder = DiCoWEncoder(config) |
| |
|
| | def forward( |
| | self, |
| | input_features: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | decoder_attention_mask: Optional[torch.LongTensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | decoder_head_mask: Optional[torch.Tensor] = None, |
| | cross_attn_head_mask: Optional[torch.Tensor] = None, |
| | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
| | decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | stno_mask: Optional[torch.FloatTensor] = None, |
| | per_group_sizes: Optional[torch.LongTensor] = None, |
| | ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutputLosses]: |
| | r""" |
| | Returns: |
| | |
| | Example: |
| | ```python |
| | >>> import torch |
| | >>> from transformers import AutoFeatureExtractor, WhisperModel |
| | >>> from datasets import load_dataset |
| | |
| | >>> model = WhisperModel.from_pretrained("openai/whisper-base") |
| | >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") |
| | >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
| | >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") |
| | >>> input_features = inputs.input_features |
| | >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id |
| | >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state |
| | >>> list(last_hidden_state.shape) |
| | [1, 2, 512] |
| | ```""" |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if encoder_outputs is None: |
| | input_features = self._mask_input_features(input_features, attention_mask=attention_mask) |
| |
|
| | encoder_outputs = self.encoder( |
| | input_features, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | head_mask=head_mask, |
| | return_dict=return_dict, |
| | stno_mask=stno_mask, |
| | per_group_sizes=per_group_sizes |
| | ) |
| | |
| | |
| | |
| |
|
| | |
| | decoder_outputs = self.decoder( |
| | input_ids=decoder_input_ids, |
| | attention_mask=decoder_attention_mask, |
| | encoder_hidden_states=encoder_outputs.hidden_states[-1], |
| | head_mask=decoder_head_mask, |
| | cross_attn_head_mask=cross_attn_head_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=decoder_inputs_embeds, |
| | position_ids=decoder_position_ids, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | if not return_dict: |
| | return decoder_outputs + encoder_outputs |
| |
|
| | return Seq2SeqModelOutputLogit( |
| | last_hidden_state=decoder_outputs.last_hidden_state, |
| | past_key_values=decoder_outputs.past_key_values, |
| | decoder_hidden_states=decoder_outputs.hidden_states, |
| | decoder_attentions=decoder_outputs.attentions, |
| | cross_attentions=decoder_outputs.cross_attentions, |
| | encoder_last_hidden_state=encoder_outputs.hidden_states[-1], |
| | encoder_hidden_states=encoder_outputs.hidden_states, |
| | encoder_attentions=encoder_outputs.attentions, |
| | encoder_logits=encoder_outputs.logits, |
| | ) |
| |
|
| |
|
| | class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration): |
| | config_class = DiCoWConfig |
| |
|
| | def __init__(self, config: DiCoWConfig): |
| | super().__init__(config) |
| | self.model = DiCoW(config) |
| | self.encoder_logits = None |
| | self.tokenizer = None |
| | self.vad_seek_callback = None |
| | self.stno_mask = None |
| | self.stno_mask_seek = None |
| |
|
| | |
| | |
| | def set_vad_seek_callback(self, vad_seek_callback): |
| | self.vad_seek_callback = vad_seek_callback |
| |
|
| | def set_tokenizer(self, tokenizer): |
| | self.tokenizer = tokenizer |
| |
|
| | def _init_weights(self, module): |
| | std = self.config.init_std |
| | fddt_init = self.config.fddt_init |
| | if isinstance(module, CustomLinear): |
| | with torch.no_grad(): |
| | if fddt_init == 'random': |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.normal_(mean=0.0, std=std) |
| | elif fddt_init == 'non-disturbing': |
| | module.weight.data = torch.eye(*module.weight.shape).data |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif fddt_init == 'disparagement': |
| | eye = torch.eye(*module.weight.shape) |
| | eye *= module.init_eye_val |
| | module.weight.data = eye.data |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, CustomDiagonalLinear): |
| | with torch.no_grad(): |
| | if fddt_init == 'random': |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.normal_(mean=0.0, std=std) |
| | elif fddt_init == 'non-disturbing': |
| | module.weight.data = torch.ones_like(module.weight.data).data |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif fddt_init == 'disparagement': |
| | module.weight.data = module.init_eye_val * torch.ones_like(module.weight.data).data |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, FDDT): |
| | if module.bias_only: |
| | if fddt_init == 'random': |
| | module.target_linear.data.normal_(mean=0.0, std=std) |
| | module.non_target_linear.data.normal_(mean=0.0, std=std) |
| | module.overlap_linear.data.normal_(mean=0.0, std=std) |
| | module.silence_linear.data.normal_(mean=0.0, std=std) |
| | else: |
| | module.target_linear.data.zero_() |
| | module.non_target_linear.data.zero_() |
| | module.overlap_linear.data.zero_() |
| | module.silence_linear.data.zero_() |
| | elif isinstance(module, (nn.Linear, nn.Conv1d)): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, WhisperEncoder): |
| | with torch.no_grad(): |
| | embed_positions = module.embed_positions.weight |
| | embed_positions.copy_(sinusoids(*embed_positions.shape)) |
| | elif isinstance(module, nn.LayerNorm): |
| | module.reset_parameters() |
| | elif isinstance(module, nn.MultiheadAttention): |
| | module._reset_parameters() |
| | elif isinstance(module, nn.ConvTranspose1d): |
| | module.reset_parameters() |
| |
|
| | def forward( |
| | self, |
| | input_features: Optional[torch.FloatTensor] = None, |
| | stno_mask: Optional[torch.FloatTensor] = None, |
| | per_group_sizes: Optional[torch.LongTensor] = None, |
| | attention_mask_enc: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | decoder_attention_mask: Optional[torch.LongTensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | decoder_head_mask: Optional[torch.Tensor] = None, |
| | cross_attn_head_mask: Optional[torch.Tensor] = None, |
| | encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
| | decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | upp_labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | is_valid: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` |
| | or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is |
| | only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> import torch |
| | >>> from transformers import AutoProcessor, WhisperForConditionalGeneration |
| | >>> from datasets import load_dataset |
| | |
| | >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") |
| | >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") |
| | |
| | >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
| | |
| | >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") |
| | >>> input_features = inputs.input_features |
| | |
| | >>> generated_ids = model.generate(inputs=input_features) |
| | |
| | >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| | >>> transcription |
| | ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' |
| | ```""" |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if labels is not None: |
| | if decoder_input_ids is None and decoder_inputs_embeds is None: |
| | decoder_input_ids = shift_tokens_right( |
| | labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| | ) |
| |
|
| | outputs = self.model( |
| | input_features, |
| | attention_mask=attention_mask, |
| | decoder_input_ids=decoder_input_ids, |
| | encoder_outputs=encoder_outputs, |
| | decoder_attention_mask=decoder_attention_mask, |
| | head_mask=head_mask, |
| | decoder_head_mask=decoder_head_mask, |
| | cross_attn_head_mask=cross_attn_head_mask, |
| | past_key_values=past_key_values, |
| | decoder_inputs_embeds=decoder_inputs_embeds, |
| | decoder_position_ids=decoder_position_ids, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | stno_mask=stno_mask, |
| | per_group_sizes=per_group_sizes |
| | ) |
| |
|
| | dec_lm_logits = self.proj_out(outputs.last_hidden_state) |
| | enc_lm_logits = outputs.encoder_logits |
| |
|
| | loss = None |
| | ctc_loss = 0 |
| |
|
| | |
| | if is_valid is not None: |
| | if self.config.ctc_weight > 0.0: |
| | enc_lm_logits = enc_lm_logits[is_valid] |
| | dec_lm_logits = dec_lm_logits[is_valid] |
| | labels = labels[is_valid] |
| | upp_labels = upp_labels[is_valid] |
| |
|
| | if labels is not None and self.config.ctc_weight > 0.0: |
| | enc_labels = labels.clone() |
| | for token in self.tokenizer.prefix_tokens: |
| | if (enc_labels[:, 0] == token).all(): |
| | enc_labels = enc_labels[:, 1:] |
| | enc_labels[enc_labels == self.config.eos_token_id] = -100 |
| |
|
| | ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels) |
| |
|
| | if labels is not None: |
| | loss_fct = CrossEntropyLoss(reduction='none') |
| | |
| | labels = labels.to(dec_lm_logits.device) |
| | dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) |
| | dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1)) |
| | dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean() |
| | loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss |
| |
|
| | if not return_dict: |
| | output = (dec_lm_logits,) + outputs[1:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return Seq2SeqLMOutputLosses( |
| | loss=loss, |
| | logits=dec_lm_logits, |
| | past_key_values=outputs.past_key_values, |
| | decoder_hidden_states=outputs.decoder_hidden_states, |
| | decoder_attentions=outputs.decoder_attentions, |
| | cross_attentions=outputs.cross_attentions, |
| | encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| | encoder_hidden_states=outputs.encoder_hidden_states, |
| | encoder_attentions=outputs.encoder_attentions, |
| | encoder_logits=enc_lm_logits, |
| | ) |
| |
|
| | def _get_feat_extract_output_lengths(self, attention_mask: torch.Tensor) -> torch.Tensor: |
| | return (self.model.encoder._get_feat_extract_output_lengths(attention_mask) / 4).ceil() |
| |
|
| | def freeze_except(self, prefixes_to_preheat): |
| | for name, param in self.named_parameters(): |
| | param.requires_grad = False |
| | for prefix in prefixes_to_preheat: |
| | if name.startswith(prefix): |
| | param.requires_grad = True |
| |
|
| | def suppress_interactions(self): |
| | """This method suppress final projection in CoAttention blocks to let the original information flow through""" |
| | for name, param in self.named_parameters(): |
| | if "interaction" in name and "cat_proj" in name: |
| | with torch.no_grad(): |
| | if "bias" in name: |
| | param[:] = 0. |
| | else: |
| | param[:] *= 0.001 |
| |
|