TensorFlow Lite Model Android: Cannot find an axis to label. A valid axis to label should have size larger than 1

huangapple 未分类评论44阅读模式
英文:

TensorFlow Lite Model Android: Cannot find an axis to label. A valid axis to label should have size larger than 1

问题

我正在尝试在Android应用程序中使用预训练的TensorFlow Lite模型。

我已经从这里下载了TensorFlow Lite图像分类示例应用程序

我已经更改了所有四个模型分类器文件中的以下代码:

protected String getModelPath() {
    // 你可以从以下位置下载这个文件
    // 请查看build.gradle以查找此文件的获取位置。它应该会自动下载到assets文件夹中。
    //return "mobilenet_v1_1.0_224_quant.tflite";
    return "model_23072020.tflite";
}

我正在使用的TensorFlow Lite模型是用于图像分类的预训练模型。基本上,它会扫描图像并生成一个输出,为0或1。0表示图像质量不好,1表示图像质量好。

模型具有动态范围量化。

当我运行应用程序并打印outputProbabilityBuffer.getFloatArray()的值时,我得到以下结果:

I/tensorflow: Classifier: value of output [F@e08d0d3

我正在尝试使用以下代码记录值:

tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());

Map<String, Float> labeledProbability = new HashMap<>();
labeledProbability.put("abc", 93.556f);

// 为在控制台显示值添加了日志记录器
LOGGER.i("output的值是 %s ", outputProbabilityBuffer.getFloatArray());

更新

我已经删除了上面的日志记录器,现在我在以下行收到异常:

Map<String, Float> labeledProbability = new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
        .getMapWithFloatValue();

收到的错误是找不到要标记的轴。有效的要标记的轴应具有大于1的大小

完整的堆栈跟踪如下所示:

java.lang.IllegalArgumentException: 找不到要标记的轴有效的要标记的轴应具有大于1的大小
    at org.tensorflow.lite.support.label.TensorLabel.getFirstAxisWithSizeGreaterThanOne(TensorLabel.java:214)
    at org.tensorflow.lite.support.label.TensorLabel.<init>(TensorLabel.java:105)
    at org.tensorflow.lite.examples.classification.tflite.Classifier.recognizeImage(Classifier.java:263)
    at org.tensorflow.lite.examples.classification.ClassifierTest.classificationResultsShouldNotChange(ClassifierTest.java:67)
    // 其余堆栈信息...
英文:

I am trying to use a pre-trained TensorFlow Lite model in an Android Application.

I have downloaded the Image Classification example application for TensorFlow Lite from here

I have changed following code in the all the four Model Classifier Files

protected String getModelPath() {
   // you can download this file from
   // see build.gradle for where to obtain this file. It should be auto
   // downloaded into assets.
   //return &quot;mobilenet_v1_1.0_224_quant.tflite&quot;;
  return &quot;model_23072020.tflite&quot;;
}

The TensorFlow Lite Model that I am using is a pre-trained model for image classification. Basically it scans the image and produces an output as 0 or 1. 0 indicates the image is not of good quality and 1 indicates the image is a good quality image.

Model has Dynamic Range Quantization.

When I am running the application and printing the value of outputProbabilityBuffer.getFloatArray(), I am getting following results

> I/tensorflow: Classifier: value of output [F@e08d0d3

I am trying to log value using below code

tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());


Map&lt;String, Float&gt; labeledProbability = new HashMap&lt;&gt;();
labeledProbability.put(&quot;abc&quot;, 93.556f);

// Added logger for displaying value in console
LOGGER.i(&quot;value of output %s &quot;, outputProbabilityBuffer.getFloatArray());

Update

I removed the logger above and now I am getting exception at this line

Map&lt;String, Float&gt; labeledProbability = new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
        .getMapWithFloatValue();

The error received is Cannot find an axis to label. A valid axis to label should have size larger than 1.

Complete stack trace is mentioned below

