blob: 672ade296279b08562474d2aeb59848b1f27b508 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A demo to classify image."""
import argparse
import re
from edgetpu.classification.engine import ClassificationEngine
from PIL import Image
# Function to read labels from text files.
def ReadLabelFile(file_path):
"""Reads labels from text file and store it in a dict.
Each line in the file contains id and description separted by colon or space.
Example: '0:cat' or '0 cat'.
Args:
file_path: String, path to the label file.
Returns:
Dict of (int, string) which maps label id to description.
"""
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
ret = {}
for line in lines:
pair = re.split(r'[:\s]+', line.strip(), maxsplit=1)
ret[int(pair[0])] = pair[1].strip()
return ret
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', help='File path of Tflite model.', required=True)
parser.add_argument(
'--label', help='File path of label file.', required=True)
parser.add_argument(
'--image', help='File path of the image to be recognized.', required=True)
args = parser.parse_args()
# Prepare labels.
labels = ReadLabelFile(args.label)
# Initialize engine.
engine = ClassificationEngine(args.model)
# Run inference.
img = Image.open(args.image)
for result in engine.ClassifyWithImage(img, top_k=3):
print('---------------------------')
print(labels[result[0]])
print('Score : ', result[1])
if __name__ == '__main__':
main()