题目

给定一个长度为 n 的数组 a1,a2,…,an。

现在,要将该数组从中间截断,得到三个非空子数组。

要求,三个子数组内各元素之和都相等。

请问,共有多少种不同的截断方法?

输入格式

第一行包含整数 n。

第二行包含 n 个整数 a1,a2,…,an。

输出格式

输出一个整数,表示截断方法数量。

数据范围

前六个测试点满足 1≤n≤10。
所有测试点满足 1≤n≤105,−10000≤ai≤10000。

样例

输入样例1:

4
1 2 3 3

输出样例1:

1

输入样例2:

5
1 2 3 4 5

输出样例2:

0

输入样例3:

2
0 0

输出样例3:

0

题解

一开始就想着单纯的前缀和,求出来前缀和之后在暴力枚举ij,结果不负众望,超时了。超时代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <bits/stdc++.h>

using namespace std;

typedef long long i64;

const int N = 1e5 + 5;
int n, ans = 0;
i64 nums[N], sum[N];
i64 sum1, sum2, sum3;

int main() {
ios::sync_with_stdio(false);
cin.tie(0);

cin >> n;
for (int i = 0; i < n; ++i) {
cin >> nums[i];
if (i != 0) sum[i] = sum[i-1] + nums[i];
else sum[i] = nums[i];
}
per (i, 0, n) cout << sum[i] << " ";
cout << endl;
for (int i = 0; i < n-1; ++i) {
for (int j = i + 1; j < n; ++j) {
sum1 = sum[i];
sum2 = sum[j] - sum[i];
sum3 = sum[n-1] - sum[j];
if (sum1 == sum2 && sum2 == sum3) ans++;
}
}
cout << ans << endl;

return 0;
}

之后考虑三等分,那也就是说第一段就是和的三分之一。第二段也是。就没必要这么比了。

首先先判断和除以三能否除尽。不能除尽输出0结束。

除以三得到num,然后遍历得到i的列表。列表内容是满足sum[i] == numi值。这样第一部分就找好了。注意,要至少留两个数。

理论上,可以如此找第二部分,再把第二部分的也存上,要至少留一个数。假设这么做的话就得到vj,上面得到的是vi.结果就是:

1
2
3
4
5
for (int i = 0; i < vi.size(); ++i) {
for (int j = 0; j < vj.size(); ++j) {
if (vi[i] < vj[j]) ++ans;
}
}

但是这样还是要枚举,且复杂度可能不低。怎么优化呢。

优化办法就是再存一个数组,大小和nums一样大,然后在得到第一部分的列表时,顺便就得到了sumi的中每个i的值。要求其实就是如果sum[i] == 2*num也就是可以从i位置截取为第二下,就让sumi[i] = 1或者直接++

最后再求一下sumi的前缀和。

用一个例子来举例吧:

nums = [1, 9, -9, 2, 2, 0, 2, 3, 4, -1, 2]

sum = [1, 10, 1, 3, 5, 5, 7, 10, 14, 13, 15] nums的前缀和

vi = [4, 5] 这里就是满足sum = 5的索引

sumi = [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0] 如果 sum[i] == 10 那么sumi[i] = 1

sumi = [0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2] sumi的前缀和

答案比较明显,就是2.即(4, 7)和(5,7)。

接下来解的个数为:

1
2
3
for (int i = 0; i < vi.size(); ++i) {
ans += sumi[n-1] - sumi[vi[i]];
}

解释:

总共遍历两次,因为vi.size() == 2, sumi[n-1] = 2

  1. vi[i] = 4

    sumi[4] = 1

    ans += 2 - 1

  2. vi[i] = 5

    sumi[5] = 1

    ans += 2 - 1

最终结果为2.

整体代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>

using namespace std;

typedef long long i64;

const int N = 1e5+5;
int n;
i64 nums[N], sum[N], sumi[N], num, ans = 0;
vector<int> vi;

int main() {
cin >> n;
if (n < 3) {
cout << 0;
return 0;
}
for (int i = 0; i < n; ++i) {
cin >> nums[i];
if (i != 0) sum[i] = sum[i-1] + nums[i];
else sum[i] = nums[i];
}
if (sum[n-1] % 3) {
cout << 0;
return 0;
}
num = sum[n-1] / 3;
for (int i = 0; i < n - 1; ++i) {
if (sum[i] == num) vi.push_back(i);
if (sum[i] == 2 * num) ++sumi[i];
}
for (int i = 1; i < n; ++i)
sumi[i] += sumi[i-1];

for (int i = 0; i < vi.size(); ++i) {
ans += sumi[n-1] - sumi[vi[i]];
}
cout << ans << endl;
}

上面这样做就已经可以AC了,但是这个复杂度还是比较高。下面继续优化。

分析内涵。还是从这段代码说起:

1
2
3
4
5
for (int i = 0; i < vi.size(); ++i) {
for (int j = 0; j < vj.size(); ++j) {
if (vi[i] < vj[j]) ++ans;
}
}

其实即使要求每一段的i后面有多少个j

换个想法,如果切到了第二刀,我们记录它前面有多少个第一刀不就好了。

那么就有这种想法:

1
2
3
4
for (int i = 0; i < n - 1; ++i) { // 注意:要留个第三段
if (sum[i] == 2 * num) ans += cnt;
if (sum[i] == num) cnt++;
}

cnt记录到是遍历到当前位置,有多少个第一刀。如果遍历到了第二刀,就直接累计它前面的第一刀即可。

写出完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <bits/stdc++.h>

using namespace std;

typedef long long i64;

const int N = 1e5+5;
int n;
i64 nums[N], sum[N], num, ans = 0, cnt;
vector<int> vi;

int main() {
cin >> n;
if (n < 3) {
cout << 0;
return 0;
}
for (int i = 0; i < n; ++i) {
cin >> nums[i];
if (i != 0) sum[i] = sum[i-1] + nums[i];
else sum[i] = nums[i];
}
if (sum[n-1] % 3) {
cout << 0;
return 0;
}
num = sum[n-1] / 3;
for (int i = 0; i < n - 1; ++i) {
if (sum[i] == 2*num) ans += cnt;
if (sum[i] == num) ++cnt;
}
cout << ans << endl;
}

最后还可以优化一下空间。

因为发现前缀和数组和传统用法不一样。所以用一个变量记录sum[i]即可,不需要完整的前缀和数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
include <bits/stdc++.h>

using namespace std;

typedef long long i64;

const int N = 1e5+5;
int n;
i64 nums[N], sum, num, ans = 0, cnt, tot;
vector<int> vi;

int main() {
cin >> n;

for (int i = 0; i < n; ++i) {
cin >> nums[i];
sum += nums[i];
}
if (n < 3) cout << 0;
else if (sum % 3) cout << 0;
else {
num = sum / 3;
for (int i = 0; i < n - 1; ++i) {
tot += nums[i];
if (tot == 2*num) ans += cnt;
if (tot == num) ++cnt;
}
}
cout << ans << endl;
}