Spaces:
Running
Running
| import torch.nn as nn | |
| from .nrtr_decoder import NRTRDecoder | |
| class CAMDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| nhead=None, | |
| num_encoder_layers=6, | |
| beam_size=0, | |
| num_decoder_layers=6, | |
| max_len=25, | |
| attention_dropout_rate=0.0, | |
| residual_dropout_rate=0.1, | |
| scale_embedding=True, | |
| ): | |
| super().__init__() | |
| self.decoder = NRTRDecoder( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| nhead=nhead, | |
| num_encoder_layers=num_encoder_layers, | |
| beam_size=beam_size, | |
| num_decoder_layers=num_decoder_layers, | |
| max_len=max_len, | |
| attention_dropout_rate=attention_dropout_rate, | |
| residual_dropout_rate=residual_dropout_rate, | |
| scale_embedding=scale_embedding, | |
| ) | |
| def forward(self, x, data=None): | |
| dec_in = x['refined_feat'] | |
| dec_output = self.decoder(dec_in, data=data) | |
| x['rec_output'] = dec_output | |
| if self.training: | |
| return x | |
| else: | |
| return dec_output | |