Use different colors for different detected objects.

Change-Id: Iacefd5f0388b50f781aaed647795ac1396a2b4d1
diff --git a/edgetpuvision/classify.py b/edgetpuvision/classify.py
index 0a5c145..8ca2f55 100644
--- a/edgetpuvision/classify.py
+++ b/edgetpuvision/classify.py
@@ -13,11 +13,33 @@
 
 from edgetpu.classification.engine import ClassificationEngine
 
-from . import overlays
+from . import svg
 from .utils import load_labels, input_image_size, same_input_image_sizes, avg_fps_counter
 from .gstreamer import Display, run_gen
 
 
+CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
+                               '.shd': svg.Style(fill='black', fill_opacity=0.6)}))
+
+def overlay(results, inference_time, inference_rate, layout):
+    x0, y0, w, h = layout.window
+
+    lines = [
+        'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
+        'Inference frame rate: %.2f fps' % inference_rate
+    ]
+
+    for i, (label, score) in enumerate(results):
+        lines.append('%s (%.2f)' % (label, score))
+
+    defs = svg.Defs()
+    defs += CSS_STYLES
+
+    doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
+    doc += defs
+    doc += svg.normal_text(lines, x=x0 + 10, y=y0 + 10, font_size_em=1.1)
+    return str(doc)
+
 def top_results(window, top_k):
     total_scores = collections.defaultdict(lambda: 0.0)
     for results in window:
@@ -68,7 +90,7 @@
             if args.print:
                 print_results(inference_rate, results)
 
-            output = overlays.classification(results, inference_time, inference_rate, layout)
+            output = overlay(results, inference_time, inference_rate, layout)
         else:
             output = None
 
diff --git a/edgetpuvision/detect.py b/edgetpuvision/detect.py
index 6a38c3e..7b73166 100644
--- a/edgetpuvision/detect.py
+++ b/edgetpuvision/detect.py
@@ -13,15 +13,20 @@
 
 import argparse
 import collections
+import colorsys
 import itertools
 import time
 
 from edgetpu.detection.engine import DetectionEngine
 
-from . import overlays
+from . import svg
 from .utils import load_labels, input_image_size, same_input_image_sizes, avg_fps_counter
 from .gstreamer import Display, run_gen
 
+CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
+                               '.shd': svg.Style(fill='black', fill_opacity=0.6),
+                               'rect': svg.Style(fill_opacity=0.1, stroke_width='2px')}))
+
 BBox = collections.namedtuple('BBox', ('x', 'y', 'w', 'h'))
 BBox.area = lambda self: self.w * self.h
 BBox.scale = lambda self, sx, sy: BBox(x=self.x * sx, y=self.y * sy,
@@ -31,6 +36,41 @@
 Object = collections.namedtuple('Object', ('id', 'label', 'score', 'bbox'))
 Object.__str__ = lambda self: 'Object(id=%d, label=%s, score=%.2f, %s)' % self
 
+def color(i, total):
+    return tuple(int(255.0 * c) for c in colorsys.hsv_to_rgb(i / total, 1.0, 1.0))
+
+def gen_colors(keys):
+    return {key : color(i, len(keys)) for i, key in enumerate(keys)}
+
+def overlay(objs, colors, inference_time, inference_rate, layout):
+    x0, y0, w, h = layout.window
+
+    defs = svg.Defs()
+    defs += CSS_STYLES
+
+    doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
+    doc += defs
+    doc += svg.normal_text((
+        'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
+        'Inference frame rate: %.2f fps' % inference_rate,
+        'Objects: %d' % len(objs),
+    ), x0 + 10, y0 + 10, font_size_em=1.1)
+
+    for obj in objs:
+        percent = int(100 * obj.score)
+        if obj.label:
+            caption = '%d%% %s' % (percent, obj.label)
+        else:
+            caption = '%d%%' % percent
+
+        x, y, w, h = obj.bbox.scale(*layout.size)
+        doc += svg.normal_text(caption, x, y - 5)
+        doc += svg.Rect(x=x, y=y, width=w, height=h, rx=2, ry=2,
+                        style='stroke:%s' % svg.rgb(colors[obj.id]))
+
+    return str(doc)
+
+
 def convert(obj, labels):
     x0, y0, x1, y1 = obj.bounding_box.flatten().tolist()
     return Object(id=obj.label_id,
@@ -53,6 +93,7 @@
 
     labels = load_labels(args.labels) if args.labels else None
     filtered_labels = set(l.strip() for l in args.filter.split(',')) if args.filter else None
+    colors = gen_colors(labels.keys())
     draw_overlay = True
 
     yield input_image_size(engine)
@@ -76,7 +117,7 @@
             if args.print:
                 print_results(inference_rate, objs)
 
-            output = overlays.detection(objs, inference_time, inference_rate, layout)
+            output = overlay(objs, colors, inference_time, inference_rate, layout)
         else:
             output = None
 
diff --git a/edgetpuvision/overlays.py b/edgetpuvision/overlays.py
deleted file mode 100644
index bcc8fd6..0000000
--- a/edgetpuvision/overlays.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from . import svg
-
-CSS_STYLES = str(svg.CssStyle({'.txt': svg.Style(fill='white'),
-                               '.shd': svg.Style(fill='black', fill_opacity=0.6),
-                               'rect': svg.Style(fill='green', fill_opacity=0.3, stroke='white')}))
-
-def classification(results, inference_time, inference_rate, layout):
-    x0, y0, w, h = layout.window
-
-    lines = [
-        'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
-        'Inference frame rate: %.2f fps' % inference_rate
-    ]
-
-    for i, (label, score) in enumerate(results):
-        lines.append('%s (%.2f)' % (label, score))
-
-    defs = svg.Defs()
-    defs += CSS_STYLES
-
-    doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
-    doc += defs
-    doc += svg.normal_text(lines, x=x0 + 10, y=y0 + 10, font_size_em=1.1)
-    return str(doc)
-
-def detection(objs, inference_time, inference_rate, layout):
-    x0, y0, w, h = layout.window
-
-    defs = svg.Defs()
-    defs += CSS_STYLES
-
-    doc = svg.Svg(width=w, height=h, viewBox='%s %s %s %s' % layout.window, font_size='26px')
-    doc += defs
-    doc += svg.normal_text((
-        'Inference time: %.2f ms (%.2f fps)' % (inference_time * 1000, 1.0 / inference_time),
-        'Inference frame rate: %.2f fps' % inference_rate,
-        'Objects: %d' % len(objs),
-    ), x0 + 10, y0 + 10, font_size_em=1.1)
-
-    for obj in objs:
-        percent = int(100 * obj.score)
-        if obj.label:
-            caption = '%d%% %s' % (percent, obj.label)
-        else:
-            caption = '%d%%' % percent
-
-        x, y, w, h = obj.bbox.scale(*layout.size)
-        doc += svg.normal_text(caption, x, y - 5)
-        doc += svg.Rect(x=x, y=y, width=w, height=h, rx=2, ry=2)
-
-    return str(doc)