Spaces:
Running
Running
| """This code is refer from: | |
| https://github.com/byeonghu-na/MATRN | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from openrec.modeling.decoders.abinet_decoder import BCNLanguage, PositionAttention, _get_length | |
| from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock | |
| class BaseSemanticVisual_backbone_feature(nn.Module): | |
| def __init__(self, | |
| d_model=512, | |
| nhead=8, | |
| num_layers=4, | |
| dim_feedforward=2048, | |
| dropout=0.0, | |
| alignment_mask_example_prob=0.9, | |
| alignment_mask_candidate_prob=0.9, | |
| alignment_num_vis_mask=10, | |
| max_length=25, | |
| num_classes=37): | |
| super().__init__() | |
| self.mask_example_prob = alignment_mask_example_prob | |
| self.mask_candidate_prob = alignment_mask_candidate_prob #ifnone(config.model_alignment_mask_candidate_prob, 0.9) | |
| self.num_vis_mask = alignment_num_vis_mask | |
| self.nhead = nhead | |
| self.d_model = d_model | |
| self.max_length = max_length + 1 # additional stop token | |
| self.model1 = nn.ModuleList([ | |
| TransformerBlock( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| attention_dropout_rate=dropout, | |
| residual_dropout_rate=dropout, | |
| with_self_attn=True, | |
| with_cross_attn=False, | |
| ) for i in range(num_layers) | |
| ]) | |
| self.pos_encoder_tfm = PositionalEncoding(dim=d_model, | |
| dropout=0, | |
| max_len=1024) | |
| self.model2_vis = PositionAttention( | |
| max_length=self.max_length, # additional stop token | |
| in_channels=d_model, | |
| num_channels=d_model // 8, | |
| mode='nearest', | |
| ) | |
| self.cls_vis = nn.Linear(d_model, num_classes) | |
| self.cls_sem = nn.Linear(d_model, num_classes) | |
| self.w_att = nn.Linear(2 * d_model, d_model) | |
| v_token = torch.empty((1, d_model)) | |
| self.v_token = nn.Parameter(v_token) | |
| torch.nn.init.uniform_(self.v_token, -0.001, 0.001) | |
| self.cls = nn.Linear(d_model, num_classes) | |
| def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None): | |
| """ | |
| Args: | |
| l_feature: (N, T, E) where T is length, N is batch size and d is dim of model | |
| v_feature: (N, E, H, W) | |
| lengths_l: (N,) | |
| v_attn: (N, T, H, W) | |
| l_logits: (N, T, C) | |
| texts: (N, T, C) | |
| """ | |
| N, E, H, W = v_feature.size() | |
| v_feature = v_feature.flatten(2, 3).transpose(1, 2) #(N, H*W, E) | |
| v_attn = v_attn.flatten(2, 3) # (N, T, H*W) | |
| if self.training: | |
| for idx, length in enumerate(lengths_l): | |
| if np.random.random() <= self.mask_example_prob: | |
| l_idx = np.random.randint(int(length)) | |
| v_random_idx = v_attn[idx, l_idx].argsort( | |
| descending=True).cpu().numpy()[:self.num_vis_mask, ] | |
| v_random_idx = v_random_idx[np.random.random( | |
| v_random_idx.shape) <= self.mask_candidate_prob] | |
| v_feature[idx, v_random_idx] = self.v_token | |
| zeros = v_feature.new_zeros((N, H * W, E)) # (N, H*W, E) | |
| base_pos = self.pos_encoder_tfm(zeros) # (N, H*W, E) | |
| base_pos = torch.bmm(v_attn, base_pos) # (N, T, E) | |
| l_feature = l_feature + base_pos | |
| sv_feature = torch.cat((v_feature, l_feature), dim=1) # (H*W+T, N, E) | |
| for decoder_layer in self.model1: | |
| sv_feature = decoder_layer(sv_feature) # (H*W+T, N, E) | |
| sv_to_v_feature = sv_feature[:, :H * W] # (N, H*W, E) | |
| sv_to_s_feature = sv_feature[:, H * W:] # (N, T, E) | |
| sv_to_v_feature = sv_to_v_feature.transpose(1, 2).reshape(N, E, H, W) | |
| sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E) | |
| sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C) | |
| pt_v_lengths = _get_length(sv_to_v_logits) # (N,) | |
| sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C) | |
| pt_s_lengths = _get_length(sv_to_s_logits) # (N,) | |
| f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2) | |
| f_att = torch.sigmoid(self.w_att(f)) | |
| output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature | |
| logits = self.cls(output) # (N, T, C) | |
| pt_lengths = _get_length(logits) | |
| return { | |
| 'logits': logits, | |
| 'pt_lengths': pt_lengths, | |
| 'v_logits': sv_to_v_logits, | |
| 'pt_v_lengths': pt_v_lengths, | |
| 's_logits': sv_to_s_logits, | |
| 'pt_s_lengths': pt_s_lengths, | |
| 'name': 'alignment' | |
| } | |
| class MATRNDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| nhead=8, | |
| num_layers=3, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| max_length=25, | |
| iter_size=3, | |
| **kwargs): | |
| super().__init__() | |
| self.max_length = max_length + 1 | |
| d_model = in_channels | |
| self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model) | |
| self.encoder = nn.ModuleList([ | |
| TransformerBlock( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| attention_dropout_rate=dropout, | |
| residual_dropout_rate=dropout, | |
| with_self_attn=True, | |
| with_cross_attn=False, | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.decoder = PositionAttention( | |
| max_length=self.max_length, # additional stop token | |
| in_channels=d_model, | |
| num_channels=d_model // 8, | |
| mode='nearest', | |
| ) | |
| self.out_channels = out_channels | |
| self.cls = nn.Linear(d_model, self.out_channels) | |
| self.iter_size = iter_size | |
| if iter_size > 0: | |
| self.language = BCNLanguage( | |
| d_model=d_model, | |
| nhead=nhead, | |
| num_layers=4, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| max_length=max_length, | |
| num_classes=self.out_channels, | |
| ) | |
| # alignment | |
| self.semantic_visual = BaseSemanticVisual_backbone_feature( | |
| d_model=d_model, | |
| nhead=nhead, | |
| num_layers=2, | |
| dim_feedforward=dim_feedforward, | |
| max_length=max_length, | |
| num_classes=self.out_channels) | |
| def forward(self, x, data=None): | |
| # bs, c, h, w | |
| x = x.permute([0, 2, 3, 1]) # bs, h, w, c | |
| _, H, W, C = x.shape | |
| # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.' | |
| feature = x.flatten(1, 2) # bs, h*w, c | |
| feature = self.pos_encoder(feature) # bs, h*w, c | |
| for encoder_layer in self.encoder: | |
| feature = encoder_layer(feature) | |
| # bs, h*w, c | |
| feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1, | |
| 2) # bs, c, h, w | |
| v_feature, v_attn_input = self.decoder(feature) # (bs[N], T, E) | |
| vis_logits = self.cls(v_feature) # (bs[N], T, E) | |
| align_lengths = _get_length(vis_logits) | |
| align_logits = vis_logits | |
| all_l_res, all_a_res = [], [] | |
| for _ in range(self.iter_size): | |
| tokens = F.softmax(align_logits, dim=-1) | |
| lengths = torch.clamp( | |
| align_lengths, 2, | |
| self.max_length) # TODO: move to language model | |
| l_feature, l_logits = self.language(tokens, lengths) | |
| all_l_res.append(l_logits) | |
| # alignment | |
| lengths_l = _get_length(l_logits) | |
| lengths_l.clamp_(2, self.max_length) | |
| a_res = self.semantic_visual(l_feature, | |
| feature, | |
| lengths_l=lengths_l, | |
| v_attn=v_attn_input) | |
| a_v_res = a_res['v_logits'] | |
| # {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'], | |
| # 'name': 'alignment'} | |
| all_a_res.append(a_v_res) | |
| a_s_res = a_res['s_logits'] | |
| # {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'], | |
| # 'name': 'alignment'} | |
| align_logits = a_res['logits'] | |
| all_a_res.append(a_s_res) | |
| all_a_res.append(align_logits) | |
| align_lengths = a_res['pt_lengths'] | |
| if self.training: | |
| return { | |
| 'align': all_a_res, | |
| 'lang': all_l_res, | |
| 'vision': vis_logits | |
| } | |
| else: | |
| return F.softmax(align_logits, -1) | |