[Android] Tensorflow Lite Model - Object Detect
Tensorflow Lite를 검색해보니 여러가지 pre-trained model을 받을 수 있었다.
그래서 사이트에서 제공하는 모델과 라이브러리를 통해서 한번 사용해 보았다.
TensorFlow Lite용 사전 학습된 모델
TensorFlow가 5월 14일 Google I/O로 돌아왔습니다! 지금 등록하세요 이 페이지는 Cloud Translation API를 통해 번역되었습니다. TensorFlow Lite용 사전 학습된 모델 컬렉션을 사용해 정리하기 내 환경설정을 기
www.tensorflow.org
위 모델 중에 Object Detection 모델을 사용했다.
모델을 다운 받은 후에 해당 모델을 앱모듈에 assets에 넣어주고 시작해보자.
이후 라이브러리들을 추가해주자.
앱 수준의 build.gradle 파일에 다음을 추가해주자.
camera는 모바일 디바이스의 카메라를 이용하여 object detection을 하기 위해 추가했다.
val camerax_version = "1.1.0-beta01"
implementation ("androidx.camera:camera-core:${camerax_version}")
implementation ("androidx.camera:camera-camera2:${camerax_version}")
implementation ("androidx.camera:camera-lifecycle:${camerax_version}")
implementation ("androidx.camera:camera-video:${camerax_version}")
implementation ("androidx.camera:camera-view:${camerax_version}")
implementation ("androidx.camera:camera-extensions:${camerax_version}")
implementation("org.tensorflow:tensorflow-lite-task-vision:0.3.1")
카메라를 사용해야 하므로 manifest에 permission들을 추가해주고
Runtime permission 코드들도 추가해주자.
이 부분은 생략하도록 하겠다.
이후 CameraX 라이브러리의 ImageAnalysis를 통해서 리스너에
카메라에 들어오는 이미지를 bitmap으로 바꾸어 모델에 입력해주도록 하자.
cameraProviderFuture.addListener({
// 생략
val imageAnalyzer = ImageAnalysis.Builder()
.build()
.also {
it.setAnalyzer(
cameraExecutor,
MyAnalyzer(
context = this@MainActivity
) { detectingObjectList ->
with(binding.indicatorView) {
setDetectingObjectList(detectingObjectList)
invalidate()
}
}
)
}
// 생략
}
}, ContextCompat.getMainExecutor(this))
}
class MyAnalyzer(
private val context: Context,
private val resultCallback: (List<DetectingObject>) -> Unit = {_->},
) : ImageAnalysis.Analyzer {
@OptIn(ExperimentalGetImage::class)
override fun analyze(image: ImageProxy) {
image.image?.let {
runObjectDetection(it.toBitmap())
}
}
fun Image.toBitmap(): Bitmap {
val yBuffer = planes[0].buffer // Y
val vuBuffer = planes[2].buffer // VU
val ySize = yBuffer.remaining()
val vuSize = vuBuffer.remaining()
val nv21 = ByteArray(ySize + vuSize)
yBuffer.get(nv21, 0, ySize)
vuBuffer.get(nv21, ySize, vuSize)
val yuvImage = YuvImage(nv21, ImageFormat.NV21, this.width, this.height, null)
val out = ByteArrayOutputStream()
yuvImage.compressToJpeg(Rect(0, 0, yuvImage.width, yuvImage.height), 50, out)
val imageBytes = out.toByteArray()
return BitmapFactory.decodeByteArray(imageBytes, 0, imageBytes.size)
}
private fun runObjectDetection(bitmap: Bitmap) {
val image = TensorImage.fromBitmap(bitmap)
val options = ObjectDetector.ObjectDetectorOptions.builder()
.setMaxResults(5)
.setScoreThreshold(0.5f)
.build()
val detector = ObjectDetector.createFromFileAndOptions(
context,
"model.tflite",
options
)
val results = detector.detect(image)
debugPrint(results)
}
private fun debugPrint(results : List<Detection>) {
val detectingObjectList = mutableListOf<DetectingObject>()
for ((i, obj) in results.withIndex()) {
val box = obj.boundingBox
Log.d(TAG, "Detected object: ${i} ")
Log.d(TAG, " boundingBox: (${box.left}, ${box.top}) - (${box.right},${box.bottom})")
detectingObjectList.add(
DetectingObject(box.left, box.top, box.right, box.bottom)
)
for ((j, category) in obj.categories.withIndex()) {
Log.d(TAG, " Label $j: ${category.label}")
val confidence: Int = category.score.times(100).toInt()
Log.d(TAG, " Confidence: ${confidence}%")
}
}
resultCallback(detectingObjectList)
}
}
어플리케이션을 실행시켜 아래의 사진처럼 카메라를 위치하면
로그들 중에서 아래와 같은 항목을 발견할 수 있었다.
저 사진을 보고 모델은 TV와 CELL PHONE을 구분하기 어려워 하는 것 같다.
confidence가 아주 비슷했다.