这个题目首先思路是参考书上的,就是一个树上差分的思想,通过观察题目可以发现,一定要先砍主要边,然后才能砍附加边,这样的话,我们可以发现,在一棵生成树里面,如果加上一条附加边的话,会形成一个环,如果首先砍这个环上面的主要边,那么第二次一定要砍这条附加边,才能把树分成两半,这是唯一的,但是如果你砍的不是在某一个环里面的主要边的话,那么你一次就已经把树分成了两半,第二次就可以砍任意的附加边即可。所以这就是我们的答案。然后,我们观察可以发现,我们可以这样,我们初始化每一个节点的权值为0,在添加附加边的时候,附加边的两个端点的权值都加上1,然后这两个节点的最近公共祖先-2。这样操作下来的话,我们最后用一个树上DP的方法求出以节点i为根的子树中各节点的权值和,我们假设为f[i],如果f[i]为0,那么答案加上m,如果为1,答案累加1,最后就可以求出我们的答案了。
我们可以发现,主要的难点就是求出附件边的两个端点的最近公共祖先,我本来是用tarjan算法求的,但是莫名会出现错误,所以还是转回用树上倍增的方法进行求解了。搞不懂。其实用tarjan算法的话时间复杂度更低的。
#include<iostream> #include<algorithm> #include<cstring> #include<set> #include<map> #include<vector> #include<queue> #include<utility> #include<cmath> #include<stack> #include<fstream> #define pii pair<int,int> #define mk make_pair using namespace std; typedef long long ll; typedef unsigned long long ull; const int inf=0x3f3f3f3f; const int mod=1e9+7; template <typename T> void read(T &t) { t=0; char ch=getchar(); int f=1; while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); } do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f; } ll gcd(ll x1,ll y1){ if(x1==0){ return y1; } else{ return gcd(y1, x1 % y1); } } ll qk(ll x1,ll y1){ if(y1==0) return 1; if(y1&1) return x1 * qk(x1, y1 - 1) % mod; else{ ll mul = qk(x1, y1 / 2); return mul * mul % mod; } } const int Size = 100010; int ver[2 * Size], nxt[2 * Size], head[Size]; int fa[Size], v[Size], lc[Size], ans[Size], num[2 * Size]; int f[Size][20], d[Size]; int n, m, tot, t; queue<int> q; void add(int x, int y) { ver[++tot] = y, nxt[tot] = head[x], head[x] = tot; } void bfs(){ q.push(1); d[1] = 1; while(q.size()){ int x = q.front(); q.pop(); for (int i = head[x]; i;i=nxt[i]){ int y = ver[i]; if(d[y]) continue; d[y] = d[x] + 1; f[y][0] = x; for(int j = 1; j <= t;j++){ f[y][j] = f[f[y][j - 1]][j - 1]; } q.push(y); } } } int lca(int x,int y){ if(d[x]>d[y]) swap(x, y); for (int i = t; i >= 0;i--){ if(d[f[y][i]]>=d[x]) y = f[y][i]; } if(x==y) return x; for (int i = t; i >= 0;i--){ if(f[x][i]!=f[y][i]) x = f[x][i], y = f[y][i]; } return f[x][0]; } void dp(int x){ v[x] = 1; for (int i = head[x]; i;i=nxt[i]){ int y = ver[i]; if(v[y]) continue; dp(y); num[x] += num[y]; } } void solve() { cin >> n >> m; t = (int)(log(n) / log(2)) + 1; for (int i = 0; i < n - 1; i++) { int a, b; cin >> a >> b; add(a, b); add(b, a); } bfs(); tot = 0; for (int i = 0; i < m; i++) { int x, y; cin >> x >> y; num[x]++, num[y]++; num[lca(x, y)] -= 2; } int ans = 0; dp(1); for (int i = 2; i <= n;i++){ if(num[i]==1) ans++; else if (num[i] == 0) { ans += m; } } cout << ans << endl; } int main() { //freopen("test.in","r",stdin); int t; //cin >> t; t = 1; while (t--) { solve(); } fclose(stdin); return 0; }