PaddlePaddle升级解读|十余行代码完成迁移学习 PaddleHub实战篇
迁移学习 (Transfer Learning) 是属于深度学习的一个子研究领域,该研究领域的目标在于利用数据、任务、或模型之间的相似性,将在旧领域学习过的知识,迁移应用于新领域中。迁移学习吸引了很多研究者投身其中,因为它能够很好的解决深度学习中的以下几个问题:
一些研究领域只有少量标注数据,且数据标注成本较高,不足以训练一个足够鲁棒的神经网络
大规模神经网络的训练依赖于大量的计算资源,这对于一般用户而言难以实现
应对于普适化需求的模型,在特定应用上表现不尽如人意
为了让开发者更便捷地应用迁移学习,百度 PaddlePaddle 开源了预训练模型管理工具 PaddleHub。开发者用使用仅仅十余行的代码,就能完成迁移学习。本文将全面介绍 PaddleHub 及其应用方法。
PaddleHub 介绍
PaddleHub 是基于 PaddlePaddle 开发的预训练模型管理工具,可以借助预训练模型更便捷地开展迁移学习工作,旨在让 PaddlePaddle 生态下的开发者更便捷体验到大规模预训练模型的价值。
PaddleHub 目前的预训练模型覆盖了图像分类、目标检测、词法分析、Transformer、情感分析五大类别。未来会持续开放更多类型的深度学习模型,如语言模型、视频分类、图像生成等预训练模型
图 1 PaddleHub 功能全景
PaddleHub 主要包括两个功能:命令行工具和 Fine-tune API。
命令行工具
PaddleHub 借鉴了 Anaconda 和 PIP 等软件包管理的理念,开发了命令行工具,可以方便快捷的完成模型的搜索、下载、安装、预测等功能,对应的关键的命令分别是 search,download,install,run 等。我们以 run 命令为例,介绍如何通过命令行工具进行预测。
Run 命令用于执行 Module 的预测,这里分别举一个 NLP 和 CV 的例子。
对于 NLP 任务:输入数据通过–input_text 指定。以百度 LAC 模型(中文词法分析)为例,可以通过以下命令实现单行文本分析。
对于 CV 任务:输入数据通过–input_path 指定。以 SSD 模型(单阶段目标检测)为例子,可以通过以下命令实现单张图片的预测。
更多的命令用法,请读者参考文首的 Github 项目链接。
Fine-tune API
PaddleHub 提供了基于 PaddlePaddle 实现的 Fine-tune API, 重点针对大规模预训练模型的 Fine-tune 任务做了高阶的抽象,让预训练模型能更好服务于用户特定场景的应用。通过大规模预训练模型结合 Fine-tune,可以在更短的时间完成模型的收敛,同时具备更好的泛化能力。
图 2 PaddleHub Fine-tune API 全景
Fine-tune :对一个 Task 进行 Fine-tune,并且定期进行验证集评估。在 Fine-tune 的过程中,接口会定期的保存 checkpoint(模型和运行数据),当运行被中断时,通过 RunConfig 指定上一次运行的 checkpoint 目录,可以直接从上一次运行的最后一次评估中恢复状态继续运行。
迁移任务 Task:在 PaddleHub 中,Task 代表了一个 Fine-tune 的任务。任务中包含了执行该任务相关的 program 以及和任务相关的一些度量指标(如分类准确率 accuracy、precision、 recall、 F1-score 等)、模型损失等。
运行配置 RunConfig:在 PaddleHub 中,RunConfig 代表了在对 Task 进行 Fine-tune 时的运行配置。包括运行的 epoch 次数、batch 的大小、是否使用 GPU 训练等。
优化策略 Strategy:在 PaddleHub 中,Strategy 类封装了一系列适用于迁移学习的 Fine-tune 策略。Strategy 包含了对预训练参数使用什么学习率变化策略,使用哪种类型的优化器,使用什么类型的正则化等。
预训练模型 Module :Module 代表了一个可执行的模型。这里的可执行指的是,Module 可以直接通过命令行 hub run {HOST_IP} 为本机 IP,需要用户自行指定
启动服务后,我们使用浏览器访问${HOST_IP}:8989,可以看到训练以及预测的 loss 曲线和 accuracy 曲线,如下图所示。
10. 使用模型进行预测
当 Fine-tune 完成后,我们使用模型来进行预测,整个预测流程大致可以分为以下几步:
1、构建网络
2、生成预测数据的 Reader
3、切换到预测的 Program
4、加载预训练好的参数
5、运行 Program 进行预测
通过以下命令来获取测试的图片(适用于猫狗分类的数据集)
注意:其他数据集所用的测试图片请自行准备。
完整预测代码如下