카테고리 없음

[Android] Tensorflow Lite Model - Object Detect

Jun.LEE 2024. 5. 2. 21:26

Tensorflow Lite를 검색해보니 여러가지 pre-trained model을 받을 수 있었다.

 

그래서 사이트에서 제공하는 모델과 라이브러리를 통해서 한번 사용해 보았다.

https://www.tensorflow.org/lite/models/trained?hl=ko&_gl=1*1rxz61k*_up*MQ..*_ga*MTEzMTI5MjY0NC4xNzE0NjUwODE3*_ga_W0YLR4190T*MTcxNDY1MDgxNy4xLjAuMTcxNDY1MDgxNy4wLjAuMA..

 

TensorFlow Lite용 사전 학습된 모델

TensorFlow가 5월 14일 Google I/O로 돌아왔습니다! 지금 등록하세요 이 페이지는 Cloud Translation API를 통해 번역되었습니다. TensorFlow Lite용 사전 학습된 모델 컬렉션을 사용해 정리하기 내 환경설정을 기

www.tensorflow.org

 

위 모델 중에 Object Detection 모델을 사용했다.

모델을 다운 받은 후에 해당 모델을 앱모듈에 assets에 넣어주고 시작해보자.

 

이후 라이브러리들을 추가해주자.

 

앱 수준의 build.gradle 파일에 다음을 추가해주자.

camera는 모바일 디바이스의 카메라를 이용하여 object detection을 하기 위해 추가했다. 

val camerax_version = "1.1.0-beta01"
implementation ("androidx.camera:camera-core:${camerax_version}")
implementation ("androidx.camera:camera-camera2:${camerax_version}")
implementation ("androidx.camera:camera-lifecycle:${camerax_version}")
implementation ("androidx.camera:camera-video:${camerax_version}")
implementation ("androidx.camera:camera-view:${camerax_version}")
implementation ("androidx.camera:camera-extensions:${camerax_version}")

implementation("org.tensorflow:tensorflow-lite-task-vision:0.3.1")

 

카메라를 사용해야 하므로 manifest에 permission들을 추가해주고 

Runtime permission 코드들도 추가해주자.

이 부분은 생략하도록 하겠다.

 

이후 CameraX 라이브러리의 ImageAnalysis를 통해서 리스너에

카메라에 들어오는 이미지를 bitmap으로 바꾸어 모델에 입력해주도록 하자.

 

cameraProviderFuture.addListener({
        
        // 생략

        val imageAnalyzer = ImageAnalysis.Builder()
            .build()
            .also {
                it.setAnalyzer(
                    cameraExecutor,
                    MyAnalyzer(
                        context = this@MainActivity
                    ) { detectingObjectList ->
                        with(binding.indicatorView) {
                            setDetectingObjectList(detectingObjectList)
                            invalidate()
                        }
                    }
                )
            }
            
       	// 생략
        
        }
    }, ContextCompat.getMainExecutor(this))
}

 

 

class MyAnalyzer(
    private val context: Context,
    private val resultCallback: (List<DetectingObject>) -> Unit = {_->},
) : ImageAnalysis.Analyzer {

    @OptIn(ExperimentalGetImage::class)
    override fun analyze(image: ImageProxy) {

        image.image?.let {
            runObjectDetection(it.toBitmap())
        }
    }

    fun Image.toBitmap(): Bitmap {
        val yBuffer = planes[0].buffer // Y
        val vuBuffer = planes[2].buffer // VU

        val ySize = yBuffer.remaining()
        val vuSize = vuBuffer.remaining()

        val nv21 = ByteArray(ySize + vuSize)

        yBuffer.get(nv21, 0, ySize)
        vuBuffer.get(nv21, ySize, vuSize)

        val yuvImage = YuvImage(nv21, ImageFormat.NV21, this.width, this.height, null)
        val out = ByteArrayOutputStream()
        yuvImage.compressToJpeg(Rect(0, 0, yuvImage.width, yuvImage.height), 50, out)
        val imageBytes = out.toByteArray()
        return BitmapFactory.decodeByteArray(imageBytes, 0, imageBytes.size)
    }

    private fun runObjectDetection(bitmap: Bitmap) {
        val image = TensorImage.fromBitmap(bitmap)
        val options = ObjectDetector.ObjectDetectorOptions.builder()
            .setMaxResults(5)
            .setScoreThreshold(0.5f)
            .build()

        val detector = ObjectDetector.createFromFileAndOptions(
            context, 
            "model.tflite", 
            options
        )

        val results = detector.detect(image)
        debugPrint(results)
    }

    private fun debugPrint(results : List<Detection>) {
        val detectingObjectList = mutableListOf<DetectingObject>()
        for ((i, obj) in results.withIndex()) {
            val box = obj.boundingBox

            Log.d(TAG, "Detected object: ${i} ")
            Log.d(TAG, "  boundingBox: (${box.left}, ${box.top}) - (${box.right},${box.bottom})")

            detectingObjectList.add(
                DetectingObject(box.left, box.top, box.right, box.bottom)
            )

            for ((j, category) in obj.categories.withIndex()) {
                Log.d(TAG, "    Label $j: ${category.label}")
                val confidence: Int = category.score.times(100).toInt()
                Log.d(TAG, "    Confidence: ${confidence}%")
            }
        }
        resultCallback(detectingObjectList)
    }
}

 

 

어플리케이션을 실행시켜 아래의 사진처럼 카메라를 위치하면 

 

로그들 중에서 아래와 같은 항목을 발견할 수 있었다.

저 사진을 보고 모델은 TV와 CELL PHONE을 구분하기 어려워 하는 것 같다.

confidence가 아주 비슷했다.