blob: 768e12fe6a097a1ed8c5bfabe383fd64e4685836 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmark on small data set."""
import os
import subprocess
import time
from edgetpu.basic.basic_engine import BasicEngine
from edgetpu.learn.imprinting.engine import ImprintingEngine
import test_utils
def _GetShape(extractor):
"""Gets images shape required by extractor.
Args:
extractor: string, file name of the embedding extractor.
Returns:
(width, height)
"""
basic_engine = BasicEngine(test_utils.TestDataPath('imprinting', extractor))
_, width, height, _ = basic_engine.get_input_tensor_shape()
return (width, height)
def _BenchmarkForTraining(extractor, data_set):
"""Measures training time for given extractor and data set.
Args:
extractor: string, file name of the embedding extractor.
data_set: string, name of the folder storing images. Labels file is also
named as '[data_set].csv'.
Returns:
float, training time.
"""
shape = _GetShape(extractor)
engine = ImprintingEngine(test_utils.TestDataPath('imprinting', extractor))
output_model_path = '/tmp/model_for_benchmark.tflite'
data_dir = test_utils.TestDataPath(data_set)
# The labels file is named as '[data_set].csv'.
image_list_by_category = test_utils.PrepareClassificationDataSet(
test_utils.TestDataPath(data_set + '.csv'))
start_time = time.monotonic()
for category, image_list in image_list_by_category.items():
category_dir = os.path.join(data_dir, category)
image_list_by_category[category] = test_utils.PrepareImages(
image_list, category_dir, shape)
end_time = time.monotonic()
print('Image pre-processing time: ', end_time - start_time, 's')
start_time = end_time
for category, tensors in image_list_by_category.items():
engine.Train(tensors)
engine.SaveModel(output_model_path)
training_time = time.monotonic() - start_time
print('Embedding Extractor: ', extractor)
print('Data set : ', data_set)
print('Training time : ', training_time, 's')
# Remove the model.
subprocess.call(['rm', output_model_path])
return training_time
if __name__ == '__main__':
args = test_utils.ParseArgs()
machine = test_utils.MachineInfo()
extractors, reference = test_utils.ReadReference(
'imprinting_reference_%s.csv' % machine)
extractor_num = len(extractors)
results = [('MODEL', 'DATA_SET', 'INFERENCE_TIME')]
for cnt, name in enumerate(extractors, start=1):
# 10 Categories, each has 20 images.
data = 'open_image_v4_subset'
print('---------------- ', cnt, '/', extractor_num, ' ----------------')
results.append((name, data, _BenchmarkForTraining(name, data)))
test_utils.SaveAsCsv('imprinting_benchmarks_%s_%s.csv' %
(machine, time.strftime('%Y%m%d-%H%M%S')), results)
test_utils.CheckResult(reference, results, args.enable_assertion)