//===- DXILEmitter.cpp - DXIL operation Emitter ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// DXILEmitter uses the descriptions of DXIL operation to construct enum and
// helper functions for DXIL operation.
//
//===----------------------------------------------------------------------===//

#include "Basic/SequenceToOffsetTable.h"
#include "Common/CodeGenTarget.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include <string>

using namespace llvm;
using namespace llvm::dxil;

namespace {

struct DXILShaderModel {
  int Major = 0;
  int Minor = 0;
};

struct DXILOperationDesc {
  std::string OpName; // name of DXIL operation
  int OpCode;         // ID of DXIL operation
  StringRef OpClass;  // name of the opcode class
  StringRef Doc;      // the documentation description of this instruction
  SmallVector<Record *> OpTypes; // Vector of operand type records -
                                 // return type is at index 0
  SmallVector<std::string>
      OpAttributes;     // operation attribute represented as strings
  StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
                        // means no map exists
  bool IsDeriv = false; // whether this is some kind of derivative
  bool IsGradient = false; // whether this requires a gradient calculation
  bool IsFeedback = false; // whether this is a sampler feedback op
  bool IsWave =
      false; // whether this requires in-wave, cross-lane functionality
  bool RequiresUniformInputs = false; // whether this operation requires that
                                      // all of its inputs are uniform across
                                      // the wave
  SmallVector<StringRef, 4>
      ShaderStages; // shader stages to which this applies, empty for all.
  DXILShaderModel ShaderModel;           // minimum shader model required
  DXILShaderModel ShaderModelTranslated; // minimum shader model required with
                                         // translation by linker
  int OverloadParamIndex;             // Index of parameter with overload type.
                                      //   -1 : no overload types
  SmallVector<StringRef, 4> counters; // counters for this inst.
  DXILOperationDesc(const Record *);
};
} // end anonymous namespace

/// Return dxil::ParameterKind corresponding to input LLVMType record
///
/// \param R TableGen def record of class LLVMType
/// \return ParameterKind As defined in llvm/Support/DXILABI.h

static ParameterKind getParameterKind(const Record *R) {
  auto VTRec = R->getValueAsDef("VT");
  switch (getValueType(VTRec)) {
  case MVT::isVoid:
    return ParameterKind::Void;
  case MVT::f16:
    return ParameterKind::Half;
  case MVT::f32:
    return ParameterKind::Float;
  case MVT::f64:
    return ParameterKind::Double;
  case MVT::i1:
    return ParameterKind::I1;
  case MVT::i8:
    return ParameterKind::I8;
  case MVT::i16:
    return ParameterKind::I16;
  case MVT::i32:
    return ParameterKind::I32;
  case MVT::fAny:
  case MVT::iAny:
    return ParameterKind::Overload;
  case MVT::Other:
    // Handle DXIL-specific overload types
    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
      return ParameterKind::Overload;
    }
    [[fallthrough]];
  default:
    llvm_unreachable("Support for specified DXIL Type not yet implemented");
  }
}

