网络结构
网络整体预览

Residual Unit

Attention Module

from tensorflow import keras as k
from tensorflow.contrib import layers
import tensorflow as tf
from bases.base_network import BaseNetwork
from utils.logger import logger
class ResidualAttentionNetwork(BaseNetwork):
def __init__(self, inputs, is_training=True):
super(ResidualAttentionNetwork, self).__init__(inputs, is_training)
def _setup(self):
self.pre_conv = k.layers.Conv2D(64, kernel_size=7, strides=(2, 2), padding="SAME", name="pre_conv")(self.inputs)
logger("pre_conv: {}".format(self.pre_conv.shape))
self.pre_pool = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME", name="pre_pool")(self.pre_conv)
logger("pre_pool: {}".format(self.pre_pool.shape))
self.pre_res_1 = self.residual_unit(self.pre_pool, 64, 256, "pre_res_1")
logger("pre_res_1: {}".format(self.pre_res_1.shape))
self.attention_1 = self.attention_module(self.pre_res_1, 256, 256, "attention_1", skip_num=2)
logger("attention_1: {}".format(self.attention_1.shape))
self.pre_res_2 = self.residual_unit(self.attention_1, 256, 512, stride=2, name="pre_res_2")
logger("pre_res_2: {}".format(self.pre_res_2.shape))
self.attention_2 = self.attention_module(self.pre_res_2, 512, 512, "attention_2", skip_num=1)
logger("attention_2: {}".format(self.attention_2.shape))
self.pre_res_3 = self.residual_unit(self.attention_2, 512, 1024, stride=2, name="pre_res_3")
logger("pre_res_3: {}".format(self.pre_res_3.shape))
self.attention_3 = self.attention_module(self.pre_res_3, 1024, 1024, "attention_3", skip_num=0)
logger("attention_3: {}".format(self.attention_3.shape))
self.pre_res_4 = self.residual_unit(self.attention_3, 1024, 2048, stride=2, name="pre_res_4")
logger("pre_res_4: {}".format(self.pre_res_4.shape))
self.ave_pool = k.layers.AveragePooling2D(pool_size=(7, 7), strides=(1, 1), name="ave_pool")(self.pre_res_4)
logger("ave_pool: {}".format(self.ave_pool.shape))
pool_shape = self.ave_pool.get_shape().as_list()
logger("pool_shape: {}".format(pool_shape))
fc_input = k.layers.Reshape(target_shape=[pool_shape[1] * pool_shape[2] * pool_shape[3]],
name="reshape"
)(self.ave_pool)
logger("fc_input: {}".format(fc_input.shape))
self.outputs = k.layers.Dense(2)(fc_input)
logger("fc: {}".format(self.outputs.shape))
def attention_module(self, x, c_in, c_out, name, p=1, t=2, r=1, skip_num=2):
"""
Attention模块
"""
with tf.name_scope(name):
with tf.name_scope("pre_trunk"), tf.variable_scope("pre_trunk"):
pre_trunk = x
for idx in range(p):
unit_name = "trunk_1_{}".format(idx + 1)
pre_trunk = self.residual_unit(pre_trunk, c_in, c_out, unit_name)
with tf.name_scope("trunk_branch"):
trunks = pre_trunk
for idx in range(t):
unit_name = "trunk_res_{}".format(idx + 1)
trunks = self.residual_unit(trunks, c_in, c_out, unit_name)
with tf.name_scope("mask_branch"):
size_1 = pre_trunk.get_shape().as_list()[1:3]
max_pool_1 = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME", name="pool_1")(pre_trunk)
down_res = max_pool_1
skips = []
sizes = []
for skip_idx in range(skip_num):
for idx in range(r):
unit_name = "down_res{}_{}".format(skip_idx + 1, idx + 1)
down_res = self.residual_unit(down_res, c_in, c_out, unit_name)
skip_res = self.residual_unit(down_res, c_in, c_out, name="skip_res_{}".format(skip_idx + 1))
skips.append(skip_res)
size = down_res.get_shape().as_list()[1:3]
sizes.append(size)
down_res = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME",
name="res_pool_{}".format(skip_idx)
)(down_res)
midlle_res = down_res
for idx in range(2 * r):
unit_name = "down_res{}_{}".format(skip_num + 1, idx + 1)
midlle_res = self.residual_unit(midlle_res, c_in, c_out, unit_name)
skips.reverse()
sizes.reverse()
up_res = midlle_res
for skip_idx, data in enumerate(zip(skips, sizes)):
skip, size = data
interp = tf.image.resize_bilinear(up_res, size, name="interp_".format(skip_num + 1 - skip_idx))
up_res = skip + interp
for idx in range(r):
unit_name = "up_res{}_{}".format(skip_num - skip_idx, idx + 1)
up_res = self.residual_unit(up_res, c_in, c_out, unit_name)
interp = tf.image.resize_bilinear(up_res, size_1, name="interp_1")
mask_bn_1 = k.layers.BatchNormalization(name="mask_bn_1")(interp, self.is_training)
linear_1 = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(1, 1),
name="linear_1",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=0.001, scope="linear_1_l2")
)(mask_bn_1)
mask_bn_2 = k.layers.BatchNormalization(name="mask_bn_2")(linear_1, self.is_training)
linear_2 = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(1, 1),
name="linear_2",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=0.001, scope="linear2_l2")
)(mask_bn_2)
sigmoid = tf.nn.sigmoid(linear_2, "mask_sigmoid")
with tf.name_scope("fusing"):
outputs = k.layers.Multiply(name="fusing")([trunks, sigmoid])
outputs = k.layers.Add(name="fuse_add")([outputs, trunks])
with tf.name_scope("last_trunk"), tf.variable_scope("last_trunk"):
for idx in range(p):
unit_name = "last_trunk_{}".format(idx + 1)
outputs = self.residual_unit(outputs, c_in, c_out, unit_name)
return outputs
def residual_unit(self, x, c_in, c_out, name, stride=1, padding="SAME", scale=0.001):
"""
Residual Unit
"""
with tf.name_scope(name):
"""
tf.name_scope()用来管理命名空间
tf.get_variable()创建的共享变量不起作用
tf.name_scope()对tf.get_variable()创建的变量不起作用
BatchNormalization: 在每一个批次的数据中标准化前一层的**项即,
应用一个维持**项平均值接近0,标准方差接近1的转换
这里有个大坑:
tf.layers.BatchNormalization和tf.layers.batch_normalization会自动将
update_ops添加到tf.GraphKeys.UPDATE_OPS这个collection中,当training=True时才会添加;
而tf.keras.layers.BatchNormalization不会自动将update_ops添加到tf.GraphKeys.UPDATE_OPS这个collection中。
所以在TensorFlow训练session中使用tf.keras.layers.BatchNormalization时,
需要手动将keras.BatchNormal层的updates添加到tf.GraphKeys.UPDATE_OPS这个collection中。
在训练时,要将BatchNormalization中的training参数设置为True,测试时设置为False,
在保存模型时,要将\mu和\delta保存,这两个参数存放在tf.global_variables()中
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
"""
bn_1 = k.layers.BatchNormalization(name="bn_1")(x, self.is_training)
conv_1 = k.layers.Conv2D(c_out//4,
kernel_size=1,
strides=(1, 1),
padding=padding,
name="conv_1",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_1_l2")
)(bn_1)
bn_2 = k.layers.BatchNormalization(name="bn_2")(conv_1, self.is_training)
conv_2 = k.layers.Conv2D(c_out//4,
kernel_size=3,
strides=(stride, stride),
padding=padding,
name="conv_2",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_2_l2")
)(bn_2)
bn_3 = k.layers.BatchNormalization(name="bn_3")(conv_2, self.is_training)
conv_3 = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(1, 1),
padding=padding,
name="conv_3",
activation=None,
activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_3_l2")
)(bn_3)
if c_out != c_in or stride > 1:
skip = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(stride, stride),
padding=padding,
name="conv_skip",
activation=None,
activity_regularizer=layers.l2_regularizer(scale=scale, scope="skip_l2")
)(x)
else:
skip = x
outputs = k.layers.Add(name="fuse")([conv_3, skip])
return outputs

