Add program to collect image files for imprinting
Change-Id: I1f9b6111c1cc0d523c256b571dd358cd2f731bda
diff --git a/collect_images.py b/collect_images.py
new file mode 100644
index 0000000..acce886
--- /dev/null
+++ b/collect_images.py
@@ -0,0 +1,106 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import contextlib
+import queue
+import os.path
+import select
+import sys
+import termios
+import time
+import threading
+import tty
+
+import vision
+
+from pycoral.utils.dataset import read_label_file
+
+@contextlib.contextmanager
+def nonblocking(f):
+ def get_char():
+ if select.select([f], [], [], 0) == ([f], [], []):
+ return sys.stdin.read(1)
+ return None
+
+ old_settings = termios.tcgetattr(sys.stdin)
+ try:
+ tty.setcbreak(f.fileno())
+ yield get_char
+ finally:
+ termios.tcsetattr(f, termios.TCSADRAIN, old_settings)
+
+@contextlib.contextmanager
+def worker(process):
+ requests = queue.Queue()
+
+ def run():
+ while True:
+ request = requests.get()
+ if request is None:
+ break
+ process(request)
+ requests.task_done()
+
+ def submit(request):
+ requests.put(request)
+
+ thread = threading.Thread(target=run)
+ thread.start()
+ try:
+ yield submit
+ finally:
+ requests.put(None)
+ thread.join()
+
+def save_frame(request):
+ filename, frame = request
+ vision.save_frame(filename, frame)
+ print('Saved: %s' % filename)
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--labels', '-l', type=str, default=None, help='Labels file')
+ parser.add_argument('--output_dir', '-d', type=str, default='capture', help='Output director')
+ args = parser.parse_args()
+
+ print("Press buttons '0' .. '9' to save images from the camera.")
+
+ labels = {}
+ if args.labels:
+ labels = read_label_file(args.labels)
+ for key in sorted(labels):
+ print(key, '-', labels[key])
+
+ with nonblocking(sys.stdin) as get_char, worker(save_frame) as submit:
+ # Handle key events from GUI window.
+ def handle_key(key, frame):
+ if key == ord('q') or key == ord('Q'):
+ return False # Stop processing frames.
+ if ord('0') <= key <= ord('9'):
+ label_id = key - ord('0')
+ class_dir = labels.get(label_id, str(label_id))
+ name = str(round(time.time() * 1000)) + '.png'
+ filename = os.path.join(args.output_dir, class_dir, name)
+ submit((filename, frame.copy()))
+ return True # Keep processing frames.
+
+ for frame in vision.get_frames(handle_key=handle_key):
+ # Handle key events from console.
+ ch = get_char()
+ if ch is not None and not handle_key(ord(ch), frame):
+ break
+
+if __name__ == '__main__':
+ main()
diff --git a/vision.py b/vision.py
index cb3e661..2d1700f 100644
--- a/vision.py
+++ b/vision.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os.path
import platform
import sys
@@ -81,9 +82,15 @@
label = '%s (%.2f)' % (labels.get(index, 'n/a'), score)
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2.0, color, 2)
-def get_frames(title='Raspimon camera', size=(640, 480)):
+def get_frames(title='Raspimon camera', size=(640, 480), handle_key=None):
width, height = size
+ if not handle_key:
+ def handle_key(key, frame):
+ if key == ord('q') or key == ord('Q'):
+ return False
+ return True
+
attempts = 5
while True:
cap = cv2.VideoCapture(0)
@@ -107,7 +114,14 @@
if success:
yield frame
cv2.imshow(title, frame)
- if cv2.waitKey(1) & 0xFF == ord('q'):
+
+ key = cv2.waitKey(1)
+ if key != -1 and not handle_key(key, frame):
break
+
cap.release()
cv2.destroyAllWindows()
+
+def save_frame(filename, frame):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ cv2.imwrite(filename, frame)