LLVM 19.0.0git
ModelUnderTrainingRunner.cpp
Go to the documentation of this file.
1//===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10// happens off a model that's provided from the command line and is interpreted.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/Config/config.h"
16#if defined(LLVM_HAVE_TFLITE)
19#include "llvm/Support/Path.h"
20#include <optional>
21
22using namespace llvm;
23namespace {
24struct LoggedFeatureSpec {
26 std::optional<std::string> LoggingName;
27};
28
29std::optional<std::vector<LoggedFeatureSpec>>
30loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
31 StringRef ModelPath, StringRef SpecFileOverride) {
32 SmallVector<char, 128> OutputSpecsPath;
33 StringRef FileName = SpecFileOverride;
34 if (FileName.empty()) {
35 llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
36 FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
37 }
38
39 auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
40 if (!BufferOrError) {
41 Ctx.emitError("Error opening output specs file: " + FileName + " : " +
42 BufferOrError.getError().message());
43 return std::nullopt;
44 }
45 auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
46 if (!ParsedJSONValues) {
47 Ctx.emitError("Could not parse specs file: " + FileName);
48 return std::nullopt;
49 }
50 auto ValuesArray = ParsedJSONValues->getAsArray();
51 if (!ValuesArray) {
52 Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
53 "logging_name:<name>} dictionaries");
54 return std::nullopt;
55 }
56 std::vector<LoggedFeatureSpec> Ret;
57 for (const auto &Value : *ValuesArray)
58 if (const auto *Obj = Value.getAsObject())
59 if (const auto *SpecPart = Obj->get("tensor_spec"))
60 if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
61 if (auto LoggingName = Obj->getString("logging_name")) {
62 if (!TensorSpec->isElementType<int64_t>() &&
63 !TensorSpec->isElementType<int32_t>() &&
64 !TensorSpec->isElementType<float>()) {
65 Ctx.emitError(
66 "Only int64, int32, and float tensors are supported. "
67 "Found unsupported type for tensor named " +
68 TensorSpec->name());
69 return std::nullopt;
70 }
71 Ret.push_back({*TensorSpec, LoggingName->str()});
72 }
73
74 if (ValuesArray->size() != Ret.size()) {
75 Ctx.emitError(
76 "Unable to parse output spec. It should be a json file containing an "
77 "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
78 "with a json object describing a TensorSpec; and a 'logging_name' key, "
79 "which is a string to use as name when logging this tensor in the "
80 "training log.");
81 return std::nullopt;
82 }
83 if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
84 Ctx.emitError("The first output spec must describe the decision tensor, "
85 "and must have the logging_name " +
86 StringRef(ExpectedDecisionName));
87 return std::nullopt;
88 }
89 return Ret;
90}
91} // namespace
92
93ModelUnderTrainingRunner::ModelUnderTrainingRunner(
94 LLVMContext &Ctx, const std::string &ModelPath,
95 const std::vector<TensorSpec> &InputSpecs,
96 const std::vector<TensorSpec> &OutputSpecs,
97 const std::vector<TensorSpec> &ExtraOutputsForLogging)
98 : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
99 OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
100 Evaluator =
101 std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
102 if (!Evaluator || !Evaluator->isValid()) {
103 Ctx.emitError("Failed to create saved model evaluator");
104 Evaluator.reset();
105 return;
106 }
107
108 for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
109 setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
110 }
111}
112
113void *ModelUnderTrainingRunner::evaluateUntyped() {
114 LastEvaluationResult = Evaluator->evaluate();
115 if (!LastEvaluationResult.has_value()) {
116 Ctx.emitError("Error evaluating model.");
117 return nullptr;
118 }
119 return LastEvaluationResult->getUntypedTensorValue(0);
120}
121
122std::unique_ptr<ModelUnderTrainingRunner>
123ModelUnderTrainingRunner::createAndEnsureValid(
124 LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
125 const std::vector<TensorSpec> &InputSpecs,
126 StringRef OutputSpecsPathOverride) {
127 if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
128 OutputSpecsPathOverride)) {
129 std::unique_ptr<ModelUnderTrainingRunner> MUTR;
130 std::vector<TensorSpec> OutputSpecs;
131 std::vector<TensorSpec> ExtraOutputsForLogging;
132 append_range(OutputSpecs,
133 map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
134 return LFS.Spec;
135 }));
136 append_range(ExtraOutputsForLogging,
137 map_range(drop_begin(*MaybeOutputSpecs),
138 [](const LoggedFeatureSpec &LFS) {
139 return TensorSpec(LFS.LoggingName
140 ? *LFS.LoggingName
141 : LFS.Spec.name(),
142 LFS.Spec);
143 }));
144
145 MUTR.reset(new ModelUnderTrainingRunner(
146 Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
147 if (MUTR && MUTR->isValid())
148 return MUTR;
149
150 Ctx.emitError("Could not load or create model evaluator.");
151 return nullptr;
152 }
153 Ctx.emitError("Could not load the policy model from the provided path");
154 return nullptr;
155}
156
157#endif // defined(LLVM_HAVE_TFLITE)
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define I(x, y, z)
Definition: MD5.cpp:58
#define DecisionName
This file contains some templates that are useful if you are working with the STL at all.
This class evaluates LLVM IR, producing the Constant representing each SSA instruction.
Definition: Evaluator.h:37
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
void emitError(uint64_t LocCookie, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
MLModelRunner interface: abstraction of a mechanism for evaluating a ML model.
Definition: MLModelRunner.h:26
size_t size() const
Definition: SmallVector.h:91
pointer data()
Return a pointer to the vector's buffer, even if empty().
Definition: SmallVector.h:299
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
constexpr bool empty() const
empty - Check if the string is empty.
Definition: StringRef.h:134
const std::string & name() const
Definition: TensorSpec.h:67
bool isElementType() const
Definition: TensorSpec.h:86
LLVM Value Representation.
Definition: Value.h:74
void append(SmallVectorImpl< char > &path, const Twine &a, const Twine &b="", const Twine &c="", const Twine &d="")
Append to path.
Definition: Path.cpp:457
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1680
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2073
std::optional< TensorSpec > getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value)
Construct a TensorSpec from a JSON dictionary of the form: { "name": <string>, "port": <int>,...
Definition: TensorSpec.cpp:69
auto map_range(ContainerTy &&C, FuncTy F)
Definition: STLExtras.h:377