实验笔记之——基于CARN的Octave Convolution实验记录
本博文跟上一篇博文《实验笔记之——基于SRResNet的Octave Convolution实验记录》类似,只是改为CARN的结构。用octave convolution代替传统的convolution
给出CARN的代码https://github.com/nmhkahn/CARN-pytorch
论文https://arxiv.org/pdf/1803.08664.pdf
理论
作者主要围绕轻量级来论述本文。CARN(Cascading Residual Network,级联残差网络)
它具有以下三个特征:
- 全局和局部级联连接
- 中间特征是级联的,且被组合在1×1大小的卷积块中
- 使多级表示和快捷连接,让信息传递更高效
参考材料
https://blog.****.net/alxe_made/article/details/85839802
https://blog.****.net/Chaolei3/article/details/79374563
https://zhuanlan.zhihu.com/p/28749411
https://zybuluo.com/hanbingtao/note/626300
实验
先给出网络的结构:
##############################################################################################
#Octave CARN
class Octave_CARN(nn.Module):#nb=3(3 block),channel=24
def __init__(self, in_nc, out_nc, nf=24, nc=4, nb=3, alpha=0.75, upscale=4, norm_type='batch', act_type='relu', \
mode='NAC', res_scale=1, upsample_mode='upconv'):
super(Octave_CARN, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
self.nb = nb
self.fea_conv =B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
self.oct_first=B.FirstOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
self.CascadeBlocks = nn.ModuleList([B.OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, \
norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
self.CatBlocks = nn.ModuleList([B.OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, \
norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
self.oct_last = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
self.upsampler = nn.PixelShuffle(upscale)
self.HR_conv1 = B.conv_block(nf, in_nc*(upscale ** 2), kernel_size=3, norm_type=None, act_type=None)
def forward(self, x):
x = self.fea_conv(x)
x = self.oct_first(x)
pre_fea = x
for i in range(self.nb):
res = self.CascadeBlocks[i](x)
pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
torch.cat((pre_fea[1], res[1]), dim=1))
x = self.CatBlocks[i](pre_fea)
x = self.oct_last(x)
x = self.HR_conv1(x)
x = F.sigmoid(self.upsampler(x))
return x
##############################################################################################
####################
class OctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.7, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(OctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.stride = stride
self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
X_h, X_l = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h = self.upsample(self.l2h(X_l))
X_l2l = self.l2l(X_l)
X_h2l = self.h2l(self.h2g_pool(X_h))
#print(X_l2h.shape,"~~~~",X_h2h.shape)
X_h = X_l2h + X_h2h
X_l = X_h2l + X_l2l
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class FirstOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.7, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(FirstOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.stride = stride
self.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else None
self.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else None
def forward(self, x):
if self.stride ==2:
x = self.h2g_pool(x)
X_h = self.h2h(x)
X_l = self.h2l(self.h2g_pool(x))
if self.n_h and self.n_l:
X_h = self.n_h(X_h)
X_l = self.n_l(X_l)
if self.a:
X_h = self.a(X_h)
X_l = self.a(X_l)
return X_h, X_l
class LastOctaveConv(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.7, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(LastOctaveConv, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.stride = stride
self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,
kernel_size, 1, padding, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h, X_l = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2h = self.h2h(X_h)
X_l2h = self.upsample(self.l2h(X_l))
X_h = X_h2h + X_l2h
if self.n_h:
X_h = self.n_h(X_h)
if self.a:
X_h = self.a(X_h)
return X_h
class OctaveCascadeBlock(nn.Module):
"""
OctaveCascadeBlock, 3-3 style
"""
def __init__(self, nc, gc, kernel_size=3, alpha=0.7, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA', res_scale=1):
super(OctaveCascadeBlock, self).__init__()
self.nc = nc
self.ResBlocks = nn.ModuleList([OctaveResBlock(gc, gc, gc, kernel_size, alpha, stride, dilation, \
groups, bias, pad_type, norm_type, act_type, mode, res_scale) for _ in range(nc)])
self.CatBlocks = nn.ModuleList([OctaveConv((i + 2)*gc, gc, kernel_size=1, alpha=alpha, bias=bias, \
pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nc)])
def forward(self, x):
pre_fea = x
for i in range(self.nc):
res = self.ResBlocks[i](x)
pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
torch.cat((pre_fea[1], res[1]), dim=1))
x = self.CatBlocks[i](pre_fea)
return x
给出setting:
{
"name": "octave_carn_DIV2K_alpha0.75" // please remove "debug_" during training
, "tb_logger_dir": "octave"
, "use_tb_logger": true
, "model":"sr"
, "scale": 4
, "crop_scale": 0
, "gpu_ids": [0,3]
// , "init_type": "kaiming"
//
// , "finetune_type": "sft"
// , "init_norm_type": "zero"
, "datasets": {
"train": {
"name": "DIV2K800"
, "mode": "LRHR"
, "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub"
, "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/DIV2K800_sub_bicLRx4"
, "subset_file": null
, "use_shuffle": true
, "n_workers": 8
, "batch_size": 16 // 16
, "HR_size": 128 // 128 | 192 | 96
, "noise_gt": true
, "use_flip": true
, "use_rot": true
}
, "val": {
"name": "set5"
, "mode": "LRHR"
, "dataroot_HR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/MSet5"
, "dataroot_LR": "/media/sdc/wpguan/BasicSR_datasets/val_set5/MSet5_bicLRx4"
, "noise_gt": false
}
}
, "path": {
"root": "/home/wpguan/SR_master/octave"
, "pretrain_model_G": null
}
//
, "network_G": {
"which_model_G": "octave_carn" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet octave_resnet, octave_carn
// , "norm_type": "adaptive_conv_res"
, "norm_type": null
, "mode": "CNA"
, "nf": 24//64
, "nb": 3//16
, "in_nc": 3
, "out_nc": 3
// , "gc": 32
, "group": 1
// , "gate_conv_bias": true
// , "ada_ksize": 1
// , "num_classes": 2
}
// , "network_G": {
// "which_model_G": "srcnn" // RRDB_net | sr_resnet
//// , "norm_type": null
// , "norm_type": "adaptive_conv_res"
// , "mode": "CNA"
// , "nf": 64
// , "in_nc": 1
// , "out_nc": 1
// , "ada_ksize": 5
// }
, "train": {
// "lr_G": 1e-3
"lr_G": 6e-4
, "lr_scheme": "MultiStepLR"
, "lr_steps": [200000, 400000, 600000, 800000]
// , "lr_steps": [500000]
, "lr_gamma": 0.5
, "pixel_criterion": "l2"
, "pixel_criterion_reg": "tv"
, "pixel_weight": 1.0
, "val_freq": 1e3
, "manual_seed": 0
, "niter": 1e6
}
, "logger": {
"print_freq": 200
, "save_checkpoint_freq": 1e3
}
}
结果如下图所示