Add program to collect image files for imprinting

Change-Id: I1f9b6111c1cc0d523c256b571dd358cd2f731bda
diff --git a/collect_images.py b/collect_images.py
new file mode 100644
index 0000000..acce886
--- /dev/null
+++ b/collect_images.py
@@ -0,0 +1,106 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import contextlib
+import queue
+import os.path
+import select
+import sys
+import termios
+import time
+import threading
+import tty
+
+import vision
+
+from pycoral.utils.dataset import read_label_file
+
+@contextlib.contextmanager
+def nonblocking(f):
+  def get_char():
+    if select.select([f], [], [], 0) == ([f], [], []):
+      return sys.stdin.read(1)
+    return None
+
+  old_settings = termios.tcgetattr(sys.stdin)
+  try:
+    tty.setcbreak(f.fileno())
+    yield get_char
+  finally:
+    termios.tcsetattr(f, termios.TCSADRAIN, old_settings)
+
+@contextlib.contextmanager
+def worker(process):
+  requests = queue.Queue()
+
+  def run():
+    while True:
+      request = requests.get()
+      if request is None:
+        break
+      process(request)
+      requests.task_done()
+
+  def submit(request):
+    requests.put(request)
+
+  thread = threading.Thread(target=run)
+  thread.start()
+  try:
+    yield submit
+  finally:
+    requests.put(None)
+    thread.join()
+
+def save_frame(request):
+  filename, frame = request
+  vision.save_frame(filename, frame)
+  print('Saved: %s' % filename)
+
+def main():
+  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument('--labels', '-l', type=str, default=None, help='Labels file')
+  parser.add_argument('--output_dir', '-d', type=str, default='capture', help='Output director')
+  args = parser.parse_args()
+
+  print("Press buttons '0' .. '9' to save images from the camera.")
+
+  labels = {}
+  if args.labels:
+    labels = read_label_file(args.labels)
+    for key in sorted(labels):
+      print(key, '-', labels[key])
+
+  with nonblocking(sys.stdin) as get_char, worker(save_frame) as submit:
+    # Handle key events from GUI window.
+    def handle_key(key, frame):
+      if key == ord('q') or key == ord('Q'):
+        return False  # Stop processing frames.
+      if ord('0') <= key <= ord('9'):
+        label_id = key - ord('0')
+        class_dir = labels.get(label_id, str(label_id))
+        name = str(round(time.time() * 1000)) + '.png'
+        filename = os.path.join(args.output_dir, class_dir, name)
+        submit((filename, frame.copy()))
+      return True  # Keep processing frames.
+
+    for frame in vision.get_frames(handle_key=handle_key):
+      # Handle key events from console.
+      ch = get_char()
+      if ch is not None and not handle_key(ord(ch), frame):
+        break
+
+if __name__ == '__main__':
+  main()
diff --git a/vision.py b/vision.py
index cb3e661..2d1700f 100644
--- a/vision.py
+++ b/vision.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os.path
 import platform
 import sys
 
@@ -81,9 +82,15 @@
     label = '%s (%.2f)' % (labels.get(index, 'n/a'), score)
     cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2.0, color, 2)
 
-def get_frames(title='Raspimon camera', size=(640, 480)):
+def get_frames(title='Raspimon camera', size=(640, 480), handle_key=None):
   width, height = size
 
+  if not handle_key:
+    def handle_key(key, frame):
+      if key == ord('q') or key == ord('Q'):
+        return False
+      return True
+
   attempts = 5
   while True:
     cap = cv2.VideoCapture(0)
@@ -107,7 +114,14 @@
     if success:
       yield frame
       cv2.imshow(title, frame)
-    if cv2.waitKey(1) & 0xFF == ord('q'):
+
+    key = cv2.waitKey(1)
+    if key != -1 and not handle_key(key, frame):
       break
+
   cap.release()
   cv2.destroyAllWindows()
+
+def save_frame(filename, frame):
+  os.makedirs(os.path.dirname(filename), exist_ok=True)
+  cv2.imwrite(filename, frame)