为什么会有人想到这种 数据啊!!!

一眼过去,树上联通子集数量,上点分治,对于当前重心算经过它的数量,O(nlogn)O(n\log n) 非常优秀。发现还要断边分别算左右两个联通块的答案,然后我开始想怎么在点分治的时候乱搞出来,于是成功 yy 没出来。

之后突然发现,算答案好像有点就可以转移,于是就发现可以树形 DP。

fuf_u 树上以 uu 为根的子树的答案,也就是断开 uu 到它父亲的边后 uu 这边的子树答案。

gug_u 树上以 uu 为根的子树内经过 uu 的答案。

不难发现 DP 转移是:

gu=vsubtree(u)(gv+1)fu=vsubtree(u)fv+gug_u = \sum_{v \in subtree(u)}(g_v + 1) \\ f_u = \sum_{v \in subtree(u)} f_v + g_u

rootroot 为根进行一次树形 DP 即可求出所有节点断掉父边的答案。

我就分析啊,怎样才能求出所有边的答案,于是我构造了这个猜想:

alt

alt

SSTT 是树上直径的两个端点,如果分别对 SSTT 进行一次树形 DP,感觉好像可以覆盖所有的边。于是被样例 hack 了。。。事实上直径没有卵有,任意两个度数为一的点都是如此。

被卡了好一会,如果有一个答案未被算到就对另一个点进行一次树形 DP,正确性显然,但时间复杂度未知,但总感觉上界不会很大。

然后就突发奇想,如果对每一个根都树形 DP 一次不就知道所有的答案了吗。自然而然就会想到换根 DP,而这个树形 DP 恰好也只和儿子有关,换根式子也比较好推,于是这道题就完了。

u(root)vfu=fufvgugu=gu÷(gv+1)fu=fu+gufv=fvgvgv=gv×(gu+1)fv=fv+gvu(root) \Rightarrow v \\ f_u = f_u - f_v - g_u \\ g_u = g_u \div (g_v + 1) \\ f_u = f_u + g_u \\ f_v = f_v - g_v \\ g_v = g_v \times (g_u + 1)\\ f_v = f_v + g_v

真的完了,发现 wa 了。

直到赛后才知道有这种数据!!!

gv+1g_v + 1 可能等于 998244353998244353, 这样换根时就没有逆元了,题解给出了什么前缀积后缀积的玩意,看不懂,其实不用这么麻烦,只需要记录非那种儿子的积,再记录那种特殊儿子的数量,这样就可以转移了。

code:

#include <bits/stdc++.h>

using namespace std;
using i64 = long long;

i64 read() {
	i64 x = 0, f = 0; char ch = getchar();
	while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
	while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
	return f ? -x : x;
}
int __stk[100], __tp;
void put(i64 x) {
	if (x < 0) putchar('-'), x = -x;
	do { __stk[++__tp] = x % 10, x /= 10; } while (x);
	while (__tp) putchar(__stk[__tp--] + '0');
}
const int N = 5e5 + 10, mod = 998244353;
int n, num;
int head[N], nxt[N << 1], to[N << 1], cnt = 1;
struct edge {
	int u, v, l, r;
}e[N];
i64 f[N], g[N], c[N], dfn[N << 1];
bool vis[N];
void add(int u, int v) {
	to[++cnt] = v, nxt[cnt] = head[u], head[u] = cnt;
	to[++cnt] = u, nxt[cnt] = head[v], head[v] = cnt;
}
i64 Pow(i64 x, i64 y) {
	i64 ans = 1;
	for ( ; y; y >>= 1, x = x * x % mod) 
		y & 1 ? ans = ans * x % mod : 0ll;
	return ans;
}
void dfs(int u, int fa) {
	dfn[++num] = u, f[u] = c[u] = 0, g[u] = 1;
	for (int i = head[u], v = 0; i; i = nxt[i]) {
		if ((v = to[i])== fa) continue;
		dfs(v, u), dfn[++num] = u;
		f[u] = (f[u] + f[v]) % mod;
		if (g[v] + 1 == mod) ++c[u];
		else g[u] = g[u] * (g[v] + 1) % mod;
	}
	f[u] = (f[u] + (c[u] ? 0 : g[u])) % mod;
}
void calc(int x) {
	if (vis[x]) return;
	vis[x] = 1;
	for (int i = head[x]; i; i = nxt[i]) {
		int v = to[i];
		if (e[i >> 1].u == v) e[i >> 1].l = f[v];
		if (e[i >> 1].v == v) e[i >> 1].r = f[v];
	}
}
void move(int u, int v) {
	f[u] = (f[u] - f[v] - (c[u] ? 0 : g[u]) + mod + mod) % mod;
	if (g[v] + 1 == mod) --c[u];
	else g[u] = g[u] * Pow(g[v] + 1, mod - 2) % mod;
	f[u] = (f[u] + (c[u] ? 0 : g[u])) % mod;
	f[v] = (f[v] - (c[v] ? 0 : g[v]) + mod) % mod;
	if (g[u] + 1 == mod) ++c[v];
	else g[v] = g[v] * (g[u] + 1) % mod;
	f[v] = (f[v] + f[u] + (c[v] ? 0 : g[v])) % mod;
}

int main() {
	//freopen(".in", "r", stdin);
	//freopen(".out", "w", stdout);
	n = read();
	for (int i = 1; i < n; ++i) e[i].u = read(), e[i].v = read(), add(e[i].u, e[i].v);
	dfs(1, 0), calc(1);
	for (int i = 2; i <= num; ++i) move(dfn[i - 1], dfn[i]), calc(dfn[i]);
	for (int i = 1; i < n; ++i) put(e[i].l), putchar(' '), put(e[i].r), putchar('\n');
	return 0;
}