/// Construct an object using the DXIL Operation records specified
/// in DXIL.td. This serves as the single source of reference of
/// the information extracted from the specified Record R, for
/// C++ code generated by this TableGen backend.
//  \param R Object representing TableGen record of a DXIL Operation
DXILOperationDesc::DXILOperationDesc(const Record *R) {
  OpName = R->getNameInitAsString();
  OpCode = R->getValueAsInt("OpCode");

  Doc = R->getValueAsString("Doc");

  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
  unsigned TypeRecsSize = TypeRecs.size();
  // Populate OpTypes with return type and parameter types

  // Parameter indices of overloaded parameters.
  // This vector contains overload parameters in the order used to
  // resolve an LLVMMatchType in accordance with  convention outlined in
  // the comment before the definition of class LLVMMatchType in
  // llvm/IR/Intrinsics.td
  SmallVector<int> OverloadParamIndices;
  for (unsigned i = 0; i < TypeRecsSize; i++) {
    auto TR = TypeRecs[i];
    // Track operation parameter indices of any overload types
    auto isAny = TR->getValueAsInt("isAny");
    if (isAny == 1) {
      // TODO: At present it is expected that all overload types in a DXIL Op
      // are of the same type. Hence, OverloadParamIndices will have only one
      // element. This implies we do not need a vector. However, until more
      // (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
      // cases this assumption would not hold.
      if (!OverloadParamIndices.empty()) {
        bool knownType = true;
        // Ensure that the same overload type registered earlier is being used
        for (auto Idx : OverloadParamIndices) {
          if (TR != TypeRecs[Idx]) {
            knownType = false;
            break;
          }
        }
        if (!knownType) {
          report_fatal_error("Specification of multiple differing overload "
                             "parameter types not yet supported",
                             false);
        }
      } else {
        OverloadParamIndices.push_back(i);
      }
    }
    // Populate OpTypes array according to the type specification
    if (TR->isAnonymous()) {
      // Check prior overload types exist
      assert(!OverloadParamIndices.empty() &&
             "No prior overloaded parameter found to match.");
      // Get the parameter index of anonymous type, TR, references
      auto OLParamIndex = TR->getValueAsInt("Number");
      // Resolve and insert the type to that at OLParamIndex
      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
    } else {
      // A non-anonymous type. Just record it in OpTypes
      OpTypes.emplace_back(TR);
    }
  }

  // Set the index of the overload parameter, if any.
  OverloadParamIndex = -1; // default; indicating none
  if (!OverloadParamIndices.empty()) {
    if (OverloadParamIndices.size() > 1)
      report_fatal_error("Multiple overload type specification not supported",
                         false);
    OverloadParamIndex = OverloadParamIndices[0];
  }
  // Get the operation class
  OpClass = R->getValueAsDef("OpClass")->getName();

  if (R->getValue("LLVMIntrinsic")) {
    auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
    auto DefName = IntrinsicDef->getName();
    assert(DefName.starts_with("int_") && "invalid intrinsic name");
    // Remove the int_ from intrinsic name.
    Intrinsic = DefName.substr(4);
    // TODO: For now, assume that attributes of DXIL Operation are the same as
    // that of the intrinsic. Deviations are expected to be encoded in TableGen
    // record specification and handled accordingly here. Support to be added
    // as needed.
    auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
    auto IntrPropListSize = IntrPropList->size();
    for (unsigned i = 0; i < IntrPropListSize; i++) {
      OpAttributes.emplace_back(IntrPropList->getElement(i)->getAsString());
    }
  }
}

/// Return a string representation of ParameterKind enum
/// \param Kind Parameter Kind enum value
/// \return std::string string representation of input Kind
static std::string getParameterKindStr(ParameterKind Kind) {
  switch (Kind) {
  case ParameterKind::Invalid:
    return "Invalid";
  case ParameterKind::Void:
    return "Void";
  case ParameterKind::Half:
    return "Half";
  case ParameterKind::Float:
    return "Float";
  case ParameterKind::Double:
    return "Double";
  case ParameterKind::I1:
    return "I1";
  case ParameterKind::I8:
    return "I8";
  case ParameterKind::I16:
    return "I16";
  case ParameterKind::I32:
    return "I32";
  case ParameterKind::I64:
    return "I64";
  case ParameterKind::Overload:
    return "Overload";
  case ParameterKind::CBufferRet:
    return "CBufferRet";
  case ParameterKind::ResourceRet:
    return "ResourceRet";
  case ParameterKind::DXILHandle:
    return "DXILHandle";
  }
  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
}

/// Return a string representation of OverloadKind enum that maps to
/// input LLVMType record
/// \param R TableGen def record of class LLVMType
/// \return std::string string representation of OverloadKind

static std::string getOverloadKindStr(const Record *R) {
  auto VTRec = R->getValueAsDef("VT");
  switch (getValueType(VTRec)) {
  case MVT::isVoid:
    return "OverloadKind::VOID";
  case MVT::f16:
    return "OverloadKind::HALF";
  case MVT::f32:
    return "OverloadKind::FLOAT";
  case MVT::f64:
    return "OverloadKind::DOUBLE";
  case MVT::i1:
    return "OverloadKind::I1";
  case MVT::i8:
    return "OverloadKind::I8";
  case MVT::i16:
    return "OverloadKind::I16";
  case MVT::i32:
    return "OverloadKind::I32";
  case MVT::i64:
    return "OverloadKind::I64";
  case MVT::iAny:
    return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
  case MVT::fAny:
    return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
  case MVT::Other:
    // Handle DXIL-specific overload types
    {
      if (R->getValueAsInt("isHalfOrFloat")) {
        return "OverloadKind::HALF | OverloadKind::FLOAT";
      } else if (R->getValueAsInt("isI16OrI32")) {
        return "OverloadKind::I16 | OverloadKind::I32";
      }
    }
    [[fallthrough]];
  default:
    llvm_unreachable(
        "Support for specified parameter OverloadKind not yet implemented");
  }
}

