blob: 3b5141e6f320cad49de55ea59905370ed2526b90 [file] [log] [blame]
import cv2
import numpy as np
import tflite_runtime.interpreter as tflite
import classify
import detect
def load_labels(filename, encoding='utf-8'):
with open(filename, 'r', encoding=encoding) as f:
return {index : line.strip() for (index, line) in enumerate(f.readlines())}
def make_interpreter(model_file):
model_file, *device = model_file.split('@')
return tflite.Interpreter(
model_path=model_file,
experimental_delegates=[tflite.load_delegate('libedgetpu.so.1',
{'device': device[0]} if device else {})])
def make_detector(model, threshold=0.01):
interpreter = make_interpreter(model)
interpreter.allocate_tensors()
def process(frame):
height, width, _ = frame.shape
scale = detect.set_input(interpreter, (width, height),
lambda size: cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
interpreter.invoke()
return detect.get_output(interpreter, threshold, scale)
return process
def make_classifier(model, top_k=1, threshold=0.0):
interpreter = make_interpreter(model)
interpreter.allocate_tensors()
size = classify.input_size(interpreter)
def process(frame):
classify.set_input(interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
interpreter.invoke()
return classify.get_output(interpreter, top_k, threshold)
return process
def Camera(title, size):
width, height = size
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
while True:
ret, frame = cap.read()
if ret:
yield frame
cv2.imshow(title, frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()