J.杰哥的直角三角恋

题解

首先注意到任意一组勾股数 a2+b2=c2a^2+b^2=c^2 可以表示为 a=x2y2a=x^2-y^2b=2xyb=2xyc=x2+y2c=x^2+y^2

那么问题可以转化成满足以下条件的 (x,y)(x,y) 点对的个数:

  1. x2+y2n,x>y1x^2+y^2\leq n,x> y\geq 1

  2. gcd(x,y)=1\gcd(x,y)=1 (用反证法,若 gcd(x,y)=k>1\gcd(x,y)=k>1 ,那么设 x=pkx=pky=qky=qk 代入得 a=(p2q2)k2a=(p^2-q^2)k^2b=2pqk2b=2pqk^2c=(p2+q2)k2c=(p^2+q^2)k^2 ,显然不互素)

  3. x,yx,y 奇偶性不同。(假如同为奇数则 aabbcc 都为偶数,同为偶数则亦然)

2x2n2x^2\leq n 时,此时 y<xy<x。若 xx 是偶数,则此部分的贡献是 φ(x)\varphi(x);若 xx 是奇数,则此部分的贡献是 φ(x)2\frac{\varphi(x)}{2}。这一部分复杂度为 Θ(n2)\Theta( \sqrt{\frac{n}{2}})

2x2n2x^2\leq n 时,此时 ynx2y\leq \sqrt{n-x^2} ,所以就是求 x=n/2+1ny=1nx2[gcd(x,y)=1]\sum_{x=\sqrt{n/2}+1}^{\sqrt{n}}\sum_{y=1}^{\sqrt{n-x^2}}[\gcd(x,y)=1]

  1. xx是偶数
ans1=x=n/2+1,2xndxμ(d)nx2d=d=1nμ(d)k=(n/2+1)/dndnk2d2d[2kd]ans1=\sum_{x=\sqrt{n/2}+1,2\mid{x}}^{\sqrt{n}}\sum_{d\mid{x}}\mu(d)\frac{\sqrt{n-x^2}}{d}\\ =\sum_{d=1}^{\sqrt{n}}\mu(d)\sum_{k=(\sqrt{n/2}+1)/d}^{\frac{\sqrt{n}}{d}}\frac{\sqrt{n-k^2d^2}}{d}[2\mid{kd}]\\
  1. xx 是奇数
ans2=x=n/2+1,2xndxμ(d)nx22dd=1nμ(d)k=(n/2+1)/dndnk2d22d[2(kd)]ans2=\sum_{x=\sqrt{n/2}+1,2\nmid{x}}^{\sqrt{n}}\sum_{d\mid{x}}\mu(d)\frac{\sqrt{n-x^2}}{2d}\\ \sum_{d=1}^{\sqrt{n}}\mu(d)\sum_{k=(\sqrt{n/2}+1)/d}^{\frac{\sqrt{n}}{d}}\frac{\sqrt{n-k^2d^2}}{2d}[2\nmid{(kd)}]\\

复杂度大概是 O(nlnn)\mathcal O(\sqrt{n}\ln\sqrt{n})

std

#include<bits/stdc++.h>
using namespace std;
#define maxn 10000005
#define int long long
int mu[maxn];
int p[maxn];
int tot;
int v[maxn];
int phi[maxn];
void pre() {
    mu[1] = 1;
    phi[1]=1;
    for (int i = 2; i <= 1e7; ++i) {
        if (!v[i]) mu[i] = -1, p[++tot] = i,phi[i]=i-1;
        for (int j = 1; j <= tot && i <= 1e7 / p[j]; ++j) {
            v[i * p[j]] = 1;
            if (i % p[j] == 0) {
                mu[i * p[j]] = 0;
                phi[i*p[j]]=phi[i]*p[j];
                break;
            }
            else {
                phi[i*p[j]]=phi[i]*phi[p[j]];
                mu[i * p[j]] = -mu[i];
            }
        }
    }
}
signed main() {
    pre();
    int n;
    cin >> n;
    int m = (sqrt(n / 2));
    int mm = sqrt(n / 2) + 1;
    int mmm = sqrt(n);
    int ans = 0;
    for (int i = 2; i <= m; i++) {
        if (i % 2 == 0)ans += phi[i];
        else ans += phi[i] / 2;
    }
    for (int d = 1; d <= mmm; d++) {
        for (int k = mm / d; k <= mmm / d; k++) {
            if ((k * d) >= mm && (k * d) <= mmm && (k * d) % 2 == 0) {
                ans += mu[d] * ((int) sqrt(n - d * d * k * k) / (d));
            }
        }
    }
    for (int d = 1; d <= mmm; d++) {
        for (int k = mm / d; k <= mmm / d; k++) {
            if ((k * d) >= mm && (k * d) <= mmm && (k * d) % 2) {
                ans += mu[d] * ((int) sqrt(n - d * d * k * k) / 2 / d);
            }
        }
    }
    cout << ans << '\n';
}