Tensorflow手写数字识别在android中的实现

说明

下载TensorFlow Android Demo

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

生成模型

运行附件压缩包里的python脚本convnet.py生成mnist_model_graph_convnet.pb文件和graph_label_strings.txt文件:
文件

编译jar包和so库

1. 下载TensorFlow Android Demo
git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

备注:

--recurse-submodules 是为了避免一些protobuf 编译问题.

2. 修改WORKSPACE文件,指定SDK、NDK的版本和路径,请务必使用NDK r12b,下载路径为:
https://developer.android.com/ndk/downloads/older_releases.html  #ndk-12b-downloads

例如,我是这样配置的:

android_sdk_repository(
    name = "androidsdk",
    api_level = 25,
    # Ensure that you have the build_tools_version below installed in the
    # SDK manager as it updates periodically.
    build_tools_version = "25.0.3",
    # Replace with path to Android SDK on your system
    path = "/home/ckt/work/Android/Sdk",
)
#
# Android NDK r12b is recommended (higher may cause issues with Bazel)
android_ndk_repository(
    name="androidndk",
    path="/home/ckt/work/Android/ndk-r12b/",
    # This needs to be 14 or higher to compile TensorFlow.
    # Note that the NDK version is not the API level.
    api_level=14)

3. 编译jar包和so库
编译jar包和so库需要构建工具Bazel,Ubuntu环境下如何安装Bazel请参考网页:

https://bazel.build/versions/master/docs/install-ubuntu.html

编译jar包命令:

bazel build //tensorflow/contrib/android:android_tensorflow_inference_java


编译完成后,可以在以下路径找到libandroid_tensorflow_inference_java.jar文件:
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

编译so库命令:

bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
  --crosstool_top=//external:android/crosstool \
  [email protected]_tools//tools/cpp:toolchain \
  --cpu=armeabi-v7a 

###cpu一定要适配自己的手机,否则找不到so文件###


编译完成后,可以在以下路径找到libtensorflow_inference.so文件:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

编写应用

1. 打开Android Studio,新建一个android工程
将jar包放入libs目录,将so库放入src/main/jniLibs/armeabi-v7a目录,将之前生成的pb文件和text文件放入src/main/assets目录

2. 将TensorFlow Android Demo中的Classifier.java和TensorFlowImageClassifier.java复制到工程,这2个文件在TensorFlow Android Demo中的的路径为:

/tensorflow/examples/android/src/org/tensorflow/demo

注意:
需要将这2个类的包名修改为自己工程的包名。

3.为了简便操作,我们将下面的mnist_test.png(一张灰度图,28×28像素,白字黑底)放到src/main/assets目录下

Tensorflow手写数字识别在android中的实现

备注:
IMAGE_MEAN和IMAGE_STD的值在本项目没有实际意义,可以随便设置。

4.在activity中调用TensorFlowImageClassifier.create()方法创建分类器:

Tensorflow手写数字识别在android中的实现

5. 将mnist_test.png图片转换成相应的bitmap(28x28),通过classifier.recognizeImage(bitmap)来取得预测结果

Tensorflow手写数字识别在android中的实现

注意:
因为我们的输入数据是28x28的灰度图,原始代码用到了rgb三个通道,我们只需要一个通道,所以需要修改TensorFlowImageClassifier类的recognizeImage方法来适应模型,代码如下:

Tensorflow手写数字识别在android中的实现


 bitmapToFloatArray()方法如下:
 /**
  * 将bitmap转为(按行优先)一个float数组。其中的每个像素点都归一化到0~1之间。
  * @param bitmap 灰度图,r,g,b分量都相等。
  * @return
  */
 public static float[] bitmapToFloatArray(Bitmap bitmap){
   int height = bitmap.getHeight();
   int width = bitmap.getWidth();
   float[] result = new float[height * width];

   int k = 0;
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       int argb = bitmap.getPixel(i, j);
       // 由于是灰度图,所以r,g,b分量是相等的。
       int r = Color.red(argb);
       result[k++] = r / 255.0f;
     }
   }
   return result;
 }