题意:
给定一棵 n 个节点的树,每条边有边权,求出树上两点距离小于等于 k 的点对数量。
题解:
点分治的模板题是求等于K的路径条数
本题是求小于等于K的路径条数,我们只需要改变统计答案即可
原本统计答案是对一个路劲长度len,判断K-len在之前的子树中出现多少次,用f数组来记录,直接查看f[k-len]等于多少即可
现在统计答案是对一个路径长度len,判断小于等于K-len的数在之前的子树中出现多少次,统计区间数量,我们可以用树状数组来实现,对于每个len我们插入数组f[]其中,然后求1~K-len的区间查询
代码:
详细看代码
#include<bits/stdc++.h>
#define MAXN 40005
#define MAXK 20005
using namespace std;
int N,K;
struct edge{
int v,w;
edge(int v=0, int w=0):v(v), w(w){}
};
vector<edge> adj[MAXN];
//
int fw[MAXK];
int lbt(int x){
return x & (-x);
}
int getsum(int x){
int ans = 0;
for(;x>0;x-=lbt(x)){
ans += fw[x];
}
return ans;
}
void change(int x, int dv){
for(;x<=K;x+=lbt(x)){
fw[x] += dv;
}
}
//
int sz[MAXN];
bool vis[MAXN];
int rt;
void dfs_rt(int u, int fa, int tot){//O(tot)求根
sz[u] = 1;
int v, n = 0;
for(int k=0;k<adj[u].size();k++){
v = adj[u][k].v;
if(v==fa || vis[v]) continue;
dfs_rt(v, u, tot);
sz[u] += sz[v];
n = max(n, sz[v]);
}
n = max(n, tot-sz[u]);
if(n*2 <= tot) rt = u;
}
void dfs_sz(int u, int fa){//O(tot)求子树
sz[u] = 1;
int v, n = 0;
for(int k=0;k<adj[u].size();k++){
v = adj[u][k].v;
if(v==fa || vis[v]) continue;
dfs_sz(v, u);
sz[u] += sz[v];
}
}
int d[MAXN], cnt = 0;
void dfs_dis(int u, int fa, int dis){//O(tot)
d[++cnt] = dis;//记录距离
int v,w;
for(int k=0;k<adj[u].size();k++){
v = adj[u][k].v;
w = adj[u][k].w;
if(v==fa || vis[v]) continue;
dfs_dis(v, u, dis + w);
}
}
void dfs_clear(int u, int fa, int dis){//O(tot)清零
if(dis) change(dis, -1);
int v,w;
for(int k=0;k<adj[u].size();k++){
v = adj[u][k].v;
w = adj[u][k].w;
if(v==fa || vis[v]) continue;
dfs_clear(v, u, dis + w);
}
}
int work(int u, int tot){
dfs_rt(u, 0, tot);
u = rt;
dfs_sz(u, 0);
vis[u] = 1;
int v,w;
//solve
/*
求出每个点到根节点的距离,然后记录到数组d
然后对于d[i],查询0~K-d[i]区间大小,用树状数组实现
如果d[i]<=K,那么d[i]本身也符合条件
将数组d插入到树状数组
当以一个点为根的一轮结束时记得情况数组树状数组
*/
int ans = 0;
for(int k=0;k<adj[u].size();k++){
v = adj[u][k].v;
w = adj[u][k].w;
if(vis[v]) continue;
cnt = 0;
dfs_dis(v, u, w);
for(int i=1;i<=cnt;i++){
if(d[i] <= K) ++ans;
ans += getsum(max(0,K-d[i]));
}
for(int i=1;i<=cnt;i++){
change(d[i], +1);
}
}
dfs_clear(u,0,0);//手动清空
//
for(int k=0;k<adj[u].size();k++){
v = adj[u][k].v;
//w = adj[u][k].w;
if(vis[v]) continue;
ans += work(v,sz[v]);
}
return ans;
}
int main(){
cin>>N;
int u,v,w;
for(int i=1;i<N;i++){
cin>>u>>v>>w;
adj[u].push_back(edge(v,w));
adj[v].push_back(edge(u,w));
}
cin>>K;
cout<<work(1, N)<<endl;
return 0;
} 
京公网安备 11010502036488号