/// Emit Enums of DXIL Ops
/// \param A vector of DXIL Ops
/// \param Output stream
static void emitDXILEnums(std::vector<DXILOperationDesc> &Ops,
                          raw_ostream &OS) {
  // Sort by OpCode
  llvm::sort(Ops, [](DXILOperationDesc &A, DXILOperationDesc &B) {
    return A.OpCode < B.OpCode;
  });

  OS << "// Enumeration for operations specified by DXIL\n";
  OS << "enum class OpCode : unsigned {\n";

  for (auto &Op : Ops) {
    // Name = ID, // Doc
    OS << Op.OpName << " = " << Op.OpCode << ", // " << Op.Doc << "\n";
  }

  OS << "\n};\n\n";

  OS << "// Groups for DXIL operations with equivalent function templates\n";
  OS << "enum class OpCodeClass : unsigned {\n";
  // Build an OpClass set to print
  SmallSet<StringRef, 2> OpClassSet;
  for (auto &Op : Ops) {
    OpClassSet.insert(Op.OpClass);
  }
  for (auto &C : OpClassSet) {
    OS << C << ",\n";
  }
  OS << "\n};\n\n";
}

/// Emit map of DXIL operation to LLVM or DirectX intrinsic
/// \param A vector of DXIL Ops
/// \param Output stream
static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
                                 raw_ostream &OS) {
  OS << "\n";
  // FIXME: use array instead of SmallDenseMap.
  OS << "static const SmallDenseMap<Intrinsic::ID, dxil::OpCode> LowerMap = "
        "{\n";
  for (auto &Op : Ops) {
    if (Op.Intrinsic.empty())
      continue;
    // {Intrinsic::sin, dxil::OpCode::Sin},
    OS << "  { Intrinsic::" << Op.Intrinsic << ", dxil::OpCode::" << Op.OpName
       << "},\n";
  }
  OS << "};\n";
  OS << "\n";
}

/// Convert operation attribute string to Attribute enum
///
/// \param Attr string reference
/// \return std::string Attribute enum string

static std::string emitDXILOperationAttr(SmallVector<std::string> Attrs) {
  for (auto Attr : Attrs) {
    // TODO: For now just recognize IntrNoMem and IntrReadMem as valid and
    //  ignore others.
    if (Attr == "IntrNoMem") {
      return "Attribute::ReadNone";
    } else if (Attr == "IntrReadMem") {
      return "Attribute::ReadOnly";
    }
  }
  return "Attribute::None";
}

