| #include "tensorflow/lite/micro/all_ops_resolver.h" |
| #include "tensorflow/lite/micro/micro_error_reporter.h" |
| #include "tensorflow/lite/micro/micro_interpreter.h" |
| #include "tensorflow/lite/schema/schema_generated.h" |
| #include "tensorflow/lite/version.h" |
| #include "model.h" |
| |
| namespace { |
| tflite::ErrorReporter *error_reporter = nullptr; |
| const tflite::Model *model = nullptr; |
| tflite::MicroInterpreter *interpreter = nullptr; |
| |
| TfLiteTensor *input = nullptr; |
| TfLiteTensor *output = nullptr; |
| int inference_count = 0; |
| const int kInferencesPerCycle = 1000; |
| const float kXrange = 2.f * 3.14159265359f; |
| |
| const int kModelArenaSize = 4096; |
| const int kExtraArenaSize = 4096; |
| const int kTensorArenaSize = kModelArenaSize + kExtraArenaSize; |
| uint8_t tensor_arena[kTensorArenaSize] __attribute__((aligned(16))); |
| } // namespace |
| |
| void HandleOutput(tflite::ErrorReporter *error_reporter, float x_val, float y_val) { |
| int brightness = (int)(127.5f * (y_val + 1)); |
| TF_LITE_REPORT_ERROR(error_reporter, "%d", brightness); |
| } |
| |
| extern "C" void hello_world_tflite_setup(void) { |
| static tflite::MicroErrorReporter micro_error_reporter; |
| error_reporter = µ_error_reporter; |
| TF_LITE_REPORT_ERROR(error_reporter, "Hello from TFLite micro!"); |
| |
| model = tflite::GetModel(g_model); |
| if (model->version() != TFLITE_SCHEMA_VERSION) { |
| TF_LITE_REPORT_ERROR(error_reporter, |
| "Model schema version is %d, supported is %d", |
| model->version(), TFLITE_SCHEMA_VERSION); |
| return; |
| } |
| |
| static tflite::AllOpsResolver resolver; |
| static tflite::MicroInterpreter static_interpreter( |
| model, resolver, tensor_arena, kTensorArenaSize, error_reporter); |
| interpreter = &static_interpreter; |
| |
| TfLiteStatus allocate_status = interpreter->AllocateTensors(); |
| if (allocate_status != kTfLiteOk) { |
| TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors failed"); |
| return; |
| } |
| |
| input = interpreter->input(0); |
| output = interpreter->output(0); |
| |
| TF_LITE_REPORT_ERROR(error_reporter, "setup() complete."); |
| } |
| |
| extern "C" void hello_world_tflite_loop(void) { |
| float position = static_cast<float>(inference_count) / |
| static_cast<float>(kInferencesPerCycle); |
| float x_val = position * kXrange; |
| |
| input->data.f[0] = x_val; |
| |
| TfLiteStatus invoke_status = interpreter->Invoke(); |
| if (invoke_status != kTfLiteOk) { |
| TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x_val: %f", |
| static_cast<double>(x_val)); |
| return; |
| } |
| |
| float y_val = output->data.f[0]; |
| |
| HandleOutput(error_reporter, x_val, y_val); |
| |
| inference_count += 1; |
| if (inference_count >= kInferencesPerCycle) { |
| inference_count = 0; |
| } |
| } |