您说:
C++中nth_element方法是怎么实现的?
std::nth_element
并不会把整个区间排好序;它只保证第 n
个位置上的元素,等于把整个区间完全排序后会出现在那个位置的元素,并且使得:
[first, nth)
中的每个元素都 不大于 *nth
(按比较器 comp
判断);
(nth, last)
中的每个元素都 不小于 *nth
;
但这两个子区间内部的相对次序是未指定的(所以它不稳定)。
典型实现基于 Quickselect(快速选择),也就是“快速排序 + 只递归一边”的做法,常见再加上“内省式选择(introselect)”以避免极端退化:
选主元(pivot):从区间中取一个代表性元素作为主元(很多实现用“三数取中”(median-of-three),有的更激进用 ninther)。
划分(partition):像快速排序那样,以主元把区间划成“小于主元 | 等于主元 | 大于主元”的三段,并得到主元块的范围 [L, R]
。
只进入一边:
若 nth
落在 [L, R]
中,工作完成;
若 nth < L
,在左半段递归/迭代;
若 nth > R
,在右半段递归/迭代。
防退化(内省式):若划分深度超出阈值(通常是 2*log2(N)
),改用更稳健的策略(不同库不同,可能换成堆选择/median-of-medians/小段排序等),以避免极端数据导致的二次方时间。
期望时间复杂度:O(N)
(良好主元选择下的平均情况)。
最坏情况:朴素 quickselect 为 O(N^2)
;带“内省式”退避后,最坏情况会被限制(不同实现的上界略有差异)。
空间复杂度:O(1)
额外空间(就地划分,只需常数额外存储)。
稳定性:不稳定(等值元素的相对次序不保留)。
下面是一个基于 quickselect 的简化实现示意(只为讲解思路,未做所有边界与微优化处理):
cpp#include <algorithm>
#include <iterator>
#include <utility>
template<class RandomIt, class Compare>
RandomIt partition3(RandomIt first, RandomIt last, typename std::iterator_traits<RandomIt>::value_type const& pivot, Compare comp) {
// 三向划分:返回等于 pivot 的区间 [m1, m2)
auto i = first, lt = first, gt = last;
while (i < gt) {
if (comp(*i, pivot)) {
std::iter_swap(i, lt);
++i; ++lt;
} else if (comp(pivot, *i)) {
--gt;
std::iter_swap(i, gt);
} else {
++i;
}
}
// 现在 [first, lt) < pivot, [lt, gt) == pivot, [gt, last) > pivot
return lt; // 返回等于块的起点;等于块的终点是 gt
}
template<class RandomIt, class Compare>
void nth_element_simple(RandomIt first, RandomIt nth, RandomIt last, Compare comp) {
while (last - first > 1) {
// 选主元:三数取中
RandomIt a = first;
RandomIt b = first + (last - first) / 2;
RandomIt c = last - 1;
if (comp(*b, *a)) std::iter_swap(a, b);
if (comp(*c, *a)) std::iter_swap(a, c);
if (comp(*c, *b)) std::iter_swap(b, c);
auto pivot = *b;
// 三向划分
auto lt = first, i = first, gt = last;
while (i < gt) {
if (comp(*i, pivot)) {
std::iter_swap(i, lt);
++i; ++lt;
} else if (comp(pivot, *i)) {
--gt;
std::iter_swap(i, gt);
} else {
++i;
}
}
// 决定进入哪一边
if (nth < lt) {
last = lt; // 目标在左段
} else if (nth >= gt) {
first = gt; // 目标在右段
} else {
return; // 目标在等于段内,完成
}
// (可选)小区间时直接插排/排序以减少常数因子
if (last - first <= 16) {
std::sort(first, last, comp);
return;
}
}
}
// 便捷重载(使用 operator<)
template<class RandomIt>
void nth_element_simple(RandomIt first, RandomIt nth, RandomIt last) {
nth_element_simple(first, nth, last, std::less<typename std::iterator_traits<RandomIt>::value_type>{});
}
libstdc++ / libc++ / MSVC 都采用“快速选择 + 好的主元选择 + 小段优化”,并在“坏情况”时切换到更稳健的方案(例如堆选择、median-of-medians 或对小范围直接排序)。
主元选择常见是三数取中,有的实现会使用 ninther 等增强版本;小段通常用插入排序或直接 std::sort
以降低常数。
这些都属于实现细节,标准并不强制具体方法;只规定行为与复杂度语义(期望线性时间、就地、非稳定)。
nth_element
找第 k
小/大元素(如中位数、分位数);
快速得到前 k
小(配合把 nth
设为 first + k
后,再对 [first, first+k)
做一次 sort
);
做选择型统计而不是完全排序时,能显著减少开销。
如果你想,我也可以:
给你一版可编译、含单元测试的实现;
展示它与 std::sort
在不同 N
/分布下的性能对比;
演示如何用它求中位数/分位数与 Top-K。
下面把“坏情况”时常见的三类稳健退避策略讲清楚:触发条件、核心思路、复杂度与适用性,并给出可嵌入的示意代码。
实作里常见两类触发器(任一满足就退避):
深度超限(内省式 selection)
设 N = last-first
,令 depth_limit = 2*floor(log2(N))
。每做一次划分(partition)就把深度上限减 1;若减到 0,认为主元选择一直不佳,进入稳健方案。
连续严重失衡的划分
例如某次 partition 后,较小一侧的元素数 < N/16
(阈值可调)。连续几次都这样,也触发退避。
这两类触发避免了“几乎有序”“重复值极多”“对手构造数据”等导致的
O(N^2)
退化。
思路
当只要第 k
小(k = nth-first
)时,用大小为 k+1
的最大堆保留当前见过的最小 k+1
个元素;遍历完后,堆顶就是第 k
小的值。若 k
接近右边(k > N/2
),则对“右侧规模”做对偶处理:用大小为 N-k
的最小堆找第 k
小(= 第 N-k-1
大)。
如何满足 nth_element
的“左右分区”语义?
先用堆找出阈值 pivot_value
,再对区间线性扫描两次 partition
:
把 < pivot_value
移到左边;
在剩余区间里把 == pivot_value
移到中间(这样 nth
必落在等值块内)。
复杂度
时间:O(N log min(k+1, N-k)) + O(N)
(第二段是两次线性 partition
)。
额外空间:用于堆的 O(min(k+1, N-k))
(退避时通常接受非常数空间的折衷)。
优缺点
优点:实现简单、常数小;当 k
很小或很靠近末尾时尤其快。
缺点:不是原地常数空间;当 k ≈ N/2
时对数因子较大。
嵌入式示意代码
cpptemplate<class It, class Compare>
auto heap_select_value(It first, It last, size_t k, Compare comp)
-> typename std::iterator_traits<It>::value_type
{
using T = typename std::iterator_traits<It>::value_type;
const size_t N = last - first;
if (k <= N/2) {
// 保留 k+1 个最小值:最大堆
std::vector<T> h; h.reserve(k+1);
for (It it = first; it != last; ++it) {
if (h.size() < k+1) { h.push_back(*it); std::push_heap(h.begin(), h.end(),
[&](const T& a, const T& b){ return comp(b,a); }); }
else if (comp(*it, h.front())) {
std::pop_heap(h.begin(), h.end(), [&](const T& a, const T& b){ return comp(b,a); });
h.back() = *it;
std::push_heap(h.begin(), h.end(), [&](const T& a, const T& b){ return comp(b,a); });
}
}
return h.front(); // 第 k 小
} else {
// 保留 N-k 个最大值:最小堆
size_t m = N - k;
std::vector<T> h; h.reserve(m);
for (It it = first; it != last; ++it) {
if (h.size() < m) { h.push_back(*it); std::push_heap(h.begin(), h.end(), comp); }
else if (comp(h.front(), *it)) {
std::pop_heap(h.begin(), h.end(), comp);
h.back() = *it;
std::push_heap(h.begin(), h.end(), comp);
}
}
// 这里 h.front() 是第 N-m (=k) 小
return h.front();
}
}
template<class It, class Compare>
void nth_element_heap_fallback(It first, It nth, It last, Compare comp) {
const size_t k = size_t(nth - first);
auto pivot_value = heap_select_value(first, last, k, comp);
// 两阶段分区:< pivot | == pivot | > pivot
It mid1 = std::partition(first, last, [&](const auto& x){ return comp(x, pivot_value); });
It mid2 = std::partition(mid1, last, [&](const auto& x){
return !comp(pivot_value, x) && !comp(x, pivot_value); // x == pivot_value
});
// 若 nth 落在 [mid1, mid2) 外,是因为比较器存在等价类定义问题;一般会落在中间块。
}
思路
当检测到主元质量持续很差时,不再靠“经验型”主元(如三数取中、ninther),而是用 Blum–Floyd–Pratt–Rivest–Tarjan 的分组取中位数法来构造一个保证良好的主元:
把区间分成若干组(通常每组 5 个);各组就地排序,取出每组中位数;
递归地在“中位数们”的数组上找中位数,作为 pivot
;
用这个 pivot
做三向划分,再只递归一边。
性质
保证每次划分至少 ~30% : 70% 的分割(常见教科书界限),因此最坏时间是 O(N)
。
额外空间:可以做到 O(1)
(把“各组中位数”搬到数组前部使用),但实现稍复杂。
优缺点
优点:给出严格的线性最坏界,彻底杜绝 O(N^2)
。
缺点:常数因子明显偏大,平均性能往往不如 quickselect/堆选择;通常只在退避时启用。
嵌入式示意代码(主元选取片段)
cpptemplate<class It, class Compare>
auto select_pivot_mom(It first, It last, Compare comp)
-> typename std::iterator_traits<It>::value_type
{
using T = typename std::iterator_traits<It>::value_type;
size_t n = last - first;
if (n <= 5) {
std::sort(first, last, comp);
return *(first + n/2);
}
// 将每组5个的中位数“搬到前面”
It write = first;
for (It g = first; g < last; g += 5) {
It g_end = std::min(g + 5, last);
std::sort(g, g_end, comp);
std::iter_swap(write++, g + (g_end - g)/2);
}
// 递归地在 [first, write) 找其中位数
return select_pivot_mom(first, write, comp);
}
在
nth_element
的循环里,当触发退避时,把pivot = select_pivot_mom(...)
后做一次三向划分;由于 MoM 保证划分比例良好,后续递归深度迅速回落。
思路
当当前处理的子区间长度 len <= cutoff
(如 16、24、32)时,直接排序整个子区间(std::sort
或插入排序),因为:
小 N
时比较/交换的常数主导,分支与函数调用的开销比排序本身还大;
排好后 nth
自然在正确位置,满足语义。
复杂度
时间:O(len log len)
或插排的 O(len^2)
(但 len
很小,整体影响不大);
空间:O(1)
;
优点:实现简单、常数极小,是所有实作都会使用的微优化。
嵌入式示意代码
cpptemplate<class It, class Compare>
bool maybe_sort_small(It& first, It& nth, It& last, Compare comp, int cutoff = 24) {
if (last - first <= cutoff) {
std::sort(first, last, comp);
return true;
}
return false;
}
cpptemplate<class It, class Compare>
void nth_element_introselect(It first, It nth, It last, Compare comp) {
auto n0 = last - first;
int depth_limit = 2 * (int)std::log2(std::max<ptrdiff_t>(n0, 1));
while (true) {
if (first >= last || nth >= last) return;
if (maybe_sort_small(first, nth, last, comp)) return;
// 经验型主元(例如三数取中 / ninther)
It a = first, b = first + (last-first)/2, c = last-1;
if (comp(*b,*a)) std::iter_swap(a,b);
if (comp(*c,*a)) std::iter_swap(a,c);
if (comp(*c,*b)) std::iter_swap(b,c);
auto pivot = *b;
// 三向划分
It lt = first, i = first, gt = last;
while (i < gt) {
if (comp(*i, pivot)) { std::iter_swap(i++, lt++); }
else if (comp(pivot, *i)) { std::iter_swap(i, --gt); }
else { ++i; }
}
// 判断是否完成
if (nth < lt) { last = lt; }
else if (nth >= gt) { first = gt; }
else { return; }
// 退避触发:深度或分割比例
--depth_limit;
ptrdiff_t L = lt - first, R = last - gt; // 进入前的左右长度估计
bool bad_split = (std::min(L, R) * 16) < (last - first); // 比如 < N/16
if (depth_limit < 0 || bad_split) {
// 任选一种稳健方案:MoM 或 堆选择
// A) MoM:强力改善下一次划分
auto pv = select_pivot_mom(first, last, comp);
// 用 pv 再做一次三向划分(略,与上同),再回到 while
// B) 或直接用堆得到值并两次 partition,然后 return:
// nth_element_heap_fallback(first, nth, last, comp); return;
}
}
}
实战里常把 C(小段排序) 放在最外层微优化;
A(MoM) 与 B(堆选择) 只在“很坏”时启用:
如果你更在乎最坏界,倾向用 MoM;
如果你更在乎工程常数且
k
偏小/偏大,倾向用 堆选择。
默认:quickselect + 三向划分 + 小段直接排序。
加一道保险:depth_limit
触发后,使用 MoM 改善一次主元(随后多数情况会回到快速路径),或直接一次性用堆选择完成。
k
很小(如找 Top-K 或第 1、2、3 小):堆选择常更快;找到阈值后再一次性分区即可。
数据可能被对手构造(如在线评测、服务器端):务必保留 MoM 退避,确保没有 O(N^2)
。
如果你愿意,我可以基于你的编译器/平台,给出一版可编译的 nth_element
内省式实现(含基准测试脚本),对比不同退避策略在多种分布(均匀、重复值多、几乎有序、锯齿分布等)下的性能曲线。
请输入需要翻译的文本。