题意:

给定一张N个点的完全图,可以从任何一个点出发,同一个点可以经过多次。询问总路径长度不超过M的情况下,最多能够经过多少个点。

思路:

首先我们能够想到一个最简单的模拟算法。

建立数组dist[][]dist[i][j]表示经过i个点后,最后停留在j所以经过的最短路径长度。

那么有如下递推公式:

dist[i][j] = Min{ dist[i - 1][k] + edge[j][k] | 1 <= k <= N, k != j }

其中edge[j][k]表示从点j到点k的边长。

初始化为dist[0][1 .. N] = 0

算法结束的条件是我们模拟到一个i值,dist[i][1 .. N]都大于M。并得到最后的结果i-1

该算法的时间复杂度为 O(AnsN^2) 。若所有的边长均为1,则该算法的时间复杂度会达到O(MN^2),对于题目给定的 MN 的范围,显然无法接受。

但通过这个模拟的算法,可以发现一个这样一个性质:

i超过Ans之前,dist[i][]中总是至少存在一个小于等于M的值。

我们构造一个check(i)函数,表示dist[i][]中是否存在一个小于等于M的值。

则可以发现i值和check(i)的值的变化关系如下图:

这恰恰是二分答案类型题目的特点,因此我们不妨考虑使用二分答案做。

最小可能经过的点数是0,最大可能经过的点数是M,可以确定下界0和上界M+1

唯一需要解决的问题是如果快速的计算出dist[i][]来方便计算check(i)

前面我们给了一个edge[][]数组,一般的理解为任意两个点之间的边距离,但其实这个数组还有另一种理解方式:edge[j][k]表示从j出发,经过1条边到达k的路径距离

我们将edge[][]记为edge_1。那么edge_1有什么用呢?

很显然,我们可以通过edge_1来计算edge_2edge_2[j][k]即表示从j出发,经过2条边到达k的路径距离。

并且在此基础上可以计算出经过3条边的edge_3,经过4条边的edge_4...一直到edge_i

其计算方法为:

add(edge_x, edge_y):
    edge_z[][] = Infinite
    For i = 1 .. N
        For j = 1 .. N
            For k = 1 .. N
                If (i not equal k and j not equal k) Then
                    If (edge_z[i][j] > edge_x[i][k] + edge_y[k][j]) Then
                        edge_z[i][j] = edge_x[i][k] + edge_y[k][j]
                    End If
                End If
            End For
        End For
    End For
    Return edge_z

即是在已知edge_xedge_y的情况下,我们可以通过O(N^3)的算法来求得edge_x+y

但是一个一个递推显然也不合适,我们仍然需要快速计算edge_i

这里采用和快速幂相似的分治算法来处理:

对于每一个edge_ii,可以分解成若干个2的幂之和,即:

i = 2^p1 + 2^p2 + 2^p3 + ... + 2^pk

如果我们事先已经计算出这些edge_2^pk,那么就可以直接计算出edge_i

则得到我们的算法为:

  1. 根据输入得到edge_1
  2. edge_1edge_1进行add操作得到edge_2,将edge_2edge_2进行add操作得到edge_4...将edge_2^tedge_2^t进行add操作得到edge_2^(t+1),直至2^t>M
  3. 计算edge_i,将i分解为若干个2的幂之和,将对应的edge_2^pt进行add操作
  4. dist[0][]edge_i进行计算,即可得到dist[i][]

该算法预处理的时间为O(N^3logM),每次计算edge_i的时间为O(N^3logi),最后计算dist[i][]的时间为O(N^2),判断dist[i][]的时间复杂度为O(N)。因此进行一次check(i)操作总的时间复杂度为O(N^3logM)

再加上前面二分答案的过程,总算法的时间复杂度为O(N^3 (logM)^2)。对于这道题目来说基本能够通过全部的数据了。

但是,这并不是最优的算法。在得到edge_2^pt的情况下,还可以把这题做的更简单,将时间复杂度降低至O(N^3logM)

首先同样要计算出所有的edge_2^pt,并初始化dist[] = 0(注意此时没有i,只是一个长度为N的数组)。

接下来从edge_2^t开始枚举,若dist[]edge_2^t进行add操作后,在dist数组存在一个不超过M的值,将ans加上2^t,继续枚举下一个edge_2^(t-1);否则先将dist[]数组还原,再枚举下一个edge_2^p(t-1)

伪代码为:

dist[] = 0
ans = 0
For i = t .. 0
    newDist[] = Infinite
    For j = 1 .. N
        For k = 1 .. N
            If (j not equal k and newDist[j] > dist[k] + edge_2^i[k][j]) Then
                newDist[j] = dist[k] + edge_2^i[k][j]
            End If
        End For
    End For
    If (check(newDist)) Then
        dist = newDist
        ans = ans + 2^i
    End If
End For

最后得到的ans,也就是最大能过经过的点数。

这个算法有点类似于完全背包问题,将经过的点数按照二进制拆解,并放入背包,看是否满足要求,若不满足,就将其移除。

关于为什么要从大到小枚举,是因为有可能先放入小物品后会导致无法放入一个大的物品,而假如不放入这个小物品,就可以将后面的一个大物品放入。反过来则不会出现这样的问题。举个例子来说明:

假设有一个大小为5个背包,我们有3个大小分别为1、2、4的物品。

按照从小到大的顺序枚举,得到的答案是1+2;

而按照从大到小的顺序枚举,得到的答案是4+1。

显然后者才是正确的答案。

/* ***********************************************
Author        :devil
************************************************ */
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <string>
#include <cmath>
#include <stdlib.h>
#define inf 0x3f3f3f3f
#define LL long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define dec(i,a,b) for(int i=a;i>=b;i--)
#define ou(a) printf("%d\n",a)
#define pb push_back
#define mkp make_pair
template<class T>inline void rd(T &x){char c=getchar();x=0;while(!isdigit(c))c=getchar();while(isdigit(c)){x=x*10+c-'0';c=getchar();}}
#define IN freopen("in.txt","r",stdin);
#define OUT freopen("out.txt","w",stdout);
using namespace std;
const int mod=1e9+7;
const int N=102;
int mp[40][N][N],dis[2][N],mai,mak,n,m,ans;
int main()
{
    rd(n),rd(m);
    memset(mp,inf,sizeof(mp));
    rep(i,0,n-1) rep(j,0,n-1) rd(mp[0][i][j]);
    rep(i,0,n-1) mp[0][i][i]=inf;
    for(int i=2,k=1;i<=m;i<<=1,k++)
    {
        mai=i,mak=k;
        rep(x,0,n-1) rep(y,0,n-1) rep(z,0,n-1) mp[k][x][y]=min(mp[k][x][y],mp[k-1][x][z]+mp[k-1][z][y]);
    }
    for(int i=mai,k=mak;i;i>>=1,k--)
    {
        bool flag=0;
        rep(j,0,n-1) dis[1][j]=inf;
        rep(j,0,n-1) rep(x,0,n-1)
        if(dis[1][j]>mp[k][x][j]+dis[0][x])
        {
            dis[1][j]=mp[k][x][j]+dis[0][x];
            if(dis[1][j]<=m) flag=1;
        }
        if(flag)
        {
            ans+=i;
            memcpy(dis[0],dis[1],(n+10)*sizeof(int));
        }
    }
    ou(ans);
    return 0;
}
View Code