blob: 63cf602cedae8af05bbd45f6fe94c987681b47e5 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import unittest
from . import test_utils
from edgetpu.basic import edgetpu_utils
from edgetpu.classification.engine import BasicEngine
from edgetpu.classification.engine import ClassificationEngine
from edgetpu.detection.engine import DetectionEngine
import numpy as np
class MultipleTpusTest(unittest.TestCase):
def testCreateBasicEngineWithSpecificPath(self):
edge_tpus = edgetpu_utils.ListEdgeTpuPaths(
edgetpu_utils.EDGE_TPU_STATE_UNASSIGNED)
self.assertGreater(len(edge_tpus), 0)
model_path = test_utils.TestDataPath(
'mobilenet_v1_1.0_224_quant_edgetpu.tflite')
basic_engine = BasicEngine(model_path, edge_tpus[0])
self.assertEqual(edge_tpus[0], basic_engine.device_path())
def testRunClassificationAndDetectionEngine(self):
def classification_task(num_inferences):
tid = threading.get_ident()
print('Thread: %d, %d inferences for classification task' %
(tid, num_inferences))
labels = test_utils.ReadLabelFile(
test_utils.TestDataPath('imagenet_labels.txt'))
model_name = 'mobilenet_v1_1.0_224_quant_edgetpu.tflite'
engine = ClassificationEngine(test_utils.TestDataPath(model_name))
print('Thread: %d, using device %s' % (tid, engine.device_path()))
with test_utils.TestImage('cat.bmp') as img:
for _ in range(num_inferences):
ret = engine.ClassifyWithImage(img, top_k=1)
self.assertEqual(len(ret), 1)
self.assertEqual(labels[ret[0][0]], 'Egyptian cat')
print('Thread: %d, done classification task' % tid)
def detection_task(num_inferences):
tid = threading.get_ident()
print('Thread: %d, %d inferences for detection task' %
(tid, num_inferences))
model_name = 'mobilenet_ssd_v1_coco_quant_postprocess_edgetpu.tflite'
engine = DetectionEngine(test_utils.TestDataPath(model_name))
print('Thread: %d, using device %s' % (tid, engine.device_path()))
with test_utils.TestImage('cat.bmp') as img:
for _ in range(num_inferences):
ret = engine.DetectWithImage(img, top_k=1)
self.assertEqual(len(ret), 1)
self.assertEqual(ret[0].label_id, 16) # cat
self.assertGreater(ret[0].score, 0.7)
self.assertGreater(
test_utils.IOU(
np.array([[0.1, 0.1], [0.7, 1.0]]), ret[0].bounding_box),
0.88)
print('Thread: %d, done detection task' % tid)
num_inferences = 10
t1 = threading.Thread(target=classification_task, args=(num_inferences,))
t2 = threading.Thread(target=detection_task, args=(num_inferences,))
t1.start()
t2.start()
t1.join()
t2.join()
if __name__ == '__main__':
unittest.main()