| import cv2 |
| import numpy as np |
| import tflite_runtime.interpreter as tflite |
| |
| import classify |
| import detect |
| |
| FACE_DETECTION_MODEL = 'ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite' |
| CLASSIFICATION_MODEL = 'mobilenet_v2_1.0_224_quant_edgetpu.tflite' |
| CLASSIFICATION_LABELS = 'imagenet_labels.txt' |
| |
| CORAL_COLOR = (86, 104, 237) |
| |
| def load_labels(filename, encoding='utf-8'): |
| with open(filename, 'r', encoding=encoding) as f: |
| return {index : line.strip() for (index, line) in enumerate(f.readlines())} |
| |
| def make_interpreter(model_file): |
| model_file, *device = model_file.split('@') |
| return tflite.Interpreter( |
| model_path=model_file, |
| experimental_delegates=[tflite.load_delegate('libedgetpu.so.1', |
| {'device': device[0]} if device else {})]) |
| |
| class Detector: |
| def __init__(self, model): |
| self.interpreter = make_interpreter(model) |
| self.interpreter.allocate_tensors() |
| |
| def get_objects(self, frame, threshold=0.01): |
| height, width, _ = frame.shape |
| scale = detect.set_input(self.interpreter, (width, height), |
| lambda size: cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC)) |
| self.interpreter.invoke() |
| return detect.get_output(self.interpreter, threshold, scale) |
| |
| class Classifier: |
| def __init__(self, model): |
| self.interpreter = make_interpreter(model) |
| self.interpreter.allocate_tensors() |
| |
| def get_classes(self, frame, top_k=1, threshold=0.0): |
| size = classify.input_size(self.interpreter) |
| classify.set_input(self.interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC)) |
| self.interpreter.invoke() |
| return classify.get_output(self.interpreter, top_k, threshold) |
| |
| def draw_objects(frame, objs, color=CORAL_COLOR, thickness=5): |
| for obj in objs: |
| bbox = obj.bbox |
| cv2.rectangle(frame, (bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax), color, thickness) |
| |
| def draw_classes(frame, classes, labels, color=CORAL_COLOR): |
| for index, score in classes: |
| label = '%s (%.2f)' % (labels.get(index, 'n/a'), score) |
| cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2.0, color, 2) |
| |
| def get_frames(title, size): |
| width, height = size |
| cap = cv2.VideoCapture(0) |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) |
| while True: |
| ret, frame = cap.read() |
| if ret: |
| yield frame |
| cv2.imshow(title, frame) |
| if cv2.waitKey(1) & 0xFF == ord('q'): |
| break |
| cap.release() |
| cv2.destroyAllWindows() |