blob: fc67683bcdc04872284016104d57203a50a346b8 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification Engine used for classification tasks."""
from edgetpu.basic.basic_engine import BasicEngine
import numpy
from PIL import Image
class ClassificationEngine(BasicEngine):
"""Engine used for classification task."""
def __init__(self, model_path, device_path=None):
"""Creates a BasicEngine with given model.
Args:
model_path: String, path to TF-Lite Flatbuffer file.
device_path: String, if specified, bind engine with Edge TPU at device_path.
Raises:
ValueError: An error occurred when the output format of model is invalid.
"""
if device_path:
super().__init__(model_path, device_path)
else:
super().__init__(model_path)
output_tensors_sizes = self.get_all_output_tensors_sizes()
if output_tensors_sizes.size != 1:
raise ValueError(
('Classification model should have 1 output tensor only!'
'This model has {}.'.format(output_tensors_sizes.size)))
def ClassifyWithImage(
self, img, threshold=0.1, top_k=3, resample=Image.NEAREST):
"""Classifies image with PIL image object.
This interface assumes the loaded model is trained for image
classification.
Args:
img: PIL image object.
threshold: float, threshold to filter results.
top_k: keep top k candidates if there are many candidates with score
exceeds given threshold. By default we keep top 3.
resample: An optional resampling filter on image resizing. By default it
is PIL.Image.NEAREST. Complex filter such as PIL.Image.BICUBIC will
bring extra latency, and slightly better accuracy.
Returns:
List of (int, float) which represents id and score.
Raises:
RuntimeError: when model isn't used for image classification.
"""
input_tensor_shape = self.get_input_tensor_shape()
if (input_tensor_shape.size != 4 or input_tensor_shape[3] != 3 or
input_tensor_shape[0] != 1):
raise RuntimeError(
'Invalid input tensor shape! Expected: [1, height, width, 3]')
_, height, width, _ = input_tensor_shape
img = img.resize((width, height), resample)
input_tensor = numpy.asarray(img).flatten()
return self.ClassifyWithInputTensor(input_tensor, threshold, top_k)
def ClassifyWithInputTensor(self, input_tensor, threshold=0.0, top_k=3):
"""Classifies with raw input tensor.
This interface requires user to process input data themselves and convert
it to formatted input tensor.
Args:
input_tensor: numpy.array represents the input tensor.
threshold: float, threshold to filter results.
top_k: keep top k candidates if there are many candidates with score
exceeds given threshold. By default we keep top 3.
Returns:
List of (int, float) which represents id and score.
Raises:
ValueError: when input param is invalid.
"""
if top_k <= 0:
raise ValueError('top_k must be positive!')
_, self._raw_result = self.RunInference(
input_tensor)
# top_k must be less or equal to number of possible results.
top_k = min(top_k, len(self._raw_result))
result = []
indices = numpy.argpartition(self._raw_result, -top_k)[-top_k:]
for i in indices:
if self._raw_result[i] > threshold:
result.append((i, self._raw_result[i]))
result.sort(key=lambda tup: -tup[1])
return result[:top_k]