结论题

其实没那么难,只不过真的被附加说明给恶心到了。

首先我们将相乘的数连边,能看出来会形成几个环,显然环与环之间互不影响。

然后易证一共有 \(gcd(n,k)\) 个环,每个环里就有 \(\frac{n}{gcd(n,k)}\) 个数,显然我们给每个环分配的数连续更优。

关于一个环里怎么分配:我们 看样例/手玩可知 将所有的数从小到大排序依次塞入,左一个右一个最优,其实也很好理解,这样我们能将较大的数尽可能的挨在一起(考虑它是一个环)。

然后发现时间复杂度并不对,我们可以记忆化一下答案,又因为我们用到的k只与 \(gcd(n,k)\) 有关,所以我们让 \(k=gcd(n,k)\) 。如果 \(n\)\(\sqrt n\)个约数的话,时间复杂度就成了\(O(n\sqrt n)\),但实际上达不到。

注意$k \times 2=n $的情况.

下面是考场代码,可能有点丑.

#include<algorithm>
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n, m, k, t;
const int N = 200010;
int a[N];
LL ans[N];
inline int read()
{
	int res = 0; char ch = getchar(); bool XX = false;
	for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
	for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
	return XX ? -res : res;
}
int GCD(int a, int b) {return b ? GCD(b, a % b) : a;}
void solve1()
{
	while (m--)
	{
		k = read();
		if (ans[k]) {printf("%lld\n", ans[k]); continue;}
		if (k == 0)
		{
			for (int i = 1; i <= n; ++i)ans[k] += (LL)a[i] * a[i];
			printf("%lld\n", ans[k]);
			continue;
		}
		if (2 * k == n)
		{
			for (int i = 1; i <= n; i += 2)ans[k] += (LL)a[i] * a[i + 1];
			ans[k] *= 2;
			printf("%lld\n", ans[k]);
			continue;
		}
		t = n / GCD(n, k); //gcd 个环
		for (int i = 1; i <= n; i += t)
			for (int j = i, to = i + t - 1; j <= to - 1; ++j)
			{
				if (j == i)ans[k] += (LL)a[j] * a[j + 1] + (LL)a[j] * a[j + 2];
				else if (j == to - 1)ans[k] += (LL)a[j] * a[j + 1];
				else ans[k] += (LL)a[j] * a[j + 2];
			}
		printf("%lld\n", ans[k]);
	}
}
void solve2()
{
	while (m--)
	{
		k = read();
		if (k == 0)
		{
			if (!ans[k])for (int i = 1; i <= n; ++i)ans[k] += (LL)a[i] * a[i];
			printf("%lld\n", ans[k]);
			continue;
		}
		k = GCD(n, k);
		if (ans[k]) {printf("%lld\n", ans[k]); continue;}
		if (2 * k == n)
		{
			for (int i = 1; i <= n; i += 2)ans[k] += (LL)a[i] * a[i + 1];
			ans[k] *= 2;
			printf("%lld\n", ans[k]);
			continue;
		}
		t = n / GCD(n, k); //gcd 个环
		for (int i = 1; i <= n; i += t)
			for (int j = i, to = i + t - 1; j <= to - 1; ++j)
			{
				if (j == i)ans[k] += (LL)a[j] * a[j + 1] + (LL)a[j] * a[j + 2];
				else if (j == to - 1)ans[k] += (LL)a[j] * a[j + 1];
				else ans[k] += (LL)a[j] * a[j + 2];
			}
		printf("%lld\n", ans[k]);
	}
}
int main()
{
// 	freopen("ring.in","r",stdin);
// 	freopen("ring.out","w",stdout);
	cin >> n >> m;
	for (int i = 1; i <= n; ++i)a[i] = read();
	sort(a + 1, a + 1 + n);
	if (n <= 3000)solve1();
	else solve2();
	fclose(stdin); fclose(stdout);
	return 0;
}