在Android上加载带有人工智能危害检测的TensorFlow模型
fullstacker 发布于 2021-01-18

在Android系列上的AI危害检测的这篇文章中,我们将把TensorFlow Lite模型添加到项目中,并准备进行处理。

在这里,我们将一个经过训练的模型添加到一个Android项目中,并创建了一个用于向其中传递图像的用户界面。

在本系列的上一篇文章中,我们创建了一个项目,用于对驾驶员进行实时危险检测,并准备了一个用于TensorFlow Lite的检测模型。在这里,我们将继续加载模型并为图像处理做准备。

要将模型添加到项目中,请在src/main中创建一个名为assets的新文件夹。将TensorFlow Lite模型和包含标签的文本文件复制到src/main/assets,使其成为项目的一部分。


为了利用模型,我们必须编写代码来加载模型并通过模型传递数据。检测代码将被放置在一个可由两个用户界面共享的类中,以便在静态图像(用于测试)和实时视频流上使用相同的代码。


为模型格式化数据

在开始为此编写代码之前,我们需要知道模型期望其输入数据的结构。数据作为多维数组传入和传出。这也称为数据的“形状”。通常,当您找到一个模型时,这些信息会被记录下来。

您还可以使用工具Netron检查数据。从该工具打开模型时,将显示构成网络的节点。单击输入节点(显示在图表顶部)显示输入数据(在本例中为图像)的信息格式和网络的输出。在本例中,我们看到输入数据是32位浮点数的数组。阵列的尺寸为1x416x416x3。这意味着网络将一次接受一个具有红色、绿色和蓝色分量的416×416像素的图像。如果要为此项目使用不同的模型,则需要检查模型的输入和输出,并相应地调整代码。在解释结果时,我们将更详细地检查输出数据。


向项目中添加一个名为Detector的新类。所有用于管理训练网络的代码都将添加到此类中。构建类时,它将接受图像并以更易于使用的格式提供结果。我们应该向类中添加一些常量和字段来开始使用它。这些字段包括一个包含训练网络的TensorFlow解释器对象、模型识别的对象类列表以及应用程序上下文。

 class Detector {
   val TF_MODEL_NAME = "yolov4.tflite"
   val IMAGE_WIDTH = 416
   val IMAGE_HEIGHT = 416
   val TAG = "Detector"
   val useGpuDelegate = false;
   val useNNAPI=true;
   val context: Context;
   lateinit var tfLiteInterpreter:Interpreter
   var labelList = Vector<String>()

   //These output values are structured to match the output of the trained model being used
   var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
   var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
   var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
   var outputBuffers: HashMap<Int, Any>? = null
}

该类的构造函数将创建输出缓冲区,加载网络模型,并从assets文件夹加载对象类的名称。

 class Detector {
   val TF_MODEL_NAME = "yolov4.tflite"
   val IMAGE_WIDTH = 416
   val IMAGE_HEIGHT = 416
   val TAG = "Detector"
   val useGpuDelegate = false;
   val useNNAPI=true;
   val context: Context;
   lateinit var tfLiteInterpreter:Interpreter
   var labelList = Vector<String>()

   //These output values are structured to match the output of the trained model being used
   var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
   var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
   var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
   var outputBuffers: HashMap<Int, Any>? = null
}

测试模型

执行网络模型只需要几行代码。当图像被提供给探测器类时,它将被调整大小以匹配网络的要求。位图图像中的数据被编码为字节。这些值必须转换为32位浮点值。TensorFlow Lite库包含的功能可以使像这样的常见转换变得简单。TensorImage类型还有一个方便的方法,可以将它用作需要输入缓冲区的方法的缓冲区。

 public fun processImage(sourceImage: Bitmap) {
   val imageProcessor = ImageProcessor.Builder()
           .add(ResizeOp(IMAGE_HEIGHT, IMAGE_WIDTH, ResizeOp.ResizeMethod.BILINEAR))
           .build()
   var tImage = TensorImage(DataType.FLOAT32)
   tImage.load(sourceImage)
   tImage = imageProcessor.process(tImage)
   tfLiteInterpreter.runForMultipleInputsOutputs(arrayOf<any>(tImage.buffer), outputBuffers!!)
}</any>
要对此进行测试,请向项目中添加新布局。布局将有一个简单的界面,允许从设备的图像被选择。探测器将对选定的图像进行处理。

 <?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout>
   <ImageView

       android:id="@+id/selected_image_view"

 />
   <Button

       android:id="@+id/select_image_button"

       android:onClick="onSelectImageClicked"

 />
</androidx.constraintlayout.widget.ConstraintLayout>
此活动的代码将打开系统图像选择器。当选择一个图像并将其传回应用程序时,它将图像传递给检测器。

 public override fun onActivityResult(reqCode: Int, resultCode: Int, data: Intent?) {
   super.onActivityResult(reqCode, resultCode, data)
   if (resultCode == RESULT_OK) {
       if (reqCode == SELECT_PICTURE) {
           val selectedUri = data!!.data
           val fileString = selectedUri!!.path
           selected_image_view!!.setImageURI(selectedUri)
           var sourceBitmap: Bitmap? = null
           try {
               sourceBitmap =
                   MediaStore.Images.Media.getBitmap(this.contentResolver, selectedUri)
               RunDetector(sourceBitmap)
           } catch (e: IOException) {
               e.printStackTrace()
           }
       }
   }
}

fun RunDetector(bitmap: Bitmap?) {
   if (detector == null) detector = Detector(this)
   detector!!.processImage(bitmap!!)
}


全栈者
关注 私信
文章
31
关注
0
粉丝
0