xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===--- SPIRVCommandLine.cpp ---- Command Line Options ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of classes and functions needed for
10 // processing, parsing, and using CLI options for the SPIR-V backend.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVCommandLine.h"
15 #include "llvm/ADT/StringRef.h"
16 #include <algorithm>
17 #include <map>
18 
19 #define DEBUG_TYPE "spirv-commandline"
20 
21 using namespace llvm;
22 
23 static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
24     SPIRVExtensionMap = {
25         {"SPV_EXT_shader_atomic_float_add",
26          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
27         {"SPV_EXT_shader_atomic_float16_add",
28          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
29         {"SPV_EXT_shader_atomic_float_min_max",
30          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
31         {"SPV_EXT_arithmetic_fence",
32          SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
33         {"SPV_EXT_demote_to_helper_invocation",
34          SPIRV::Extension::Extension::SPV_EXT_demote_to_helper_invocation},
35         {"SPV_INTEL_arbitrary_precision_integers",
36          SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
37         {"SPV_INTEL_cache_controls",
38          SPIRV::Extension::Extension::SPV_INTEL_cache_controls},
39         {"SPV_INTEL_float_controls2",
40          SPIRV::Extension::Extension::SPV_INTEL_float_controls2},
41         {"SPV_INTEL_global_variable_fpga_decorations",
42          SPIRV::Extension::Extension::
43              SPV_INTEL_global_variable_fpga_decorations},
44         {"SPV_INTEL_global_variable_host_access",
45          SPIRV::Extension::Extension::SPV_INTEL_global_variable_host_access},
46         {"SPV_INTEL_optnone", SPIRV::Extension::Extension::SPV_INTEL_optnone},
47         {"SPV_EXT_optnone", SPIRV::Extension::Extension::SPV_EXT_optnone},
48         {"SPV_INTEL_usm_storage_classes",
49          SPIRV::Extension::Extension::SPV_INTEL_usm_storage_classes},
50         {"SPV_INTEL_split_barrier",
51          SPIRV::Extension::Extension::SPV_INTEL_split_barrier},
52         {"SPV_INTEL_subgroups",
53          SPIRV::Extension::Extension::SPV_INTEL_subgroups},
54         {"SPV_INTEL_media_block_io",
55          SPIRV::Extension::Extension::SPV_INTEL_media_block_io},
56         {"SPV_INTEL_memory_access_aliasing",
57          SPIRV::Extension::Extension::SPV_INTEL_memory_access_aliasing},
58         {"SPV_INTEL_joint_matrix",
59          SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
60         {"SPV_KHR_uniform_group_instructions",
61          SPIRV::Extension::Extension::SPV_KHR_uniform_group_instructions},
62         {"SPV_KHR_no_integer_wrap_decoration",
63          SPIRV::Extension::Extension::SPV_KHR_no_integer_wrap_decoration},
64         {"SPV_KHR_float_controls",
65          SPIRV::Extension::Extension::SPV_KHR_float_controls},
66         {"SPV_KHR_expect_assume",
67          SPIRV::Extension::Extension::SPV_KHR_expect_assume},
68         {"SPV_KHR_bit_instructions",
69          SPIRV::Extension::Extension::SPV_KHR_bit_instructions},
70         {"SPV_KHR_integer_dot_product",
71          SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
72         {"SPV_KHR_linkonce_odr",
73          SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
74         {"SPV_INTEL_inline_assembly",
75          SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
76         {"SPV_INTEL_bindless_images",
77          SPIRV::Extension::Extension::SPV_INTEL_bindless_images},
78         {"SPV_INTEL_bfloat16_conversion",
79          SPIRV::Extension::Extension::SPV_INTEL_bfloat16_conversion},
80         {"SPV_KHR_subgroup_rotate",
81          SPIRV::Extension::Extension::SPV_KHR_subgroup_rotate},
82         {"SPV_INTEL_variable_length_array",
83          SPIRV::Extension::Extension::SPV_INTEL_variable_length_array},
84         {"SPV_INTEL_function_pointers",
85          SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
86         {"SPV_KHR_shader_clock",
87          SPIRV::Extension::Extension::SPV_KHR_shader_clock},
88         {"SPV_KHR_cooperative_matrix",
89          SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix},
90         {"SPV_KHR_non_semantic_info",
91          SPIRV::Extension::Extension::SPV_KHR_non_semantic_info},
92         {"SPV_INTEL_long_composites",
93          SPIRV::Extension::Extension::SPV_INTEL_long_composites},
94         {"SPV_INTEL_fp_max_error",
95          SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
96         {"SPV_INTEL_subgroup_matrix_multiply_accumulate",
97          SPIRV::Extension::Extension::
98              SPV_INTEL_subgroup_matrix_multiply_accumulate},
99         {"SPV_INTEL_ternary_bitwise_function",
100          SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
101         {"SPV_INTEL_2d_block_io",
102          SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103         {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
104         {"SPV_KHR_float_controls2",
105          SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
106 
107 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
108                                   StringRef ArgValue,
109                                   std::set<SPIRV::Extension::Extension> &Vals) {
110   SmallVector<StringRef, 10> Tokens;
111   ArgValue.split(Tokens, ",", -1, false);
112   std::sort(Tokens.begin(), Tokens.end());
113 
114   std::set<SPIRV::Extension::Extension> EnabledExtensions;
115 
116   for (const auto &Token : Tokens) {
117     if (Token == "all") {
118       for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
119         EnabledExtensions.insert(ExtensionEnum);
120 
121       continue;
122     }
123 
124     if (Token.size() == 3 && Token.upper() == "KHR") {
125       for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
126         if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
127           EnabledExtensions.insert(ExtensionEnum);
128       continue;
129     }
130 
131     if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
132       return O.error("Invalid extension list format: " + Token.str());
133 
134     StringRef ExtensionName = Token.substr(1);
135     auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
136 
137     if (NameValuePair == SPIRVExtensionMap.end())
138       return O.error("Unknown SPIR-V extension: " + Token.str());
139 
140     if (Token.starts_with("+")) {
141       EnabledExtensions.insert(NameValuePair->second);
142     } else if (EnabledExtensions.count(NameValuePair->second)) {
143       if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
144         return O.error(
145             "Extension cannot be allowed and disallowed at the same time: " +
146             ExtensionName.str());
147 
148       EnabledExtensions.erase(NameValuePair->second);
149     }
150   }
151 
152   Vals = std::move(EnabledExtensions);
153   return false;
154 }
155 
156 StringRef SPIRVExtensionsParser::checkExtensions(
157     const std::vector<std::string> &ExtNames,
158     std::set<SPIRV::Extension::Extension> &AllowedExtensions) {
159   for (const auto &Ext : ExtNames) {
160     if (Ext == "all") {
161       for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
162         AllowedExtensions.insert(ExtensionEnum);
163       break;
164     }
165     auto It = SPIRVExtensionMap.find(Ext);
166     if (It == SPIRVExtensionMap.end())
167       return Ext;
168     AllowedExtensions.insert(It->second);
169   }
170   return StringRef();
171 }
172