Use pycoral adapters
Change-Id: I38a6112052a4b206824faf0868c945b73d1e224f
diff --git a/Makefile b/Makefile
index c368b19..df15053 100644
--- a/Makefile
+++ b/Makefile
@@ -2,13 +2,16 @@
TEST_DATA_URL=https://github.com/google-coral/edgetpu/raw/master/test_data
VENV_NAME=.env
-.PHONY: venv download clean
+.PHONY: venv deb download clean
venv:
rm -rf $(VENV_NAME)
python3 -m venv --system-site-packages $(VENV_NAME)
$(SHELL) -c "source $(VENV_NAME)/bin/activate && pip3 install --upgrade pip"
- $(SHELL) -c "source $(VENV_NAME)/bin/activate && pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite-runtime"
+ $(SHELL) -c "source $(VENV_NAME)/bin/activate && pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite-runtime pycoral"
+
+deb:
+ sudo apt-get install -y python3-numpy python3-pyaudio python3-opencv
imagenet_labels.txt:
wget "$(TEST_DATA_URL)/$@"
diff --git a/classify.py b/classify.py
deleted file mode 100644
index abcf62d..0000000
--- a/classify.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Functions to work with classification models."""
-
-import collections
-import operator
-import numpy as np
-
-Class = collections.namedtuple('Class', ['id', 'score'])
-
-
-def input_details(interpreter, key):
- """Returns input details by specified key."""
- return interpreter.get_input_details()[0][key]
-
-
-def input_size(interpreter):
- """Returns input image size as (width, height) tuple."""
- _, height, width, _ = input_details(interpreter, 'shape')
- return width, height
-
-
-def input_tensor(interpreter):
- """Returns input tensor view as numpy array of shape (height, width, 3)."""
- tensor_index = input_details(interpreter, 'index')
- return interpreter.tensor(tensor_index)()[0]
-
-
-def output_tensor(interpreter, dequantize=True):
- """Returns output tensor of classification model.
-
- Integer output tensor is dequantized by default.
-
- Args:
- interpreter: tflite.Interpreter;
- dequantize: bool; whether to dequantize integer output tensor.
-
- Returns:
- Output tensor as numpy array.
- """
- output_details = interpreter.get_output_details()[0]
- output_data = np.squeeze(interpreter.tensor(output_details['index'])())
-
- if dequantize and np.issubdtype(output_details['dtype'], np.integer):
- scale, zero_point = output_details['quantization']
- return scale * (output_data - zero_point)
-
- return output_data
-
-
-def set_input(interpreter, data):
- """Copies data to input tensor."""
- input_tensor(interpreter)[:, :] = data
-
-
-def get_output(interpreter, top_k=1, score_threshold=0.0):
- """Returns no more than top_k classes with score >= score_threshold."""
- scores = output_tensor(interpreter)
- classes = [
- Class(i, scores[i])
- for i in np.argpartition(scores, -top_k)[-top_k:]
- if scores[i] >= score_threshold
- ]
- return sorted(classes, key=operator.itemgetter(1), reverse=True)
diff --git a/detect.py b/detect.py
deleted file mode 100644
index 7f35b6b..0000000
--- a/detect.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Functions to work with detection models."""
-
-import collections
-import numpy as np
-
-Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])
-
-
-class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])):
- """Bounding box.
-
- Represents a rectangle which sides are either vertical or horizontal, parallel
- to the x or y axis.
- """
- __slots__ = ()
-
- @property
- def width(self):
- """Returns bounding box width."""
- return self.xmax - self.xmin
-
- @property
- def height(self):
- """Returns bounding box height."""
- return self.ymax - self.ymin
-
- @property
- def area(self):
- """Returns bound box area."""
- return self.width * self.height
-
- @property
- def valid(self):
- """Returns whether bounding box is valid or not.
-
- Valid bounding box has xmin <= xmax and ymin <= ymax which is equivalent to
- width >= 0 and height >= 0.
- """
- return self.width >= 0 and self.height >= 0
-
- def scale(self, sx, sy):
- """Returns scaled bounding box."""
- return BBox(xmin=sx * self.xmin,
- ymin=sy * self.ymin,
- xmax=sx * self.xmax,
- ymax=sy * self.ymax)
-
- def translate(self, dx, dy):
- """Returns translated bounding box."""
- return BBox(xmin=dx + self.xmin,
- ymin=dy + self.ymin,
- xmax=dx + self.xmax,
- ymax=dy + self.ymax)
-
- def map(self, f):
- """Returns bounding box modified by applying f for each coordinate."""
- return BBox(xmin=f(self.xmin),
- ymin=f(self.ymin),
- xmax=f(self.xmax),
- ymax=f(self.ymax))
-
- @staticmethod
- def intersect(a, b):
- """Returns the intersection of two bounding boxes (may be invalid)."""
- return BBox(xmin=max(a.xmin, b.xmin),
- ymin=max(a.ymin, b.ymin),
- xmax=min(a.xmax, b.xmax),
- ymax=min(a.ymax, b.ymax))
-
- @staticmethod
- def union(a, b):
- """Returns the union of two bounding boxes (always valid)."""
- return BBox(xmin=min(a.xmin, b.xmin),
- ymin=min(a.ymin, b.ymin),
- xmax=max(a.xmax, b.xmax),
- ymax=max(a.ymax, b.ymax))
-
- @staticmethod
- def iou(a, b):
- """Returns intersection-over-union value."""
- intersection = BBox.intersect(a, b)
- if not intersection.valid:
- return 0.0
- area = intersection.area
- return area / (a.area + b.area - area)
-
-
-def input_size(interpreter):
- """Returns input image size as (width, height) tuple."""
- _, height, width, _ = interpreter.get_input_details()[0]['shape']
- return width, height
-
-
-def input_tensor(interpreter):
- """Returns input tensor view as numpy array of shape (height, width, 3)."""
- tensor_index = interpreter.get_input_details()[0]['index']
- return interpreter.tensor(tensor_index)()[0]
-
-
-def set_input(interpreter, size, resize):
- """Copies a resized and properly zero-padded image to the input tensor.
-
- Args:
- interpreter: Interpreter object.
- size: original image size as (width, height) tuple.
- resize: a function that takes a (width, height) tuple, and returns an RGB
- image resized to those dimensions.
- Returns:
- Actual resize ratio, which should be passed to `get_output` function.
- """
- width, height = input_size(interpreter)
- w, h = size
- scale = min(width / w, height / h)
- w, h = int(w * scale), int(h * scale)
- tensor = input_tensor(interpreter)
- tensor.fill(0) # padding
- _, _, channel = tensor.shape
- tensor[:h, :w] = np.reshape(resize((w, h)), (h, w, channel))
- return scale, scale
-
-
-def output_tensor(interpreter, i):
- """Returns output tensor view."""
- tensor = interpreter.tensor(interpreter.get_output_details()[i]['index'])()
- return np.squeeze(tensor)
-
-
-def get_output(interpreter, score_threshold, image_scale=(1.0, 1.0)):
- """Returns list of detected objects."""
- boxes = output_tensor(interpreter, 0)
- class_ids = output_tensor(interpreter, 1)
- scores = output_tensor(interpreter, 2)
- count = int(output_tensor(interpreter, 3))
-
- width, height = input_size(interpreter)
- image_scale_x, image_scale_y = image_scale
- sx, sy = width / image_scale_x, height / image_scale_y
-
- def make(i):
- ymin, xmin, ymax, xmax = boxes[i]
- return Object(
- id=int(class_ids[i]),
- score=float(scores[i]),
- bbox=BBox(xmin=xmin,
- ymin=ymin,
- xmax=xmax,
- ymax=ymax).scale(sx, sy).map(int))
-
- return [make(i) for i in range(count) if scores[i] >= score_threshold]
diff --git a/vision.py b/vision.py
index 0cdc34b..8955783 100644
--- a/vision.py
+++ b/vision.py
@@ -2,8 +2,9 @@
import numpy as np
import tflite_runtime.interpreter as tflite
-import classify
-import detect
+from pycoral.adapters import common
+from pycoral.adapters import classify
+from pycoral.adapters import detect
FACE_DETECTION_MODEL = 'ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite'
CLASSIFICATION_MODEL = 'mobilenet_v2_1.0_224_quant_edgetpu.tflite'
@@ -29,10 +30,11 @@
def get_objects(self, frame, threshold=0.01):
height, width, _ = frame.shape
- scale = detect.set_input(self.interpreter, (width, height),
- lambda size: cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
+ _, scale = common.set_resized_input(self.interpreter, (width, height),
+ lambda size: cv2.resize(frame, size, fx=0, fy=0,
+ interpolation=cv2.INTER_CUBIC))
self.interpreter.invoke()
- return detect.get_output(self.interpreter, threshold, scale)
+ return detect.get_objects(self.interpreter, threshold, scale)
class Classifier:
def __init__(self, model):
@@ -40,10 +42,10 @@
self.interpreter.allocate_tensors()
def get_classes(self, frame, top_k=1, threshold=0.0):
- size = classify.input_size(self.interpreter)
- classify.set_input(self.interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
+ size = common.input_size(self.interpreter)
+ common.set_input(self.interpreter, cv2.resize(frame, size, fx=0, fy=0, interpolation = cv2.INTER_CUBIC))
self.interpreter.invoke()
- return classify.get_output(self.interpreter, top_k, threshold)
+ return classify.get_classes(self.interpreter, top_k, threshold)
def draw_objects(frame, objs, color=CORAL_COLOR, thickness=5):
for obj in objs:
diff --git a/vision_example.py b/vision_example.py
index a0a691a..82786c4 100644
--- a/vision_example.py
+++ b/vision_example.py
@@ -14,5 +14,5 @@
vision.draw_classes(frame, classes, labels)
if __name__ == '__main__':
- run_classifier_example()
- #run_detector_example()
+ #run_classifier_example()
+ run_detector_example()