Android配置tensorflow lite
按照官方網站的指導在專案的模塊的構建檔案build.gradle中配置中增加如下配置:
implementation 'org.tensorflow:tensorflow-lite:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
android{
aaptOptions {
noCompress "tflite"
}
defaultConfig {
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
}
匯入模型資源資源
創建將文《關于將Tesorflow的SavedModel模型轉換成tflite模型》創建的模型model.tflite,匯入到Android專案的assets目錄中,
定義模型基本配置類BaseModelConfig
/**
* 定義模型的基本配置類
*/
public abstract class BaseModelConfig{
//每通道處理的位元組數
var numBytesPerChannel:Int = 0
//定義批處理的個數
var dimBatchSize:Int = 0
//定義像素個數
var dimPixelSize:Int = 0
//定義圖片的寬度
var dimImgWidth:Int = 0
//定義圖片的高度
var dimImgHeight:Int = 0
//定義平均差
var imageMean=0
//定義圖片的標準差
var imageSTD:Float = 0.0F
//定義模型的名稱
lateinit var modelName:String
constructor() : super() {
setConfigs()
}
/**
* 將像素值轉換成ByteBuffer
* 增加圖片的值
*/
public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)
/**
* 配置
*/
public abstract fun setConfigs()
}
定義FloatSavedModelConfig類
class FloatSavedModelConfig: BaseModelConfig() {
public override fun setConfigs() {
modelName="model.tflite"
numBytesPerChannel = 4
dimBatchSize = 1
dimPixelSize = 1
dimImgWidth = 28
dimImgHeight = 28
imageMean = 0
imageSTD = 255.0f
}
override fun addImgValue(imgData: ByteBuffer, pixel: Int) {
imgData.putFloat(((pixel and 0xFF) - imageMean) / imageSTD)
}
}
創建配置模型引數的工廠類
object ModelConfigFactory {
const val FLOAT_SAVED_MODEL = "float_saved_model"
const val QUANT_SAVED_MODEL = "quant_saved_model"
fun getModelConfig(model: String): BaseModelConfig? =
when(model) {
FLOAT_SAVED_MODEL-> FloatSavedModelConfig()
QUANT_SAVED_MODEL-> QuantSavedModelConfig()
else->null
}
}
定義影像分類器
class ImageClassifier {
private val TAG = "FashionMNIST"
private val RESULTS_TO_SHOW = 3
lateinit var mTFLite: Interpreter
lateinit var mModelPath:String
var mNumBytesPerChannel = 0
var mDimBatchSize = 0
var mDimPixelSize = 0
var mDimImgWidth = 0
var mDimImgHeight = 0
lateinit var mModelConfig:BaseModelConfig
//定義標簽檢測的二維陣列1x10
val mLabelProbArray = Array(1) {
FloatArray(
10
)
}
val labels = arrayListOf("T恤","褲子","帽頭衫","連衣裙","外套","涼鞋","襯衫","運動鞋","包","靴子")
//定義檢測結果保持到優先佇列中
var mSortedLabels = PriorityQueue<Map.Entry<String, Float>>(
RESULTS_TO_SHOW) {
o1, o2 -> o1?.value!!.compareTo(o2?.value!!)
}
/**
* 配置引數
*/
private fun initConfig(config: BaseModelConfig) {
mModelConfig = config
mNumBytesPerChannel = config.numBytesPerChannel
mDimBatchSize = config.dimBatchSize
mDimPixelSize = config.dimPixelSize
mDimImgWidth = config.dimImgWidth
mDimImgHeight = config.dimImgHeight
mModelPath = config.modelName
}
constructor(modelConfig: String, activity: Activity) {
// 初始化分類器的相關引數
initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)
// 使用配置引數初始化翻譯器
mTFLite = Interpreter(loadModelFile(activity)!!)
}
/**
* 在Assets中的模型檔案映射到記憶體中
* */
private fun loadModelFile(activity: Activity): MappedByteBuffer? {
val fileDescriptor = activity.assets.openFd(mModelPath)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
/**
* 將圖片資料寫入到ByteBuffer,加載到記憶體中
* */
protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {
val intValues = IntArray(mDimImgWidth * mDimImgHeight)
//調整要處理的圖片為28x28
var tmp = scaleBitmap(bitmap)
//將圖片二值化
tmp = binarized(tmp)
//將二值化的圖片加載到記憶體中
tmp.getPixels(intValues,
0, tmp.width, 0, 0, tmp.width, tmp.height
)
val imgData = ByteBuffer.allocateDirect(
mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize
)
imgData.order(ByteOrder.nativeOrder())
imgData.rewind()
//將圖片轉換成像素實數資料
var pixel = 0
for (i in 0 until mDimImgWidth) {
for (j in 0 until mDimImgHeight) {
var value = intValues[pixel++]
mModelConfig.addImgValue(imgData, value)
}
}
return imgData
}
/**
* 將圖片二值化處理
* 轉換成二值影像
* @param bmp
* @return
*/
fun binarized(bmp: Bitmap): Bitmap {
val width = bmp.width
val height = bmp.height
val pixels = IntArray(width * height)
//將圖片的像素加載到陣列中
bmp.getPixels(pixels, 0, width, 0, 0, width, height)
var alpha = 0xFF shl 24
for (i in 0 until height) {
for (j in 0 until width) {
val grey = pixels[width * i + j]
// 分離三原色
alpha = grey and -0x1000000 shr 24
var red = grey and 0x00FF0000 shr 16
var green = grey and 0x0000FF00 shr 8
var blue = grey and 0x000000FF
val tmp = 180
red = if (red > tmp) 255 else 0
blue = if (blue > tmp) 255 else 0
green = if (green > tmp) 255 else 0
pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blue
if (pixels[width * i + j] == -1) {
pixels[width * i + j] = -1
} else {
pixels[width * i + j] = -16777216
}
}
}
// 新建圖片
val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
// 設定圖片資料
newBmp.setPixels(pixels, 0, width, 0, 0, width, height)
return newBmp
}
/**
* 將圖片調整到規定的大小28x28
*/
fun scaleBitmap(bmp: Bitmap?): Bitmap {
return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)
}
/**
* 分類處理
*/
fun doClassify(bitmap: Bitmap?): String? {
// 將Bitmap圖片轉換成TFLite翻譯器的可讀的ByteBuffer
val imgData = convertBitmapToByteBuffer(bitmap)
// do run interpreter
val startTime = System.nanoTime()
mTFLite.run(imgData, mLabelProbArray)
val endTime = System.nanoTime()
Log.i(TAG, String.format(
"運行識別的時間: %f ms",
(endTime - startTime).toFloat() / 1000000.0f
)
)
// 生成并回傳結果
return printTopKLabels()
}
/**
* 列印檢測排序在前幾位的標簽,并作為結果顯示在UI界面中,
*/
fun printTopKLabels(): String? {
for (i in 0..9) {
mSortedLabels.add(
AbstractMap.SimpleEntry(
labels[i],
mLabelProbArray[0][i]
)
)
if (mSortedLabels.size > RESULTS_TO_SHOW) {
mSortedLabels.poll()
}
}
val textToShow = StringBuffer()
val size = mSortedLabels.size
for (i in 0 until size) {
val label = mSortedLabels.poll()
textToShow.insert(0, String.format("\n%s %4.8f", label.key, label.value))
}
return textToShow.toString()
}
}
定義主活動MainActivity
在主活動中,主要處理如下操作:
(1)從圖庫中選擇圖片
(2)利用影像分類器檢測圖片中的內容,判斷是FashionMnist資料集的哪種標簽
(3)將檢測的結果在移動終端的GUI界面中顯示出來,
class MainActivity : AppCompatActivity() {
private lateinit var binding: ActivityMainBinding
val RequestCameraCode = 1
val TAG = "FashionMNIST"
companion object{
var mIsFloat = true
}
private var bitmap: Bitmap? = null
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
//生成視圖系結物件
binding = ActivityMainBinding.inflate(layoutInflater)
//設定視圖的根視圖
setContentView(binding.root)
binding.imageView.setOnClickListener {
val intent = Intent()
intent.type = "image/*"
intent.action = Intent.ACTION_GET_CONTENT
startActivityForResult(intent,RequestCameraCode)
}
val spinnerAdapter = ArrayAdapter<String>(this,android.R.layout.simple_spinner_item,getChoices())
binding.typeSpinner.adapter = spinnerAdapter
binding.typeSpinner.onItemSelectedListener = object : OnItemSelectedListener {
override fun onItemSelected(
parent: AdapterView<*>?,
view: View,
position: Int,
id: Long
) {
mIsFloat = position == 0
}
override fun onNothingSelected(parent: AdapterView<*>?) {}
}
}
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if(resultCode == RESULT_OK && requestCode == RequestCameraCode){
val uri = data?.data
try{
//從圖庫中讀取圖片
var bitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(uri!!))
//在影像視圖ImageView中顯示圖片
binding.imageView.setImageBitmap(bitmap)
//判斷模型型別
val config = when(mIsFloat){
true->ModelConfigFactory.FLOAT_SAVED_MODEL
else->ModelConfigFactory.QUANT_SAVED_MODEL
}
//根據模型型別創建影像識別器
val classifier = ImageClassifier(config,this)
//檢測并判斷影像的類別
val result = classifier.doClassify(bitmap)
binding.labelTxt.text = result
binding.tipTxt.visibility = View.GONE
}catch(e: FileNotFoundException){
Log.d(TAG,"沒有找到指定的影像檔案")
}catch(e: IOException){
Log.e(TAG,"初始化影像識別器失敗")
}
}
}
/**
* 回傳可用模型的名稱
*/
private fun getChoices()= resources.getStringArray(R.array.model_names)
}
參考文獻
李錫涵等 《簡明的Tensorflow 2》人民郵電出版社 北京 P91-P96
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/398038.html
標籤:AI
上一篇:video播放視頻
