Spaces:
Running
Running
File size: 1,395 Bytes
65d7391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: davide.cozzolino
"""
import torch
import torch.nn.functional as F
def weighted_statistics_pooling(x, log_w=None):
b = x.shape[0]
c = x.shape[1]
x = x.view(b,c,-1)
if log_w is None:
log_w = torch.zeros((b,1,x.shape[-1]), device=x.device)
else:
assert log_w.shape[0]==b
assert log_w.shape[1]==1
log_w = log_w.view(b,1,-1)
assert log_w.shape[-1]==x.shape[-1]
log_w = F.log_softmax(log_w, dim=-1)
x_min = -torch.logsumexp(log_w-x, dim=-1)
x_max = torch.logsumexp(log_w+x, dim=-1)
w = torch.exp(log_w)
x_avg = torch.sum(w*x , dim=-1)
x_msq = torch.sum(w*x*x, dim=-1)
x = torch.cat((x_min, x_max, x_avg, x_msq), dim=1)
return x |