Spaces:
Running
Running
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SAREncoder(nn.Module): | |
| def __init__(self, | |
| enc_bi_rnn=False, | |
| enc_drop_rnn=0.1, | |
| in_channels=512, | |
| d_enc=512, | |
| **kwargs): | |
| super().__init__() | |
| # LSTM Encoder | |
| if enc_bi_rnn: | |
| bidirectional = True | |
| else: | |
| bidirectional = False | |
| hidden_size = d_enc | |
| self.rnn_encoder = nn.LSTM(input_size=in_channels, | |
| hidden_size=hidden_size, | |
| num_layers=2, | |
| dropout=enc_drop_rnn, | |
| bidirectional=bidirectional, | |
| batch_first=True) | |
| # global feature transformation | |
| encoder_rnn_out_size = hidden_size * (int(enc_bi_rnn) + 1) | |
| self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) | |
| def forward(self, feat): | |
| h_feat = feat.shape[2] | |
| feat_v = F.max_pool2d(feat, | |
| kernel_size=(h_feat, 1), | |
| stride=1, | |
| padding=0) | |
| feat_v = feat_v.squeeze(2) | |
| feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C | |
| holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * hidden_size | |
| valid_hf = holistic_feat[:, -1, :] # bsz * hidden_size | |
| holistic_feat = self.linear(valid_hf) # bsz * C | |
| return holistic_feat | |
| class SARDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| max_len=25, | |
| enc_bi_rnn=False, | |
| enc_drop_rnn=0.1, | |
| dec_bi_rnn=False, | |
| dec_drop_rnn=0.0, | |
| pred_dropout=0.1, | |
| pred_concat=True, | |
| mask=True, | |
| use_lstm=True, | |
| **kwargs): | |
| super(SARDecoder, self).__init__() | |
| self.num_classes = out_channels | |
| self.start_idx = out_channels - 2 | |
| self.padding_idx = out_channels - 1 | |
| self.end_idx = 0 | |
| self.max_seq_len = max_len + 1 | |
| self.pred_concat = pred_concat | |
| self.mask = mask | |
| enc_dim = in_channels | |
| d = in_channels | |
| embedding_dim = in_channels | |
| dec_dim = in_channels | |
| self.use_lstm = use_lstm | |
| if use_lstm: | |
| # encoder module | |
| self.encoder = SAREncoder(enc_bi_rnn=enc_bi_rnn, | |
| enc_drop_rnn=enc_drop_rnn, | |
| in_channels=in_channels, | |
| d_enc=enc_dim) | |
| # decoder module | |
| # 2D attention layer | |
| self.conv1x1_1 = nn.Linear(dec_dim, d) | |
| self.conv3x3_1 = nn.Conv2d(in_channels, | |
| d, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1) | |
| self.conv1x1_2 = nn.Linear(d, 1) | |
| # Decoder input embedding | |
| self.embedding = nn.Embedding(self.num_classes, | |
| embedding_dim, | |
| padding_idx=self.padding_idx) | |
| self.rnndecoder = nn.LSTM(input_size=embedding_dim, | |
| hidden_size=dec_dim, | |
| num_layers=2, | |
| dropout=dec_drop_rnn, | |
| bidirectional=dec_bi_rnn, | |
| batch_first=True) | |
| # Prediction layer | |
| self.pred_dropout = nn.Dropout(pred_dropout) | |
| if pred_concat: | |
| fc_in_channel = in_channels + in_channels + dec_dim | |
| else: | |
| fc_in_channel = in_channels | |
| self.prediction = nn.Linear(fc_in_channel, self.num_classes) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def _2d_attation(self, feat, tokens, data, training): | |
| Hidden_state = self.rnndecoder(tokens)[0] | |
| attn_query = self.conv1x1_1(Hidden_state) | |
| bsz, seq_len, _ = attn_query.size() | |
| attn_query = attn_query.unsqueeze(-1).unsqueeze(-1) | |
| # bsz * seq_len+1 * attn_size * 1 * 1 | |
| attn_key = self.conv3x3_1(feat).unsqueeze(1) | |
| # bsz * 1 * attn_size * h * w | |
| attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) | |
| attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous() | |
| attn_weight = self.conv1x1_2(attn_weight) | |
| _, T, h, w, c = attn_weight.size() | |
| if self.mask: | |
| valid_ratios = data[-1] | |
| # cal mask of attention weight | |
| attn_mask = torch.zeros_like(attn_weight) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(w, math.ceil(w * valid_ratio)) | |
| attn_mask[i, :, :, valid_width:, :] = 1 | |
| attn_weight = attn_weight.masked_fill(attn_mask.bool(), | |
| float('-inf')) | |
| attn_weight = attn_weight.view(bsz, T, -1) | |
| attn_weight = F.softmax(attn_weight, dim=-1) | |
| attn_weight = attn_weight.view(bsz, T, h, w, | |
| c).permute(0, 1, 4, 2, 3).contiguous() | |
| # bsz, T, 1, h, w | |
| # bsz, 1, f_c ,h, w | |
| attn_feat = torch.sum(torch.mul(feat.unsqueeze(1), attn_weight), | |
| (3, 4), | |
| keepdim=False) | |
| return [Hidden_state, attn_feat] | |
| def forward_train(self, feat, holistic_feat, data): | |
| max_len = data[1].max() | |
| label = data[0][:, :1 + max_len] # label | |
| label_embedding = self.embedding(label) | |
| holistic_feat = holistic_feat.unsqueeze(1) | |
| tokens = torch.cat((holistic_feat, label_embedding), dim=1) | |
| Hidden_state, attn_feat = self._2d_attation(feat, | |
| tokens, | |
| data, | |
| training=self.training) | |
| bsz, seq_len, f_c = Hidden_state.size() | |
| # linear transformation | |
| if self.pred_concat: | |
| f_c = holistic_feat.size(-1) | |
| holistic_feat = holistic_feat.expand(bsz, seq_len, f_c) | |
| preds = self.prediction( | |
| torch.cat((Hidden_state, attn_feat, holistic_feat), 2)) | |
| else: | |
| preds = self.prediction(attn_feat) | |
| # bsz * (seq_len + 1) * num_classes | |
| preds = self.pred_dropout(preds) | |
| return preds[:, 1:, :] | |
| def forward_test(self, feat, holistic_feat, data=None): | |
| bsz = feat.shape[0] | |
| seq_len = self.max_seq_len | |
| holistic_feat = holistic_feat.unsqueeze(1) | |
| tokens = torch.full((bsz, ), | |
| self.start_idx, | |
| device=feat.device, | |
| dtype=torch.long) | |
| outputs = [] | |
| tokens = self.embedding(tokens) | |
| tokens = tokens.unsqueeze(1).expand(-1, seq_len, -1) | |
| tokens = torch.cat((holistic_feat, tokens), dim=1) | |
| for i in range(1, seq_len + 1): | |
| Hidden_state, attn_feat = self._2d_attation(feat, | |
| tokens, | |
| data=data, | |
| training=self.training) | |
| if self.pred_concat: | |
| f_c = holistic_feat.size(-1) | |
| holistic_feat = holistic_feat.expand(bsz, seq_len + 1, f_c) | |
| preds = self.prediction( | |
| torch.cat((Hidden_state, attn_feat, holistic_feat), 2)) | |
| else: | |
| preds = self.prediction(attn_feat) | |
| # bsz * (seq_len + 1) * num_classes | |
| char_output = preds[:, i, :] | |
| char_output = F.softmax(char_output, -1) | |
| outputs.append(char_output) | |
| _, max_idx = torch.max(char_output, dim=1, keepdim=False) | |
| char_embedding = self.embedding(max_idx) | |
| if (i < seq_len): | |
| tokens[:, i + 1, :] = char_embedding | |
| if (tokens == self.end_idx).any(dim=-1).all(): | |
| break | |
| outputs = torch.stack(outputs, 1) | |
| return outputs | |
| def forward(self, feat, data=None): | |
| if self.use_lstm: | |
| holistic_feat = self.encoder(feat) # bsz c | |
| else: | |
| holistic_feat = F.adaptive_avg_pool2d(feat, (1, 1)).squeeze() | |
| if self.training: | |
| preds = self.forward_train(feat, holistic_feat, data=data) | |
| else: | |
| preds = self.forward_test(feat, holistic_feat, data=data) | |
| # (bsz, seq_len, num_classes) | |
| return preds | |