LLVM 18.0.0git
ModelUnderTrainingRunner.h
Go to the documentation of this file.
1//===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9
10#ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
11#define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
12
13#include "llvm/ADT/STLExtras.h"
16#include "llvm/Config/llvm-config.h"
17
18#ifdef LLVM_HAVE_TFLITE
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/PassManager.h"
23
24namespace llvm {
25
26/// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs
27/// to dynamically load and evaluate a TF SavedModel
28/// (https://www.tensorflow.org/guide/saved_model). Runtime performance is
29/// sacrificed for ease of use while training.
30class ModelUnderTrainingRunner final : public MLModelRunner {
31public:
32 // Disallows copy and assign.
33 ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
34 ModelUnderTrainingRunner &
35 operator=(const ModelUnderTrainingRunner &) = delete;
36
37 const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
38 return ExtraOutputsForLogging;
39 }
40
41 const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
42 return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
43 }
44
45 const std::optional<TFModelEvaluator::EvaluationResult> &
46 lastEvaluationResult() const {
47 return LastEvaluationResult;
48 }
49 static bool classof(const MLModelRunner *R) {
50 return R->getKind() == MLModelRunner::Kind::Development;
51 }
52
53 static std::unique_ptr<ModelUnderTrainingRunner>
54 createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
55 StringRef DecisionName,
56 const std::vector<TensorSpec> &InputSpecs,
57 StringRef OutputSpecsPathOverride = "");
58
59 ModelUnderTrainingRunner(
60 LLVMContext &Ctx, const std::string &ModelPath,
61 const std::vector<TensorSpec> &InputSpecs,
62 const std::vector<TensorSpec> &OutputSpecs,
63 const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
64
65 bool isValid() const { return !!Evaluator; }
66
67private:
68 std::unique_ptr<TFModelEvaluator> Evaluator;
69 const std::vector<TensorSpec> OutputSpecs;
70 const std::vector<TensorSpec> ExtraOutputsForLogging;
71 std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
72 void *evaluateUntyped() override;
73};
74
75} // namespace llvm
76#endif // define(LLVM_HAVE_TFLITE)
77#endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
aa Exhaustive Alias Analysis Precision Evaluator
#define DecisionName
This header defines various interfaces for pass management in LLVM.
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
This file contains some templates that are useful if you are working with the STL at all.
This provides a very simple, boring adaptor for a begin and end iterator into a range type.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18