题目链接:

http://poj.org/problem?id=2785

题面:

Description

The SUM problem can be formulated as follows: given four lists A, B, C, D of integer values, compute how many quadruplet (a, b, c, d ) ∈ A x B x C x D are such that a + b + c + d = 0 . In the following, we assume that all lists have the same size n .

Input

The first line of the input file contains the size of the lists n (this value can be as large as 4000). We then have n lines containing four integer values (with absolute value as large as 228 ) that belong respectively to A, B, C and D .

Output

For each input file, your program has to write the number quadruplets whose sum is zero.

Sample Input
6
-45 22 42 -16
-41 -27 56 30
-36 53 -37 77
-36 30 -75 -46
26 -38 -10 62
-32 -54 -6 45

Sample Output
5

Hint
Sample Explanation: 
Indeed, the sum of the five following quadruplets is zero:
(-45, -27, 42, 30), (26, 30, -10, -46),
(-32, 22, 56, -46),(-32, 30, -75, 77), (-32, -54, 56, 30).

题意:

给你 N 行 4 列的数,从每一列选取一个数,问使它们的和为0的情况有多少种
(N \leq 4000)

分析过程+代码:

四分为二 + 双指针

  • 把四列的数组手工分为前两列和与后两列,计算前两列和后两列的和(需要 nn2n*n*2 次运算),然后对后两列的和所有可能情况进行统计,并分别排序(快排时间复杂度 nlog(n)n * log(n))。

  • 再用双指针判断前两列和与后两列和的和是否为0,若为0, 则把 前两列中该计算值出现次数 乘上 后两列中该计算值出现次数。(双指针时间复杂度 nn

基于上面的思想写出如下代码:

poj——TLE代码:
#include<iostream>
#include<map>
#include<vector>
#include<algorithm>

using namespace std;

const int maxn = 4004;

int n;
int row1[maxn] = {};
int row2[maxn] = {};
int row3[maxn] = {};
int row4[maxn] = {};

map<int, int> cnt12;
map<int, int> cnt34;
vector<int> v12;
vector<int> v34;


long long ans = 0;

int main () {
	cin >> n;
	for (int i = 1; i <= n; i++) {
		cin >> row1[i] >> row2[i] >> row3[i] >> row4[i];
		
	}
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= n; j++) {
			cnt12[row1[i] + row2[j]]++;
			if (!count(v12.begin(), v12.end(), row1[i] + row2[j])) { // 应该是这里超时
				v12.push_back(row1[i] + row2[j]);
			}
			cnt34[row3[i] + row4[j]]++;	
			if (!count(v34.begin(), v34.end(), row3[i] + row4[j])) { // 应该是这里超时
				v34.push_back(row3[i] + row4[j]);
			}
		}
	}
	sort(v12.begin(), v12.end());
	sort(v34.begin(), v34.end());
	
//	cout << endl;
//	for(auto a : v12) {
//		cout << a << " ";
//	}
//	cout << endl;
//
//	for (auto a : v34) {
//		cout << a << " ";
//	}
//	cout << endl;
	
	
	vector<int>::iterator it12 = v12.begin();
	vector<int>::reverse_iterator it34 = v34.rbegin();
	
	while (it12 != v12.end() && it34 != v34.rend()) {
		
		if ((*it12) < -(*it34)) {
//			cout << (*it12) << "(it12)  ";
			it12++;
		}
		else if ((*it12) > -(*it34)) {
//			cout << (*it34) << "(it34) ";
			it34++;
		}
		else {
			ans += (cnt12[*it12] * cnt34[*it34]);
			it12++;
//			cout << endl << ans << endl;
		}
		
	}
	
	cout << ans;
	
		
}

简单懒惰的修改

poj——CE代码:
#include<iostream>
#include<map>
#include<vector>
#include<algorithm>

using namespace std;

const int maxn = 4004;

int n;
int row1[maxn] = {};
int row2[maxn] = {};
int row3[maxn] = {};
int row4[maxn] = {};

map<int, int> cnt12;
map<int, int> cnt34;
vector<int> v12;
vector<int> v34;

long long ans = 0;

int main () {
	cin >> n;
	for (int i = 1; i <= n; i++) {
		cin >> row1[i] >> row2[i] >> row3[i] >> row4[i];
		
	}
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= n; j++) {
			cnt12[row1[i] + row2[j]]++;
			if (!count(v12.begin(), v12.end(), row1[i] + row2[j])) {  // 应该是这里超时
				v12.push_back(row1[i] + row2[j]);
			}
			cnt34[row3[i] + row4[j]]++;	
			if (!count(v34.begin(), v34.end(), row3[i] + row4[j])) {
				v34.push_back(row3[i] + row4[j]);
			}
		}
	}

	
	for (auto a : v12) {
		if (count(v34.begin(), v34.end(), -a)) {
			ans += (cnt12[a] * cnt34[-a]);
		}
	}
	
	cout << ans;
	
	return 0;
	
}

经过观察和分析发现,超时主要在于自己懒惰的模拟照搬的思想,总是把问题做拆分做的很复杂,其实大可不必死板地按照最初思路按步实现,不必每次插入数据都排个序(必超时),思维要灵活!!!
利用好STL中的Lower_bound()函数和upper_bound()函数,对已排序的数组中的具体数值进行定位获得该数值的数量即可!

// AC代码
#include<iostream>
#include<algorithm>
#include<math.h>
using namespace std;

const int maxn = 4004;
typedef long long ll;
int n;
int row1[maxn] = {};
int row2[maxn] = {};
int row3[maxn] = {};
int row4[maxn] = {};

int row12[maxn*maxn] = {};
int row34[maxn*maxn] = {};

ll ans = 0;

int main () {
	
	cin >> n;
	
	for (int i = 0; i < n; i++) {
		cin >> row1[i] >> row2[i] >> row3[i] >> row4[i];
		
	}
	
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			row12[i*n + j] = row1[i] + row2[j];
			row34[i*n + j] = row3[i] + row4[j];
		}
	}
	
	sort(row12, row12 + (n*n));
		
	for (int i = 0; i < n*n; i++) {
		ans += upper_bound(row12, row12+(n*n), -row34[i]) - lower_bound(row12, row12+(n*n), -row34[i]);
	}
	
	cout << ans << endl;
}