xref: /freebsd/contrib/llvm-project/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp (revision 1165fc9a526630487a1feb63daef65c5aee1a583)
1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/STLForwardCompat.h"
18 #include "llvm/ADT/StringSwitch.h"
19 #include "llvm/BinaryFormat/MsgPackDocument.h"
20 
21 #include <map>
22 #include <utility>
23 
24 namespace llvm {
25 namespace AMDGPU {
26 namespace HSAMD {
27 namespace V3 {
28 
29 bool MetadataVerifier::verifyScalar(
30     msgpack::DocNode &Node, msgpack::Type SKind,
31     function_ref<bool(msgpack::DocNode &)> verifyValue) {
32   if (!Node.isScalar())
33     return false;
34   if (Node.getKind() != SKind) {
35     if (Strict)
36       return false;
37     // If we are not strict, we interpret string values as "implicitly typed"
38     // and attempt to coerce them to the expected type here.
39     if (Node.getKind() != msgpack::Type::String)
40       return false;
41     StringRef StringValue = Node.getString();
42     Node.fromString(StringValue);
43     if (Node.getKind() != SKind)
44       return false;
45   }
46   if (verifyValue)
47     return verifyValue(Node);
48   return true;
49 }
50 
51 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
52   if (!verifyScalar(Node, msgpack::Type::UInt))
53     if (!verifyScalar(Node, msgpack::Type::Int))
54       return false;
55   return true;
56 }
57 
58 bool MetadataVerifier::verifyArray(
59     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
60     Optional<size_t> Size) {
61   if (!Node.isArray())
62     return false;
63   auto &Array = Node.getArray();
64   if (Size && Array.size() != *Size)
65     return false;
66   return llvm::all_of(Array, verifyNode);
67 }
68 
69 bool MetadataVerifier::verifyEntry(
70     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
71     function_ref<bool(msgpack::DocNode &)> verifyNode) {
72   auto Entry = MapNode.find(Key);
73   if (Entry == MapNode.end())
74     return !Required;
75   return verifyNode(Entry->second);
76 }
77 
78 bool MetadataVerifier::verifyScalarEntry(
79     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
80     msgpack::Type SKind,
81     function_ref<bool(msgpack::DocNode &)> verifyValue) {
82   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
83     return verifyScalar(Node, SKind, verifyValue);
84   });
85 }
86 
87 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
88                                           StringRef Key, bool Required) {
89   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
90     return verifyInteger(Node);
91   });
92 }
93 
94 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
95   if (!Node.isMap())
96     return false;
97   auto &ArgsMap = Node.getMap();
98 
99   if (!verifyScalarEntry(ArgsMap, ".name", false,
100                          msgpack::Type::String))
101     return false;
102   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
103                          msgpack::Type::String))
104     return false;
105   if (!verifyIntegerEntry(ArgsMap, ".size", true))
106     return false;
107   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
108     return false;
109   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
110                          msgpack::Type::String,
111                          [](msgpack::DocNode &SNode) {
112                            return StringSwitch<bool>(SNode.getString())
113                                .Case("by_value", true)
114                                .Case("global_buffer", true)
115                                .Case("dynamic_shared_pointer", true)
116                                .Case("sampler", true)
117                                .Case("image", true)
118                                .Case("pipe", true)
119                                .Case("queue", true)
120                                .Case("hidden_block_count_x", true)
121                                .Case("hidden_block_count_y", true)
122                                .Case("hidden_block_count_z", true)
123                                .Case("hidden_group_size_x", true)
124                                .Case("hidden_group_size_y", true)
125                                .Case("hidden_group_size_z", true)
126                                .Case("hidden_remainder_x", true)
127                                .Case("hidden_remainder_y", true)
128                                .Case("hidden_remainder_z", true)
129                                .Case("hidden_global_offset_x", true)
130                                .Case("hidden_global_offset_y", true)
131                                .Case("hidden_global_offset_z", true)
132                                .Case("hidden_grid_dims", true)
133                                .Case("hidden_none", true)
134                                .Case("hidden_printf_buffer", true)
135                                .Case("hidden_hostcall_buffer", true)
136                                .Case("hidden_default_queue", true)
137                                .Case("hidden_completion_action", true)
138                                .Case("hidden_multigrid_sync_arg", true)
139                                .Case("hidden_private_base", true)
140                                .Case("hidden_shared_base", true)
141                                .Case("hidden_queue_ptr", true)
142                                .Default(false);
143                          }))
144     return false;
145   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
146     return false;
147   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
148                          msgpack::Type::String,
149                          [](msgpack::DocNode &SNode) {
150                            return StringSwitch<bool>(SNode.getString())
151                                .Case("private", true)
152                                .Case("global", true)
153                                .Case("constant", true)
154                                .Case("local", true)
155                                .Case("generic", true)
156                                .Case("region", true)
157                                .Default(false);
158                          }))
159     return false;
160   if (!verifyScalarEntry(ArgsMap, ".access", false,
161                          msgpack::Type::String,
162                          [](msgpack::DocNode &SNode) {
163                            return StringSwitch<bool>(SNode.getString())
164                                .Case("read_only", true)
165                                .Case("write_only", true)
166                                .Case("read_write", true)
167                                .Default(false);
168                          }))
169     return false;
170   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
171                          msgpack::Type::String,
172                          [](msgpack::DocNode &SNode) {
173                            return StringSwitch<bool>(SNode.getString())
174                                .Case("read_only", true)
175                                .Case("write_only", true)
176                                .Case("read_write", true)
177                                .Default(false);
178                          }))
179     return false;
180   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
181                          msgpack::Type::Boolean))
182     return false;
183   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
184                          msgpack::Type::Boolean))
185     return false;
186   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
187                          msgpack::Type::Boolean))
188     return false;
189   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
190                          msgpack::Type::Boolean))
191     return false;
192 
193   return true;
194 }
195 
196 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
197   if (!Node.isMap())
198     return false;
199   auto &KernelMap = Node.getMap();
200 
201   if (!verifyScalarEntry(KernelMap, ".name", true,
202                          msgpack::Type::String))
203     return false;
204   if (!verifyScalarEntry(KernelMap, ".symbol", true,
205                          msgpack::Type::String))
206     return false;
207   if (!verifyScalarEntry(KernelMap, ".language", false,
208                          msgpack::Type::String,
209                          [](msgpack::DocNode &SNode) {
210                            return StringSwitch<bool>(SNode.getString())
211                                .Case("OpenCL C", true)
212                                .Case("OpenCL C++", true)
213                                .Case("HCC", true)
214                                .Case("HIP", true)
215                                .Case("OpenMP", true)
216                                .Case("Assembler", true)
217                                .Default(false);
218                          }))
219     return false;
220   if (!verifyEntry(
221           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
222             return verifyArray(
223                 Node,
224                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
225           }))
226     return false;
227   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
228         return verifyArray(Node, [this](msgpack::DocNode &Node) {
229           return verifyKernelArgs(Node);
230         });
231       }))
232     return false;
233   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
234                    [this](msgpack::DocNode &Node) {
235                      return verifyArray(Node,
236                                         [this](msgpack::DocNode &Node) {
237                                           return verifyInteger(Node);
238                                         },
239                                         3);
240                    }))
241     return false;
242   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
243                    [this](msgpack::DocNode &Node) {
244                      return verifyArray(Node,
245                                         [this](msgpack::DocNode &Node) {
246                                           return verifyInteger(Node);
247                                         },
248                                         3);
249                    }))
250     return false;
251   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
252                          msgpack::Type::String))
253     return false;
254   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
255                          msgpack::Type::String))
256     return false;
257   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
258     return false;
259   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
260     return false;
261   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
262     return false;
263   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
264     return false;
265   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
266     return false;
267   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
268     return false;
269   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
270     return false;
271   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
272     return false;
273   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
274     return false;
275   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
276     return false;
277 
278   return true;
279 }
280 
281 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
282   if (!HSAMetadataRoot.isMap())
283     return false;
284   auto &RootMap = HSAMetadataRoot.getMap();
285 
286   if (!verifyEntry(
287           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
288             return verifyArray(
289                 Node,
290                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
291           }))
292     return false;
293   if (!verifyEntry(
294           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
295             return verifyArray(Node, [this](msgpack::DocNode &Node) {
296               return verifyScalar(Node, msgpack::Type::String);
297             });
298           }))
299     return false;
300   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
301                    [this](msgpack::DocNode &Node) {
302                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
303                        return verifyKernel(Node);
304                      });
305                    }))
306     return false;
307 
308   return true;
309 }
310 
311 } // end namespace V3
312 } // end namespace HSAMD
313 } // end namespace AMDGPU
314 } // end namespace llvm
315