算法介绍

点分治,顾名思义,是一种对点进行分治的数据结构。(树上的点)

多用于在树上进行有限制的路径计数。

比如:求树上长度小于$ k$ 的简单路径条数。\((n \leq 10000)\)

直接做肯定是补星的。所以就需要点分治这种东西了。

需要统计的路径肯定有这么两种:

  • 1.经过根节点$ root $的路径
  • 2.不经过根节点\(root\)的路径

显然第二种路径对于某个节点\(u\),属于第一种路径。所以分治解决即可。

我们来考虑第一种情况如何解决。

处理出一个数组\(d\),表示从当前根节点\(u\),到各个子节点的距离。

那么我们要统计的显然就是\(d[u]+d[v]\leq k\)的路径\((u,v)\)的个数。

这个东西可以通过在dfs求这个数组时顺便把所有的\(d\)值记录下来,排序之后让他们具有单调性。

然后双指针扫一下就好(合法状态就是\(d[l]+d[r]\leq k​\))那么当指针在\(l​\)时,对答案的贡献就是\(r-l​\)(不能重复选自己,所以不+1)

然后现在考虑一种情况。当\(u,v​\)都在当前根节点的同一个子树里面。这样子的话,路径\((u,v)​\)如果经过根节点就不是一条简单路径了(重边)。如何解决呢?

容斥的思想!

对于每个子树,分别处理它其中的子节点的d值,给答案减去就行了!

代码大概就长这个样子

void dfs(int u) {
    vis[u] = 1;
    ans += solve(u, 0); //所有情况
    for(int i = head[u]; i; i = e[i].nxt) {
        if(vis[e[i].to]) continue;
        int v = e[i].to;
        ans -= solve(v, e[i].v); //减掉不合法情况
        //下面是找重心的代码,后面会解释为什么要找重心
        now_sz = inf, root = 0; sz = siz[v];
        find_root(v, 0);
        dfs(root);
    }
}

先不管为什么要找重心。我们总结一下算法流程:

  • 1.找一个根节点root
  • 2.对root计算出d数组并计算答案
  • 3.把root删了,对root的各个子树执行流程1,2

复杂度是多少呢?粗略估计一下是\(O(Tnlogn) ​\)\(T​\)是树的层数。(这里有个\(log\)是因为用了排序

显然我们要让这个树优美一点,身材圆润一点,不能瘦成一条链,不然复杂度就变成\(O(n^2logn) ​\)了。

那这个根节点怎么找呢?树的重心

将重心当做根节点,可以保证树是\(log​\)层的!

那么复杂度就是$O(nlog^2n) \(了!(如果不使用排序的话(比如一些题是用到的桶),那么复杂度是\)O(nlogn)$)

还有就是关于点分治这里的重心有两种找法。一种就是上面那样的,另外一种就是改了一句

sz = siz[v];->sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];

实际上第二种才是对的,因为v可能在上次处理siz数组时是u的父亲(这是一棵无根树!)

但是复杂度并不会退化qwq,有神仙证明了。链接

例题:

POJ1741 tree

真正的模板题。就是我上面提到的那个问题。

直接点分一下就好了。每次将距离排序一下,然后双指针扫一扫,每次合法答案就是r-l,容斥一下将不合法情况减去即可。注意找重心不要写错不然复杂度就炸了。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010

inline void in(int &x) {
    x = 0; int f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    x *= f;
}

int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N];
struct edge {
    int to, nxt, v;
}e[N<<1];

void ins(int u, int v, int w) {
    e[++cnt] = (edge) {v, head[u], w};
    head[u] = cnt;
}

int now_sz = inf, root = 0, sz;

void find_root(int u, int fa) {
    siz[u] = 1;
    int res = 0;
    for(int i = head[u]; i; i = e[i].nxt) {
        if(vis[e[i].to] || e[i].to == fa) continue;
        int v = e[i].to;
        find_root(v, u);
        siz[u] += siz[v];
        res = max(res, siz[v]);
    }
    res = max(res, sz - siz[u]);
    if(res < now_sz) now_sz = res, root = u;
}

int a[N], tot;
void get_dis(int u, int fa) {
    a[++tot] = d[u];
    for(int i = head[u]; i; i = e[i].nxt) {
        if(vis[e[i].to] || e[i].to == fa) continue;
        int v = e[i].to;
        d[v] = d[u] + e[i].v;
        get_dis(v, u);
    }
}

