Reuse inference code in streaming server.
Change-Id: Ifb682e66633dafa77de5292d8f4e179e9210820a
diff --git a/edgetpuvision/classify.py b/edgetpuvision/classify.py
index 247e27a..8b6acd0 100644
--- a/edgetpuvision/classify.py
+++ b/edgetpuvision/classify.py
@@ -74,13 +74,7 @@
elif command == 'n':
engine = next(engines)
-def main():
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--source',
- help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
- default='/dev/video0:YUY2:1280x720:30/1')
- parser.add_argument('--downscale', type=float, default=2.0,
- help='Downscale factor for .mp4 file rendering')
+def add_render_gen_args(parser):
parser.add_argument('--model', required=True,
help='.tflite model path')
parser.add_argument('--labels', required=True,
@@ -91,10 +85,19 @@
help='number of classes with highest score to display')
parser.add_argument('--threshold', type=float, default=0.1,
help='class score threshold')
- parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
- help='Display mode')
parser.add_argument('--print', default=False, action='store_true',
help='Print inference results')
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--source',
+ help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
+ default='/dev/video0:YUY2:1280x720:30/1')
+ parser.add_argument('--downscale', type=float, default=2.0,
+ help='Downscale factor for .mp4 file rendering')
+ parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
+ help='Display mode')
+ add_render_gen_args(parser)
args = parser.parse_args()
if not run_gen(render_gen(args),
diff --git a/edgetpuvision/classify_server.py b/edgetpuvision/classify_server.py
index 4c43372..e5c9f97 100644
--- a/edgetpuvision/classify_server.py
+++ b/edgetpuvision/classify_server.py
@@ -6,53 +6,11 @@
# --model ${TEST_DATA}/mobilenet_v2_1.0_224_inat_bird_quant.tflite \
# --labels ${TEST_DATA}/inat_bird_labels.txt
-import argparse
-import logging
-import signal
-import time
-
-from edgetpu.classification.engine import ClassificationEngine
-
-from . import overlays
-from .camera import make_camera
-from .streaming.server import StreamingServer
-from .utils import load_labels, input_image_size
-
+from .classify import add_render_gen_args, render_gen
+from .server import run
def main():
- logging.basicConfig(level=logging.INFO)
-
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--source',
- help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
- default='/dev/video0:YUY2:1280x720:30/1')
- parser.add_argument('--model', required=True,
- help='.tflite model path')
- parser.add_argument('--labels', required=True,
- help='label file path')
- parser.add_argument('--top_k', type=int, default=3,
- help='number of classes with highest score to display')
- parser.add_argument('--threshold', type=float, default=0.1,
- help='class score threshold')
- args = parser.parse_args()
-
- engine = ClassificationEngine(args.model)
- labels = load_labels(args.labels)
-
- camera = make_camera(args.source, input_image_size(engine))
- assert camera is not None
-
- with StreamingServer(camera) as server:
- def on_image(tensor, inference_fps, size, window):
- start = time.monotonic()
- results = engine.ClassifyWithInputTensor(tensor, threshold=args.threshold, top_k=args.top_k)
- inference_time = time.monotonic() - start
-
- results = [(labels[i], score) for i, score in results]
- server.send_overlay(overlays.classification(results, inference_time, inference_fps, size, window))
-
- camera.on_image = on_image
- signal.pause()
+ run(add_render_gen_args, render_gen)
if __name__ == '__main__':
main()
diff --git a/edgetpuvision/detect.py b/edgetpuvision/detect.py
index b87697e..a3820d4 100644
--- a/edgetpuvision/detect.py
+++ b/edgetpuvision/detect.py
@@ -70,13 +70,7 @@
elif command == 'n':
engine = next(engines)
-def main():
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--source',
- help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
- default='/dev/video0:YUY2:1280x720:30/1')
- parser.add_argument('--downscale', type=float, default=2.0,
- help='Downscale factor for video/image file rendering')
+def add_render_gen_args(parser):
parser.add_argument('--model',
help='.tflite model path', required=True)
parser.add_argument('--labels',
@@ -91,10 +85,19 @@
help='Max bounding box area')
parser.add_argument('--filter', default=None,
help='Comma-separated list of allowed labels')
- parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
- help='Display mode')
parser.add_argument('--print', default=False, action='store_true',
help='Print inference results')
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--source',
+ help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
+ default='/dev/video0:YUY2:1280x720:30/1')
+ parser.add_argument('--downscale', type=float, default=2.0,
+ help='Downscale factor for video/image file rendering')
+ parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
+ help='Display mode')
+ add_render_gen_args(parser)
args = parser.parse_args()
if not run_gen(render_gen(args),
diff --git a/edgetpuvision/detect_server.py b/edgetpuvision/detect_server.py
index 906020b..f711ba2 100644
--- a/edgetpuvision/detect_server.py
+++ b/edgetpuvision/detect_server.py
@@ -11,57 +11,11 @@
# --model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \
# --labels ${TEST_DATA}/coco_labels.txt
-import argparse
-import logging
-import os
-import signal
-import time
-
-from edgetpu.detection.engine import DetectionEngine
-
-from . import overlays
-from .camera import make_camera
-from .streaming.server import StreamingServer
-from .utils import load_labels, input_image_size
+from .detect import add_render_gen_args, render_gen
+from .server import run
def main():
- logging.basicConfig(level=logging.INFO)
-
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--source',
- help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
- default='/dev/video0:YUY2:1280x720:30/1')
- parser.add_argument('--model',
- help='.tflite model path', required=True)
- parser.add_argument('--labels',
- help='labels file path')
- parser.add_argument('--top_k', type=int, default=50,
- help='Max number of objects to detect')
- parser.add_argument('--threshold', type=float, default=0.1,
- help='Detection threshold')
- parser.add_argument('--filter', default=None)
- args = parser.parse_args()
-
- engine = DetectionEngine(args.model)
- labels = load_labels(args.labels) if args.labels else None
- filtered_labels = set(l.strip() for l in args.filter.split(',')) if args.filter else None
-
- camera = make_camera(args.source, input_image_size(engine))
- assert camera is not None
-
- with StreamingServer(camera) as server:
- def on_image(tensor, inference_fps, size, window):
- start = time.monotonic()
- objs = engine.DetectWithInputTensor(tensor, threshold=args.threshold, top_k=args.top_k)
- inference_time = time.monotonic() - start
-
- if labels and filtered_labels:
- objs = [obj for obj in objs if labels[obj.label_id] in filtered_labels]
-
- server.send_overlay(overlays.detection(objs, labels, inference_time, inference_fps, size, window))
-
- camera.on_image = on_image
- signal.pause()
+ run(add_render_gen_args, render_gen)
if __name__ == '__main__':
main()
diff --git a/edgetpuvision/server.py b/edgetpuvision/server.py
new file mode 100644
index 0000000..95f2e3c
--- /dev/null
+++ b/edgetpuvision/server.py
@@ -0,0 +1,28 @@
+import argparse
+import logging
+import signal
+
+from .camera import make_camera
+from .streaming.server import StreamingServer
+
+def run(add_render_gen_args, render_gen):
+ logging.basicConfig(level=logging.INFO)
+
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--source',
+ help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
+ default='/dev/video0:YUY2:1280x720:30/1')
+ add_render_gen_args(parser)
+ args = parser.parse_args()
+
+ gen = render_gen(args)
+ camera = make_camera(args.source, next(gen))
+ assert camera is not None
+
+ with StreamingServer(camera) as server:
+ def on_image(tensor, inference_rate, size, window):
+ overlay = gen.send((tensor, size, window, inference_rate, None))
+ server.send_overlay(overlay)
+
+ camera.on_image = on_image
+ signal.pause()