181ad6265SDimitry Andric //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // Implementation file for the abstraction of a tensor type, and JSON loading
1081ad6265SDimitry Andric // utils.
1181ad6265SDimitry Andric //
1281ad6265SDimitry Andric //===----------------------------------------------------------------------===//
13*06c3fb27SDimitry Andric #include "llvm/ADT/STLExtras.h"
1481ad6265SDimitry Andric #include "llvm/Config/config.h"
1581ad6265SDimitry Andric
16*06c3fb27SDimitry Andric #include "llvm/ADT/StringExtras.h"
1781ad6265SDimitry Andric #include "llvm/ADT/Twine.h"
1881ad6265SDimitry Andric #include "llvm/Analysis/TensorSpec.h"
1981ad6265SDimitry Andric #include "llvm/Support/CommandLine.h"
2081ad6265SDimitry Andric #include "llvm/Support/Debug.h"
2181ad6265SDimitry Andric #include "llvm/Support/JSON.h"
2281ad6265SDimitry Andric #include "llvm/Support/ManagedStatic.h"
2381ad6265SDimitry Andric #include "llvm/Support/raw_ostream.h"
24bdd1243dSDimitry Andric #include <array>
2581ad6265SDimitry Andric #include <cassert>
2681ad6265SDimitry Andric #include <numeric>
2781ad6265SDimitry Andric
2881ad6265SDimitry Andric using namespace llvm;
2981ad6265SDimitry Andric
3081ad6265SDimitry Andric namespace llvm {
3181ad6265SDimitry Andric
3281ad6265SDimitry Andric #define TFUTILS_GETDATATYPE_IMPL(T, E) \
3381ad6265SDimitry Andric template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
3481ad6265SDimitry Andric
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)3581ad6265SDimitry Andric SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
3681ad6265SDimitry Andric
3781ad6265SDimitry Andric #undef TFUTILS_GETDATATYPE_IMPL
3881ad6265SDimitry Andric
39bdd1243dSDimitry Andric static std::array<std::string, static_cast<size_t>(TensorType::Total)>
40bdd1243dSDimitry Andric TensorTypeNames{"INVALID",
41bdd1243dSDimitry Andric #define TFUTILS_GETNAME_IMPL(T, _) #T,
42bdd1243dSDimitry Andric SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
43bdd1243dSDimitry Andric #undef TFUTILS_GETNAME_IMPL
44bdd1243dSDimitry Andric };
45bdd1243dSDimitry Andric
toString(TensorType TT)46bdd1243dSDimitry Andric StringRef toString(TensorType TT) {
47bdd1243dSDimitry Andric return TensorTypeNames[static_cast<size_t>(TT)];
48bdd1243dSDimitry Andric }
49bdd1243dSDimitry Andric
toJSON(json::OStream & OS) const50bdd1243dSDimitry Andric void TensorSpec::toJSON(json::OStream &OS) const {
51bdd1243dSDimitry Andric OS.object([&]() {
52bdd1243dSDimitry Andric OS.attribute("name", name());
53bdd1243dSDimitry Andric OS.attribute("type", toString(type()));
54bdd1243dSDimitry Andric OS.attribute("port", port());
55bdd1243dSDimitry Andric OS.attributeArray("shape", [&]() {
56bdd1243dSDimitry Andric for (size_t D : shape())
57bdd1243dSDimitry Andric OS.value(static_cast<int64_t>(D));
58bdd1243dSDimitry Andric });
59bdd1243dSDimitry Andric });
60bdd1243dSDimitry Andric }
61bdd1243dSDimitry Andric
TensorSpec(const std::string & Name,int Port,TensorType Type,size_t ElementSize,const std::vector<int64_t> & Shape)6281ad6265SDimitry Andric TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
6381ad6265SDimitry Andric size_t ElementSize, const std::vector<int64_t> &Shape)
6481ad6265SDimitry Andric : Name(Name), Port(Port), Type(Type), Shape(Shape),
6581ad6265SDimitry Andric ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
6681ad6265SDimitry Andric std::multiplies<int64_t>())),
6781ad6265SDimitry Andric ElementSize(ElementSize) {}
6881ad6265SDimitry Andric
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)69bdd1243dSDimitry Andric std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
7081ad6265SDimitry Andric const json::Value &Value) {
71bdd1243dSDimitry Andric auto EmitError =
72bdd1243dSDimitry Andric [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
7381ad6265SDimitry Andric std::string S;
7481ad6265SDimitry Andric llvm::raw_string_ostream OS(S);
7581ad6265SDimitry Andric OS << Value;
7681ad6265SDimitry Andric Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
77bdd1243dSDimitry Andric return std::nullopt;
7881ad6265SDimitry Andric };
7981ad6265SDimitry Andric // FIXME: accept a Path as a parameter, and use it for error reporting.
8081ad6265SDimitry Andric json::Path::Root Root("tensor_spec");
8181ad6265SDimitry Andric json::ObjectMapper Mapper(Value, Root);
8281ad6265SDimitry Andric if (!Mapper)
8381ad6265SDimitry Andric return EmitError("Value is not a dict");
8481ad6265SDimitry Andric
8581ad6265SDimitry Andric std::string TensorName;
8681ad6265SDimitry Andric int TensorPort = -1;
8781ad6265SDimitry Andric std::string TensorType;
8881ad6265SDimitry Andric std::vector<int64_t> TensorShape;
8981ad6265SDimitry Andric
9081ad6265SDimitry Andric if (!Mapper.map<std::string>("name", TensorName))
9181ad6265SDimitry Andric return EmitError("'name' property not present or not a string");
9281ad6265SDimitry Andric if (!Mapper.map<std::string>("type", TensorType))
9381ad6265SDimitry Andric return EmitError("'type' property not present or not a string");
9481ad6265SDimitry Andric if (!Mapper.map<int>("port", TensorPort))
9581ad6265SDimitry Andric return EmitError("'port' property not present or not an int");
9681ad6265SDimitry Andric if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
9781ad6265SDimitry Andric return EmitError("'shape' property not present or not an int array");
9881ad6265SDimitry Andric
9981ad6265SDimitry Andric #define PARSE_TYPE(T, E) \
10081ad6265SDimitry Andric if (TensorType == #T) \
10181ad6265SDimitry Andric return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
10281ad6265SDimitry Andric SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
10381ad6265SDimitry Andric #undef PARSE_TYPE
104bdd1243dSDimitry Andric return std::nullopt;
10581ad6265SDimitry Andric }
10681ad6265SDimitry Andric
tensorValueToString(const char * Buffer,const TensorSpec & Spec)107*06c3fb27SDimitry Andric std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
108*06c3fb27SDimitry Andric switch (Spec.type()) {
109*06c3fb27SDimitry Andric #define _IMR_DBG_PRINTER(T, N) \
110*06c3fb27SDimitry Andric case TensorType::N: { \
111*06c3fb27SDimitry Andric const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
112*06c3fb27SDimitry Andric auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
113*06c3fb27SDimitry Andric return llvm::join( \
114*06c3fb27SDimitry Andric llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
115*06c3fb27SDimitry Andric }
116*06c3fb27SDimitry Andric SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
117*06c3fb27SDimitry Andric #undef _IMR_DBG_PRINTER
118*06c3fb27SDimitry Andric case TensorType::Total:
119*06c3fb27SDimitry Andric case TensorType::Invalid:
120*06c3fb27SDimitry Andric llvm_unreachable("invalid tensor type");
121*06c3fb27SDimitry Andric }
122*06c3fb27SDimitry Andric // To appease warnings about not all control paths returning a value.
123*06c3fb27SDimitry Andric return "";
124*06c3fb27SDimitry Andric }
125*06c3fb27SDimitry Andric
12681ad6265SDimitry Andric } // namespace llvm
127