blob: e3fb8b4d75c794bd6c1e384c361fb06f4237cba9 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A demo to classify Raspberry Pi camera stream."""
import argparse
import io
import time
import numpy as np
import picamera
import edgetpu.classification.engine
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', help='File path of Tflite model.', required=True)
parser.add_argument(
'--label', help='File path of label file.', required=True)
args = parser.parse_args()
with open(args.label, 'r', encoding="utf-8") as f:
pairs = (l.strip().split(maxsplit=1) for l in f.readlines())
labels = dict((int(k), v) for k, v in pairs)
engine = edgetpu.classification.engine.ClassificationEngine(args.model)
with picamera.PiCamera() as camera:
camera.resolution = (640, 480)
camera.framerate = 30
_, width, height, channels = engine.get_input_tensor_shape()
camera.start_preview()
try:
stream = io.BytesIO()
for foo in camera.capture_continuous(stream,
format='rgb',
use_video_port=True,
resize=(width, height)):
stream.truncate()
stream.seek(0)
input = np.frombuffer(stream.getvalue(), dtype=np.uint8)
start_ms = time.time()
results = engine.ClassifyWithInputTensor(input, top_k=1)
elapsed_ms = time.time() - start_ms
if results:
camera.annotate_text = "%s %.2f\n%.2fms" % (
labels[results[0][0]], results[0][1], elapsed_ms*1000.0)
finally:
camera.stop_preview()
if __name__ == '__main__':
main()