题意:
给定一棵 n 个节点的树,每条边有边权,求出树上两点距离小于等于 k 的点对数量。
题解:
点分治的模板题是求等于K的路径条数
本题是求小于等于K的路径条数,我们只需要改变统计答案即可
原本统计答案是对一个路劲长度len,判断K-len在之前的子树中出现多少次,用f数组来记录,直接查看f[k-len]等于多少即可
现在统计答案是对一个路径长度len,判断小于等于K-len的数在之前的子树中出现多少次,统计区间数量,我们可以用树状数组来实现,对于每个len我们插入数组f[]其中,然后求1~K-len的区间查询
代码:
详细看代码
#include<bits/stdc++.h> #define MAXN 40005 #define MAXK 20005 using namespace std; int N,K; struct edge{ int v,w; edge(int v=0, int w=0):v(v), w(w){} }; vector<edge> adj[MAXN]; // int fw[MAXK]; int lbt(int x){ return x & (-x); } int getsum(int x){ int ans = 0; for(;x>0;x-=lbt(x)){ ans += fw[x]; } return ans; } void change(int x, int dv){ for(;x<=K;x+=lbt(x)){ fw[x] += dv; } } // int sz[MAXN]; bool vis[MAXN]; int rt; void dfs_rt(int u, int fa, int tot){//O(tot)求根 sz[u] = 1; int v, n = 0; for(int k=0;k<adj[u].size();k++){ v = adj[u][k].v; if(v==fa || vis[v]) continue; dfs_rt(v, u, tot); sz[u] += sz[v]; n = max(n, sz[v]); } n = max(n, tot-sz[u]); if(n*2 <= tot) rt = u; } void dfs_sz(int u, int fa){//O(tot)求子树 sz[u] = 1; int v, n = 0; for(int k=0;k<adj[u].size();k++){ v = adj[u][k].v; if(v==fa || vis[v]) continue; dfs_sz(v, u); sz[u] += sz[v]; } } int d[MAXN], cnt = 0; void dfs_dis(int u, int fa, int dis){//O(tot) d[++cnt] = dis;//记录距离 int v,w; for(int k=0;k<adj[u].size();k++){ v = adj[u][k].v; w = adj[u][k].w; if(v==fa || vis[v]) continue; dfs_dis(v, u, dis + w); } } void dfs_clear(int u, int fa, int dis){//O(tot)清零 if(dis) change(dis, -1); int v,w; for(int k=0;k<adj[u].size();k++){ v = adj[u][k].v; w = adj[u][k].w; if(v==fa || vis[v]) continue; dfs_clear(v, u, dis + w); } } int work(int u, int tot){ dfs_rt(u, 0, tot); u = rt; dfs_sz(u, 0); vis[u] = 1; int v,w; //solve /* 求出每个点到根节点的距离,然后记录到数组d 然后对于d[i],查询0~K-d[i]区间大小,用树状数组实现 如果d[i]<=K,那么d[i]本身也符合条件 将数组d插入到树状数组 当以一个点为根的一轮结束时记得情况数组树状数组 */ int ans = 0; for(int k=0;k<adj[u].size();k++){ v = adj[u][k].v; w = adj[u][k].w; if(vis[v]) continue; cnt = 0; dfs_dis(v, u, w); for(int i=1;i<=cnt;i++){ if(d[i] <= K) ++ans; ans += getsum(max(0,K-d[i])); } for(int i=1;i<=cnt;i++){ change(d[i], +1); } } dfs_clear(u,0,0);//手动清空 // for(int k=0;k<adj[u].size();k++){ v = adj[u][k].v; //w = adj[u][k].w; if(vis[v]) continue; ans += work(v,sz[v]); } return ans; } int main(){ cin>>N; int u,v,w; for(int i=1;i<N;i++){ cin>>u>>v>>w; adj[u].push_back(edge(v,w)); adj[v].push_back(edge(u,w)); } cin>>K; cout<<work(1, N)<<endl; return 0; }