Description

在2016年,佳媛姐姐刚刚学习了树,非常开心。现在他想解决这样一个问题:给定一颗有根树(根为1),有以下

两种操作:1. 标记操作:对某个结点打上标记(在最开始,只有结点1有标记,其他结点均无标记,而且对于某个

结点,可以打多次标记。)2. 询问操作:询问某个结点最近的一个打了标记的祖先(这个结点本身也算自己的祖

先)你能帮帮他吗?

Input

输入第一行两个正整数N和Q分别表示节点个数和操作次数接下来N-1行,每行两个正整数u,v(1≤u,v≤n)表示u到v

有一条有向边接下来Q行,形如“oper num”oper为“C”时表示这是一个标记操作,oper为“Q”时表示这是一个询

问操作对于每次询问操作,1 ≤ N, Q ≤ 100000。

Output

输出一个正整数,表示结果

Sample Input

5 5
1 2
1 3
2 4
2 5
Q 2
C 2
Q 2
Q 5
Q 3

Sample Output

1
2
2
1

Solution

看到这题就没有脑子的写了个树剖...

挺显然的一个做法。树剖+线段树维护是否有被标记,然后询问的时候,每次跳链时看一下这条链有没有点被标记过,有就在这条链上二分(重链\(dfs\)序是连续的)。

复杂度是\(O(n \log n \log n)\)

然而这题有\(O(n)\)做法...

有一道经典的奶牛题,就是那道牛涂防晒霜的。这题其实就是那题上了树的版本。

考虑把操作离线了,统计出每个点被染色的次数,如果被标记了就父指针指向自己,否则指向父亲。

倒着处理操作,每次遇到染色就把对应点的次数\(-1\)就好,变成\(0\)了就把父指针指向父亲。

对于询问就直接\(find\)一遍就是答案了。

树剖的代码:

#include <bits/stdc++.h>

#define ll long long
#define inf 0x3f3f3f3f
#define il inline

namespace io {

    #define in(a) a=read()
    #define out(a) write(a)
    #define outn(a) out(a),putchar('\n')

    #define I_int long long
    inline I_int read() {
        I_int x = 0 , f = 1 ; char c = getchar() ;
        while( c < '0' || c > '9' ) { if( c == '-' ) f = -1 ; c = getchar() ; }
        while( c >= '0' && c <= '9' ) { x = x * 10 + c - '0' ; c = getchar() ; }
        return x * f ;
    }
    char F[ 200 ] ;
    inline void write( I_int x ) {
        if( x == 0 ) { putchar( '0' ) ; return ; }
        I_int tmp = x > 0 ? x : -x ;
        if( x < 0 ) putchar( '-' ) ;
        int cnt = 0 ;
        while( tmp > 0 ) {
            F[ cnt ++ ] = tmp % 10 + '0' ;
            tmp /= 10 ;
        }
        while( cnt > 0 ) putchar( F[ -- cnt ] ) ;
    }
    #undef I_int

}
using namespace io ;

using namespace std ;

#define N 100010

int n, m;
int top[N], fa[N], dep[N], siz[N], id[N], a[N];
int tim;

struct Tree {
    int l, r, sum;
}t[N<<2];

int head[N], cnt;
struct edge {
    int to, nxt;
}e[N<<1]; 

void ins(int u, int v) {
    e[++cnt] = (edge) {v, head[u]};
    head[u] = cnt;
}

void dfs1(int u) {
    siz[u] = 1;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v);
        siz[u] += siz[v];
    }
}

void dfs2(int u, int topf) {
    top[u] = topf;
    id[u] = ++tim;
    a[tim] = u;
    int k = 0;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa[u]) continue;
        if(siz[v] > siz[k]) k = v;
    }
    if(!k) return;
    dfs2(k, topf);
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa[u] || v == k) continue;
        dfs2(v, v);
    }
}

#define lc (rt << 1)
#define rc (rt << 1 | 1)

void build(int l, int r, int rt) {
    t[rt].l = l; t[rt].r = r;
    int mid = (l + r) >> 1;
    if(l == r) return;
    build(l, mid, lc);
    build(mid + 1, r, rc);
}

#define l (t[rt].l)
#define r (t[rt].r)

void up(int rt) {
    t[rt].sum = t[lc].sum + t[rc].sum;
}

void upd(int pos, int rt) {
    if(l == r) {
        t[rt].sum = 1;
        return;
    }
    int mid = (l + r) >> 1;
    if(pos <= mid) upd(pos, lc);
    else upd(pos, rc);
    up(rt);
}

