Spaces:
Sleeping
Sleeping
| # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
| # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA'). | |
| # | |
| # All rights reserved. | |
| # This work should only be used for nonprofit purposes. | |
| # | |
| # By downloading and/or using any of these files, you implicitly agree to all the | |
| # terms of the license, as specified in the document LICENSE.txt | |
| # (included in this package) and online at | |
| # http://www.grip.unina.it/download/LICENSE_OPEN.txt | |
| """ | |
| Created in September 2020 | |
| @author: davide.cozzolino | |
| """ | |
| import math | |
| import torch.nn as nn | |
| def conv_with_padding(in_planes, out_planes, kernelsize, stride=1, dilation=1, bias=False, padding = None): | |
| if padding is None: | |
| padding = kernelsize//2 | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=kernelsize, stride=stride, dilation=dilation, padding=padding, bias=bias) | |
| def conv_init(conv, act='linear'): | |
| r""" | |
| Reproduces conv initialization from DnCNN | |
| """ | |
| n = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels | |
| conv.weight.data.normal_(0, math.sqrt(2. / n)) | |
| def batchnorm_init(m, kernelsize=3): | |
| r""" | |
| Reproduces batchnorm initialization from DnCNN | |
| """ | |
| n = kernelsize**2 * m.num_features | |
| m.weight.data.normal_(0, math.sqrt(2. / (n))) | |
| m.bias.data.zero_() | |
| def make_activation(act): | |
| if act is None: | |
| return None | |
| elif act == 'relu': | |
| return nn.ReLU(inplace=True) | |
| elif act == 'tanh': | |
| return nn.Tanh() | |
| elif act == 'leaky_relu': | |
| return nn.LeakyReLU(inplace=True) | |
| elif act == 'softmax': | |
| return nn.Softmax() | |
| elif act == 'linear': | |
| return None | |
| else: | |
| assert(False) | |
| def make_net(nplanes_in, kernels, features, bns, acts, dilats, bn_momentum = 0.1, padding=None): | |
| r""" | |
| :param nplanes_in: number of of input feature channels | |
| :param kernels: list of kernel size for convolution layers | |
| :param features: list of hidden layer feature channels | |
| :param bns: list of whether to add batchnorm layers | |
| :param acts: list of activations | |
| :param dilats: list of dilation factors | |
| :param bn_momentum: momentum of batchnorm | |
| :param padding: integer for padding (None for same padding) | |
| """ | |
| depth = len(features) | |
| assert(len(features)==len(kernels)) | |
| layers = list() | |
| for i in range(0,depth): | |
| if i==0: | |
| in_feats = nplanes_in | |
| else: | |
| in_feats = features[i-1] | |
| elem = conv_with_padding(in_feats, features[i], kernelsize=kernels[i], dilation=dilats[i], padding=padding, bias=not(bns[i])) | |
| conv_init(elem, act=acts[i]) | |
| layers.append(elem) | |
| if bns[i]: | |
| elem = nn.BatchNorm2d(features[i], momentum = bn_momentum) | |
| batchnorm_init(elem, kernelsize=kernels[i]) | |
| layers.append(elem) | |
| elem = make_activation(acts[i]) | |
| if elem is not None: | |
| layers.append(elem) | |
| return nn.Sequential(*layers) | |
| class DnCNN(nn.Module): | |
| r""" | |
| Implements a DnCNN network | |
| """ | |
| def __init__(self, nplanes_in, nplanes_out, features, kernel, depth, activation, residual, bn, lastact=None, bn_momentum = 0.10, padding=None): | |
| r""" | |
| :param nplanes_in: number of of input feature channels | |
| :param nplanes_out: number of of output feature channels | |
| :param features: number of of hidden layer feature channels | |
| :param kernel: kernel size of convolution layers | |
| :param depth: number of convolution layers (minimum 2) | |
| :param bn: whether to add batchnorm layers | |
| :param residual: whether to add a residual connection from input to output | |
| :param bn_momentum: momentum of batchnorm | |
| :param padding: inteteger for padding | |
| """ | |
| super(DnCNN, self).__init__() | |
| self.residual = residual | |
| self.nplanes_out = nplanes_out | |
| self.nplanes_in = nplanes_in | |
| kernels = [kernel, ] * depth | |
| features = [features, ] * (depth-1) + [nplanes_out, ] | |
| bns = [False, ] + [bn,] * (depth - 2) + [False, ] | |
| dilats = [1, ] * depth | |
| acts = [activation, ] * (depth - 1) + [lastact, ] | |
| self.layers = make_net(nplanes_in, kernels, features, bns, acts, dilats=dilats, bn_momentum = bn_momentum, padding=padding) | |
| def forward(self, x): | |
| shortcut = x | |
| x = self.layers(x) | |
| if self.residual: | |
| nshortcut = min(self.nplanes_in, self.nplanes_out) | |
| x[:, :nshortcut, :, :] = x[:, :nshortcut, :, :] + shortcut[:, :nshortcut, :, :] | |
| return x | |
| def add_commandline_networkparams(parser, name, features, depth, kernel, activation, bn): | |
| parser.add_argument("--{}.{}".format(name, "features" ), type=int, default=features ) | |
| parser.add_argument("--{}.{}".format(name, "depth" ), type=int, default=depth ) | |
| parser.add_argument("--{}.{}".format(name, "kernel" ), type=int, default=kernel ) | |
| parser.add_argument("--{}.{}".format(name, "activation"), type=str, default=activation) | |
| bnarg = "{}.{}".format(name, "bn") | |
| parser.add_argument("--"+bnarg , action="store_true", dest=bnarg) | |
| parser.add_argument("--{}.{}".format(name, "no-bn"), action="store_false", dest=bnarg) | |
| parser.set_defaults(**{bnarg: bn}) | |