1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Implements a verifier for AMDGPU HSA metadata. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h" 15 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/ADT/STLForwardCompat.h" 18 #include "llvm/ADT/StringSwitch.h" 19 #include "llvm/BinaryFormat/MsgPackDocument.h" 20 21 #include <map> 22 #include <utility> 23 24 namespace llvm { 25 namespace AMDGPU { 26 namespace HSAMD { 27 namespace V3 { 28 29 bool MetadataVerifier::verifyScalar( 30 msgpack::DocNode &Node, msgpack::Type SKind, 31 function_ref<bool(msgpack::DocNode &)> verifyValue) { 32 if (!Node.isScalar()) 33 return false; 34 if (Node.getKind() != SKind) { 35 if (Strict) 36 return false; 37 // If we are not strict, we interpret string values as "implicitly typed" 38 // and attempt to coerce them to the expected type here. 39 if (Node.getKind() != msgpack::Type::String) 40 return false; 41 StringRef StringValue = Node.getString(); 42 Node.fromString(StringValue); 43 if (Node.getKind() != SKind) 44 return false; 45 } 46 if (verifyValue) 47 return verifyValue(Node); 48 return true; 49 } 50 51 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) { 52 if (!verifyScalar(Node, msgpack::Type::UInt)) 53 if (!verifyScalar(Node, msgpack::Type::Int)) 54 return false; 55 return true; 56 } 57 58 bool MetadataVerifier::verifyArray( 59 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode, 60 Optional<size_t> Size) { 61 if (!Node.isArray()) 62 return false; 63 auto &Array = Node.getArray(); 64 if (Size && Array.size() != *Size) 65 return false; 66 return llvm::all_of(Array, verifyNode); 67 } 68 69 bool MetadataVerifier::verifyEntry( 70 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 71 function_ref<bool(msgpack::DocNode &)> verifyNode) { 72 auto Entry = MapNode.find(Key); 73 if (Entry == MapNode.end()) 74 return !Required; 75 return verifyNode(Entry->second); 76 } 77 78 bool MetadataVerifier::verifyScalarEntry( 79 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 80 msgpack::Type SKind, 81 function_ref<bool(msgpack::DocNode &)> verifyValue) { 82 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) { 83 return verifyScalar(Node, SKind, verifyValue); 84 }); 85 } 86 87 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode, 88 StringRef Key, bool Required) { 89 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) { 90 return verifyInteger(Node); 91 }); 92 } 93 94 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) { 95 if (!Node.isMap()) 96 return false; 97 auto &ArgsMap = Node.getMap(); 98 99 if (!verifyScalarEntry(ArgsMap, ".name", false, 100 msgpack::Type::String)) 101 return false; 102 if (!verifyScalarEntry(ArgsMap, ".type_name", false, 103 msgpack::Type::String)) 104 return false; 105 if (!verifyIntegerEntry(ArgsMap, ".size", true)) 106 return false; 107 if (!verifyIntegerEntry(ArgsMap, ".offset", true)) 108 return false; 109 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String, 110 [](msgpack::DocNode &SNode) { 111 return StringSwitch<bool>(SNode.getString()) 112 .Case("by_value", true) 113 .Case("global_buffer", true) 114 .Case("dynamic_shared_pointer", true) 115 .Case("sampler", true) 116 .Case("image", true) 117 .Case("pipe", true) 118 .Case("queue", true) 119 .Case("hidden_block_count_x", true) 120 .Case("hidden_block_count_y", true) 121 .Case("hidden_block_count_z", true) 122 .Case("hidden_group_size_x", true) 123 .Case("hidden_group_size_y", true) 124 .Case("hidden_group_size_z", true) 125 .Case("hidden_remainder_x", true) 126 .Case("hidden_remainder_y", true) 127 .Case("hidden_remainder_z", true) 128 .Case("hidden_global_offset_x", true) 129 .Case("hidden_global_offset_y", true) 130 .Case("hidden_global_offset_z", true) 131 .Case("hidden_grid_dims", true) 132 .Case("hidden_none", true) 133 .Case("hidden_printf_buffer", true) 134 .Case("hidden_hostcall_buffer", true) 135 .Case("hidden_heap_v1", true) 136 .Case("hidden_default_queue", true) 137 .Case("hidden_completion_action", true) 138 .Case("hidden_multigrid_sync_arg", true) 139 .Case("hidden_private_base", true) 140 .Case("hidden_shared_base", true) 141 .Case("hidden_queue_ptr", true) 142 .Default(false); 143 })) 144 return false; 145 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false)) 146 return false; 147 if (!verifyScalarEntry(ArgsMap, ".address_space", false, 148 msgpack::Type::String, 149 [](msgpack::DocNode &SNode) { 150 return StringSwitch<bool>(SNode.getString()) 151 .Case("private", true) 152 .Case("global", true) 153 .Case("constant", true) 154 .Case("local", true) 155 .Case("generic", true) 156 .Case("region", true) 157 .Default(false); 158 })) 159 return false; 160 if (!verifyScalarEntry(ArgsMap, ".access", false, 161 msgpack::Type::String, 162 [](msgpack::DocNode &SNode) { 163 return StringSwitch<bool>(SNode.getString()) 164 .Case("read_only", true) 165 .Case("write_only", true) 166 .Case("read_write", true) 167 .Default(false); 168 })) 169 return false; 170 if (!verifyScalarEntry(ArgsMap, ".actual_access", false, 171 msgpack::Type::String, 172 [](msgpack::DocNode &SNode) { 173 return StringSwitch<bool>(SNode.getString()) 174 .Case("read_only", true) 175 .Case("write_only", true) 176 .Case("read_write", true) 177 .Default(false); 178 })) 179 return false; 180 if (!verifyScalarEntry(ArgsMap, ".is_const", false, 181 msgpack::Type::Boolean)) 182 return false; 183 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false, 184 msgpack::Type::Boolean)) 185 return false; 186 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false, 187 msgpack::Type::Boolean)) 188 return false; 189 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false, 190 msgpack::Type::Boolean)) 191 return false; 192 193 return true; 194 } 195 196 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) { 197 if (!Node.isMap()) 198 return false; 199 auto &KernelMap = Node.getMap(); 200 201 if (!verifyScalarEntry(KernelMap, ".name", true, 202 msgpack::Type::String)) 203 return false; 204 if (!verifyScalarEntry(KernelMap, ".symbol", true, 205 msgpack::Type::String)) 206 return false; 207 if (!verifyScalarEntry(KernelMap, ".language", false, 208 msgpack::Type::String, 209 [](msgpack::DocNode &SNode) { 210 return StringSwitch<bool>(SNode.getString()) 211 .Case("OpenCL C", true) 212 .Case("OpenCL C++", true) 213 .Case("HCC", true) 214 .Case("HIP", true) 215 .Case("OpenMP", true) 216 .Case("Assembler", true) 217 .Default(false); 218 })) 219 return false; 220 if (!verifyEntry( 221 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) { 222 return verifyArray( 223 Node, 224 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 225 })) 226 return false; 227 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) { 228 return verifyArray(Node, [this](msgpack::DocNode &Node) { 229 return verifyKernelArgs(Node); 230 }); 231 })) 232 return false; 233 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false, 234 [this](msgpack::DocNode &Node) { 235 return verifyArray(Node, 236 [this](msgpack::DocNode &Node) { 237 return verifyInteger(Node); 238 }, 239 3); 240 })) 241 return false; 242 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false, 243 [this](msgpack::DocNode &Node) { 244 return verifyArray(Node, 245 [this](msgpack::DocNode &Node) { 246 return verifyInteger(Node); 247 }, 248 3); 249 })) 250 return false; 251 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false, 252 msgpack::Type::String)) 253 return false; 254 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false, 255 msgpack::Type::String)) 256 return false; 257 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true)) 258 return false; 259 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true)) 260 return false; 261 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true)) 262 return false; 263 if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false, 264 msgpack::Type::Boolean)) 265 return false; 266 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true)) 267 return false; 268 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true)) 269 return false; 270 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true)) 271 return false; 272 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true)) 273 return false; 274 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true)) 275 return false; 276 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false)) 277 return false; 278 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false)) 279 return false; 280 281 return true; 282 } 283 284 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) { 285 if (!HSAMetadataRoot.isMap()) 286 return false; 287 auto &RootMap = HSAMetadataRoot.getMap(); 288 289 if (!verifyEntry( 290 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) { 291 return verifyArray( 292 Node, 293 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 294 })) 295 return false; 296 if (!verifyEntry( 297 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) { 298 return verifyArray(Node, [this](msgpack::DocNode &Node) { 299 return verifyScalar(Node, msgpack::Type::String); 300 }); 301 })) 302 return false; 303 if (!verifyEntry(RootMap, "amdhsa.kernels", true, 304 [this](msgpack::DocNode &Node) { 305 return verifyArray(Node, [this](msgpack::DocNode &Node) { 306 return verifyKernel(Node); 307 }); 308 })) 309 return false; 310 311 return true; 312 } 313 314 } // end namespace V3 315 } // end namespace HSAMD 316 } // end namespace AMDGPU 317 } // end namespace llvm 318