1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implementation file for the abstraction of a tensor type, and JSON loading 10 // utils. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/Config/config.h" 14 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/Analysis/TensorSpec.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/JSON.h" 20 #include "llvm/Support/ManagedStatic.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include <array> 23 #include <cassert> 24 #include <numeric> 25 26 using namespace llvm; 27 28 namespace llvm { 29 30 #define TFUTILS_GETDATATYPE_IMPL(T, E) \ 31 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; } 32 33 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) 34 35 #undef TFUTILS_GETDATATYPE_IMPL 36 37 static std::array<std::string, static_cast<size_t>(TensorType::Total)> 38 TensorTypeNames{"INVALID", 39 #define TFUTILS_GETNAME_IMPL(T, _) #T, 40 SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL) 41 #undef TFUTILS_GETNAME_IMPL 42 }; 43 44 StringRef toString(TensorType TT) { 45 return TensorTypeNames[static_cast<size_t>(TT)]; 46 } 47 48 void TensorSpec::toJSON(json::OStream &OS) const { 49 OS.object([&]() { 50 OS.attribute("name", name()); 51 OS.attribute("type", toString(type())); 52 OS.attribute("port", port()); 53 OS.attributeArray("shape", [&]() { 54 for (size_t D : shape()) 55 OS.value(static_cast<int64_t>(D)); 56 }); 57 }); 58 } 59 60 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, 61 size_t ElementSize, const std::vector<int64_t> &Shape) 62 : Name(Name), Port(Port), Type(Type), Shape(Shape), 63 ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, 64 std::multiplies<int64_t>())), 65 ElementSize(ElementSize) {} 66 67 std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, 68 const json::Value &Value) { 69 auto EmitError = 70 [&](const llvm::Twine &Message) -> std::optional<TensorSpec> { 71 std::string S; 72 llvm::raw_string_ostream OS(S); 73 OS << Value; 74 Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); 75 return std::nullopt; 76 }; 77 // FIXME: accept a Path as a parameter, and use it for error reporting. 78 json::Path::Root Root("tensor_spec"); 79 json::ObjectMapper Mapper(Value, Root); 80 if (!Mapper) 81 return EmitError("Value is not a dict"); 82 83 std::string TensorName; 84 int TensorPort = -1; 85 std::string TensorType; 86 std::vector<int64_t> TensorShape; 87 88 if (!Mapper.map<std::string>("name", TensorName)) 89 return EmitError("'name' property not present or not a string"); 90 if (!Mapper.map<std::string>("type", TensorType)) 91 return EmitError("'type' property not present or not a string"); 92 if (!Mapper.map<int>("port", TensorPort)) 93 return EmitError("'port' property not present or not an int"); 94 if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) 95 return EmitError("'shape' property not present or not an int array"); 96 97 #define PARSE_TYPE(T, E) \ 98 if (TensorType == #T) \ 99 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); 100 SUPPORTED_TENSOR_TYPES(PARSE_TYPE) 101 #undef PARSE_TYPE 102 return std::nullopt; 103 } 104 105 } // namespace llvm 106