题目
解析
莫比乌斯反演。
给定\(n\),\(m\),\(d\),求\[\sum_{i=1}^{n}\sum_{j=1}^{m}[gcd(i,j)=d]\]
那我们设\[f(x)=\sum_{i=1}^{n}\sum_{j=1}^{m}[gcd(i,j)=x]\]
设
\[\begin{aligned} F(x)=& \sum_{x\mid i}f(k) \\Q =&\sum_{x\mid k}\sum_{i=1}^{n}\sum_{j=1}^{m}[gcd(i,j)=k]\\ &当gcd(i,j)=k且x\mid i时,会对答案做一次贡献;\\ &所以只要枚举gcd(i,j),当x\mid (k=gcd(i,j))时,会对答案做一次贡献\\ =&\sum_{i=1}^{n}\sum_{j=1}^{m}[x\mid gcd(i,j)]\\ &\because x\mid gcd(i,j)\\ &\therefore x\mid i且x\mid j\\ =&\sum_{i=1}^{\frac{n}{x}}\sum_{j=1}^{\frac{m}{x}}[1\mid gcd(i,j)]\\ =&\lfloor \frac{n}{x}\rfloor \lfloor\frac{m}{x}\rfloor \end{aligned} \]
实在看不懂的话其实也可以这么理解,根据\(gcd\)的性质,发现\(F\)实际上就是在求有多少个\(ij\)都是\(x\)的倍数,\(1\)到\(n\)里有\(\lfloor\dfrac{n}{x}\rfloor\)个,\(1\)到\(m\)里有\(\lfloor\dfrac{m}{x}\rfloor\)个,根据乘法原理,就是\(\lfloor\dfrac{n}{x}\rfloor\lfloor\dfrac{m}{x}\rfloor\)。
然后直接反演
\[\begin{aligned} F(x)=& \sum_{x\mid i}f(i) \\ f(x)=&\sum_{x\mid i}\mu(\frac{i}{x})F(i)\\ =&\sum_{x\mid i}\mu(\frac{i}{x})\lfloor \frac{n}{i}\rfloor \lfloor\frac{m}{i}\rfloor \end{aligned} \]
将\(d\)带入
\[\begin{aligned} f(x)=\sum_{d\mid i}\mu(\frac{i}{d})\lfloor \frac{n}{i}\rfloor \lfloor\frac{m}{i}\rfloor \end{aligned} \]
令\(\dfrac{i}{d}=t\),得到
\[\begin{aligned} f(d)=\sum_{t=1}^{min(a,b)}\mu(t)\lfloor \frac{n}{td}\rfloor \lfloor\frac{m}{td}\rfloor \end{aligned} \]
这样就做到了\(O(n)\)的做,最后套一个数论分块,就可以\(O(m\sqrt n)\)的做这道题
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
int t, n, m, d, num, mx, ans;
int mu[N], p[N], sum[N];
bool vis[N];
template<class T>inline void read(T &x) {
x = 0; int f = 0; char ch = getchar();
while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
x = f ? -x : x;
return;
}
void get_mu(int n) {
mu[1] = 1;
for (int i = 2; i <= n; ++i) {
if (!vis[i]) p[++num] = i, mu[i] = -1;
for (int j = 1; j <= num; ++j) {
if (i * p[j] > n) break;
vis[i * p[j]] = 1;
if (i % p[j] == 0) {
mu[i * p[j]] = 0;
break;
} else mu[i * p[j]] = -mu[i];
}
}
}
int main() {
get_mu(N);
for (int i = 1; i <= N; ++i) sum[i] = sum[i - 1] + mu[i];
read(t);
while (t--) {
ans = 0;
read(n), read(m), read(d);
mx = min(n, m);
for (int l = 1, r; l <= mx; l = r + 1) {
r = min(n / (n / l), m / (m / l));
ans += ((n / (l * d)) * (m / (l * d)) * (sum[r] - sum[l - 1]));
}
printf("%d\n", ans);
}
return 0;
}