int solve(int u, int dis) {
    d[u] = dis; tot = 0;
    get_dis(u, u);
    sort(a + 1, a + tot + 1);
    int l = 1, r = tot, res = 0;
    for(; l < r; ++l) {
        while(l < r && a[l] + a[r] > k) --r;
        if(l < r) res += r - l;
    }
    return res;
}

void dfs(int u) {
    vis[u] = 1;
    ans += solve(u, 0);
    for(int i = head[u]; i; i = e[i].nxt) {
        if(vis[e[i].to]) continue;
        int v = e[i].to;
        ans -= solve(v, e[i].v);
        now_sz = inf, root = 0; sz = siz[v];
        find_root(v, 0);
        dfs(root);
    }
}

int main() {
    while(~scanf("%d%d", &n, &k) && n && k) {
        ans = 0; cnt = 0; 
        memset(head, 0, sizeof(head));
        memset(vis, 0, sizeof(vis));
        for(int i = 1; i < n; ++i) {
            int u, v, w; in(u), in(v), in(w);
            ins(u, v, w), ins(v, u, w);
        }
        dfs(1);
        printf("%d\n", ans);
    }
} 

BZOJ2152: 聪聪可可

求倍数为3的路径数。

考虑\(mod\ 3\)意义下的路径,为0显然可以互相拼起来,贡献是\(sum[0]^2\)。1和2可以互相拼,而且起点终点互换,所以贡献是\(sum[1]*sum[2]*2\),点分治计算这两个即可。总方案数是\(n^2\),所以答案就是\(\frac{sum}{n^2}\)

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010

inline void in(int &x) {
    x = 0; int f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    x *= f;
}

int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N], sum[3];
struct edge {
    int to, nxt, v;
}e[N<<1];

void ins(int u, int v, int w) {
    e[++cnt] = (edge) {v, head[u], w};
    head[u] = cnt;
}

int now_siz, sz, root;
void find_root(int u, int fa) {
    siz[u] = 1; int res = 0;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        find_root(v, u);
        siz[u] += siz[v];
        res = max(res, siz[v]);
    }
    res = max(res, sz - siz[u]);
    if(res < now_siz) now_siz = res, root = u;
}

void get_dis(int u, int fa) {
    sum[d[u]%3]++;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(vis[v] || v == fa) continue;
        d[v] = d[u] + e[i].v;
        get_dis(v, u);
    }
}

int solve(int u, int dis) {
    d[u] = dis; sum[0] = sum[1] = sum[2] = 0;
    get_dis(u, u);
    return sum[0] * sum[0] + sum[1] * sum[2] * 2;
}

void dfs(int u) {
    ans += solve(u, 0);
    vis[u] = 1;
    for(int i = head[u]; i; i = e[i].nxt) { 
        int v = e[i].to;
        if(vis[v]) continue;
        ans -= solve(v, e[i].v);
        now_siz = inf; sz = siz[v]; root = 0;
        find_root(v, u);
        dfs(root);
    }
}

int main() {
    in(n);
    for(int i = 1; i < n; ++i) {
        int u, v, w; in(u), in(v), in(w);
        ins(u, v, w), ins(v, u, w); 
    }
    now_siz = inf; root = 0; sz = n;
    find_root(1, 1); 
    dfs(root);
    int now = n * n, g = __gcd(now, ans);
    printf("%d/%d\n", ans / g, now / g);
}

LuoguP3806 【模板】点分治1

注意这题数据很水...

求长度为k的路径是否存在。多次询问(询问数\(\leq 100​\)

这题效率有点奇怪...

自己估算了一下是\(O(mnlog^2n)​\)

对长度正好k的话,其实用个桶标记就好了,实际上和小于k没多大区别的。

考虑先将询问离线,然后在点分治过程中对所有答案进行判定。处理出d[]表示到节点i到当前根的距离。那么照例是拼路径,但是现在不是求方案总数而是求有没有这个方案,看起来不能容斥了。但是实际上可以的:考虑先对根u solve一遍,给所有询问加上这次的结果,然后对每个子节点计算一遍,给所有询问减掉这次的结果就好了。

具体的话看看代码吧

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
#define lim 10000000

inline void in(int &x) {
    x = 0; int f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    x *= f;
}

int top, n, m, d[N], cnt, head[N], ans[110];
int vis[N], siz[N], q[110], st[N], s[10000010];
struct edge {
    int to, nxt, v;
}e[N<<1];

void ins(int u, int v, int w) {
    e[++cnt] = (edge) {v, head[u], w};
    head[u] = cnt;
}

int now_sz = inf, root, sz;
void find_root(int u, int fa) {
    siz[u] = 1; int res = 0;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        find_root(v, u);
        res = max(res, siz[v]);
        siz[u] += siz[v];
    }
    res = max(res, sz - siz[u]);
    if(res < now_sz) now_sz = res, root = u;
}

