Reuse inference code in streaming server.

Change-Id: Ifb682e66633dafa77de5292d8f4e179e9210820a
diff --git a/edgetpuvision/classify.py b/edgetpuvision/classify.py
index 247e27a..8b6acd0 100644
--- a/edgetpuvision/classify.py
+++ b/edgetpuvision/classify.py
@@ -74,13 +74,7 @@
         elif command == 'n':
             engine = next(engines)
 
-def main():
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--source',
-                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
-                        default='/dev/video0:YUY2:1280x720:30/1')
-    parser.add_argument('--downscale', type=float, default=2.0,
-                        help='Downscale factor for .mp4 file rendering')
+def add_render_gen_args(parser):
     parser.add_argument('--model', required=True,
                         help='.tflite model path')
     parser.add_argument('--labels', required=True,
@@ -91,10 +85,19 @@
                         help='number of classes with highest score to display')
     parser.add_argument('--threshold', type=float, default=0.1,
                         help='class score threshold')
-    parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
-                        help='Display mode')
     parser.add_argument('--print', default=False, action='store_true',
                         help='Print inference results')
+
+def main():
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--source',
+                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
+                        default='/dev/video0:YUY2:1280x720:30/1')
+    parser.add_argument('--downscale', type=float, default=2.0,
+                        help='Downscale factor for .mp4 file rendering')
+    parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
+                        help='Display mode')
+    add_render_gen_args(parser)
     args = parser.parse_args()
 
     if not run_gen(render_gen(args),
diff --git a/edgetpuvision/classify_server.py b/edgetpuvision/classify_server.py
index 4c43372..e5c9f97 100644
--- a/edgetpuvision/classify_server.py
+++ b/edgetpuvision/classify_server.py
@@ -6,53 +6,11 @@
 #   --model ${TEST_DATA}/mobilenet_v2_1.0_224_inat_bird_quant.tflite \
 #   --labels ${TEST_DATA}/inat_bird_labels.txt
 
-import argparse
-import logging
-import signal
-import time
-
-from edgetpu.classification.engine import ClassificationEngine
-
-from . import overlays
-from .camera import make_camera
-from .streaming.server import StreamingServer
-from .utils import load_labels, input_image_size
-
+from .classify import add_render_gen_args, render_gen
+from .server import run
 
 def main():
-    logging.basicConfig(level=logging.INFO)
-
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--source',
-                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
-                        default='/dev/video0:YUY2:1280x720:30/1')
-    parser.add_argument('--model', required=True,
-                        help='.tflite model path')
-    parser.add_argument('--labels', required=True,
-                        help='label file path')
-    parser.add_argument('--top_k', type=int, default=3,
-                        help='number of classes with highest score to display')
-    parser.add_argument('--threshold', type=float, default=0.1,
-                        help='class score threshold')
-    args = parser.parse_args()
-
-    engine = ClassificationEngine(args.model)
-    labels = load_labels(args.labels)
-
-    camera = make_camera(args.source, input_image_size(engine))
-    assert camera is not None
-
-    with StreamingServer(camera) as server:
-        def on_image(tensor, inference_fps, size, window):
-            start = time.monotonic()
-            results = engine.ClassifyWithInputTensor(tensor, threshold=args.threshold, top_k=args.top_k)
-            inference_time = time.monotonic() - start
-
-            results = [(labels[i], score) for i, score in results]
-            server.send_overlay(overlays.classification(results, inference_time, inference_fps, size, window))
-
-        camera.on_image = on_image
-        signal.pause()
+    run(add_render_gen_args, render_gen)
 
 if __name__ == '__main__':
     main()
diff --git a/edgetpuvision/detect.py b/edgetpuvision/detect.py
index b87697e..a3820d4 100644
--- a/edgetpuvision/detect.py
+++ b/edgetpuvision/detect.py
@@ -70,13 +70,7 @@
         elif command == 'n':
             engine = next(engines)
 
-def main():
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--source',
-                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
-                        default='/dev/video0:YUY2:1280x720:30/1')
-    parser.add_argument('--downscale', type=float, default=2.0,
-                        help='Downscale factor for video/image file rendering')
+def add_render_gen_args(parser):
     parser.add_argument('--model',
                         help='.tflite model path', required=True)
     parser.add_argument('--labels',
@@ -91,10 +85,19 @@
                         help='Max bounding box area')
     parser.add_argument('--filter', default=None,
                         help='Comma-separated list of allowed labels')
