图片说明
这个题目首先思路是参考书上的,就是一个树上差分的思想,通过观察题目可以发现,一定要先砍主要边,然后才能砍附加边,这样的话,我们可以发现,在一棵生成树里面,如果加上一条附加边的话,会形成一个环,如果首先砍这个环上面的主要边,那么第二次一定要砍这条附加边,才能把树分成两半,这是唯一的,但是如果你砍的不是在某一个环里面的主要边的话,那么你一次就已经把树分成了两半,第二次就可以砍任意的附加边即可。所以这就是我们的答案。然后,我们观察可以发现,我们可以这样,我们初始化每一个节点的权值为0,在添加附加边的时候,附加边的两个端点的权值都加上1,然后这两个节点的最近公共祖先-2。这样操作下来的话,我们最后用一个树上DP的方法求出以节点i为根的子树中各节点的权值和,我们假设为f[i],如果f[i]为0,那么答案加上m,如果为1,答案累加1,最后就可以求出我们的答案了。
我们可以发现,主要的难点就是求出附件边的两个端点的最近公共祖先,我本来是用tarjan算法求的,但是莫名会出现错误,所以还是转回用树上倍增的方法进行求解了。搞不懂。其实用tarjan算法的话时间复杂度更低的。

#include<iostream>
#include<algorithm>
#include<cstring>
#include<set>
#include<map>
#include<vector>
#include<queue>
#include<utility>
#include<cmath>
#include<stack>
#include<fstream>
#define pii pair<int,int>
#define mk make_pair
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
template <typename T> void read(T &t) {
    t=0; char ch=getchar(); int f=1;
    while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
    do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
ll gcd(ll x1,ll y1){
    if(x1==0){
        return y1;
    }
    else{
        return gcd(y1, x1 % y1);
    }
}

ll qk(ll x1,ll y1){
    if(y1==0)
        return 1;
    if(y1&1)
        return x1 * qk(x1, y1 - 1) % mod;
    else{
        ll mul = qk(x1, y1 / 2);
        return mul * mul % mod;
    }
}
const int Size = 100010;
int ver[2 * Size], nxt[2 * Size], head[Size];
int fa[Size], v[Size], lc[Size], ans[Size], num[2 * Size];
int f[Size][20], d[Size];
int n, m, tot, t;
queue<int> q;
void add(int x, int y)
{
    ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
void bfs(){
    q.push(1);
    d[1] = 1;
    while(q.size()){
        int x = q.front();
        q.pop();
        for (int i = head[x]; i;i=nxt[i]){
            int y = ver[i];
            if(d[y])
                continue;
            d[y] = d[x] + 1;
            f[y][0] = x;
            for(int j = 1; j <= t;j++){
                f[y][j] = f[f[y][j - 1]][j - 1];
            }
            q.push(y);
        }
    }
}
int lca(int x,int y){
    if(d[x]>d[y])
        swap(x, y);
    for (int i = t; i >= 0;i--){
        if(d[f[y][i]]>=d[x])
            y = f[y][i];
    }
    if(x==y)
        return x;
    for (int i = t; i >= 0;i--){
        if(f[x][i]!=f[y][i])
            x = f[x][i], y = f[y][i];
    }
    return f[x][0];
}
void dp(int x){
    v[x] = 1;
    for (int i = head[x]; i;i=nxt[i]){
        int y = ver[i];
        if(v[y])
            continue;
        dp(y);
        num[x] += num[y];
    }
}
void solve()
{
    cin >> n >> m;
    t = (int)(log(n) / log(2)) + 1;
    for (int i = 0; i < n - 1; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }
    bfs();
    tot = 0;
    for (int i = 0; i < m; i++)
    {
        int x, y;
        cin >> x >> y;
        num[x]++, num[y]++;
        num[lca(x, y)] -= 2;
    }
    int ans = 0;
    dp(1);
    for (int i = 2; i <= n;i++){
        if(num[i]==1)
            ans++;
        else if (num[i] == 0)
        {
            ans += m;
        }
    }
    cout << ans << endl;
}
int main()
{
    //freopen("test.in","r",stdin);
    int t;
    //cin >> t;
    t = 1;
    while (t--)
    {
        solve();
    }
    fclose(stdin);
    return 0;
}