xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp (revision 8311bc5f17dec348749f763b82dfe2737bc53cd7)
1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation file for the abstraction of a tensor type, and JSON loading
10 // utils.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Config/config.h"
15 
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TensorSpec.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/JSON.h"
22 #include "llvm/Support/ManagedStatic.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <array>
25 #include <cassert>
26 #include <numeric>
27 
28 using namespace llvm;
29 
30 namespace llvm {
31 
32 #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
33   template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
34 
35 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
36 
37 #undef TFUTILS_GETDATATYPE_IMPL
38 
39 static std::array<std::string, static_cast<size_t>(TensorType::Total)>
40     TensorTypeNames{"INVALID",
41 #define TFUTILS_GETNAME_IMPL(T, _) #T,
42                     SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
43 #undef TFUTILS_GETNAME_IMPL
44     };
45 
46 StringRef toString(TensorType TT) {
47   return TensorTypeNames[static_cast<size_t>(TT)];
48 }
49 
50 void TensorSpec::toJSON(json::OStream &OS) const {
51   OS.object([&]() {
52     OS.attribute("name", name());
53     OS.attribute("type", toString(type()));
54     OS.attribute("port", port());
55     OS.attributeArray("shape", [&]() {
56       for (size_t D : shape())
57         OS.value(static_cast<int64_t>(D));
58     });
59   });
60 }
61 
62 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
63                        size_t ElementSize, const std::vector<int64_t> &Shape)
64     : Name(Name), Port(Port), Type(Type), Shape(Shape),
65       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
66                                    std::multiplies<int64_t>())),
67       ElementSize(ElementSize) {}
68 
69 std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
70                                                 const json::Value &Value) {
71   auto EmitError =
72       [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
73     std::string S;
74     llvm::raw_string_ostream OS(S);
75     OS << Value;
76     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
77     return std::nullopt;
78   };
79   // FIXME: accept a Path as a parameter, and use it for error reporting.
80   json::Path::Root Root("tensor_spec");
81   json::ObjectMapper Mapper(Value, Root);
82   if (!Mapper)
83     return EmitError("Value is not a dict");
84 
85   std::string TensorName;
86   int TensorPort = -1;
87   std::string TensorType;
88   std::vector<int64_t> TensorShape;
89 
90   if (!Mapper.map<std::string>("name", TensorName))
91     return EmitError("'name' property not present or not a string");
92   if (!Mapper.map<std::string>("type", TensorType))
93     return EmitError("'type' property not present or not a string");
94   if (!Mapper.map<int>("port", TensorPort))
95     return EmitError("'port' property not present or not an int");
96   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
97     return EmitError("'shape' property not present or not an int array");
98 
99 #define PARSE_TYPE(T, E)                                                       \
100   if (TensorType == #T)                                                        \
101     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
102   SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
103 #undef PARSE_TYPE
104   return std::nullopt;
105 }
106 
107 std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
108   switch (Spec.type()) {
109 #define _IMR_DBG_PRINTER(T, N)                                                 \
110   case TensorType::N: {                                                        \
111     const T *TypedBuff = reinterpret_cast<const T *>(Buffer);                  \
112     auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount());  \
113     return llvm::join(                                                         \
114         llvm::map_range(R, [](T V) { return std::to_string(V); }), ",");       \
115   }
116     SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
117 #undef _IMR_DBG_PRINTER
118   case TensorType::Total:
119   case TensorType::Invalid:
120     llvm_unreachable("invalid tensor type");
121   }
122   // To appease warnings about not all control paths returning a value.
123   return "";
124 }
125 
126 } // namespace llvm
127