> java.lang.IllegalArgumentException: Cannot find an axis to label. A
> valid axis to label should have size larger than 1. at
> org.tensorflow.lite.support.label.TensorLabel.getFirstAxisWithSizeGreaterThanOne(TensorLabel.java:214)
> at
> org.tensorflow.lite.support.label.TensorLabel.<init>(TensorLabel.java:105)
> at
> org.tensorflow.lite.examples.classification.tflite.Classifier.recognizeImage(Classifier.java:263)
> at
> org.tensorflow.lite.examples.classification.ClassifierTest.classificationResultsShouldNotChange(ClassifierTest.java:67)
> at java.lang.reflect.Method.invoke(Native Method) at
> org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50)
> at
> org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12)
> at
> org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47)
> at
> org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17)
> at
> androidx.test.internal.runner.junit4.statement.RunBefores.evaluate(RunBefores.java:80)
> at
> androidx.test.rule.ActivityTestRule$ActivityStatement.evaluate(ActivityTestRule.java:527)
> at org.junit.rules.RunRules.evaluate(RunRules.java:20) at
> org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at
> org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78)
> at
> org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57)
> at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at
> org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at
> org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at
> org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at
> org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at
> org.junit.runners.ParentRunner.run(ParentRunner.java:363) at
> org.junit.runners.Suite.runChild(Suite.java:128) at
> org.junit.runners.Suite.runChild(Suite.java:27) at
> org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at
> org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at
> org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at
> org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at
> org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at
> org.junit.runners.ParentRunner.run(ParentRunner.java:363) at
> org.junit.runner.JUnitCore.run(JUnitCore.java:137) at
> org.junit.runner.JUnitCore.run(JUnitCore.java:115) at
> androidx.test.internal.runner.TestExecutor.execute(TestExecutor.java:56)
> at
> androidx.test.runner.AndroidJUnitRunner.onStart(AndroidJUnitRunner.java:389)
> at
> android.app.Instrumentation$InstrumentationThread.run(Instrumentation.java:2075)

答案1

得分: 0

你尝试记录一个对象。outputProbabilityBuffer是一个输出对象,你需要其中的值。

在Classifier类中,在第332行尝试记录entry.getKey()和entry.getValue()。这应该是标签名称和置信度。

还有一个已定义的getter:getTitle(),例如在CameraActivity的第522行中使用,你可以记录输出。

英文:

You try to log an object. outputProbabilityBuffer is an output object, you need the value of it.

In the Classifier class, try to log entry.getKey() and entry.getValue() in line 332. This should be the label name and confidence.

there is also a getter defined: getTitle() e.g. which is used in the CameraActivity at line 522 where you can log the output.

答案2

得分: 0

问题在于你正在使用二元分类器,但示例的TensorFlow Lite图像识别代码要求你使用多类分类器(即输出多个概率的分类器)。

你正在使用的模型只输出一个概率(可通过outputProbabilityBuffer.getFloatArray()获得),其中0代表一类,1代表另一类。

要解决此问题,不要使用示例中用于在执行后处理输出并将标签映射到概率的代码,只需执行类似以下的操作:

// 执行后
float result = outputProbabilityBuffer.getFloatArray()[0];
if (result < 0.5) {
// 表示预测属于0对应的类别
} else {
// 表示预测属于1对应的类别
}

英文:

The issue is that you are using a binary classifier, but the example TensorFlowLite image recognition code expects you to use a multi-class classifier (i.e., a classifier that spits out more than one probability).

The model you are using spits out only one probability (which you can get by doing outputProbabilityBuffer.getFloatArray()), where 0 is one class and 1 is the other class.

To fix the issue, don't use the code that the example uses to process the output and map labels to probabilities after execution - just do something like:

// after execution
float result = outputProbabilityBuffer.getFloatArray()[0];
if (result &lt; 0.5) {
  // means prediction was for category corresponding to 0
} else {
  // means prediction was for category corresponding to 1
}

huangapple
  • 本文由 发表于 2020年7月24日 17:47:02
  • 转载请务必保留本文链接:https://java.coder-hub.com/63071067.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定