1 // -*- C++ -*- 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 // Kokkos v. 4.0 9 // Copyright (2022) National Technology & Engineering 10 // Solutions of Sandia, LLC (NTESS). 11 // 12 // Under the terms of Contract DE-NA0003525 with NTESS, 13 // the U.S. Government retains certain rights in this software. 14 // 15 //===---------------------------------------------------------------------===// 16 17 #ifndef _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 18 #define _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 19 20 #include <__assert> 21 #include <__config> 22 #include <__fwd/mdspan.h> 23 #include <__mdspan/extents.h> 24 #include <__type_traits/is_constructible.h> 25 #include <__type_traits/is_convertible.h> 26 #include <__type_traits/is_nothrow_constructible.h> 27 #include <__utility/as_const.h> 28 #include <__utility/integer_sequence.h> 29 #include <__utility/swap.h> 30 #include <array> 31 #include <cinttypes> 32 #include <cstddef> 33 #include <limits> 34 35 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 36 # pragma GCC system_header 37 #endif 38 39 _LIBCPP_PUSH_MACROS 40 #include <__undef_macros> 41 42 _LIBCPP_BEGIN_NAMESPACE_STD 43 44 #if _LIBCPP_STD_VER >= 23 45 46 namespace __mdspan_detail { 47 template <class _Layout, class _Mapping> 48 constexpr bool __is_mapping_of = 49 is_same_v<typename _Layout::template mapping<typename _Mapping::extents_type>, _Mapping>; 50 51 template <class _Mapping> 52 concept __layout_mapping_alike = requires { 53 requires __is_mapping_of<typename _Mapping::layout_type, _Mapping>; 54 requires __is_extents_v<typename _Mapping::extents_type>; 55 { _Mapping::is_always_strided() } -> same_as<bool>; 56 { _Mapping::is_always_exhaustive() } -> same_as<bool>; 57 { _Mapping::is_always_unique() } -> same_as<bool>; 58 bool_constant<_Mapping::is_always_strided()>::value; 59 bool_constant<_Mapping::is_always_exhaustive()>::value; 60 bool_constant<_Mapping::is_always_unique()>::value; 61 }; 62 } // namespace __mdspan_detail 63 64 template <class _Extents> 65 class layout_stride::mapping { 66 public: 67 static_assert(__mdspan_detail::__is_extents<_Extents>::value, 68 "layout_stride::mapping template argument must be a specialization of extents."); 69 70 using extents_type = _Extents; 71 using index_type = typename extents_type::index_type; 72 using size_type = typename extents_type::size_type; 73 using rank_type = typename extents_type::rank_type; 74 using layout_type = layout_stride; 75 76 private: 77 static constexpr rank_type __rank_ = extents_type::rank(); 78 79 // Used for default construction check and mandates 80 _LIBCPP_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext) { 81 if constexpr (__rank_ == 0) 82 return true; 83 84 index_type __prod = __ext.extent(0); 85 for (rank_type __r = 1; __r < __rank_; __r++) { 86 bool __overflowed = __builtin_mul_overflow(__prod, __ext.extent(__r), &__prod); 87 if (__overflowed) 88 return false; 89 } 90 return true; 91 } 92 93 template <class _OtherIndexType> 94 _LIBCPP_HIDE_FROM_ABI static constexpr bool 95 __required_span_size_is_representable(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) { 96 if constexpr (__rank_ == 0) 97 return true; 98 99 index_type __size = 1; 100 for (rank_type __r = 0; __r < __rank_; __r++) { 101 // We can only check correct conversion of _OtherIndexType if it is an integral 102 if constexpr (is_integral_v<_OtherIndexType>) { 103 using _CommonType = common_type_t<index_type, _OtherIndexType>; 104 if (static_cast<_CommonType>(__strides[__r]) > static_cast<_CommonType>(numeric_limits<index_type>::max())) 105 return false; 106 } 107 if (__ext.extent(__r) == static_cast<index_type>(0)) 108 return true; 109 index_type __prod = (__ext.extent(__r) - 1); 110 bool __overflowed_mul = __builtin_mul_overflow(__prod, static_cast<index_type>(__strides[__r]), &__prod); 111 if (__overflowed_mul) 112 return false; 113 bool __overflowed_add = __builtin_add_overflow(__size, __prod, &__size); 114 if (__overflowed_add) 115 return false; 116 } 117 return true; 118 } 119 120 // compute offset of a strided layout mapping 121 template <class _StridedMapping> 122 _LIBCPP_HIDE_FROM_ABI static constexpr index_type __offset(const _StridedMapping& __mapping) { 123 if constexpr (_StridedMapping::extents_type::rank() == 0) { 124 return static_cast<index_type>(__mapping()); 125 } else if (__mapping.required_span_size() == static_cast<typename _StridedMapping::index_type>(0)) { 126 return static_cast<index_type>(0); 127 } else { 128 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 129 return static_cast<index_type>(__mapping((_Pos ? 0 : 0)...)); 130 }(make_index_sequence<__rank_>()); 131 } 132 } 133 134 // compute the permutation for sorting the stride array 135 // we never actually sort the stride array 136 _LIBCPP_HIDE_FROM_ABI constexpr void __bubble_sort_by_strides(array<rank_type, __rank_>& __permute) const { 137 for (rank_type __i = __rank_ - 1; __i > 0; __i--) { 138 for (rank_type __r = 0; __r < __i; __r++) { 139 if (__strides_[__permute[__r]] > __strides_[__permute[__r + 1]]) { 140 swap(__permute[__r], __permute[__r + 1]); 141 } else { 142 // if two strides are the same then one of the associated extents must be 1 or 0 143 // both could be, but you can't have one larger than 1 come first 144 if ((__strides_[__permute[__r]] == __strides_[__permute[__r + 1]]) && 145 (__extents_.extent(__permute[__r]) > static_cast<index_type>(1))) 146 swap(__permute[__r], __permute[__r + 1]); 147 } 148 } 149 } 150 } 151 152 static_assert(extents_type::rank_dynamic() > 0 || __required_span_size_is_representable(extents_type()), 153 "layout_stride::mapping product of static extents must be representable as index_type."); 154 155 public: 156 // [mdspan.layout.stride.cons], constructors 157 _LIBCPP_HIDE_FROM_ABI constexpr mapping() noexcept : __extents_(extents_type()) { 158 // Note the nominal precondition is covered by above static assert since 159 // if rank_dynamic is != 0 required_span_size is zero for default construction 160 if constexpr (__rank_ > 0) { 161 index_type __stride = 1; 162 for (rank_type __r = __rank_ - 1; __r > static_cast<rank_type>(0); __r--) { 163 __strides_[__r] = __stride; 164 __stride *= __extents_.extent(__r); 165 } 166 __strides_[0] = __stride; 167 } 168 } 169 170 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const mapping&) noexcept = default; 171 172 template <class _OtherIndexType> 173 requires(is_convertible_v<const _OtherIndexType&, index_type> && 174 is_nothrow_constructible_v<index_type, const _OtherIndexType&>) 175 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) noexcept 176 : __extents_(__ext), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { 177 return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ 178 static_cast<index_type>(std::as_const(__strides[_Pos]))...}; 179 }(make_index_sequence<__rank_>())) { 180 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 181 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 182 // For integrals we can do a pre-conversion check, for other types not 183 if constexpr (is_integral_v<_OtherIndexType>) { 184 return ((__strides[_Pos] > static_cast<_OtherIndexType>(0)) && ... && true); 185 } else { 186 return ((static_cast<index_type>(__strides[_Pos]) > static_cast<index_type>(0)) && ... && true); 187 } 188 }(make_index_sequence<__rank_>())), 189 "layout_stride::mapping ctor: all strides must be greater than 0"); 190 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 191 __required_span_size_is_representable(__ext, __strides), 192 "layout_stride::mapping ctor: required span size is not representable as index_type."); 193 if constexpr (__rank_ > 1) { 194 _LIBCPP_ASSERT_UNCATEGORIZED( 195 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 196 // basically sort the dimensions based on strides and extents, sorting is represented in permute array 197 array<rank_type, __rank_> __permute{_Pos...}; 198 __bubble_sort_by_strides(__permute); 199 200 // check that this permutations represents a growing set 201 for (rank_type __i = 1; __i < __rank_; __i++) 202 if (static_cast<index_type>(__strides[__permute[__i]]) < 203 static_cast<index_type>(__strides[__permute[__i - 1]]) * __extents_.extent(__permute[__i - 1])) 204 return false; 205 return true; 206 }(make_index_sequence<__rank_>())), 207 "layout_stride::mapping ctor: the provided extents and strides lead to a non-unique mapping"); 208 } 209 } 210 211 template <class _OtherIndexType> 212 requires(is_convertible_v<const _OtherIndexType&, index_type> && 213 is_nothrow_constructible_v<index_type, const _OtherIndexType&>) 214 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, 215 const array<_OtherIndexType, __rank_>& __strides) noexcept 216 : mapping(__ext, span(__strides)) {} 217 218 template <class _StridedLayoutMapping> 219 requires(__mdspan_detail::__layout_mapping_alike<_StridedLayoutMapping> && 220 is_constructible_v<extents_type, typename _StridedLayoutMapping::extents_type> && 221 _StridedLayoutMapping::is_always_unique() && _StridedLayoutMapping::is_always_strided()) 222 _LIBCPP_HIDE_FROM_ABI constexpr explicit( 223 !(is_convertible_v<typename _StridedLayoutMapping::extents_type, extents_type> && 224 (__mdspan_detail::__is_mapping_of<layout_left, _StridedLayoutMapping> || 225 __mdspan_detail::__is_mapping_of<layout_right, _StridedLayoutMapping> || 226 __mdspan_detail::__is_mapping_of<layout_stride, _StridedLayoutMapping>))) 227 mapping(const _StridedLayoutMapping& __other) noexcept 228 : __extents_(__other.extents()), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { 229 // stride() only compiles for rank > 0 230 if constexpr (__rank_ > 0) { 231 return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ 232 static_cast<index_type>(__other.stride(_Pos))...}; 233 } else { 234 return __mdspan_detail::__possibly_empty_array<index_type, 0>{}; 235 } 236 }(make_index_sequence<__rank_>())) { 237 // stride() only compiles for rank > 0 238 if constexpr (__rank_ > 0) { 239 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 240 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 241 return ((static_cast<index_type>(__other.stride(_Pos)) > static_cast<index_type>(0)) && ... && true); 242 }(make_index_sequence<__rank_>())), 243 "layout_stride::mapping converting ctor: all strides must be greater than 0"); 244 } 245 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 246 __mdspan_detail::__is_representable_as<index_type>(__other.required_span_size()), 247 "layout_stride::mapping converting ctor: other.required_span_size() must be representable as index_type."); 248 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(static_cast<index_type>(0) == __offset(__other), 249 "layout_stride::mapping converting ctor: base offset of mapping must be zero."); 250 } 251 252 _LIBCPP_HIDE_FROM_ABI constexpr mapping& operator=(const mapping&) noexcept = default; 253 254 // [mdspan.layout.stride.obs], observers 255 _LIBCPP_HIDE_FROM_ABI constexpr const extents_type& extents() const noexcept { return __extents_; } 256 257 _LIBCPP_HIDE_FROM_ABI constexpr array<index_type, __rank_> strides() const noexcept { 258 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 259 return array<index_type, __rank_>{__strides_[_Pos]...}; 260 }(make_index_sequence<__rank_>()); 261 } 262 263 _LIBCPP_HIDE_FROM_ABI constexpr index_type required_span_size() const noexcept { 264 if constexpr (__rank_ == 0) { 265 return static_cast<index_type>(1); 266 } else { 267 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 268 if ((__extents_.extent(_Pos) * ... * 1) == 0) 269 return static_cast<index_type>(0); 270 else 271 return static_cast<index_type>( 272 static_cast<index_type>(1) + 273 (((__extents_.extent(_Pos) - static_cast<index_type>(1)) * __strides_[_Pos]) + ... + 274 static_cast<index_type>(0))); 275 }(make_index_sequence<__rank_>()); 276 } 277 } 278 279 template <class... _Indices> 280 requires((sizeof...(_Indices) == __rank_) && (is_convertible_v<_Indices, index_type> && ...) && 281 (is_nothrow_constructible_v<index_type, _Indices> && ...)) 282 _LIBCPP_HIDE_FROM_ABI constexpr index_type operator()(_Indices... __idx) const noexcept { 283 // Mappings are generally meant to be used for accessing allocations and are meant to guarantee to never 284 // return a value exceeding required_span_size(), which is used to know how large an allocation one needs 285 // Thus, this is a canonical point in multi-dimensional data structures to make invalid element access checks 286 // However, mdspan does check this on its own, so for now we avoid double checking in hardened mode 287 _LIBCPP_ASSERT_UNCATEGORIZED(__mdspan_detail::__is_multidimensional_index_in(__extents_, __idx...), 288 "layout_stride::mapping: out of bounds indexing"); 289 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 290 return ((static_cast<index_type>(__idx) * __strides_[_Pos]) + ... + index_type(0)); 291 }(make_index_sequence<sizeof...(_Indices)>()); 292 } 293 294 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_unique() noexcept { return true; } 295 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_exhaustive() noexcept { return false; } 296 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_strided() noexcept { return true; } 297 298 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_unique() noexcept { return true; } 299 // The answer of this function is fairly complex in the case where one or more 300 // extents are zero. 301 // Technically it is meaningless to query is_exhaustive() in that case, but unfortunately 302 // the way the standard defines this function, we can't give a simple true or false then. 303 _LIBCPP_HIDE_FROM_ABI constexpr bool is_exhaustive() const noexcept { 304 if constexpr (__rank_ == 0) 305 return true; 306 else { 307 index_type __span_size = required_span_size(); 308 if (__span_size == static_cast<index_type>(0)) { 309 if constexpr (__rank_ == 1) 310 return __strides_[0] == 1; 311 else { 312 rank_type __r_largest = 0; 313 for (rank_type __r = 1; __r < __rank_; __r++) 314 if (__strides_[__r] > __strides_[__r_largest]) 315 __r_largest = __r; 316 for (rank_type __r = 0; __r < __rank_; __r++) 317 if (__extents_.extent(__r) == 0 && __r != __r_largest) 318 return false; 319 return true; 320 } 321 } else { 322 return required_span_size() == [&]<size_t... _Pos>(index_sequence<_Pos...>) { 323 return (__extents_.extent(_Pos) * ... * static_cast<index_type>(1)); 324 }(make_index_sequence<__rank_>()); 325 } 326 } 327 } 328 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_strided() noexcept { return true; } 329 330 // according to the standard layout_stride does not have a constraint on stride(r) for rank>0 331 // it still has the precondition though 332 _LIBCPP_HIDE_FROM_ABI constexpr index_type stride(rank_type __r) const noexcept { 333 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(__r < __rank_, "layout_stride::mapping::stride(): invalid rank index"); 334 return __strides_[__r]; 335 } 336 337 template <class _OtherMapping> 338 requires(__mdspan_detail::__layout_mapping_alike<_OtherMapping> && 339 (_OtherMapping::extents_type::rank() == __rank_) && _OtherMapping::is_always_strided()) 340 _LIBCPP_HIDE_FROM_ABI friend constexpr bool operator==(const mapping& __lhs, const _OtherMapping& __rhs) noexcept { 341 if (__offset(__rhs)) 342 return false; 343 if constexpr (__rank_ == 0) 344 return true; 345 else { 346 return __lhs.extents() == __rhs.extents() && [&]<size_t... _Pos>(index_sequence<_Pos...>) { 347 // avoid warning when comparing signed and unsigner integers and pick the wider of two types 348 using _CommonType = common_type_t<index_type, typename _OtherMapping::index_type>; 349 return ((static_cast<_CommonType>(__lhs.stride(_Pos)) == static_cast<_CommonType>(__rhs.stride(_Pos))) && ... && 350 true); 351 }(make_index_sequence<__rank_>()); 352 } 353 } 354 355 private: 356 _LIBCPP_NO_UNIQUE_ADDRESS extents_type __extents_{}; 357 _LIBCPP_NO_UNIQUE_ADDRESS __mdspan_detail::__possibly_empty_array<index_type, __rank_> __strides_{}; 358 }; 359 360 #endif // _LIBCPP_STD_VER >= 23 361 362 _LIBCPP_END_NAMESPACE_STD 363 364 _LIBCPP_POP_MACROS 365 366 #endif // _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 367