blob: e13fbbd69893262785becaf08297e32d25dcc309 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test utils for benchmark and manual tests."""
import argparse
import collections
import contextlib
import csv
import os
import platform
import random
import urllib.parse
import numpy as np
from PIL import Image
def ParseArgs():
parser = argparse.ArgumentParser()
parser.add_argument('--enable_assertion', dest='enable_assertion',
action='store_true', default=False)
return parser.parse_args()
def CheckCpuScalingGovernorStatus():
"""Checks whether CPU scaling enabled."""
with open('/sys/devices/system/cpu/cpu0/cpufreq/scaling_governor') as f:
status = f.read()
if 'performance' != status.strip():
print('************************ WARNING *****************************')
print('CPU scaling is enabled! Please switch to \'performance\' mode ')
print('**************************************************************')
def MachineInfo():
"""Gets platform info to choose reference value."""
machine = platform.machine()
if machine == 'armv7l':
with open('/proc/device-tree/model') as model_file:
board_info = model_file.read()
if 'Raspberry Pi 3 Model B Rev' in board_info:
machine = 'rp3b'
elif 'Raspberry Pi 3 Model B Plus Rev' in board_info:
machine = 'rp3b+'
else:
machine = 'unknown'
return machine
TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..', 'test_data')
REFERENCE_DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'reference')
BENCHMARK_RESULT_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'result')
def TestDataPath(path, *paths):
"""Returns absolute path for a given test file."""
return os.path.abspath(os.path.join(TEST_DATA_DIR, path, *paths))
def ReferencePath(path, *paths):
"""Returns absolute path for a given benchmark reference file."""
return os.path.abspath(os.path.join(REFERENCE_DATA_DIR, path, *paths))
def BenchmarkResultPath(path, *paths):
"""Returns absolute path for a given benchmark result file."""
return os.path.abspath(os.path.join(BENCHMARK_RESULT_DIR, path, *paths))
@contextlib.contextmanager
def TestImage(path, *paths):
"""Returns opened test image."""
with open(TestDataPath(path, *paths), 'rb') as f:
with Image.open(f) as image:
yield image
def GenerateRandomInput(seed, n):
"""Generates a list with n uint8 numbers."""
random.seed(a=seed)
return [random.randint(0, 255) for _ in range(n)]
def PrepareClassificationDataSet(filename):
"""Prepares classification data set.
Args:
filename: name of the csv file. It contains filenames of images and the
categories they belonged.
Returns:
Dict with format {category_name : list of filenames}
"""
ret = collections.defaultdict(list)
with open(filename, mode='r') as csv_file:
for row in csv.DictReader(csv_file):
if not row['URL']:
continue
url = urllib.parse.urlparse(row['URL'])
filename = os.path.basename(url.path)
ret[row['Category']].append(filename)
return ret
def PrepareImages(image_list, directory, shape):
"""Reads images and converts them to numpy array with specified shape.
Args:
image_list: a list of strings storing file names.
directory: string, path of directory storing input images.
shape: a 2-D tuple represents the shape of required input tensor.
Returns:
A list of numpy.array.
"""
ret = []
for filename in image_list:
file_path = os.path.join(directory, filename)
if not os.path.isfile(file_path):
continue
with Image.open(file_path) as img:
img = img.resize(shape, Image.NEAREST)
flat_img = np.asarray(img).flatten()
if flat_img.shape[0] == shape[0] * shape[1] * 3:
ret.append(flat_img)
return np.array(ret)
def ReadReference(file_name):
"""Reads reference from csv file.
Args:
file_name: string, name of the reference file.
Returns:
model_list: list of string.
reference: { environment : reference_time}, environment is a string tuple
while reference_time is a float number.
"""
model_list = set()
reference = {}
with open(ReferencePath(file_name), newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
# Drop first line(column names).
next(reader)
for row in reader:
reference[tuple(row[:-1])] = float(row[-1])
model_list.add(row[0])
return sorted(model_list), reference
def CheckResult(reference, result_list, enable_assertion):
"""Checks result, warns when latency is abnormal.
Args:
reference: { environment : reference_time}, environment is a string tuple
while reference_time is a float number.
result_list: a list of tuple.
enable_assertion: bool, throw assertion when unexpected latencty detected.
"""
# Allow 30% variance.
variance_threshold = 0.30
print('******************** Check results *********************')
cnt = 0
# Drop first line(column name).
for result in result_list[1:]:
environment = result[:-1]
inference_time = result[-1]
if environment not in reference:
print(' * No matching record for [%s].' % (','.join(environment)))
cnt += 1
reference_latency = reference[environment]
up_limit = reference_latency * (1 + variance_threshold)
down_limit = reference_latency * (1 - variance_threshold)
if inference_time > up_limit:
msg = ((' * Unexpected high latency! [%s]\n'
' Inference time: %s ms Reference time: %s ms') %
(','.join(environment), inference_time, reference_latency))
print(msg)
cnt += 1
if inference_time < down_limit:
msg = ((' * Unexpected low latency! [%s]\n'
' Inference time: %s ms Reference time: %s ms') %
(','.join(environment), inference_time, reference_latency))
print(msg)
cnt += 1
print('******************** Check finished! *******************')
if enable_assertion:
assert cnt == 0, 'Benchmark test failed!'
def SaveAsCsv(file_name, result):
"""Saves benchmark result as csv files.
Args:
file_name: string, name of the saved file.
result: A list of tuple.
"""
os.makedirs(BENCHMARK_RESULT_DIR, exist_ok=True)
with open(BenchmarkResultPath(file_name), 'w', newline='') as csv_file:
writer = csv.writer(
csv_file, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
for line in result:
writer.writerow(line)
print(file_name, ' saved!')