void get_dis(int u, int fa) {
    st[++top] = d[u];
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        d[v] = d[u] + e[i].v;
        get_dis(v, u);
    }
}

void solve(int u, int dis, int op) {
    top = 0; d[u] = dis; get_dis(u, 0);
    for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]++;
    for(int i = 1; i <= m; ++i) {
        for(int j = 1; j <= top; ++j) if(q[i] >= st[j]) ans[i] += s[q[i] - st[j]] * op;
    }
    for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]--;
}

void dfs(int u) {
    vis[u] = 1;
    solve(u, 0, 1);
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(vis[v]) continue;
        top = 0; d[v] = e[i].v;
        solve(v, e[i].v, -1);
        now_sz = inf, root = 0, sz = siz[v];
        find_root(v, u);
        dfs(root);
    }
}

int main() {
    in(n), in(m);
    for(int i = 1; i < n; ++i) {
        int u, v, w; in(u), in(v), in(w);
        ins(u, v, w), ins(v, u, w);
    }
    for(int i = 1; i <= m; ++i) in(q[i]);
    sz = n; now_sz = inf; root = 0;
    find_root(1, 1); dfs(root);
    for(int i = 1; i <= m; ++i) puts(ans[i] ? "AYE" : "NAY");
}

CF161D Distance in Tree

求长度等于k的路径数...就很烦....这种一般都要分类讨论

需要分类讨论一下,同样是套路点分然后开个桶,然后分\(k-v[i]=v[i]\)和不等两种情况,显然相等的话答案就是\(cnt[v[i]]*(cnt[v[i]]-1)/2\).不相等的话用乘法原理考虑一下,\(cnt[v[i]]*cnt[k-v[i]]​\),注意每次统计完之后就要把cnt清空。

#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline

namespace io {

#define in(a) a = read()
#define out(a) write(a)
#define outn(a) out(a), putchar('\n')

#define I_int ll
inline I_int read() {
    I_int x = 0, f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * f;
}
char F[200];
inline void write(I_int x) {
    if (x == 0) return (void) (putchar('0'));
    I_int tmp = x > 0 ? x : -x;
    if (x < 0) putchar('-');
    int cnt = 0;
    while (tmp > 0) {
        F[cnt++] = tmp % 10 + '0';
        tmp /= 10;
    }
    while (cnt > 0) putchar(F[--cnt]);
}
#undef I_int

}
using namespace io;

using namespace std;

#define N 100010

int n, k;
int cnt, head[N], vis[N], d[N];
struct edge {
    int to, nxt;
}e[N<<1];

void ins(int u, int v) {
    e[++cnt] = (edge) {v, head[u]};
    head[u] = cnt;
}

int siz[N], now_sz = inf, root, sz;
void find_root(int u, int fa) {
    siz[u] = 1; int res = 0;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        find_root(v, u);
        siz[u] += siz[v];
        res = max(res, siz[v]);
    }
    res = max(res, sz - siz[u]);
    if(res < now_sz) now_sz = res, root = u;
}

int top, st[N], s[N];
void get_dis(int u, int fa) {
    st[++top] = d[u]; if(d[u] <= k) ++s[d[u]];
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        d[v] = d[u] + 1;
        get_dis(v, u);
    }
}

ll solve(int u, int dis) {
    d[u] = dis; top = 0; get_dis(u, 0); 
    ll ans = 0;
    for(int i = 1; i <= top; ++i) 
        if(st[i] <= k) {
            if(st[i] * 2 == k) ans += 1ll * s[st[i]] * (s[st[i]] - 1) / 2ll;
            else ans += 1ll * s[k - st[i]] * s[st[i]];
            s[st[i]] = s[k - st[i]] = 0;
        }
    return ans;
}

ll ans = 0;
void dfs(int u) {
    vis[u] = 1; ans += solve(u, 0);
    int totsiz = sz;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(vis[v]) continue;
        ans -= solve(v, 1);
        sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
        now_sz = inf; root = 0;
        find_root(v, 0);
        dfs(root);
    }
}

int main() {
    in(n), in(k);
    for(int i = 1; i < n; ++i) {
        int u = read(), v = read();
        ins(u, v), ins(v, u);
    }
    now_sz = inf; sz = n; root = inf;
    find_root(1, 0); 
    dfs(root);
    outn(ans);
}