blob: 3d409eb927c23d148e3df16c3b8e897965471a8d [file] [log] [blame]
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
#include "model.h"
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *input = nullptr;
TfLiteTensor *output = nullptr;
int inference_count = 0;
const int kInferencesPerCycle = 1000;
const float kXrange = 2.f * 3.14159265359f;
const int kModelArenaSize = 4096;
const int kExtraArenaSize = 4096;
const int kTensorArenaSize = kModelArenaSize + kExtraArenaSize;
uint8_t tensor_arena[kTensorArenaSize] __attribute__((aligned(16)));
} // namespace
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_val, float y_val) {
int brightness = (int)(127.5f * (y_val + 1));
TF_LITE_REPORT_ERROR(error_reporter, "%d", brightness);
}
extern "C" void hello_world_tflite_setup(void) {
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
TF_LITE_REPORT_ERROR(error_reporter, "Hello from TFLite micro!");
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model schema version is %d, supported is %d",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
static tflite::AllOpsResolver resolver;
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors failed");
return;
}
input = interpreter->input(0);
output = interpreter->output(0);
TF_LITE_REPORT_ERROR(error_reporter, "setup() complete.");
}
extern "C" void hello_world_tflite_loop(void) {
float position = static_cast<float>(inference_count) /
static_cast<float>(kInferencesPerCycle);
float x_val = position * kXrange;
input->data.f[0] = x_val;
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x_val: %f",
static_cast<double>(x_val));
return;
}
float y_val = output->data.f[0];
HandleOutput(error_reporter, x_val, y_val);
inference_count += 1;
if (inference_count >= kInferencesPerCycle) {
inference_count = 0;
}
}