题意


题目描述:
有一座雪山,这里有N个山头和M条轨道。滑雪者从a山头滑到b山头要求,a山比b山高或相等。
滑雪者想要从1号山头开始滑尽量多的山头。滑雪者有回溯的能力(返回上一个节点),并且可以连续回溯。
得到以最短滑行距离滑到尽量多的景点的方案。
求出最短距离和最多可以景点数。

输入描述:
输入的第一行是两个整数N,M。
接下来1行有N个整数Hi,分别表示每个景点的高度。
接下来M行,表示各个景点之间轨道分布的情况。每行3个整数,Ui,Vi,Ki。表示编号为Ui的景点和编号为Vi的景点之间有一条长度为Ki的轨道。

输出描述:
输出一行,表示滑雪者最多能到达多少个景点,以及此时最短的滑行距离总和。


Solution


图片说明 算法构建最小生成树,因为用堆优化了图片说明 ,分析的思路和最短路有点像。
1.因为不能从低点到高点,所以存边时如果图片说明 ,就不存图片说明 这条边,如果图片说明 ,要存两条边图片说明
2.当有环的情况,存在一个点是多个找到最短路的点的邻居,也就是有多种长度,这个可以通过先取出长度最小的边然后标记来排除其它多余的长度。
3.因为1步骤中我把边看成有向边存入,我开始以为从1开始建生成树然后取出已经找到最短路的点的邻居时取距离最小的一个,但是这样可能会存在如下情况,假设目前只走了1、2、3,然后由2、3更新的值,若此时就将图片说明 出队是不准确的,如果我们要选一个点a出队,显然比a更高的点都应该已经出队了,这样才能确保与 a 相连的边都已经比较过。所以出队的应该按照出点的高度从大到小排序,高度相同再按照长度排序,这样相当于我们在一层一层扩展,先把最高层的点加入最小树形图然后次高层然后第三高层……这样所有的边就是按照从高的低的方向走的了。(好吧就是雨巨的原话,只是prim也是这样的)
图片说明


Code:


#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6+7,maxm=1e5+7;
template <class T>
inline void read(T &res) {
    char c; T flag = 1;
    while ((c = getchar()) < '0' || c > '9')
        if (c == '-')
            flag = -1;
    res = c - '0';
    while ((c = getchar()) >= '0' && c <= '9')
        res = res * 10 + c - '0';
    res *= flag;
}
struct node{
    int v,h;
    ll dis;
    bool operator<(const node&a)const{
        if(h!=a.h)    return h<a.h;
        return dis>a.dis;
    }
};
int head[maxm],Next[maxn<<1],to[maxn<<1],tot;
ll w[maxn<<1],dis[maxm],ans;
void add(int x,int y,int c) {
    to[++tot]=y;
    Next[tot]=head[x];
    w[tot]=c;
    head[x]=tot;
}
int vis[maxm],val[maxm],cnt,n,m;
void Prim() {
    for(int i=0; i<=n; ++i) dis[i]=0x3f3f3f3f3f3f3f3f;
    priority_queue<node> q;
    q.push(node{1,val[1],0});
    dis[1]=0;
    while(!q.empty()) {
        node tmp=q.top();
        q.pop();
        int u=tmp.v;
        if(vis[u]) continue;
        vis[u]=1;
        ++cnt,ans+=dis[u];
        for(int i=head[u];i;i=Next[i]) {
            int v=to[i];
            if(!vis[v]&&dis[v]>w[i]) {
                dis[v]=w[i];
                q.push(node{v,val[v],dis[v]});
            }
        }
    }
}
int main() {
    read(n),read(m);
    for(int i=1;i<=n;++i)    read(val[i]);
    for(int i=1;i<=m;++i) {
        int u,v;ll c;
        read(u),read(v),read(c);
        if(val[u]>=val[v]) add(u,v,c);
        if(val[v]>=val[u]) add(v,u,c);
    }
    Prim();
    printf("%d %lld\n",cnt,ans);
}