C++ STL 中那些实用高效的算法

最近在刷 LeetCode 的时候发现, STL 里除了 sort、swap之类常用的算法之外,还有很多高效的算法可以大大简便我们的代码,虽然有些在题目中用到的不多,但都挺实用的,因此这在里记录一下那些 STL 中的实用方便的算法。

我们都知道 STL 里的容器因为迭代器的存在,我们可以以一个统一方式来使用 STL 中的算法,因此后面的算法作用在容器上时,我默认用的都是 vector<int> v,其他容器基本也是相同的写法。

排序算法

排序算法可以说是平时写题目用的最多的几种算法之一了,一个 sort 基本可以解决大部分需要排序的场合,相比手写快排可是节省了不少时间,而且一般来说 STL 里的时间复杂度都会比我们自己手写的略低一些,除非特殊情况,一般情况下能用 STL 肯定是最好的。

sort(v.begin(), v.end()); // 最基本的使用方法,默认升序排序
  
bool cmp(int x, int y) {
    return x > y;
}
sort(v.begin(), v.end(), cmp); // 也可以使用自定义的比较函数
sort(v.begin(), v.end(), greater<int>()) // 如果只是简单的降序排序可以使用 greater,相对应的升序的比较函数就是 less,这两个相当于是 STL 帮你封装好的简易比较函数(对象)
sort(v.begin(), v.end(), [](int a, int b){return a>b;}) // 更高阶的用法就是使用 lambda 表达式了,配合其他 STL 算法可以优雅简洁的实现复杂规则的排序

int a[10];
sort(a, a+10, cmp); // 用在普通数组上也是可以的

以上基本就是一个 STL 算法的几种常用的使用方式,很多 STL 算法都是类似 sort 的这些用法,后面就不多展开了。

一种 sort 显然不能实现所有的排序需求,比如 sort 的底层是使用了多种排序算法,不过其中快排和对排都是不稳定的排序,因此 sort 本身也是个不稳定的排序,如果想要一个稳定的排序我们可以使用 stable_sort

stable_sort(v.bengin(), v.end()) // stable_sort 底层大致是归并排序的思路,用法基本同 sort

划分的算法其实也可以算作是排序算法中的一部分,比如快速排序本质上就是就是进行不断的划分来实现的,partition 就是个比较常用的划分算法,就是类似快排划分的思路,后面传递一个布尔函数,如果判断是 ture,就放前面,是 false 就放后面

auto bound = partition(v.begin(), v.end(), [](int i){return i%2 == 0;}) 
// 这里的例子执行后,v 就变成了前部分是偶数,后部分是奇数,返回值 bound 就是指向后部分的第一个数的迭代器,也就是划分的分界位置

同样的因为是快排划分的思路,所以是不稳定的划分,稳定的可以使用 stable_partition , 用法相同不多赘述,不过值得注意的一点是 partition 返回的是一个正向的迭代器,也就是只能往后走,是找不到前一个元素的,而 stable_partition 返回的是一个双向的迭代器,可以往前也可以往后。不过我们一般来说我们划分是为将一组数据分成两类,拆开来,这个时候能就会发现 bound 迭代器指向的刚好就是前一组数据的末尾元素的后一个,也同时是第二组数据的第一个,相当于就是前一组的 end() 和 后一组的 begin(),因此后面可以很方便的进行其他操作

STL 里还有一种平时不太常用的排序算法,叫做 partial_sort, 是局部排序,需要传递三个随机访问迭代器,比如下面例子中,它会将数组中最小的前 n 个元素移到前 n 个位置,思路就是 topk 问题中最大堆的方法,然后对前 n 个元素进行升序排序。

partial_sort(v.begin(), v.begin() + n, v.end());// 同样的也可以再传递一个自定义的比较函数

partial_sort 还有个进阶版本 叫 partial_sort_copy,它不会修改原数组,而是会复制到另一个数组

下面这个例子中如果 m > n 的话,那么结果就相当于把 v 中所有数复制到 v2 的前 n 个位置后对前 n 个位置进行排序,如果 n > m 的话,就是将 v 中最小的 n 个数,移到 v2 中,再对整个 v2 进行排序。

