结论题
其实没那么难,只不过真的被附加说明给恶心到了。
首先我们将相乘的数连边,能看出来会形成几个环,显然环与环之间互不影响。
然后易证一共有 \(gcd(n,k)\) 个环,每个环里就有 \(\frac{n}{gcd(n,k)}\) 个数,显然我们给每个环分配的数连续更优。
关于一个环里怎么分配:我们 看样例/手玩可知 将所有的数从小到大排序依次塞入,左一个右一个最优,其实也很好理解,这样我们能将较大的数尽可能的挨在一起(考虑它是一个环)。
然后发现时间复杂度并不对,我们可以记忆化一下答案,又因为我们用到的k只与 \(gcd(n,k)\) 有关,所以我们让 \(k=gcd(n,k)\) 。如果 \(n\) 有\(\sqrt n\)个约数的话,时间复杂度就成了\(O(n\sqrt n)\),但实际上达不到。
注意$k \times 2=n $的情况.
下面是考场代码,可能有点丑.
#include<algorithm>
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n, m, k, t;
const int N = 200010;
int a[N];
LL ans[N];
inline int read()
{
int res = 0; char ch = getchar(); bool XX = false;
for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
return XX ? -res : res;
}
int GCD(int a, int b) {return b ? GCD(b, a % b) : a;}
void solve1()
{
while (m--)
{
k = read();
if (ans[k]) {printf("%lld\n", ans[k]); continue;}
if (k == 0)
{
for (int i = 1; i <= n; ++i)ans[k] += (LL)a[i] * a[i];
printf("%lld\n", ans[k]);
continue;
}
if (2 * k == n)
{
for (int i = 1; i <= n; i += 2)ans[k] += (LL)a[i] * a[i + 1];
ans[k] *= 2;
printf("%lld\n", ans[k]);
continue;
}
t = n / GCD(n, k); //gcd 个环
for (int i = 1; i <= n; i += t)
for (int j = i, to = i + t - 1; j <= to - 1; ++j)
{
if (j == i)ans[k] += (LL)a[j] * a[j + 1] + (LL)a[j] * a[j + 2];
else if (j == to - 1)ans[k] += (LL)a[j] * a[j + 1];
else ans[k] += (LL)a[j] * a[j + 2];
}
printf("%lld\n", ans[k]);
}
}
void solve2()
{
while (m--)
{
k = read();
if (k == 0)
{
if (!ans[k])for (int i = 1; i <= n; ++i)ans[k] += (LL)a[i] * a[i];
printf("%lld\n", ans[k]);
continue;
}
k = GCD(n, k);
if (ans[k]) {printf("%lld\n", ans[k]); continue;}
if (2 * k == n)
{
for (int i = 1; i <= n; i += 2)ans[k] += (LL)a[i] * a[i + 1];
ans[k] *= 2;
printf("%lld\n", ans[k]);
continue;
}
t = n / GCD(n, k); //gcd 个环
for (int i = 1; i <= n; i += t)
for (int j = i, to = i + t - 1; j <= to - 1; ++j)
{
if (j == i)ans[k] += (LL)a[j] * a[j + 1] + (LL)a[j] * a[j + 2];
else if (j == to - 1)ans[k] += (LL)a[j] * a[j + 1];
else ans[k] += (LL)a[j] * a[j + 2];
}
printf("%lld\n", ans[k]);
}
}
int main()
{
// freopen("ring.in","r",stdin);
// freopen("ring.out","w",stdout);
cin >> n >> m;
for (int i = 1; i <= n; ++i)a[i] = read();
sort(a + 1, a + 1 + n);
if (n <= 3000)solve1();
else solve2();
fclose(stdin); fclose(stdout);
return 0;
}