前言
TensorFlow Lite是一款专门针对移动设备的深度学习框架,移动设备深度学习框架是部署在手机或者树莓派等小型移动设备上的深度学习框架,可以使用训练好的模型在手机等设备上完成推理任务。这一类框架的出现,可以使得一些推理的任务可以在本地执行,不需要再调用服务器的网络接口,大大减少了预测时间。在前几篇文章中已经介绍了百度的paddle-mobile,小米的mace,还有腾讯的ncnn。这在本章中我们将介绍谷歌的TensorFlow Lite。
TensorFlow Lite的GitHub地址:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
正文
转换模型
手机上执行预测,首先需要一个训练好的模型,这个模型不能是TensorFlow原来格式的模型,TensorFlow Lite使用的模型格式是另一种格式的模型。
下面就介绍如何使用这个格式的模型。 获取模型主要有两种方法,第一种是在训练的时候就保存tflite模型,另外一种就是使用其他格式的TensorFlow模型转换成tflite模型。
最方便的就是在训练的时候保存tflite格式的模型,主要是使用到tf.contrib.lite.toco_convert()接口,下面就是一个简单的例子:
import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) out = tf.identity(val, name="out") with tf.Session() as sess: tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) open("converteds_model.tflite", "wb").write(tflite_model)
最后获得的converteds_model.tflite文件就可以直接在TensorFlow Lite上使用。
第二种就是把tensorflow保存的其他模型转换成tflite,我们可以在以下的链接下载模型。tensorflow模型地址如下所示:
https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
上面提供的模型同时也包括了tflite模型,我们可以直接拿来使用,但是我们也可以使用其他格式的模型来转换。比如我们下载一个mobilenet_v1_1.0_224.tgz,解压之后获得以下文件:
mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 mobilenet_v1_1.0_224_eval.pbtxt mobilenet_v1_1.0_224.tflite mobilenet_v1_1.0_224.ckpt.index mobilenet_v1_1.0_224_frozen.pb mobilenet_v1_1.0_224.ckpt.meta mobilenet_v1_1.0_224_info.txt
首先要安装Bazel,可以参考:
https://docs.bazel.build/versions/master/install-ubuntu.html
只需要完成Installing using binary installer这一部分即可。然后克隆TensorFlow的源码:
git clone https://github.com/tensorflow/tensorflow.git
接着编译转换工具,这个编译时间可能比较长:
cd tensorflow/ bazel build tensorflow/python/tools:freeze_graph bazel build tensorflow/contrib/lite/toco:toco
获得到转换工具之后,我们就可以开始转换模型了,以下操作是冻结图。
不过要注意的是我们下载的模型已经是冻结过来,所以不用再执行这个操作。但如果是其他的模型,要先冻结图,然后再执行之后的操作。
./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb --input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb --output_node_names=MobilenetV1/Predictions/Reshape_1
以下操作就是把已经冻结的图转换成.tflite:
./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=/tmp/mobilenet_v1_1.0_224.tflite --inference_type=FLOAT --input_type=FLOAT --input_arrays=input --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3
经过上面的步骤就可以获取到mobilenet_v1_1.0_224.tflite模型了,之后我们会在Android项目中使用它。
开发Android项目
有了上面的模型之后,我们就使用Android Studio创建一个Android项目,一路默认就可以了,并不需要C++的支持,因为我们使用到的TensorFlow Lite是Java代码的,开发起来非常方便。
1、创建完成之后,在app目录下的build.gradle配置文件加上以下配置信息: 在dependencies下加上包的引用,第一个是图片加载框架Glide,第二个就是我们这个项目的核心TensorFlow Lite:
implementation 'com.github.bumptech.glide:glide:4.3.1' implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
然后在android下加上以下代码,这个主要是限制不要对tensorflow lite的模型进行压缩,压缩之后就无法加载模型了:
//set no compress models aaptOptions { noCompress "tflite" }
在main目录下创建assets文件夹,这个文件夹主要是存放tflite模型和label名称文件。
以下是主界面的代码MainActivity.java,这个代码比较长,我们来分析这段代码,重要的方法介绍如下:
package com.yeyupiaoling.testtflite; import android.Manifest; import android.app.Activity; import android.content.DialogInterface; import android.content.Intent; import android.content.pm.PackageManager; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.net.Uri; import android.os.Bundle; import android.support.annotation.NonNull; import android.support.annotation.Nullable; import android.support.v4.app.ActivityCompat; import android.support.v4.content.ContextCompat; import android.support.v7.app.AlertDialog; import android.support.v7.app.AppCompatActivity; import android.text.method.ScrollingMovementMethod; import android.util.Log; import android.view.View; import android.widget.Button; import android.widget.ImageView; import android.widget.TextView; import android.widget.Toast; import com.bumptech.glide.Glide; import com.bumptech.glide.load.engine.DiskCacheStrategy; import com.bumptech.glide.request.RequestOptions; import org.tensorflow.lite.Interpreter; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.List; public class MainActivity extends AppCompatActivity { private static final String TAG = MainActivity.class.getName(); private static final int USE_PHOTO = 1001; private static final int START_CAMERA = 1002; private String camera_image_path; private ImageView show_image; private TextView result_text; private String assets_path = "lite_images"; private boolean load_result = false; private int[] ddims = {1, 3, 224, 224}; private int model_index = 0; private List<String> resultLabel = new ArrayList<>(); private Interpreter tflite = null; private static final String[] PADDLE_MODEL = { "mobilenet_v1", "mobilenet_v2" }; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); init_view(); readCacheLabelFromLocalFile(); } // initialize view private void init_view() { request_permissions(); show_image = (ImageView) findViewById(R.id.show_image); result_text = (TextView) findViewById(R.id.result_text); result_text.setMovementMethod(ScrollingMovementMethod.getInstance()); Button load_model = (Button) findViewById(R.id.load_model); Button use_photo = (Button) findViewById(R.id.use_photo); Button start_photo = (Button) findViewById(R.id.start_camera); load_model.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View view) { showDialog(); } }); // use photo click use_photo.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View view) { if (!load_result) { Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show(); return; } PhotoUtil.use_photo(MainActivity.this, USE_PHOTO); } }); // start camera click start_photo.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View view) { if (!load_result) { Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show(); return; } camera_image_path = PhotoUtil.start_camera(MainActivity.this, START_CAMERA); } }); } /** * Memory-map the model file in Assets. */ private MappedByteBuffer loadModelFile(String model) throws IOException { AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } // load infer model private void load_model(String model) { try { tflite = new Interpreter(loadModelFile(model)); Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show(); Log.d(TAG, model + " model load success"); tflite.setNumThreads(4); load_result = true; } catch (IOException e) { Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show(); Log.d(TAG, model + " model load fail"); load_result = false; e.printStackTrace(); } } public void showDialog() { AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this); // set dialog title builder.setTitle("Please select model"); // set dialog icon builder.setIcon(android.R.drawable.ic_dialog_alert); // able click other will cancel builder.setCancelable(true); // cancel button builder.setNegativeButton("cancel", null); // set list builder.setSingleChoiceItems(PADDLE_MODEL, model_index, new DialogInterface.OnClickListener() { @Override public void onClick(DialogInterface dialog, int which) { model_index = which; load_model(PADDLE_MODEL[model_index]); dialog.dismiss(); } }); // show dialog builder.show(); } private void readCacheLabelFromLocalFile() { try { AssetManager assetManager = getApplicationContext().getAssets(); BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel.txt"))); String readLine = null; while ((readLine = reader.readLine()) != null) { resultLabel.add(readLine); } reader.close(); } catch (Exception e) { Log.e("labelCache", "error " + e); } } @Override protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) { String image_path; RequestOptions options = new RequestOptions().skipMemoryCache(true).diskCacheStrategy(DiskCacheStrategy.NONE); if (resultCode == Activity.RESULT_OK) { switch (requestCode) { case USE_PHOTO: if (data == null) { Log.w(TAG, "user photo data is null"); return; } Uri image_uri = data.getData(); Glide.with(MainActivity.this).load(image_uri).apply(options).into(show_image); // get image path from uri image_path = PhotoUtil.get_path_from_URI(MainActivity.this, image_uri); // predict image predict_image(image_path); break; case START_CAMERA: // show photo Glide.with(MainActivity.this).load(camera_image_path).apply(options).into(show_image); // predict image predict_image(camera_image_path); break; } } } // predict image private void predict_image(String image_path) { // picture to float array Bitmap bmp = PhotoUtil.getScaleBitmap(image_path); ByteBuffer inputData = PhotoUtil.getScaledMatrix(bmp, ddims); try { // Data format conversion takes too long // Log.d("inputData", Arrays.toString(inputData)); float[][] labelProbArray = new float[1][1001]; long start = System.currentTimeMillis(); // get predict result tflite.run(inputData, labelProbArray); long end = System.currentTimeMillis(); long time = end - start; float[] results = new float[labelProbArray[0].length]; System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length); // show predict result and time int r = get_max_result(results); String show_text = "result:" + r + " name:" + resultLabel.get(r) + " probability:" + results[r] + " time:" + time + "ms"; result_text.setText(show_text); } catch (Exception e) { e.printStackTrace(); } // get max probability label private int get_max_result(float[] result) { float probability = result[0]; int r = 0; for (int i = 0; i < result.length; i++) { if (probability < result[i]) { probability = result[i]; r = i; } } return r; } // request permissions private void request_permissions() { List<String> permissionList = new ArrayList<>(); if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { permissionList.add(Manifest.permission.CAMERA); } if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) { permissionList.add(Manifest.permission.WRITE_EXTERNAL_STORAGE); } if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) { permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE); } // if list is not empty will request permissions if (!permissionList.isEmpty()) { ActivityCompat.requestPermissions(this, permissionList.toArray(new String[permissionList.size()]), 1); } } @Override public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { super.onRequestPermissionsResult(requestCode, permissions, grantResults); switch (requestCode) { case 1: if (grantResults.length > 0) { for (int i = 0; i < grantResults.length; i++) { int grantResult = grantResults[i]; if (grantResult == PackageManager.PERMISSION_DENIED) { String s = permissions[i]; Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show(); } } } break; } } }
以下的代码片段是一个工具类PhotoUtil.java,各方法功能如下:
package com.yeyupiaoling.testtflite; import android.app.Activity; import android.content.Context; import android.content.Intent; import android.database.Cursor; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.net.Uri; import android.os.Build; import android.os.Environment; import android.provider.MediaStore; import android.support.v4.content.FileProvider; import android.util.Log; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; public class PhotoUtil { // start camera public static String start_camera(Activity activity, int requestCode) { Uri imageUri; // save image in cache path File outputImage = new File(Environment.getExternalStorageDirectory().getAbsolutePath() + "/lite_mobile/", System.currentTimeMillis() + ".jpg"); Log.d("outputImage", outputImage.getAbsolutePath()); try { if (outputImage.exists()) { outputImage.delete(); } File out_path = new File(Environment.getExternalStorageDirectory().getAbsolutePath() + "/lite_mobile/"); if (!out_path.exists()) { out_path.mkdirs(); } outputImage.createNewFile(); } catch (IOException e) { e.printStackTrace(); } if (Build.VERSION.SDK_INT >= 24) { // compatible with Android 7.0 or over imageUri = FileProvider.getUriForFile(activity, "com.yeyupiaoling.testtflite.fileprovider", outputImage); } else { imageUri = Uri.fromFile(outputImage); } // set system camera Action Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE); intent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION); // set save photo path intent.putExtra(MediaStore.EXTRA_OUTPUT, imageUri); // set photo quality, min is 0, max is 1 intent.putExtra(MediaStore.EXTRA_VIDEO_QUALITY, 0); activity.startActivityForResult(intent, requestCode); // return image absolute path return outputImage.getAbsolutePath(); } // get picture in photo public static void use_photo(Activity activity, int requestCode) { Intent intent = new Intent(Intent.ACTION_PICK); intent.setType("image/*"); activity.startActivityForResult(intent, requestCode); } // get photo from Uri public static String get_path_from_URI(Context context, Uri uri) { String result; Cursor cursor = context.getContentResolver().query(uri, null, null, null, null); if (cursor == null) { result = uri.getPath(); } else { cursor.moveToFirst(); int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA); result = cursor.getString(idx); cursor.close(); } return result; } // TensorFlow model,get predict data public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) { ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4); imgData.order(ByteOrder.nativeOrder()); // get image pixel int[] pixels = new int[ddims[2] * ddims[3]]; Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false); bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]); int pixel = 0; for (int i = 0; i < ddims[2]; ++i) { for (int j = 0; j < ddims[3]; ++j) { final int val = pixels[pixel++]; imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f)); imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f)); imgData.putFloat((((val & 0xFF) - 128f) / 128f)); } } if (bm.isRecycled()) { bm.recycle(); } return imgData; } // compress picture public static Bitmap getScaleBitmap(String filePath) { BitmapFactory.Options opt = new BitmapFactory.Options(); opt.inJustDecodeBounds = true; BitmapFactory.decodeFile(filePath, opt); int bmpWidth = opt.outWidth; int bmpHeight = opt.outHeight; int maxSize = 500; // compress picture with inSampleSize opt.inSampleSize = 1; while (true) { if (bmpWidth / opt.inSampleSize < maxSize || bmpHeight / opt.inSampleSize < maxSize) { break; } opt.inSampleSize *= 2; } opt.inJustDecodeBounds = false; return BitmapFactory.decodeFile(filePath, opt); } }
AndroidManifest.xml下加上申请的权限,用到了相机和读取外部存储的内存:
<uses-permission android:name="android.permission.CAMERA"/> <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/> <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
然后还要在application下加上以下的配置信息,这个主要是为了兼容Android 7.0的相机:
<!-- FileProvider配置访问路径,适配7.0及其以上 --> <provider android:name="android.support.v4.content.FileProvider" android:authorities="com.yeyupiaoling.testtflite.fileprovider" android:exported="false" android:grantUriPermissions="true"> <meta-data android:name="android.support.FILE_PROVIDER_PATHS" android:resource="@xml/file_paths"/> </provider>
之后在res创建一个xml目录,然后创建一个file_paths.xml文件,在这个文件中加上以下代码,这个是我们拍照之后图片存放的位置:
<?xml version="1.0" encoding="utf-8"?> <resources> <external-path name="images" path="lite_mobile/" /> </resources>
主界面布局代码activity_main.xml:
<?xml version="1.0" encoding="utf-8"?> <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" tools:context=".MainActivity"> <LinearLayout android:id="@+id/btn1_ll" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_alignParentBottom="true" android:orientation="horizontal"> <Button android:id="@+id/use_photo" android:layout_width="0dp" android:layout_height="wrap_content" android:layout_weight="1" android:text="相册" /> <Button android:id="@+id/start_camera" android:layout_width="0dp" android:layout_height="wrap_content" android:layout_weight="1" android:text="拍照" /> </LinearLayout> <LinearLayout android:id="@+id/btn2_ll" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_above="@id/btn1_ll" android:orientation="horizontal"> <Button android:id="@+id/load_model" android:layout_width="0dp" android:layout_height="wrap_content" android:layout_weight="1" android:text="加载模型" /> </LinearLayout> <TextView android:id="@+id/result_text" android:layout_width="match_parent" android:layout_height="150dp" android:layout_above="@id/btn2_ll" android:hint="预测结果会在这里显示" android:inputType="textMultiLine" android:textSize="16sp" tools:ignore="TextViewEdits" /> <ImageView android:id="@+id/show_image" android:layout_width="match_parent" android:layout_height="match_parent" android:layout_above="@id/result_text" android:layout_alignParentTop="true" /> </RelativeLayout>
以下就是效果图片:
热门源码