xref: /freebsd/contrib/llvm-project/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp (revision 4b50c451720d8b427757a6da1dd2bb4c52cd9e35)
1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/Support/AMDGPUMetadata.h"
16 
17 namespace llvm {
18 namespace AMDGPU {
19 namespace HSAMD {
20 namespace V3 {
21 
22 bool MetadataVerifier::verifyScalar(
23     msgpack::DocNode &Node, msgpack::Type SKind,
24     function_ref<bool(msgpack::DocNode &)> verifyValue) {
25   if (!Node.isScalar())
26     return false;
27   if (Node.getKind() != SKind) {
28     if (Strict)
29       return false;
30     // If we are not strict, we interpret string values as "implicitly typed"
31     // and attempt to coerce them to the expected type here.
32     if (Node.getKind() != msgpack::Type::String)
33       return false;
34     StringRef StringValue = Node.getString();
35     Node.fromString(StringValue);
36     if (Node.getKind() != SKind)
37       return false;
38   }
39   if (verifyValue)
40     return verifyValue(Node);
41   return true;
42 }
43 
44 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
45   if (!verifyScalar(Node, msgpack::Type::UInt))
46     if (!verifyScalar(Node, msgpack::Type::Int))
47       return false;
48   return true;
49 }
50 
51 bool MetadataVerifier::verifyArray(
52     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
53     Optional<size_t> Size) {
54   if (!Node.isArray())
55     return false;
56   auto &Array = Node.getArray();
57   if (Size && Array.size() != *Size)
58     return false;
59   for (auto &Item : Array)
60     if (!verifyNode(Item))
61       return false;
62 
63   return true;
64 }
65 
66 bool MetadataVerifier::verifyEntry(
67     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
68     function_ref<bool(msgpack::DocNode &)> verifyNode) {
69   auto Entry = MapNode.find(Key);
70   if (Entry == MapNode.end())
71     return !Required;
72   return verifyNode(Entry->second);
73 }
74 
75 bool MetadataVerifier::verifyScalarEntry(
76     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
77     msgpack::Type SKind,
78     function_ref<bool(msgpack::DocNode &)> verifyValue) {
79   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
80     return verifyScalar(Node, SKind, verifyValue);
81   });
82 }
83 
84 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
85                                           StringRef Key, bool Required) {
86   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
87     return verifyInteger(Node);
88   });
89 }
90 
91 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
92   if (!Node.isMap())
93     return false;
94   auto &ArgsMap = Node.getMap();
95 
96   if (!verifyScalarEntry(ArgsMap, ".name", false,
97                          msgpack::Type::String))
98     return false;
99   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
100                          msgpack::Type::String))
101     return false;
102   if (!verifyIntegerEntry(ArgsMap, ".size", true))
103     return false;
104   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
105     return false;
106   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
107                          msgpack::Type::String,
108                          [](msgpack::DocNode &SNode) {
109                            return StringSwitch<bool>(SNode.getString())
110                                .Case("by_value", true)
111                                .Case("global_buffer", true)
112                                .Case("dynamic_shared_pointer", true)
113                                .Case("sampler", true)
114                                .Case("image", true)
115                                .Case("pipe", true)
116                                .Case("queue", true)
117                                .Case("hidden_global_offset_x", true)
118                                .Case("hidden_global_offset_y", true)
119                                .Case("hidden_global_offset_z", true)
120                                .Case("hidden_none", true)
121                                .Case("hidden_printf_buffer", true)
122                                .Case("hidden_default_queue", true)
123                                .Case("hidden_completion_action", true)
124                                .Case("hidden_multigrid_sync_arg", true)
125                                .Default(false);
126                          }))
127     return false;
128   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
129                          msgpack::Type::String,
130                          [](msgpack::DocNode &SNode) {
131                            return StringSwitch<bool>(SNode.getString())
132                                .Case("struct", true)
133                                .Case("i8", true)
134                                .Case("u8", true)
135                                .Case("i16", true)
136                                .Case("u16", true)
137                                .Case("f16", true)
138                                .Case("i32", true)
139                                .Case("u32", true)
140                                .Case("f32", true)
141                                .Case("i64", true)
142                                .Case("u64", true)
143                                .Case("f64", true)
144                                .Default(false);
145                          }))
146     return false;
147   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
148     return false;
149   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
150                          msgpack::Type::String,
151                          [](msgpack::DocNode &SNode) {
152                            return StringSwitch<bool>(SNode.getString())
153                                .Case("private", true)
154                                .Case("global", true)
155                                .Case("constant", true)
156                                .Case("local", true)
157                                .Case("generic", true)
158                                .Case("region", true)
159                                .Default(false);
160                          }))
161     return false;
162   if (!verifyScalarEntry(ArgsMap, ".access", false,
163                          msgpack::Type::String,
164                          [](msgpack::DocNode &SNode) {
165                            return StringSwitch<bool>(SNode.getString())
166                                .Case("read_only", true)
167                                .Case("write_only", true)
168                                .Case("read_write", true)
169                                .Default(false);
170                          }))
171     return false;
172   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
173                          msgpack::Type::String,
174                          [](msgpack::DocNode &SNode) {
175                            return StringSwitch<bool>(SNode.getString())
176                                .Case("read_only", true)
177                                .Case("write_only", true)
178                                .Case("read_write", true)
179                                .Default(false);
180                          }))
181     return false;
182   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
183                          msgpack::Type::Boolean))
184     return false;
185   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
186                          msgpack::Type::Boolean))
187     return false;
188   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
189                          msgpack::Type::Boolean))
190     return false;
191   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
192                          msgpack::Type::Boolean))
193     return false;
194 
195   return true;
196 }
197 
198 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
199   if (!Node.isMap())
200     return false;
201   auto &KernelMap = Node.getMap();
202 
203   if (!verifyScalarEntry(KernelMap, ".name", true,
204                          msgpack::Type::String))
205     return false;
206   if (!verifyScalarEntry(KernelMap, ".symbol", true,
207                          msgpack::Type::String))
208     return false;
209   if (!verifyScalarEntry(KernelMap, ".language", false,
210                          msgpack::Type::String,
211                          [](msgpack::DocNode &SNode) {
212                            return StringSwitch<bool>(SNode.getString())
213                                .Case("OpenCL C", true)
214                                .Case("OpenCL C++", true)
215                                .Case("HCC", true)
216                                .Case("HIP", true)
217                                .Case("OpenMP", true)
218                                .Case("Assembler", true)
219                                .Default(false);
220                          }))
221     return false;
222   if (!verifyEntry(
223           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
224             return verifyArray(
225                 Node,
226                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
227           }))
228     return false;
229   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
230         return verifyArray(Node, [this](msgpack::DocNode &Node) {
231           return verifyKernelArgs(Node);
232         });
233       }))
234     return false;
235   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
236                    [this](msgpack::DocNode &Node) {
237                      return verifyArray(Node,
238                                         [this](msgpack::DocNode &Node) {
239                                           return verifyInteger(Node);
240                                         },
241                                         3);
242                    }))
243     return false;
244   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
245                    [this](msgpack::DocNode &Node) {
246                      return verifyArray(Node,
247                                         [this](msgpack::DocNode &Node) {
248                                           return verifyInteger(Node);
249                                         },
250                                         3);
251                    }))
252     return false;
253   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
254                          msgpack::Type::String))
255     return false;
256   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
257                          msgpack::Type::String))
258     return false;
259   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
260     return false;
261   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
262     return false;
263   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
264     return false;
265   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
266     return false;
267   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
268     return false;
269   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
270     return false;
271   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
272     return false;
273   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
274     return false;
275   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
276     return false;
277   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
278     return false;
279 
280   return true;
281 }
282 
283 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
284   if (!HSAMetadataRoot.isMap())
285     return false;
286   auto &RootMap = HSAMetadataRoot.getMap();
287 
288   if (!verifyEntry(
289           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
290             return verifyArray(
291                 Node,
292                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
293           }))
294     return false;
295   if (!verifyEntry(
296           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
297             return verifyArray(Node, [this](msgpack::DocNode &Node) {
298               return verifyScalar(Node, msgpack::Type::String);
299             });
300           }))
301     return false;
302   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
303                    [this](msgpack::DocNode &Node) {
304                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
305                        return verifyKernel(Node);
306                      });
307                    }))
308     return false;
309 
310   return true;
311 }
312 
313 } // end namespace V3
314 } // end namespace HSAMD
315 } // end namespace AMDGPU
316 } // end namespace llvm
317