1 //=- WebAssemblyISelLowering.cpp - WebAssembly DAG Lowering Implementation -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the WebAssemblyTargetLowering class.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "WebAssemblyISelLowering.h"
15 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
16 #include "Utils/WebAssemblyTypeUtilities.h"
17 #include "WebAssemblyMachineFunctionInfo.h"
18 #include "WebAssemblySubtarget.h"
19 #include "WebAssemblyTargetMachine.h"
20 #include "WebAssemblyUtilities.h"
21 #include "llvm/CodeGen/CallingConvLower.h"
22 #include "llvm/CodeGen/MachineFrameInfo.h"
23 #include "llvm/CodeGen/MachineFunctionPass.h"
24 #include "llvm/CodeGen/MachineInstrBuilder.h"
25 #include "llvm/CodeGen/MachineJumpTableInfo.h"
26 #include "llvm/CodeGen/MachineModuleInfo.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/SelectionDAG.h"
29 #include "llvm/CodeGen/SelectionDAGNodes.h"
30 #include "llvm/IR/DiagnosticInfo.h"
31 #include "llvm/IR/DiagnosticPrinter.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsWebAssembly.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/KnownBits.h"
39 #include "llvm/Support/MathExtras.h"
40 #include "llvm/Support/raw_ostream.h"
41 #include "llvm/Target/TargetOptions.h"
42 using namespace llvm;
43
44 #define DEBUG_TYPE "wasm-lower"
45
WebAssemblyTargetLowering(const TargetMachine & TM,const WebAssemblySubtarget & STI)46 WebAssemblyTargetLowering::WebAssemblyTargetLowering(
47 const TargetMachine &TM, const WebAssemblySubtarget &STI)
48 : TargetLowering(TM), Subtarget(&STI) {
49 auto MVTPtr = Subtarget->hasAddr64() ? MVT::i64 : MVT::i32;
50
51 // Booleans always contain 0 or 1.
52 setBooleanContents(ZeroOrOneBooleanContent);
53 // Except in SIMD vectors
54 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
55 // We don't know the microarchitecture here, so just reduce register pressure.
56 setSchedulingPreference(Sched::RegPressure);
57 // Tell ISel that we have a stack pointer.
58 setStackPointerRegisterToSaveRestore(
59 Subtarget->hasAddr64() ? WebAssembly::SP64 : WebAssembly::SP32);
60 // Set up the register classes.
61 addRegisterClass(MVT::i32, &WebAssembly::I32RegClass);
62 addRegisterClass(MVT::i64, &WebAssembly::I64RegClass);
63 addRegisterClass(MVT::f32, &WebAssembly::F32RegClass);
64 addRegisterClass(MVT::f64, &WebAssembly::F64RegClass);
65 if (Subtarget->hasSIMD128()) {
66 addRegisterClass(MVT::v16i8, &WebAssembly::V128RegClass);
67 addRegisterClass(MVT::v8i16, &WebAssembly::V128RegClass);
68 addRegisterClass(MVT::v4i32, &WebAssembly::V128RegClass);
69 addRegisterClass(MVT::v4f32, &WebAssembly::V128RegClass);
70 addRegisterClass(MVT::v2i64, &WebAssembly::V128RegClass);
71 addRegisterClass(MVT::v2f64, &WebAssembly::V128RegClass);
72 }
73 if (Subtarget->hasHalfPrecision()) {
74 addRegisterClass(MVT::v8f16, &WebAssembly::V128RegClass);
75 }
76 if (Subtarget->hasReferenceTypes()) {
77 addRegisterClass(MVT::externref, &WebAssembly::EXTERNREFRegClass);
78 addRegisterClass(MVT::funcref, &WebAssembly::FUNCREFRegClass);
79 if (Subtarget->hasExceptionHandling()) {
80 addRegisterClass(MVT::exnref, &WebAssembly::EXNREFRegClass);
81 }
82 }
83 // Compute derived properties from the register classes.
84 computeRegisterProperties(Subtarget->getRegisterInfo());
85
86 // Transform loads and stores to pointers in address space 1 to loads and
87 // stores to WebAssembly global variables, outside linear memory.
88 for (auto T : {MVT::i32, MVT::i64, MVT::f32, MVT::f64}) {
89 setOperationAction(ISD::LOAD, T, Custom);
90 setOperationAction(ISD::STORE, T, Custom);
91 }
92 if (Subtarget->hasSIMD128()) {
93 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
94 MVT::v2f64}) {
95 setOperationAction(ISD::LOAD, T, Custom);
96 setOperationAction(ISD::STORE, T, Custom);
97 }
98 }
99 if (Subtarget->hasReferenceTypes()) {
100 // We need custom load and store lowering for both externref, funcref and
101 // Other. The MVT::Other here represents tables of reference types.
102 for (auto T : {MVT::externref, MVT::funcref, MVT::Other}) {
103 setOperationAction(ISD::LOAD, T, Custom);
104 setOperationAction(ISD::STORE, T, Custom);
105 }
106 }
107
108 setOperationAction(ISD::GlobalAddress, MVTPtr, Custom);
109 setOperationAction(ISD::GlobalTLSAddress, MVTPtr, Custom);
110 setOperationAction(ISD::ExternalSymbol, MVTPtr, Custom);
111 setOperationAction(ISD::JumpTable, MVTPtr, Custom);
112 setOperationAction(ISD::BlockAddress, MVTPtr, Custom);
113 setOperationAction(ISD::BRIND, MVT::Other, Custom);
114 setOperationAction(ISD::CLEAR_CACHE, MVT::Other, Custom);
115
116 // Take the default expansion for va_arg, va_copy, and va_end. There is no
117 // default action for va_start, so we do that custom.
118 setOperationAction(ISD::VASTART, MVT::Other, Custom);
119 setOperationAction(ISD::VAARG, MVT::Other, Expand);
120 setOperationAction(ISD::VACOPY, MVT::Other, Expand);
121 setOperationAction(ISD::VAEND, MVT::Other, Expand);
122
123 for (auto T : {MVT::f32, MVT::f64, MVT::v4f32, MVT::v2f64}) {
124 // Don't expand the floating-point types to constant pools.
125 setOperationAction(ISD::ConstantFP, T, Legal);
126 // Expand floating-point comparisons.
127 for (auto CC : {ISD::SETO, ISD::SETUO, ISD::SETUEQ, ISD::SETONE,
128 ISD::SETULT, ISD::SETULE, ISD::SETUGT, ISD::SETUGE})
129 setCondCodeAction(CC, T, Expand);
130 // Expand floating-point library function operators.
131 for (auto Op :
132 {ISD::FSIN, ISD::FCOS, ISD::FSINCOS, ISD::FPOW, ISD::FREM, ISD::FMA})
133 setOperationAction(Op, T, Expand);
134 // Note supported floating-point library function operators that otherwise
135 // default to expand.
136 for (auto Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FTRUNC, ISD::FNEARBYINT,
137 ISD::FRINT, ISD::FROUNDEVEN})
138 setOperationAction(Op, T, Legal);
139 // Support minimum and maximum, which otherwise default to expand.
140 setOperationAction(ISD::FMINIMUM, T, Legal);
141 setOperationAction(ISD::FMAXIMUM, T, Legal);
142 // WebAssembly currently has no builtin f16 support.
143 setOperationAction(ISD::FP16_TO_FP, T, Expand);
144 setOperationAction(ISD::FP_TO_FP16, T, Expand);
145 setLoadExtAction(ISD::EXTLOAD, T, MVT::f16, Expand);
146 setTruncStoreAction(T, MVT::f16, Expand);
147 }
148
149 if (Subtarget->hasHalfPrecision()) {
150 setOperationAction(ISD::FMINIMUM, MVT::v8f16, Legal);
151 setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Legal);
152 }
153
154 // Expand unavailable integer operations.
155 for (auto Op :
156 {ISD::BSWAP, ISD::SMUL_LOHI, ISD::UMUL_LOHI, ISD::MULHS, ISD::MULHU,
157 ISD::SDIVREM, ISD::UDIVREM, ISD::SHL_PARTS, ISD::SRA_PARTS,
158 ISD::SRL_PARTS, ISD::ADDC, ISD::ADDE, ISD::SUBC, ISD::SUBE}) {
159 for (auto T : {MVT::i32, MVT::i64})
160 setOperationAction(Op, T, Expand);
161 if (Subtarget->hasSIMD128())
162 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64})
163 setOperationAction(Op, T, Expand);
164 }
165
166 if (Subtarget->hasNontrappingFPToInt())
167 for (auto Op : {ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT})
168 for (auto T : {MVT::i32, MVT::i64})
169 setOperationAction(Op, T, Custom);
170
171 // SIMD-specific configuration
172 if (Subtarget->hasSIMD128()) {
173 // Combine vector mask reductions into alltrue/anytrue
174 setTargetDAGCombine(ISD::SETCC);
175
176 // Convert vector to integer bitcasts to bitmask
177 setTargetDAGCombine(ISD::BITCAST);
178
179 // Hoist bitcasts out of shuffles
180 setTargetDAGCombine(ISD::VECTOR_SHUFFLE);
181
182 // Combine extends of extract_subvectors into widening ops
183 setTargetDAGCombine({ISD::SIGN_EXTEND, ISD::ZERO_EXTEND});
184
185 // Combine int_to_fp or fp_extend of extract_vectors and vice versa into
186 // conversions ops
187 setTargetDAGCombine({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_EXTEND,
188 ISD::EXTRACT_SUBVECTOR});
189
190 // Combine fp_to_{s,u}int_sat or fp_round of concat_vectors or vice versa
191 // into conversion ops
192 setTargetDAGCombine({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT,
193 ISD::FP_ROUND, ISD::CONCAT_VECTORS});
194
195 setTargetDAGCombine(ISD::TRUNCATE);
196
197 // Support saturating add for i8x16 and i16x8
198 for (auto Op : {ISD::SADDSAT, ISD::UADDSAT})
199 for (auto T : {MVT::v16i8, MVT::v8i16})
200 setOperationAction(Op, T, Legal);
201
202 // Support integer abs
203 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64})
204 setOperationAction(ISD::ABS, T, Legal);
205
206 // Custom lower BUILD_VECTORs to minimize number of replace_lanes
207 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
208 MVT::v2f64})
209 setOperationAction(ISD::BUILD_VECTOR, T, Custom);
210
211 // We have custom shuffle lowering to expose the shuffle mask
212 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
213 MVT::v2f64})
214 setOperationAction(ISD::VECTOR_SHUFFLE, T, Custom);
215
216 // Support splatting
217 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
218 MVT::v2f64})
219 setOperationAction(ISD::SPLAT_VECTOR, T, Legal);
220
221 // Custom lowering since wasm shifts must have a scalar shift amount
222 for (auto Op : {ISD::SHL, ISD::SRA, ISD::SRL})
223 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64})
224 setOperationAction(Op, T, Custom);
225
226 // Custom lower lane accesses to expand out variable indices
227 for (auto Op : {ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT})
228 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
229 MVT::v2f64})
230 setOperationAction(Op, T, Custom);
231
232 // There is no i8x16.mul instruction
233 setOperationAction(ISD::MUL, MVT::v16i8, Expand);
234
235 // There is no vector conditional select instruction
236 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
237 MVT::v2f64})
238 setOperationAction(ISD::SELECT_CC, T, Expand);
239
240 // Expand integer operations supported for scalars but not SIMD
241 for (auto Op :
242 {ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM, ISD::ROTL, ISD::ROTR})
243 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64})
244 setOperationAction(Op, T, Expand);
245
246 // But we do have integer min and max operations
247 for (auto Op : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
248 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
249 setOperationAction(Op, T, Legal);
250
251 // And we have popcnt for i8x16. It can be used to expand ctlz/cttz.
252 setOperationAction(ISD::CTPOP, MVT::v16i8, Legal);
253 setOperationAction(ISD::CTLZ, MVT::v16i8, Expand);
254 setOperationAction(ISD::CTTZ, MVT::v16i8, Expand);
255
256 // Custom lower bit counting operations for other types to scalarize them.
257 for (auto Op : {ISD::CTLZ, ISD::CTTZ, ISD::CTPOP})
258 for (auto T : {MVT::v8i16, MVT::v4i32, MVT::v2i64})
259 setOperationAction(Op, T, Custom);
260
261 // Expand float operations supported for scalars but not SIMD
262 for (auto Op : {ISD::FCOPYSIGN, ISD::FLOG, ISD::FLOG2, ISD::FLOG10,
263 ISD::FEXP, ISD::FEXP2})
264 for (auto T : {MVT::v4f32, MVT::v2f64})
265 setOperationAction(Op, T, Expand);
266
267 // Unsigned comparison operations are unavailable for i64x2 vectors.
268 for (auto CC : {ISD::SETUGT, ISD::SETUGE, ISD::SETULT, ISD::SETULE})
269 setCondCodeAction(CC, MVT::v2i64, Custom);
270
271 // 64x2 conversions are not in the spec
272 for (auto Op :
273 {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT})
274 for (auto T : {MVT::v2i64, MVT::v2f64})
275 setOperationAction(Op, T, Expand);
276
277 // But saturating fp_to_int converstions are
278 for (auto Op : {ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT})
279 setOperationAction(Op, MVT::v4i32, Custom);
280
281 // Support vector extending
282 for (auto T : MVT::integer_fixedlen_vector_valuetypes()) {
283 setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
284 setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
285 }
286 }
287
288 // As a special case, these operators use the type to mean the type to
289 // sign-extend from.
290 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
291 if (!Subtarget->hasSignExt()) {
292 // Sign extends are legal only when extending a vector extract
293 auto Action = Subtarget->hasSIMD128() ? Custom : Expand;
294 for (auto T : {MVT::i8, MVT::i16, MVT::i32})
295 setOperationAction(ISD::SIGN_EXTEND_INREG, T, Action);
296 }
297 for (auto T : MVT::integer_fixedlen_vector_valuetypes())
298 setOperationAction(ISD::SIGN_EXTEND_INREG, T, Expand);
299
300 // Dynamic stack allocation: use the default expansion.
301 setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
302 setOperationAction(ISD::STACKRESTORE, MVT::Other, Expand);
303 setOperationAction(ISD::DYNAMIC_STACKALLOC, MVTPtr, Expand);
304
305 setOperationAction(ISD::FrameIndex, MVT::i32, Custom);
306 setOperationAction(ISD::FrameIndex, MVT::i64, Custom);
307 setOperationAction(ISD::CopyToReg, MVT::Other, Custom);
308
309 // Expand these forms; we pattern-match the forms that we can handle in isel.
310 for (auto T : {MVT::i32, MVT::i64, MVT::f32, MVT::f64})
311 for (auto Op : {ISD::BR_CC, ISD::SELECT_CC})
312 setOperationAction(Op, T, Expand);
313
314 // We have custom switch handling.
315 setOperationAction(ISD::BR_JT, MVT::Other, Custom);
316
317 // WebAssembly doesn't have:
318 // - Floating-point extending loads.
319 // - Floating-point truncating stores.
320 // - i1 extending loads.
321 // - truncating SIMD stores and most extending loads
322 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
323 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
324 for (auto T : MVT::integer_valuetypes())
325 for (auto Ext : {ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD})
326 setLoadExtAction(Ext, T, MVT::i1, Promote);
327 if (Subtarget->hasSIMD128()) {
328 for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v4f32,
329 MVT::v2f64}) {
330 for (auto MemT : MVT::fixedlen_vector_valuetypes()) {
331 if (MVT(T) != MemT) {
332 setTruncStoreAction(T, MemT, Expand);
333 for (auto Ext : {ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD})
334 setLoadExtAction(Ext, T, MemT, Expand);
335 }
336 }
337 }
338 // But some vector extending loads are legal
339 for (auto Ext : {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}) {
340 setLoadExtAction(Ext, MVT::v8i16, MVT::v8i8, Legal);
341 setLoadExtAction(Ext, MVT::v4i32, MVT::v4i16, Legal);
342 setLoadExtAction(Ext, MVT::v2i64, MVT::v2i32, Legal);
343 }
344 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Legal);
345 }
346
347 // Don't do anything clever with build_pairs
348 setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand);
349
350 // Trap lowers to wasm unreachable
351 setOperationAction(ISD::TRAP, MVT::Other, Legal);
352 setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
353
354 // Exception handling intrinsics
355 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
356 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
357 setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom);
358
359 setMaxAtomicSizeInBitsSupported(64);
360
361 // Override the __gnu_f2h_ieee/__gnu_h2f_ieee names so that the f32 name is
362 // consistent with the f64 and f128 names.
363 setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
364 setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
365
366 // Define the emscripten name for return address helper.
367 // TODO: when implementing other Wasm backends, make this generic or only do
368 // this on emscripten depending on what they end up doing.
369 setLibcallName(RTLIB::RETURN_ADDRESS, "emscripten_return_address");
370
371 // Always convert switches to br_tables unless there is only one case, which
372 // is equivalent to a simple branch. This reduces code size for wasm, and we
373 // defer possible jump table optimizations to the VM.
374 setMinimumJumpTableEntries(2);
375 }
376
getPointerTy(const DataLayout & DL,uint32_t AS) const377 MVT WebAssemblyTargetLowering::getPointerTy(const DataLayout &DL,
378 uint32_t AS) const {
379 if (AS == WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_EXTERNREF)
380 return MVT::externref;
381 if (AS == WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF)
382 return MVT::funcref;
383 return TargetLowering::getPointerTy(DL, AS);
384 }
385
getPointerMemTy(const DataLayout & DL,uint32_t AS) const386 MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
387 uint32_t AS) const {
388 if (AS == WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_EXTERNREF)
389 return MVT::externref;
390 if (AS == WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF)
391 return MVT::funcref;
392 return TargetLowering::getPointerMemTy(DL, AS);
393 }
394
395 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const396 WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
397 // We have wasm instructions for these
398 switch (AI->getOperation()) {
399 case AtomicRMWInst::Add:
400 case AtomicRMWInst::Sub:
401 case AtomicRMWInst::And:
402 case AtomicRMWInst::Or:
403 case AtomicRMWInst::Xor:
404 case AtomicRMWInst::Xchg:
405 return AtomicExpansionKind::None;
406 default:
407 break;
408 }
409 return AtomicExpansionKind::CmpXChg;
410 }
411
shouldScalarizeBinop(SDValue VecOp) const412 bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
413 // Implementation copied from X86TargetLowering.
414 unsigned Opc = VecOp.getOpcode();
415
416 // Assume target opcodes can't be scalarized.
417 // TODO - do we have any exceptions?
418 if (Opc >= ISD::BUILTIN_OP_END)
419 return false;
420
421 // If the vector op is not supported, try to convert to scalar.
422 EVT VecVT = VecOp.getValueType();
423 if (!isOperationLegalOrCustomOrPromote(Opc, VecVT))
424 return true;
425
426 // If the vector op is supported, but the scalar op is not, the transform may
427 // not be worthwhile.
428 EVT ScalarVT = VecVT.getScalarType();
429 return isOperationLegalOrCustomOrPromote(Opc, ScalarVT);
430 }
431
createFastISel(FunctionLoweringInfo & FuncInfo,const TargetLibraryInfo * LibInfo) const432 FastISel *WebAssemblyTargetLowering::createFastISel(
433 FunctionLoweringInfo &FuncInfo, const TargetLibraryInfo *LibInfo) const {
434 return WebAssembly::createFastISel(FuncInfo, LibInfo);
435 }
436
getScalarShiftAmountTy(const DataLayout &,EVT VT) const437 MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
438 EVT VT) const {
439 unsigned BitWidth = NextPowerOf2(VT.getSizeInBits() - 1);
440 if (BitWidth > 1 && BitWidth < 8)
441 BitWidth = 8;
442
443 if (BitWidth > 64) {
444 // The shift will be lowered to a libcall, and compiler-rt libcalls expect
445 // the count to be an i32.
446 BitWidth = 32;
447 assert(BitWidth >= Log2_32_Ceil(VT.getSizeInBits()) &&
448 "32-bit shift counts ought to be enough for anyone");
449 }
450
451 MVT Result = MVT::getIntegerVT(BitWidth);
452 assert(Result != MVT::INVALID_SIMPLE_VALUE_TYPE &&
453 "Unable to represent scalar shift amount type");
454 return Result;
455 }
456
457 // Lower an fp-to-int conversion operator from the LLVM opcode, which has an
458 // undefined result on invalid/overflow, to the WebAssembly opcode, which
459 // traps on invalid/overflow.
LowerFPToInt(MachineInstr & MI,DebugLoc DL,MachineBasicBlock * BB,const TargetInstrInfo & TII,bool IsUnsigned,bool Int64,bool Float64,unsigned LoweredOpcode)460 static MachineBasicBlock *LowerFPToInt(MachineInstr &MI, DebugLoc DL,
461 MachineBasicBlock *BB,
462 const TargetInstrInfo &TII,
463 bool IsUnsigned, bool Int64,
464 bool Float64, unsigned LoweredOpcode) {
465 MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
466
467 Register OutReg = MI.getOperand(0).getReg();
468 Register InReg = MI.getOperand(1).getReg();
469
470 unsigned Abs = Float64 ? WebAssembly::ABS_F64 : WebAssembly::ABS_F32;
471 unsigned FConst = Float64 ? WebAssembly::CONST_F64 : WebAssembly::CONST_F32;
472 unsigned LT = Float64 ? WebAssembly::LT_F64 : WebAssembly::LT_F32;
473 unsigned GE = Float64 ? WebAssembly::GE_F64 : WebAssembly::GE_F32;
474 unsigned IConst = Int64 ? WebAssembly::CONST_I64 : WebAssembly::CONST_I32;
475 unsigned Eqz = WebAssembly::EQZ_I32;
476 unsigned And = WebAssembly::AND_I32;
477 int64_t Limit = Int64 ? INT64_MIN : INT32_MIN;
478 int64_t Substitute = IsUnsigned ? 0 : Limit;
479 double CmpVal = IsUnsigned ? -(double)Limit * 2.0 : -(double)Limit;
480 auto &Context = BB->getParent()->getFunction().getContext();
481 Type *Ty = Float64 ? Type::getDoubleTy(Context) : Type::getFloatTy(Context);
482
483 const BasicBlock *LLVMBB = BB->getBasicBlock();
484 MachineFunction *F = BB->getParent();
485 MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
486 MachineBasicBlock *FalseMBB = F->CreateMachineBasicBlock(LLVMBB);
487 MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);
488
489 MachineFunction::iterator It = ++BB->getIterator();
490 F->insert(It, FalseMBB);
491 F->insert(It, TrueMBB);
492 F->insert(It, DoneMBB);
493
494 // Transfer the remainder of BB and its successor edges to DoneMBB.
495 DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
496 DoneMBB->transferSuccessorsAndUpdatePHIs(BB);
497
498 BB->addSuccessor(TrueMBB);
499 BB->addSuccessor(FalseMBB);
500 TrueMBB->addSuccessor(DoneMBB);
501 FalseMBB->addSuccessor(DoneMBB);
502
503 unsigned Tmp0, Tmp1, CmpReg, EqzReg, FalseReg, TrueReg;
504 Tmp0 = MRI.createVirtualRegister(MRI.getRegClass(InReg));
505 Tmp1 = MRI.createVirtualRegister(MRI.getRegClass(InReg));
506 CmpReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
507 EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
508 FalseReg = MRI.createVirtualRegister(MRI.getRegClass(OutReg));
509 TrueReg = MRI.createVirtualRegister(MRI.getRegClass(OutReg));
510
511 MI.eraseFromParent();
512 // For signed numbers, we can do a single comparison to determine whether
513 // fabs(x) is within range.
514 if (IsUnsigned) {
515 Tmp0 = InReg;
516 } else {
517 BuildMI(BB, DL, TII.get(Abs), Tmp0).addReg(InReg);
518 }
519 BuildMI(BB, DL, TII.get(FConst), Tmp1)
520 .addFPImm(cast<ConstantFP>(ConstantFP::get(Ty, CmpVal)));
521 BuildMI(BB, DL, TII.get(LT), CmpReg).addReg(Tmp0).addReg(Tmp1);
522
523 // For unsigned numbers, we have to do a separate comparison with zero.
524 if (IsUnsigned) {
525 Tmp1 = MRI.createVirtualRegister(MRI.getRegClass(InReg));
526 Register SecondCmpReg =
527 MRI.createVirtualRegister(&WebAssembly::I32RegClass);
528 Register AndReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
529 BuildMI(BB, DL, TII.get(FConst), Tmp1)
530 .addFPImm(cast<ConstantFP>(ConstantFP::get(Ty, 0.0)));
531 BuildMI(BB, DL, TII.get(GE), SecondCmpReg).addReg(Tmp0).addReg(Tmp1);
532 BuildMI(BB, DL, TII.get(And), AndReg).addReg(CmpReg).addReg(SecondCmpReg);
533 CmpReg = AndReg;
534 }
535
536 BuildMI(BB, DL, TII.get(Eqz), EqzReg).addReg(CmpReg);
537
538 // Create the CFG diamond to select between doing the conversion or using
539 // the substitute value.
540 BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(TrueMBB).addReg(EqzReg);
541 BuildMI(FalseMBB, DL, TII.get(LoweredOpcode), FalseReg).addReg(InReg);
542 BuildMI(FalseMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);
543 BuildMI(TrueMBB, DL, TII.get(IConst), TrueReg).addImm(Substitute);
544 BuildMI(*DoneMBB, DoneMBB->begin(), DL, TII.get(TargetOpcode::PHI), OutReg)
545 .addReg(FalseReg)
546 .addMBB(FalseMBB)
547 .addReg(TrueReg)
548 .addMBB(TrueMBB);
549
550 return DoneMBB;
551 }
552
553 static MachineBasicBlock *
LowerCallResults(MachineInstr & CallResults,DebugLoc DL,MachineBasicBlock * BB,const WebAssemblySubtarget * Subtarget,const TargetInstrInfo & TII)554 LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
555 const WebAssemblySubtarget *Subtarget,
556 const TargetInstrInfo &TII) {
557 MachineInstr &CallParams = *CallResults.getPrevNode();
558 assert(CallParams.getOpcode() == WebAssembly::CALL_PARAMS);
559 assert(CallResults.getOpcode() == WebAssembly::CALL_RESULTS ||
560 CallResults.getOpcode() == WebAssembly::RET_CALL_RESULTS);
561
562 bool IsIndirect =
563 CallParams.getOperand(0).isReg() || CallParams.getOperand(0).isFI();
564 bool IsRetCall = CallResults.getOpcode() == WebAssembly::RET_CALL_RESULTS;
565
566 bool IsFuncrefCall = false;
567 if (IsIndirect && CallParams.getOperand(0).isReg()) {
568 Register Reg = CallParams.getOperand(0).getReg();
569 const MachineFunction *MF = BB->getParent();
570 const MachineRegisterInfo &MRI = MF->getRegInfo();
571 const TargetRegisterClass *TRC = MRI.getRegClass(Reg);
572 IsFuncrefCall = (TRC == &WebAssembly::FUNCREFRegClass);
573 assert(!IsFuncrefCall || Subtarget->hasReferenceTypes());
574 }
575
576 unsigned CallOp;
577 if (IsIndirect && IsRetCall) {
578 CallOp = WebAssembly::RET_CALL_INDIRECT;
579 } else if (IsIndirect) {
580 CallOp = WebAssembly::CALL_INDIRECT;
581 } else if (IsRetCall) {
582 CallOp = WebAssembly::RET_CALL;
583 } else {
584 CallOp = WebAssembly::CALL;
585 }
586
587 MachineFunction &MF = *BB->getParent();
588 const MCInstrDesc &MCID = TII.get(CallOp);
589 MachineInstrBuilder MIB(MF, MF.CreateMachineInstr(MCID, DL));
590
591 // Move the function pointer to the end of the arguments for indirect calls
592 if (IsIndirect) {
593 auto FnPtr = CallParams.getOperand(0);
594 CallParams.removeOperand(0);
595
596 // For funcrefs, call_indirect is done through __funcref_call_table and the
597 // funcref is always installed in slot 0 of the table, therefore instead of
598 // having the function pointer added at the end of the params list, a zero
599 // (the index in
600 // __funcref_call_table is added).
601 if (IsFuncrefCall) {
602 Register RegZero =
603 MF.getRegInfo().createVirtualRegister(&WebAssembly::I32RegClass);
604 MachineInstrBuilder MIBC0 =
605 BuildMI(MF, DL, TII.get(WebAssembly::CONST_I32), RegZero).addImm(0);
606
607 BB->insert(CallResults.getIterator(), MIBC0);
608 MachineInstrBuilder(MF, CallParams).addReg(RegZero);
609 } else
610 CallParams.addOperand(FnPtr);
611 }
612
613 for (auto Def : CallResults.defs())
614 MIB.add(Def);
615
616 if (IsIndirect) {
617 // Placeholder for the type index.
618 MIB.addImm(0);
619 // The table into which this call_indirect indexes.
620 MCSymbolWasm *Table = IsFuncrefCall
621 ? WebAssembly::getOrCreateFuncrefCallTableSymbol(
622 MF.getContext(), Subtarget)
623 : WebAssembly::getOrCreateFunctionTableSymbol(
624 MF.getContext(), Subtarget);
625 if (Subtarget->hasReferenceTypes()) {
626 MIB.addSym(Table);
627 } else {
628 // For the MVP there is at most one table whose number is 0, but we can't
629 // write a table symbol or issue relocations. Instead we just ensure the
630 // table is live and write a zero.
631 Table->setNoStrip();
632 MIB.addImm(0);
633 }
634 }
635
636 for (auto Use : CallParams.uses())
637 MIB.add(Use);
638
639 BB->insert(CallResults.getIterator(), MIB);
640 CallParams.eraseFromParent();
641 CallResults.eraseFromParent();
642
643 // If this is a funcref call, to avoid hidden GC roots, we need to clear the
644 // table slot with ref.null upon call_indirect return.
645 //
646 // This generates the following code, which comes right after a call_indirect
647 // of a funcref:
648 //
649 // i32.const 0
650 // ref.null func
651 // table.set __funcref_call_table
652 if (IsIndirect && IsFuncrefCall) {
653 MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
654 MF.getContext(), Subtarget);
655 Register RegZero =
656 MF.getRegInfo().createVirtualRegister(&WebAssembly::I32RegClass);
657 MachineInstr *Const0 =
658 BuildMI(MF, DL, TII.get(WebAssembly::CONST_I32), RegZero).addImm(0);
659 BB->insertAfter(MIB.getInstr()->getIterator(), Const0);
660
661 Register RegFuncref =
662 MF.getRegInfo().createVirtualRegister(&WebAssembly::FUNCREFRegClass);
663 MachineInstr *RefNull =
664 BuildMI(MF, DL, TII.get(WebAssembly::REF_NULL_FUNCREF), RegFuncref);
665 BB->insertAfter(Const0->getIterator(), RefNull);
666
667 MachineInstr *TableSet =
668 BuildMI(MF, DL, TII.get(WebAssembly::TABLE_SET_FUNCREF))
669 .addSym(Table)
670 .addReg(RegZero)
671 .addReg(RegFuncref);
672 BB->insertAfter(RefNull->getIterator(), TableSet);
673 }
674
675 return BB;
676 }
677
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const678 MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
679 MachineInstr &MI, MachineBasicBlock *BB) const {
680 const TargetInstrInfo &TII = *Subtarget->getInstrInfo();
681 DebugLoc DL = MI.getDebugLoc();
682
683 switch (MI.getOpcode()) {
684 default:
685 llvm_unreachable("Unexpected instr type to insert");
686 case WebAssembly::FP_TO_SINT_I32_F32:
687 return LowerFPToInt(MI, DL, BB, TII, false, false, false,
688 WebAssembly::I32_TRUNC_S_F32);
689 case WebAssembly::FP_TO_UINT_I32_F32:
690 return LowerFPToInt(MI, DL, BB, TII, true, false, false,
691 WebAssembly::I32_TRUNC_U_F32);
692 case WebAssembly::FP_TO_SINT_I64_F32:
693 return LowerFPToInt(MI, DL, BB, TII, false, true, false,
694 WebAssembly::I64_TRUNC_S_F32);
695 case WebAssembly::FP_TO_UINT_I64_F32:
696 return LowerFPToInt(MI, DL, BB, TII, true, true, false,
697 WebAssembly::I64_TRUNC_U_F32);
698 case WebAssembly::FP_TO_SINT_I32_F64:
699 return LowerFPToInt(MI, DL, BB, TII, false, false, true,
700 WebAssembly::I32_TRUNC_S_F64);
701 case WebAssembly::FP_TO_UINT_I32_F64:
702 return LowerFPToInt(MI, DL, BB, TII, true, false, true,
703 WebAssembly::I32_TRUNC_U_F64);
704 case WebAssembly::FP_TO_SINT_I64_F64:
705 return LowerFPToInt(MI, DL, BB, TII, false, true, true,
706 WebAssembly::I64_TRUNC_S_F64);
707 case WebAssembly::FP_TO_UINT_I64_F64:
708 return LowerFPToInt(MI, DL, BB, TII, true, true, true,
709 WebAssembly::I64_TRUNC_U_F64);
710 case WebAssembly::CALL_RESULTS:
711 case WebAssembly::RET_CALL_RESULTS:
712 return LowerCallResults(MI, DL, BB, Subtarget, TII);
713 }
714 }
715
716 const char *
getTargetNodeName(unsigned Opcode) const717 WebAssemblyTargetLowering::getTargetNodeName(unsigned Opcode) const {
718 switch (static_cast<WebAssemblyISD::NodeType>(Opcode)) {
719 case WebAssemblyISD::FIRST_NUMBER:
720 case WebAssemblyISD::FIRST_MEM_OPCODE:
721 break;
722 #define HANDLE_NODETYPE(NODE) \
723 case WebAssemblyISD::NODE: \
724 return "WebAssemblyISD::" #NODE;
725 #define HANDLE_MEM_NODETYPE(NODE) HANDLE_NODETYPE(NODE)
726 #include "WebAssemblyISD.def"
727 #undef HANDLE_MEM_NODETYPE
728 #undef HANDLE_NODETYPE
729 }
730 return nullptr;
731 }
732
733 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const734 WebAssemblyTargetLowering::getRegForInlineAsmConstraint(
735 const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const {
736 // First, see if this is a constraint that directly corresponds to a
737 // WebAssembly register class.
738 if (Constraint.size() == 1) {
739 switch (Constraint[0]) {
740 case 'r':
741 assert(VT != MVT::iPTR && "Pointer MVT not expected here");
742 if (Subtarget->hasSIMD128() && VT.isVector()) {
743 if (VT.getSizeInBits() == 128)
744 return std::make_pair(0U, &WebAssembly::V128RegClass);
745 }
746 if (VT.isInteger() && !VT.isVector()) {
747 if (VT.getSizeInBits() <= 32)
748 return std::make_pair(0U, &WebAssembly::I32RegClass);
749 if (VT.getSizeInBits() <= 64)
750 return std::make_pair(0U, &WebAssembly::I64RegClass);
751 }
752 if (VT.isFloatingPoint() && !VT.isVector()) {
753 switch (VT.getSizeInBits()) {
754 case 32:
755 return std::make_pair(0U, &WebAssembly::F32RegClass);
756 case 64:
757 return std::make_pair(0U, &WebAssembly::F64RegClass);
758 default:
759 break;
760 }
761 }
762 break;
763 default:
764 break;
765 }
766 }
767
768 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
769 }
770
isCheapToSpeculateCttz(Type * Ty) const771 bool WebAssemblyTargetLowering::isCheapToSpeculateCttz(Type *Ty) const {
772 // Assume ctz is a relatively cheap operation.
773 return true;
774 }
775
isCheapToSpeculateCtlz(Type * Ty) const776 bool WebAssemblyTargetLowering::isCheapToSpeculateCtlz(Type *Ty) const {
777 // Assume clz is a relatively cheap operation.
778 return true;
779 }
780
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AM,Type * Ty,unsigned AS,Instruction * I) const781 bool WebAssemblyTargetLowering::isLegalAddressingMode(const DataLayout &DL,
782 const AddrMode &AM,
783 Type *Ty, unsigned AS,
784 Instruction *I) const {
785 // WebAssembly offsets are added as unsigned without wrapping. The
786 // isLegalAddressingMode gives us no way to determine if wrapping could be
787 // happening, so we approximate this by accepting only non-negative offsets.
788 if (AM.BaseOffs < 0)
789 return false;
790
791 // WebAssembly has no scale register operands.
792 if (AM.Scale != 0)
793 return false;
794
795 // Everything else is legal.
796 return true;
797 }
798
allowsMisalignedMemoryAccesses(EVT,unsigned,Align,MachineMemOperand::Flags,unsigned * Fast) const799 bool WebAssemblyTargetLowering::allowsMisalignedMemoryAccesses(
800 EVT /*VT*/, unsigned /*AddrSpace*/, Align /*Align*/,
801 MachineMemOperand::Flags /*Flags*/, unsigned *Fast) const {
802 // WebAssembly supports unaligned accesses, though it should be declared
803 // with the p2align attribute on loads and stores which do so, and there
804 // may be a performance impact. We tell LLVM they're "fast" because
805 // for the kinds of things that LLVM uses this for (merging adjacent stores
806 // of constants, etc.), WebAssembly implementations will either want the
807 // unaligned access or they'll split anyway.
808 if (Fast)
809 *Fast = 1;
810 return true;
811 }
812
isIntDivCheap(EVT VT,AttributeList Attr) const813 bool WebAssemblyTargetLowering::isIntDivCheap(EVT VT,
814 AttributeList Attr) const {
815 // The current thinking is that wasm engines will perform this optimization,
816 // so we can save on code size.
817 return true;
818 }
819
isVectorLoadExtDesirable(SDValue ExtVal) const820 bool WebAssemblyTargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
821 EVT ExtT = ExtVal.getValueType();
822 EVT MemT = cast<LoadSDNode>(ExtVal->getOperand(0))->getValueType(0);
823 return (ExtT == MVT::v8i16 && MemT == MVT::v8i8) ||
824 (ExtT == MVT::v4i32 && MemT == MVT::v4i16) ||
825 (ExtT == MVT::v2i64 && MemT == MVT::v2i32);
826 }
827
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const828 bool WebAssemblyTargetLowering::isOffsetFoldingLegal(
829 const GlobalAddressSDNode *GA) const {
830 // Wasm doesn't support function addresses with offsets
831 const GlobalValue *GV = GA->getGlobal();
832 return isa<Function>(GV) ? false : TargetLowering::isOffsetFoldingLegal(GA);
833 }
834
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const835 bool WebAssemblyTargetLowering::shouldSinkOperands(
836 Instruction *I, SmallVectorImpl<Use *> &Ops) const {
837 using namespace llvm::PatternMatch;
838
839 if (!I->getType()->isVectorTy() || !I->isShift())
840 return false;
841
842 Value *V = I->getOperand(1);
843 // We dont need to sink constant splat.
844 if (dyn_cast<Constant>(V))
845 return false;
846
847 if (match(V, m_Shuffle(m_InsertElt(m_Value(), m_Value(), m_ZeroInt()),
848 m_Value(), m_ZeroMask()))) {
849 // Sink insert
850 Ops.push_back(&cast<Instruction>(V)->getOperandUse(0));
851 // Sink shuffle
852 Ops.push_back(&I->getOperandUse(1));
853 return true;
854 }
855
856 return false;
857 }
858
getSetCCResultType(const DataLayout & DL,LLVMContext & C,EVT VT) const859 EVT WebAssemblyTargetLowering::getSetCCResultType(const DataLayout &DL,
860 LLVMContext &C,
861 EVT VT) const {
862 if (VT.isVector())
863 return VT.changeVectorElementTypeToInteger();
864
865 // So far, all branch instructions in Wasm take an I32 condition.
866 // The default TargetLowering::getSetCCResultType returns the pointer size,
867 // which would be useful to reduce instruction counts when testing
868 // against 64-bit pointers/values if at some point Wasm supports that.
869 return EVT::getIntegerVT(C, 32);
870 }
871
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const872 bool WebAssemblyTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
873 const CallInst &I,
874 MachineFunction &MF,
875 unsigned Intrinsic) const {
876 switch (Intrinsic) {
877 case Intrinsic::wasm_memory_atomic_notify:
878 Info.opc = ISD::INTRINSIC_W_CHAIN;
879 Info.memVT = MVT::i32;
880 Info.ptrVal = I.getArgOperand(0);
881 Info.offset = 0;
882 Info.align = Align(4);
883 // atomic.notify instruction does not really load the memory specified with
884 // this argument, but MachineMemOperand should either be load or store, so
885 // we set this to a load.
886 // FIXME Volatile isn't really correct, but currently all LLVM atomic
887 // instructions are treated as volatiles in the backend, so we should be
888 // consistent. The same applies for wasm_atomic_wait intrinsics too.
889 Info.flags = MachineMemOperand::MOVolatile | MachineMemOperand::MOLoad;
890 return true;
891 case Intrinsic::wasm_memory_atomic_wait32:
892 Info.opc = ISD::INTRINSIC_W_CHAIN;
893 Info.memVT = MVT::i32;
894 Info.ptrVal = I.getArgOperand(0);
895 Info.offset = 0;
896 Info.align = Align(4);
897 Info.flags = MachineMemOperand::MOVolatile | MachineMemOperand::MOLoad;
898 return true;
899 case Intrinsic::wasm_memory_atomic_wait64:
900 Info.opc = ISD::INTRINSIC_W_CHAIN;
901 Info.memVT = MVT::i64;
902 Info.ptrVal = I.getArgOperand(0);
903 Info.offset = 0;
904 Info.align = Align(8);
905 Info.flags = MachineMemOperand::MOVolatile | MachineMemOperand::MOLoad;
906 return true;
907 case Intrinsic::wasm_loadf16_f32:
908 Info.opc = ISD::INTRINSIC_W_CHAIN;
909 Info.memVT = MVT::f16;
910 Info.ptrVal = I.getArgOperand(0);
911 Info.offset = 0;
912 Info.align = Align(2);
913 Info.flags = MachineMemOperand::MOLoad;
914 return true;
915 case Intrinsic::wasm_storef16_f32:
916 Info.opc = ISD::INTRINSIC_VOID;
917 Info.memVT = MVT::f16;
918 Info.ptrVal = I.getArgOperand(1);
919 Info.offset = 0;
920 Info.align = Align(2);
921 Info.flags = MachineMemOperand::MOStore;
922 return true;
923 default:
924 return false;
925 }
926 }
927
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const928 void WebAssemblyTargetLowering::computeKnownBitsForTargetNode(
929 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
930 const SelectionDAG &DAG, unsigned Depth) const {
931 switch (Op.getOpcode()) {
932 default:
933 break;
934 case ISD::INTRINSIC_WO_CHAIN: {
935 unsigned IntNo = Op.getConstantOperandVal(0);
936 switch (IntNo) {
937 default:
938 break;
939 case Intrinsic::wasm_bitmask: {
940 unsigned BitWidth = Known.getBitWidth();
941 EVT VT = Op.getOperand(1).getSimpleValueType();
942 unsigned PossibleBits = VT.getVectorNumElements();
943 APInt ZeroMask = APInt::getHighBitsSet(BitWidth, BitWidth - PossibleBits);
944 Known.Zero |= ZeroMask;
945 break;
946 }
947 }
948 }
949 }
950 }
951
952 TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const953 WebAssemblyTargetLowering::getPreferredVectorAction(MVT VT) const {
954 if (VT.isFixedLengthVector()) {
955 MVT EltVT = VT.getVectorElementType();
956 // We have legal vector types with these lane types, so widening the
957 // vector would let us use some of the lanes directly without having to
958 // extend or truncate values.
959 if (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
960 EltVT == MVT::i64 || EltVT == MVT::f32 || EltVT == MVT::f64)
961 return TypeWidenVector;
962 }
963
964 return TargetLoweringBase::getPreferredVectorAction(VT);
965 }
966
shouldSimplifyDemandedVectorElts(SDValue Op,const TargetLoweringOpt & TLO) const967 bool WebAssemblyTargetLowering::shouldSimplifyDemandedVectorElts(
968 SDValue Op, const TargetLoweringOpt &TLO) const {
969 // ISel process runs DAGCombiner after legalization; this step is called
970 // SelectionDAG optimization phase. This post-legalization combining process
971 // runs DAGCombiner on each node, and if there was a change to be made,
972 // re-runs legalization again on it and its user nodes to make sure
973 // everythiing is in a legalized state.
974 //
975 // The legalization calls lowering routines, and we do our custom lowering for
976 // build_vectors (LowerBUILD_VECTOR), which converts undef vector elements
977 // into zeros. But there is a set of routines in DAGCombiner that turns unused
978 // (= not demanded) nodes into undef, among which SimplifyDemandedVectorElts
979 // turns unused vector elements into undefs. But this routine does not work
980 // with our custom LowerBUILD_VECTOR, which turns undefs into zeros. This
981 // combination can result in a infinite loop, in which undefs are converted to
982 // zeros in legalization and back to undefs in combining.
983 //
984 // So after DAG is legalized, we prevent SimplifyDemandedVectorElts from
985 // running for build_vectors.
986 if (Op.getOpcode() == ISD::BUILD_VECTOR && TLO.LegalOps && TLO.LegalTys)
987 return false;
988 return true;
989 }
990
991 //===----------------------------------------------------------------------===//
992 // WebAssembly Lowering private implementation.
993 //===----------------------------------------------------------------------===//
994
995 //===----------------------------------------------------------------------===//
996 // Lowering Code
997 //===----------------------------------------------------------------------===//
998
fail(const SDLoc & DL,SelectionDAG & DAG,const char * Msg)999 static void fail(const SDLoc &DL, SelectionDAG &DAG, const char *Msg) {
1000 MachineFunction &MF = DAG.getMachineFunction();
1001 DAG.getContext()->diagnose(
1002 DiagnosticInfoUnsupported(MF.getFunction(), Msg, DL.getDebugLoc()));
1003 }
1004
1005 // Test whether the given calling convention is supported.
callingConvSupported(CallingConv::ID CallConv)1006 static bool callingConvSupported(CallingConv::ID CallConv) {
1007 // We currently support the language-independent target-independent
1008 // conventions. We don't yet have a way to annotate calls with properties like
1009 // "cold", and we don't have any call-clobbered registers, so these are mostly
1010 // all handled the same.
1011 return CallConv == CallingConv::C || CallConv == CallingConv::Fast ||
1012 CallConv == CallingConv::Cold ||
1013 CallConv == CallingConv::PreserveMost ||
1014 CallConv == CallingConv::PreserveAll ||
1015 CallConv == CallingConv::CXX_FAST_TLS ||
1016 CallConv == CallingConv::WASM_EmscriptenInvoke ||
1017 CallConv == CallingConv::Swift;
1018 }
1019
1020 SDValue
LowerCall(CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const1021 WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
1022 SmallVectorImpl<SDValue> &InVals) const {
1023 SelectionDAG &DAG = CLI.DAG;
1024 SDLoc DL = CLI.DL;
1025 SDValue Chain = CLI.Chain;
1026 SDValue Callee = CLI.Callee;
1027 MachineFunction &MF = DAG.getMachineFunction();
1028 auto Layout = MF.getDataLayout();
1029
1030 CallingConv::ID CallConv = CLI.CallConv;
1031 if (!callingConvSupported(CallConv))
1032 fail(DL, DAG,
1033 "WebAssembly doesn't support language-specific or target-specific "
1034 "calling conventions yet");
1035 if (CLI.IsPatchPoint)
1036 fail(DL, DAG, "WebAssembly doesn't support patch point yet");
1037
1038 if (CLI.IsTailCall) {
1039 auto NoTail = [&](const char *Msg) {
1040 if (CLI.CB && CLI.CB->isMustTailCall())
1041 fail(DL, DAG, Msg);
1042 CLI.IsTailCall = false;
1043 };
1044
1045 if (!Subtarget->hasTailCall())
1046 NoTail("WebAssembly 'tail-call' feature not enabled");
1047
1048 // Varargs calls cannot be tail calls because the buffer is on the stack
1049 if (CLI.IsVarArg)
1050 NoTail("WebAssembly does not support varargs tail calls");
1051
1052 // Do not tail call unless caller and callee return types match
1053 const Function &F = MF.getFunction();
1054 const TargetMachine &TM = getTargetMachine();
1055 Type *RetTy = F.getReturnType();
1056 SmallVector<MVT, 4> CallerRetTys;
1057 SmallVector<MVT, 4> CalleeRetTys;
1058 computeLegalValueVTs(F, TM, RetTy, CallerRetTys);
1059 computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys);
1060 bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() &&
1061 std::equal(CallerRetTys.begin(), CallerRetTys.end(),
1062 CalleeRetTys.begin());
1063 if (!TypesMatch)
1064 NoTail("WebAssembly tail call requires caller and callee return types to "
1065 "match");
1066
1067 // If pointers to local stack values are passed, we cannot tail call
1068 if (CLI.CB) {
1069 for (auto &Arg : CLI.CB->args()) {
1070 Value *Val = Arg.get();
1071 // Trace the value back through pointer operations
1072 while (true) {
1073 Value *Src = Val->stripPointerCastsAndAliases();
1074 if (auto *GEP = dyn_cast<GetElementPtrInst>(Src))
1075 Src = GEP->getPointerOperand();
1076 if (Val == Src)
1077 break;
1078 Val = Src;
1079 }
1080 if (isa<AllocaInst>(Val)) {
1081 NoTail(
1082 "WebAssembly does not support tail calling with stack arguments");
1083 break;
1084 }
1085 }
1086 }
1087 }
1088
1089 SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1090 SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
1091 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1092
1093 // The generic code may have added an sret argument. If we're lowering an
1094 // invoke function, the ABI requires that the function pointer be the first
1095 // argument, so we may have to swap the arguments.
1096 if (CallConv == CallingConv::WASM_EmscriptenInvoke && Outs.size() >= 2 &&
1097 Outs[0].Flags.isSRet()) {
1098 std::swap(Outs[0], Outs[1]);
1099 std::swap(OutVals[0], OutVals[1]);
1100 }
1101
1102 bool HasSwiftSelfArg = false;
1103 bool HasSwiftErrorArg = false;
1104 unsigned NumFixedArgs = 0;
1105 for (unsigned I = 0; I < Outs.size(); ++I) {
1106 const ISD::OutputArg &Out = Outs[I];
1107 SDValue &OutVal = OutVals[I];
1108 HasSwiftSelfArg |= Out.Flags.isSwiftSelf();
1109 HasSwiftErrorArg |= Out.Flags.isSwiftError();
1110 if (Out.Flags.isNest())
1111 fail(DL, DAG, "WebAssembly hasn't implemented nest arguments");
1112 if (Out.Flags.isInAlloca())
1113 fail(DL, DAG, "WebAssembly hasn't implemented inalloca arguments");
1114 if (Out.Flags.isInConsecutiveRegs())
1115 fail(DL, DAG, "WebAssembly hasn't implemented cons regs arguments");
1116 if (Out.Flags.isInConsecutiveRegsLast())
1117 fail(DL, DAG, "WebAssembly hasn't implemented cons regs last arguments");
1118 if (Out.Flags.isByVal() && Out.Flags.getByValSize() != 0) {
1119 auto &MFI = MF.getFrameInfo();
1120 int FI = MFI.CreateStackObject(Out.Flags.getByValSize(),
1121 Out.Flags.getNonZeroByValAlign(),
1122 /*isSS=*/false);
1123 SDValue SizeNode =
1124 DAG.getConstant(Out.Flags.getByValSize(), DL, MVT::i32);
1125 SDValue FINode = DAG.getFrameIndex(FI, getPointerTy(Layout));
1126 Chain = DAG.getMemcpy(Chain, DL, FINode, OutVal, SizeNode,
1127 Out.Flags.getNonZeroByValAlign(),
1128 /*isVolatile*/ false, /*AlwaysInline=*/false,
1129 /*CI=*/nullptr, std::nullopt, MachinePointerInfo(),
1130 MachinePointerInfo());
1131 OutVal = FINode;
1132 }
1133 // Count the number of fixed args *after* legalization.
1134 NumFixedArgs += Out.IsFixed;
1135 }
1136
1137 bool IsVarArg = CLI.IsVarArg;
1138 auto PtrVT = getPointerTy(Layout);
1139
1140 // For swiftcc, emit additional swiftself and swifterror arguments
1141 // if there aren't. These additional arguments are also added for callee
1142 // signature They are necessary to match callee and caller signature for
1143 // indirect call.
1144 if (CallConv == CallingConv::Swift) {
1145 if (!HasSwiftSelfArg) {
1146 NumFixedArgs++;
1147 ISD::OutputArg Arg;
1148 Arg.Flags.setSwiftSelf();
1149 CLI.Outs.push_back(Arg);
1150 SDValue ArgVal = DAG.getUNDEF(PtrVT);
1151 CLI.OutVals.push_back(ArgVal);
1152 }
1153 if (!HasSwiftErrorArg) {
1154 NumFixedArgs++;
1155 ISD::OutputArg Arg;
1156 Arg.Flags.setSwiftError();
1157 CLI.Outs.push_back(Arg);
1158 SDValue ArgVal = DAG.getUNDEF(PtrVT);
1159 CLI.OutVals.push_back(ArgVal);
1160 }
1161 }
1162
1163 // Analyze operands of the call, assigning locations to each operand.
1164 SmallVector<CCValAssign, 16> ArgLocs;
1165 CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
1166
1167 if (IsVarArg) {
1168 // Outgoing non-fixed arguments are placed in a buffer. First
1169 // compute their offsets and the total amount of buffer space needed.
1170 for (unsigned I = NumFixedArgs; I < Outs.size(); ++I) {
1171 const ISD::OutputArg &Out = Outs[I];
1172 SDValue &Arg = OutVals[I];
1173 EVT VT = Arg.getValueType();
1174 assert(VT != MVT::iPTR && "Legalized args should be concrete");
1175 Type *Ty = VT.getTypeForEVT(*DAG.getContext());
1176 Align Alignment =
1177 std::max(Out.Flags.getNonZeroOrigAlign(), Layout.getABITypeAlign(Ty));
1178 unsigned Offset =
1179 CCInfo.AllocateStack(Layout.getTypeAllocSize(Ty), Alignment);
1180 CCInfo.addLoc(CCValAssign::getMem(ArgLocs.size(), VT.getSimpleVT(),
1181 Offset, VT.getSimpleVT(),
1182 CCValAssign::Full));
1183 }
1184 }
1185
1186 unsigned NumBytes = CCInfo.getAlignedCallFrameSize();
1187
1188 SDValue FINode;
1189 if (IsVarArg && NumBytes) {
1190 // For non-fixed arguments, next emit stores to store the argument values
1191 // to the stack buffer at the offsets computed above.
1192 int FI = MF.getFrameInfo().CreateStackObject(NumBytes,
1193 Layout.getStackAlignment(),
1194 /*isSS=*/false);
1195 unsigned ValNo = 0;
1196 SmallVector<SDValue, 8> Chains;
1197 for (SDValue Arg : drop_begin(OutVals, NumFixedArgs)) {
1198 assert(ArgLocs[ValNo].getValNo() == ValNo &&
1199 "ArgLocs should remain in order and only hold varargs args");
1200 unsigned Offset = ArgLocs[ValNo++].getLocMemOffset();
1201 FINode = DAG.getFrameIndex(FI, getPointerTy(Layout));
1202 SDValue Add = DAG.getNode(ISD::ADD, DL, PtrVT, FINode,
1203 DAG.getConstant(Offset, DL, PtrVT));
1204 Chains.push_back(
1205 DAG.getStore(Chain, DL, Arg, Add,
1206 MachinePointerInfo::getFixedStack(MF, FI, Offset)));
1207 }
1208 if (!Chains.empty())
1209 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
1210 } else if (IsVarArg) {
1211 FINode = DAG.getIntPtrConstant(0, DL);
1212 }
1213
1214 if (Callee->getOpcode() == ISD::GlobalAddress) {
1215 // If the callee is a GlobalAddress node (quite common, every direct call
1216 // is) turn it into a TargetGlobalAddress node so that LowerGlobalAddress
1217 // doesn't at MO_GOT which is not needed for direct calls.
1218 GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Callee);
1219 Callee = DAG.getTargetGlobalAddress(GA->getGlobal(), DL,
1220 getPointerTy(DAG.getDataLayout()),
1221 GA->getOffset());
1222 Callee = DAG.getNode(WebAssemblyISD::Wrapper, DL,
1223 getPointerTy(DAG.getDataLayout()), Callee);
1224 }
1225
1226 // Compute the operands for the CALLn node.
1227 SmallVector<SDValue, 16> Ops;
1228 Ops.push_back(Chain);
1229 Ops.push_back(Callee);
1230
1231 // Add all fixed arguments. Note that for non-varargs calls, NumFixedArgs
1232 // isn't reliable.
1233 Ops.append(OutVals.begin(),
1234 IsVarArg ? OutVals.begin() + NumFixedArgs : OutVals.end());
1235 // Add a pointer to the vararg buffer.
1236 if (IsVarArg)
1237 Ops.push_back(FINode);
1238
1239 SmallVector<EVT, 8> InTys;
1240 for (const auto &In : Ins) {
1241 assert(!In.Flags.isByVal() && "byval is not valid for return values");
1242 assert(!In.Flags.isNest() && "nest is not valid for return values");
1243 if (In.Flags.isInAlloca())
1244 fail(DL, DAG, "WebAssembly hasn't implemented inalloca return values");
1245 if (In.Flags.isInConsecutiveRegs())
1246 fail(DL, DAG, "WebAssembly hasn't implemented cons regs return values");
1247 if (In.Flags.isInConsecutiveRegsLast())
1248 fail(DL, DAG,
1249 "WebAssembly hasn't implemented cons regs last return values");
1250 // Ignore In.getNonZeroOrigAlign() because all our arguments are passed in
1251 // registers.
1252 InTys.push_back(In.VT);
1253 }
1254
1255 // Lastly, if this is a call to a funcref we need to add an instruction
1256 // table.set to the chain and transform the call.
1257 if (CLI.CB && WebAssembly::isWebAssemblyFuncrefType(
1258 CLI.CB->getCalledOperand()->getType())) {
1259 // In the absence of function references proposal where a funcref call is
1260 // lowered to call_ref, using reference types we generate a table.set to set
1261 // the funcref to a special table used solely for this purpose, followed by
1262 // a call_indirect. Here we just generate the table set, and return the
1263 // SDValue of the table.set so that LowerCall can finalize the lowering by
1264 // generating the call_indirect.
1265 SDValue Chain = Ops[0];
1266
1267 MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
1268 MF.getContext(), Subtarget);
1269 SDValue Sym = DAG.getMCSymbol(Table, PtrVT);
1270 SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32);
1271 SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee};
1272 SDValue TableSet = DAG.getMemIntrinsicNode(
1273 WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps,
1274 MVT::funcref,
1275 // Machine Mem Operand args
1276 MachinePointerInfo(
1277 WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF),
1278 CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()),
1279 MachineMemOperand::MOStore);
1280
1281 Ops[0] = TableSet; // The new chain is the TableSet itself
1282 }
1283
1284 if (CLI.IsTailCall) {
1285 // ret_calls do not return values to the current frame
1286 SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
1287 return DAG.getNode(WebAssemblyISD::RET_CALL, DL, NodeTys, Ops);
1288 }
1289
1290 InTys.push_back(MVT::Other);
1291 SDVTList InTyList = DAG.getVTList(InTys);
1292 SDValue Res = DAG.getNode(WebAssemblyISD::CALL, DL, InTyList, Ops);
1293
1294 for (size_t I = 0; I < Ins.size(); ++I)
1295 InVals.push_back(Res.getValue(I));
1296
1297 // Return the chain
1298 return Res.getValue(Ins.size());
1299 }
1300
CanLowerReturn(CallingConv::ID,MachineFunction &,bool,const SmallVectorImpl<ISD::OutputArg> & Outs,LLVMContext &) const1301 bool WebAssemblyTargetLowering::CanLowerReturn(
1302 CallingConv::ID /*CallConv*/, MachineFunction & /*MF*/, bool /*IsVarArg*/,
1303 const SmallVectorImpl<ISD::OutputArg> &Outs,
1304 LLVMContext & /*Context*/) const {
1305 // WebAssembly can only handle returning tuples with multivalue enabled
1306 return WebAssembly::canLowerReturn(Outs.size(), Subtarget);
1307 }
1308
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SDLoc & DL,SelectionDAG & DAG) const1309 SDValue WebAssemblyTargetLowering::LowerReturn(
1310 SDValue Chain, CallingConv::ID CallConv, bool /*IsVarArg*/,
1311 const SmallVectorImpl<ISD::OutputArg> &Outs,
1312 const SmallVectorImpl<SDValue> &OutVals, const SDLoc &DL,
1313 SelectionDAG &DAG) const {
1314 assert(WebAssembly::canLowerReturn(Outs.size(), Subtarget) &&
1315 "MVP WebAssembly can only return up to one value");
1316 if (!callingConvSupported(CallConv))
1317 fail(DL, DAG, "WebAssembly doesn't support non-C calling conventions");
1318
1319 SmallVector<SDValue, 4> RetOps(1, Chain);
1320 RetOps.append(OutVals.begin(), OutVals.end());
1321 Chain = DAG.getNode(WebAssemblyISD::RETURN, DL, MVT::Other, RetOps);
1322
1323 // Record the number and types of the return values.
1324 for (const ISD::OutputArg &Out : Outs) {
1325 assert(!Out.Flags.isByVal() && "byval is not valid for return values");
1326 assert(!Out.Flags.isNest() && "nest is not valid for return values");
1327 assert(Out.IsFixed && "non-fixed return value is not valid");
1328 if (Out.Flags.isInAlloca())
1329 fail(DL, DAG, "WebAssembly hasn't implemented inalloca results");
1330 if (Out.Flags.isInConsecutiveRegs())
1331 fail(DL, DAG, "WebAssembly hasn't implemented cons regs results");
1332 if (Out.Flags.isInConsecutiveRegsLast())
1333 fail(DL, DAG, "WebAssembly hasn't implemented cons regs last results");
1334 }
1335
1336 return Chain;
1337 }
1338
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool IsVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const1339 SDValue WebAssemblyTargetLowering::LowerFormalArguments(
1340 SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
1341 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
1342 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
1343 if (!callingConvSupported(CallConv))
1344 fail(DL, DAG, "WebAssembly doesn't support non-C calling conventions");
1345
1346 MachineFunction &MF = DAG.getMachineFunction();
1347 auto *MFI = MF.getInfo<WebAssemblyFunctionInfo>();
1348
1349 // Set up the incoming ARGUMENTS value, which serves to represent the liveness
1350 // of the incoming values before they're represented by virtual registers.
1351 MF.getRegInfo().addLiveIn(WebAssembly::ARGUMENTS);
1352
1353 bool HasSwiftErrorArg = false;
1354 bool HasSwiftSelfArg = false;
1355 for (const ISD::InputArg &In : Ins) {
1356 HasSwiftSelfArg |= In.Flags.isSwiftSelf();
1357 HasSwiftErrorArg |= In.Flags.isSwiftError();
1358 if (In.Flags.isInAlloca())
1359 fail(DL, DAG, "WebAssembly hasn't implemented inalloca arguments");
1360 if (In.Flags.isNest())
1361 fail(DL, DAG, "WebAssembly hasn't implemented nest arguments");
1362 if (In.Flags.isInConsecutiveRegs())
1363 fail(DL, DAG, "WebAssembly hasn't implemented cons regs arguments");
1364 if (In.Flags.isInConsecutiveRegsLast())
1365 fail(DL, DAG, "WebAssembly hasn't implemented cons regs last arguments");
1366 // Ignore In.getNonZeroOrigAlign() because all our arguments are passed in
1367 // registers.
1368 InVals.push_back(In.Used ? DAG.getNode(WebAssemblyISD::ARGUMENT, DL, In.VT,
1369 DAG.getTargetConstant(InVals.size(),
1370 DL, MVT::i32))
1371 : DAG.getUNDEF(In.VT));
1372
1373 // Record the number and types of arguments.
1374 MFI->addParam(In.VT);
1375 }
1376
1377 // For swiftcc, emit additional swiftself and swifterror arguments
1378 // if there aren't. These additional arguments are also added for callee
1379 // signature They are necessary to match callee and caller signature for
1380 // indirect call.
1381 auto PtrVT = getPointerTy(MF.getDataLayout());
1382 if (CallConv == CallingConv::Swift) {
1383 if (!HasSwiftSelfArg) {
1384 MFI->addParam(PtrVT);
1385 }
1386 if (!HasSwiftErrorArg) {
1387 MFI->addParam(PtrVT);
1388 }
1389 }
1390 // Varargs are copied into a buffer allocated by the caller, and a pointer to
1391 // the buffer is passed as an argument.
1392 if (IsVarArg) {
1393 MVT PtrVT = getPointerTy(MF.getDataLayout());
1394 Register VarargVreg =
1395 MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrVT));
1396 MFI->setVarargBufferVreg(VarargVreg);
1397 Chain = DAG.getCopyToReg(
1398 Chain, DL, VarargVreg,
1399 DAG.getNode(WebAssemblyISD::ARGUMENT, DL, PtrVT,
1400 DAG.getTargetConstant(Ins.size(), DL, MVT::i32)));
1401 MFI->addParam(PtrVT);
1402 }
1403
1404 // Record the number and types of arguments and results.
1405 SmallVector<MVT, 4> Params;
1406 SmallVector<MVT, 4> Results;
1407 computeSignatureVTs(MF.getFunction().getFunctionType(), &MF.getFunction(),
1408 MF.getFunction(), DAG.getTarget(), Params, Results);
1409 for (MVT VT : Results)
1410 MFI->addResult(VT);
1411 // TODO: Use signatures in WebAssemblyMachineFunctionInfo too and unify
1412 // the param logic here with ComputeSignatureVTs
1413 assert(MFI->getParams().size() == Params.size() &&
1414 std::equal(MFI->getParams().begin(), MFI->getParams().end(),
1415 Params.begin()));
1416
1417 return Chain;
1418 }
1419
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const1420 void WebAssemblyTargetLowering::ReplaceNodeResults(
1421 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
1422 switch (N->getOpcode()) {
1423 case ISD::SIGN_EXTEND_INREG:
1424 // Do not add any results, signifying that N should not be custom lowered
1425 // after all. This happens because simd128 turns on custom lowering for
1426 // SIGN_EXTEND_INREG, but for non-vector sign extends the result might be an
1427 // illegal type.
1428 break;
1429 case ISD::SIGN_EXTEND_VECTOR_INREG:
1430 case ISD::ZERO_EXTEND_VECTOR_INREG:
1431 // Do not add any results, signifying that N should not be custom lowered.
1432 // EXTEND_VECTOR_INREG is implemented for some vectors, but not all.
1433 break;
1434 default:
1435 llvm_unreachable(
1436 "ReplaceNodeResults not implemented for this op for WebAssembly!");
1437 }
1438 }
1439
1440 //===----------------------------------------------------------------------===//
1441 // Custom lowering hooks.
1442 //===----------------------------------------------------------------------===//
1443
LowerOperation(SDValue Op,SelectionDAG & DAG) const1444 SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
1445 SelectionDAG &DAG) const {
1446 SDLoc DL(Op);
1447 switch (Op.getOpcode()) {
1448 default:
1449 llvm_unreachable("unimplemented operation lowering");
1450 return SDValue();
1451 case ISD::FrameIndex:
1452 return LowerFrameIndex(Op, DAG);
1453 case ISD::GlobalAddress:
1454 return LowerGlobalAddress(Op, DAG);
1455 case ISD::GlobalTLSAddress:
1456 return LowerGlobalTLSAddress(Op, DAG);
1457 case ISD::ExternalSymbol:
1458 return LowerExternalSymbol(Op, DAG);
1459 case ISD::JumpTable:
1460 return LowerJumpTable(Op, DAG);
1461 case ISD::BR_JT:
1462 return LowerBR_JT(Op, DAG);
1463 case ISD::VASTART:
1464 return LowerVASTART(Op, DAG);
1465 case ISD::BlockAddress:
1466 case ISD::BRIND:
1467 fail(DL, DAG, "WebAssembly hasn't implemented computed gotos");
1468 return SDValue();
1469 case ISD::RETURNADDR:
1470 return LowerRETURNADDR(Op, DAG);
1471 case ISD::FRAMEADDR:
1472 return LowerFRAMEADDR(Op, DAG);
1473 case ISD::CopyToReg:
1474 return LowerCopyToReg(Op, DAG);
1475 case ISD::EXTRACT_VECTOR_ELT:
1476 case ISD::INSERT_VECTOR_ELT:
1477 return LowerAccessVectorElement(Op, DAG);
1478 case ISD::INTRINSIC_VOID:
1479 case ISD::INTRINSIC_WO_CHAIN:
1480 case ISD::INTRINSIC_W_CHAIN:
1481 return LowerIntrinsic(Op, DAG);
1482 case ISD::SIGN_EXTEND_INREG:
1483 return LowerSIGN_EXTEND_INREG(Op, DAG);
1484 case ISD::ZERO_EXTEND_VECTOR_INREG:
1485 case ISD::SIGN_EXTEND_VECTOR_INREG:
1486 return LowerEXTEND_VECTOR_INREG(Op, DAG);
1487 case ISD::BUILD_VECTOR:
1488 return LowerBUILD_VECTOR(Op, DAG);
1489 case ISD::VECTOR_SHUFFLE:
1490 return LowerVECTOR_SHUFFLE(Op, DAG);
1491 case ISD::SETCC:
1492 return LowerSETCC(Op, DAG);
1493 case ISD::SHL:
1494 case ISD::SRA:
1495 case ISD::SRL:
1496 return LowerShift(Op, DAG);
1497 case ISD::FP_TO_SINT_SAT:
1498 case ISD::FP_TO_UINT_SAT:
1499 return LowerFP_TO_INT_SAT(Op, DAG);
1500 case ISD::LOAD:
1501 return LowerLoad(Op, DAG);
1502 case ISD::STORE:
1503 return LowerStore(Op, DAG);
1504 case ISD::CTPOP:
1505 case ISD::CTLZ:
1506 case ISD::CTTZ:
1507 return DAG.UnrollVectorOp(Op.getNode());
1508 case ISD::CLEAR_CACHE:
1509 report_fatal_error("llvm.clear_cache is not supported on wasm");
1510 }
1511 }
1512
IsWebAssemblyGlobal(SDValue Op)1513 static bool IsWebAssemblyGlobal(SDValue Op) {
1514 if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op))
1515 return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
1516
1517 return false;
1518 }
1519
IsWebAssemblyLocal(SDValue Op,SelectionDAG & DAG)1520 static std::optional<unsigned> IsWebAssemblyLocal(SDValue Op,
1521 SelectionDAG &DAG) {
1522 const FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(Op);
1523 if (!FI)
1524 return std::nullopt;
1525
1526 auto &MF = DAG.getMachineFunction();
1527 return WebAssemblyFrameLowering::getLocalForStackObject(MF, FI->getIndex());
1528 }
1529
LowerStore(SDValue Op,SelectionDAG & DAG) const1530 SDValue WebAssemblyTargetLowering::LowerStore(SDValue Op,
1531 SelectionDAG &DAG) const {
1532 SDLoc DL(Op);
1533 StoreSDNode *SN = cast<StoreSDNode>(Op.getNode());
1534 const SDValue &Value = SN->getValue();
1535 const SDValue &Base = SN->getBasePtr();
1536 const SDValue &Offset = SN->getOffset();
1537
1538 if (IsWebAssemblyGlobal(Base)) {
1539 if (!Offset->isUndef())
1540 report_fatal_error("unexpected offset when storing to webassembly global",
1541 false);
1542
1543 SDVTList Tys = DAG.getVTList(MVT::Other);
1544 SDValue Ops[] = {SN->getChain(), Value, Base};
1545 return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_SET, DL, Tys, Ops,
1546 SN->getMemoryVT(), SN->getMemOperand());
1547 }
1548
1549 if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
1550 if (!Offset->isUndef())
1551 report_fatal_error("unexpected offset when storing to webassembly local",
1552 false);
1553
1554 SDValue Idx = DAG.getTargetConstant(*Local, Base, MVT::i32);
1555 SDVTList Tys = DAG.getVTList(MVT::Other); // The chain.
1556 SDValue Ops[] = {SN->getChain(), Idx, Value};
1557 return DAG.getNode(WebAssemblyISD::LOCAL_SET, DL, Tys, Ops);
1558 }
1559
1560 if (WebAssembly::isWasmVarAddressSpace(SN->getAddressSpace()))
1561 report_fatal_error(
1562 "Encountered an unlowerable store to the wasm_var address space",
1563 false);
1564
1565 return Op;
1566 }
1567
LowerLoad(SDValue Op,SelectionDAG & DAG) const1568 SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
1569 SelectionDAG &DAG) const {
1570 SDLoc DL(Op);
1571 LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
1572 const SDValue &Base = LN->getBasePtr();
1573 const SDValue &Offset = LN->getOffset();
1574
1575 if (IsWebAssemblyGlobal(Base)) {
1576 if (!Offset->isUndef())
1577 report_fatal_error(
1578 "unexpected offset when loading from webassembly global", false);
1579
1580 SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other);
1581 SDValue Ops[] = {LN->getChain(), Base};
1582 return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
1583 LN->getMemoryVT(), LN->getMemOperand());
1584 }
1585
1586 if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
1587 if (!Offset->isUndef())
1588 report_fatal_error(
1589 "unexpected offset when loading from webassembly local", false);
1590
1591 SDValue Idx = DAG.getTargetConstant(*Local, Base, MVT::i32);
1592 EVT LocalVT = LN->getValueType(0);
1593 SDValue LocalGet = DAG.getNode(WebAssemblyISD::LOCAL_GET, DL, LocalVT,
1594 {LN->getChain(), Idx});
1595 SDValue Result = DAG.getMergeValues({LocalGet, LN->getChain()}, DL);
1596 assert(Result->getNumValues() == 2 && "Loads must carry a chain!");
1597 return Result;
1598 }
1599
1600 if (WebAssembly::isWasmVarAddressSpace(LN->getAddressSpace()))
1601 report_fatal_error(
1602 "Encountered an unlowerable load from the wasm_var address space",
1603 false);
1604
1605 return Op;
1606 }
1607
LowerCopyToReg(SDValue Op,SelectionDAG & DAG) const1608 SDValue WebAssemblyTargetLowering::LowerCopyToReg(SDValue Op,
1609 SelectionDAG &DAG) const {
1610 SDValue Src = Op.getOperand(2);
1611 if (isa<FrameIndexSDNode>(Src.getNode())) {
1612 // CopyToReg nodes don't support FrameIndex operands. Other targets select
1613 // the FI to some LEA-like instruction, but since we don't have that, we
1614 // need to insert some kind of instruction that can take an FI operand and
1615 // produces a value usable by CopyToReg (i.e. in a vreg). So insert a dummy
1616 // local.copy between Op and its FI operand.
1617 SDValue Chain = Op.getOperand(0);
1618 SDLoc DL(Op);
1619 Register Reg = cast<RegisterSDNode>(Op.getOperand(1))->getReg();
1620 EVT VT = Src.getValueType();
1621 SDValue Copy(DAG.getMachineNode(VT == MVT::i32 ? WebAssembly::COPY_I32
1622 : WebAssembly::COPY_I64,
1623 DL, VT, Src),
1624 0);
1625 return Op.getNode()->getNumValues() == 1
1626 ? DAG.getCopyToReg(Chain, DL, Reg, Copy)
1627 : DAG.getCopyToReg(Chain, DL, Reg, Copy,
1628 Op.getNumOperands() == 4 ? Op.getOperand(3)
1629 : SDValue());
1630 }
1631 return SDValue();
1632 }
1633
LowerFrameIndex(SDValue Op,SelectionDAG & DAG) const1634 SDValue WebAssemblyTargetLowering::LowerFrameIndex(SDValue Op,
1635 SelectionDAG &DAG) const {
1636 int FI = cast<FrameIndexSDNode>(Op)->getIndex();
1637 return DAG.getTargetFrameIndex(FI, Op.getValueType());
1638 }
1639
LowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const1640 SDValue WebAssemblyTargetLowering::LowerRETURNADDR(SDValue Op,
1641 SelectionDAG &DAG) const {
1642 SDLoc DL(Op);
1643
1644 if (!Subtarget->getTargetTriple().isOSEmscripten()) {
1645 fail(DL, DAG,
1646 "Non-Emscripten WebAssembly hasn't implemented "
1647 "__builtin_return_address");
1648 return SDValue();
1649 }
1650
1651 if (verifyReturnAddressArgumentIsConstant(Op, DAG))
1652 return SDValue();
1653
1654 unsigned Depth = Op.getConstantOperandVal(0);
1655 MakeLibCallOptions CallOptions;
1656 return makeLibCall(DAG, RTLIB::RETURN_ADDRESS, Op.getValueType(),
1657 {DAG.getConstant(Depth, DL, MVT::i32)}, CallOptions, DL)
1658 .first;
1659 }
1660
LowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const1661 SDValue WebAssemblyTargetLowering::LowerFRAMEADDR(SDValue Op,
1662 SelectionDAG &DAG) const {
1663 // Non-zero depths are not supported by WebAssembly currently. Use the
1664 // legalizer's default expansion, which is to return 0 (what this function is
1665 // documented to do).
1666 if (Op.getConstantOperandVal(0) > 0)
1667 return SDValue();
1668
1669 DAG.getMachineFunction().getFrameInfo().setFrameAddressIsTaken(true);
1670 EVT VT = Op.getValueType();
1671 Register FP =
1672 Subtarget->getRegisterInfo()->getFrameRegister(DAG.getMachineFunction());
1673 return DAG.getCopyFromReg(DAG.getEntryNode(), SDLoc(Op), FP, VT);
1674 }
1675
1676 SDValue
LowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const1677 WebAssemblyTargetLowering::LowerGlobalTLSAddress(SDValue Op,
1678 SelectionDAG &DAG) const {
1679 SDLoc DL(Op);
1680 const auto *GA = cast<GlobalAddressSDNode>(Op);
1681
1682 MachineFunction &MF = DAG.getMachineFunction();
1683 if (!MF.getSubtarget<WebAssemblySubtarget>().hasBulkMemory())
1684 report_fatal_error("cannot use thread-local storage without bulk memory",
1685 false);
1686
1687 const GlobalValue *GV = GA->getGlobal();
1688
1689 // Currently only Emscripten supports dynamic linking with threads. Therefore,
1690 // on other targets, if we have thread-local storage, only the local-exec
1691 // model is possible.
1692 auto model = Subtarget->getTargetTriple().isOSEmscripten()
1693 ? GV->getThreadLocalMode()
1694 : GlobalValue::LocalExecTLSModel;
1695
1696 // Unsupported TLS modes
1697 assert(model != GlobalValue::NotThreadLocal);
1698 assert(model != GlobalValue::InitialExecTLSModel);
1699
1700 if (model == GlobalValue::LocalExecTLSModel ||
1701 model == GlobalValue::LocalDynamicTLSModel ||
1702 (model == GlobalValue::GeneralDynamicTLSModel &&
1703 getTargetMachine().shouldAssumeDSOLocal(GV))) {
1704 // For DSO-local TLS variables we use offset from __tls_base
1705
1706 MVT PtrVT = getPointerTy(DAG.getDataLayout());
1707 auto GlobalGet = PtrVT == MVT::i64 ? WebAssembly::GLOBAL_GET_I64
1708 : WebAssembly::GLOBAL_GET_I32;
1709 const char *BaseName = MF.createExternalSymbolName("__tls_base");
1710
1711 SDValue BaseAddr(
1712 DAG.getMachineNode(GlobalGet, DL, PtrVT,
1713 DAG.getTargetExternalSymbol(BaseName, PtrVT)),
1714 0);
1715
1716 SDValue TLSOffset = DAG.getTargetGlobalAddress(
1717 GV, DL, PtrVT, GA->getOffset(), WebAssemblyII::MO_TLS_BASE_REL);
1718 SDValue SymOffset =
1719 DAG.getNode(WebAssemblyISD::WrapperREL, DL, PtrVT, TLSOffset);
1720
1721 return DAG.getNode(ISD::ADD, DL, PtrVT, BaseAddr, SymOffset);
1722 }
1723
1724 assert(model == GlobalValue::GeneralDynamicTLSModel);
1725
1726 EVT VT = Op.getValueType();
1727 return DAG.getNode(WebAssemblyISD::Wrapper, DL, VT,
1728 DAG.getTargetGlobalAddress(GA->getGlobal(), DL, VT,
1729 GA->getOffset(),
1730 WebAssemblyII::MO_GOT_TLS));
1731 }
1732
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const1733 SDValue WebAssemblyTargetLowering::LowerGlobalAddress(SDValue Op,
1734 SelectionDAG &DAG) const {
1735 SDLoc DL(Op);
1736 const auto *GA = cast<GlobalAddressSDNode>(Op);
1737 EVT VT = Op.getValueType();
1738 assert(GA->getTargetFlags() == 0 &&
1739 "Unexpected target flags on generic GlobalAddressSDNode");
1740 if (!WebAssembly::isValidAddressSpace(GA->getAddressSpace()))
1741 fail(DL, DAG, "Invalid address space for WebAssembly target");
1742
1743 unsigned OperandFlags = 0;
1744 const GlobalValue *GV = GA->getGlobal();
1745 // Since WebAssembly tables cannot yet be shared accross modules, we don't
1746 // need special treatment for tables in PIC mode.
1747 if (isPositionIndependent() &&
1748 !WebAssembly::isWebAssemblyTableType(GV->getValueType())) {
1749 if (getTargetMachine().shouldAssumeDSOLocal(GV)) {
1750 MachineFunction &MF = DAG.getMachineFunction();
1751 MVT PtrVT = getPointerTy(MF.getDataLayout());
1752 const char *BaseName;
1753 if (GV->getValueType()->isFunctionTy()) {
1754 BaseName = MF.createExternalSymbolName("__table_base");
1755 OperandFlags = WebAssemblyII::MO_TABLE_BASE_REL;
1756 } else {
1757 BaseName = MF.createExternalSymbolName("__memory_base");
1758 OperandFlags = WebAssemblyII::MO_MEMORY_BASE_REL;
1759 }
1760 SDValue BaseAddr =
1761 DAG.getNode(WebAssemblyISD::Wrapper, DL, PtrVT,
1762 DAG.getTargetExternalSymbol(BaseName, PtrVT));
1763
1764 SDValue SymAddr = DAG.getNode(
1765 WebAssemblyISD::WrapperREL, DL, VT,
1766 DAG.getTargetGlobalAddress(GA->getGlobal(), DL, VT, GA->getOffset(),
1767 OperandFlags));
1768
1769 return DAG.getNode(ISD::ADD, DL, VT, BaseAddr, SymAddr);
1770 }
1771 OperandFlags = WebAssemblyII::MO_GOT;
1772 }
1773
1774 return DAG.getNode(WebAssemblyISD::Wrapper, DL, VT,
1775 DAG.getTargetGlobalAddress(GA->getGlobal(), DL, VT,
1776 GA->getOffset(), OperandFlags));
1777 }
1778
1779 SDValue
LowerExternalSymbol(SDValue Op,SelectionDAG & DAG) const1780 WebAssemblyTargetLowering::LowerExternalSymbol(SDValue Op,
1781 SelectionDAG &DAG) const {
1782 SDLoc DL(Op);
1783 const auto *ES = cast<ExternalSymbolSDNode>(Op);
1784 EVT VT = Op.getValueType();
1785 assert(ES->getTargetFlags() == 0 &&
1786 "Unexpected target flags on generic ExternalSymbolSDNode");
1787 return DAG.getNode(WebAssemblyISD::Wrapper, DL, VT,
1788 DAG.getTargetExternalSymbol(ES->getSymbol(), VT));
1789 }
1790
LowerJumpTable(SDValue Op,SelectionDAG & DAG) const1791 SDValue WebAssemblyTargetLowering::LowerJumpTable(SDValue Op,
1792 SelectionDAG &DAG) const {
1793 // There's no need for a Wrapper node because we always incorporate a jump
1794 // table operand into a BR_TABLE instruction, rather than ever
1795 // materializing it in a register.
1796 const JumpTableSDNode *JT = cast<JumpTableSDNode>(Op);
1797 return DAG.getTargetJumpTable(JT->getIndex(), Op.getValueType(),
1798 JT->getTargetFlags());
1799 }
1800
LowerBR_JT(SDValue Op,SelectionDAG & DAG) const1801 SDValue WebAssemblyTargetLowering::LowerBR_JT(SDValue Op,
1802 SelectionDAG &DAG) const {
1803 SDLoc DL(Op);
1804 SDValue Chain = Op.getOperand(0);
1805 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
1806 SDValue Index = Op.getOperand(2);
1807 assert(JT->getTargetFlags() == 0 && "WebAssembly doesn't set target flags");
1808
1809 SmallVector<SDValue, 8> Ops;
1810 Ops.push_back(Chain);
1811 Ops.push_back(Index);
1812
1813 MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
1814 const auto &MBBs = MJTI->getJumpTables()[JT->getIndex()].MBBs;
1815
1816 // Add an operand for each case.
1817 for (auto *MBB : MBBs)
1818 Ops.push_back(DAG.getBasicBlock(MBB));
1819
1820 // Add the first MBB as a dummy default target for now. This will be replaced
1821 // with the proper default target (and the preceding range check eliminated)
1822 // if possible by WebAssemblyFixBrTableDefaults.
1823 Ops.push_back(DAG.getBasicBlock(*MBBs.begin()));
1824 return DAG.getNode(WebAssemblyISD::BR_TABLE, DL, MVT::Other, Ops);
1825 }
1826
LowerVASTART(SDValue Op,SelectionDAG & DAG) const1827 SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
1828 SelectionDAG &DAG) const {
1829 SDLoc DL(Op);
1830 EVT PtrVT = getPointerTy(DAG.getMachineFunction().getDataLayout());
1831
1832 auto *MFI = DAG.getMachineFunction().getInfo<WebAssemblyFunctionInfo>();
1833 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
1834
1835 SDValue ArgN = DAG.getCopyFromReg(DAG.getEntryNode(), DL,
1836 MFI->getVarargBufferVreg(), PtrVT);
1837 return DAG.getStore(Op.getOperand(0), DL, ArgN, Op.getOperand(1),
1838 MachinePointerInfo(SV));
1839 }
1840
LowerIntrinsic(SDValue Op,SelectionDAG & DAG) const1841 SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
1842 SelectionDAG &DAG) const {
1843 MachineFunction &MF = DAG.getMachineFunction();
1844 unsigned IntNo;
1845 switch (Op.getOpcode()) {
1846 case ISD::INTRINSIC_VOID:
1847 case ISD::INTRINSIC_W_CHAIN:
1848 IntNo = Op.getConstantOperandVal(1);
1849 break;
1850 case ISD::INTRINSIC_WO_CHAIN:
1851 IntNo = Op.getConstantOperandVal(0);
1852 break;
1853 default:
1854 llvm_unreachable("Invalid intrinsic");
1855 }
1856 SDLoc DL(Op);
1857
1858 switch (IntNo) {
1859 default:
1860 return SDValue(); // Don't custom lower most intrinsics.
1861
1862 case Intrinsic::wasm_lsda: {
1863 auto PtrVT = getPointerTy(MF.getDataLayout());
1864 const char *SymName = MF.createExternalSymbolName(
1865 "GCC_except_table" + std::to_string(MF.getFunctionNumber()));
1866 if (isPositionIndependent()) {
1867 SDValue Node = DAG.getTargetExternalSymbol(
1868 SymName, PtrVT, WebAssemblyII::MO_MEMORY_BASE_REL);
1869 const char *BaseName = MF.createExternalSymbolName("__memory_base");
1870 SDValue BaseAddr =
1871 DAG.getNode(WebAssemblyISD::Wrapper, DL, PtrVT,
1872 DAG.getTargetExternalSymbol(BaseName, PtrVT));
1873 SDValue SymAddr =
1874 DAG.getNode(WebAssemblyISD::WrapperREL, DL, PtrVT, Node);
1875 return DAG.getNode(ISD::ADD, DL, PtrVT, BaseAddr, SymAddr);
1876 }
1877 SDValue Node = DAG.getTargetExternalSymbol(SymName, PtrVT);
1878 return DAG.getNode(WebAssemblyISD::Wrapper, DL, PtrVT, Node);
1879 }
1880
1881 case Intrinsic::wasm_shuffle: {
1882 // Drop in-chain and replace undefs, but otherwise pass through unchanged
1883 SDValue Ops[18];
1884 size_t OpIdx = 0;
1885 Ops[OpIdx++] = Op.getOperand(1);
1886 Ops[OpIdx++] = Op.getOperand(2);
1887 while (OpIdx < 18) {
1888 const SDValue &MaskIdx = Op.getOperand(OpIdx + 1);
1889 if (MaskIdx.isUndef() || MaskIdx.getNode()->getAsZExtVal() >= 32) {
1890 bool isTarget = MaskIdx.getNode()->getOpcode() == ISD::TargetConstant;
1891 Ops[OpIdx++] = DAG.getConstant(0, DL, MVT::i32, isTarget);
1892 } else {
1893 Ops[OpIdx++] = MaskIdx;
1894 }
1895 }
1896 return DAG.getNode(WebAssemblyISD::SHUFFLE, DL, Op.getValueType(), Ops);
1897 }
1898 }
1899 }
1900
1901 SDValue
LowerSIGN_EXTEND_INREG(SDValue Op,SelectionDAG & DAG) const1902 WebAssemblyTargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
1903 SelectionDAG &DAG) const {
1904 SDLoc DL(Op);
1905 // If sign extension operations are disabled, allow sext_inreg only if operand
1906 // is a vector extract of an i8 or i16 lane. SIMD does not depend on sign
1907 // extension operations, but allowing sext_inreg in this context lets us have
1908 // simple patterns to select extract_lane_s instructions. Expanding sext_inreg
1909 // everywhere would be simpler in this file, but would necessitate large and
1910 // brittle patterns to undo the expansion and select extract_lane_s
1911 // instructions.
1912 assert(!Subtarget->hasSignExt() && Subtarget->hasSIMD128());
1913 if (Op.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT)
1914 return SDValue();
1915
1916 const SDValue &Extract = Op.getOperand(0);
1917 MVT VecT = Extract.getOperand(0).getSimpleValueType();
1918 if (VecT.getVectorElementType().getSizeInBits() > 32)
1919 return SDValue();
1920 MVT ExtractedLaneT =
1921 cast<VTSDNode>(Op.getOperand(1).getNode())->getVT().getSimpleVT();
1922 MVT ExtractedVecT =
1923 MVT::getVectorVT(ExtractedLaneT, 128 / ExtractedLaneT.getSizeInBits());
1924 if (ExtractedVecT == VecT)
1925 return Op;
1926
1927 // Bitcast vector to appropriate type to ensure ISel pattern coverage
1928 const SDNode *Index = Extract.getOperand(1).getNode();
1929 if (!isa<ConstantSDNode>(Index))
1930 return SDValue();
1931 unsigned IndexVal = Index->getAsZExtVal();
1932 unsigned Scale =
1933 ExtractedVecT.getVectorNumElements() / VecT.getVectorNumElements();
1934 assert(Scale > 1);
1935 SDValue NewIndex =
1936 DAG.getConstant(IndexVal * Scale, DL, Index->getValueType(0));
1937 SDValue NewExtract = DAG.getNode(
1938 ISD::EXTRACT_VECTOR_ELT, DL, Extract.getValueType(),
1939 DAG.getBitcast(ExtractedVecT, Extract.getOperand(0)), NewIndex);
1940 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Op.getValueType(), NewExtract,
1941 Op.getOperand(1));
1942 }
1943
1944 SDValue
LowerEXTEND_VECTOR_INREG(SDValue Op,SelectionDAG & DAG) const1945 WebAssemblyTargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op,
1946 SelectionDAG &DAG) const {
1947 SDLoc DL(Op);
1948 EVT VT = Op.getValueType();
1949 SDValue Src = Op.getOperand(0);
1950 EVT SrcVT = Src.getValueType();
1951
1952 if (SrcVT.getVectorElementType() == MVT::i1 ||
1953 SrcVT.getVectorElementType() == MVT::i64)
1954 return SDValue();
1955
1956 assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 &&
1957 "Unexpected extension factor.");
1958 unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();
1959
1960 if (Scale != 2 && Scale != 4 && Scale != 8)
1961 return SDValue();
1962
1963 unsigned Ext;
1964 switch (Op.getOpcode()) {
1965 case ISD::ZERO_EXTEND_VECTOR_INREG:
1966 Ext = WebAssemblyISD::EXTEND_LOW_U;
1967 break;
1968 case ISD::SIGN_EXTEND_VECTOR_INREG:
1969 Ext = WebAssemblyISD::EXTEND_LOW_S;
1970 break;
1971 }
1972
1973 SDValue Ret = Src;
1974 while (Scale != 1) {
1975 Ret = DAG.getNode(Ext, DL,
1976 Ret.getValueType()
1977 .widenIntegerVectorElementType(*DAG.getContext())
1978 .getHalfNumVectorElementsVT(*DAG.getContext()),
1979 Ret);
1980 Scale /= 2;
1981 }
1982 assert(Ret.getValueType() == VT);
1983 return Ret;
1984 }
1985
LowerConvertLow(SDValue Op,SelectionDAG & DAG)1986 static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
1987 SDLoc DL(Op);
1988 if (Op.getValueType() != MVT::v2f64)
1989 return SDValue();
1990
1991 auto GetConvertedLane = [](SDValue Op, unsigned &Opcode, SDValue &SrcVec,
1992 unsigned &Index) -> bool {
1993 switch (Op.getOpcode()) {
1994 case ISD::SINT_TO_FP:
1995 Opcode = WebAssemblyISD::CONVERT_LOW_S;
1996 break;
1997 case ISD::UINT_TO_FP:
1998 Opcode = WebAssemblyISD::CONVERT_LOW_U;
1999 break;
2000 case ISD::FP_EXTEND:
2001 Opcode = WebAssemblyISD::PROMOTE_LOW;
2002 break;
2003 default:
2004 return false;
2005 }
2006
2007 auto ExtractVector = Op.getOperand(0);
2008 if (ExtractVector.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
2009 return false;
2010
2011 if (!isa<ConstantSDNode>(ExtractVector.getOperand(1).getNode()))
2012 return false;
2013
2014 SrcVec = ExtractVector.getOperand(0);
2015 Index = ExtractVector.getConstantOperandVal(1);
2016 return true;
2017 };
2018
2019 unsigned LHSOpcode, RHSOpcode, LHSIndex, RHSIndex;
2020 SDValue LHSSrcVec, RHSSrcVec;
2021 if (!GetConvertedLane(Op.getOperand(0), LHSOpcode, LHSSrcVec, LHSIndex) ||
2022 !GetConvertedLane(Op.getOperand(1), RHSOpcode, RHSSrcVec, RHSIndex))
2023 return SDValue();
2024
2025 if (LHSOpcode != RHSOpcode)
2026 return SDValue();
2027
2028 MVT ExpectedSrcVT;
2029 switch (LHSOpcode) {
2030 case WebAssemblyISD::CONVERT_LOW_S:
2031 case WebAssemblyISD::CONVERT_LOW_U:
2032 ExpectedSrcVT = MVT::v4i32;
2033 break;
2034 case WebAssemblyISD::PROMOTE_LOW:
2035 ExpectedSrcVT = MVT::v4f32;
2036 break;
2037 }
2038 if (LHSSrcVec.getValueType() != ExpectedSrcVT)
2039 return SDValue();
2040
2041 auto Src = LHSSrcVec;
2042 if (LHSIndex != 0 || RHSIndex != 1 || LHSSrcVec != RHSSrcVec) {
2043 // Shuffle the source vector so that the converted lanes are the low lanes.
2044 Src = DAG.getVectorShuffle(
2045 ExpectedSrcVT, DL, LHSSrcVec, RHSSrcVec,
2046 {static_cast<int>(LHSIndex), static_cast<int>(RHSIndex) + 4, -1, -1});
2047 }
2048 return DAG.getNode(LHSOpcode, DL, MVT::v2f64, Src);
2049 }
2050
LowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG) const2051 SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2052 SelectionDAG &DAG) const {
2053 if (auto ConvertLow = LowerConvertLow(Op, DAG))
2054 return ConvertLow;
2055
2056 SDLoc DL(Op);
2057 const EVT VecT = Op.getValueType();
2058 const EVT LaneT = Op.getOperand(0).getValueType();
2059 const size_t Lanes = Op.getNumOperands();
2060 bool CanSwizzle = VecT == MVT::v16i8;
2061
2062 // BUILD_VECTORs are lowered to the instruction that initializes the highest
2063 // possible number of lanes at once followed by a sequence of replace_lane
2064 // instructions to individually initialize any remaining lanes.
2065
2066 // TODO: Tune this. For example, lanewise swizzling is very expensive, so
2067 // swizzled lanes should be given greater weight.
2068
2069 // TODO: Investigate looping rather than always extracting/replacing specific
2070 // lanes to fill gaps.
2071
2072 auto IsConstant = [](const SDValue &V) {
2073 return V.getOpcode() == ISD::Constant || V.getOpcode() == ISD::ConstantFP;
2074 };
2075
2076 // Returns the source vector and index vector pair if they exist. Checks for:
2077 // (extract_vector_elt
2078 // $src,
2079 // (sign_extend_inreg (extract_vector_elt $indices, $i))
2080 // )
2081 auto GetSwizzleSrcs = [](size_t I, const SDValue &Lane) {
2082 auto Bail = std::make_pair(SDValue(), SDValue());
2083 if (Lane->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
2084 return Bail;
2085 const SDValue &SwizzleSrc = Lane->getOperand(0);
2086 const SDValue &IndexExt = Lane->getOperand(1);
2087 if (IndexExt->getOpcode() != ISD::SIGN_EXTEND_INREG)
2088 return Bail;
2089 const SDValue &Index = IndexExt->getOperand(0);
2090 if (Index->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
2091 return Bail;
2092 const SDValue &SwizzleIndices = Index->getOperand(0);
2093 if (SwizzleSrc.getValueType() != MVT::v16i8 ||
2094 SwizzleIndices.getValueType() != MVT::v16i8 ||
2095 Index->getOperand(1)->getOpcode() != ISD::Constant ||
2096 Index->getConstantOperandVal(1) != I)
2097 return Bail;
2098 return std::make_pair(SwizzleSrc, SwizzleIndices);
2099 };
2100
2101 // If the lane is extracted from another vector at a constant index, return
2102 // that vector. The source vector must not have more lanes than the dest
2103 // because the shufflevector indices are in terms of the destination lanes and
2104 // would not be able to address the smaller individual source lanes.
2105 auto GetShuffleSrc = [&](const SDValue &Lane) {
2106 if (Lane->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
2107 return SDValue();
2108 if (!isa<ConstantSDNode>(Lane->getOperand(1).getNode()))
2109 return SDValue();
2110 if (Lane->getOperand(0).getValueType().getVectorNumElements() >
2111 VecT.getVectorNumElements())
2112 return SDValue();
2113 return Lane->getOperand(0);
2114 };
2115
2116 using ValueEntry = std::pair<SDValue, size_t>;
2117 SmallVector<ValueEntry, 16> SplatValueCounts;
2118
2119 using SwizzleEntry = std::pair<std::pair<SDValue, SDValue>, size_t>;
2120 SmallVector<SwizzleEntry, 16> SwizzleCounts;
2121
2122 using ShuffleEntry = std::pair<SDValue, size_t>;
2123 SmallVector<ShuffleEntry, 16> ShuffleCounts;
2124
2125 auto AddCount = [](auto &Counts, const auto &Val) {
2126 auto CountIt =
2127 llvm::find_if(Counts, [&Val](auto E) { return E.first == Val; });
2128 if (CountIt == Counts.end()) {
2129 Counts.emplace_back(Val, 1);
2130 } else {
2131 CountIt->second++;
2132 }
2133 };
2134
2135 auto GetMostCommon = [](auto &Counts) {
2136 auto CommonIt =
2137 std::max_element(Counts.begin(), Counts.end(), llvm::less_second());
2138 assert(CommonIt != Counts.end() && "Unexpected all-undef build_vector");
2139 return *CommonIt;
2140 };
2141
2142 size_t NumConstantLanes = 0;
2143
2144 // Count eligible lanes for each type of vector creation op
2145 for (size_t I = 0; I < Lanes; ++I) {
2146 const SDValue &Lane = Op->getOperand(I);
2147 if (Lane.isUndef())
2148 continue;
2149
2150 AddCount(SplatValueCounts, Lane);
2151
2152 if (IsConstant(Lane))
2153 NumConstantLanes++;
2154 if (auto ShuffleSrc = GetShuffleSrc(Lane))
2155 AddCount(ShuffleCounts, ShuffleSrc);
2156 if (CanSwizzle) {
2157 auto SwizzleSrcs = GetSwizzleSrcs(I, Lane);
2158 if (SwizzleSrcs.first)
2159 AddCount(SwizzleCounts, SwizzleSrcs);
2160 }
2161 }
2162
2163 SDValue SplatValue;
2164 size_t NumSplatLanes;
2165 std::tie(SplatValue, NumSplatLanes) = GetMostCommon(SplatValueCounts);
2166
2167 SDValue SwizzleSrc;
2168 SDValue SwizzleIndices;
2169 size_t NumSwizzleLanes = 0;
2170 if (SwizzleCounts.size())
2171 std::forward_as_tuple(std::tie(SwizzleSrc, SwizzleIndices),
2172 NumSwizzleLanes) = GetMostCommon(SwizzleCounts);
2173
2174 // Shuffles can draw from up to two vectors, so find the two most common
2175 // sources.
2176 SDValue ShuffleSrc1, ShuffleSrc2;
2177 size_t NumShuffleLanes = 0;
2178 if (ShuffleCounts.size()) {
2179 std::tie(ShuffleSrc1, NumShuffleLanes) = GetMostCommon(ShuffleCounts);
2180 llvm::erase_if(ShuffleCounts,
2181 [&](const auto &Pair) { return Pair.first == ShuffleSrc1; });
2182 }
2183 if (ShuffleCounts.size()) {
2184 size_t AdditionalShuffleLanes;
2185 std::tie(ShuffleSrc2, AdditionalShuffleLanes) =
2186 GetMostCommon(ShuffleCounts);
2187 NumShuffleLanes += AdditionalShuffleLanes;
2188 }
2189
2190 // Predicate returning true if the lane is properly initialized by the
2191 // original instruction
2192 std::function<bool(size_t, const SDValue &)> IsLaneConstructed;
2193 SDValue Result;
2194 // Prefer swizzles over shuffles over vector consts over splats
2195 if (NumSwizzleLanes >= NumShuffleLanes &&
2196 NumSwizzleLanes >= NumConstantLanes && NumSwizzleLanes >= NumSplatLanes) {
2197 Result = DAG.getNode(WebAssemblyISD::SWIZZLE, DL, VecT, SwizzleSrc,
2198 SwizzleIndices);
2199 auto Swizzled = std::make_pair(SwizzleSrc, SwizzleIndices);
2200 IsLaneConstructed = [&, Swizzled](size_t I, const SDValue &Lane) {
2201 return Swizzled == GetSwizzleSrcs(I, Lane);
2202 };
2203 } else if (NumShuffleLanes >= NumConstantLanes &&
2204 NumShuffleLanes >= NumSplatLanes) {
2205 size_t DestLaneSize = VecT.getVectorElementType().getFixedSizeInBits() / 8;
2206 size_t DestLaneCount = VecT.getVectorNumElements();
2207 size_t Scale1 = 1;
2208 size_t Scale2 = 1;
2209 SDValue Src1 = ShuffleSrc1;
2210 SDValue Src2 = ShuffleSrc2 ? ShuffleSrc2 : DAG.getUNDEF(VecT);
2211 if (Src1.getValueType() != VecT) {
2212 size_t LaneSize =
2213 Src1.getValueType().getVectorElementType().getFixedSizeInBits() / 8;
2214 assert(LaneSize > DestLaneSize);
2215 Scale1 = LaneSize / DestLaneSize;
2216 Src1 = DAG.getBitcast(VecT, Src1);
2217 }
2218 if (Src2.getValueType() != VecT) {
2219 size_t LaneSize =
2220 Src2.getValueType().getVectorElementType().getFixedSizeInBits() / 8;
2221 assert(LaneSize > DestLaneSize);
2222 Scale2 = LaneSize / DestLaneSize;
2223 Src2 = DAG.getBitcast(VecT, Src2);
2224 }
2225
2226 int Mask[16];
2227 assert(DestLaneCount <= 16);
2228 for (size_t I = 0; I < DestLaneCount; ++I) {
2229 const SDValue &Lane = Op->getOperand(I);
2230 SDValue Src = GetShuffleSrc(Lane);
2231 if (Src == ShuffleSrc1) {
2232 Mask[I] = Lane->getConstantOperandVal(1) * Scale1;
2233 } else if (Src && Src == ShuffleSrc2) {
2234 Mask[I] = DestLaneCount + Lane->getConstantOperandVal(1) * Scale2;
2235 } else {
2236 Mask[I] = -1;
2237 }
2238 }
2239 ArrayRef<int> MaskRef(Mask, DestLaneCount);
2240 Result = DAG.getVectorShuffle(VecT, DL, Src1, Src2, MaskRef);
2241 IsLaneConstructed = [&](size_t, const SDValue &Lane) {
2242 auto Src = GetShuffleSrc(Lane);
2243 return Src == ShuffleSrc1 || (Src && Src == ShuffleSrc2);
2244 };
2245 } else if (NumConstantLanes >= NumSplatLanes) {
2246 SmallVector<SDValue, 16> ConstLanes;
2247 for (const SDValue &Lane : Op->op_values()) {
2248 if (IsConstant(Lane)) {
2249 // Values may need to be fixed so that they will sign extend to be
2250 // within the expected range during ISel. Check whether the value is in
2251 // bounds based on the lane bit width and if it is out of bounds, lop
2252 // off the extra bits and subtract 2^n to reflect giving the high bit
2253 // value -2^(n-1) rather than +2^(n-1). Skip the i64 case because it
2254 // cannot possibly be out of range.
2255 auto *Const = dyn_cast<ConstantSDNode>(Lane.getNode());
2256 int64_t Val = Const ? Const->getSExtValue() : 0;
2257 uint64_t LaneBits = 128 / Lanes;
2258 assert((LaneBits == 64 || Val >= -(1ll << (LaneBits - 1))) &&
2259 "Unexpected out of bounds negative value");
2260 if (Const && LaneBits != 64 && Val > (1ll << (LaneBits - 1)) - 1) {
2261 uint64_t Mask = (1ll << LaneBits) - 1;
2262 auto NewVal = (((uint64_t)Val & Mask) - (1ll << LaneBits)) & Mask;
2263 ConstLanes.push_back(DAG.getConstant(NewVal, SDLoc(Lane), LaneT));
2264 } else {
2265 ConstLanes.push_back(Lane);
2266 }
2267 } else if (LaneT.isFloatingPoint()) {
2268 ConstLanes.push_back(DAG.getConstantFP(0, DL, LaneT));
2269 } else {
2270 ConstLanes.push_back(DAG.getConstant(0, DL, LaneT));
2271 }
2272 }
2273 Result = DAG.getBuildVector(VecT, DL, ConstLanes);
2274 IsLaneConstructed = [&IsConstant](size_t _, const SDValue &Lane) {
2275 return IsConstant(Lane);
2276 };
2277 } else {
2278 // Use a splat (which might be selected as a load splat)
2279 Result = DAG.getSplatBuildVector(VecT, DL, SplatValue);
2280 IsLaneConstructed = [&SplatValue](size_t _, const SDValue &Lane) {
2281 return Lane == SplatValue;
2282 };
2283 }
2284
2285 assert(Result);
2286 assert(IsLaneConstructed);
2287
2288 // Add replace_lane instructions for any unhandled values
2289 for (size_t I = 0; I < Lanes; ++I) {
2290 const SDValue &Lane = Op->getOperand(I);
2291 if (!Lane.isUndef() && !IsLaneConstructed(I, Lane))
2292 Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecT, Result, Lane,
2293 DAG.getConstant(I, DL, MVT::i32));
2294 }
2295
2296 return Result;
2297 }
2298
2299 SDValue
LowerVECTOR_SHUFFLE(SDValue Op,SelectionDAG & DAG) const2300 WebAssemblyTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2301 SelectionDAG &DAG) const {
2302 SDLoc DL(Op);
2303 ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op.getNode())->getMask();
2304 MVT VecType = Op.getOperand(0).getSimpleValueType();
2305 assert(VecType.is128BitVector() && "Unexpected shuffle vector type");
2306 size_t LaneBytes = VecType.getVectorElementType().getSizeInBits() / 8;
2307
2308 // Space for two vector args and sixteen mask indices
2309 SDValue Ops[18];
2310 size_t OpIdx = 0;
2311 Ops[OpIdx++] = Op.getOperand(0);
2312 Ops[OpIdx++] = Op.getOperand(1);
2313
2314 // Expand mask indices to byte indices and materialize them as operands
2315 for (int M : Mask) {
2316 for (size_t J = 0; J < LaneBytes; ++J) {
2317 // Lower undefs (represented by -1 in mask) to {0..J}, which use a
2318 // whole lane of vector input, to allow further reduction at VM. E.g.
2319 // match an 8x16 byte shuffle to an equivalent cheaper 32x4 shuffle.
2320 uint64_t ByteIndex = M == -1 ? J : (uint64_t)M * LaneBytes + J;
2321 Ops[OpIdx++] = DAG.getConstant(ByteIndex, DL, MVT::i32);
2322 }
2323 }
2324
2325 return DAG.getNode(WebAssemblyISD::SHUFFLE, DL, Op.getValueType(), Ops);
2326 }
2327
LowerSETCC(SDValue Op,SelectionDAG & DAG) const2328 SDValue WebAssemblyTargetLowering::LowerSETCC(SDValue Op,
2329 SelectionDAG &DAG) const {
2330 SDLoc DL(Op);
2331 // The legalizer does not know how to expand the unsupported comparison modes
2332 // of i64x2 vectors, so we manually unroll them here.
2333 assert(Op->getOperand(0)->getSimpleValueType(0) == MVT::v2i64);
2334 SmallVector<SDValue, 2> LHS, RHS;
2335 DAG.ExtractVectorElements(Op->getOperand(0), LHS);
2336 DAG.ExtractVectorElements(Op->getOperand(1), RHS);
2337 const SDValue &CC = Op->getOperand(2);
2338 auto MakeLane = [&](unsigned I) {
2339 return DAG.getNode(ISD::SELECT_CC, DL, MVT::i64, LHS[I], RHS[I],
2340 DAG.getConstant(uint64_t(-1), DL, MVT::i64),
2341 DAG.getConstant(uint64_t(0), DL, MVT::i64), CC);
2342 };
2343 return DAG.getBuildVector(Op->getValueType(0), DL,
2344 {MakeLane(0), MakeLane(1)});
2345 }
2346
2347 SDValue
LowerAccessVectorElement(SDValue Op,SelectionDAG & DAG) const2348 WebAssemblyTargetLowering::LowerAccessVectorElement(SDValue Op,
2349 SelectionDAG &DAG) const {
2350 // Allow constant lane indices, expand variable lane indices
2351 SDNode *IdxNode = Op.getOperand(Op.getNumOperands() - 1).getNode();
2352 if (isa<ConstantSDNode>(IdxNode)) {
2353 // Ensure the index type is i32 to match the tablegen patterns
2354 uint64_t Idx = IdxNode->getAsZExtVal();
2355 SmallVector<SDValue, 3> Ops(Op.getNode()->ops());
2356 Ops[Op.getNumOperands() - 1] =
2357 DAG.getConstant(Idx, SDLoc(IdxNode), MVT::i32);
2358 return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), Ops);
2359 }
2360 // Perform default expansion
2361 return SDValue();
2362 }
2363
unrollVectorShift(SDValue Op,SelectionDAG & DAG)2364 static SDValue unrollVectorShift(SDValue Op, SelectionDAG &DAG) {
2365 EVT LaneT = Op.getSimpleValueType().getVectorElementType();
2366 // 32-bit and 64-bit unrolled shifts will have proper semantics
2367 if (LaneT.bitsGE(MVT::i32))
2368 return DAG.UnrollVectorOp(Op.getNode());
2369 // Otherwise mask the shift value to get proper semantics from 32-bit shift
2370 SDLoc DL(Op);
2371 size_t NumLanes = Op.getSimpleValueType().getVectorNumElements();
2372 SDValue Mask = DAG.getConstant(LaneT.getSizeInBits() - 1, DL, MVT::i32);
2373 unsigned ShiftOpcode = Op.getOpcode();
2374 SmallVector<SDValue, 16> ShiftedElements;
2375 DAG.ExtractVectorElements(Op.getOperand(0), ShiftedElements, 0, 0, MVT::i32);
2376 SmallVector<SDValue, 16> ShiftElements;
2377 DAG.ExtractVectorElements(Op.getOperand(1), ShiftElements, 0, 0, MVT::i32);
2378 SmallVector<SDValue, 16> UnrolledOps;
2379 for (size_t i = 0; i < NumLanes; ++i) {
2380 SDValue MaskedShiftValue =
2381 DAG.getNode(ISD::AND, DL, MVT::i32, ShiftElements[i], Mask);
2382 SDValue ShiftedValue = ShiftedElements[i];
2383 if (ShiftOpcode == ISD::SRA)
2384 ShiftedValue = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i32,
2385 ShiftedValue, DAG.getValueType(LaneT));
2386 UnrolledOps.push_back(
2387 DAG.getNode(ShiftOpcode, DL, MVT::i32, ShiftedValue, MaskedShiftValue));
2388 }
2389 return DAG.getBuildVector(Op.getValueType(), DL, UnrolledOps);
2390 }
2391
LowerShift(SDValue Op,SelectionDAG & DAG) const2392 SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
2393 SelectionDAG &DAG) const {
2394 SDLoc DL(Op);
2395
2396 // Only manually lower vector shifts
2397 assert(Op.getSimpleValueType().isVector());
2398
2399 uint64_t LaneBits = Op.getValueType().getScalarSizeInBits();
2400 auto ShiftVal = Op.getOperand(1);
2401
2402 // Try to skip bitmask operation since it is implied inside shift instruction
2403 auto SkipImpliedMask = [](SDValue MaskOp, uint64_t MaskBits) {
2404 if (MaskOp.getOpcode() != ISD::AND)
2405 return MaskOp;
2406 SDValue LHS = MaskOp.getOperand(0);
2407 SDValue RHS = MaskOp.getOperand(1);
2408 if (MaskOp.getValueType().isVector()) {
2409 APInt MaskVal;
2410 if (!ISD::isConstantSplatVector(RHS.getNode(), MaskVal))
2411 std::swap(LHS, RHS);
2412
2413 if (ISD::isConstantSplatVector(RHS.getNode(), MaskVal) &&
2414 MaskVal == MaskBits)
2415 MaskOp = LHS;
2416 } else {
2417 if (!isa<ConstantSDNode>(RHS.getNode()))
2418 std::swap(LHS, RHS);
2419
2420 auto ConstantRHS = dyn_cast<ConstantSDNode>(RHS.getNode());
2421 if (ConstantRHS && ConstantRHS->getAPIntValue() == MaskBits)
2422 MaskOp = LHS;
2423 }
2424
2425 return MaskOp;
2426 };
2427
2428 // Skip vector and operation
2429 ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
2430 ShiftVal = DAG.getSplatValue(ShiftVal);
2431 if (!ShiftVal)
2432 return unrollVectorShift(Op, DAG);
2433
2434 // Skip scalar and operation
2435 ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
2436 // Use anyext because none of the high bits can affect the shift
2437 ShiftVal = DAG.getAnyExtOrTrunc(ShiftVal, DL, MVT::i32);
2438
2439 unsigned Opcode;
2440 switch (Op.getOpcode()) {
2441 case ISD::SHL:
2442 Opcode = WebAssemblyISD::VEC_SHL;
2443 break;
2444 case ISD::SRA:
2445 Opcode = WebAssemblyISD::VEC_SHR_S;
2446 break;
2447 case ISD::SRL:
2448 Opcode = WebAssemblyISD::VEC_SHR_U;
2449 break;
2450 default:
2451 llvm_unreachable("unexpected opcode");
2452 }
2453
2454 return DAG.getNode(Opcode, DL, Op.getValueType(), Op.getOperand(0), ShiftVal);
2455 }
2456
LowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const2457 SDValue WebAssemblyTargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
2458 SelectionDAG &DAG) const {
2459 SDLoc DL(Op);
2460 EVT ResT = Op.getValueType();
2461 EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2462
2463 if ((ResT == MVT::i32 || ResT == MVT::i64) &&
2464 (SatVT == MVT::i32 || SatVT == MVT::i64))
2465 return Op;
2466
2467 if (ResT == MVT::v4i32 && SatVT == MVT::i32)
2468 return Op;
2469
2470 return SDValue();
2471 }
2472
2473 //===----------------------------------------------------------------------===//
2474 // Custom DAG combine hooks
2475 //===----------------------------------------------------------------------===//
2476 static SDValue
performVECTOR_SHUFFLECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2477 performVECTOR_SHUFFLECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
2478 auto &DAG = DCI.DAG;
2479 auto Shuffle = cast<ShuffleVectorSDNode>(N);
2480
2481 // Hoist vector bitcasts that don't change the number of lanes out of unary
2482 // shuffles, where they are less likely to get in the way of other combines.
2483 // (shuffle (vNxT1 (bitcast (vNxT0 x))), undef, mask) ->
2484 // (vNxT1 (bitcast (vNxT0 (shuffle x, undef, mask))))
2485 SDValue Bitcast = N->getOperand(0);
2486 if (Bitcast.getOpcode() != ISD::BITCAST)
2487 return SDValue();
2488 if (!N->getOperand(1).isUndef())
2489 return SDValue();
2490 SDValue CastOp = Bitcast.getOperand(0);
2491 EVT SrcType = CastOp.getValueType();
2492 EVT DstType = Bitcast.getValueType();
2493 if (!SrcType.is128BitVector() ||
2494 SrcType.getVectorNumElements() != DstType.getVectorNumElements())
2495 return SDValue();
2496 SDValue NewShuffle = DAG.getVectorShuffle(
2497 SrcType, SDLoc(N), CastOp, DAG.getUNDEF(SrcType), Shuffle->getMask());
2498 return DAG.getBitcast(DstType, NewShuffle);
2499 }
2500
2501 /// Convert ({u,s}itofp vec) --> ({u,s}itofp ({s,z}ext vec)) so it doesn't get
2502 /// split up into scalar instructions during legalization, and the vector
2503 /// extending instructions are selected in performVectorExtendCombine below.
2504 static SDValue
performVectorExtendToFPCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2505 performVectorExtendToFPCombine(SDNode *N,
2506 TargetLowering::DAGCombinerInfo &DCI) {
2507 auto &DAG = DCI.DAG;
2508 assert(N->getOpcode() == ISD::UINT_TO_FP ||
2509 N->getOpcode() == ISD::SINT_TO_FP);
2510
2511 EVT InVT = N->getOperand(0)->getValueType(0);
2512 EVT ResVT = N->getValueType(0);
2513 MVT ExtVT;
2514 if (ResVT == MVT::v4f32 && (InVT == MVT::v4i16 || InVT == MVT::v4i8))
2515 ExtVT = MVT::v4i32;
2516 else if (ResVT == MVT::v2f64 && (InVT == MVT::v2i16 || InVT == MVT::v2i8))
2517 ExtVT = MVT::v2i32;
2518 else
2519 return SDValue();
2520
2521 unsigned Op =
2522 N->getOpcode() == ISD::UINT_TO_FP ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
2523 SDValue Conv = DAG.getNode(Op, SDLoc(N), ExtVT, N->getOperand(0));
2524 return DAG.getNode(N->getOpcode(), SDLoc(N), ResVT, Conv);
2525 }
2526
2527 static SDValue
performVectorExtendCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2528 performVectorExtendCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
2529 auto &DAG = DCI.DAG;
2530 assert(N->getOpcode() == ISD::SIGN_EXTEND ||
2531 N->getOpcode() == ISD::ZERO_EXTEND);
2532
2533 // Combine ({s,z}ext (extract_subvector src, i)) into a widening operation if
2534 // possible before the extract_subvector can be expanded.
2535 auto Extract = N->getOperand(0);
2536 if (Extract.getOpcode() != ISD::EXTRACT_SUBVECTOR)
2537 return SDValue();
2538 auto Source = Extract.getOperand(0);
2539 auto *IndexNode = dyn_cast<ConstantSDNode>(Extract.getOperand(1));
2540 if (IndexNode == nullptr)
2541 return SDValue();
2542 auto Index = IndexNode->getZExtValue();
2543
2544 // Only v8i8, v4i16, and v2i32 extracts can be widened, and only if the
2545 // extracted subvector is the low or high half of its source.
2546 EVT ResVT = N->getValueType(0);
2547 if (ResVT == MVT::v8i16) {
2548 if (Extract.getValueType() != MVT::v8i8 ||
2549 Source.getValueType() != MVT::v16i8 || (Index != 0 && Index != 8))
2550 return SDValue();
2551 } else if (ResVT == MVT::v4i32) {
2552 if (Extract.getValueType() != MVT::v4i16 ||
2553 Source.getValueType() != MVT::v8i16 || (Index != 0 && Index != 4))
2554 return SDValue();
2555 } else if (ResVT == MVT::v2i64) {
2556 if (Extract.getValueType() != MVT::v2i32 ||
2557 Source.getValueType() != MVT::v4i32 || (Index != 0 && Index != 2))
2558 return SDValue();
2559 } else {
2560 return SDValue();
2561 }
2562
2563 bool IsSext = N->getOpcode() == ISD::SIGN_EXTEND;
2564 bool IsLow = Index == 0;
2565
2566 unsigned Op = IsSext ? (IsLow ? WebAssemblyISD::EXTEND_LOW_S
2567 : WebAssemblyISD::EXTEND_HIGH_S)
2568 : (IsLow ? WebAssemblyISD::EXTEND_LOW_U
2569 : WebAssemblyISD::EXTEND_HIGH_U);
2570
2571 return DAG.getNode(Op, SDLoc(N), ResVT, Source);
2572 }
2573
2574 static SDValue
performVectorTruncZeroCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2575 performVectorTruncZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
2576 auto &DAG = DCI.DAG;
2577
2578 auto GetWasmConversionOp = [](unsigned Op) {
2579 switch (Op) {
2580 case ISD::FP_TO_SINT_SAT:
2581 return WebAssemblyISD::TRUNC_SAT_ZERO_S;
2582 case ISD::FP_TO_UINT_SAT:
2583 return WebAssemblyISD::TRUNC_SAT_ZERO_U;
2584 case ISD::FP_ROUND:
2585 return WebAssemblyISD::DEMOTE_ZERO;
2586 }
2587 llvm_unreachable("unexpected op");
2588 };
2589
2590 auto IsZeroSplat = [](SDValue SplatVal) {
2591 auto *Splat = dyn_cast<BuildVectorSDNode>(SplatVal.getNode());
2592 APInt SplatValue, SplatUndef;
2593 unsigned SplatBitSize;
2594 bool HasAnyUndefs;
2595 // Endianness doesn't matter in this context because we are looking for
2596 // an all-zero value.
2597 return Splat &&
2598 Splat->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
2599 HasAnyUndefs) &&
2600 SplatValue == 0;
2601 };
2602
2603 if (N->getOpcode() == ISD::CONCAT_VECTORS) {
2604 // Combine this:
2605 //
2606 // (concat_vectors (v2i32 (fp_to_{s,u}int_sat $x, 32)), (v2i32 (splat 0)))
2607 //
2608 // into (i32x4.trunc_sat_f64x2_zero_{s,u} $x).
2609 //
2610 // Or this:
2611 //
2612 // (concat_vectors (v2f32 (fp_round (v2f64 $x))), (v2f32 (splat 0)))
2613 //
2614 // into (f32x4.demote_zero_f64x2 $x).
2615 EVT ResVT;
2616 EVT ExpectedConversionType;
2617 auto Conversion = N->getOperand(0);
2618 auto ConversionOp = Conversion.getOpcode();
2619 switch (ConversionOp) {
2620 case ISD::FP_TO_SINT_SAT:
2621 case ISD::FP_TO_UINT_SAT:
2622 ResVT = MVT::v4i32;
2623 ExpectedConversionType = MVT::v2i32;
2624 break;
2625 case ISD::FP_ROUND:
2626 ResVT = MVT::v4f32;
2627 ExpectedConversionType = MVT::v2f32;
2628 break;
2629 default:
2630 return SDValue();
2631 }
2632
2633 if (N->getValueType(0) != ResVT)
2634 return SDValue();
2635
2636 if (Conversion.getValueType() != ExpectedConversionType)
2637 return SDValue();
2638
2639 auto Source = Conversion.getOperand(0);
2640 if (Source.getValueType() != MVT::v2f64)
2641 return SDValue();
2642
2643 if (!IsZeroSplat(N->getOperand(1)) ||
2644 N->getOperand(1).getValueType() != ExpectedConversionType)
2645 return SDValue();
2646
2647 unsigned Op = GetWasmConversionOp(ConversionOp);
2648 return DAG.getNode(Op, SDLoc(N), ResVT, Source);
2649 }
2650
2651 // Combine this:
2652 //
2653 // (fp_to_{s,u}int_sat (concat_vectors $x, (v2f64 (splat 0))), 32)
2654 //
2655 // into (i32x4.trunc_sat_f64x2_zero_{s,u} $x).
2656 //
2657 // Or this:
2658 //
2659 // (v4f32 (fp_round (concat_vectors $x, (v2f64 (splat 0)))))
2660 //
2661 // into (f32x4.demote_zero_f64x2 $x).
2662 EVT ResVT;
2663 auto ConversionOp = N->getOpcode();
2664 switch (ConversionOp) {
2665 case ISD::FP_TO_SINT_SAT:
2666 case ISD::FP_TO_UINT_SAT:
2667 ResVT = MVT::v4i32;
2668 break;
2669 case ISD::FP_ROUND:
2670 ResVT = MVT::v4f32;
2671 break;
2672 default:
2673 llvm_unreachable("unexpected op");
2674 }
2675
2676 if (N->getValueType(0) != ResVT)
2677 return SDValue();
2678
2679 auto Concat = N->getOperand(0);
2680 if (Concat.getValueType() != MVT::v4f64)
2681 return SDValue();
2682
2683 auto Source = Concat.getOperand(0);
2684 if (Source.getValueType() != MVT::v2f64)
2685 return SDValue();
2686
2687 if (!IsZeroSplat(Concat.getOperand(1)) ||
2688 Concat.getOperand(1).getValueType() != MVT::v2f64)
2689 return SDValue();
2690
2691 unsigned Op = GetWasmConversionOp(ConversionOp);
2692 return DAG.getNode(Op, SDLoc(N), ResVT, Source);
2693 }
2694
2695 // Helper to extract VectorWidth bits from Vec, starting from IdxVal.
extractSubVector(SDValue Vec,unsigned IdxVal,SelectionDAG & DAG,const SDLoc & DL,unsigned VectorWidth)2696 static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
2697 const SDLoc &DL, unsigned VectorWidth) {
2698 EVT VT = Vec.getValueType();
2699 EVT ElVT = VT.getVectorElementType();
2700 unsigned Factor = VT.getSizeInBits() / VectorWidth;
2701 EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
2702 VT.getVectorNumElements() / Factor);
2703
2704 // Extract the relevant VectorWidth bits. Generate an EXTRACT_SUBVECTOR
2705 unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits();
2706 assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
2707
2708 // This is the index of the first element of the VectorWidth-bit chunk
2709 // we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
2710 IdxVal &= ~(ElemsPerChunk - 1);
2711
2712 // If the input is a buildvector just emit a smaller one.
2713 if (Vec.getOpcode() == ISD::BUILD_VECTOR)
2714 return DAG.getBuildVector(ResultVT, DL,
2715 Vec->ops().slice(IdxVal, ElemsPerChunk));
2716
2717 SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL);
2718 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx);
2719 }
2720
2721 // Helper to recursively truncate vector elements in half with NARROW_U. DstVT
2722 // is the expected destination value type after recursion. In is the initial
2723 // input. Note that the input should have enough leading zero bits to prevent
2724 // NARROW_U from saturating results.
truncateVectorWithNARROW(EVT DstVT,SDValue In,const SDLoc & DL,SelectionDAG & DAG)2725 static SDValue truncateVectorWithNARROW(EVT DstVT, SDValue In, const SDLoc &DL,
2726 SelectionDAG &DAG) {
2727 EVT SrcVT = In.getValueType();
2728
2729 // No truncation required, we might get here due to recursive calls.
2730 if (SrcVT == DstVT)
2731 return In;
2732
2733 unsigned SrcSizeInBits = SrcVT.getSizeInBits();
2734 unsigned NumElems = SrcVT.getVectorNumElements();
2735 if (!isPowerOf2_32(NumElems))
2736 return SDValue();
2737 assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
2738 assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation");
2739
2740 LLVMContext &Ctx = *DAG.getContext();
2741 EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
2742
2743 // Narrow to the largest type possible:
2744 // vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u.
2745 EVT InVT = MVT::i16, OutVT = MVT::i8;
2746 if (SrcVT.getScalarSizeInBits() > 16) {
2747 InVT = MVT::i32;
2748 OutVT = MVT::i16;
2749 }
2750 unsigned SubSizeInBits = SrcSizeInBits / 2;
2751 InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
2752 OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
2753
2754 // Split lower/upper subvectors.
2755 SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits);
2756 SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits);
2757
2758 // 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors.
2759 if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
2760 Lo = DAG.getBitcast(InVT, Lo);
2761 Hi = DAG.getBitcast(InVT, Hi);
2762 SDValue Res = DAG.getNode(WebAssemblyISD::NARROW_U, DL, OutVT, Lo, Hi);
2763 return DAG.getBitcast(DstVT, Res);
2764 }
2765
2766 // Recursively narrow lower/upper subvectors, concat result and narrow again.
2767 EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2);
2768 Lo = truncateVectorWithNARROW(PackedVT, Lo, DL, DAG);
2769 Hi = truncateVectorWithNARROW(PackedVT, Hi, DL, DAG);
2770
2771 PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
2772 SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
2773 return truncateVectorWithNARROW(DstVT, Res, DL, DAG);
2774 }
2775
performTruncateCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2776 static SDValue performTruncateCombine(SDNode *N,
2777 TargetLowering::DAGCombinerInfo &DCI) {
2778 auto &DAG = DCI.DAG;
2779
2780 SDValue In = N->getOperand(0);
2781 EVT InVT = In.getValueType();
2782 if (!InVT.isSimple())
2783 return SDValue();
2784
2785 EVT OutVT = N->getValueType(0);
2786 if (!OutVT.isVector())
2787 return SDValue();
2788
2789 EVT OutSVT = OutVT.getVectorElementType();
2790 EVT InSVT = InVT.getVectorElementType();
2791 // Currently only cover truncate to v16i8 or v8i16.
2792 if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) &&
2793 (OutSVT == MVT::i8 || OutSVT == MVT::i16) && OutVT.is128BitVector()))
2794 return SDValue();
2795
2796 SDLoc DL(N);
2797 APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(),
2798 OutVT.getScalarSizeInBits());
2799 In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT));
2800 return truncateVectorWithNARROW(OutVT, In, DL, DAG);
2801 }
2802
performBitcastCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2803 static SDValue performBitcastCombine(SDNode *N,
2804 TargetLowering::DAGCombinerInfo &DCI) {
2805 auto &DAG = DCI.DAG;
2806 SDLoc DL(N);
2807 SDValue Src = N->getOperand(0);
2808 EVT VT = N->getValueType(0);
2809 EVT SrcVT = Src.getValueType();
2810
2811 // bitcast <N x i1> to iN
2812 // ==> bitmask
2813 if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
2814 SrcVT.isFixedLengthVector() && SrcVT.getScalarType() == MVT::i1) {
2815 unsigned NumElts = SrcVT.getVectorNumElements();
2816 if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
2817 return SDValue();
2818 EVT Width = MVT::getIntegerVT(128 / NumElts);
2819 return DAG.getZExtOrTrunc(
2820 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
2821 {DAG.getConstant(Intrinsic::wasm_bitmask, DL, MVT::i32),
2822 DAG.getSExtOrTrunc(N->getOperand(0), DL,
2823 SrcVT.changeVectorElementType(Width))}),
2824 DL, VT);
2825 }
2826
2827 return SDValue();
2828 }
2829
performSETCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)2830 static SDValue performSETCCCombine(SDNode *N,
2831 TargetLowering::DAGCombinerInfo &DCI) {
2832 auto &DAG = DCI.DAG;
2833
2834 SDValue LHS = N->getOperand(0);
2835 SDValue RHS = N->getOperand(1);
2836 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
2837 SDLoc DL(N);
2838 EVT VT = N->getValueType(0);
2839
2840 // setcc (iN (bitcast (vNi1 X))), 0, ne
2841 // ==> any_true (vNi1 X)
2842 // setcc (iN (bitcast (vNi1 X))), 0, eq
2843 // ==> xor (any_true (vNi1 X)), -1
2844 // setcc (iN (bitcast (vNi1 X))), -1, eq
2845 // ==> all_true (vNi1 X)
2846 // setcc (iN (bitcast (vNi1 X))), -1, ne
2847 // ==> xor (all_true (vNi1 X)), -1
2848 if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
2849 (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
2850 (isNullConstant(RHS) || isAllOnesConstant(RHS)) &&
2851 LHS->getOpcode() == ISD::BITCAST) {
2852 EVT FromVT = LHS->getOperand(0).getValueType();
2853 if (FromVT.isFixedLengthVector() &&
2854 FromVT.getVectorElementType() == MVT::i1) {
2855 int Intrin = isNullConstant(RHS) ? Intrinsic::wasm_anytrue
2856 : Intrinsic::wasm_alltrue;
2857 unsigned NumElts = FromVT.getVectorNumElements();
2858 if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
2859 return SDValue();
2860 EVT Width = MVT::getIntegerVT(128 / NumElts);
2861 SDValue Ret = DAG.getZExtOrTrunc(
2862 DAG.getNode(
2863 ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
2864 {DAG.getConstant(Intrin, DL, MVT::i32),
2865 DAG.getSExtOrTrunc(LHS->getOperand(0), DL,
2866 FromVT.changeVectorElementType(Width))}),
2867 DL, MVT::i1);
2868 if ((isNullConstant(RHS) && (Cond == ISD::SETEQ)) ||
2869 (isAllOnesConstant(RHS) && (Cond == ISD::SETNE))) {
2870 Ret = DAG.getNOT(DL, Ret, MVT::i1);
2871 }
2872 return DAG.getZExtOrTrunc(Ret, DL, VT);
2873 }
2874 }
2875
2876 return SDValue();
2877 }
2878
2879 SDValue
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const2880 WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
2881 DAGCombinerInfo &DCI) const {
2882 switch (N->getOpcode()) {
2883 default:
2884 return SDValue();
2885 case ISD::BITCAST:
2886 return performBitcastCombine(N, DCI);
2887 case ISD::SETCC:
2888 return performSETCCCombine(N, DCI);
2889 case ISD::VECTOR_SHUFFLE:
2890 return performVECTOR_SHUFFLECombine(N, DCI);
2891 case ISD::SIGN_EXTEND:
2892 case ISD::ZERO_EXTEND:
2893 return performVectorExtendCombine(N, DCI);
2894 case ISD::UINT_TO_FP:
2895 case ISD::SINT_TO_FP:
2896 return performVectorExtendToFPCombine(N, DCI);
2897 case ISD::FP_TO_SINT_SAT:
2898 case ISD::FP_TO_UINT_SAT:
2899 case ISD::FP_ROUND:
2900 case ISD::CONCAT_VECTORS:
2901 return performVectorTruncZeroCombine(N, DCI);
2902 case ISD::TRUNCATE:
2903 return performTruncateCombine(N, DCI);
2904 }
2905 }
2906