TensorFlow 图变换:FoldBatchNorm
TensorFlow 的计算是按图(Graph)组织的,构建好的图有时需要根据需要做一些变换(例如将训练好的模型部署到生产环境时,去除无用的节点),在保证计算结果不变(或近似不变)的情况下优化计算速度或降低内存占用。Graph Transform Tool【1】是 TensorFlow 提供的一组可以修改 TensorFlow Graph 的工具,使用方便,易于扩展。
使用 Graph Transform Tool 时,它的操作对象为 GraphDef 对象,通常保存为二进制文件,后缀为 .pb。前面文章《TensorFlow 到底有几种模型格式?》介绍过这种文件的生成方式。
该工具调用方式如下:
bazel build tensorflow/tools/graph_transforms:transform_graph bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --in_graph=tensorflow_inception_graph.pb \ --out_graph=optimized_inception_graph.pb \ --inputs='Mul:0' \ --outputs='softmax:0' \ --transforms=' strip_unused_nodes(type=float, shape="1,299,299,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants fold_batch_norms'
注意需要在 TensorFlow 源码根目录下运行。
其中参数 --in_graph 指定输入 GraphDef 文件名,--out_graph 制定输出 GraphDef 文件名,--inputs 指定输入 Node,--outputs 指定输出 Node,--transforms 指定变换类型。变换类型使用一串命令构成,每条命令都对应一种变换。
Batch Normalization(后面简称 BN)【2】是一种加速深度模型训练的技术,通过训练时对每个 mini-batch 内的 activations 做归一化降低 internal covariate shift,进而加速模型收敛。目前主流深度学习模型(ResNet,Inception,DenseNet,……)几乎都使用了 BN 技术。训练完毕,Batch Normalization 的参数(均值 E[x] 和方差 Val[x])不再更新,在后续推理计算时,可将这些常数参数通过 constant folding 来简化模型。BN 计算公式如下:
一般 BN 位置都在 Convolution 之后(DenseNet 例外),以 TensorFlow 实现的 Inception V3 模型【3】 为例,Conv-BN-Relu 表示为计算图如下:
其中 Conv2D 节点实现 Convolution 计算,Rsqrt 实现先求平方根再取倒数运算,Mul, Add, Sub 分别实现乘法、加法、减法计算,Const 为常数,代表计算参数。由于推理计算时 BN 输入参数均为常数,那么经过 constant folding, BN 可在算数上简化为:
y = x * a + b
进一步,当 x 为卷积输出,对卷积权值直接乘上 a,就可以在前向计算时直接得到 x * a 的结果,这一步称为 BN folding。经过两步简化后的计算图为:
此时节点数目也有大量缩减。在推理计算时,能降低运行时间和存储开销。
与上面 BN folding 优化对应的 Graph Transform 代码位于 tensorflow/tools/graph_transforms/fold_batch_norms.cc,其中使用了一个非常有用的函数:
Status ReplaceMatchingOpTypes(
const GraphDef& input_graph_def,
const OpTypePattern& pattern,
const std::function<Status(const NodeMatch&, const std::set<string>&,
const std::set<string>&, std::vector<NodeDef>*)>& node_generator,
const ReplaceMatchingOpTypesOptions& options,
GraphDef* output_graph_def);
该函数将 input_graph_def 中所有与 pattern 匹配的子图替换为 node_generator 产生的新 op,然后保存到 output_graph_def 中。
pattern 定义为:
{"Mul", // mul_node
{
{"Conv2D|MatMul", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
}
},
{"Const"}, // mul_values_node
}
}, // clang-format on
上述 pattern 能匹配原计算图中 Conv2D -> Mul 子图。node_generator 代码中将匹配后的子图直接替换为新的 Conv2D(权值常量更新为原权值与乘数因子 a 的乘积)。代码如下:
// 从匹配模式中得到 Mul、Conv、Input、weight、Mul value 节点
const NodeDef& mul_node = match.node;
const NodeDef& conv_node = match.inputs[0].node;
const NodeDef& input_node = match.inputs[0].inputs[0].node;
const NodeDef& weights_node = match.inputs[0].inputs[1].node;
const NodeDef& mul_values_node = match.inputs[1].node;
// 获取卷积权值、乘数因子数值
Tensor weights = GetNodeTensorAttr(weights_node, "value");
Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value");
// 原始卷积权值乘上乘数因子
auto weights_matrix = weights.flat_inner_dims<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
for (int64 col = 0; col < weights_cols; ++col) {
scaled_weights_matrix(row, col) =
weights_matrix(row, col) * mul_values.flat<float>()(col);
}
}
// 构造新的卷积权值节点,填入更新后的权值
NodeDef scaled_weights_node;
scaled_weights_node.set_op("Const");
scaled_weights_node.set_name(weights_node.name());
SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
new_nodes->push_back(scaled_weights_node);
new_nodes->push_back(input_node);
// 构造新的卷积节点,复制旧卷积节点参数,改个名
NodeDef new_conv_node;
new_conv_node = conv_node;
new_conv_node.set_name(mul_node.name());
new_nodes->push_back(new_conv_node);
return Status::OK();
Graph Transform 工具为离线优化工具,优化后的 GraphDef 文件可以像原先模型一样部署,无需修改生产环境代码。
本文介绍的 BN folding 优化方法可适用于 CPU、GPU、移动端、嵌入式等各种需要推理加速的场景。
【1】 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms
【2】 Batch Normalization : Accelerating Deep Network Training by Reducing Internal Covariate Shift, arXiv:1502.03167
【3】 http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz