partition.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. // -*- C++ -*-
  2. // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
  3. //
  4. // This file is part of the GNU ISO C++ Library. This library is free
  5. // software; you can redistribute it and/or modify it under the terms
  6. // of the GNU General Public License as published by the Free Software
  7. // Foundation; either version 3, or (at your option) any later
  8. // version.
  9. // This library is distributed in the hope that it will be useful, but
  10. // WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. // General Public License for more details.
  13. // Under Section 7 of GPL version 3, you are granted additional
  14. // permissions described in the GCC Runtime Library Exception, version
  15. // 3.1, as published by the Free Software Foundation.
  16. // You should have received a copy of the GNU General Public License and
  17. // a copy of the GCC Runtime Library Exception along with this program;
  18. // see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
  19. // <http://www.gnu.org/licenses/>.
  20. /** @file parallel/partition.h
  21. * @brief Parallel implementation of std::partition(),
  22. * std::nth_element(), and std::partial_sort().
  23. * This file is a GNU parallel extension to the Standard C++ Library.
  24. */
  25. // Written by Johannes Singler and Felix Putze.
  26. #ifndef _GLIBCXX_PARALLEL_PARTITION_H
  27. #define _GLIBCXX_PARALLEL_PARTITION_H 1
  28. #include <parallel/basic_iterator.h>
  29. #include <parallel/sort.h>
  30. #include <parallel/random_number.h>
  31. #include <bits/stl_algo.h>
  32. #include <parallel/parallel.h>
  33. /** @brief Decide whether to declare certain variables volatile. */
  34. #define _GLIBCXX_VOLATILE volatile
  35. namespace __gnu_parallel
  36. {
  37. /** @brief Parallel implementation of std::partition.
  38. * @param begin Begin iterator of input sequence to split.
  39. * @param end End iterator of input sequence to split.
  40. * @param pred Partition predicate, possibly including some kind of pivot.
  41. * @param num_threads Maximum number of threads to use for this task.
  42. * @return Number of elements not fulfilling the predicate. */
  43. template<typename RandomAccessIterator, typename Predicate>
  44. typename std::iterator_traits<RandomAccessIterator>::difference_type
  45. parallel_partition(RandomAccessIterator begin, RandomAccessIterator end,
  46. Predicate pred, thread_index_t num_threads)
  47. {
  48. typedef std::iterator_traits<RandomAccessIterator> traits_type;
  49. typedef typename traits_type::value_type value_type;
  50. typedef typename traits_type::difference_type difference_type;
  51. difference_type n = end - begin;
  52. _GLIBCXX_CALL(n)
  53. const _Settings& __s = _Settings::get();
  54. // Shared.
  55. _GLIBCXX_VOLATILE difference_type left = 0, right = n - 1;
  56. _GLIBCXX_VOLATILE difference_type leftover_left, leftover_right;
  57. _GLIBCXX_VOLATILE difference_type leftnew, rightnew;
  58. bool* reserved_left = NULL, * reserved_right = NULL;
  59. difference_type chunk_size = __s.partition_chunk_size;
  60. omp_lock_t result_lock;
  61. omp_init_lock(&result_lock);
  62. //at least two chunks per thread
  63. if(right - left + 1 >= 2 * num_threads * chunk_size)
  64. # pragma omp parallel num_threads(num_threads)
  65. {
  66. # pragma omp single
  67. {
  68. num_threads = omp_get_num_threads();
  69. reserved_left = new bool[num_threads];
  70. reserved_right = new bool[num_threads];
  71. if (__s.partition_chunk_share > 0.0)
  72. chunk_size = std::max<difference_type>(__s.partition_chunk_size,
  73. (double)n * __s.partition_chunk_share
  74. / (double)num_threads);
  75. else
  76. chunk_size = __s.partition_chunk_size;
  77. }
  78. while (right - left + 1 >= 2 * num_threads * chunk_size)
  79. {
  80. # pragma omp single
  81. {
  82. difference_type num_chunks = (right - left + 1) / chunk_size;
  83. for (int r = 0; r < num_threads; ++r)
  84. {
  85. reserved_left[r] = false;
  86. reserved_right[r] = false;
  87. }
  88. leftover_left = 0;
  89. leftover_right = 0;
  90. } //implicit barrier
  91. // Private.
  92. difference_type thread_left, thread_left_border,
  93. thread_right, thread_right_border;
  94. thread_left = left + 1;
  95. // Just to satisfy the condition below.
  96. thread_left_border = thread_left - 1;
  97. thread_right = n - 1;
  98. thread_right_border = thread_right + 1;
  99. bool iam_finished = false;
  100. while (!iam_finished)
  101. {
  102. if (thread_left > thread_left_border)
  103. {
  104. omp_set_lock(&result_lock);
  105. if (left + (chunk_size - 1) > right)
  106. iam_finished = true;
  107. else
  108. {
  109. thread_left = left;
  110. thread_left_border = left + (chunk_size - 1);
  111. left += chunk_size;
  112. }
  113. omp_unset_lock(&result_lock);
  114. }
  115. if (thread_right < thread_right_border)
  116. {
  117. omp_set_lock(&result_lock);
  118. if (left > right - (chunk_size - 1))
  119. iam_finished = true;
  120. else
  121. {
  122. thread_right = right;
  123. thread_right_border = right - (chunk_size - 1);
  124. right -= chunk_size;
  125. }
  126. omp_unset_lock(&result_lock);
  127. }
  128. if (iam_finished)
  129. break;
  130. // Swap as usual.
  131. while (thread_left < thread_right)
  132. {
  133. while (pred(begin[thread_left])
  134. && thread_left <= thread_left_border)
  135. ++thread_left;
  136. while (!pred(begin[thread_right])
  137. && thread_right >= thread_right_border)
  138. --thread_right;
  139. if (thread_left > thread_left_border
  140. || thread_right < thread_right_border)
  141. // Fetch new chunk(s).
  142. break;
  143. std::swap(begin[thread_left], begin[thread_right]);
  144. ++thread_left;
  145. --thread_right;
  146. }
  147. }
  148. // Now swap the leftover chunks to the right places.
  149. if (thread_left <= thread_left_border)
  150. # pragma omp atomic
  151. ++leftover_left;
  152. if (thread_right >= thread_right_border)
  153. # pragma omp atomic
  154. ++leftover_right;
  155. # pragma omp barrier
  156. # pragma omp single
  157. {
  158. leftnew = left - leftover_left * chunk_size;
  159. rightnew = right + leftover_right * chunk_size;
  160. }
  161. # pragma omp barrier
  162. // <=> thread_left_border + (chunk_size - 1) >= leftnew
  163. if (thread_left <= thread_left_border
  164. && thread_left_border >= leftnew)
  165. {
  166. // Chunk already in place, reserve spot.
  167. reserved_left[(left - (thread_left_border + 1)) / chunk_size]
  168. = true;
  169. }
  170. // <=> thread_right_border - (chunk_size - 1) <= rightnew
  171. if (thread_right >= thread_right_border
  172. && thread_right_border <= rightnew)
  173. {
  174. // Chunk already in place, reserve spot.
  175. reserved_right[((thread_right_border - 1) - right)
  176. / chunk_size] = true;
  177. }
  178. # pragma omp barrier
  179. if (thread_left <= thread_left_border
  180. && thread_left_border < leftnew)
  181. {
  182. // Find spot and swap.
  183. difference_type swapstart = -1;
  184. omp_set_lock(&result_lock);
  185. for (int r = 0; r < leftover_left; ++r)
  186. if (!reserved_left[r])
  187. {
  188. reserved_left[r] = true;
  189. swapstart = left - (r + 1) * chunk_size;
  190. break;
  191. }
  192. omp_unset_lock(&result_lock);
  193. #if _GLIBCXX_ASSERTIONS
  194. _GLIBCXX_PARALLEL_ASSERT(swapstart != -1);
  195. #endif
  196. std::swap_ranges(begin + thread_left_border
  197. - (chunk_size - 1),
  198. begin + thread_left_border + 1,
  199. begin + swapstart);
  200. }
  201. if (thread_right >= thread_right_border
  202. && thread_right_border > rightnew)
  203. {
  204. // Find spot and swap
  205. difference_type swapstart = -1;
  206. omp_set_lock(&result_lock);
  207. for (int r = 0; r < leftover_right; ++r)
  208. if (!reserved_right[r])
  209. {
  210. reserved_right[r] = true;
  211. swapstart = right + r * chunk_size + 1;
  212. break;
  213. }
  214. omp_unset_lock(&result_lock);
  215. #if _GLIBCXX_ASSERTIONS
  216. _GLIBCXX_PARALLEL_ASSERT(swapstart != -1);
  217. #endif
  218. std::swap_ranges(begin + thread_right_border,
  219. begin + thread_right_border + chunk_size,
  220. begin + swapstart);
  221. }
  222. #if _GLIBCXX_ASSERTIONS
  223. # pragma omp barrier
  224. # pragma omp single
  225. {
  226. for (int r = 0; r < leftover_left; ++r)
  227. _GLIBCXX_PARALLEL_ASSERT(reserved_left[r]);
  228. for (int r = 0; r < leftover_right; ++r)
  229. _GLIBCXX_PARALLEL_ASSERT(reserved_right[r]);
  230. }
  231. # pragma omp barrier
  232. #endif
  233. # pragma omp barrier
  234. left = leftnew;
  235. right = rightnew;
  236. }
  237. # pragma omp flush(left, right)
  238. } // end "recursion" //parallel
  239. difference_type final_left = left, final_right = right;
  240. while (final_left < final_right)
  241. {
  242. // Go right until key is geq than pivot.
  243. while (pred(begin[final_left]) && final_left < final_right)
  244. ++final_left;
  245. // Go left until key is less than pivot.
  246. while (!pred(begin[final_right]) && final_left < final_right)
  247. --final_right;
  248. if (final_left == final_right)
  249. break;
  250. std::swap(begin[final_left], begin[final_right]);
  251. ++final_left;
  252. --final_right;
  253. }
  254. // All elements on the left side are < piv, all elements on the
  255. // right are >= piv
  256. delete[] reserved_left;
  257. delete[] reserved_right;
  258. omp_destroy_lock(&result_lock);
  259. // Element "between" final_left and final_right might not have
  260. // been regarded yet
  261. if (final_left < n && !pred(begin[final_left]))
  262. // Really swapped.
  263. return final_left;
  264. else
  265. return final_left + 1;
  266. }
  267. /**
  268. * @brief Parallel implementation of std::nth_element().
  269. * @param begin Begin iterator of input sequence.
  270. * @param nth Iterator of element that must be in position afterwards.
  271. * @param end End iterator of input sequence.
  272. * @param comp Comparator.
  273. */
  274. template<typename RandomAccessIterator, typename Comparator>
  275. void
  276. parallel_nth_element(RandomAccessIterator begin, RandomAccessIterator nth,
  277. RandomAccessIterator end, Comparator comp)
  278. {
  279. typedef std::iterator_traits<RandomAccessIterator> traits_type;
  280. typedef typename traits_type::value_type value_type;
  281. typedef typename traits_type::difference_type difference_type;
  282. _GLIBCXX_CALL(end - begin)
  283. RandomAccessIterator split;
  284. random_number rng;
  285. const _Settings& __s = _Settings::get();
  286. difference_type minimum_length = std::max<difference_type>(2,
  287. std::max(__s.nth_element_minimal_n, __s.partition_minimal_n));
  288. // Break if input range to small.
  289. while (static_cast<sequence_index_t>(end - begin) >= minimum_length)
  290. {
  291. difference_type n = end - begin;
  292. RandomAccessIterator pivot_pos = begin + rng(n);
  293. // Swap pivot_pos value to end.
  294. if (pivot_pos != (end - 1))
  295. std::swap(*pivot_pos, *(end - 1));
  296. pivot_pos = end - 1;
  297. // XXX Comparator must have first_value_type, second_value_type,
  298. // result_type
  299. // Comparator == __gnu_parallel::lexicographic<S, int,
  300. // __gnu_parallel::less<S, S> >
  301. // pivot_pos == std::pair<S, int>*
  302. // XXX binder2nd only for RandomAccessIterators??
  303. __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
  304. pred(comp, *pivot_pos);
  305. // Divide, leave pivot unchanged in last place.
  306. RandomAccessIterator split_pos1, split_pos2;
  307. split_pos1 = begin + parallel_partition(begin, end - 1, pred,
  308. get_max_threads());
  309. // Left side: < pivot_pos; right side: >= pivot_pos
  310. // Swap pivot back to middle.
  311. if (split_pos1 != pivot_pos)
  312. std::swap(*split_pos1, *pivot_pos);
  313. pivot_pos = split_pos1;
  314. // In case all elements are equal, split_pos1 == 0
  315. if ((split_pos1 + 1 - begin) < (n >> 7)
  316. || (end - split_pos1) < (n >> 7))
  317. {
  318. // Very unequal split, one part smaller than one 128th
  319. // elements not strictly larger than the pivot.
  320. __gnu_parallel::unary_negate<__gnu_parallel::
  321. binder1st<Comparator, value_type, value_type, bool>, value_type>
  322. pred(__gnu_parallel::binder1st<Comparator, value_type,
  323. value_type, bool>(comp, *pivot_pos));
  324. // Find other end of pivot-equal range.
  325. split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
  326. end, pred);
  327. }
  328. else
  329. // Only skip the pivot.
  330. split_pos2 = split_pos1 + 1;
  331. // Compare iterators.
  332. if (split_pos2 <= nth)
  333. begin = split_pos2;
  334. else if (nth < split_pos1)
  335. end = split_pos1;
  336. else
  337. break;
  338. }
  339. // Only at most _Settings::partition_minimal_n elements left.
  340. __gnu_sequential::nth_element(begin, nth, end, comp);
  341. }
  342. /** @brief Parallel implementation of std::partial_sort().
  343. * @param begin Begin iterator of input sequence.
  344. * @param middle Sort until this position.
  345. * @param end End iterator of input sequence.
  346. * @param comp Comparator. */
  347. template<typename RandomAccessIterator, typename Comparator>
  348. void
  349. parallel_partial_sort(RandomAccessIterator begin,
  350. RandomAccessIterator middle,
  351. RandomAccessIterator end, Comparator comp)
  352. {
  353. parallel_nth_element(begin, middle, end, comp);
  354. std::sort(begin, middle, comp);
  355. }
  356. } //namespace __gnu_parallel
  357. #undef _GLIBCXX_VOLATILE
  358. #endif /* _GLIBCXX_PARALLEL_PARTITION_H */