| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | from transformers import WhisperConfig |
| | from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput, Seq2SeqModelOutput |
| |
|
| |
|
| | @dataclass |
| | class Seq2SeqLMOutputLosses(Seq2SeqLMOutput): |
| | enc_loss: Optional[torch.FloatTensor] = None |
| | dec_loss: Optional[torch.FloatTensor] = None |
| | encoder_logits: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | @dataclass |
| | class BaseModelOutputLogit(BaseModelOutput): |
| | logits: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | @dataclass |
| | class Seq2SeqModelOutputLogit(Seq2SeqModelOutput): |
| | encoder_logits: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | class DiCoWConfig(WhisperConfig): |
| | """This is a modified version of the `WhisperEncoder` model from the `transformers` library. |
| | The model has been modified to support CTC loss computation in the forward pass.""" |
| | model_type = "DiCoW" |
| | def __init__( |
| | self, |
| | ctc_loss_reduction: str = "mean", |
| | final_dropout: float = 0.0, |
| | ctc_zero_infinity: bool = False, |
| | ctc_weight: float = 0.0, |
| | blank_token_id: Optional[int] = None, |
| | additional_layer: bool = False, |
| | additional_self_attention_layer: bool = False, |
| | sub_sample: bool = False, |
| | use_fddt: bool = True, |
| | fddt_is_diagonal: bool = True, |
| | fddt_bias_only: bool = False, |
| | fddt_use_silence: bool = True, |
| | fddt_use_target: bool = True, |
| | fddt_use_overlap: bool = True, |
| | fddt_use_non_target: bool = True, |
| | remove_timestamps_from_ctc: bool = False, |
| | apply_fddt_to_n_layers: int = -1, |
| | fddt_init: str = 'non-disturbing', |
| | n_soft_prompts: int = 16, |
| | mt_num_speakers: int = 1, |
| | non_target_fddt_value: float = 0.0, |
| | use_initial_fddt: bool = False, |
| | scb_method: str = None, |
| | scb_layers: int = -1, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.ctc_loss_reduction = ctc_loss_reduction |
| | self.final_dropout = final_dropout |
| | self.ctc_zero_infinity = ctc_zero_infinity |
| | self.ctc_weight = ctc_weight |
| | self.blank_token_id = blank_token_id |
| | self.additional_layer = additional_layer |
| | self.additional_self_attention_layer = additional_self_attention_layer |
| | self.sub_sample = sub_sample |
| | self.use_fddt = use_fddt |
| | self.fddt_is_diagonal = fddt_is_diagonal |
| | self.fddt_bias_only = fddt_bias_only |
| | self.fddt_use_silence = fddt_use_silence |
| | self.fddt_use_target = fddt_use_target |
| | self.fddt_use_overlap = fddt_use_overlap |
| | self.fddt_use_non_target = fddt_use_non_target |
| | self.remove_timestamps_from_ctc = remove_timestamps_from_ctc |
| | self.apply_fddt_to_n_layers = apply_fddt_to_n_layers |
| | self.fddt_init = fddt_init |
| | self.n_soft_prompts = n_soft_prompts |
| | self.mt_num_speakers = mt_num_speakers |
| | self.non_target_fddt_value = non_target_fddt_value |
| | self.use_initial_fddt = use_initial_fddt |
| | self.scb_method = scb_method |
| | self.scb_layers = scb_layers |
| |
|
| |
|
| | _HIDDEN_STATES_START_POSITION = 2 |
| |
|