-    parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
-                        help='Display mode')
     parser.add_argument('--print', default=False, action='store_true',
                         help='Print inference results')
+
+def main():
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--source',
+                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
+                        default='/dev/video0:YUY2:1280x720:30/1')
+    parser.add_argument('--downscale', type=float, default=2.0,
+                        help='Downscale factor for video/image file rendering')
+    parser.add_argument('--display', type=Display, choices=Display, default=Display.FULLSCREEN,
+                        help='Display mode')
+    add_render_gen_args(parser)
     args = parser.parse_args()
 
     if not run_gen(render_gen(args),
diff --git a/edgetpuvision/detect_server.py b/edgetpuvision/detect_server.py
index 906020b..f711ba2 100644
--- a/edgetpuvision/detect_server.py
+++ b/edgetpuvision/detect_server.py
@@ -11,57 +11,11 @@
 #   --model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \
 #   --labels ${TEST_DATA}/coco_labels.txt
 
-import argparse
-import logging
-import os
-import signal
-import time
-
-from edgetpu.detection.engine import DetectionEngine
-
-from . import overlays
-from .camera import make_camera
-from .streaming.server import StreamingServer
-from .utils import load_labels, input_image_size
+from .detect import add_render_gen_args, render_gen
+from .server import run
 
 def main():
-    logging.basicConfig(level=logging.INFO)
-
-    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--source',
-                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
-                        default='/dev/video0:YUY2:1280x720:30/1')
-    parser.add_argument('--model',
-                        help='.tflite model path', required=True)
-    parser.add_argument('--labels',
-                        help='labels file path')
-    parser.add_argument('--top_k', type=int, default=50,
-                        help='Max number of objects to detect')
-    parser.add_argument('--threshold', type=float, default=0.1,
-                        help='Detection threshold')
-    parser.add_argument('--filter', default=None)
-    args = parser.parse_args()
-
-    engine = DetectionEngine(args.model)
-    labels = load_labels(args.labels) if args.labels else None
-    filtered_labels = set(l.strip() for l in args.filter.split(',')) if args.filter else None
-
-    camera = make_camera(args.source, input_image_size(engine))
-    assert camera is not None
-
-    with StreamingServer(camera) as server:
-        def on_image(tensor, inference_fps, size, window):
-            start = time.monotonic()
-            objs = engine.DetectWithInputTensor(tensor, threshold=args.threshold, top_k=args.top_k)
-            inference_time = time.monotonic() - start
-
-            if labels and filtered_labels:
-                objs = [obj for obj in objs if labels[obj.label_id] in filtered_labels]
-
-            server.send_overlay(overlays.detection(objs, labels, inference_time, inference_fps, size, window))
-
-        camera.on_image = on_image
-        signal.pause()
+    run(add_render_gen_args, render_gen)
 
 if __name__ == '__main__':
     main()
diff --git a/edgetpuvision/server.py b/edgetpuvision/server.py
new file mode 100644
index 0000000..95f2e3c
--- /dev/null
+++ b/edgetpuvision/server.py
@@ -0,0 +1,28 @@
+import argparse
+import logging
+import signal
+
+from .camera import make_camera
+from .streaming.server import StreamingServer
+
+def run(add_render_gen_args, render_gen):
+    logging.basicConfig(level=logging.INFO)
+
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--source',
+                        help='/dev/videoN:FMT:WxH:N/D or .mp4 file or image file',
+                        default='/dev/video0:YUY2:1280x720:30/1')
+    add_render_gen_args(parser)
+    args = parser.parse_args()
+
+    gen = render_gen(args)
+    camera = make_camera(args.source, next(gen))
+    assert camera is not None
+
+    with StreamingServer(camera) as server:
+        def on_image(tensor, inference_rate, size, window):
+            overlay = gen.send((tensor, size, window, inference_rate, None))
+            server.send_overlay(overlay)
+
+        camera.on_image = on_image
+        signal.pause()