1 //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// A bitvector that uses an IntervalMap to coalesce adjacent elements 11 /// into intervals. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_ADT_COALESCINGBITVECTOR_H 16 #define LLVM_ADT_COALESCINGBITVECTOR_H 17 18 #include "llvm/ADT/IntervalMap.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/iterator_range.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/raw_ostream.h" 24 25 #include <initializer_list> 26 27 namespace llvm { 28 29 /// A bitvector that, under the hood, relies on an IntervalMap to coalesce 30 /// elements into intervals. Good for representing sets which predominantly 31 /// contain contiguous ranges. Bad for representing sets with lots of gaps 32 /// between elements. 33 /// 34 /// Compared to SparseBitVector, CoalescingBitVector offers more predictable 35 /// performance for non-sequential find() operations. 36 /// 37 /// \tparam IndexT - The type of the index into the bitvector. 38 template <typename IndexT> class CoalescingBitVector { 39 static_assert(std::is_unsigned<IndexT>::value, 40 "Index must be an unsigned integer."); 41 42 using ThisT = CoalescingBitVector<IndexT>; 43 44 /// An interval map for closed integer ranges. The mapped values are unused. 45 using MapT = IntervalMap<IndexT, char>; 46 47 using UnderlyingIterator = typename MapT::const_iterator; 48 49 using IntervalT = std::pair<IndexT, IndexT>; 50 51 public: 52 using Allocator = typename MapT::Allocator; 53 54 /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator 55 /// reference. CoalescingBitVector(Allocator & Alloc)56 CoalescingBitVector(Allocator &Alloc) 57 : Alloc(&Alloc), Intervals(Alloc) {} 58 59 /// \name Copy/move constructors and assignment operators. 60 /// @{ 61 CoalescingBitVector(const ThisT & Other)62 CoalescingBitVector(const ThisT &Other) 63 : Alloc(Other.Alloc), Intervals(*Other.Alloc) { 64 set(Other); 65 } 66 67 ThisT &operator=(const ThisT &Other) { 68 clear(); 69 set(Other); 70 return *this; 71 } 72 73 CoalescingBitVector(ThisT &&Other) = delete; 74 ThisT &operator=(ThisT &&Other) = delete; 75 76 /// @} 77 78 /// Clear all the bits. clear()79 void clear() { Intervals.clear(); } 80 81 /// Check whether no bits are set. empty()82 bool empty() const { return Intervals.empty(); } 83 84 /// Count the number of set bits. count()85 unsigned count() const { 86 unsigned Bits = 0; 87 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It) 88 Bits += 1 + It.stop() - It.start(); 89 return Bits; 90 } 91 92 /// Set the bit at \p Index. 93 /// 94 /// This method does /not/ support setting a bit that has already been set, 95 /// for efficiency reasons. If possible, restructure your code to not set the 96 /// same bit multiple times, or use \ref test_and_set. set(IndexT Index)97 void set(IndexT Index) { 98 assert(!test(Index) && "Setting already-set bits not supported/efficient, " 99 "IntervalMap will assert"); 100 insert(Index, Index); 101 } 102 103 /// Set the bits set in \p Other. 104 /// 105 /// This method does /not/ support setting already-set bits, see \ref set 106 /// for the rationale. For a safe set union operation, use \ref operator|=. set(const ThisT & Other)107 void set(const ThisT &Other) { 108 for (auto It = Other.Intervals.begin(), End = Other.Intervals.end(); 109 It != End; ++It) 110 insert(It.start(), It.stop()); 111 } 112 113 /// Set the bits at \p Indices. Used for testing, primarily. set(std::initializer_list<IndexT> Indices)114 void set(std::initializer_list<IndexT> Indices) { 115 for (IndexT Index : Indices) 116 set(Index); 117 } 118 119 /// Check whether the bit at \p Index is set. test(IndexT Index)120 bool test(IndexT Index) const { 121 const auto It = Intervals.find(Index); 122 if (It == Intervals.end()) 123 return false; 124 assert(It.stop() >= Index && "Interval must end after Index"); 125 return It.start() <= Index; 126 } 127 128 /// Set the bit at \p Index. Supports setting an already-set bit. test_and_set(IndexT Index)129 void test_and_set(IndexT Index) { 130 if (!test(Index)) 131 set(Index); 132 } 133 134 /// Reset the bit at \p Index. Supports resetting an already-unset bit. reset(IndexT Index)135 void reset(IndexT Index) { 136 auto It = Intervals.find(Index); 137 if (It == Intervals.end()) 138 return; 139 140 // Split the interval containing Index into up to two parts: one from 141 // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to 142 // either Start or Stop, we create one new interval. If Index is equal to 143 // both Start and Stop, we simply erase the existing interval. 144 IndexT Start = It.start(); 145 if (Index < Start) 146 // The index was not set. 147 return; 148 IndexT Stop = It.stop(); 149 assert(Index <= Stop && "Wrong interval for index"); 150 It.erase(); 151 if (Start < Index) 152 insert(Start, Index - 1); 153 if (Index < Stop) 154 insert(Index + 1, Stop); 155 } 156 157 /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may 158 /// be a faster alternative. 159 void operator|=(const ThisT &RHS) { 160 // Get the overlaps between the two interval maps. 161 SmallVector<IntervalT, 8> Overlaps; 162 getOverlaps(RHS, Overlaps); 163 164 // Insert the non-overlapping parts of all the intervals from RHS. 165 for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end(); 166 It != End; ++It) { 167 IndexT Start = It.start(); 168 IndexT Stop = It.stop(); 169 SmallVector<IntervalT, 8> NonOverlappingParts; 170 getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts); 171 for (IntervalT AdditivePortion : NonOverlappingParts) 172 insert(AdditivePortion.first, AdditivePortion.second); 173 } 174 } 175 176 /// Set intersection. 177 void operator&=(const ThisT &RHS) { 178 // Get the overlaps between the two interval maps (i.e. the intersection). 179 SmallVector<IntervalT, 8> Overlaps; 180 getOverlaps(RHS, Overlaps); 181 // Rebuild the interval map, including only the overlaps. 182 clear(); 183 for (IntervalT Overlap : Overlaps) 184 insert(Overlap.first, Overlap.second); 185 } 186 187 /// Reset all bits present in \p Other. intersectWithComplement(const ThisT & Other)188 void intersectWithComplement(const ThisT &Other) { 189 SmallVector<IntervalT, 8> Overlaps; 190 if (!getOverlaps(Other, Overlaps)) { 191 // If there is no overlap with Other, the intersection is empty. 192 return; 193 } 194 195 // Delete the overlapping intervals. Split up intervals that only partially 196 // intersect an overlap. 197 for (IntervalT Overlap : Overlaps) { 198 IndexT OlapStart, OlapStop; 199 std::tie(OlapStart, OlapStop) = Overlap; 200 201 auto It = Intervals.find(OlapStart); 202 IndexT CurrStart = It.start(); 203 IndexT CurrStop = It.stop(); 204 assert(CurrStart <= OlapStart && OlapStop <= CurrStop && 205 "Expected some intersection!"); 206 207 // Split the overlap interval into up to two parts: one from [CurrStart, 208 // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is 209 // equal to CurrStart, the first split interval is unnecessary. Ditto for 210 // when OlapStop is equal to CurrStop, we omit the second split interval. 211 It.erase(); 212 if (CurrStart < OlapStart) 213 insert(CurrStart, OlapStart - 1); 214 if (OlapStop < CurrStop) 215 insert(OlapStop + 1, CurrStop); 216 } 217 } 218 219 bool operator==(const ThisT &RHS) const { 220 // We cannot just use std::equal because it checks the dereferenced values 221 // of an iterator pair for equality, not the iterators themselves. In our 222 // case that results in comparison of the (unused) IntervalMap values. 223 auto ItL = Intervals.begin(); 224 auto ItR = RHS.Intervals.begin(); 225 while (ItL != Intervals.end() && ItR != RHS.Intervals.end() && 226 ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) { 227 ++ItL; 228 ++ItR; 229 } 230 return ItL == Intervals.end() && ItR == RHS.Intervals.end(); 231 } 232 233 bool operator!=(const ThisT &RHS) const { return !operator==(RHS); } 234 235 class const_iterator { 236 friend class CoalescingBitVector; 237 238 public: 239 using iterator_category = std::forward_iterator_tag; 240 using value_type = IndexT; 241 using difference_type = std::ptrdiff_t; 242 using pointer = value_type *; 243 using reference = value_type &; 244 245 private: 246 // For performance reasons, make the offset at the end different than the 247 // one used in \ref begin, to optimize the common `It == end()` pattern. 248 static constexpr unsigned kIteratorAtTheEndOffset = ~0u; 249 250 UnderlyingIterator MapIterator; 251 unsigned OffsetIntoMapIterator = 0; 252 253 // Querying the start/stop of an IntervalMap iterator can be very expensive. 254 // Cache these values for performance reasons. 255 IndexT CachedStart = IndexT(); 256 IndexT CachedStop = IndexT(); 257 setToEnd()258 void setToEnd() { 259 OffsetIntoMapIterator = kIteratorAtTheEndOffset; 260 CachedStart = IndexT(); 261 CachedStop = IndexT(); 262 } 263 264 /// MapIterator has just changed, reset the cached state to point to the 265 /// start of the new underlying iterator. resetCache()266 void resetCache() { 267 if (MapIterator.valid()) { 268 OffsetIntoMapIterator = 0; 269 CachedStart = MapIterator.start(); 270 CachedStop = MapIterator.stop(); 271 } else { 272 setToEnd(); 273 } 274 } 275 276 /// Advance the iterator to \p Index, if it is contained within the current 277 /// interval. The public-facing method which supports advancing past the 278 /// current interval is \ref advanceToLowerBound. advanceTo(IndexT Index)279 void advanceTo(IndexT Index) { 280 assert(Index <= CachedStop && "Cannot advance to OOB index"); 281 if (Index < CachedStart) 282 // We're already past this index. 283 return; 284 OffsetIntoMapIterator = Index - CachedStart; 285 } 286 const_iterator(UnderlyingIterator MapIt)287 const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) { 288 resetCache(); 289 } 290 291 public: const_iterator()292 const_iterator() { setToEnd(); } 293 294 bool operator==(const const_iterator &RHS) const { 295 // Do /not/ compare MapIterator for equality, as this is very expensive. 296 // The cached start/stop values make that check unnecessary. 297 return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) == 298 std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart, 299 RHS.CachedStop); 300 } 301 302 bool operator!=(const const_iterator &RHS) const { 303 return !operator==(RHS); 304 } 305 306 IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; } 307 308 const_iterator &operator++() { // Pre-increment (++It). 309 if (CachedStart + OffsetIntoMapIterator < CachedStop) { 310 // Keep going within the current interval. 311 ++OffsetIntoMapIterator; 312 } else { 313 // We reached the end of the current interval: advance. 314 ++MapIterator; 315 resetCache(); 316 } 317 return *this; 318 } 319 320 const_iterator operator++(int) { // Post-increment (It++). 321 const_iterator tmp = *this; 322 operator++(); 323 return tmp; 324 } 325 326 /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If 327 /// no such set bit exists, advance to end(). This is like std::lower_bound. 328 /// This is useful if \p Index is close to the current iterator position. 329 /// However, unlike \ref find(), this has worst-case O(n) performance. advanceToLowerBound(IndexT Index)330 void advanceToLowerBound(IndexT Index) { 331 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) 332 return; 333 334 // Advance to the first interval containing (or past) Index, or to end(). 335 while (Index > CachedStop) { 336 ++MapIterator; 337 resetCache(); 338 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) 339 return; 340 } 341 342 advanceTo(Index); 343 } 344 }; 345 begin()346 const_iterator begin() const { return const_iterator(Intervals.begin()); } 347 end()348 const_iterator end() const { return const_iterator(); } 349 350 /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index. 351 /// If no such set bit exists, return end(). This is like std::lower_bound. 352 /// This has worst-case logarithmic performance (roughly O(log(gaps between 353 /// contiguous ranges))). find(IndexT Index)354 const_iterator find(IndexT Index) const { 355 auto UnderlyingIt = Intervals.find(Index); 356 if (UnderlyingIt == Intervals.end()) 357 return end(); 358 auto It = const_iterator(UnderlyingIt); 359 It.advanceTo(Index); 360 return It; 361 } 362 363 /// Return a range iterator which iterates over all of the set bits in the 364 /// half-open range [Start, End). half_open_range(IndexT Start,IndexT End)365 iterator_range<const_iterator> half_open_range(IndexT Start, 366 IndexT End) const { 367 assert(Start < End && "Not a valid range"); 368 auto StartIt = find(Start); 369 if (StartIt == end() || *StartIt >= End) 370 return {end(), end()}; 371 auto EndIt = StartIt; 372 EndIt.advanceToLowerBound(End); 373 return {StartIt, EndIt}; 374 } 375 print(raw_ostream & OS)376 void print(raw_ostream &OS) const { 377 OS << "{"; 378 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; 379 ++It) { 380 OS << "[" << It.start(); 381 if (It.start() != It.stop()) 382 OS << ", " << It.stop(); 383 OS << "]"; 384 } 385 OS << "}"; 386 } 387 388 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) dump()389 LLVM_DUMP_METHOD void dump() const { 390 // LLDB swallows the first line of output after callling dump(). Add 391 // newlines before/after the braces to work around this. 392 dbgs() << "\n"; 393 print(dbgs()); 394 dbgs() << "\n"; 395 } 396 #endif 397 398 private: insert(IndexT Start,IndexT End)399 void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); } 400 401 /// Record the overlaps between \p this and \p Other in \p Overlaps. Return 402 /// true if there is any overlap. getOverlaps(const ThisT & Other,SmallVectorImpl<IntervalT> & Overlaps)403 bool getOverlaps(const ThisT &Other, 404 SmallVectorImpl<IntervalT> &Overlaps) const { 405 for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals); 406 I.valid(); ++I) 407 Overlaps.emplace_back(I.start(), I.stop()); 408 assert(llvm::is_sorted(Overlaps, 409 [](IntervalT LHS, IntervalT RHS) { 410 return LHS.second < RHS.first; 411 }) && 412 "Overlaps must be sorted"); 413 return !Overlaps.empty(); 414 } 415 416 /// Given the set of overlaps between this and some other bitvector, and an 417 /// interval [Start, Stop] from that bitvector, determine the portions of the 418 /// interval which do not overlap with this. getNonOverlappingParts(IndexT Start,IndexT Stop,const SmallVectorImpl<IntervalT> & Overlaps,SmallVectorImpl<IntervalT> & NonOverlappingParts)419 void getNonOverlappingParts(IndexT Start, IndexT Stop, 420 const SmallVectorImpl<IntervalT> &Overlaps, 421 SmallVectorImpl<IntervalT> &NonOverlappingParts) { 422 IndexT NextUncoveredBit = Start; 423 for (IntervalT Overlap : Overlaps) { 424 IndexT OlapStart, OlapStop; 425 std::tie(OlapStart, OlapStop) = Overlap; 426 427 // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop 428 // and Start <= OlapStop. 429 bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop; 430 if (!DoesOverlap) 431 continue; 432 433 // Cover the range [NextUncoveredBit, OlapStart). This puts the start of 434 // the next uncovered range at OlapStop+1. 435 if (NextUncoveredBit < OlapStart) 436 NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1); 437 NextUncoveredBit = OlapStop + 1; 438 if (NextUncoveredBit > Stop) 439 break; 440 } 441 if (NextUncoveredBit <= Stop) 442 NonOverlappingParts.emplace_back(NextUncoveredBit, Stop); 443 } 444 445 Allocator *Alloc; 446 MapT Intervals; 447 }; 448 449 } // namespace llvm 450 451 #endif // LLVM_ADT_COALESCINGBITVECTOR_H 452