1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Implements a verifier for AMDGPU HSA metadata. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h" 15 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/ADT/StringSwitch.h" 18 #include "llvm/BinaryFormat/MsgPackDocument.h" 19 20 #include <utility> 21 22 namespace llvm { 23 namespace AMDGPU { 24 namespace HSAMD { 25 namespace V3 { 26 27 bool MetadataVerifier::verifyScalar( 28 msgpack::DocNode &Node, msgpack::Type SKind, 29 function_ref<bool(msgpack::DocNode &)> verifyValue) { 30 if (!Node.isScalar()) 31 return false; 32 if (Node.getKind() != SKind) { 33 if (Strict) 34 return false; 35 // If we are not strict, we interpret string values as "implicitly typed" 36 // and attempt to coerce them to the expected type here. 37 if (Node.getKind() != msgpack::Type::String) 38 return false; 39 StringRef StringValue = Node.getString(); 40 Node.fromString(StringValue); 41 if (Node.getKind() != SKind) 42 return false; 43 } 44 if (verifyValue) 45 return verifyValue(Node); 46 return true; 47 } 48 49 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) { 50 if (!verifyScalar(Node, msgpack::Type::UInt)) 51 if (!verifyScalar(Node, msgpack::Type::Int)) 52 return false; 53 return true; 54 } 55 56 bool MetadataVerifier::verifyArray( 57 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode, 58 std::optional<size_t> Size) { 59 if (!Node.isArray()) 60 return false; 61 auto &Array = Node.getArray(); 62 if (Size && Array.size() != *Size) 63 return false; 64 return llvm::all_of(Array, verifyNode); 65 } 66 67 bool MetadataVerifier::verifyEntry( 68 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 69 function_ref<bool(msgpack::DocNode &)> verifyNode) { 70 auto Entry = MapNode.find(Key); 71 if (Entry == MapNode.end()) 72 return !Required; 73 return verifyNode(Entry->second); 74 } 75 76 bool MetadataVerifier::verifyScalarEntry( 77 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 78 msgpack::Type SKind, 79 function_ref<bool(msgpack::DocNode &)> verifyValue) { 80 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) { 81 return verifyScalar(Node, SKind, verifyValue); 82 }); 83 } 84 85 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode, 86 StringRef Key, bool Required) { 87 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) { 88 return verifyInteger(Node); 89 }); 90 } 91 92 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) { 93 if (!Node.isMap()) 94 return false; 95 auto &ArgsMap = Node.getMap(); 96 97 if (!verifyScalarEntry(ArgsMap, ".name", false, 98 msgpack::Type::String)) 99 return false; 100 if (!verifyScalarEntry(ArgsMap, ".type_name", false, 101 msgpack::Type::String)) 102 return false; 103 if (!verifyIntegerEntry(ArgsMap, ".size", true)) 104 return false; 105 if (!verifyIntegerEntry(ArgsMap, ".offset", true)) 106 return false; 107 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String, 108 [](msgpack::DocNode &SNode) { 109 return StringSwitch<bool>(SNode.getString()) 110 .Case("by_value", true) 111 .Case("global_buffer", true) 112 .Case("dynamic_shared_pointer", true) 113 .Case("sampler", true) 114 .Case("image", true) 115 .Case("pipe", true) 116 .Case("queue", true) 117 .Case("hidden_block_count_x", true) 118 .Case("hidden_block_count_y", true) 119 .Case("hidden_block_count_z", true) 120 .Case("hidden_group_size_x", true) 121 .Case("hidden_group_size_y", true) 122 .Case("hidden_group_size_z", true) 123 .Case("hidden_remainder_x", true) 124 .Case("hidden_remainder_y", true) 125 .Case("hidden_remainder_z", true) 126 .Case("hidden_global_offset_x", true) 127 .Case("hidden_global_offset_y", true) 128 .Case("hidden_global_offset_z", true) 129 .Case("hidden_grid_dims", true) 130 .Case("hidden_none", true) 131 .Case("hidden_printf_buffer", true) 132 .Case("hidden_hostcall_buffer", true) 133 .Case("hidden_heap_v1", true) 134 .Case("hidden_default_queue", true) 135 .Case("hidden_completion_action", true) 136 .Case("hidden_multigrid_sync_arg", true) 137 .Case("hidden_dynamic_lds_size", true) 138 .Case("hidden_private_base", true) 139 .Case("hidden_shared_base", true) 140 .Case("hidden_queue_ptr", true) 141 .Default(false); 142 })) 143 return false; 144 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false)) 145 return false; 146 if (!verifyScalarEntry(ArgsMap, ".address_space", false, 147 msgpack::Type::String, 148 [](msgpack::DocNode &SNode) { 149 return StringSwitch<bool>(SNode.getString()) 150 .Case("private", true) 151 .Case("global", true) 152 .Case("constant", true) 153 .Case("local", true) 154 .Case("generic", true) 155 .Case("region", true) 156 .Default(false); 157 })) 158 return false; 159 if (!verifyScalarEntry(ArgsMap, ".access", false, 160 msgpack::Type::String, 161 [](msgpack::DocNode &SNode) { 162 return StringSwitch<bool>(SNode.getString()) 163 .Case("read_only", true) 164 .Case("write_only", true) 165 .Case("read_write", true) 166 .Default(false); 167 })) 168 return false; 169 if (!verifyScalarEntry(ArgsMap, ".actual_access", false, 170 msgpack::Type::String, 171 [](msgpack::DocNode &SNode) { 172 return StringSwitch<bool>(SNode.getString()) 173 .Case("read_only", true) 174 .Case("write_only", true) 175 .Case("read_write", true) 176 .Default(false); 177 })) 178 return false; 179 if (!verifyScalarEntry(ArgsMap, ".is_const", false, 180 msgpack::Type::Boolean)) 181 return false; 182 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false, 183 msgpack::Type::Boolean)) 184 return false; 185 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false, 186 msgpack::Type::Boolean)) 187 return false; 188 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false, 189 msgpack::Type::Boolean)) 190 return false; 191 192 return true; 193 } 194 195 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) { 196 if (!Node.isMap()) 197 return false; 198 auto &KernelMap = Node.getMap(); 199 200 if (!verifyScalarEntry(KernelMap, ".name", true, 201 msgpack::Type::String)) 202 return false; 203 if (!verifyScalarEntry(KernelMap, ".symbol", true, 204 msgpack::Type::String)) 205 return false; 206 if (!verifyScalarEntry(KernelMap, ".language", false, 207 msgpack::Type::String, 208 [](msgpack::DocNode &SNode) { 209 return StringSwitch<bool>(SNode.getString()) 210 .Case("OpenCL C", true) 211 .Case("OpenCL C++", true) 212 .Case("HCC", true) 213 .Case("HIP", true) 214 .Case("OpenMP", true) 215 .Case("Assembler", true) 216 .Default(false); 217 })) 218 return false; 219 if (!verifyEntry( 220 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) { 221 return verifyArray( 222 Node, 223 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 224 })) 225 return false; 226 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) { 227 return verifyArray(Node, [this](msgpack::DocNode &Node) { 228 return verifyKernelArgs(Node); 229 }); 230 })) 231 return false; 232 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false, 233 [this](msgpack::DocNode &Node) { 234 return verifyArray(Node, 235 [this](msgpack::DocNode &Node) { 236 return verifyInteger(Node); 237 }, 238 3); 239 })) 240 return false; 241 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false, 242 [this](msgpack::DocNode &Node) { 243 return verifyArray(Node, 244 [this](msgpack::DocNode &Node) { 245 return verifyInteger(Node); 246 }, 247 3); 248 })) 249 return false; 250 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false, 251 msgpack::Type::String)) 252 return false; 253 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false, 254 msgpack::Type::String)) 255 return false; 256 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true)) 257 return false; 258 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true)) 259 return false; 260 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true)) 261 return false; 262 if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false, 263 msgpack::Type::Boolean)) 264 return false; 265 if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false)) 266 return false; 267 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true)) 268 return false; 269 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true)) 270 return false; 271 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true)) 272 return false; 273 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true)) 274 return false; 275 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true)) 276 return false; 277 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false)) 278 return false; 279 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false)) 280 return false; 281 if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false)) 282 return false; 283 284 285 return true; 286 } 287 288 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) { 289 if (!HSAMetadataRoot.isMap()) 290 return false; 291 auto &RootMap = HSAMetadataRoot.getMap(); 292 293 if (!verifyEntry( 294 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) { 295 return verifyArray( 296 Node, 297 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 298 })) 299 return false; 300 if (!verifyEntry( 301 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) { 302 return verifyArray(Node, [this](msgpack::DocNode &Node) { 303 return verifyScalar(Node, msgpack::Type::String); 304 }); 305 })) 306 return false; 307 if (!verifyEntry(RootMap, "amdhsa.kernels", true, 308 [this](msgpack::DocNode &Node) { 309 return verifyArray(Node, [this](msgpack::DocNode &Node) { 310 return verifyKernel(Node); 311 }); 312 })) 313 return false; 314 315 return true; 316 } 317 318 } // end namespace V3 319 } // end namespace HSAMD 320 } // end namespace AMDGPU 321 } // end namespace llvm 322