| # Copyright 2019 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Classification Engine used for classification tasks.""" |
| |
| from edgetpu.basic.basic_engine import BasicEngine |
| import numpy |
| from PIL import Image |
| |
| |
| class ClassificationEngine(BasicEngine): |
| """Engine used for classification task.""" |
| |
| def __init__(self, model_path, device_path=None): |
| """Creates a BasicEngine with given model. |
| |
| Args: |
| model_path: String, path to TF-Lite Flatbuffer file. |
| device_path: String, if specified, bind engine with Edge TPU at device_path. |
| |
| Raises: |
| ValueError: An error occurred when the output format of model is invalid. |
| """ |
| if device_path: |
| super().__init__(model_path, device_path) |
| else: |
| super().__init__(model_path) |
| output_tensors_sizes = self.get_all_output_tensors_sizes() |
| if output_tensors_sizes.size != 1: |
| raise ValueError( |
| ('Classification model should have 1 output tensor only!' |
| 'This model has {}.'.format(output_tensors_sizes.size))) |
| |
| def ClassifyWithImage( |
| self, img, threshold=0.1, top_k=3, resample=Image.NEAREST): |
| """Classifies image with PIL image object. |
| |
| This interface assumes the loaded model is trained for image |
| classification. |
| |
| Args: |
| img: PIL image object. |
| threshold: float, threshold to filter results. |
| top_k: keep top k candidates if there are many candidates with score |
| exceeds given threshold. By default we keep top 3. |
| resample: An optional resampling filter on image resizing. By default it |
| is PIL.Image.NEAREST. Complex filter such as PIL.Image.BICUBIC will |
| bring extra latency, and slightly better accuracy. |
| |
| Returns: |
| List of (int, float) which represents id and score. |
| |
| Raises: |
| RuntimeError: when model isn't used for image classification. |
| """ |
| input_tensor_shape = self.get_input_tensor_shape() |
| if (input_tensor_shape.size != 4 or input_tensor_shape[3] != 3 or |
| input_tensor_shape[0] != 1): |
| raise RuntimeError( |
| 'Invalid input tensor shape! Expected: [1, height, width, 3]') |
| _, height, width, _ = input_tensor_shape |
| img = img.resize((width, height), resample) |
| input_tensor = numpy.asarray(img).flatten() |
| return self.ClassifyWithInputTensor(input_tensor, threshold, top_k) |
| |
| def ClassifyWithInputTensor(self, input_tensor, threshold=0.0, top_k=3): |
| """Classifies with raw input tensor. |
| |
| This interface requires user to process input data themselves and convert |
| it to formatted input tensor. |
| |
| Args: |
| input_tensor: numpy.array represents the input tensor. |
| threshold: float, threshold to filter results. |
| top_k: keep top k candidates if there are many candidates with score |
| exceeds given threshold. By default we keep top 3. |
| |
| Returns: |
| List of (int, float) which represents id and score. |
| |
| Raises: |
| ValueError: when input param is invalid. |
| """ |
| if top_k <= 0: |
| raise ValueError('top_k must be positive!') |
| _, self._raw_result = self.RunInference( |
| input_tensor) |
| # top_k must be less or equal to number of possible results. |
| top_k = min(top_k, len(self._raw_result)) |
| result = [] |
| indices = numpy.argpartition(self._raw_result, -top_k)[-top_k:] |
| for i in indices: |
| if self._raw_result[i] > threshold: |
| result.append((i, self._raw_result[i])) |
| result.sort(key=lambda tup: -tup[1]) |
| return result[:top_k] |