blob: ac3f674e4f711e90bc4204bf31e4109d0aeb2337 [file] [log] [blame]
<!DOCTYPE html>
<html>
<head>
<meta charset='UTF-8'>
<title>Coral USB Accelerator Demo</title>
<script src='dfu.js'></script>
<script src='dfuse.js'></script>
<script src='models.js'></script>
<script src='tflite.js'></script>
<script src='interpreter.js'></script>
</head>
<body>
<h1>Coral USB Accelerator Demo</h1>
<button id='button-firmware'>1. Flash USB Firmware</button>&nbsp;<span id='firmware-status'></span>
<br/><br/>
<button id='button-init'>2. Initialize Interpreter</button>
<select id='model'>
<option value='mobilenet_tpu'>MobileNet V1 (TPU)</option>
<option value='mobilenet_cpu'>MobileNet V1 (CPU)</option>
<option value='ssd_mobilenet_coco_tpu'>SSD MobileNet V2 - Coco (TPU)</option>
<option value='ssd_mobilenet_coco_cpu'>SSD MobileNet V2 - Coco (CPU)</option>
<option value='ssd_mobilenet_face_tpu'>SSD MobileNet V2 - Face (TPU)</option>
<option value='ssd_mobilenet_face_cpu'>SSD MobileNet V2 - Face (CPU)</option>
</select>
<br/><br/>
<button id='button-image' onclick='document.getElementById("file").click();' disabled>3. Choose Image File</button>
<input type='file' id='file' style='display:none;'/>
<br/><br/>
<h2 id='result'></h2>
<canvas id='canvas'></canvas>
<br/><br/>
<script>
// DFU functions are based on the code from
// https://github.com/devanlai/webdfu/blob/gh-pages/dfu-util/dfu-util.js
async function getDFUDescriptorProperties(device) {
let data = await device.readConfigurationDescriptor(0);
let configDesc = dfu.parseConfigurationDescriptor(data);
let funcDesc = null;
if (configDesc.bConfigurationValue == device.settings.configuration.configurationValue) {
for (let desc of configDesc.descriptors) {
if (desc.bDescriptorType == 0x21 && desc.hasOwnProperty('bcdDFUVersion')) {
funcDesc = desc;
break;
}
}
}
if (!funcDesc) return {};
return {
WillDetach: ((funcDesc.bmAttributes & 0x08) != 0),
ManifestationTolerant: ((funcDesc.bmAttributes & 0x04) != 0),
CanUpload: ((funcDesc.bmAttributes & 0x02) != 0),
CanDnload: ((funcDesc.bmAttributes & 0x01) != 0),
TransferSize: funcDesc.wTransferSize,
DetachTimeOut: funcDesc.wDetachTimeOut,
DFUVersion: funcDesc.bcdDFUVersion
};
}
async function connect(device) {
await device.open();
let desc = await getDFUDescriptorProperties(device);
if (desc && Object.keys(desc).length > 0)
device.properties = desc;
device.logDebug = console.log;
device.logInfo = console.log;
device.logWarning = console.log;
device.logError = console.log;
device.logProgress = console.log;
return device;
}
async function openDevice() {
try {
let device = await navigator.usb.requestDevice({
'filters': [{'vendorId': 0x1a6e, 'productId': 0x089a}]
});
let interfaces = dfu.findDeviceDfuInterfaces(device);
if (interfaces.length != 1) return null;
return await connect(new dfu.Device(device, interfaces[0]));
} catch (error) {
return null;
}
}
async function loadLocalImage(file) {
return new Promise(resolve => {
var reader = new FileReader();
reader.onload = () => {
var img = new Image();
img.onload = () => { resolve(img); };
img.src = reader.result;
};
reader.readAsDataURL(file);
});
}
async function loadFile(url) {
return new Promise(resolve => {
let req = new XMLHttpRequest();
req.open('GET', url, true);
req.responseType = 'arraybuffer';
req.onload = event => { resolve(req.response); };
req.send(null);
});
}
document.addEventListener('DOMContentLoaded', event => {
let downloadButton = document.querySelector('#button-firmware');
downloadButton.addEventListener('click', async event => {
let firmwareFile = await loadFile('/firmware.bin');
let device = await openDevice();
if (device) {
console.log('Device: ', device.properties);
} else {
console.log('Cannot open device.');
}
if (device && firmwareFile) {
document.getElementById('firmware-status').textContent = 'Flashing...';
try {
let status = await device.getStatus();
if (status.state == dfu.dfuERROR) {
await device.clearStatus();
}
} catch (error) {
device.logWarning('Failed to clear status');
}
try {
let manifestationTolerant = device.properties.ManifestationTolerant;
await device.do_download(device.properties.TransferSize, firmwareFile, manifestationTolerant);
if (!manifestationTolerant)
await device.waitDisconnected(5000);
document.getElementById('firmware-status').textContent = 'Ready :)';
} catch (error) {
document.getElementById('firmware-status').textContent = 'Failed :(';
}
}
});
});
Module['onRuntimeInitialized'] = () => {
let interpreter;
let model;
document.querySelector('#button-init').addEventListener('click', async () => {
model = TFLITE_MODELS[document.getElementById('model').value];
console.log(model);
interpreter = new tflite.Interpreter();
if (await interpreter.createFromBuffer(await loadFile(model.url))) {
document.getElementById('button-firmware').disabled = true;
document.getElementById('button-init').disabled = true;
document.getElementById('model').disabled = true;
document.getElementById('button-image').disabled = false;
}
});
document.querySelector('#file').addEventListener('change', async () => {
const input = document.getElementById('file');
const file = input.files[0];
if (!file) return;
[_, height, width, _] = interpreter.inputShape(0);
console.log('SHAPE:', width, 'x', height);
const img = await loadLocalImage(file);
const c = document.getElementById('canvas');
c.width = width;
c.height = height;
const ctx = c.getContext('2d');
const color = 'red';
ctx.strokeStyle = color;
ctx.fillStyle = color;
ctx.textBaseline = 'top';
switch (model.type) {
case 'classification':
ctx.drawImage(img, 0, 0, img.width, img.height,
0, 0, width, height);
break;
case 'detection':
const alpha = Math.min(width / img.width, height / img.height);
ctx.drawImage(img, 0, 0, img.width, img.height,
0, 0, alpha * img.width, alpha * img.height);
break;
}
const imageData = ctx.getImageData(0, 0, width, height);
tflite.setRgbaInput(interpreter, imageData.data);
document.getElementById('result').textContent = 'Recognizing...';
const inferenceStart = Date.now();
await interpreter.invoke();
const inferenceTime = Date.now() - inferenceStart;
let label = null;
switch (model.type) {
case 'classification':
const maxIndex = tflite.getClassificationOutput(interpreter);
label = model.labels[maxIndex];
break;
case 'detection':
const objects = tflite.getDetectionOutput(interpreter, threshold=0.5);
for (obj of objects) {
const x = obj.bbox.xmin * width;
const y = obj.bbox.ymin * height;
const w = obj.bbox.xmax * width - x;
const h = obj.bbox.ymax * height - y;
ctx.strokeRect(x, y, w, h);
ctx.fillText(model.labels[obj.id], x + 5, y + 5);
}
label = `${objects.length} ${objects.length == 1 ? 'object' : 'objects'}`;
break;
}
document.getElementById('result').textContent = `${label}: ${inferenceTime} ms`;
});
};
Module['print'] = txt => console.log(txt);
</script>
</body>
</html>