blob: 0e92bbcf89f60084d5291a3307bbd35003285bac [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests image classification accuracy with ImageNet validation data set.
Please download the validation image data from to edgetpu/test_data/imagenet/
"""
import unittest
from edgetpu.classification.engine import ClassificationEngine
from PIL import Image
from . import test_utils
class TestImagenetClassification(unittest.TestCase):
def _CropImage(self, image, cropping_fraction=0.75):
"""Crops an image in the center.
Args:
image: PIL image object.
cropping_fraction: Fraction of cropped image.
Returns:
Cropped image as PIL image object.
"""
init_width, init_height = image.size
new_width = round(init_width * cropping_fraction)
new_height = round(init_height * cropping_fraction)
width_offset = round((init_width - init_width * cropping_fraction) / 2.0)
height_offset = round((init_height - init_height * cropping_fraction) / 2.0)
return image.crop((width_offset, height_offset,
width_offset + new_width, height_offset + new_height))
def _TestModel(self, model_name, expected_top_1=None, expected_top_5=None):
engine = ClassificationEngine(test_utils.TestDataPath(model_name))
with open(test_utils.TestDataPath('imagenet/val.txt'), 'r') as gt_file:
gt = [line .strip().split(' ') for line in gt_file.readlines()]
top_1_count = 0
top_5_count = 0
print('Running inference for model %s...' % model_name)
for i in range(50000):
label = int(gt[i][1]) + 1
image_name = 'imagenet/ILSVRC2012_val_%s.JPEG' % str(i + 1).zfill(8)
with test_utils.TestImage(image_name) as image:
image = self._CropImage(image.convert('RGB'))
prediction = engine.ClassifyWithImage(image, threshold=0.0, top_k=5)
if prediction[0][0] == label:
top_1_count += 1
top_5_count += 1
else:
for j in range(1, len(prediction)):
if prediction[j][0] == label:
top_5_count += 1
top_1_accuracy = top_1_count / 50000.0
top_5_accuracy = top_5_count / 50000.0
print('Top 1 accuracy: %.2f%%' % (top_1_accuracy * 100))
print('Top 5 accuracy: %.2f%%' % (top_5_accuracy * 100))
if expected_top_1 is not None:
self.assertAlmostEqual(top_1_accuracy, expected_top_1, places=4)
if expected_top_5 is not None:
self.assertAlmostEqual(top_5_accuracy, expected_top_5, places=4)
def testMobilenetV1(self):
self._TestModel('mobilenet_v1_1.0_224_quant_edgetpu.tflite',
expected_top_1=0.6854, expected_top_5=0.8772)
def testMobilenetV2(self):
self._TestModel('mobilenet_v2_1.0_224_quant_edgetpu.tflite',
expected_top_1=0.6912, expected_top_5=0.8829)
if __name__ == '__main__':
unittest.main()