思路
正难则反
我们定义一条路径是 “好路径”,当且仅当满足两个条件:
- 条件 A:路径上点权的最小值 ≤ a
- 条件 B:路径上点权的最大值 ≥ b
直接求同时满足 A 和 B 的路径数比较困难。我们可以通过求它的反面来解决。
根据容斥原理:
满足 A 且满足 B 的路径数 = 总路径数 - 不满足 A 的路径数 - 不满足 B 的路径数 + 同时不满足 A 和 B 的路径数
我们来翻译一下这三个排除条件:
- 不满足 A:路径上点权的最小值 > a,也就是说路径上的所有点权都 > a。
- 不满足 B:路径上点权的最大值 < b,也就是说路径上的所有点权都 < b。
- 同时不满足 A 和 B:路径上的所有点权都满足 a < wᵢ < b。
如何计算满足某种条件(例如所有点权 > a)的路径数?
这时候并查集就派上用场了:
- 我们只保留点权 > a 的节点,将原本无向图中的边(如果两端点权都 > a)用并查集连起来。
- 此时树会被分成若干个连通块。对于一个包含 k 个节点的连通块,内部能构成的路径数量就是k*(k+1)/2。
- 将所有连通块的路径数求和即可。
对于 “所有点权 < b” 和 “所有点权在 (a,b) 之间”,我们可以复用完全相同的逻辑。
// 并查集
int fa[M];
int sz[M];
void init()
{
for (int i = 1; i <= n; i++)
{
fa[i] = i;
sz[i] = 1;
}
}
int find(int a)
{
return a == fa[a] ? a : fa[a] = find(fa[a]);
}
bool same(int a, int b)
{
return find(a) == find(b);
}
void join(int a, int b)
{
a = find(a);
b = find(b);
if (a != b)
{
fa[a] = b;
sz[b] += sz[a]; // 合并时累加连通块大小
}
}
// 统计所有点权在 [l, r] 闭区间内的路径总数
int jisuan(int l, int r, vi &w, vector<pii> &edge)
{
init();// 初始化当前 n 个节点的并查集
for (auto i : edge)
{
int u = i.first;
int v = i.second;
// 只有当边的两端点权都在允许的范围内时,才将它们连通
if (w[u] >= l && w[u] <= r && w[v] >= l && w[v] <= r)
{
join(u, v);
}
}
int res = 0;
for (int i = 1; i <= n; i++)
{
// 如果点 i 的权值满足条件,并且它是所在连通块的代表元
if (w[i] >= l && w[i] <= r && fa[i] == i)
{
int num = sz[i];
res += num * (num + 1) / 2;
}
}
return res;
}
void solve()
{
int a, b;
cin >> n >> a >> b;
vi w(n + 1);
for (int i = 1; i <= n; i++)
{
cin >> w[i];
}
vector<pii> edge(n - 1);
for (int i = 0; i < n - 1; i++)
{
cin >> edge[i].first >> edge[i].second;
}
int sum = n * (n + 1) / 2; // 所有可能的路径总数
int ra = jisuan(a + 1, 2e18, w, edge);
// 不满足A条件:所有点权 > a (导致最小值不会 <= a)
// 权值范围:[a + 1, 2e18]
int rb = jisuan(-2e18, b - 1, w, edge);
// 不满足B条件:所有点权 < b (导致最大值不会 >= b)
// 权值范围:[-2e18, b - 1]
int rab = jisuan(a + 1, b - 1, w, edge);
// 同时不满足A和B:所有点权 > a 且 < b
// 权值范围:[a + 1, b - 1]
int ans = sum - ra - rb + rab;
// 容斥原理:
// 满足 A 和 B = 总路径数 - 不满足A - 不满足B + (A和B都不满足)
cout << ans << endl;
}
火车头
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
using ll = long long;
#define int long long
#define vi vector<int>
#define vs vector<string>
#define vvi vector<vector<int>>
#define vb vector<bool>
#define pii pair<int, int>
#define all(a) a.begin(), a.end()
#define ull unsigned long long
#define pb push_back
const int M = 5e5 + 7;
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 1e9;
const double pi = acos(-1.0);
int inf = -1e18;
int dir[4][2] = {{0, 1}, {1, 0}, {-1, 0}, {0, -1}};
int n, m, k = 0, x, y, l, r;
string s;

京公网安备 11010502036488号