原题链接https://ac.nowcoder.com/acm/contest/5666/B


题目描述

定义mindiv(n)表示n的大于1的最小因数。Bob用所有正整数构建了一棵无限大的树,每个正整数对应一个节点,对于所有大于1的n,在点n和点图片说明 间连边。
定义δ(u,v)表示连接点u和点v的路径上的边的数量。给定m和图片说明 ,Bob想知道图片说明


输入描述

输入包含几个以文件结尾终止的测试样例。
每个样例的第一行输入一个整数m;
接下来的一行输入m个整数。

输出描述

每个样例输出一行整数作为答案。


样例

  • 输入

    3
    1 1 1
    4
    3 1 2 4
    4
    0 0 0 0
  • 输出

    3
    17
    0

题解思路

这道题目毛看看似乎在考你生成函数,但仔细一看,体中明显表明是树,所以肯定有问题。
但因为本题树的节点过多,所以推荐使用虚树来优化树形DP。

虚树介绍1:https://blog.csdn.net/weixin_37517391/article/details/82744605
虚树介绍2:https://blog.csdn.net/zhouyuheng2003/article/details/79110326

奆佬思路:https://blog.nowcoder.net/n/df889adfaf824d50ad2291f4d2eb04a2

何为虚树?
所谓虚树,就是只包含关键点和关键lca的树,而整棵虚树的规模不会超过关键点的两倍。如下图所示,绿色覆盖的点为关键点,其余为非关键点但可能是关键lca。
图片说明
则建立的虚树可能为(因为本人虚树学得不是很好如果有错请原谅):
图片说明

好那话不多说(说不下去了,思路有点忘了,啊这),直接上代码。

参考代码

#include<bits/stdc++.h>
#define lowbit(x) x&-x
using namespace std;
const int MAXN=2e5+10;
int n,w[MAXN];
long long ans;
int c[MAXN];
int lcadep[MAXN],dep[MAXN];
int st[MAXN],top,tot;
vector<int>g[MAXN];
void upd(int p,int k)
{
    for(;p<=n;p+=lowbit(p))
        c[p]+=k;
}
int query(int p)
{
    int res=0;
    for(;p;p-=lowbit(p))
        res+=c[p];
    return res;
}
int mindiv[MAXN];
void sieve(int siz)
{
    for (int i=2;i<=siz;i++)
        if (!mindiv[i])
            for (int j=i;j<=siz;j+=i)
                if (!mindiv[j])
                    mindiv[j]=i;
}
void build()
{
    tot=n;
    st[top=1]=1;
    for (int i=2;i<=n;i++)
    {
        int j=i;
        dep[i]=dep[i-1]+1;
        for (;j!=mindiv[j];j/=mindiv[j])
            dep[i]++;
        lcadep[i]=query(n)-query(j-1);
        for (j=i;j!=1;j/=mindiv[j])
            upd(mindiv[j],1);
    }
    for (int i=2;i<=n;i++)
    {
        while (top>1&&dep[st[top-1]]>=lcadep[i])
        {
            g[st[top-1]].push_back(st[top]);
            g[st[top]].push_back(st[top-1]);
            top--;
        }
        if (dep[st[top]]!=lcadep[i])
        {
            dep[++tot]=lcadep[i];
            g[st[top]].push_back(tot);
            g[tot].push_back(st[top]);
            st[top]=tot;
        }
        st[++top]=i;
    }
    while(top>1)
    {
        g[st[top-1]].push_back(st[top]);
        g[st[top]].push_back(st[top-1]);
        top--;
    }
}
void dfs(int u,int fa)
{
    ans+=1ll*w[u]*dep[u];
    for (auto &v:g[u])
        if (v!=fa)
        {
            dfs(v,u);
            w[u]+=w[v];
        }
}
void dfs2(int u, int fa)
{
    for (auto &v:g[u])
        if (v!=fa)
            if (w[1]-2*w[v]<0)
            {
                ans+=1ll*(w[1]-2*w[v])*(dep[v]-dep[u]);
                dfs2(v, u);
            }
}
int main()
{
    sieve(1e5);
    while (~scanf("%d",&n))
    {
        ans=top=0;
        for (int i=1;i<=tot;i++)
        {
            g[i].clear();
            c[i]=w[i]=lcadep[i]=dep[i]=0;
        }
        for (int i=1;i<=n;i++)
            scanf("%d", &w[i]);
        build();
        int rt=1;
        dfs(rt,0);
        dfs2(rt,0);
        printf("%lld\n",ans);
    }
}