这个题目首先思路是参考书上的,就是一个树上差分的思想,通过观察题目可以发现,一定要先砍主要边,然后才能砍附加边,这样的话,我们可以发现,在一棵生成树里面,如果加上一条附加边的话,会形成一个环,如果首先砍这个环上面的主要边,那么第二次一定要砍这条附加边,才能把树分成两半,这是唯一的,但是如果你砍的不是在某一个环里面的主要边的话,那么你一次就已经把树分成了两半,第二次就可以砍任意的附加边即可。所以这就是我们的答案。然后,我们观察可以发现,我们可以这样,我们初始化每一个节点的权值为0,在添加附加边的时候,附加边的两个端点的权值都加上1,然后这两个节点的最近公共祖先-2。这样操作下来的话,我们最后用一个树上DP的方法求出以节点i为根的子树中各节点的权值和,我们假设为f[i],如果f[i]为0,那么答案加上m,如果为1,答案累加1,最后就可以求出我们的答案了。
我们可以发现,主要的难点就是求出附件边的两个端点的最近公共祖先,我本来是用tarjan算法求的,但是莫名会出现错误,所以还是转回用树上倍增的方法进行求解了。搞不懂。其实用tarjan算法的话时间复杂度更低的。
#include<iostream>
#include<algorithm>
#include<cstring>
#include<set>
#include<map>
#include<vector>
#include<queue>
#include<utility>
#include<cmath>
#include<stack>
#include<fstream>
#define pii pair<int,int>
#define mk make_pair
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
ll gcd(ll x1,ll y1){
if(x1==0){
return y1;
}
else{
return gcd(y1, x1 % y1);
}
}
ll qk(ll x1,ll y1){
if(y1==0)
return 1;
if(y1&1)
return x1 * qk(x1, y1 - 1) % mod;
else{
ll mul = qk(x1, y1 / 2);
return mul * mul % mod;
}
}
const int Size = 100010;
int ver[2 * Size], nxt[2 * Size], head[Size];
int fa[Size], v[Size], lc[Size], ans[Size], num[2 * Size];
int f[Size][20], d[Size];
int n, m, tot, t;
queue<int> q;
void add(int x, int y)
{
ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
void bfs(){
q.push(1);
d[1] = 1;
while(q.size()){
int x = q.front();
q.pop();
for (int i = head[x]; i;i=nxt[i]){
int y = ver[i];
if(d[y])
continue;
d[y] = d[x] + 1;
f[y][0] = x;
for(int j = 1; j <= t;j++){
f[y][j] = f[f[y][j - 1]][j - 1];
}
q.push(y);
}
}
}
int lca(int x,int y){
if(d[x]>d[y])
swap(x, y);
for (int i = t; i >= 0;i--){
if(d[f[y][i]]>=d[x])
y = f[y][i];
}
if(x==y)
return x;
for (int i = t; i >= 0;i--){
if(f[x][i]!=f[y][i])
x = f[x][i], y = f[y][i];
}
return f[x][0];
}
void dp(int x){
v[x] = 1;
for (int i = head[x]; i;i=nxt[i]){
int y = ver[i];
if(v[y])
continue;
dp(y);
num[x] += num[y];
}
}
void solve()
{
cin >> n >> m;
t = (int)(log(n) / log(2)) + 1;
for (int i = 0; i < n - 1; i++)
{
int a, b;
cin >> a >> b;
add(a, b);
add(b, a);
}
bfs();
tot = 0;
for (int i = 0; i < m; i++)
{
int x, y;
cin >> x >> y;
num[x]++, num[y]++;
num[lca(x, y)] -= 2;
}
int ans = 0;
dp(1);
for (int i = 2; i <= n;i++){
if(num[i]==1)
ans++;
else if (num[i] == 0)
{
ans += m;
}
}
cout << ans << endl;
}
int main()
{
//freopen("test.in","r",stdin);
int t;
//cin >> t;
t = 1;
while (t--)
{
solve();
}
fclose(stdin);
return 0;
}
京公网安备 11010502036488号