blob: 1742da4bd5482d4b924b0e4f17d77c0d59972c70 [file] [log] [blame]
"""A demo which runs object detection on camera frames."""
# export TEST_DATA=/usr/lib/python3/dist-packages/edgetpu/test_data
#
# Run face detection model:
# python3 -m edgetpuvision.detect \
# --model ${TEST_DATA}/mobilenet_ssd_v2_face_quant_postprocess_edgetpu.tflite
#
# Run coco model:
# python3 -m edgetpuvision.detect \
# --model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \
# --labels ${TEST_DATA}/coco_labels.txt
import argparse
import collections
import colorsys
import itertools
import time
from edgetpu.detection.engine import DetectionEngine
from . import svg
from .apps import run_app
from .utils import load_labels, input_image_size, same_input_image_sizes, avg_fps_counter
CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
'.back': svg.Style(fill='black',
stroke='black',
stroke_width='1em'),
'.bbox': svg.Style(fill_opacity=0.0, stroke_width='2px')}))
BBox = collections.namedtuple('BBox', ('x', 'y', 'w', 'h'))
BBox.area = lambda self: self.w * self.h
BBox.scale = lambda self, sx, sy: BBox(x=self.x * sx, y=self.y * sy,
w=self.w * sx, h=self.h * sy)
BBox.__str__ = lambda self: 'BBox(x=%.2f y=%.2f w=%.2f h=%.2f)' % self
Object = collections.namedtuple('Object', ('id', 'label', 'score', 'bbox'))
Object.__str__ = lambda self: 'Object(id=%d, label=%s, score=%.2f, %s)' % self
def color(i, total):
return tuple(int(255.0 * c) for c in colorsys.hsv_to_rgb(i / total, 1.0, 1.0))
def make_palette(keys):
return {key : svg.rgb(color(i, len(keys))) for i, key in enumerate(keys)}
def make_get_color(color, labels):
if color:
return lambda obj_id: color
if labels:
palette = make_palette(labels.keys())
return lambda obj_id: palette[obj_id]
return lambda obj_id: 'white'
def overlay(objs, get_color, inference_time, inference_rate, layout):
x0, y0, width, height = layout.window
defs = svg.Defs()
defs += CSS_STYLES
doc = svg.Svg(width=width, height=height,
viewBox='%s %s %s %s' % layout.window,
font_size='1em', font_family='sans-serif', font_weight=600)
doc += defs
for obj in objs:
percent = int(100 * obj.score)
if obj.label:
caption = '%d%% %s' % (percent, obj.label)
else:
caption = '%d%%' % percent
x, y, w, h = obj.bbox.scale(*layout.size)
doc += svg.Text(caption, x=x, y=y - 5, _class='txt')
doc += svg.Rect(x=x + 1, y=y + 1, width=w, height=h, rx=2, ry=2,
_class='bbox', style='stroke:black')
doc += svg.Rect(x=x, y=y, width=w, height=h, rx=2, ry=2,
_class='bbox', style='stroke:%s' % get_color(obj.id))
ox, oy = x0 + 20, y0 + height - 20
doc += svg.Rect(x=0, y=0, width='22em', height='2.2em',
transform='translate(%s, %s) scale(1,-1)' % (ox, oy), _class='back')
t = svg.Text(y=oy, _class='txt')
t += svg.TSpan('Objects: %d' % len(objs),
x=ox)
perf = inference_time * 1000, 1.0 / inference_time
t += svg.TSpan('Inference time: %.2f ms (%.2f fps)' % perf,
x=ox, dy='-1.2em')
doc += t
return str(doc)
def convert(obj, labels):
x0, y0, x1, y1 = obj.bounding_box.flatten().tolist()
return Object(id=obj.label_id,
label=labels[obj.label_id] if labels else None,
score=obj.score,
bbox=BBox(x=x0, y=y0, w=x1 - x0, h=y1 - y0))
def print_results(inference_rate, objs):
print('\nInference (rate=%.2f fps):' % inference_rate)
for i, obj in enumerate(objs):
print(' %d: %s, area=%.2f' % (i, obj, obj.bbox.area()))
def render_gen(args):
fps_counter=avg_fps_counter(30)
engines = [DetectionEngine(m) for m in args.model.split(',')]
assert same_input_image_sizes(engines)
engines = itertools.cycle(engines)
engine = next(engines)
labels = load_labels(args.labels) if args.labels else None
filtered_labels = set(l.strip() for l in args.filter.split(',')) if args.filter else None
get_color = make_get_color(args.color, labels)
draw_overlay = True
yield input_image_size(engine)
output = None
while True:
tensor, layout, command = (yield output)
inference_rate = next(fps_counter)
if draw_overlay:
start = time.monotonic()
objs = engine.DetectWithInputTensor(tensor, threshold=args.threshold, top_k=args.top_k)
inference_time = time.monotonic() - start
objs = [convert(obj, labels) for obj in objs]
if labels and filtered_labels:
objs = [obj for obj in objs if obj.label in filtered_labels]
objs = [obj for obj in objs if args.min_area <= obj.bbox.area() <= args.max_area]
if args.print:
print_results(inference_rate, objs)
output = overlay(objs, get_color, inference_time, inference_rate, layout)
else:
output = None
if command == 'o':
draw_overlay = not draw_overlay
elif command == 'n':
engine = next(engines)
def add_render_gen_args(parser):
parser.add_argument('--model',
help='.tflite model path', required=True)
parser.add_argument('--labels',
help='labels file path')
parser.add_argument('--top_k', type=int, default=50,
help='Max number of objects to detect')
parser.add_argument('--threshold', type=float, default=0.1,
help='Detection threshold')
parser.add_argument('--min_area', type=float, default=0.0,
help='Min bounding box area')
parser.add_argument('--max_area', type=float, default=1.0,
help='Max bounding box area')
parser.add_argument('--filter', default=None,
help='Comma-separated list of allowed labels')
parser.add_argument('--color', default=None,
help='Bounding box display color'),
parser.add_argument('--print', default=False, action='store_true',
help='Print inference results')
def main():
run_app(add_render_gen_args, render_gen)
if __name__ == '__main__':
main()