blob: c18b1606eca4e34f53ebcd9c07464e46e1413e39 [file] [log] [blame]
import cv2
import numpy as np
import tflite_runtime.interpreter as tflite
import classify
import detect
FACE_DETECTION_MODEL = 'ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite'
CLASSIFICATION_MODEL = 'mobilenet_v2_1.0_224_quant_edgetpu.tflite'
CLASSIFICATION_LABELS = 'imagenet_labels.txt'
CORAL_COLOR = (86, 104, 237)
def load_labels(filename, encoding='utf-8'):
with open(filename, 'r', encoding=encoding) as f:
return {index : line.strip() for (index, line) in enumerate(f.readlines())}
def make_interpreter(model_file):
model_file, *device = model_file.split('@')
return tflite.Interpreter(
model_path=model_file,
experimental_delegates=[tflite.load_delegate('libedgetpu.so.1',
{'device': device[0]} if device else {})])
class Detector:
def __init__(self, model):
self.interpreter = make_interpreter(model)
self.interpreter.allocate_tensors()
def get_objects(self, frame, threshold=0.01):
height, width, _ = frame.shape
scale = detect.set_input(self.interpreter, (width, height),
lambda size: cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
self.interpreter.invoke()
return detect.get_output(self.interpreter, threshold, scale)
class Classifier:
def __init__(self, model):
self.interpreter = make_interpreter(model)
self.interpreter.allocate_tensors()
def get_classes(self, frame, top_k=1, threshold=0.0):
size = classify.input_size(self.interpreter)
classify.set_input(self.interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
self.interpreter.invoke()
return classify.get_output(self.interpreter, top_k, threshold)
def draw_objects(frame, objs, color=CORAL_COLOR, thickness=5):
for obj in objs:
bbox = obj.bbox
cv2.rectangle(frame, (bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax), color, thickness)
def draw_classes(frame, classes, labels, color=CORAL_COLOR):
for index, score in classes:
label = '%s (%.2f)' % (labels.get(index, 'n/a'), score)
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2.0, color, 2)
def get_frames(title, size):
width, height = size
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
while True:
ret, frame = cap.read()
if ret:
yield frame
cv2.imshow(title, frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()