ACM模版

描述

题解

感觉这个题好难啊,虽然知道是要求每条边的贡献,但是完全不知道具体怎么搞,花了 5 盾看了题解……

虽说是思路上理解了,但是后边的 启发式合并 + 数据结构来维护子树 还是一脸懵逼,于是一狠心,又花了 60 盾看了大牛们的代码……好吧,我必须承认,没有看懂,大体上算是理解了一丢丢,但是具体的还是十分头疼……感觉好麻烦的说,树归部分倒是十分容易看懂,但是还是无法顿悟 启发式合并 + 数据结构来维护子树 部分的高超技艺……

Mark 一下,如果有大神有时间写更加详细的题解,烦请留言区留地址……就此谢过~~~

是时候看看 启发式合并 是啥东西了~~~害怕.jpg

代码

#include <cstdio>
#include <iostream>

#define ll long long

using namespace std;

const int MAXN = 2e5 + 5;
const int MAXM = MAXN << 4;

int n;
int cnt = 0, pos = 0;
ll ans = 0, tmp;
int root[MAXN];
int l1[MAXM], r1[MAXM];
int ls[MAXM], rs[MAXM];
int lb[MAXM], rb[MAXM];
int nt[MAXN], head[MAXN], v[MAXN];
ll s[MAXM];

template <class T>
inline void scan_d(T &ret)
{
    char c;
    ret = 0;
    while ((c = getchar()) < '0' || c > '9');
    while (c >= '0' && c <= '9')
    {
        ret = ret * 10 + (c - '0'), c = getchar();
    }
}

ll get_n(int n)
{
    return (ll)(n + 1) * n / 2;
}

void add(int x, int y)
{
    nt[++pos] = head[x];
    head[x] = pos;
    v[pos] = y;
}

void ins(int &x, int l, int r, int a)
{
    if (!x)
    {
        x = ++cnt;
    }
    if (l == r)
    {
        ls[x] = rs[x] = 1;
        lb[x] = rb[x] = 0;
        s[x] = 1;
        return ;
    }

    int m = (l + r) >> 1;
    if (a <= m)
    {
        ins(l1[x], l, m, a);
    }
    else
    {
        ins(r1[x], m + 1, r, a);
    }
    if (!l1[x])
    {
        l1[x] = ++cnt;
        lb[cnt] = rb[cnt] = m - l + 1;
        s[cnt] = get_n(m - l + 1);
    }
    if (!r1[x])
    {
        r1[x] = ++cnt;
        lb[cnt] = rb[cnt] = r - m;
        s[cnt] = get_n(r - m);
    }
    int k1 = l1[x], k2 = r1[x];
    s[x] = s[k1] + s[k2];
    if (ls[k2] && rs[k1])
    {
        s[x] -= get_n(rs[k1]) + get_n(ls[k2]);
        s[x] += get_n(rs[k1] + ls[k2]);
    }
    if (lb[k2] && rb[k1])
    {
        s[x] -= get_n(rb[k1]) + get_n(lb[k2]);
        s[x] += get_n(rb[k1] + lb[k2]);
    }
    ls[x] = ls[k1];
    rs[x] = rs[k2];
    lb[x] = lb[k1];
    rb[x] = rb[k2];
    if (ls[k1] == m - l + 1)
    {
        ls[x] += ls[k2];
    }
    if (lb[k1] == m - l + 1)
    {
        lb[x] += lb[k2];
    }
    if (rs[k2] == r - m)
    {
        rs[x] += rs[k1];
    }
    if (rb[k2] == r - m)
    {
        rb[x] += rb[k1];
    }
}

int merge(int x, int y, int l, int r)
{
    if (!x || !y)
    {
        return x + y;
    }
    if (lb[x] == r - l + 1)
    {
        return y;
    }
    if (lb[y] == r - l + 1)
    {
        return x;
    }

    int m = (l + r) >> 1;
    l1[x] = merge(l1[x], l1[y], l, m);
    r1[x] = merge(r1[x], r1[y], m + 1, r);
    if (!l1[x])
    {
        l1[x] = ++cnt;
        lb[cnt] = rb[cnt] = m - l + 1;
        s[cnt] = get_n(m - l + 1);
    }
    if (!r1[x])
    {
        r1[x] = ++cnt;
        lb[cnt] = rb[cnt] = r - m;
        s[cnt] = get_n(r - m);
    }
    int k1 = l1[x], k2 = r1[x];
    s[x] = s[k1] + s[k2];
    if (ls[k2] && rs[k1])
    {
        s[x] -= get_n(rs[k1]) + get_n(ls[k2]);
        s[x] += get_n(rs[k1] + ls[k2]);
    }
    if (lb[k2] && rb[k1])
    {
        s[x] -= get_n(rb[k1]) + get_n(lb[k2]);
        s[x] += get_n(rb[k1] + lb[k2]);
    }
    ls[x] = ls[k1];
    rs[x] = rs[k2];
    lb[x] = lb[k1];
    rb[x] = rb[k2];
    if (ls[k1] == m - l + 1)
    {
        ls[x] += ls[k2];
    }
    if (lb[k1] == m - l + 1)
    {
        lb[x] += lb[k2];
    }
    if (rs[k2] == r - m)
    {
        rs[x] += rs[k1];
    }
    if (rb[k2] == r - m)
    {
        rb[x] += rb[k1];
    }

    return x;
}

void dfs(int rt, int pre)
{
    for (int i = head[rt]; i; i = nt[i])
    {
        int v_ = v[i];
        if (v_ != pre)
        {
            dfs(v_, rt);
            root[rt] = merge(root[rt], root[v_], 1, n);
        }
    }

    ins(root[rt], 1, n, rt);
    ans += tmp - s[root[rt]];
}

int main()
{
    scan_d(n);
    tmp = get_n(n);

    int x, y;
    for (int i = 1; i < n; i++)
    {
        scan_d(x), scan_d(y);
        add(x, y), add(y, x);
    }

    dfs(1, 0);

    printf("%lld\n", ans);

    return 0;
}