/// Emit DXIL operation table
/// \param A vector of DXIL Ops
/// \param Output stream
static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
                                   raw_ostream &OS) {
  // Sort by OpCode.
  llvm::sort(Ops, [](DXILOperationDesc &A, DXILOperationDesc &B) {
    return A.OpCode < B.OpCode;
  });

  // Collect Names.
  SequenceToOffsetTable<std::string> OpClassStrings;
  SequenceToOffsetTable<std::string> OpStrings;
  SequenceToOffsetTable<SmallVector<ParameterKind>> Parameters;

  StringMap<SmallVector<ParameterKind>> ParameterMap;
  StringSet<> ClassSet;
  for (auto &Op : Ops) {
    OpStrings.add(Op.OpName);

    if (ClassSet.contains(Op.OpClass))
      continue;
    ClassSet.insert(Op.OpClass);
    OpClassStrings.add(Op.OpClass.data());
    SmallVector<ParameterKind> ParamKindVec;
    // ParamKindVec is a vector of parameters. Skip return type at index 0
    for (unsigned i = 1; i < Op.OpTypes.size(); i++) {
      ParamKindVec.emplace_back(getParameterKind(Op.OpTypes[i]));
    }
    ParameterMap[Op.OpClass] = ParamKindVec;
    Parameters.add(ParamKindVec);
  }

  // Layout names.
  OpStrings.layout();
  OpClassStrings.layout();
  Parameters.layout();

  // Emit the DXIL operation table.
  //{dxil::OpCode::Sin, OpCodeNameIndex, OpCodeClass::unary,
  // OpCodeClassNameIndex,
  // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0,
  // 3, ParameterTableOffset},
  OS << "static const OpCodeProperty *getOpCodeProperty(dxil::OpCode Op) "
        "{\n";

  OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
  for (auto &Op : Ops) {
    // Consider Op.OverloadParamIndex as the overload parameter index, by
    // default
    auto OLParamIdx = Op.OverloadParamIndex;
    // If no overload parameter index is set, treat first parameter type as
    // overload type - unless the Op has no parameters, in which case treat the
    // return type - as overload parameter to emit the appropriate overload kind
    // enum.
    if (OLParamIdx < 0) {
      OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
    }
    OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
       << ", OpCodeClass::" << Op.OpClass << ", "
       << OpClassStrings.get(Op.OpClass.data()) << ", "
       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
       << emitDXILOperationAttr(Op.OpAttributes) << ", "
       << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
       << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
  }
  OS << "  };\n";

  OS << "  // FIXME: change search to indexing with\n";
  OS << "  // Op once all DXIL operations are added.\n";
  OS << "  OpCodeProperty TmpProp;\n";
  OS << "  TmpProp.OpCode = Op;\n";
  OS << "  const OpCodeProperty *Prop =\n";
  OS << "      llvm::lower_bound(OpCodeProps, TmpProp,\n";
  OS << "                        [](const OpCodeProperty &A, const "
        "OpCodeProperty &B) {\n";
  OS << "                          return A.OpCode < B.OpCode;\n";
  OS << "                        });\n";
  OS << "  assert(Prop && \"failed to find OpCodeProperty\");\n";
  OS << "  return Prop;\n";
  OS << "}\n\n";

  // Emit the string tables.
  OS << "static const char *getOpCodeName(dxil::OpCode Op) {\n\n";

  OpStrings.emitStringLiteralDef(OS,
                                 "  static const char DXILOpCodeNameTable[]");

  OS << "  auto *Prop = getOpCodeProperty(Op);\n";
  OS << "  unsigned Index = Prop->OpCodeNameOffset;\n";
  OS << "  return DXILOpCodeNameTable + Index;\n";
  OS << "}\n\n";

  OS << "static const char *getOpCodeClassName(const OpCodeProperty &Prop) "
        "{\n\n";

  OpClassStrings.emitStringLiteralDef(
      OS, "  static const char DXILOpCodeClassNameTable[]");

  OS << "  unsigned Index = Prop.OpCodeClassNameOffset;\n";
  OS << "  return DXILOpCodeClassNameTable + Index;\n";
  OS << "}\n ";

  OS << "static const ParameterKind *getOpCodeParameterKind(const "
        "OpCodeProperty &Prop) "
        "{\n\n";
  OS << "  static const ParameterKind DXILOpParameterKindTable[] = {\n";
  Parameters.emit(
      OS,
      [](raw_ostream &ParamOS, ParameterKind Kind) {
        ParamOS << "ParameterKind::" << getParameterKindStr(Kind);
      },
      "ParameterKind::Invalid");
  OS << "  };\n\n";
  OS << "  unsigned Index = Prop.ParameterTableOffset;\n";
  OS << "  return DXILOpParameterKindTable + Index;\n";
  OS << "}\n ";
}

/// Entry function call that invokes the functionality of this TableGen backend
/// \param Records TableGen records of DXIL Operations defined in DXIL.td
/// \param OS output stream
static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
  OS << "// Generated code, do not edit.\n";
  OS << "\n";
  // Get all DXIL Ops to intrinsic mapping records
  std::vector<Record *> OpIntrMaps =
      Records.getAllDerivedDefinitions("DXILOpMapping");
  std::vector<DXILOperationDesc> DXILOps;
  for (auto *Record : OpIntrMaps) {
    DXILOps.emplace_back(DXILOperationDesc(Record));
  }
  OS << "#ifdef DXIL_OP_ENUM\n";
  emitDXILEnums(DXILOps, OS);
  OS << "#endif\n\n";
  OS << "#ifdef DXIL_OP_INTRINSIC_MAP\n";
  emitDXILIntrinsicMap(DXILOps, OS);
  OS << "#endif\n\n";
  OS << "#ifdef DXIL_OP_OPERATION_TABLE\n";
  emitDXILOperationTable(DXILOps, OS);
  OS << "#endif\n\n";
}

static TableGen::Emitter::Opt X("gen-dxil-operation", EmitDXILOperation,
                                "Generate DXIL operation information");