Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

向AI转型的程序员都关注了这个号????????????

机器学习AI算法工程   公众号:datayx

这几天一直在用Pytorch来复现文本检测领域的CTPN论文,本文章将从数据处理、训练标签生成、神经网络搭建、损失函数设计、训练主过程编写等这几个方面来一步一步复现CTPN。CTPN算法理论可以参考这里。

https://www.cnblogs.com/skyfsm/p/9776611.html

本文项目代码 获取方式:

关注微信公众号 datayx  然后回复  CTPN  即可获取。

AI项目体验地址 https://loveai.tech

训练数据处理

我们的训练选择天池ICPR2018和MSRA_TD500两个数据集,天池ICPR的数据集为网络图像,都是一些淘宝商家上传到淘宝的一些商品介绍图像,其标签方式参考了ICDAR2015的数据标签格式,即一个文本框用4个坐标来表示,即左上、右上、右下、左下四个坐标,共八个值,记作[x1 y1 x2 y2 x3 y3 x4 y4]

Pytorch文本行检测,深度学习网络结构CTPN

天池ICPR2018数据集的风格如下,字体形态格式颜色多变,多嵌套于物体之中,识别难度大:

Pytorch文本行检测,深度学习网络结构CTPN

MSRA_TD500使微软收集的一个文本检测和识别的一个数据集,里面的图像多是街景图,背景比较复杂,但文本位置比较明显,一目了然。因为MSRA_TD500的标签格式不一样,最后一个参数表示矩形框的旋转角度。

Pytorch文本行检测,深度学习网络结构CTPN

所以我们第一步就是将这两个数据集的标签格式统一,我的做法是将MSRA数据集格式改为ICDAR格式,方便后面的模型训练。因为MSRA_TD500采取的标签格式是[index difficulty_label x y w h angle],所以我们需要根据这个文本框的旋转角度来求得水平文本框旋转后的4个坐标位置。实现如下:

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

经过格式处理后,我们两份数据集算是整理好了。当然我们还需要对整个数据集划分为训练集和测试集,我的文件组织习惯如下:train_im, test_im文件夹装的是训练和测试图像,train_gt和test_gt装的是训练和测试标签。

Pytorch文本行检测,深度学习网络结构CTPN

训练标签生成

因为CTPN的核心思想也是基于Faster RCNN中的region proposal机制的,所以原始数据标签需要转化为

第一步我们需要将原先每张图的bbox标签转化为每个anchor标签。为了实现该功能,我们先将一张图划分为宽度为16的各个anchor。

  • 首先计算一张图可以分为多少个宽度为16的acnhor(比如一张图的宽度为w,那么水平anchor总数为w/16),再计算出我们的文本框标签中含有几个acnhor,最左和最右的anchor又是哪几个;

  • 计算文本框内anchor的高度和中心是多少:此时我们可以在一个全黑的mask中把文本框label画上去(白色),然后从上往下和从下往上找到第一个白色像素点的位置作为该anchor的上下边界;

  • 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回

Pytorch文本行检测,深度学习网络结构CTPN

计算anchor上下边界的方法:

Pytorch文本行检测,深度学习网络结构CTPN

经过上面的标签处理,我们已经将原先的标准的文本框标签转化为一个一个小尺度anchor标签,以下是标签转化后的效果:

Pytorch文本行检测,深度学习网络结构CTPN

以上标签可视化后看来anchor标签做得不错,但是这里需要提出的是,我发现这种anchor生成方法是不太精准的,比如一个文本框边缘像素刚好落在一个新的anchor上,那么我们就要为这个像素分配一个16像素的anchor,显然导致了文本框标签的不准确,引入了15像素的误差,这个是需要思考的。这个问题我们先不做处理,继续下面的工作。

当然转化期间我们也遇到很多奇怪的问题,比如下图这种标签都已经超出图像范围的,我们必须做相应的特殊处理,比如限定标签横坐标的最大尺寸为图像宽度。

Pytorch文本行检测,深度学习网络结构CTPN

CTPN网络结构

因为CTPN用到了CNN+双向LSTM的网络结构,所以我们分步实现CTPN架构。

Pytorch文本行检测,深度学习网络结构CTPN

CNN部分CTPN采取了VGG16进行底层特征提取。

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

再实现双向LSTM,增强关联序列的信息学习。

Pytorch文本行检测,深度学习网络结构CTPN

这里实现多一层中间层,用于连接CNN和LSTM。将VGG最后一层卷积层输出的feature map转化为向量形式,用于接下来的LSTM训练。

Pytorch文本行检测,深度学习网络结构CTPN

最后将以上三部分拼接成一个完整的CTPN网络:底层使用VGG16做特征提取->lstm序列信息学习->output每个anchor分数,h, y, side_refinement

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

训练过程设计

训练:优化器我们选择SGD,learning rate我们设置了两个,前N个epoch使用较大的lr,后面的epoch使用较小的lr以更好地收敛。训练过程我们定义了4个loss,分别是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三个loss相加)。

Pytorch文本行检测,深度学习网络结构CTPN

检测效果和总结

首先看一下训练出来的模型的文字检测效果,为了便于观察,我把anchor和最终合并好的文本框一并画出:

Pytorch文本行检测,深度学习网络结构CTPN

Pytorch文本行检测,深度学习网络结构CTPN

下面再看看一些比较好的文字检测效果吧:

Pytorch文本行检测,深度学习网络结构CTPN

在实现过程中的一些总结和想法:

  1. CTPN对于带旋转角度的文本的检测效果不好,其实这是CTPN的算法特点决定的:一个个固定宽度的四边形是很难合并出一个准确的文本框,比如一些anchors很难组成一组,即使组成一组了也很难精确恢复成完整的精确的文本矩形框(推断阶段的缺点)。当然啦,对于水平排布的文本检测,个人认为这个算法思路还是很奏效的。

  2. CTPN中的side-refinement其实作用不大,如果我们检测出来的文本是直接拿出识别,这个side-refinement优化的几个像素差别其实可以忽略;

  3. CTPN的中间步骤有点多:从anchor标签的生成到中间计算loss再到最后推断的文本线生成步骤,都会引入一定的误差,这个缺点也是EAST论文中所提出的。训练的步骤越简洁,中间过程越少,精度更有保障。

  4. CTPN的算法得出的效果可以看出,准确率低但召回率高。这种基于16像素的anchor识别感觉对于一些大的非文字图标(比如路标)误判率相当高,这是源于其anchor的宽度实在太小了,尽管使用了lstm关联周围anchor,但是我还是认为有点“一叶障目”的感觉。所以CTPN对于过大或过小的文字检测效果不会太好。

  5. CTPN是个比较老的算法了(2016年),其思路在当年还是很创新的,但是也有很多弊端。现在提出的新方法已经基本解决了这些不足之处,比如EAST,PixelNet都是一些很优秀的新算法。

原文地址 https://www.cnblogs.com/skyfsm/p/10054386.html


阅读过本文的人还看了以下文章: