1 // -*- C++ -*- 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 // Kokkos v. 4.0 9 // Copyright (2022) National Technology & Engineering 10 // Solutions of Sandia, LLC (NTESS). 11 // 12 // Under the terms of Contract DE-NA0003525 with NTESS, 13 // the U.S. Government retains certain rights in this software. 14 // 15 //===---------------------------------------------------------------------===// 16 17 #ifndef _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 18 #define _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 19 20 #include <__assert> 21 #include <__concepts/same_as.h> 22 #include <__config> 23 #include <__fwd/mdspan.h> 24 #include <__mdspan/extents.h> 25 #include <__memory/addressof.h> 26 #include <__type_traits/common_type.h> 27 #include <__type_traits/is_constructible.h> 28 #include <__type_traits/is_convertible.h> 29 #include <__type_traits/is_integral.h> 30 #include <__type_traits/is_nothrow_constructible.h> 31 #include <__type_traits/is_same.h> 32 #include <__utility/as_const.h> 33 #include <__utility/integer_sequence.h> 34 #include <__utility/swap.h> 35 #include <array> 36 #include <limits> 37 #include <span> 38 39 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 40 # pragma GCC system_header 41 #endif 42 43 _LIBCPP_PUSH_MACROS 44 #include <__undef_macros> 45 46 _LIBCPP_BEGIN_NAMESPACE_STD 47 48 #if _LIBCPP_STD_VER >= 23 49 50 namespace __mdspan_detail { 51 template <class _Layout, class _Mapping> 52 constexpr bool __is_mapping_of = 53 is_same_v<typename _Layout::template mapping<typename _Mapping::extents_type>, _Mapping>; 54 55 template <class _Mapping> 56 concept __layout_mapping_alike = requires { 57 requires __is_mapping_of<typename _Mapping::layout_type, _Mapping>; 58 requires __is_extents_v<typename _Mapping::extents_type>; 59 { _Mapping::is_always_strided() } -> same_as<bool>; 60 { _Mapping::is_always_exhaustive() } -> same_as<bool>; 61 { _Mapping::is_always_unique() } -> same_as<bool>; 62 bool_constant<_Mapping::is_always_strided()>::value; 63 bool_constant<_Mapping::is_always_exhaustive()>::value; 64 bool_constant<_Mapping::is_always_unique()>::value; 65 }; 66 } // namespace __mdspan_detail 67 68 template <class _Extents> 69 class layout_stride::mapping { 70 public: 71 static_assert(__mdspan_detail::__is_extents<_Extents>::value, 72 "layout_stride::mapping template argument must be a specialization of extents."); 73 74 using extents_type = _Extents; 75 using index_type = typename extents_type::index_type; 76 using size_type = typename extents_type::size_type; 77 using rank_type = typename extents_type::rank_type; 78 using layout_type = layout_stride; 79 80 private: 81 static constexpr rank_type __rank_ = extents_type::rank(); 82 83 // Used for default construction check and mandates __required_span_size_is_representable(const extents_type & __ext)84 _LIBCPP_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext) { 85 if constexpr (__rank_ == 0) 86 return true; 87 88 index_type __prod = __ext.extent(0); 89 for (rank_type __r = 1; __r < __rank_; __r++) { 90 bool __overflowed = __builtin_mul_overflow(__prod, __ext.extent(__r), std::addressof(__prod)); 91 if (__overflowed) 92 return false; 93 } 94 return true; 95 } 96 97 template <class _OtherIndexType> 98 _LIBCPP_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type & __ext,span<_OtherIndexType,__rank_> __strides)99 __required_span_size_is_representable(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) { 100 if constexpr (__rank_ == 0) 101 return true; 102 103 index_type __size = 1; 104 for (rank_type __r = 0; __r < __rank_; __r++) { 105 // We can only check correct conversion of _OtherIndexType if it is an integral 106 if constexpr (is_integral_v<_OtherIndexType>) { 107 using _CommonType = common_type_t<index_type, _OtherIndexType>; 108 if (static_cast<_CommonType>(__strides[__r]) > static_cast<_CommonType>(numeric_limits<index_type>::max())) 109 return false; 110 } 111 if (__ext.extent(__r) == static_cast<index_type>(0)) 112 return true; 113 index_type __prod = (__ext.extent(__r) - 1); 114 bool __overflowed_mul = 115 __builtin_mul_overflow(__prod, static_cast<index_type>(__strides[__r]), std::addressof(__prod)); 116 if (__overflowed_mul) 117 return false; 118 bool __overflowed_add = __builtin_add_overflow(__size, __prod, std::addressof(__size)); 119 if (__overflowed_add) 120 return false; 121 } 122 return true; 123 } 124 125 // compute offset of a strided layout mapping 126 template <class _StridedMapping> __offset(const _StridedMapping & __mapping)127 _LIBCPP_HIDE_FROM_ABI static constexpr index_type __offset(const _StridedMapping& __mapping) { 128 if constexpr (_StridedMapping::extents_type::rank() == 0) { 129 return static_cast<index_type>(__mapping()); 130 } else if (__mapping.required_span_size() == static_cast<typename _StridedMapping::index_type>(0)) { 131 return static_cast<index_type>(0); 132 } else { 133 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 134 return static_cast<index_type>(__mapping((_Pos ? 0 : 0)...)); 135 }(make_index_sequence<__rank_>()); 136 } 137 } 138 139 // compute the permutation for sorting the stride array 140 // we never actually sort the stride array __bubble_sort_by_strides(array<rank_type,__rank_> & __permute)141 _LIBCPP_HIDE_FROM_ABI constexpr void __bubble_sort_by_strides(array<rank_type, __rank_>& __permute) const { 142 for (rank_type __i = __rank_ - 1; __i > 0; __i--) { 143 for (rank_type __r = 0; __r < __i; __r++) { 144 if (__strides_[__permute[__r]] > __strides_[__permute[__r + 1]]) { 145 swap(__permute[__r], __permute[__r + 1]); 146 } else { 147 // if two strides are the same then one of the associated extents must be 1 or 0 148 // both could be, but you can't have one larger than 1 come first 149 if ((__strides_[__permute[__r]] == __strides_[__permute[__r + 1]]) && 150 (__extents_.extent(__permute[__r]) > static_cast<index_type>(1))) 151 swap(__permute[__r], __permute[__r + 1]); 152 } 153 } 154 } 155 } 156 157 static_assert(extents_type::rank_dynamic() > 0 || __required_span_size_is_representable(extents_type()), 158 "layout_stride::mapping product of static extents must be representable as index_type."); 159 160 public: 161 // [mdspan.layout.stride.cons], constructors mapping()162 _LIBCPP_HIDE_FROM_ABI constexpr mapping() noexcept : __extents_(extents_type()) { 163 // Note the nominal precondition is covered by above static assert since 164 // if rank_dynamic is != 0 required_span_size is zero for default construction 165 if constexpr (__rank_ > 0) { 166 index_type __stride = 1; 167 for (rank_type __r = __rank_ - 1; __r > static_cast<rank_type>(0); __r--) { 168 __strides_[__r] = __stride; 169 __stride *= __extents_.extent(__r); 170 } 171 __strides_[0] = __stride; 172 } 173 } 174 175 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const mapping&) noexcept = default; 176 177 template <class _OtherIndexType> requires(is_convertible_v<const _OtherIndexType &,index_type> && is_nothrow_constructible_v<index_type,const _OtherIndexType &>)178 requires(is_convertible_v<const _OtherIndexType&, index_type> && 179 is_nothrow_constructible_v<index_type, const _OtherIndexType&>) 180 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) noexcept 181 : __extents_(__ext), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { 182 return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ 183 static_cast<index_type>(std::as_const(__strides[_Pos]))...}; 184 }(make_index_sequence<__rank_>())) { 185 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 186 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 187 // For integrals we can do a pre-conversion check, for other types not 188 if constexpr (is_integral_v<_OtherIndexType>) { 189 return ((__strides[_Pos] > static_cast<_OtherIndexType>(0)) && ... && true); 190 } else { 191 return ((static_cast<index_type>(__strides[_Pos]) > static_cast<index_type>(0)) && ... && true); 192 } 193 }(make_index_sequence<__rank_>())), 194 "layout_stride::mapping ctor: all strides must be greater than 0"); 195 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 196 __required_span_size_is_representable(__ext, __strides), 197 "layout_stride::mapping ctor: required span size is not representable as index_type."); 198 if constexpr (__rank_ > 1) { 199 _LIBCPP_ASSERT_UNCATEGORIZED( 200 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 201 // basically sort the dimensions based on strides and extents, sorting is represented in permute array 202 array<rank_type, __rank_> __permute{_Pos...}; 203 __bubble_sort_by_strides(__permute); 204 205 // check that this permutations represents a growing set 206 for (rank_type __i = 1; __i < __rank_; __i++) 207 if (static_cast<index_type>(__strides[__permute[__i]]) < 208 static_cast<index_type>(__strides[__permute[__i - 1]]) * __extents_.extent(__permute[__i - 1])) 209 return false; 210 return true; 211 }(make_index_sequence<__rank_>())), 212 "layout_stride::mapping ctor: the provided extents and strides lead to a non-unique mapping"); 213 } 214 } 215 216 template <class _OtherIndexType> requires(is_convertible_v<const _OtherIndexType &,index_type> && is_nothrow_constructible_v<index_type,const _OtherIndexType &>)217 requires(is_convertible_v<const _OtherIndexType&, index_type> && 218 is_nothrow_constructible_v<index_type, const _OtherIndexType&>) 219 _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, 220 const array<_OtherIndexType, __rank_>& __strides) noexcept 221 : mapping(__ext, span(__strides)) {} 222 223 template <class _StridedLayoutMapping> requires(__mdspan_detail::__layout_mapping_alike<_StridedLayoutMapping> && is_constructible_v<extents_type,typename _StridedLayoutMapping::extents_type> && _StridedLayoutMapping::is_always_unique ()&& _StridedLayoutMapping::is_always_strided ())224 requires(__mdspan_detail::__layout_mapping_alike<_StridedLayoutMapping> && 225 is_constructible_v<extents_type, typename _StridedLayoutMapping::extents_type> && 226 _StridedLayoutMapping::is_always_unique() && _StridedLayoutMapping::is_always_strided()) 227 _LIBCPP_HIDE_FROM_ABI constexpr explicit( 228 !(is_convertible_v<typename _StridedLayoutMapping::extents_type, extents_type> && 229 (__mdspan_detail::__is_mapping_of<layout_left, _StridedLayoutMapping> || 230 __mdspan_detail::__is_mapping_of<layout_right, _StridedLayoutMapping> || 231 __mdspan_detail::__is_mapping_of<layout_stride, _StridedLayoutMapping>))) 232 mapping(const _StridedLayoutMapping& __other) noexcept 233 : __extents_(__other.extents()), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { 234 // stride() only compiles for rank > 0 235 if constexpr (__rank_ > 0) { 236 return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ 237 static_cast<index_type>(__other.stride(_Pos))...}; 238 } else { 239 return __mdspan_detail::__possibly_empty_array<index_type, 0>{}; 240 } 241 }(make_index_sequence<__rank_>())) { 242 // stride() only compiles for rank > 0 243 if constexpr (__rank_ > 0) { 244 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 245 ([&]<size_t... _Pos>(index_sequence<_Pos...>) { 246 return ((static_cast<index_type>(__other.stride(_Pos)) > static_cast<index_type>(0)) && ... && true); 247 }(make_index_sequence<__rank_>())), 248 "layout_stride::mapping converting ctor: all strides must be greater than 0"); 249 } 250 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( 251 __mdspan_detail::__is_representable_as<index_type>(__other.required_span_size()), 252 "layout_stride::mapping converting ctor: other.required_span_size() must be representable as index_type."); 253 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(static_cast<index_type>(0) == __offset(__other), 254 "layout_stride::mapping converting ctor: base offset of mapping must be zero."); 255 } 256 257 _LIBCPP_HIDE_FROM_ABI constexpr mapping& operator=(const mapping&) noexcept = default; 258 259 // [mdspan.layout.stride.obs], observers extents()260 _LIBCPP_HIDE_FROM_ABI constexpr const extents_type& extents() const noexcept { return __extents_; } 261 strides()262 _LIBCPP_HIDE_FROM_ABI constexpr array<index_type, __rank_> strides() const noexcept { 263 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 264 return array<index_type, __rank_>{__strides_[_Pos]...}; 265 }(make_index_sequence<__rank_>()); 266 } 267 required_span_size()268 _LIBCPP_HIDE_FROM_ABI constexpr index_type required_span_size() const noexcept { 269 if constexpr (__rank_ == 0) { 270 return static_cast<index_type>(1); 271 } else { 272 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 273 if ((__extents_.extent(_Pos) * ... * 1) == 0) 274 return static_cast<index_type>(0); 275 else 276 return static_cast<index_type>( 277 static_cast<index_type>(1) + 278 (((__extents_.extent(_Pos) - static_cast<index_type>(1)) * __strides_[_Pos]) + ... + 279 static_cast<index_type>(0))); 280 }(make_index_sequence<__rank_>()); 281 } 282 } 283 284 template <class... _Indices> 285 requires((sizeof...(_Indices) == __rank_) && (is_convertible_v<_Indices, index_type> && ...) && 286 (is_nothrow_constructible_v<index_type, _Indices> && ...)) operator()287 _LIBCPP_HIDE_FROM_ABI constexpr index_type operator()(_Indices... __idx) const noexcept { 288 // Mappings are generally meant to be used for accessing allocations and are meant to guarantee to never 289 // return a value exceeding required_span_size(), which is used to know how large an allocation one needs 290 // Thus, this is a canonical point in multi-dimensional data structures to make invalid element access checks 291 // However, mdspan does check this on its own, so for now we avoid double checking in hardened mode 292 _LIBCPP_ASSERT_UNCATEGORIZED(__mdspan_detail::__is_multidimensional_index_in(__extents_, __idx...), 293 "layout_stride::mapping: out of bounds indexing"); 294 return [&]<size_t... _Pos>(index_sequence<_Pos...>) { 295 return ((static_cast<index_type>(__idx) * __strides_[_Pos]) + ... + index_type(0)); 296 }(make_index_sequence<sizeof...(_Indices)>()); 297 } 298 is_always_unique()299 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_unique() noexcept { return true; } is_always_exhaustive()300 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_exhaustive() noexcept { return false; } is_always_strided()301 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_strided() noexcept { return true; } 302 is_unique()303 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_unique() noexcept { return true; } 304 // The answer of this function is fairly complex in the case where one or more 305 // extents are zero. 306 // Technically it is meaningless to query is_exhaustive() in that case, but unfortunately 307 // the way the standard defines this function, we can't give a simple true or false then. is_exhaustive()308 _LIBCPP_HIDE_FROM_ABI constexpr bool is_exhaustive() const noexcept { 309 if constexpr (__rank_ == 0) 310 return true; 311 else { 312 index_type __span_size = required_span_size(); 313 if (__span_size == static_cast<index_type>(0)) { 314 if constexpr (__rank_ == 1) 315 return __strides_[0] == 1; 316 else { 317 rank_type __r_largest = 0; 318 for (rank_type __r = 1; __r < __rank_; __r++) 319 if (__strides_[__r] > __strides_[__r_largest]) 320 __r_largest = __r; 321 for (rank_type __r = 0; __r < __rank_; __r++) 322 if (__extents_.extent(__r) == 0 && __r != __r_largest) 323 return false; 324 return true; 325 } 326 } else { 327 return required_span_size() == [&]<size_t... _Pos>(index_sequence<_Pos...>) { 328 return (__extents_.extent(_Pos) * ... * static_cast<index_type>(1)); 329 }(make_index_sequence<__rank_>()); 330 } 331 } 332 } is_strided()333 _LIBCPP_HIDE_FROM_ABI static constexpr bool is_strided() noexcept { return true; } 334 335 // according to the standard layout_stride does not have a constraint on stride(r) for rank>0 336 // it still has the precondition though stride(rank_type __r)337 _LIBCPP_HIDE_FROM_ABI constexpr index_type stride(rank_type __r) const noexcept { 338 _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(__r < __rank_, "layout_stride::mapping::stride(): invalid rank index"); 339 return __strides_[__r]; 340 } 341 342 template <class _OtherMapping> 343 requires(__mdspan_detail::__layout_mapping_alike<_OtherMapping> && 344 (_OtherMapping::extents_type::rank() == __rank_) && _OtherMapping::is_always_strided()) 345 _LIBCPP_HIDE_FROM_ABI friend constexpr bool operator==(const mapping& __lhs, const _OtherMapping& __rhs) noexcept { 346 if (__offset(__rhs)) 347 return false; 348 if constexpr (__rank_ == 0) 349 return true; 350 else { 351 return __lhs.extents() == __rhs.extents() && [&]<size_t... _Pos>(index_sequence<_Pos...>) { 352 // avoid warning when comparing signed and unsigner integers and pick the wider of two types 353 using _CommonType = common_type_t<index_type, typename _OtherMapping::index_type>; 354 return ((static_cast<_CommonType>(__lhs.stride(_Pos)) == static_cast<_CommonType>(__rhs.stride(_Pos))) && ... && 355 true); 356 }(make_index_sequence<__rank_>()); 357 } 358 } 359 360 private: 361 _LIBCPP_NO_UNIQUE_ADDRESS extents_type __extents_{}; 362 _LIBCPP_NO_UNIQUE_ADDRESS __mdspan_detail::__possibly_empty_array<index_type, __rank_> __strides_{}; 363 }; 364 365 #endif // _LIBCPP_STD_VER >= 23 366 367 _LIBCPP_END_NAMESPACE_STD 368 369 _LIBCPP_POP_MACROS 370 371 #endif // _LIBCPP___MDSPAN_LAYOUT_STRIDE_H 372