Description
在2016年,佳媛姐姐刚刚学习了树,非常开心。现在他想解决这样一个问题:给定一颗有根树(根为1),有以下
两种操作:1. 标记操作:对某个结点打上标记(在最开始,只有结点1有标记,其他结点均无标记,而且对于某个
结点,可以打多次标记。)2. 询问操作:询问某个结点最近的一个打了标记的祖先(这个结点本身也算自己的祖
先)你能帮帮他吗?
Input
输入第一行两个正整数N和Q分别表示节点个数和操作次数接下来N-1行,每行两个正整数u,v(1≤u,v≤n)表示u到v
有一条有向边接下来Q行,形如“oper num”oper为“C”时表示这是一个标记操作,oper为“Q”时表示这是一个询
问操作对于每次询问操作,1 ≤ N, Q ≤ 100000。
Output
输出一个正整数,表示结果
Sample Input
5 5
1 2
1 3
2 4
2 5
Q 2
C 2
Q 2
Q 5
Q 3
Sample Output
1
2
2
1
Solution
看到这题就没有脑子的写了个树剖...
挺显然的一个做法。树剖+线段树维护是否有被标记,然后询问的时候,每次跳链时看一下这条链有没有点被标记过,有就在这条链上二分(重链\(dfs\)序是连续的)。
复杂度是\(O(n \log n \log n)\)
然而这题有\(O(n)\)做法...
有一道经典的奶牛题,就是那道牛涂防晒霜的。这题其实就是那题上了树的版本。
考虑把操作离线了,统计出每个点被染色的次数,如果被标记了就父指针指向自己,否则指向父亲。
倒着处理操作,每次遇到染色就把对应点的次数\(-1\)就好,变成\(0\)了就把父指针指向父亲。
对于询问就直接\(find\)一遍就是答案了。
树剖的代码:
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline
namespace io {
#define in(a) a=read()
#define out(a) write(a)
#define outn(a) out(a),putchar('\n')
#define I_int long long
inline I_int read() {
I_int x = 0 , f = 1 ; char c = getchar() ;
while( c < '0' || c > '9' ) { if( c == '-' ) f = -1 ; c = getchar() ; }
while( c >= '0' && c <= '9' ) { x = x * 10 + c - '0' ; c = getchar() ; }
return x * f ;
}
char F[ 200 ] ;
inline void write( I_int x ) {
if( x == 0 ) { putchar( '0' ) ; return ; }
I_int tmp = x > 0 ? x : -x ;
if( x < 0 ) putchar( '-' ) ;
int cnt = 0 ;
while( tmp > 0 ) {
F[ cnt ++ ] = tmp % 10 + '0' ;
tmp /= 10 ;
}
while( cnt > 0 ) putchar( F[ -- cnt ] ) ;
}
#undef I_int
}
using namespace io ;
using namespace std ;
#define N 100010
int n, m;
int top[N], fa[N], dep[N], siz[N], id[N], a[N];
int tim;
struct Tree {
int l, r, sum;
}t[N<<2];
int head[N], cnt;
struct edge {
int to, nxt;
}e[N<<1];
void ins(int u, int v) {
e[++cnt] = (edge) {v, head[u]};
head[u] = cnt;
}
void dfs1(int u) {
siz[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa[u]) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs1(v);
siz[u] += siz[v];
}
}
void dfs2(int u, int topf) {
top[u] = topf;
id[u] = ++tim;
a[tim] = u;
int k = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa[u]) continue;
if(siz[v] > siz[k]) k = v;
}
if(!k) return;
dfs2(k, topf);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa[u] || v == k) continue;
dfs2(v, v);
}
}
#define lc (rt << 1)
#define rc (rt << 1 | 1)
void build(int l, int r, int rt) {
t[rt].l = l; t[rt].r = r;
int mid = (l + r) >> 1;
if(l == r) return;
build(l, mid, lc);
build(mid + 1, r, rc);
}
#define l (t[rt].l)
#define r (t[rt].r)
void up(int rt) {
t[rt].sum = t[lc].sum + t[rc].sum;
}
void upd(int pos, int rt) {
if(l == r) {
t[rt].sum = 1;
return;
}
int mid = (l + r) >> 1;
if(pos <= mid) upd(pos, lc);
else upd(pos, rc);
up(rt);
}
int query(int L, int R, int rt) {
if(L <= l && r <= R) return t[rt].sum;
int res = 0, mid = (l + r) >> 1;
if(L <= mid) res += query(L, R, lc);
if(R > mid) res += query(L, R, rc);
return res;
}
#undef lc
#undef rc
#undef l
#undef r
int solve(int x) {
while(top[x] != 1) {
// printf("%d %d %d\n", x, top[x], fa[top[x]]);
int sum = query(id[top[x]], id[x], 1);
if(sum) {
int l = id[top[x]], r = id[x];
while(l + 1 < r) {
int mid = (l + r) >> 1;
sum = query(mid + 1, r, 1);
if(sum) l = mid + 1;
else r = mid;
}
if(query(r, r, 1)) return a[r];
return a[l];
}
x = fa[top[x]];
}
int l = 1, r = id[x], sum = 0;
while(l + 1 < r) {
int mid = (l + r) >> 1;
sum = query(mid + 1, r, 1);
if(sum) l = mid + 1;
else r = mid;
}
if(query(r, r, 1)) return a[r];
return a[l];
}
int main() {
in(n); in(m);
for(int i = 1, u, v; i < n; ++i) {
in(u), in(v);
ins(u, v), ins(v, u);
}
dfs1(1);
dfs2(1, 1);
build(1, n, 1);
upd(id[1], 1);
for(int i = 1; i <= m; ++i) {
char op[2];
int x;
scanf("%s", op);
in(x);
if(op[0] == 'C') upd(id[x], 1);
else printf("%d\n", solve(x));
}
return 0;
}
并查集代码:
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline
namespace io {
#define in(a) a=read()
#define out(a) write(a)
#define outn(a) out(a),putchar('\n')
#define I_int long long
inline I_int read() {
I_int x = 0 , f = 1 ; char c = getchar() ;
while( c < '0' || c > '9' ) { if( c == '-' ) f = -1 ; c = getchar() ; }
while( c >= '0' && c <= '9' ) { x = x * 10 + c - '0' ; c = getchar() ; }
return x * f ;
}
char F[ 200 ] ;
inline void write( I_int x ) {
if( x == 0 ) { putchar( '0' ) ; return ; }
I_int tmp = x > 0 ? x : -x ;
if( x < 0 ) putchar( '-' ) ;
int cnt = 0 ;
while( tmp > 0 ) {
F[ cnt ++ ] = tmp % 10 + '0' ;
tmp /= 10 ;
}
while( cnt > 0 ) putchar( F[ -- cnt ] ) ;
}
#undef I_int
}
using namespace io ;
using namespace std ;
#define N 100010
int n = read(), m = read();
int head[N], cnt;
struct edge {
int to, nxt;
}e[N<<1];
int a[N], f[N], fa[N];
int op[N], c[N];
int ans[N], tot;
void ins(int u, int v) {
e[++cnt] = (edge) {v, head[u]};
head[u] = cnt;
}
void dfs(int u) {
if(a[u]) f[u] = u;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa[u]) continue;
if(!a[v]) f[v] = u;
fa[v] = u;
dfs(v);
}
}
int find(int u) {
if(f[u] == u) return u;
return f[u] = find(f[u]);
}
int main() {
for(int i = 1; i < n; ++i) {
int u = read(), v = read();
ins(u, v), ins(v, u);
}
char s[2];
a[1] = 1;
for(int i = 1; i <= m; ++i) {
scanf("%s", s);
in(c[i]);
op[i] = s[0] == 'C';
if(s[0] == 'C') ++a[c[i]];
}
dfs(1);
for(int i = 1; i <= n; ++i) f[i] = find(f[i]);
for(int i = m; i; --i) {
if(op[i]) {
--a[c[i]];
if(!a[c[i]]) f[c[i]] = fa[c[i]];
} else {
ans[++tot] = find(f[c[i]]);
}
}
reverse(ans + 1, ans + tot + 1);
for(int i = 1; i <= tot; ++i) printf("%d\n", ans[i]);
}