# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA'). # # All rights reserved. # This work should only be used for nonprofit purposes. # # By downloading and/or using any of these files, you implicitly agree to all the # terms of the license, as specified in the document LICENSE.txt # (included in this package) and online at # http://www.grip.unina.it/download/LICENSE_OPEN.txt """ Created in September 2020 @author: davide.cozzolino """ import math import torch.nn as nn def conv_with_padding(in_planes, out_planes, kernelsize, stride=1, dilation=1, bias=False, padding = None): if padding is None: padding = kernelsize//2 return nn.Conv2d(in_planes, out_planes, kernel_size=kernelsize, stride=stride, dilation=dilation, padding=padding, bias=bias) def conv_init(conv, act='linear'): r""" Reproduces conv initialization from DnCNN """ n = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels conv.weight.data.normal_(0, math.sqrt(2. / n)) def batchnorm_init(m, kernelsize=3): r""" Reproduces batchnorm initialization from DnCNN """ n = kernelsize**2 * m.num_features m.weight.data.normal_(0, math.sqrt(2. / (n))) m.bias.data.zero_() def make_activation(act): if act is None: return None elif act == 'relu': return nn.ReLU(inplace=True) elif act == 'tanh': return nn.Tanh() elif act == 'leaky_relu': return nn.LeakyReLU(inplace=True) elif act == 'softmax': return nn.Softmax() elif act == 'linear': return None else: assert(False) def make_net(nplanes_in, kernels, features, bns, acts, dilats, bn_momentum = 0.1, padding=None): r""" :param nplanes_in: number of of input feature channels :param kernels: list of kernel size for convolution layers :param features: list of hidden layer feature channels :param bns: list of whether to add batchnorm layers :param acts: list of activations :param dilats: list of dilation factors :param bn_momentum: momentum of batchnorm :param padding: integer for padding (None for same padding) """ depth = len(features) assert(len(features)==len(kernels)) layers = list() for i in range(0,depth): if i==0: in_feats = nplanes_in else: in_feats = features[i-1] elem = conv_with_padding(in_feats, features[i], kernelsize=kernels[i], dilation=dilats[i], padding=padding, bias=not(bns[i])) conv_init(elem, act=acts[i]) layers.append(elem) if bns[i]: elem = nn.BatchNorm2d(features[i], momentum = bn_momentum) batchnorm_init(elem, kernelsize=kernels[i]) layers.append(elem) elem = make_activation(acts[i]) if elem is not None: layers.append(elem) return nn.Sequential(*layers) class DnCNN(nn.Module): r""" Implements a DnCNN network """ def __init__(self, nplanes_in, nplanes_out, features, kernel, depth, activation, residual, bn, lastact=None, bn_momentum = 0.10, padding=None): r""" :param nplanes_in: number of of input feature channels :param nplanes_out: number of of output feature channels :param features: number of of hidden layer feature channels :param kernel: kernel size of convolution layers :param depth: number of convolution layers (minimum 2) :param bn: whether to add batchnorm layers :param residual: whether to add a residual connection from input to output :param bn_momentum: momentum of batchnorm :param padding: inteteger for padding """ super(DnCNN, self).__init__() self.residual = residual self.nplanes_out = nplanes_out self.nplanes_in = nplanes_in kernels = [kernel, ] * depth features = [features, ] * (depth-1) + [nplanes_out, ] bns = [False, ] + [bn,] * (depth - 2) + [False, ] dilats = [1, ] * depth acts = [activation, ] * (depth - 1) + [lastact, ] self.layers = make_net(nplanes_in, kernels, features, bns, acts, dilats=dilats, bn_momentum = bn_momentum, padding=padding) def forward(self, x): shortcut = x x = self.layers(x) if self.residual: nshortcut = min(self.nplanes_in, self.nplanes_out) x[:, :nshortcut, :, :] = x[:, :nshortcut, :, :] + shortcut[:, :nshortcut, :, :] return x def add_commandline_networkparams(parser, name, features, depth, kernel, activation, bn): parser.add_argument("--{}.{}".format(name, "features" ), type=int, default=features ) parser.add_argument("--{}.{}".format(name, "depth" ), type=int, default=depth ) parser.add_argument("--{}.{}".format(name, "kernel" ), type=int, default=kernel ) parser.add_argument("--{}.{}".format(name, "activation"), type=str, default=activation) bnarg = "{}.{}".format(name, "bn") parser.add_argument("--"+bnarg , action="store_true", dest=bnarg) parser.add_argument("--{}.{}".format(name, "no-bn"), action="store_false", dest=bnarg) parser.set_defaults(**{bnarg: bn})