blob: 44af5ff0c2acfb417e7c013ec525710b5710fa46 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluates the accuracy of imprinting based transfer learning model."""
import collections
import contextlib
import os
import unittest
from edgetpu.basic.basic_engine import BasicEngine
from edgetpu.classification.engine import ClassificationEngine
from edgetpu.learn.imprinting.engine import ImprintingEngine
from PIL import Image
from . import test_utils
@contextlib.contextmanager
def TestImage(path):
with open(path, 'rb') as f:
with Image.open(f) as image:
yield image
class ImprintingEngineEvaluationTest(unittest.TestCase):
def _GetInputTensorShape(self, model_path):
"""Gets input tensor shape of given model.
Args:
model_path: string, path of the model.
Returns:
List of integers.
"""
tmp = BasicEngine(model_path)
shape = tmp.get_input_tensor_shape()
return shape.copy()
def _TransferLearnAndEvaluate(self, extractor_path, dataset_path,
test_ratio, top_k_range):
"""Transfer-learns with given params and returns the evaluatoin result.
Args:
extractor_path: string, path of embedding extractor.
dataset_path: string, path to the directory of dataset. The images
should be put under sub-directory named by category.
test_ratio: float, the ratio of images used for test.
top_k_range: int, top_k range to be evaluated. The function will return
accuracy from top 1 to top k.
Returns:
list of float numbers.
"""
print('--------------- Parsing dataset ----------------')
print('Dataset path:', dataset_path)
# Train in fixed order, this is because many test images get same score for
# several categories. If we don't fix the order, the evaluation result will
# be slightly different.
#
# For example, a real daffodil image is classified as:
# 1.pansy(0.0625) 2.daffodil(0.0625) 3.crocus(0.0625) 4.iris(0.0625) ...
#
# The order of them depends on the training order. Hence if we train
# daffodil first, we'll get:
# 1.daffodil(0.0625) 2.pansy(0.0625) ..
#
# The top 1 accuracy increased while top 2 accuracy decreased.
train_set, test_set = test_utils.PrepareDataSetFromDirectory(
dataset_path, test_ratio, True)
print('Image list successfully parsed! Number of Categories = ',
len(train_set))
input_shape = self._GetInputTensorShape(extractor_path)
required_image_shape = (input_shape[2], input_shape[1]) # (width, height)
print('--------------- Processing training data ----------------')
print('This process may take more than 30 seconds.')
train_input = collections.OrderedDict()
for category, image_list in train_set.items():
print('Processing {} ({} images)'.format(category, len(image_list)))
train_input[category] = test_utils.PrepareImages(
image_list,
os.path.join(dataset_path, category),
required_image_shape
)
# Train
print('---------------- Start training -----------------')
imprinting_engine = ImprintingEngine(extractor_path)
labels_map = imprinting_engine.TrainAll(train_input)
print('---------------- Training finished -----------------')
imprinting_engine.SaveModel('model_for_evaluation.tflite')
# Evaluate
print('---------------- Start evaluating -----------------')
classification_engine = ClassificationEngine('model_for_evaluation.tflite')
# top[i] represents number of top (i+1) correct inference.
top_k_correct_count = [0] * top_k_range
image_num = 0
for category, image_list in test_set.items():
n = len(image_list)
print('Evaluating {} ({} images)'.format(category, n))
for image_name in image_list:
with TestImage(os.path.join(dataset_path, category, image_name)) as raw_image:
# Set threshold as a negative number to ensure we get top k candidates
# even if its score is 0.
candidates = classification_engine.ClassifyWithImage(
raw_image, threshold=-0.1, top_k=top_k_range)
for i in range(len(candidates)):
if labels_map[candidates[i][0]] == category:
top_k_correct_count[i] += 1
break
image_num += n
for i in range(1, top_k_range):
top_k_correct_count[i] += top_k_correct_count[i-1]
return [top_k_correct_count[i] / image_num for i in range(top_k_range)]
def testOxford17Flowers(self):
expected = [0.79, 0.89, 0.92, 0.93, 0.95]
top_k_range = len(expected)
ret = self._TransferLearnAndEvaluate(
test_utils.TestDataPath(
'imprinting/mobilenet_v1_1.0_224_quant_embedding_extractor.tflite'),
test_utils.TestDataPath('oxford_17flowers'),
0.25,
top_k_range
)
for i in range(top_k_range):
self.assertGreaterEqual(ret[i], expected[i])
if __name__ == '__main__':
unittest.main()