题号 NC19987
名称 [HAOI2012]ROAD
来源 [HAOI2012]

题目描述

C国有n座城市,城市之间通过m条单向道路连接。一条路径被称为最短路,当且仅当不存在从它的起点到终点的另外一条路径总长度比它小。两条最短路不同,当且仅当它们包含的道路序列不同。我们需要对每条道路的重要性进行评估,评估方式为计算有多少条不同的最短路经过该道路。现在,这个任务交给了你。

样例

输入
4 4
1 2 5
2 3 5
3 4 5
1 4 8
输出
2
3
2
1

算法

(dp + 最短路 + 拓扑排序)

一个直接的想法,判断一条边经过多少条路径就看,

的两个端点,有多少条路径的终点是有多少条路径的起点是

然后根据乘法原理,这两种路径个数相乘就是这条边经过的路径的数量

同理看一条边经过多少条最短的路径,就看分别以为终点,为起点的最短路径有多少再相乘

我们观察数据,只有1500,边只有5000,用堆优化的是可以将个点为起点的最短路跑出来的

  1. 首先我们看以为终点的最短路径个数:

    我们定义一个数组,表示经过第个节点的路径个数,

    初始化起点,

    每一次我们用出队的节点值去更新其所有后继节点

    ,

    (说明已经存在一条到达v的最短路径,累加起来)

    这样我们就能得到所有节点的

  1. 接着我们看以为起点的最短路径的个数

    我们考虑将最短路图的反向图建出来,跑一个bfs来求

    由于每一次我们跑最短路的时候都是以一个节点为起点,求到其他节点的最短路径

    所以跑完最短路后得到的最短路径图一定是一个拓扑图

    我们可以一边求最短路最短路一边建反向边

    , 我们从v向u建一条有向边

    (如果v已经向外连了一条边,应该把已经连的所有边删掉后再连向u,对应到代码中是h1[v] = -1)

    ,我们从v向u建一条有向边

    接着我们在反向图中跑拓扑排序

    定义一个数组表示从i的前驱节点的个数

    (由于我们是反向建边,在反向图中的前驱节点就是原图的后继节点)

    每一个原图的后继节点都是最短路径的终点,

    所以我们计算这些节点的数量就能求出原图中第i个节点出发的最短路径个数

    就是的贡献

时间复杂度

参考文献:https://blog.nowcoder.net/n/6498f3d56eb54735be458ff24dc1ad7e

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef pair<int,int> PII;
typedef long long LL;
const int N = 2010,M = 50010 * 2,INF = 0x3f3f3f3f,mod = 1e9 + 7;
int h[N],ne[M],e[M],w[M],id[M],idx;
int h1[N],ne1[M],e1[M],id1[M],idx1;
int dist[N];
LL cnt1[N],cnt2[N];
LL ans[M];
int q[N];
struct edge
{
    int a,b,w;
}edges[M];
bool st[N];
int d[N];
int n,m;

void add(int a,int b,int c,int d)
{
    e[idx] = b,w[idx] = c,id[idx] = d,ne[idx] = h[a],h[a] = idx ++;
}

void add1(int a,int b,int d)
{
    e1[idx1] = b,id1[idx1] = d,ne1[idx1] = h1[a],h1[a] = idx1 ++;
}

void build()
{
    memset(h1,-1,sizeof h1);
    memset(dist,0x3f,sizeof dist);
    memset(st,0,sizeof st);
    memset(cnt1,0,sizeof cnt1);
    idx1 = 0;
}

void dijkstra(int s)
{
    priority_queue<PII,vector<PII>,greater<PII>> heap;
    dist[s] = 0;
    cnt1[s] = 1;
    heap.push(make_pair(dist[s],s));
    while(heap.size())
    {
        PII t = heap.top();
        heap.pop();
        int ver = t.second,distance = t.first;
        if(st[ver]) continue;
        st[ver] = true;
        for(int i = h[ver];~i;i = ne[i])
        {
            int j = e[i];
            if(dist[j] == distance + w[i])
            {
                add1(j,ver,id[i]);
                cnt1[j] = (cnt1[j] + cnt1[ver]) % mod;
            }else if(dist[j] > distance + w[i])
            {
                h1[j] = -1;
                add1(j,ver,id[i]);
                dist[j] = distance + w[i];
                cnt1[j] = cnt1[ver];
                heap.push(make_pair(dist[j],j));
            }
        }
    }
}

void topsort()
{
    int hh = 0,tt = -1;
    for(int i = 1;i <= n;i ++)
        for(int j = h1[i];~j;j = ne1[j])
        {
            int son = e1[j];
            ++ d[son];
        }
    for(int i = 1;i <= n;i ++)  
    {
        if(!d[i])
            q[++ tt] = i;
        cnt2[i] = 1;
    }
    while(hh <= tt)
    {
        int t = q[hh ++];
        for(int i = h1[t];~i;i = ne1[i])
        {
            int j = e1[i];
            cnt2[j] = (cnt2[j] + cnt2[t]) % mod;
            ans[id1[i]] = (ans[id1[i]] + cnt1[j] * cnt2[t] % mod) % mod;
            d[j] --;
            if(!d[j]) q[++ tt] = j;
        }
    }
}

void solve()
{
    scanf("%d%d",&n,&m);
    memset(h,-1,sizeof h);
    for(int i = 1;i <= m;i ++)
    {
        scanf("%d%d%d",&edges[i].a,&edges[i].b,&edges[i].w);
        add(edges[i].a,edges[i].b,edges[i].w,i);
    }
    for(int i = 1;i <= n;i ++)
    {
        build();
        dijkstra(i);
        topsort();
    }
    for(int i = 1;i <= m;i ++) printf("%lld\n",ans[i]);
}

int main()
{
    /*#ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #else
    #endif // LOCAL*/
    int T = 1;
    // init();
    // scanf("%d",&T);
    while(T --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}