LLVM 22.0.0git
TensorSpec.cpp
Go to the documentation of this file.
1//===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation file for the abstraction of a tensor type, and JSON loading
10// utils.
11//
12//===----------------------------------------------------------------------===//
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/Config/config.h"
15
17#include "llvm/ADT/Twine.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/JSON.h"
24#include <array>
25#include <cassert>
26#include <numeric>
27
28using namespace llvm;
29
30namespace llvm {
31
32#define TFUTILS_GETDATATYPE_IMPL(T, E) \
33 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
34
36
37#undef TFUTILS_GETDATATYPE_IMPL
38
39static std::array<std::string, static_cast<size_t>(TensorType::Total)>
40 TensorTypeNames{"INVALID",
41#define TFUTILS_GETNAME_IMPL(T, _) #T,
43#undef TFUTILS_GETNAME_IMPL
44 };
45
47 return TensorTypeNames[static_cast<size_t>(TT)];
48}
49
51 OS.object([&]() {
52 OS.attribute("name", name());
53 OS.attribute("type", toString(type()));
54 OS.attribute("port", port());
55 OS.attributeArray("shape", [&]() {
56 for (size_t D : shape())
57 OS.value(static_cast<int64_t>(D));
58 });
59 });
60}
61
62TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
63 size_t ElementSize, const std::vector<int64_t> &Shape)
64 : Name(Name), Port(Port), Type(Type), Shape(Shape),
65 ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
66 std::multiplies<int64_t>())),
67 ElementSize(ElementSize) {}
68
69std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
70 const json::Value &Value) {
71 auto EmitError =
72 [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
73 std::string S;
75 OS << Value;
76 Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
77 return std::nullopt;
78 };
79 // FIXME: accept a Path as a parameter, and use it for error reporting.
80 json::Path::Root Root("tensor_spec");
81 json::ObjectMapper Mapper(Value, Root);
82 if (!Mapper)
83 return EmitError("Value is not a dict");
84
85 std::string TensorName;
86 int TensorPort = -1;
87 std::string TensorType;
88 std::vector<int64_t> TensorShape;
89
90 if (!Mapper.map<std::string>("name", TensorName))
91 return EmitError("'name' property not present or not a string");
92 if (!Mapper.map<std::string>("type", TensorType))
93 return EmitError("'type' property not present or not a string");
94 if (!Mapper.map<int>("port", TensorPort))
95 return EmitError("'port' property not present or not an int");
96 if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
97 return EmitError("'shape' property not present or not an int array");
98
99#define PARSE_TYPE(T, E) \
100 if (TensorType == #T) \
101 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
103#undef PARSE_TYPE
104 return std::nullopt;
105}
106
107std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
108 switch (Spec.type()) {
109#define _IMR_DBG_PRINTER(T, N) \
110 case TensorType::N: { \
111 const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
112 auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
113 return llvm::join( \
114 llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
115 }
117#undef _IMR_DBG_PRINTER
120 llvm_unreachable("invalid tensor type");
121 }
122 // To appease warnings about not all control paths returning a value.
123 return "";
124}
125
126} // namespace llvm
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
This file supports working with JSON data.
This file contains some templates that are useful if you are working with the STL at all.
This file contains some functions that are useful when dealing with strings.
#define PARSE_TYPE(T, E)
#define TFUTILS_GETDATATYPE_IMPL(T, E)
#define TFUTILS_GETNAME_IMPL(T, _)
#define _IMR_DBG_PRINTER(T, N)
#define SUPPORTED_TENSOR_TYPES(M)
TensorSpec encapsulates the specification of a tensor: its dimensions, or "shape" (row-major),...
Definition TensorSpec.h:43
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
const std::string & name() const
Definition TensorSpec.h:72
TensorType type() const
Definition TensorSpec.h:74
const std::vector< int64_t > & shape() const
Definition TensorSpec.h:75
int port() const
Definition TensorSpec.h:73
TensorSpec(const std::string &NewName, const TensorSpec &Other)
Definition TensorSpec.h:95
LLVM_ABI void toJSON(json::OStream &OS) const
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
LLVM Value Representation.
Definition Value.h:75
json::OStream allows writing well-formed JSON without materializing all structures as json::Value ahe...
Definition JSON.h:998
void object(Block Contents)
Emit an object whose elements are emitted in the provided Block.
Definition JSON.h:1028
void attribute(llvm::StringRef Key, const Value &Contents)
Emit an attribute whose value is self-contained (number, vector<int> etc).
Definition JSON.h:1053
void attributeArray(llvm::StringRef Key, Block Contents)
Emit an attribute whose value is an array with elements from the Block.
Definition JSON.h:1057
LLVM_ABI void value(const Value &V)
Emit a self-contained value (number, string, vector<string> etc).
Definition JSON.cpp:747
Helper for mapping JSON objects onto protocol structs.
Definition JSON.h:861
The root is the trivial Path to the root value.
Definition JSON.h:713
A Value is an JSON value of unknown type.
Definition JSON.h:290
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
This is an optimization pass for GlobalISel generic memory operations.
auto accumulate(R &&Range, E &&Init)
Wrapper for std::accumulate.
Definition STLExtras.h:1690
LLVM_ABI std::optional< TensorSpec > getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value)
Construct a TensorSpec from a JSON dictionary of the form: { "name": <string>, "port": <int>,...
static std::array< std::string, static_cast< size_t >(TensorType::Total)> TensorTypeNames
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
LLVM_ABI std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec)
For debugging.
TensorType
Definition TensorSpec.h:55
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:851