luoguP1850 换教室

链接

https://www.luogu.org/problemnew/show/P1850

思路

状态很显然就是f[n][k][0/1]
前i次,用了k次机会,当前是在哪个教室
转移就很、、了。
每次转移i的时候,一定是从i的两个可能的教室来的。
那就有四种可能的方式了。

先考虑不换的情况,即\(k = 0\)时的情况
\(C1 = c[i - 1], C2 = d[i - 1], C3 = c[i], C4 = d[i]\)
\(mp[i][j]\)表示\(i,j\)间的最短路
\(dp[i][j][0] =min\begin{cases} dp[i - 1][j][0] + mp[C1][C3]\\dp[i - 1][j][1] + mp[C1][C3] * (1 - k[i - 1]) + mp[C2][C3] * k[i - 1]\end{cases}\)
显然如果\(i-1\)时没有换教室那么,\(i - 1\)\(i\)只有一种情况就是都不换教室,如果\(i - 1\)时换了教室那么就有两种情况\(i - 1\)换成功了,或者没换成功所以就是对应的路径长乘上对应的概率
$dp[i][j][1] =min\begin{cases} dp[i - 1][j - 1][0] + mp[C1][C3] * (1 - k[i]) + mp[C1][C4] * k[i]\dp[i - 1][j - 1][1] + mp[C2][C4] * k[i] * k[i - 1] + mp[C2][C3] * k[i - 1] * (1 - k[i]) + mp[C1][C4] * (1 - k[i - 1]) * k[i] + mp[C1][C3] * (1 - k[i - 1]) * (1 - k[i])\end{cases} $
ORZ 大佬

错误

数组开小了
floyd的k写到外边了

代码

#include <bits/stdc++.h>
using namespace std;
const int N=2e3+7,M=307;
const double inf=10000000000000000000.0;
int read() {
    int x=0,f=1;char s=getchar();
    for(;s>'9'||s<'0';s=getchar()) if(s=='-') f=-1;
    for(;s<='9'&&s>='0';s=getchar()) x=x*10+s-'0';
    return x*f;
}
int n,m,v,e,c[N],d[N],dis[M][M];
double k[N],f[N][N][2];
int main() {
    n=read(),m=read(),v=read(),e=read();
    for(int i=1;i<=n;++i) c[i]=read();
    for(int i=1;i<=n;++i) d[i]=read();
    for(int i=1;i<=n;++i) scanf("%lf",&k[i]);
    for(int i=1;i<=v;++i) {
        for(int j=1;j<=v;++j) dis[i][j]=0x3f3f3f3f;
        dis[i][i]=0;
    }
    for(int i=1;i<=e;++i) {
        int a=read(),b=read(),val=read();
        dis[a][b]=dis[b][a]=min(dis[a][b],val);
    }
    for(int mmp=1;mmp<=v;++mmp)
        for(int i=1;i<=v;++i)
            for(int j=1;j<=v;++j)
                dis[i][j]=min(dis[i][j],dis[i][mmp]+dis[mmp][j]);
    for(int i=1;i<=n;++i)
        for(int j=0;j<=m;++j)
            f[i][j][0]=f[i][j][1]=inf;
    f[1][0][0]=f[1][1][1]=0;
    for(int i=2;i<=n;++i) {
        for(int j=0;j<=m;++j) {
            f[i][j][0]=min(f[i-1][j][0]+dis[c[i-1]][c[i]],
                           f[i-1][j][1]+k[i-1]*dis[d[i-1]][c[i]]+(1.0-k[i-1])*dis[c[i-1]][c[i]]);
            if(j)
            f[i][j][1]=min(f[i-1][j-1][0]+k[i]*dis[c[i-1]][d[i]]+(1.0-k[i])*dis[c[i-1]][c[i]],
                           f[i-1][j-1][1]+k[i]*k[i-1]*dis[d[i-1]][d[i]]+
                                          k[i]*(1.0-k[i-1])*dis[c[i-1]][d[i]]+
                                          (1.0-k[i])*k[i-1]*dis[d[i-1]][c[i]]+
                                          (1.0-k[i])*(1.0-k[i-1])*dis[c[i-1]][c[i]]);
        }
    }
    double ans=inf;
    for(int i=0;i<=m;++i)  ans=min(ans,min(f[n][i][0],f[n][i][1]));
    printf("%.2lf\n",ans);
    return 0;
}