当然实际上没这么麻烦,就是原本的思路,只不过建最大堆的时候建在了 v2 上,其他和 partial_sort 其实是类似的

vector<int>v(n);
vector<int>v2(m);
// 这里省略初始化 v 中的值
auto pos = partial_sort_copy(v.begin(), v.end(), v2.begin(), v2.end());
// 同理可以再传递自定义比较函数

最后还有一个 nth_element,就是经典的 topk 问题,思想就是上面的 partition 的思想,这种类似快速排序的算法其实有个名字叫快速选择算法,执行完下面的例子后,第 k 个元素前面的元素都比它小,后面的元素都比它大,然后取前 k 个元素,就是 topk 问题的解法了。

nth_element(v.begin(), v.begin() + k, v.end()); // 同理可以再传递自定义比较函数

最大最小

max 和 min 就不用说了吧,应该没有比这两个更常用的了,不过这两个函数只能比较两个数,如果需要比较三个就需要嵌套使用

max(a, b);
min(a, b, cmp);// 其实也可以传递自定义比较函数
auto it = max_element(v.begin(), v.end()); // 寻找数组中的最大最小元素
auto it = min_element(v.begin(), v.end(), cmp); // 可以传递自定义比较函数
// 返回最大元素的迭代器

反转和旋转

reverse 也算是比较常用的一个算法了,就是简单的反序

reverse(v.begin(), v.end());
reverse_copy(v.begin(), v.end(), v2.begin()); // 不改变原数组,而是反转复制到 v2 中

还有一个不大常用的 rotate 旋转算法,不过 leetcode 上其实有几道旋转算法的题,主要是旋转链表,这个算法的作用是将前 k 个元素移动到后面,比如 1 2 3 4 5 6 7, 旋转前 3 个就是 4 5 6 7 1 2 3

rotate(v.begin(), v.begin() + k, v.end());
rotate_copy(v.begin(), v.begin() + k, v.end(), v2.begin());// 同样有 copy 版

这个算法其实底层实现很巧妙,针对前向迭代器、双向和随机迭代器有三种不同的实现方法,感兴趣的可以自行搜索相关资料查询,其中双向迭代器版本的底层实现就是我们旋转链表时常用的方法,就是前部分反转,后部分反转,再整个反转

查找算法

查找可以说也是其中比较常用的算法了,首先是最通用的查找 find,查找某个元素是否存在,存在则返回第一个的位置,如果不存在返回的是结束位置,比如下面就是 v.end()。不过这个算法其实平时用的不多,因为像 map 之类的经常需要查找的容器都有自己的 find 函数,这个通用版的也用不了,而 vector 之类类似数组的容器,很少需要在无序的情况下查找指定元素,一般都是有序的情况下进行二分查找

auto it = find(v.begin(), v.end(), x);
auto it = find_if(v.begin(), v.end(), is_prime);
// 还有个版本可以传递一个布尔函数,查找满足某个条件的元素,比如是否有质数之类的

// 在 v 数组中寻找是否存在 v2 数组中的元素, 可以再传递一个比较函数,代替原本的 ==
auto it = find_first_of(v.begin(), v.end(), v2.begin(), v2.end());

// 这个和前面的不太一样,相当于是在 v1 中寻找和 v2 一样的连续子序列
// 返回的是最后一个匹配上的连续子序列的第一个元素的迭代器
vector<int>v{1,2,3,4,1,2,1};
vector<int>v2{1,2};
auto it = find_end(v.begin(), v.end(), v2.begin(), v2.end());
// 比如上面这个例子,返回的v.begin()+4 这个位置,也就是第二个 1、2,中 1 的位置
// 必须和 v2 中顺序都一样才能匹配的上
// 同样可以传递一个自定义比较函数


// 最后是一个查找是否有连续相等元素的,反复的是连续相等元素的第一个元素位置
// 其实就是找到第一个 v[i+1] == v[i],返回 i 这个位置的迭代器
auto it = adjacent_find(v.begin(), v.end()); // 同样可以再传递一个自定义比较函数

这种无序元素的查找相对用的还是比较少,用的比较多的还是有序元素的查找,STL 提供了三种二分查找函数,分别是 lower_boundupper_boundequal_range,用法基本相同

