xref: /freebsd/contrib/llvm-project/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp (revision 96b917bfcf2c558340f8f6e620e64efededa1daf)
1  //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  /// \file
10  /// Implements a verifier for AMDGPU HSA metadata.
11  //
12  //===----------------------------------------------------------------------===//
13  
14  #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15  
16  #include "llvm/ADT/STLExtras.h"
17  #include "llvm/ADT/STLForwardCompat.h"
18  #include "llvm/ADT/StringSwitch.h"
19  #include "llvm/BinaryFormat/MsgPackDocument.h"
20  
21  #include <map>
22  #include <utility>
23  
24  namespace llvm {
25  namespace AMDGPU {
26  namespace HSAMD {
27  namespace V3 {
28  
29  bool MetadataVerifier::verifyScalar(
30      msgpack::DocNode &Node, msgpack::Type SKind,
31      function_ref<bool(msgpack::DocNode &)> verifyValue) {
32    if (!Node.isScalar())
33      return false;
34    if (Node.getKind() != SKind) {
35      if (Strict)
36        return false;
37      // If we are not strict, we interpret string values as "implicitly typed"
38      // and attempt to coerce them to the expected type here.
39      if (Node.getKind() != msgpack::Type::String)
40        return false;
41      StringRef StringValue = Node.getString();
42      Node.fromString(StringValue);
43      if (Node.getKind() != SKind)
44        return false;
45    }
46    if (verifyValue)
47      return verifyValue(Node);
48    return true;
49  }
50  
51  bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
52    if (!verifyScalar(Node, msgpack::Type::UInt))
53      if (!verifyScalar(Node, msgpack::Type::Int))
54        return false;
55    return true;
56  }
57  
58  bool MetadataVerifier::verifyArray(
59      msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
60      Optional<size_t> Size) {
61    if (!Node.isArray())
62      return false;
63    auto &Array = Node.getArray();
64    if (Size && Array.size() != *Size)
65      return false;
66    return llvm::all_of(Array, verifyNode);
67  }
68  
69  bool MetadataVerifier::verifyEntry(
70      msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
71      function_ref<bool(msgpack::DocNode &)> verifyNode) {
72    auto Entry = MapNode.find(Key);
73    if (Entry == MapNode.end())
74      return !Required;
75    return verifyNode(Entry->second);
76  }
77  
78  bool MetadataVerifier::verifyScalarEntry(
79      msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
80      msgpack::Type SKind,
81      function_ref<bool(msgpack::DocNode &)> verifyValue) {
82    return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
83      return verifyScalar(Node, SKind, verifyValue);
84    });
85  }
86  
87  bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
88                                            StringRef Key, bool Required) {
89    return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
90      return verifyInteger(Node);
91    });
92  }
93  
94  bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
95    if (!Node.isMap())
96      return false;
97    auto &ArgsMap = Node.getMap();
98  
99    if (!verifyScalarEntry(ArgsMap, ".name", false,
100                           msgpack::Type::String))
101      return false;
102    if (!verifyScalarEntry(ArgsMap, ".type_name", false,
103                           msgpack::Type::String))
104      return false;
105    if (!verifyIntegerEntry(ArgsMap, ".size", true))
106      return false;
107    if (!verifyIntegerEntry(ArgsMap, ".offset", true))
108      return false;
109    if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
110                           msgpack::Type::String,
111                           [](msgpack::DocNode &SNode) {
112                             return StringSwitch<bool>(SNode.getString())
113                                 .Case("by_value", true)
114                                 .Case("global_buffer", true)
115                                 .Case("dynamic_shared_pointer", true)
116                                 .Case("sampler", true)
117                                 .Case("image", true)
118                                 .Case("pipe", true)
119                                 .Case("queue", true)
120                                 .Case("hidden_block_count_x", true)
121                                 .Case("hidden_block_count_y", true)
122                                 .Case("hidden_block_count_z", true)
123                                 .Case("hidden_group_size_x", true)
124                                 .Case("hidden_group_size_y", true)
125                                 .Case("hidden_group_size_z", true)
126                                 .Case("hidden_remainder_x", true)
127                                 .Case("hidden_remainder_y", true)
128                                 .Case("hidden_remainder_z", true)
129                                 .Case("hidden_global_offset_x", true)
130                                 .Case("hidden_global_offset_y", true)
131                                 .Case("hidden_global_offset_z", true)
132                                 .Case("hidden_grid_dims", true)
133                                 .Case("hidden_none", true)
134                                 .Case("hidden_printf_buffer", true)
135                                 .Case("hidden_hostcall_buffer", true)
136                                 .Case("hidden_default_queue", true)
137                                 .Case("hidden_completion_action", true)
138                                 .Case("hidden_multigrid_sync_arg", true)
139                                 .Case("hidden_private_base", true)
140                                 .Case("hidden_shared_base", true)
141                                 .Case("hidden_queue_ptr", true)
142                                 .Default(false);
143                           }))
144      return false;
145    if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
146      return false;
147    if (!verifyScalarEntry(ArgsMap, ".address_space", false,
148                           msgpack::Type::String,
149                           [](msgpack::DocNode &SNode) {
150                             return StringSwitch<bool>(SNode.getString())
151                                 .Case("private", true)
152                                 .Case("global", true)
153                                 .Case("constant", true)
154                                 .Case("local", true)
155                                 .Case("generic", true)
156                                 .Case("region", true)
157                                 .Default(false);
158                           }))
159      return false;
160    if (!verifyScalarEntry(ArgsMap, ".access", false,
161                           msgpack::Type::String,
162                           [](msgpack::DocNode &SNode) {
163                             return StringSwitch<bool>(SNode.getString())
164                                 .Case("read_only", true)
165                                 .Case("write_only", true)
166                                 .Case("read_write", true)
167                                 .Default(false);
168                           }))
169      return false;
170    if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
171                           msgpack::Type::String,
172                           [](msgpack::DocNode &SNode) {
173                             return StringSwitch<bool>(SNode.getString())
174                                 .Case("read_only", true)
175                                 .Case("write_only", true)
176                                 .Case("read_write", true)
177                                 .Default(false);
178                           }))
179      return false;
180    if (!verifyScalarEntry(ArgsMap, ".is_const", false,
181                           msgpack::Type::Boolean))
182      return false;
183    if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
184                           msgpack::Type::Boolean))
185      return false;
186    if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
187                           msgpack::Type::Boolean))
188      return false;
189    if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
190                           msgpack::Type::Boolean))
191      return false;
192  
193    return true;
194  }
195  
196  bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
197    if (!Node.isMap())
198      return false;
199    auto &KernelMap = Node.getMap();
200  
201    if (!verifyScalarEntry(KernelMap, ".name", true,
202                           msgpack::Type::String))
203      return false;
204    if (!verifyScalarEntry(KernelMap, ".symbol", true,
205                           msgpack::Type::String))
206      return false;
207    if (!verifyScalarEntry(KernelMap, ".language", false,
208                           msgpack::Type::String,
209                           [](msgpack::DocNode &SNode) {
210                             return StringSwitch<bool>(SNode.getString())
211                                 .Case("OpenCL C", true)
212                                 .Case("OpenCL C++", true)
213                                 .Case("HCC", true)
214                                 .Case("HIP", true)
215                                 .Case("OpenMP", true)
216                                 .Case("Assembler", true)
217                                 .Default(false);
218                           }))
219      return false;
220    if (!verifyEntry(
221            KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
222              return verifyArray(
223                  Node,
224                  [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
225            }))
226      return false;
227    if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
228          return verifyArray(Node, [this](msgpack::DocNode &Node) {
229            return verifyKernelArgs(Node);
230          });
231        }))
232      return false;
233    if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
234                     [this](msgpack::DocNode &Node) {
235                       return verifyArray(Node,
236                                          [this](msgpack::DocNode &Node) {
237                                            return verifyInteger(Node);
238                                          },
239                                          3);
240                     }))
241      return false;
242    if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
243                     [this](msgpack::DocNode &Node) {
244                       return verifyArray(Node,
245                                          [this](msgpack::DocNode &Node) {
246                                            return verifyInteger(Node);
247                                          },
248                                          3);
249                     }))
250      return false;
251    if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
252                           msgpack::Type::String))
253      return false;
254    if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
255                           msgpack::Type::String))
256      return false;
257    if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
258      return false;
259    if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
260      return false;
261    if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
262      return false;
263    if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
264      return false;
265    if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
266      return false;
267    if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
268      return false;
269    if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
270      return false;
271    if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
272      return false;
273    if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
274      return false;
275    if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
276      return false;
277  
278    return true;
279  }
280  
281  bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
282    if (!HSAMetadataRoot.isMap())
283      return false;
284    auto &RootMap = HSAMetadataRoot.getMap();
285  
286    if (!verifyEntry(
287            RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
288              return verifyArray(
289                  Node,
290                  [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
291            }))
292      return false;
293    if (!verifyEntry(
294            RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
295              return verifyArray(Node, [this](msgpack::DocNode &Node) {
296                return verifyScalar(Node, msgpack::Type::String);
297              });
298            }))
299      return false;
300    if (!verifyEntry(RootMap, "amdhsa.kernels", true,
301                     [this](msgpack::DocNode &Node) {
302                       return verifyArray(Node, [this](msgpack::DocNode &Node) {
303                         return verifyKernel(Node);
304                       });
305                     }))
306      return false;
307  
308    return true;
309  }
310  
311  } // end namespace V3
312  } // end namespace HSAMD
313  } // end namespace AMDGPU
314  } // end namespace llvm
315