1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implementation file for the abstraction of a tensor type, and JSON loading 10 // utils. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/ADT/STLExtras.h" 14 #include "llvm/Config/config.h" 15 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/TensorSpec.h" 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/JSON.h" 22 #include "llvm/Support/ManagedStatic.h" 23 #include "llvm/Support/raw_ostream.h" 24 #include <array> 25 #include <cassert> 26 #include <numeric> 27 28 using namespace llvm; 29 30 namespace llvm { 31 32 #define TFUTILS_GETDATATYPE_IMPL(T, E) \ 33 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; } 34 35 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) 36 37 #undef TFUTILS_GETDATATYPE_IMPL 38 39 static std::array<std::string, static_cast<size_t>(TensorType::Total)> 40 TensorTypeNames{"INVALID", 41 #define TFUTILS_GETNAME_IMPL(T, _) #T, 42 SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL) 43 #undef TFUTILS_GETNAME_IMPL 44 }; 45 46 StringRef toString(TensorType TT) { 47 return TensorTypeNames[static_cast<size_t>(TT)]; 48 } 49 50 void TensorSpec::toJSON(json::OStream &OS) const { 51 OS.object([&]() { 52 OS.attribute("name", name()); 53 OS.attribute("type", toString(type())); 54 OS.attribute("port", port()); 55 OS.attributeArray("shape", [&]() { 56 for (size_t D : shape()) 57 OS.value(static_cast<int64_t>(D)); 58 }); 59 }); 60 } 61 62 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, 63 size_t ElementSize, const std::vector<int64_t> &Shape) 64 : Name(Name), Port(Port), Type(Type), Shape(Shape), 65 ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, 66 std::multiplies<int64_t>())), 67 ElementSize(ElementSize) {} 68 69 std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, 70 const json::Value &Value) { 71 auto EmitError = 72 [&](const llvm::Twine &Message) -> std::optional<TensorSpec> { 73 std::string S; 74 llvm::raw_string_ostream OS(S); 75 OS << Value; 76 Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); 77 return std::nullopt; 78 }; 79 // FIXME: accept a Path as a parameter, and use it for error reporting. 80 json::Path::Root Root("tensor_spec"); 81 json::ObjectMapper Mapper(Value, Root); 82 if (!Mapper) 83 return EmitError("Value is not a dict"); 84 85 std::string TensorName; 86 int TensorPort = -1; 87 std::string TensorType; 88 std::vector<int64_t> TensorShape; 89 90 if (!Mapper.map<std::string>("name", TensorName)) 91 return EmitError("'name' property not present or not a string"); 92 if (!Mapper.map<std::string>("type", TensorType)) 93 return EmitError("'type' property not present or not a string"); 94 if (!Mapper.map<int>("port", TensorPort)) 95 return EmitError("'port' property not present or not an int"); 96 if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) 97 return EmitError("'shape' property not present or not an int array"); 98 99 #define PARSE_TYPE(T, E) \ 100 if (TensorType == #T) \ 101 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); 102 SUPPORTED_TENSOR_TYPES(PARSE_TYPE) 103 #undef PARSE_TYPE 104 return std::nullopt; 105 } 106 107 std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) { 108 switch (Spec.type()) { 109 #define _IMR_DBG_PRINTER(T, N) \ 110 case TensorType::N: { \ 111 const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \ 112 auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \ 113 return llvm::join( \ 114 llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \ 115 } 116 SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER) 117 #undef _IMR_DBG_PRINTER 118 case TensorType::Total: 119 case TensorType::Invalid: 120 llvm_unreachable("invalid tensor type"); 121 } 122 // To appease warnings about not all control paths returning a value. 123 return ""; 124 } 125 126 } // namespace llvm 127