int query(int L, int R, int rt) {
    if(L <= l && r <= R) return t[rt].sum;
    int res = 0, mid = (l + r) >> 1;
    if(L <= mid) res += query(L, R, lc);
    if(R > mid) res += query(L, R, rc);
    return res;
}

#undef lc
#undef rc
#undef l
#undef r

int solve(int x) {
    while(top[x] != 1) {
//      printf("%d %d %d\n", x, top[x], fa[top[x]]);
        int sum = query(id[top[x]], id[x], 1);
        if(sum) {
            int l = id[top[x]], r = id[x];
            while(l + 1 < r) {
                int mid = (l + r) >> 1;
                sum = query(mid + 1, r, 1);
                if(sum) l = mid + 1;
                else r = mid;
            }
            if(query(r, r, 1)) return a[r];
            return a[l];
        }
        x = fa[top[x]];
    }
    int l = 1, r = id[x], sum = 0;
    while(l + 1 < r) {
        int mid = (l + r) >> 1;
        sum = query(mid + 1, r, 1);
        if(sum) l = mid + 1;
        else r = mid;
    }
    if(query(r, r, 1)) return a[r];
    return a[l];
}

int main() {
    in(n); in(m);
    for(int i = 1, u, v; i < n; ++i) {
        in(u), in(v);
        ins(u, v), ins(v, u);
    }
    
    dfs1(1);
    dfs2(1, 1);
    build(1, n, 1);
    upd(id[1], 1);
    
    for(int i = 1; i <= m; ++i) {
        char op[2];
        int x;
        scanf("%s", op);
        in(x);
        if(op[0] == 'C') upd(id[x], 1);
        else printf("%d\n", solve(x));
    }
    
    return 0;
}

并查集代码:

#include <bits/stdc++.h>

#define ll long long
#define inf 0x3f3f3f3f
#define il inline

namespace io {

    #define in(a) a=read()
    #define out(a) write(a)
    #define outn(a) out(a),putchar('\n')

    #define I_int long long
    inline I_int read() {
        I_int x = 0 , f = 1 ; char c = getchar() ;
        while( c < '0' || c > '9' ) { if( c == '-' ) f = -1 ; c = getchar() ; }
        while( c >= '0' && c <= '9' ) { x = x * 10 + c - '0' ; c = getchar() ; }
        return x * f ;
    }
    char F[ 200 ] ;
    inline void write( I_int x ) {
        if( x == 0 ) { putchar( '0' ) ; return ; }
        I_int tmp = x > 0 ? x : -x ;
        if( x < 0 ) putchar( '-' ) ;
        int cnt = 0 ;
        while( tmp > 0 ) {
            F[ cnt ++ ] = tmp % 10 + '0' ;
            tmp /= 10 ;
        }
        while( cnt > 0 ) putchar( F[ -- cnt ] ) ;
    }
    #undef I_int

}
using namespace io ;

using namespace std ;

#define N 100010

int n = read(), m = read();

int head[N], cnt;
struct edge {
    int to, nxt;
}e[N<<1];

int a[N], f[N], fa[N];
int op[N], c[N];
int ans[N], tot;

void ins(int u, int v) {
    e[++cnt] = (edge) {v, head[u]};
    head[u] = cnt;
}

void dfs(int u) {
    if(a[u]) f[u] = u;
    for(int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if(v == fa[u]) continue;
        if(!a[v]) f[v] = u;
        fa[v] = u;
        dfs(v);
    }
}

int find(int u) {
    if(f[u] == u) return u;
    return f[u] = find(f[u]);
}

int main() {
    for(int i = 1; i < n; ++i) {
        int u = read(), v = read();
        ins(u, v), ins(v, u);
    }
    char s[2];
    a[1] = 1;
    for(int i = 1; i <= m; ++i) {
        scanf("%s", s);
        in(c[i]);
        op[i] = s[0] == 'C';
        if(s[0] == 'C') ++a[c[i]];
    }
    dfs(1);
    for(int i = 1; i <= n; ++i) f[i] = find(f[i]);
    for(int i = m; i; --i) {
        if(op[i]) {
            --a[c[i]];
            if(!a[c[i]]) f[c[i]] = fa[c[i]];
        } else {
            ans[++tot] = find(f[c[i]]);
        }
    }
    reverse(ans + 1, ans + tot + 1);
    for(int i = 1; i <= tot; ++i) printf("%d\n", ans[i]);
}