Add Classifier and Detector
Change-Id: Ic8f344a982db43919244aa5238ac2f597b2115e7
diff --git a/example.py b/example.py
index 6dee6ae..6dd6f1a 100644
--- a/example.py
+++ b/example.py
@@ -1,16 +1,16 @@
import vision
def run_detector_example():
- detector = vision.make_detector('ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite')
- for frame in vision.Camera('Face Detector', size=(640, 480)):
- faces = detector(frame)
+ detector = vision.Detector('ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite')
+ for frame in vision.get_frames('Face Detector', size=(640, 480)):
+ faces = detector.get_objects(frame)
vision.draw_objects(frame, faces, color=(255, 0, 255), thickness=5)
def run_classifier_example():
labels = vision.load_labels('imagenet_labels.txt')
- classifier = vision.make_classifier('mobilenet_v2_1.0_224_quant_edgetpu.tflite')
- for frame in vision.Camera('Object Classifier', size=(640, 480)):
- classes = classifier(frame)
+ classifier = vision.Classifier('mobilenet_v2_1.0_224_quant_edgetpu.tflite')
+ for frame in vision.get_frames('Object Classifier', size=(640, 480)):
+ classes = classifier.get_classes(frame)
vision.draw_classes(frame, classes, labels, color=(255, 0, 255))
if __name__ == '__main__':
diff --git a/vision.py b/vision.py
index 957bf1d..d2262ae 100644
--- a/vision.py
+++ b/vision.py
@@ -16,26 +16,28 @@
experimental_delegates=[tflite.load_delegate('libedgetpu.so.1',
{'device': device[0]} if device else {})])
-def make_detector(model, threshold=0.01):
- interpreter = make_interpreter(model)
- interpreter.allocate_tensors()
- def process(frame):
- height, width, _ = frame.shape
- scale = detect.set_input(interpreter, (width, height),
- lambda size: cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
- interpreter.invoke()
- return detect.get_output(interpreter, threshold, scale)
- return process
+class Detector:
+ def __init__(self, model):
+ self.interpreter = make_interpreter(model)
+ self.interpreter.allocate_tensors()
-def make_classifier(model, top_k=1, threshold=0.0):
- interpreter = make_interpreter(model)
- interpreter.allocate_tensors()
- size = classify.input_size(interpreter)
- def process(frame):
- classify.set_input(interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
- interpreter.invoke()
- return classify.get_output(interpreter, top_k, threshold)
- return process
+ def get_objects(self, frame, threshold=0.01):
+ height, width, _ = frame.shape
+ scale = detect.set_input(self.interpreter, (width, height),
+ lambda size: cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
+ self.interpreter.invoke()
+ return detect.get_output(self.interpreter, threshold, scale)
+
+class Classifier:
+ def __init__(self, model):
+ self.interpreter = make_interpreter(model)
+ self.interpreter.allocate_tensors()
+
+ def get_classes(self, frame, top_k=1, threshold=0.0):
+ size = classify.input_size(self.interpreter)
+ classify.set_input(self.interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
+ self.interpreter.invoke()
+ return classify.get_output(self.interpreter, top_k, threshold)
def draw_objects(frame, objs, color, thickness):
for obj in objs:
@@ -47,7 +49,7 @@
label = '%s (%.2f)' % (labels.get(index, 'n/a'), score)
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2.0, color, 2)
-def Camera(title, size):
+def get_frames(title, size):
width, height = size
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)