·题意

·给定一颗n格结点的无向树,每经过一条边或第一次进入新的结点都会减少体力,在树上移动的过程中可以使用一次传送,消耗一点体力从i转移到a[i],输出从节点1开始在体力消耗<=x的情况下回到节点1所能通过的最多的结点数;

·分析

·不传送的情况

·首先我们考虑不用传送的情况,可发现我们的轨迹一定是如下的:

·因为我们最后一定要回到节点1,同时这个图又是没有环的,所以每到达一个新的节点,只能原路返回

·所以最后的轨迹也是一颗,其中的结点都只计算一次,树的边都是要正反走两次的,故若节点数为X,则会消耗 X + (X - 1) * 2 = 3 * X - 2 的体力;

·所以在 x的体力下,最多能取到 min(n, (x + 2) / 3) 个结点; (因为最多只有n个结点嘛)

·考虑传送的情况

·首先我们可以发现,我们传送走后需要考虑回到节点1的问题,所以我们使用每个节点的传送都有至少要用的体力;

·这个至少要用的体力示意图如下:

·假设使用了图中的蓝虚线传送,可发现一定需要通过,cost = 图中的红边 + 图中的紫边 + 节点数cur + 1(传送消耗).

·因此若x < cost, 则该传送必然不可能实现;

·但是, 若x>cost,还是有可能多走节点的;

·此时可发现,要多走一个结点,可以在已有的图上扩展,并且每扩展一个节点,必然多耗费三点体力;

·因为同不传送的情况,每个新的结点必须原路返回才能回到节点1

·所以能多扩展的节点数为(x - cost) / 3

·该情况结点数为 cur + (x - cost) / 3

// 可以发现不传送的情况本质是无传送费用,从1传送到1的传送的情况;

·代码实现

·不传送的情况直接计算即可

ans = min(n, (x + 2) / 3);

·对传送的情况

·我们发现若令节点1为深度为0的点

·则对使用节点i(传送至a[i])传送的:

cur = i的深度 + a[i] 的深度 - i 和 a[i] 最近公共祖先的深度 + 1; (深度也可以理解为从根到当前节点的链的大小除去根节点,所以最后需要加上根节点)

cost = (i的深度 + a[i] 的深度) * 2 - i 和 a[i] 最近公共祖先的深度 + 1 + 1; (比要多算上边)

·要考虑最近公共祖先, 这启发我们使用lca;

int lca(int x, int y)
{
    if (depth[x] > depth[y])
        swap(x, y);

    int sub = depth[y] - depth[x];
    int bit = 0;
    while (sub)
    {
        if (sub & 1)
            y = fa[y][bit];
        bit++;
        sub >>= 1;
    }

    if (x == y)
        return x;
    for (int i = 19; i >= 0; i--)
    {
        if (fa[x][i] == fa[y][i])
            continue;
        x = fa[x][i];
        y = fa[y][i];
    }
    return fa[x][0];
}

·传送的情况

for (int i = 1; i <= n; i++)
    {
        int y = a[i];
        int the_lca = lca(i, y);
        int cost = (depth[i] + depth[y]) * 2 - depth[the_lca] + 2;
        if (cost > x)
            continue;
        int cur = depth[i] + depth[y] - depth[the_lca] + 1;
        ans = max(ans, min(cur + (x - cost) / 3, n));
    }

·这样就能解决该问题了;

·复杂度分析

·不传送的情况O(1);

·传送的情况;

·lca 预处理O(nlog);

·lca查询过程O(nlogn);

·即时间复杂度为O(nlogn);

·空间复杂度O(20*n); // (lca倍增思路所需的空间

·完整代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
const int N = 200010, M = N * 2;

int n, x;
int a[N];
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][20];
int ans;
int que[N];
bool st[N];

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

int lca(int x, int y)
{
    if (depth[x] > depth[y])
        swap(x, y);

    int sub = depth[y] - depth[x];
    int bit = 0;
    while (sub)
    {
        if (sub & 1)
            y = fa[y][bit];
        bit++;
        sub >>= 1;
    }

    if (x == y)
        return x;
    for (int i = 19; i >= 0; i--)
    {
        if (fa[x][i] == fa[y][i])
            continue;
        x = fa[x][i];
        y = fa[y][i];
    }
    return fa[x][0];
}

void bfs(int u)
{
    int hh = 0, tt = -1;
    que[++tt] = u;
    depth[u] = 0;
    st[u] = true;
    while (hh <= tt)
    {
        int t = que[hh++];
        for (int i = h[t]; ~i; i = ne[i])
        {
            int j = e[i];
            if (st[j])
                continue;
            que[++tt] = j;
            depth[j] = depth[t] + 1;
            fa[j][0] = t;
            st[j] = true;
        }
    }
}

int main()
{
    scanf("%d%d", &n, &x);
    memset(h, -1, sizeof h);
    ans = min(n, (x + 2) / 3);

    for (int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    for (int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v), add(v, u);
    }

    bfs(1);

    for (int height = 1; height < 20; height++)
    {
        for (int i = 1; i <= n; i++)
        {
            fa[i][height] = fa[fa[i][height - 1]][height - 1];
        }
    }

    for (int i = 1; i <= n; i++)
    {
        int y = a[i];
        int the_lca = lca(i, y);
        int cost = (depth[i] + depth[y]) * 2 - depth[the_lca] + 2;
        if (cost > x)
            continue;
        int cur = depth[i] + depth[y] - depth[the_lca] + 1;
        ans = max(ans, min(cur + (x - cost) / 3, n));
    }
    printf("%d\n", ans);

    return 0;
}

·希望对大家有所帮助qwq。