-
Notifications
You must be signed in to change notification settings - Fork 126
/
base_task_api.h
152 lines (131 loc) · 6.16 KB
/
base_task_api.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
#include <utility>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
namespace tflite {
namespace task {
namespace core {
class BaseUntypedTaskApi {
public:
explicit BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine)
: engine_{std::move(engine)} {}
virtual ~BaseUntypedTaskApi() = default;
const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); }
const metadata::ModelMetadataExtractor* GetMetadataExtractor() const {
return engine_->metadata_extractor();
}
protected:
std::unique_ptr<TfLiteEngine> engine_;
};
template <class OutputType, class... InputTypes>
class BaseTaskApi : public BaseUntypedTaskApi {
public:
explicit BaseTaskApi(std::unique_ptr<TfLiteEngine> engine)
: BaseUntypedTaskApi(std::move(engine)) {}
// BaseTaskApi is neither copyable nor movable.
BaseTaskApi(const BaseTaskApi&) = delete;
BaseTaskApi& operator=(const BaseTaskApi&) = delete;
// Cancels the current running TFLite invocation on CPU.
//
// Usually called on a different thread than the one inference is running on.
// Calling Cancel() will cause the underlying TFLite interpreter to return an
// error, which will turn into a `CANCELLED` status and empty results. Calling
// Cancel() at the other time will not take any effect on the current or
// following invocation. It is perfectly fine to run inference again on the
// same instance after a cancelled invocation. If the TFLite inference is
// partially delegated on CPU, logs a warning message and only cancels the
// invocation running on CPU. Other invocation which depends on the output of
// the CPU invocation will not be executed.
void Cancel() { engine_->Cancel(); }
protected:
// Subclasses need to populate input_tensors from api_inputs.
virtual absl::Status Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
InputTypes... api_inputs) = 0;
// Subclasses need to construct OutputType object from output_tensors.
// Original inputs are also provided as they may be needed.
virtual tflite::support::StatusOr<OutputType> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
InputTypes... api_inputs) = 0;
// Returns the tensors associated with the given input/output indexes.
template <typename TensorType>
std::vector<TensorType*> GetTensors(const std::vector<int>& tensor_indices) {
tflite::Interpreter* interpreter = engine_->interpreter();
std::vector<TensorType*> tensors;
tensors.reserve(tensor_indices.size());
for (int index : tensor_indices) {
tensors.push_back(interpreter->tensor(index));
}
return tensors;
}
std::vector<TfLiteTensor*> GetInputTensors() {
return GetTensors<TfLiteTensor>(engine_->interpreter()->inputs());
}
std::vector<const TfLiteTensor*> GetOutputTensors() {
return GetTensors<const TfLiteTensor>(engine_->interpreter()->outputs());
}
// Performs inference using tflite::support::TfLiteInterpreterWrapper
// InvokeWithoutFallback().
tflite::support::StatusOr<OutputType> Infer(InputTypes... args) {
tflite::support::TfLiteInterpreterWrapper* interpreter_wrapper =
engine_->interpreter_wrapper();
// Note: AllocateTensors() is already performed by the interpreter wrapper
// at InitInterpreter time (see TfLiteEngine).
RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
absl::Status status = interpreter_wrapper->InvokeWithoutFallback();
if (!status.ok()) {
return status.GetPayload(tflite::support::kTfLiteSupportPayload)
.has_value()
? status
: tflite::support::CreateStatusWithPayload(status.code(),
status.message());
}
return Postprocess(GetOutputTensors(), args...);
}
// Performs inference using tflite::support::TfLiteInterpreterWrapper
// InvokeWithFallback() to benefit from automatic fallback from delegation to
// CPU where applicable.
tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) {
tflite::support::TfLiteInterpreterWrapper* interpreter_wrapper =
engine_->interpreter_wrapper();
// Note: AllocateTensors() is already performed by the interpreter wrapper
// at InitInterpreter time (see TfLiteEngine).
RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
auto set_inputs_nop = [](tflite::Interpreter* interpreter) -> absl::Status {
// NOP since inputs are populated at Preprocess() time.
return absl::OkStatus();
};
absl::Status status =
interpreter_wrapper->InvokeWithFallback(set_inputs_nop);
if (!status.ok()) {
return status.GetPayload(tflite::support::kTfLiteSupportPayload)
.has_value()
? status
: tflite::support::CreateStatusWithPayload(status.code(),
status.message());
}
return Postprocess(GetOutputTensors(), args...);
}
};
} // namespace core
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_