blob: 97d373ff6bf31bec533574fbb65ca9efc3dd38db [file] [log] [blame]
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let tflite = {};
(function() {
'use strict';
let nextId = 0;
function rgbaArrayToRgbArray(rgbaArray) {
var rgbArray = new Uint8Array(new ArrayBuffer(3 * rgbaArray.length / 4));
for (var i = 0, j = 0; i < rgbaArray.length; i += 4, j += 3) {
rgbArray[j + 0] = rgbaArray[i + 0];
rgbArray[j + 1] = rgbaArray[i + 1];
rgbArray[j + 2] = rgbaArray[i + 2];
}
return rgbArray;
}
function callbackKey(id) {
return 'interpreter_' + id;
}
tflite.Interpreter = function() {
this.interpreter_create = Module.cwrap('interpreter_create', 'number', ['number'], { async: true });
this.interpreter_destroy = Module.cwrap('interpreter_destroy', null, ['number']);
this.interpreter_num_inputs = Module.cwrap('interpreter_num_inputs', 'number', ['number']);
this.interpreter_input_buffer = Module.cwrap('interpreter_input_buffer', 'number', ['number', 'number']);
this.interpreter_num_input_dims = Module.cwrap('interpreter_num_input_dims', 'number', ['number', 'number']);
this.interpreter_input_dim = Module.cwrap('interpreter_input_dim', 'number', ['number', 'number', 'number']);
this.interpreter_num_outputs = Module.cwrap('interpreter_num_outputs', 'number', ['number']);
this.interpreter_output_buffer = Module.cwrap('interpreter_output_buffer', 'number', ['number', 'number']);
this.interpreter_num_output_dims = Module.cwrap('interpreter_num_output_dims', 'number', ['number', 'number']);
this.interpreter_output_dim = Module.cwrap('interpreter_output_dim', 'number', ['number', 'number', 'number']);
this.interpreter_invoke_async = Module.cwrap('interpreter_invoke_async', null, ['number', 'number']);
this.id = nextId++;
Module['invokeDone'] = function(id) {
let key = callbackKey(id);
let callback = Module[key];
if (callback) callback();
delete Module[key];
};
}
tflite.Interpreter.prototype.create = async function(modelBuffer, modelBufferSize) {
this.interpreter = await this.interpreter_create(modelBuffer, modelBufferSize, 0);
return this.interpreter != null;
}
tflite.Interpreter.prototype.destroy = function() {
this.interpreter_destroy(this.interpreter);
}
tflite.Interpreter.prototype.numInputs = function() {
return this.interpreter_num_inputs(this.interpreter);
}
tflite.Interpreter.prototype.inputBuffer = function(index) {
return this.interpreter_input_buffer(this.interpreter, index);
}
tflite.Interpreter.prototype.inputShape = function(index) {
let dims = this.interpreter_num_input_dims(this.interpreter, index);
let shape = [];
for (let i = 0; i < dims; ++i)
shape.push(this.interpreter_input_dim(this.interpreter, index, i));
return shape;
}
tflite.Interpreter.prototype.numOutputs = function() {
return this.interpreter_num_outputs(this.interpreter);
}
tflite.Interpreter.prototype.outputBuffer = function(index) {
return this.interpreter_output_buffer(this.interpreter, index);
}
tflite.Interpreter.prototype.outputShape = function(index) {
let dims = this.interpreter_num_output_dims(this.interpreter, index);
let shape = [];
for (let i = 0; i < dims; ++i)
shape.push(this.interpreter_output_dim(this.interpreter, index, i));
return shape;
}
tflite.Interpreter.prototype.invoke = function(callback) {
Module[callbackKey(this.id)] = callback;
this.interpreter_invoke_async(this.interpreter, this.id);
}
tflite.Interpreter.prototype.setRgbInput = function(index, rgbArray) {
let shape = this.inputShape(index);
if (rgbArray.length != shape.reduce((a, b) => a * b))
throw new Error('Invalid input array size');
writeArrayToMemory(rgbArray, this.inputBuffer(index));
}
tflite.Interpreter.prototype.setRgbaInput = function(index, rgbaArray) {
this.setRgbInput(index, rgbaArrayToRgbArray(rgbaArray));
}
tflite.Interpreter.prototype.getClassificationOutput = function(index, rgbaArray) {
let size = this.outputShape(index).reduce((a, b) => a * b);
let buf = this.outputBuffer(index);
let tensor = Module.HEAPU8.slice(buf, buf + size);
return tensor.indexOf(Math.max(...tensor))
}
})();