xref: /freebsd/contrib/llvm-project/llvm/include/llvm/ADT/IntervalTree.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- IntervalTree.h ------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an interval tree.
10 //
11 // Further information:
12 // https://en.wikipedia.org/wiki/Interval_tree
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_ADT_INTERVALTREE_H
17 #define LLVM_ADT_INTERVALTREE_H
18 
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Allocator.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include <algorithm>
24 #include <cassert>
25 #include <iterator>
26 
27 // IntervalTree is a light tree data structure to hold intervals. It allows
28 // finding all intervals that overlap with any given point. At this time,
29 // it does not support any deletion or rebalancing operations.
30 //
31 // The IntervalTree is designed to be set up once, and then queried without
32 // any further additions.
33 //
34 // Synopsis:
35 //   Closed intervals delimited by PointT objects are mapped to ValueT objects.
36 //
37 // Restrictions:
38 //   PointT must be a fundamental type.
39 //   ValueT must be a fundamental or pointer type.
40 //
41 // template <typename PointT, typename ValueT, typename DataT>
42 // class IntervalTree {
43 // public:
44 //
45 //   IntervalTree();
46 //   ~IntervalTree():
47 //
48 //   using IntervalReferences = SmallVector<IntervalData *>;
49 //
50 //   void create();
51 //   void insert(PointT Left, PointT Right, ValueT Value);
52 //
53 //   IntervalReferences getContaining(PointT Point);
54 //   static void sortIntervals(IntervalReferences &Intervals, Sorting Sort);
55 //
56 //   find_iterator begin(PointType Point) const;
57 //   find_iterator end() const;
58 //
59 //   bool empty() const;
60 //   void clear();
61 //
62 //   void print(raw_ostream &OS, bool HexFormat = true);
63 // };
64 //
65 //===----------------------------------------------------------------------===//
66 //
67 // In the below given dataset
68 //
69 //   [a, b] <- (x)
70 //
71 // 'a' and 'b' describe a range and 'x' the value for that interval.
72 //
73 // The following data are purely for illustrative purposes:
74 //
75 // [30, 35] <- (3035),    [39, 50] <- (3950),    [55, 61] <- (5561),
76 // [31, 56] <- (3156),    [12, 21] <- (1221),    [25, 41] <- (2541),
77 // [49, 65] <- (4965),    [71, 79] <- (7179),    [11, 16] <- (1116),
78 // [20, 30] <- (2030),    [36, 54] <- (3654),    [60, 70] <- (6070),
79 // [74, 80] <- (7480),    [15, 40] <- (1540),    [43, 43] <- (4343),
80 // [50, 75] <- (5075),    [10, 85] <- (1085)
81 //
82 // The data represents a set of overlapping intervals:
83 //
84 //                    30--35  39------------50  55----61
85 //                      31------------------------56
86 //     12--------21 25------------41      49-------------65   71-----79
87 //   11----16  20-----30    36----------------54    60------70  74---- 80
88 //       15---------------------40  43--43  50--------------------75
89 // 10----------------------------------------------------------------------85
90 //
91 // The items are stored in a binary tree with each node storing:
92 //
93 // MP: A middle point.
94 // IL: All intervals whose left value are completely to the left of the middle
95 //     point. They are sorted in ascending order by their beginning point.
96 // IR: All intervals whose right value are completely to the right of the
97 //     middle point. They are sorted in descending order by their ending point.
98 // LS: Left subtree.
99 // RS: Right subtree.
100 //
101 // As IL and IR will contain the same intervals, in order to optimize space,
102 // instead of storing intervals on each node, we use two vectors that will
103 // contain the intervals described by IL and IR. Each node will contain an
104 // index into that vector (global bucket), to indicate the beginning of the
105 // intervals assigned to the node.
106 //
107 // The following is the output from print():
108 //
109 // 0: MP:43 IR [10,85] [31,56] [36,54] [39,50] [43,43]
110 // 0: MP:43 IL [10,85] [31,56] [36,54] [39,50] [43,43]
111 // 1:   MP:25 IR [25,41] [15,40] [20,30]
112 // 1:   MP:25 IL [15,40] [20,30] [25,41]
113 // 2:     MP:15 IR [12,21] [11,16]
114 // 2:     MP:15 IL [11,16] [12,21]
115 // 2:     MP:36 IR []
116 // 2:     MP:36 IL []
117 // 3:       MP:31 IR [30,35]
118 // 3:       MP:31 IL [30,35]
119 // 1:   MP:61 IR [50,75] [60,70] [49,65] [55,61]
120 // 1:   MP:61 IL [49,65] [50,75] [55,61] [60,70]
121 // 2:     MP:74 IR [74,80] [71,79]
122 // 2:     MP:74 IL [71,79] [74,80]
123 //
124 // with:
125 //    0: Root Node.
126 //   MP: Middle point.
127 //   IL: Intervals to the left (in ascending order by beginning point).
128 //   IR: Intervals to the right (in descending order by ending point).
129 //
130 //                                    Root
131 //                                      |
132 //                                      V
133 //                       +------------MP:43------------+
134 //                       |            IL IR            |
135 //                       |       [10,85] [10,85]       |
136 //                    LS |       [31,56] [31,56]       | RS
137 //                       |       [36,54] [36,54]       |
138 //                       |       [39,50] [39,50]       |
139 //                       |       [43,43] [43,43]       |
140 //                       V                             V
141 //        +------------MP:25------------+            MP:61------------+
142 //        |            IL IR            |            IL IR            |
143 //        |       [15,40] [25,41]       |       [49,65] [50,75]       |
144 //     LS |       [20,30] [15,40]       | RS    [50,75] [60,70]       | RS
145 //        |       [25,41] [20,30]       |       [55,61] [49,65]       |
146 //        |                             |       [60,70] [55,61]       |
147 //        V                             V                             V
148 //      MP:15                 +-------MP:36                         MP:74
149 //      IL IR                 |       IL IR                         IL IR
150 // [11,16] [12,21]         LS |       [] []                    [71,79] [74,80]
151 // [12,21] [11,16]            |                                [74,80] [71,79]
152 //                            V
153 //                          MP:31
154 //                          IL IR
155 //                     [30,35] [30,35]
156 //
157 // The creation of an interval tree is done in 2 steps:
158 // 1) Insert the interval items by calling
159 //    void insert(PointT Left, PointT Right, ValueT Value);
160 //    Left, Right: the interval left and right limits.
161 //    Value: the data associated with that specific interval.
162 //
163 // 2) Create the interval tree by calling
164 //    void create();
165 //
166 // Once the tree is created, it is switched to query mode.
167 // Query the tree by using iterators or container.
168 //
169 // a) Iterators over intervals overlapping the given point with very weak
170 //    ordering guarantees.
171 //    find_iterator begin(PointType Point) const;
172 //    find_iterator end() const;
173 //    Point: a target point to be tested for inclusion in any interval.
174 //
175 // b) Container:
176 //    IntervalReferences getContaining(PointT Point);
177 //    Point: a target point to be tested for inclusion in any interval.
178 //    Returns vector with all the intervals containing the target point.
179 //
180 // The returned intervals are in their natural tree location. They can
181 // be sorted:
182 //
183 // static void sortIntervals(IntervalReferences &Intervals, Sorting Sort);
184 //
185 // Ability to print the constructed interval tree:
186 //   void print(raw_ostream &OS, bool HexFormat = true);
187 // Display the associated data in hexadecimal format.
188 
189 namespace llvm {
190 
191 //===----------------------------------------------------------------------===//
192 //---                          IntervalData                               ----//
193 //===----------------------------------------------------------------------===//
194 /// An interval data composed by a \a Left and \a Right points and an
195 /// associated \a Value.
196 /// \a PointT corresponds to the interval endpoints type.
197 /// \a ValueT corresponds to the interval value type.
198 template <typename PointT, typename ValueT> class IntervalData {
199 protected:
200   using PointType = PointT;
201   using ValueType = ValueT;
202 
203 private:
204   PointType Left;
205   PointType Right;
206   ValueType Value;
207 
208 public:
209   IntervalData() = delete;
IntervalData(PointType Left,PointType Right,ValueType Value)210   IntervalData(PointType Left, PointType Right, ValueType Value)
211       : Left(Left), Right(Right), Value(Value) {
212     assert(Left <= Right && "'Left' must be less or equal to 'Right'");
213   }
214   virtual ~IntervalData() = default;
left()215   PointType left() const { return Left; }
right()216   PointType right() const { return Right; }
value()217   ValueType value() const { return Value; }
218 
219   /// Return true if \a Point is inside the left bound of closed interval \a
220   /// [Left;Right]. This is Left <= Point for closed intervals.
left(const PointType & Point)221   bool left(const PointType &Point) const { return left() <= Point; }
222 
223   /// Return true if \a Point is inside the right bound of closed interval \a
224   /// [Left;Right]. This is Point <= Right for closed intervals.
right(const PointType & Point)225   bool right(const PointType &Point) const { return Point <= right(); }
226 
227   /// Return true when \a Point is contained in interval \a [Left;Right].
228   /// This is Left <= Point <= Right for closed intervals.
contains(const PointType & Point)229   bool contains(const PointType &Point) const {
230     return left(Point) && right(Point);
231   }
232 };
233 
234 //===----------------------------------------------------------------------===//
235 //---                          IntervalTree                               ----//
236 //===----------------------------------------------------------------------===//
237 // Helper class template that is used by the IntervalTree to ensure that one
238 // does instantiate using only fundamental and/or pointer types.
239 template <typename T>
240 using PointTypeIsValid = std::bool_constant<std::is_fundamental<T>::value>;
241 
242 template <typename T>
243 using ValueTypeIsValid = std::bool_constant<std::is_fundamental<T>::value ||
244                                             std::is_pointer<T>::value>;
245 
246 template <typename PointT, typename ValueT,
247           typename DataT = IntervalData<PointT, ValueT>>
248 class IntervalTree {
249   static_assert(PointTypeIsValid<PointT>::value,
250                 "PointT must be a fundamental type");
251   static_assert(ValueTypeIsValid<ValueT>::value,
252                 "ValueT must be a fundamental or pointer type");
253 
254 public:
255   using PointType = PointT;
256   using ValueType = ValueT;
257   using DataType = DataT;
258   using Allocator = BumpPtrAllocator;
259 
260   enum class Sorting { Ascending, Descending };
261   using IntervalReferences = SmallVector<const DataType *, 4>;
262 
263 private:
264   using IntervalVector = SmallVector<DataType, 4>;
265   using PointsVector = SmallVector<PointType, 4>;
266 
267   class IntervalNode {
268     PointType MiddlePoint;             // MP - Middle point.
269     IntervalNode *Left = nullptr;      // LS - Left subtree.
270     IntervalNode *Right = nullptr;     // RS - Right subtree.
271     unsigned BucketIntervalsStart = 0; // Starting index in global bucket.
272     unsigned BucketIntervalsSize = 0;  // Size of bucket.
273 
274   public:
middle()275     PointType middle() const { return MiddlePoint; }
start()276     unsigned start() const { return BucketIntervalsStart; }
size()277     unsigned size() const { return BucketIntervalsSize; }
278 
IntervalNode(PointType Point,unsigned Start)279     IntervalNode(PointType Point, unsigned Start)
280         : MiddlePoint(Point), BucketIntervalsStart(Start) {}
281 
282     friend IntervalTree;
283   };
284 
285   Allocator &NodeAllocator;     // Allocator used for creating interval nodes.
286   IntervalNode *Root = nullptr; // Interval tree root.
287   IntervalVector Intervals; // Storage for each interval and all of the fields
288                             // point back into it.
289   PointsVector EndPoints; // Sorted left and right points of all the intervals.
290 
291   // These vectors provide storage that nodes carve buckets of overlapping
292   // intervals out of. All intervals are recorded on each vector.
293   // The bucket with the intervals associated to a node, is determined by
294   // the fields 'BucketIntervalStart' and 'BucketIntervalSize' in the node.
295   // The buckets in the first vector are sorted in ascending order using
296   // the left value and the buckets in the second vector are sorted in
297   // descending order using the right value. Every interval in a bucket
298   // contains the middle point for the node.
299   IntervalReferences IntervalsLeft;  // Intervals to the left of middle point.
300   IntervalReferences IntervalsRight; // Intervals to the right of middle point.
301 
302   // Working vector used during the tree creation to sort the intervals. It is
303   // cleared once the tree is created.
304   IntervalReferences References;
305 
306   /// Recursively delete the constructed tree.
deleteTree(IntervalNode * Node)307   void deleteTree(IntervalNode *Node) {
308     if (Node) {
309       deleteTree(Node->Left);
310       deleteTree(Node->Right);
311       Node->~IntervalNode();
312       NodeAllocator.Deallocate(Node);
313     }
314   }
315 
316   /// Print the interval list (left and right) for a given \a Node.
317   static void printList(raw_ostream &OS, IntervalReferences &IntervalSet,
318                         unsigned Start, unsigned Size, bool HexFormat = true) {
319     assert(Start + Size <= IntervalSet.size() &&
320            "Start + Size must be in bounds of the IntervalSet");
321     const char *Format = HexFormat ? "[0x%08x,0x%08x] " : "[%2d,%2d] ";
322     if (Size) {
323       for (unsigned Position = Start; Position < Start + Size; ++Position)
324         OS << format(Format, IntervalSet[Position]->left(),
325                      IntervalSet[Position]->right());
326     } else {
327       OS << "[]";
328     }
329     OS << "\n";
330   }
331 
332   /// Print an interval tree \a Node.
333   void printNode(raw_ostream &OS, unsigned Level, IntervalNode *Node,
334                  bool HexFormat = true) {
335     const char *Format = HexFormat ? "MP:0x%08x " : "MP:%2d ";
336     auto PrintNodeData = [&](StringRef Text, IntervalReferences &IntervalSet) {
337       OS << format("%5d: ", Level);
338       OS.indent(Level * 2);
339       OS << format(Format, Node->middle()) << Text << " ";
340       printList(OS, IntervalSet, Node->start(), Node->size(), HexFormat);
341     };
342 
343     PrintNodeData("IR", IntervalsRight);
344     PrintNodeData("IL", IntervalsLeft);
345   }
346 
347   /// Recursively print all the interval nodes.
348   void printTree(raw_ostream &OS, unsigned Level, IntervalNode *Node,
349                  bool HexFormat = true) {
350     if (Node) {
351       printNode(OS, Level, Node, HexFormat);
352       ++Level;
353       printTree(OS, Level, Node->Left, HexFormat);
354       printTree(OS, Level, Node->Right, HexFormat);
355     }
356   }
357 
358   /// Recursively construct the interval tree.
359   /// IntervalsSize: Number of intervals that have been processed and it will
360   /// be used as the start for the intervals bucket for a node.
361   /// PointsBeginIndex, PointsEndIndex: Determine the range into the EndPoints
362   /// vector of end points to be processed.
363   /// ReferencesBeginIndex, ReferencesSize: Determine the range into the
364   /// intervals being processed.
createTree(unsigned & IntervalsSize,int PointsBeginIndex,int PointsEndIndex,int ReferencesBeginIndex,int ReferencesSize)365   IntervalNode *createTree(unsigned &IntervalsSize, int PointsBeginIndex,
366                            int PointsEndIndex, int ReferencesBeginIndex,
367                            int ReferencesSize) {
368     // We start by taking the entire range of all the intervals and dividing
369     // it in half at x_middle (in practice, x_middle should be picked to keep
370     // the tree relatively balanced).
371     // This gives three sets of intervals, those completely to the left of
372     // x_middle which we'll call S_left, those completely to the right of
373     // x_middle which we'll call S_right, and those overlapping x_middle
374     // which we'll call S_middle.
375     // The intervals in S_left and S_right are recursively divided in the
376     // same manner until there are no intervals remaining.
377 
378     if (PointsBeginIndex > PointsEndIndex ||
379         ReferencesBeginIndex >= ReferencesSize)
380       return nullptr;
381 
382     int MiddleIndex = (PointsBeginIndex + PointsEndIndex) / 2;
383     PointType MiddlePoint = EndPoints[MiddleIndex];
384 
385     unsigned NewBucketStart = IntervalsSize;
386     unsigned NewBucketSize = 0;
387     int ReferencesRightIndex = ReferencesSize;
388 
389     IntervalNode *Root =
390         new (NodeAllocator) IntervalNode(MiddlePoint, NewBucketStart);
391 
392     // A quicksort implementation where all the intervals that overlap
393     // with the pivot are put into the "bucket", and "References" is the
394     // partition space where we recursively sort the remaining intervals.
395     for (int Index = ReferencesBeginIndex; Index < ReferencesRightIndex;) {
396 
397       // Current interval contains the middle point.
398       if (References[Index]->contains(MiddlePoint)) {
399         IntervalsLeft[IntervalsSize] = References[Index];
400         IntervalsRight[IntervalsSize] = References[Index];
401         ++IntervalsSize;
402         Root->BucketIntervalsSize = ++NewBucketSize;
403 
404         if (Index < --ReferencesRightIndex)
405           std::swap(References[Index], References[ReferencesRightIndex]);
406         if (ReferencesRightIndex < --ReferencesSize)
407           std::swap(References[ReferencesRightIndex],
408                     References[ReferencesSize]);
409         continue;
410       }
411 
412       if (References[Index]->left() > MiddlePoint) {
413         if (Index < --ReferencesRightIndex)
414           std::swap(References[Index], References[ReferencesRightIndex]);
415         continue;
416       }
417       ++Index;
418     }
419 
420     // Sort intervals on the left and right of the middle point.
421     if (NewBucketSize > 1) {
422       // Sort the intervals in ascending order by their beginning point.
423       std::stable_sort(IntervalsLeft.begin() + NewBucketStart,
424                        IntervalsLeft.begin() + NewBucketStart + NewBucketSize,
425                        [](const DataType *LHS, const DataType *RHS) {
426                          return LHS->left() < RHS->left();
427                        });
428       // Sort the intervals in descending order by their ending point.
429       std::stable_sort(IntervalsRight.begin() + NewBucketStart,
430                        IntervalsRight.begin() + NewBucketStart + NewBucketSize,
431                        [](const DataType *LHS, const DataType *RHS) {
432                          return LHS->right() > RHS->right();
433                        });
434     }
435 
436     if (PointsBeginIndex <= MiddleIndex - 1) {
437       Root->Left = createTree(IntervalsSize, PointsBeginIndex, MiddleIndex - 1,
438                               ReferencesBeginIndex, ReferencesRightIndex);
439     }
440 
441     if (MiddleIndex + 1 <= PointsEndIndex) {
442       Root->Right = createTree(IntervalsSize, MiddleIndex + 1, PointsEndIndex,
443                                ReferencesRightIndex, ReferencesSize);
444     }
445 
446     return Root;
447   }
448 
449 public:
450   class find_iterator {
451   public:
452     using iterator_category = std::forward_iterator_tag;
453     using value_type = DataType;
454     using difference_type = DataType;
455     using pointer = DataType *;
456     using reference = DataType &;
457 
458   private:
459     const IntervalReferences *AscendingBuckets = nullptr;
460     const IntervalReferences *DescendingBuckets = nullptr;
461 
462     // Current node and index while traversing the intervals that contain
463     // the reference point.
464     IntervalNode *Node = nullptr;
465     PointType Point = {};
466     unsigned Index = 0;
467 
468     // For the current node, check if we have intervals that contain the
469     // reference point. We return when the node does have intervals that
470     // contain such point. Otherwise we keep descending on that branch.
initNode()471     void initNode() {
472       Index = 0;
473       while (Node) {
474         // Return if the reference point is the same as the middle point or
475         // the current node doesn't have any intervals at all.
476         if (Point == Node->middle()) {
477           if (Node->size() == 0) {
478             // No intervals that contain the reference point.
479             Node = nullptr;
480           }
481           return;
482         }
483 
484         if (Point < Node->middle()) {
485           // The reference point can be at the left or right of the middle
486           // point. Return if the current node has intervals that contain the
487           // reference point; otherwise descend on the respective branch.
488           if (Node->size() && (*AscendingBuckets)[Node->start()]->left(Point)) {
489             return;
490           }
491           Node = Node->Left;
492         } else {
493           if (Node->size() &&
494               (*DescendingBuckets)[Node->start()]->right(Point)) {
495             return;
496           }
497           Node = Node->Right;
498         }
499       }
500     }
501 
502     // Given the current node (which was initialized by initNode), move to
503     // the next interval in the list of intervals that contain the reference
504     // point. Otherwise move to the next node, as the intervals contained
505     // in that node, can contain the reference point.
nextInterval()506     void nextInterval() {
507       // If there are available intervals that contain the reference point,
508       // traverse them; otherwise move to the left or right node, depending
509       // on the middle point value.
510       if (++Index < Node->size()) {
511         if (Node->middle() == Point)
512           return;
513         if (Point < Node->middle()) {
514           // Reference point is on the left.
515           if (!(*AscendingBuckets)[Node->start() + Index]->left(Point)) {
516             // The intervals don't contain the reference point. Move to the
517             // next node, preserving the descending order.
518             Node = Node->Left;
519             initNode();
520           }
521         } else {
522           // Reference point is on the right.
523           if (!(*DescendingBuckets)[Node->start() + Index]->right(Point)) {
524             // The intervals don't contain the reference point. Move to the
525             // next node, preserving the ascending order.
526             Node = Node->Right;
527             initNode();
528           }
529         }
530       } else {
531         // We have traversed all the intervals in the current node.
532         if (Point == Node->middle()) {
533           Node = nullptr;
534           Index = 0;
535           return;
536         }
537         // Select a branch based on the middle point.
538         Node = Point < Node->middle() ? Node->Left : Node->Right;
539         initNode();
540       }
541     }
542 
543     find_iterator() = default;
find_iterator(const IntervalReferences * Left,const IntervalReferences * Right,IntervalNode * Node,PointType Point)544     explicit find_iterator(const IntervalReferences *Left,
545                            const IntervalReferences *Right, IntervalNode *Node,
546                            PointType Point)
547         : AscendingBuckets(Left), DescendingBuckets(Right), Node(Node),
548           Point(Point), Index(0) {
549       initNode();
550     }
551 
current()552     const DataType *current() const {
553       return (Point <= Node->middle())
554                  ? (*AscendingBuckets)[Node->start() + Index]
555                  : (*DescendingBuckets)[Node->start() + Index];
556     }
557 
558   public:
559     find_iterator &operator++() {
560       nextInterval();
561       return *this;
562     }
563 
564     find_iterator operator++(int) {
565       find_iterator Iter(*this);
566       nextInterval();
567       return Iter;
568     }
569 
570     /// Dereference operators.
571     const DataType *operator->() const { return current(); }
572     const DataType &operator*() const { return *(current()); }
573 
574     /// Comparison operators.
575     friend bool operator==(const find_iterator &LHS, const find_iterator &RHS) {
576       return (!LHS.Node && !RHS.Node && !LHS.Index && !RHS.Index) ||
577              (LHS.Point == RHS.Point && LHS.Node == RHS.Node &&
578               LHS.Index == RHS.Index);
579     }
580     friend bool operator!=(const find_iterator &LHS, const find_iterator &RHS) {
581       return !(LHS == RHS);
582     }
583 
584     friend IntervalTree;
585   };
586 
587 private:
588   find_iterator End;
589 
590 public:
IntervalTree(Allocator & NodeAllocator)591   explicit IntervalTree(Allocator &NodeAllocator)
592       : NodeAllocator(NodeAllocator) {}
~IntervalTree()593   ~IntervalTree() { clear(); }
594 
595   /// Return true when no intervals are mapped.
empty()596   bool empty() const { return Root == nullptr; }
597 
598   /// Remove all entries.
clear()599   void clear() {
600     deleteTree(Root);
601     Root = nullptr;
602     Intervals.clear();
603     IntervalsLeft.clear();
604     IntervalsRight.clear();
605     EndPoints.clear();
606   }
607 
608   /// Add a mapping of [Left;Right] to \a Value.
insert(PointType Left,PointType Right,ValueType Value)609   void insert(PointType Left, PointType Right, ValueType Value) {
610     assert(empty() && "Invalid insertion. Interval tree already constructed.");
611     Intervals.emplace_back(Left, Right, Value);
612   }
613 
614   /// Return all the intervals in their natural tree location, that
615   /// contain the given point.
getContaining(PointType Point)616   IntervalReferences getContaining(PointType Point) const {
617     assert(!empty() && "Interval tree it is not constructed.");
618     IntervalReferences IntervalSet;
619     for (find_iterator Iter = find(Point), E = find_end(); Iter != E; ++Iter)
620       IntervalSet.push_back(const_cast<DataType *>(&(*Iter)));
621     return IntervalSet;
622   }
623 
624   /// Sort the given intervals using the following sort options:
625   /// Ascending: return the intervals with the smallest at the front.
626   /// Descending: return the intervals with the biggest at the front.
sortIntervals(IntervalReferences & IntervalSet,Sorting Sort)627   static void sortIntervals(IntervalReferences &IntervalSet, Sorting Sort) {
628     std::stable_sort(IntervalSet.begin(), IntervalSet.end(),
629                      [Sort](const DataType *RHS, const DataType *LHS) {
630                        return Sort == Sorting::Ascending
631                                   ? (LHS->right() - LHS->left()) >
632                                         (RHS->right() - RHS->left())
633                                   : (LHS->right() - LHS->left()) <
634                                         (RHS->right() - RHS->left());
635                      });
636   }
637 
638   /// Print the interval tree.
639   /// When \a HexFormat is true, the interval tree interval ranges and
640   /// associated values are printed in hexadecimal format.
641   void print(raw_ostream &OS, bool HexFormat = true) {
642     printTree(OS, 0, Root, HexFormat);
643   }
644 
645   /// Create the interval tree.
create()646   void create() {
647     assert(empty() && "Interval tree already constructed.");
648     // Sorted vector of unique end points values of all the intervals.
649     // Records references to the collected intervals.
650     SmallVector<PointType, 4> Points;
651     for (const DataType &Data : Intervals) {
652       Points.push_back(Data.left());
653       Points.push_back(Data.right());
654       References.push_back(std::addressof(Data));
655     }
656     std::stable_sort(Points.begin(), Points.end());
657     auto Last = llvm::unique(Points);
658     Points.erase(Last, Points.end());
659 
660     EndPoints.assign(Points.begin(), Points.end());
661 
662     IntervalsLeft.resize(Intervals.size());
663     IntervalsRight.resize(Intervals.size());
664 
665     // Given a set of n intervals, construct a data structure so that
666     // we can efficiently retrieve all intervals overlapping another
667     // interval or point.
668     unsigned IntervalsSize = 0;
669     Root =
670         createTree(IntervalsSize, /*PointsBeginIndex=*/0, EndPoints.size() - 1,
671                    /*ReferencesBeginIndex=*/0, References.size());
672 
673     // Save to clear this storage, as it used only to sort the intervals.
674     References.clear();
675   }
676 
677   /// Iterator to start a find operation; it returns find_end() if the
678   /// tree has not been built.
679   /// There is no support to iterate over all the elements of the tree.
find(PointType Point)680   find_iterator find(PointType Point) const {
681     return empty()
682                ? find_end()
683                : find_iterator(&IntervalsLeft, &IntervalsRight, Root, Point);
684   }
685 
686   /// Iterator to end find operation.
find_end()687   find_iterator find_end() const { return End; }
688 };
689 
690 } // namespace llvm
691 
692 #endif // LLVM_ADT_INTERVALTREE_H
693