在Android系列上的AI危害检测的这篇文章中,我们将把TensorFlow Lite模型添加到项目中,并准备进行处理。
在这里,我们将一个经过训练的模型添加到一个Android项目中,并创建了一个用于向其中传递图像的用户界面。
在本系列的上一篇文章中,我们创建了一个项目,用于对驾驶员进行实时危险检测,并准备了一个用于TensorFlow Lite的检测模型。在这里,我们将继续加载模型并为图像处理做准备。
要将模型添加到项目中,请在src/main中创建一个名为assets的新文件夹。将TensorFlow Lite模型和包含标签的文本文件复制到src/main/assets,使其成为项目的一部分。
为了利用模型,我们必须编写代码来加载模型并通过模型传递数据。检测代码将被放置在一个可由两个用户界面共享的类中,以便在静态图像(用于测试)和实时视频流上使用相同的代码。
为模型格式化数据
在开始为此编写代码之前,我们需要知道模型期望其输入数据的结构。数据作为多维数组传入和传出。这也称为数据的“形状”。通常,当您找到一个模型时,这些信息会被记录下来。
您还可以使用工具Netron检查数据。从该工具打开模型时,将显示构成网络的节点。单击输入节点(显示在图表顶部)显示输入数据(在本例中为图像)的信息格式和网络的输出。在本例中,我们看到输入数据是32位浮点数的数组。阵列的尺寸为1x416x416x3。这意味着网络将一次接受一个具有红色、绿色和蓝色分量的416×416像素的图像。如果要为此项目使用不同的模型,则需要检查模型的输入和输出,并相应地调整代码。在解释结果时,我们将更详细地检查输出数据。
向项目中添加一个名为Detector的新类。所有用于管理训练网络的代码都将添加到此类中。构建类时,它将接受图像并以更易于使用的格式提供结果。我们应该向类中添加一些常量和字段来开始使用它。这些字段包括一个包含训练网络的TensorFlow解释器对象、模型识别的对象类列表以及应用程序上下文。
class Detector {
val TF_MODEL_NAME = "yolov4.tflite"
val IMAGE_WIDTH = 416
val IMAGE_HEIGHT = 416
val TAG = "Detector"
val useGpuDelegate = false;
val useNNAPI=true;
val context: Context;
lateinit var tfLiteInterpreter:Interpreter
var labelList = Vector<String>()
//These output values are structured to match the output of the trained model being used
var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
var outputBuffers: HashMap<Int, Any>? = null
}
该类的构造函数将创建输出缓冲区,加载网络模型,并从assets文件夹加载对象类的名称。
class Detector {
val TF_MODEL_NAME = "yolov4.tflite"
val IMAGE_WIDTH = 416
val IMAGE_HEIGHT = 416
val TAG = "Detector"
val useGpuDelegate = false;
val useNNAPI=true;
val context: Context;
lateinit var tfLiteInterpreter:Interpreter
var labelList = Vector<String>()
//These output values are structured to match the output of the trained model being used
var buf0 = Array(1) { Array(52) { Array(52) { Array(3) { FloatArray(85) } } } }
var buf1 = Array(1) { Array(26) { Array(26) { Array(3) { FloatArray(85) } } } }
var buf2 = Array(1) { Array(13) { Array(13) { Array(3) { FloatArray(85) } } } }
var outputBuffers: HashMap<Int, Any>? = null
}
测试模型
执行网络模型只需要几行代码。当图像被提供给探测器类时,它将被调整大小以匹配网络的要求。位图图像中的数据被编码为字节。这些值必须转换为32位浮点值。TensorFlow Lite库包含的功能可以使像这样的常见转换变得简单。TensorImage类型还有一个方便的方法,可以将它用作需要输入缓冲区的方法的缓冲区。
public fun processImage(sourceImage: Bitmap) {
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(IMAGE_HEIGHT, IMAGE_WIDTH, ResizeOp.ResizeMethod.BILINEAR))
.build()
var tImage = TensorImage(DataType.FLOAT32)
tImage.load(sourceImage)
tImage = imageProcessor.process(tImage)
tfLiteInterpreter.runForMultipleInputsOutputs(arrayOf<any>(tImage.buffer), outputBuffers!!)
}</any>
要对此进行测试,请向项目中添加新布局。布局将有一个简单的界面,允许从设备的图像被选择。探测器将对选定的图像进行处理。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout>
<ImageView
android:id="@+id/selected_image_view"
/>
<Button
android:id="@+id/select_image_button"
android:onClick="onSelectImageClicked"
/>
</androidx.constraintlayout.widget.ConstraintLayout>
此活动的代码将打开系统图像选择器。当选择一个图像并将其传回应用程序时,它将图像传递给检测器。
public override fun onActivityResult(reqCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(reqCode, resultCode, data)
if (resultCode == RESULT_OK) {
if (reqCode == SELECT_PICTURE) {
val selectedUri = data!!.data
val fileString = selectedUri!!.path
selected_image_view!!.setImageURI(selectedUri)
var sourceBitmap: Bitmap? = null
try {
sourceBitmap =
MediaStore.Images.Media.getBitmap(this.contentResolver, selectedUri)
RunDetector(sourceBitmap)
} catch (e: IOException) {
e.printStackTrace()
}
}
}
}
fun RunDetector(bitmap: Bitmap?) {
if (detector == null) detector = Detector(this)
detector!!.processImage(bitmap!!)
}