题目 给定一个长度为 n 的数组 a1,a2,…,an。
现在,要将该数组从中间截断,得到三个非空 子数组。
要求,三个子数组内各元素之和都相等。
请问,共有多少种不同的截断方法?
输入格式 第一行包含整数 n。
第二行包含 n 个整数 a1,a2,…,an。
输出格式 输出一个整数,表示截断方法数量。
数据范围 前六个测试点满足 1≤n≤10。 所有测试点满足 1≤n≤105,−10000≤ai≤10000。
样例 输入样例1:
4 1 2 3 3
输出样例1:
1
输入样例2:
5 1 2 3 4 5
输出样例2:
0
输入样例3:
2 0 0
输出样例3:
0
题解 一开始就想着单纯的前缀和,求出来前缀和之后在暴力枚举i
和j
,结果不负众望,超时了。超时代码 如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 #include <bits/stdc++.h> using namespace std;typedef long long i64;const int N = 1e5 + 5 ;int n, ans = 0 ;i64 nums[N], sum[N]; i64 sum1, sum2, sum3; int main () { ios::sync_with_stdio (false ); cin.tie (0 ); cin >> n; for (int i = 0 ; i < n; ++i) { cin >> nums[i]; if (i != 0 ) sum[i] = sum[i-1 ] + nums[i]; else sum[i] = nums[i]; } per (i, 0 , n) cout << sum[i] << " " ; cout << endl; for (int i = 0 ; i < n-1 ; ++i) { for (int j = i + 1 ; j < n; ++j) { sum1 = sum[i]; sum2 = sum[j] - sum[i]; sum3 = sum[n-1 ] - sum[j]; if (sum1 == sum2 && sum2 == sum3) ans++; } } cout << ans << endl; return 0 ; }
之后考虑三等分,那也就是说第一段就是和的三分之一。第二段也是。就没必要这么比了。
首先先判断和除以三能否除尽。不能除尽输出0结束。
除以三得到num,然后遍历得到i
的列表。列表内容是满足sum[i] == num
的i
值。这样第一部分就找好了。注意,要至少留两个数。
理论上,可以如此找第二部分,再把第二部分的也存上,要至少留一个数。假设这么做的话就得到vj
,上面得到的是vi
.结果就是:
1 2 3 4 5 for (int i = 0 ; i < vi.size (); ++i) { for (int j = 0 ; j < vj.size (); ++j) { if (vi[i] < vj[j]) ++ans; } }
但是这样还是要枚举,且复杂度可能不低。怎么优化呢。
优化办法就是再存一个数组,大小和nums
一样大,然后在得到第一部分的列表时,顺便就得到了sumi
的中每个i
的值。要求其实就是如果sum[i] == 2*num
也就是可以从i
位置截取为第二下,就让sumi[i] = 1
或者直接++
。
最后再求一下sumi
的前缀和。
用一个例子来举例吧:
nums = [1, 9, -9, 2, 2, 0, 2, 3, 4, -1, 2]
sum = [1, 10, 1, 3, 5, 5, 7, 10, 14, 13, 15]
nums的前缀和
vi = [4, 5]
这里就是满足sum = 5的索引
sumi = [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0]
如果 sum[i] == 10
那么sumi[i] = 1
sumi = [0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2]
sumi的前缀和
答案比较明显,就是2.即(4, 7)和(5,7)。
接下来解的个数为:
1 2 3 for (int i = 0 ; i < vi.size (); ++i) { ans += sumi[n-1 ] - sumi[vi[i]]; }
解释:
总共遍历两次,因为vi.size() == 2
, sumi[n-1] = 2
vi[i] = 4
sumi[4] = 1
ans += 2 - 1
vi[i] = 5
sumi[5] = 1
ans += 2 - 1
最终结果为2.
整体代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 #include <bits/stdc++.h> using namespace std;typedef long long i64;const int N = 1e5 +5 ;int n;i64 nums[N], sum[N], sumi[N], num, ans = 0 ; vector<int > vi; int main () { cin >> n; if (n < 3 ) { cout << 0 ; return 0 ; } for (int i = 0 ; i < n; ++i) { cin >> nums[i]; if (i != 0 ) sum[i] = sum[i-1 ] + nums[i]; else sum[i] = nums[i]; } if (sum[n-1 ] % 3 ) { cout << 0 ; return 0 ; } num = sum[n-1 ] / 3 ; for (int i = 0 ; i < n - 1 ; ++i) { if (sum[i] == num) vi.push_back (i); if (sum[i] == 2 * num) ++sumi[i]; } for (int i = 1 ; i < n; ++i) sumi[i] += sumi[i-1 ]; for (int i = 0 ; i < vi.size (); ++i) { ans += sumi[n-1 ] - sumi[vi[i]]; } cout << ans << endl; }
上面这样做就已经可以AC了,但是这个复杂度还是比较高。下面继续优化。
分析内涵。还是从这段代码说起:
1 2 3 4 5 for (int i = 0 ; i < vi.size (); ++i) { for (int j = 0 ; j < vj.size (); ++j) { if (vi[i] < vj[j]) ++ans; } }
其实即使要求每一段的i
后面有多少个j
。
换个想法,如果切到了第二刀,我们记录它前面有多少个第一刀不就好了。
那么就有这种想法:
1 2 3 4 for (int i = 0 ; i < n - 1 ; ++i) { if (sum[i] == 2 * num) ans += cnt; if (sum[i] == num) cnt++; }
即cnt
记录到是遍历到当前位置,有多少个第一刀。如果遍历到了第二刀,就直接累计它前面的第一刀即可。
写出完整代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 #include <bits/stdc++.h> using namespace std;typedef long long i64;const int N = 1e5 +5 ;int n;i64 nums[N], sum[N], num, ans = 0 , cnt; vector<int > vi; int main () { cin >> n; if (n < 3 ) { cout << 0 ; return 0 ; } for (int i = 0 ; i < n; ++i) { cin >> nums[i]; if (i != 0 ) sum[i] = sum[i-1 ] + nums[i]; else sum[i] = nums[i]; } if (sum[n-1 ] % 3 ) { cout << 0 ; return 0 ; } num = sum[n-1 ] / 3 ; for (int i = 0 ; i < n - 1 ; ++i) { if (sum[i] == 2 *num) ans += cnt; if (sum[i] == num) ++cnt; } cout << ans << endl; }
最后还可以优化一下空间。
因为发现前缀和数组和传统用法不一样。所以用一个变量记录sum[i]
即可,不需要完整的前缀和数组。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 include <bits/stdc++.h> using namespace std;typedef long long i64;const int N = 1e5 +5 ;int n;i64 nums[N], sum, num, ans = 0 , cnt, tot; vector<int > vi; int main () { cin >> n; for (int i = 0 ; i < n; ++i) { cin >> nums[i]; sum += nums[i]; } if (n < 3 ) cout << 0 ; else if (sum % 3 ) cout << 0 ; else { num = sum / 3 ; for (int i = 0 ; i < n - 1 ; ++i) { tot += nums[i]; if (tot == 2 *num) ans += cnt; if (tot == num) ++cnt; } } cout << ans << endl; }