add object detection example and add code to draw labels with boxes
Also remove the custom load_labels function and use pycoral's util instead
(only the pycoral version properly handles the weird COCO label IDs)
Change-Id: I2405ce0150559f50a3951b08a86bb4133c82da9b
diff --git a/Makefile b/Makefile
index df15053..e9d2515 100644
--- a/Makefile
+++ b/Makefile
@@ -13,6 +13,12 @@
deb:
sudo apt-get install -y python3-numpy python3-pyaudio python3-opencv
+coco_labels.txt:
+ wget "$(TEST_DATA_URL)/$@"
+
+ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite:
+ wget "$(TEST_DATA_URL)/$@"
+
imagenet_labels.txt:
wget "$(TEST_DATA_URL)/$@"
@@ -28,7 +34,9 @@
voice_commands_v0.7_edgetpu.tflite:
wget "https://github.com/google-coral/project-keyword-spotter/raw/master/models/$@"
-download: imagenet_labels.txt \
+download: coco_labels.txt \
+ ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite \
+ imagenet_labels.txt \
ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite \
mobilenet_v2_1.0_224_quant_edgetpu.tflite \
labels_gc2.raw.txt \
diff --git a/raspimon_sees_things.py b/raspimon_sees_things.py
index d0d840b..3f4d5d7 100644
--- a/raspimon_sees_things.py
+++ b/raspimon_sees_things.py
@@ -1,6 +1,7 @@
from sense_hat import SenseHat
from threading import Thread
from queue import Queue
+from pycoral.utils.dataset import read_label_file
import vision
# Initialize SenseHat instance and clear the LED matrix
@@ -8,7 +9,7 @@
sense.clear()
# Load the neural network model
-labels = vision.load_labels(vision.CLASSIFICATION_LABELS)
+labels = read_label_file(vision.CLASSIFICATION_LABELS)
classifier = vision.Classifier(vision.CLASSIFICATION_MODEL)
diff --git a/vision.py b/vision.py
index 9701bb5..cb3e661 100644
--- a/vision.py
+++ b/vision.py
@@ -30,15 +30,13 @@
}[platform.system()]
FACE_DETECTION_MODEL = 'ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite'
+OBJECT_DETECTION_MODEL = 'ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite'
+OBJECT_DETECTION_LABELS = 'coco_labels.txt'
CLASSIFICATION_MODEL = 'tf2_mobilenet_v2_1.0_224_ptq_edgetpu.tflite'
CLASSIFICATION_LABELS = 'imagenet_labels.txt'
CORAL_COLOR = (86, 104, 237)
-def load_labels(filename, encoding='utf-8'):
- with open(filename, 'r', encoding=encoding) as f:
- return {index : line.strip() for (index, line) in enumerate(f.readlines())}
-
def make_interpreter(model_file):
model_file, *device = model_file.split('@')
return tflite.Interpreter(
@@ -70,10 +68,13 @@
self.interpreter.invoke()
return classify.get_classes(self.interpreter, top_k, threshold)
-def draw_objects(frame, objs, color=CORAL_COLOR, thickness=5):
+def draw_objects(frame, objs, labels=None, color=CORAL_COLOR, thickness=5):
for obj in objs:
bbox = obj.bbox
cv2.rectangle(frame, (bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax), color, thickness)
+ if labels:
+ cv2.putText(frame, labels.get(obj.id), (bbox.xmin + thickness, bbox.ymax - thickness),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=CORAL_COLOR, thickness=2)
def draw_classes(frame, classes, labels, color=CORAL_COLOR):
for index, score in classes:
diff --git a/vision_example.py b/vision_example.py
index 93f9572..4977799 100644
--- a/vision_example.py
+++ b/vision_example.py
@@ -13,15 +13,23 @@
# limitations under the License.
import vision
+from pycoral.utils.dataset import read_label_file
-def run_detector_example():
+def run_object_detector_example():
+ detector = vision.Detector(vision.OBJECT_DETECTION_MODEL)
+ labels = read_label_file(vision.OBJECT_DETECTION_LABELS)
+ for frame in vision.get_frames('Object Detector', size=(640, 480)):
+ objects = detector.get_objects(frame, threshold=0.2)
+ vision.draw_objects(frame, objects, labels)
+
+def run_face_detector_example():
detector = vision.Detector(vision.FACE_DETECTION_MODEL)
for frame in vision.get_frames('Face Detector', size=(640, 480)):
faces = detector.get_objects(frame)
vision.draw_objects(frame, faces)
def run_classifier_example():
- labels = vision.load_labels(vision.CLASSIFICATION_LABELS)
+ labels = read_label_file(vision.CLASSIFICATION_LABELS)
classifier = vision.Classifier(vision.CLASSIFICATION_MODEL)
for frame in vision.get_frames('Object Classifier', size=(640, 480)):
classes = classifier.get_classes(frame)
@@ -29,4 +37,4 @@
if __name__ == '__main__':
#run_classifier_example()
- run_detector_example()
+ run_object_detector_example()