题目要求是给一颗树,从任意一个节点开始向下走,统计有多个和为0的路径

首先思考暴力怎么做

既然题目说的是向下走路径,我们思考的就是枚举每个节点,然后把这个节点到根节点这段路径拿出来,针对这一条路径中,有多少个子路径满足(1、子路径的起点是当前枚举到节点 && 2、这个子路径的和为0),这样能保证每次添加答案的路径都不重复

这个算法的时间复杂度不太确定,主要看这颗树是怎么构建的,如果这棵树够成了一条链,那么最坏情况就是(n^2)的复杂度,然后这个做法在该数据下能过18/20,大多数情况感觉应该在(nlogn)

#include<bits/stdc++.h>
using namespace std;
int read()
{
    char ch=getchar();
    int ans=0,f=1;
    if(ch=='-')
    {f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')
    {
        ans=(ans<<1)+(ans<<3)+(ch-'0');
        ch=getchar();
    }
    return ans*f;
}

int n;
string s[200001];
int ls[200001],rs[200001];
long long int sum[200001];
int ans;
int val;
int work(string t)
{
    int f=1,ANS=0;
    if(t[0]=='-')
    {
        f=-1;
        int len=t.length();
        for(int i=1;i<len;i++)
        ANS=(ANS*10)+(t[i]-'0');
        return ANS*f;
    }
    else
    {
        int len=t.length();
        for(int i=0;i<len;i++)
        {
            ANS=(ANS*10)+(t[i]-'0');
        }
        return ANS;
    }
}
void dfs(int pos,int depth,int SUM)
{
    val=work(s[pos]);
    if(depth-2>=0)
    {
        for(int i=0;i<=depth-2;i++)
        if(SUM+val-sum[i]==0)
        {
           ans++;
        }
    }
    sum[depth]=sum[depth-1]+val;
    if(ls[pos]!=-1&&ls[pos]!=0)
        dfs(ls[pos],depth+1,sum[depth]);
    if(rs[pos]!=-1&&rs[pos]!=0)
        dfs(rs[pos],depth+1,sum[depth]);
    return;
}
queue<int>Q;

int main()
{
    n=read();
    for(int i=1;i<=n;i++)
        cin>>s[i];
//建树
    Q.push(1);
    for(int i=2;i<=n;i+=2)
    {
        int fa=Q.front();
        if(s[i]=="None")
        {ls[fa]=-1;}
        else 
        {
            ls[fa]=i;
            Q.push(i);
        }
        if(i+1<=n)
        if(s[i+1]=="None"){rs[fa]=-1;}
        else{ rs[fa]=i+1;Q.push(i+1);}
        Q.pop();
    }
    dfs(1,1,0);
    cout<<ans;
    return 0;
}

解释一下上面的代码,建树就是把树结构建出来,保留所有合法节点,ls是左儿子,rs是右儿子,如果儿子节点是None,那么就标注为-1意思为不合法,而且全局变量初始化为0,如果儿子是0,则表示没有儿子

下来的话是搜索dfs,参数SUM其实多余了,可以优化掉,后面看我第二版代码

主要用sum[depth]来表示深度为depth的前缀和是多少,然后用

for(int i=0;i<=depth-2;i++)

if(SUM+val-sum[i]==0)

ans++;

这个语句来枚举搜索到当前节点时,向上看,有多少个合法子路径和为0,注意有一个小细节时从depth>=2开始计算,因为depth=1的时候,相当于根节点,根节点只有一个节点不能构成路径,至少要出第二层开始才能有路径

最后能统计出答案

以上是暴力,下来思考怎么优化

优化点其实还在于ans++这个操作加起来太慢了,一次只能加一个答案,在这里看怎么能加快答案计数效率

答案增加的条件是SUM+val-sum[i]==0

这个条件翻译下来就是,如果当前节点到根节点的前缀和等于当前节点的任意一个祖先节点到根节点的前缀和,那么ans+=1

这里面其实SUM+val对于每个枚举到的节点是固定值,我们想要知道就是SUM+val=sum[i]中sum[i]的数量

这个数量可以用map来维护(unordered_map一个道理),下来的任务就是就转化成了,当前节点到根节点,有多少个前缀和刚好和当前节点到根节点的前缀和,前缀和用sum来维护,前缀和的数量用map来维护

但这里面还有一个细节,在于如果这个节点0,怎么处理

举一个例子

-1

0 0

按照上面的说法,算出来的结果应该是2,但是实际上真正答案为0,意味着如果当前节点为0需要特殊处理,就是ans在这时加的是+mp[]-1,而不是平时的mp[]

#include<bits/stdc++.h>
using namespace std;
int read()
{
    char ch=getchar();
    int ans=0,f=1;
    if(ch=='-')
    {f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')
    {
        ans=(ans<<1)+(ans<<3)+(ch-'0');
        ch=getchar();
    }
    return ans*f;
}
int n;
string s[200001];
int ls[200001],rs[200001];
long long int sum[200001];
long long int ans;
int val;
int work(string t)
{
    int f=1,ANS=0;
    if(t[0]=='-')
    {
        f=-1;
        int len=t.length();
        for(int i=1;i<len;i++)
        ANS=(ANS*10)+(t[i]-'0');
        return ANS*f;
    }
    else
    {
        int len=t.length();
        for(int i=0;i<len;i++)
        {
            ANS=(ANS*10)+(t[i]-'0');
        }
        return ANS;
    }
}
unordered_map<long long int,int>mp;
void dfs(int pos,int depth)
{
    val=work(s[pos]);
    if(depth>=2)
    {
        if(val==0)
        {
            if(mp[sum[depth-1]])
            ans+=mp[sum[depth-1]]-1;
        }
        else
        ans+=mp[val+sum[depth-1]];
    }
    sum[depth]=sum[depth-1]+val;
    mp[sum[depth]]++;
    if(ls[pos]!=0&&ls[pos]!=-1)
    dfs(ls[pos],depth+1);
    if(rs[pos]!=0&&rs[pos]!=-1)
    dfs(rs[pos],depth+1);
    mp[sum[depth]]--;
    return;
}
queue<int>Q;
int main()
{
    n=read();
    for(int i=1;i<=n;i++)
        cin>>s[i];
    Q.push(1);
    for(int i=2;i<=n;i+=2)
    {
        int fa=Q.front();
        if(s[i]=="None")
        {ls[fa]=-1;}
        else 
        {
            ls[fa]=i;
            Q.push(i);
        }
        if(i+1<=n)
        if(s[i+1]=="None"){rs[fa]=-1;}
        else{ rs[fa]=i+1;Q.push(i+1);}
        Q.pop();
    }
    mp[0]=1;
    dfs(1,1);
    cout<<ans;
    return 0;
}

最后注意开longlong,不然1e9和2e5在一起会爆int