blob: b5a3598cb411930121da697c4e7200cf824dc771 [file] [log] [blame] [edit]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A demo which runs object classification on camera frames.
export TEST_DATA=/usr/lib/python3/dist-packages/edgetpu/test_data
python3 -m edgetpuvision.classify \
--model ${TEST_DATA}/mobilenet_v2_1.0_224_inat_bird_quant.tflite \
--labels ${TEST_DATA}/inat_bird_labels.txt
"""
import argparse
import collections
import itertools
import numpy as np
import time
from pycoral.adapters import classify
from pycoral.utils import edgetpu
from PIL import Image
from . import svg
from . import utils
from .apps import run_app
CSS_STYLES = str(svg.CssStyle({'.back': svg.Style(fill='black',
stroke='black',
stroke_width='0.5em')}))
def size_em(length):
return '%sem' % str(0.6 * (length + 1))
def overlay(title, results, inference_time, inference_rate, layout):
x0, y0, width, height = layout.window
font_size = 0.03 * height
defs = svg.Defs()
defs += CSS_STYLES
doc = svg.Svg(width=width, height=height,
viewBox='%s %s %s %s' % layout.window,
font_size=font_size, font_family='monospace', font_weight=500)
doc += defs
ox1, ox2 = x0 + 20, x0 + width - 20
oy1, oy2 = y0 + 20 + font_size, y0 + height - 20
# Classes
lines = ['%s (%.2f)' % pair for pair in results]
for i, line in enumerate(lines):
y = oy2 - i * 1.7 * font_size
doc += svg.Rect(x=0, y=0, width=size_em(len(line)), height='1em',
transform='translate(%s, %s) scale(-1,-1)' % (ox2, y),
_class='back')
doc += svg.Text(line, text_anchor='end', x=ox2, y=y, fill='white')
# Title
if title:
doc += svg.Rect(x=0, y=0, width=size_em(len(title)), height='1em',
transform='translate(%s, %s) scale(1,-1)' % (ox1, oy1), _class='back')
doc += svg.Text(title, x=ox1, y=oy1, fill='white')
# Info
lines = [
'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time)
]
for i, line in enumerate(reversed(lines)):
y = oy2 - i * 1.7 * font_size
doc += svg.Rect(x=0, y=0, width=size_em(len(line)), height='1em',
transform='translate(%s, %s) scale(1,-1)' % (ox1, y), _class='back')
doc += svg.Text(line, x=ox1, y=y, fill='white')
return str(doc)
def top_results(window, top_k):
total_scores = collections.defaultdict(lambda: 0.0)
for results in window:
for label, score in results:
total_scores[label] += score
return sorted(total_scores.items(), key=lambda kv: kv[1], reverse=True)[:top_k]
def accumulator(size, top_k):
window = collections.deque(maxlen=size)
window.append((yield []))
while True:
window.append((yield top_results(window, top_k)))
def print_results(inference_rate, results):
print('\nInference (rate=%.2f fps):' % inference_rate)
print(results)
for label, score in results:
print(' %s, score=%.2f' % (label, score))
def render_gen(args):
acc = accumulator(size=args.window, top_k=args.top_k)
acc.send(None) # Initialize.
fps_counter = utils.avg_fps_counter(30)
interpreters, titles = utils.make_interpreters(args.model)
assert utils.same_input_image_sizes(interpreters)
interpreters = itertools.cycle(interpreters)
interpreter = next(interpreters)
labels = utils.load_labels(args.labels)
draw_overlay = True
yield utils.input_image_size(interpreter)
output = None
while True:
image, layout, command = (yield output)
inference_rate = next(fps_counter)
if draw_overlay:
input_shape = interpreter.get_input_details()[0]['shape']
if input_shape[3] == 1:
image = image.convert('L')
tensor = np.asarray(image).flatten()
start = time.monotonic()
edgetpu.run_inference(interpreter, tensor)
inference_time = time.monotonic() - start
classes = classify.get_classes(interpreter, top_k=args.top_k,
score_threshold=args.threshold)
results = [(labels[class_id], score) for class_id, score in classes]
results = acc.send(results)
if args.print:
print_results(inference_rate, results)
title = titles[interpreter]
output = overlay(title, results, inference_time, inference_rate, layout)
else:
output = None
if command == 'o':
draw_overlay = not draw_overlay
elif command == 'n':
interpreter = next(interpreters)
def add_render_gen_args(parser):
parser.add_argument('--model', required=True,
help='.tflite model path')
parser.add_argument('--labels', required=True,
help='label file path')
parser.add_argument('--window', type=int, default=10,
help='number of frames to accumulate inference results')
parser.add_argument('--top_k', type=int, default=3,
help='number of classes with highest score to display')
parser.add_argument('--threshold', type=float, default=0.1,
help='class score threshold')
parser.add_argument('--print', default=False, action='store_true',
help='Print inference results')
def main():
run_app(add_render_gen_args, render_gen)
if __name__ == '__main__':
main()