Use different colors for different detected objects.
Change-Id: Iacefd5f0388b50f781aaed647795ac1396a2b4d1
diff --git a/edgetpuvision/classify.py b/edgetpuvision/classify.py
index 0a5c145..8ca2f55 100644
--- a/edgetpuvision/classify.py
+++ b/edgetpuvision/classify.py
@@ -13,11 +13,33 @@
from edgetpu.classification.engine import ClassificationEngine
-from . import overlays
+from . import svg
from .utils import load_labels, input_image_size, same_input_image_sizes, avg_fps_counter
from .gstreamer import Display, run_gen
+CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
+ '.shd': svg.Style(fill='black', fill_opacity=0.6)}))
+
+def overlay(results, inference_time, inference_rate, layout):
+ x0, y0, w, h = layout.window
+
+ lines = [
+ 'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
+ 'Inference frame rate: %.2f fps' % inference_rate
+ ]
+
+ for i, (label, score) in enumerate(results):
+ lines.append('%s (%.2f)' % (label, score))
+
+ defs = svg.Defs()
+ defs += CSS_STYLES
+
+ doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
+ doc += defs
+ doc += svg.normal_text(lines, x=x0 + 10, y=y0 + 10, font_size_em=1.1)
+ return str(doc)
+
def top_results(window, top_k):
total_scores = collections.defaultdict(lambda: 0.0)
for results in window:
@@ -68,7 +90,7 @@
if args.print:
print_results(inference_rate, results)
- output = overlays.classification(results, inference_time, inference_rate, layout)
+ output = overlay(results, inference_time, inference_rate, layout)
else:
output = None
diff --git a/edgetpuvision/detect.py b/edgetpuvision/detect.py
index 6a38c3e..7b73166 100644
--- a/edgetpuvision/detect.py
+++ b/edgetpuvision/detect.py
@@ -13,15 +13,20 @@
import argparse
import collections
+import colorsys
import itertools
import time
from edgetpu.detection.engine import DetectionEngine
-from . import overlays
+from . import svg
from .utils import load_labels, input_image_size, same_input_image_sizes, avg_fps_counter
from .gstreamer import Display, run_gen
+CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
+ '.shd': svg.Style(fill='black', fill_opacity=0.6),
+ 'rect': svg.Style(fill_opacity=0.1, stroke_width='2px')}))
+
BBox = collections.namedtuple('BBox', ('x', 'y', 'w', 'h'))
BBox.area = lambda self: self.w * self.h
BBox.scale = lambda self, sx, sy: BBox(x=self.x * sx, y=self.y * sy,
@@ -31,6 +36,41 @@
Object = collections.namedtuple('Object', ('id', 'label', 'score', 'bbox'))
Object.__str__ = lambda self: 'Object(id=%d, label=%s, score=%.2f, %s)' % self
+def color(i, total):
+ return tuple(int(255.0 * c) for c in colorsys.hsv_to_rgb(i / total, 1.0, 1.0))
+
+def gen_colors(keys):
+ return {key : color(i, len(keys)) for i, key in enumerate(keys)}
+
+def overlay(objs, colors, inference_time, inference_rate, layout):
+ x0, y0, w, h = layout.window
+
+ defs = svg.Defs()
+ defs += CSS_STYLES
+
+ doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
+ doc += defs
+ doc += svg.normal_text((
+ 'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
+ 'Inference frame rate: %.2f fps' % inference_rate,
+ 'Objects: %d' % len(objs),
+ ), x0 + 10, y0 + 10, font_size_em=1.1)
+
+ for obj in objs:
+ percent = int(100 * obj.score)
+ if obj.label:
+ caption = '%d%% %s' % (percent, obj.label)
+ else:
+ caption = '%d%%' % percent
+
+ x, y, w, h = obj.bbox.scale(*layout.size)
+ doc += svg.normal_text(caption, x, y - 5)
+ doc += svg.Rect(x=x, y=y, width=w, height=h, rx=2, ry=2,
+ style='stroke:%s' % svg.rgb(colors[obj.id]))
+
+ return str(doc)
+
+
def convert(obj, labels):
x0, y0, x1, y1 = obj.bounding_box.flatten().tolist()
return Object(id=obj.label_id,
@@ -53,6 +93,7 @@
labels = load_labels(args.labels) if args.labels else None
filtered_labels = set(l.strip() for l in args.filter.split(',')) if args.filter else None
+ colors = gen_colors(labels.keys())
draw_overlay = True
yield input_image_size(engine)
@@ -76,7 +117,7 @@
if args.print:
print_results(inference_rate, objs)
- output = overlays.detection(objs, inference_time, inference_rate, layout)
+ output = overlay(objs, colors, inference_time, inference_rate, layout)
else:
output = None
diff --git a/edgetpuvision/overlays.py b/edgetpuvision/overlays.py
deleted file mode 100644
index bcc8fd6..0000000
--- a/edgetpuvision/overlays.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from . import svg
-
-CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
- '.shd': svg.Style(fill='black', fill_opacity=0.6),
- 'rect': svg.Style(fill='green', fill_opacity=0.3, stroke='white')}))
-
-def classification(results, inference_time, inference_rate, layout):
- x0, y0, w, h = layout.window
-
- lines = [
- 'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
- 'Inference frame rate: %.2f fps' % inference_rate
- ]
-
- for i, (label, score) in enumerate(results):
- lines.append('%s (%.2f)' % (label, score))
-
- defs = svg.Defs()
- defs += CSS_STYLES
-
- doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
- doc += defs
- doc += svg.normal_text(lines, x=x0 + 10, y=y0 + 10, font_size_em=1.1)
- return str(doc)
-
-def detection(objs, inference_time, inference_rate, layout):
- x0, y0, w, h = layout.window
-
- defs = svg.Defs()
- defs += CSS_STYLES
-
- doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
- doc += defs
- doc += svg.normal_text((
- 'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
- 'Inference frame rate: %.2f fps' % inference_rate,
- 'Objects: %d' % len(objs),
- ), x0 + 10, y0 + 10, font_size_em=1.1)
-
- for obj in objs:
- percent = int(100 * obj.score)
- if obj.label:
- caption = '%d%% %s' % (percent, obj.label)
- else:
- caption = '%d%%' % percent
-
- x, y, w, h = obj.bbox.scale(*layout.size)
- doc += svg.normal_text(caption, x, y - 5)
- doc += svg.Rect(x=x, y=y, width=w, height=h, rx=2, ry=2)
-
- return str(doc)