auto it = lower_bound(v.begin(), v.end(), val); // 二分查找第一个大于等于 val 的元素位置
auto it = upper_bound(v.begin(), v.end(), val); // 二分查找第一个大于 val 的元素位置
auto pair_it = equal_range(v.begin(), v.end(), val) // 相当于同时返回上面两个的结果,返回值是一个 pair,分别指向 lower_bound 和 upper_bound 的结果,这个区间就是和 val 相等的元素的区间
// 相减就是与 val 元素相等的元素数量
// 同样的以上三个函数都可以通过,再传递一个自定义的比较函数

除了这些查找的还有一些搜索相关的函数,类似于 find 系列

auto it = search(v.begin(), v.end(), v2.begin(), v2.end()); 
// 用法和 find_end 完全一样,区别就是 find_end 是找的最后一个匹配上的,而 search 是找的第一个匹配上的,不知道为什么不叫 find_first

auto it = search_n(v.begin(), v.end(), n, val);
// 就是把 v2 换成了 n 个连续的 val,同样查找到一个匹配上的

bool has_val = binary_search(v.begin(), v.end(), val);
// 二分查找是否存在 val ,存在则返回 true,反之 false

// 还是一样都能添加自定义比较函数

最后还有一个 count 系列的函数用来统计数量

auto cnt = count(v.begin(), v.end(), val);
// 返回和 val 相等的元素个数

// 还有个 if 版本,将 val 换成一个布尔函数,就可以统计满足条件的元素个数
auto cnt = count_if(v.begin(), v.end(), [](int x){return x%2 == 0;});
// 这里以统计偶数为例

填充

填充有用来填充相同元素的 fill 和 不同元素的 generate 两种,fill 相对来说因为各种容器一般都有自己的初始化方法,如果需要完全赋值成 0,-1之类的值的话,memset 其实效率更高,所以实际情况下需要用到 fill 的时候不多,只有局部赋值一些一般的数值可能才用的到

fiil(v.begin(), v.end(), val);
// 将 [begin, end) 区域填充满 val
fill_n(v.begin, n, val);
// 等价于 fill(v.begin(), v.begin()+n, val);
// 似乎就是一个更简洁的写法
  
  
generate(v.begin(), v.end(), random);
// 相当于将每个元素填充为函数的返回值,一般可能用随机函数多一点
// 其实也可以通过局部静态变量来实现一些递增赋值的效果
// 传递的这个函数没有形参,也就是个纯粹的生成用的函数
generate_n(v.begin(), n, random);
// 和 fill_n 同理

遍历

这里的所谓遍历其实是依次对每个元素调用一次指定的函数,只不过这里是不修改原本元素罢了

for_each(v.begin(), v.end(), [](int i) { cout << i << " "; });
// 这个就是纯粹的遍历了,如果不通过全局遍历或者 lambda 捕获外部的变量或者局部静态变量的方法
// 基本就只能输出元素结果了吧
// 这个函数其实有返回值,返回了传进去的 func 的副本,如果普通函数的话没什么意义,但如果是函数对象的话,那么其实可以从对象中获取一些执行后所需的信息

transform(v.begin(), v.end(), v2.begin(), [](int i){return i*2; });
// transform 顾名思义是个变换函数,第三个参数是结果存放位置,比如上面的例子就是将 v 中所有元素 *2 后放入 v2 中,如果把 v2 换成 v 就相当于是原地修改了
// 这个函数也可以传递一些比如 toupper、tolower 之类的 c++ 自带的大小写转换函数等等,甚至是其他 STL 的算法,可以有很多玩法

transform(v.begin(), v.end(), v2.begin(), v3.begin(), [](int i, int j){return i + j;});
// transform 还有另一个用法,再传递一个数组,可以实现两个数组对应位置元素进行计算
// 上面的例子等价于下面的代码
for(auto it1 = v.begin(), it2 = v2.begin(), it3 = v3.begin(); it1 != v.end();it1++,it2++, it3++) {
    *it3  = *it1 + *it2;
}
// 可以看到这里都是以第一个数组为主的,如果 v2 和 v3 数组长度比 v1 小就可能出现越界
// 这种二元运算的函数如果是一般的运算可以使用 STL 自带的一些算术仿函数,比如 std::plus<int>()

未完待续...