思路

正难则反

我们定义一条路径是 “好路径”,当且仅当满足两个条件:

  1. 条件 A:路径上点权的最小值 ≤ a
  2. 条件 B:路径上点权的最大值 ≥ b

直接求同时满足 A 和 B 的路径数比较困难。我们可以通过求它的反面来解决。

根据容斥原理:

满足 A 且满足 B 的路径数 = 总路径数 - 不满足 A 的路径数 - 不满足 B 的路径数 + 同时不满足 A 和 B 的路径数

我们来翻译一下这三个排除条件:

  • 不满足 A:路径上点权的最小值 > a,也就是说路径上的所有点权都 > a。
  • 不满足 B:路径上点权的最大值 < b,也就是说路径上的所有点权都 < b。
  • 同时不满足 A 和 B:路径上的所有点权都满足 a < wᵢ < b。


如何计算满足某种条件(例如所有点权 > a)的路径数?

这时候并查集就派上用场了:

  1. 我们只保留点权 > a 的节点,将原本无向图中的边(如果两端点权都 > a)用并查集连起来。
  2. 此时树会被分成若干个连通块。对于一个包含 k 个节点的连通块,内部能构成的路径数量就是 k*(k+1)/2。
  3. 将所有连通块的路径数求和即可。


对于 “所有点权 < b” 和 “所有点权在 (a,b) 之间”,我们可以复用完全相同的逻辑。


code

// 并查集
int fa[M];
int sz[M]; // sz 数组维护连通块的大小
void init()
{
    for(int i = 1; i <= n; i++)
    {
        fa[i] = i;
        sz[i] = 1;
    }
}
int find(int a)
{
    return a == fa[a] ? a : fa[a] = find(fa[a]);
}
bool same(int a, int b)
{
    return find(a) == find(b);
}
void join(int a, int b)
{
    a = find(a);
    b = find(b);
    if(a != b)
    {
        fa[a] = b;
        sz[b] += sz[a]; // ★合并时累加连通块大小
    }
}
//统计所有点权在 [l,r] 闭区间内的路径总数
int jisuan(int l, int r, vi &w, vector<pii> &edge)
{
    init();// 初始化当前 n 个节点的并查集
    for (auto i : edge)
    {
        int u = i.first;
        int v = i.second;
		// 只有当边的两端点权都在允许的范围内时,才将它们连通
        if (w[u] >= l && w[u] <= r && w[v] >= l && w[v] <= r)
        {
            join(u, v);
        }
    }
    int res = 0;
    for (int i = 1; i <= n; i++)
    {
        if (w[i] >= l && w[i] <= r && fa[i] == i)
        {
			// 如果点 i 的权值满足条件,并且它是所在连通块的代表元
            int num = sz[i];
            res += num * (num + 1) / 2;
        }
    }
    return res;
}
void solve()
{
    int a, b;
    cin >> n >> a >> b;
    vi w(n + 1);
    for (int i = 1; i <= n; i++)
    {
        cin >> w[i];
    }
    vector<pii> edge(n - 1);
    for (int i = 0; i < n - 1; i++)
    {
        cin >> edge[i].first >> edge[i].second;
    }
    int sum = n * (n + 1) / 2; // 所有可能的路径总数
	 
    int ra = jisuan(a + 1, 2e18, w, edge);
	// 不满足A条件:所有点权 > a (导致最小值不会 <= a)
    // 权值范围:[a + 1, 2e18]
    int rb = jisuan(-2e18, b - 1, w, edge);
	// 不满足B条件:所有点权 < b (导致最大值不会 >= b)
    // 权值范围:[-2e18, b - 1]
    int rab = jisuan(a + 1, b - 1, w, edge);
	// 同时不满足A和B:所有点权 > a 且 < b
    // 权值范围:[a + 1, b - 1]
    int ans = sum - ra - rb + rab;
	// 容斥原理:
    // 满足 A 和 B = 总路径数 - 不满足A - 不满足B + (A和B都不满足)
    cout << ans << endl;
}