Y型树

[题目链接](https://www.nowcoder.com/practice/444ca6d8ab1c416e9729392d377b4fa8)

思路

先来理解什么是 Y 型树。题目说"恰好只有三个分叉"——也就是说这棵树里有且只有一个度为 3 的节点(中心节点),其余节点度都是 1 或 2,整体形状像字母 Y:从中心节点出发,伸出三条"手臂",每条手臂是一条链。

设三条手臂的长度分别为 (都 ),加上中心节点本身,总共有 个顶点。因为题目中顶点没有标号(无标号树),所以不同的 Y 型树本质上由三条手臂的长度决定。为了避免重复计数,我们规定

于是问题就变成了:满足 的整数解有多少组?这就是把 拆分成恰好 3 个正整数部分的有序(非降)拆分数。

手算验证一下:

  • ,只有 ,答案
  • ,有 ,答案

和样例吻合。

接下来推公式。对于固定的 的范围是 (因为 意味着 ), 的范围是 。把这个求和展开化简后,可以得到一个非常简洁的结论:

$$

也就是 四舍五入到最近整数。之所以成立,是因为 只可能是 四种值,对应的余项分别让四舍五入恰好给出正确结果。

用整数运算表达就是 ,不需要浮点数。

由于 可能很大,答案对 取模,所以我们用模逆元来计算除以 12 的操作:先算出 ,加上调整量,再乘以 的模逆元即可。调整量根据 的值确定:余数为 时加 ,余数为 时减 ,余数为 时减 ,余数为 时加 ——保证被 整除后再取模。

代码

#include <bits/stdc++.h>
using namespace std;
int main() {
    long long n;
    scanf("%lld", &n);
    long long m = n - 1;
    if (m < 3) { printf("0\n"); return 0; }
    const long long MOD = 1000000007;
    long long r = m % 12;
    long long rsq = r * r % 12;
    long long adjust;
    if (rsq == 0) adjust = 0;
    else if (rsq == 1) adjust = -1;
    else if (rsq == 4) adjust = -4;
    else adjust = 3; // rsq == 9
    auto power = [&](long long base, long long exp, long long mod) {
        long long res = 1; base %= mod;
        while (exp > 0) {
            if (exp & 1) res = res * base % mod;
            base = base * base % mod;
            exp >>= 1;
        }
        return res;
    };
    long long inv12 = power(12, MOD - 2, MOD);
    long long mMod = m % MOD;
    long long msq = mMod * mMod % MOD;
    long long ans = (msq + adjust % MOD + MOD) % MOD * inv12 % MOD;
    printf("%lld\n", ans);
    return 0;
}
import java.util.Scanner;
public class Main {
    static final long MOD = 1000000007L;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long m = n - 1;
        if (m < 3) { System.out.println(0); return; }
        long r = m % 12;
        long rsq = r * r % 12;
        long adjust;
        if (rsq == 0) adjust = 0;
        else if (rsq == 1) adjust = -1;
        else if (rsq == 4) adjust = -4;
        else adjust = 3;
        long inv12 = power(12, MOD - 2, MOD);
        long mMod = m % MOD;
        long msq = mMod * mMod % MOD;
        long ans = (msq + adjust % MOD + MOD) % MOD * inv12 % MOD;
        System.out.println(ans);
    }
    static long power(long base, long exp, long mod) {
        long res = 1; base %= mod;
        while (exp > 0) {
            if ((exp & 1) == 1) res = res * base % mod;
            base = base * base % mod;
            exp >>= 1;
        }
        return res;
    }
}
import sys
input = sys.stdin.readline

n = int(input())
m = n - 1
MOD = 10**9 + 7
if m < 3:
    print(0)
else:
    r = m % 12
    rsq = (r * r) % 12
    if rsq == 0: adjust = 0
    elif rsq == 1: adjust = -1
    elif rsq == 4: adjust = -4
    else: adjust = 3  # rsq == 9
    msq = pow(m, 2, MOD)
    ans = (msq + adjust) * pow(12, MOD - 2, MOD) % MOD
    print(ans)
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
rl.on('line', (line) => {
    const n = BigInt(line.trim());
    const MOD = 1000000007n;
    const m = n - 1n;
    if (m < 3n) { console.log("0"); return; }
    const r = m % 12n;
    const rsq = r * r % 12n;
    let adjust;
    if (rsq === 0n) adjust = 0n;
    else if (rsq === 1n) adjust = -1n;
    else if (rsq === 4n) adjust = -4n;
    else adjust = 3n;
    function power(base, exp, mod) {
        let res = 1n; base %= mod;
        while (exp > 0n) {
            if (exp & 1n) res = res * base % mod;
            base = base * base % mod;
            exp >>= 1n;
        }
        return res;
    }
    const inv12 = power(12n, MOD - 2n, MOD);
    const mMod = m % MOD;
    const msq = mMod * mMod % MOD;
    const ans = ((msq + adjust % MOD + MOD) % MOD * inv12) % MOD;
    console.log(ans.toString());
});

复杂度分析

  • 时间复杂度:,瓶颈在于计算 的模逆元(快速幂)。
  • 空间复杂度: