Tensorflow手写数字识别在android中的实现
说明
下载TensorFlow Android Demo
git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git
生成模型
运行附件压缩包里的python脚本convnet.py生成mnist_model_graph_convnet.pb文件和graph_label_strings.txt文件:文件
编译jar包和so库
1. 下载TensorFlow Android Demo
git clone --recurse-submodules
https://github.com/tensorflow/tensorflow.git
备注:
--recurse-submodules
是为了避免一些protobuf 编译问题.
2. 修改WORKSPACE文件,指定SDK、NDK的版本和路径,请务必使用NDK r12b,下载路径为:
https://developer.android.com/ndk/downloads/older_releases.html #ndk-12b-downloads
例如,我是这样配置的:
android_sdk_repository( name = "androidsdk", api_level = 25, # Ensure that you have the build_tools_version below installed in the # SDK manager as it updates periodically. build_tools_version = "25.0.3", # Replace with path to Android SDK on your system path = "/home/ckt/work/Android/Sdk", ) # # Android NDK r12b is recommended (higher may cause issues with Bazel) android_ndk_repository( name="androidndk", path="/home/ckt/work/Android/ndk-r12b/", # This needs to be 14 or higher to compile TensorFlow. # Note that the NDK version is not the API level. api_level=14)
3. 编译jar包和so库
编译jar包和so库需要构建工具Bazel,Ubuntu环境下如何安装Bazel请参考网页:
https://bazel.build/versions/master/docs/install-ubuntu.html
编译jar包命令:
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
编译完成后,可以在以下路径找到libandroid_tensorflow_inference_java.jar文件:
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar
编译so库命令:
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \ --crosstool_top=//external:android/crosstool \ [email protected]_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a
###cpu一定要适配自己的手机,否则找不到so文件###
编译完成后,可以在以下路径找到libtensorflow_inference.so文件:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so
编写应用
1. 打开Android Studio,新建一个android工程将jar包放入libs目录,将so库放入src/main/jniLibs/armeabi-v7a目录,将之前生成的pb文件和text文件放入src/main/assets目录
2. 将TensorFlow Android Demo中的Classifier.java和TensorFlowImageClassifier.java复制到工程,这2个文件在TensorFlow Android Demo中的的路径为:
/tensorflow/examples/android/src/org/tensorflow/demo
注意:
需要将这2个类的包名修改为自己工程的包名。
3.为了简便操作,我们将下面的mnist_test.png(一张灰度图,28×28像素,白字黑底)放到src/main/assets目录下
备注:
IMAGE_MEAN和IMAGE_STD的值在本项目没有实际意义,可以随便设置。
4.在activity中调用TensorFlowImageClassifier.create()方法创建分类器:
5. 将mnist_test.png图片转换成相应的bitmap(28x28),通过classifier.recognizeImage(bitmap)来取得预测结果
注意:
因为我们的输入数据是28x28的灰度图,原始代码用到了rgb三个通道,我们只需要一个通道,所以需要修改TensorFlowImageClassifier类的recognizeImage方法来适应模型,代码如下:
bitmapToFloatArray()方法如下: /** * 将bitmap转为(按行优先)一个float数组。其中的每个像素点都归一化到0~1之间。 * @param bitmap 灰度图,r,g,b分量都相等。 * @return */ public static float[] bitmapToFloatArray(Bitmap bitmap){ int height = bitmap.getHeight(); int width = bitmap.getWidth(); float[] result = new float[height * width]; int k = 0; for (int j = 0; j < height; j++) { for (int i = 0; i < width; i++) { int argb = bitmap.getPixel(i, j); // 由于是灰度图,所以r,g,b分量是相等的。 int r = Color.red(argb); result[k++] = r / 255.0f; } } return result; }