换句话说就是要求一号点的出度为k的最短路

容易发现:当连在1号节点的几个边同时加一个数的时候,1号店的度会减少,反之,当同时增大一个k的时候,1号点的度会增大

因此,我们考虑二分一个数作为连在1号节点的那几条边修改的数值,进行二分查找出答案即可

#include <iostream> #include <stdio.h> #include <string.h> #include <algorithm> #include <vector> #define maxn 5500 #define maxm 1100000 #define rep(x,y,z) for(int x = y ; x <= z ; x ++) using namespace std ; int n , m , kk ; struct dy{ int u , v , w , id ; int operator < (const dy &x) const { return w == x.w ? u < x.u : w < x.w ; } }a[maxm] ; vector<int>ans ; int fa[maxn] ; int find(int x) { return x == fa[x] ? x : fa[x] = find(fa[x]) ; } void unionn(int x ,int y){ fa[find(x)] = find(y) ; } int read() ; int ksm(int k) { ans.clear() ; rep(i,1,n) fa[i] = i ; int cnt = 0 , in = 0 ; rep(i,1,m) { if(a[i].u == 1 || a[i].v == 1) { a[i].w += k ; } }sort(a+1,a+1+m) ; rep(i,1,m) { int u = a[i].u , v = a[i].v , w = a[i].w ; int fu = find(u) , fv = find(v) ; if(fu == fv) continue ; fa[fv] = fu ; if(u == 1 || v == 1) in ++ ; ans.push_back(a[i].id) ; cnt ++ ; if(cnt == n-1) break ; } rep(i,1,m) { if(a[i].u == 1 || a[i].v == 1) a[i].w -= k ; }return in ; } int deal(int k) { ans.clear() ; rep(i,1,n) fa[i] = i ; int cnt = 0 , in = 0 ; rep(i,1,m) { if(a[i].u == 1 || a[i].v == 1) { a[i].w += k ; } }sort(a+1,a+1+m) ; rep(i,1,m) { int u = a[i].u , v = a[i].v , w = a[i].w ; int fu = find(u) , fv = find(v) ; if(fu == fv) continue ; if(a[i].u == 1 && in == kk) continue ; fa[fv] = fu ; if(u == 1 || v == 1) in ++ ; cnt ++ ; ans.push_back(a[i].id) ; } rep(i,1,m) { if(a[i].u == 1 || a[i].v == 1) a[i].w -= k ; } if(cnt != n-1 || in != kk) { puts("-1") ; exit(0) ; } cout << n-1 << endl ; for(int i = 0 ; i < ans.size() ; i ++) cout << ans[i] << " " ; puts("") ;exit(0) ; return in ; } int main () { n = read() , m = read() , kk = read() ; rep(i,1,m) { a[i].u = read() , a[i].v = read() , a[i].w = read() , a[i].id = i ; if(a[i].u > a[i].v) swap(a[i].u,a[i].v) ; }int l = -1e5 , r = 1e5 , Ans = -1e9 ; while(l <= r) { int mid = (l+r) / 2 ; if(ksm(mid) < kk) r = mid - 1 ; else { l = mid + 1 , Ans = mid ; } } deal(Ans) ; return 0 ; } int read() { int x = 0 , f = 1 ; char s = getchar() ; while(s > '9' || s < '0') {if(s == '-') f = -1 ; s = getchar() ;} while(s <='9' && s >='0') {x = x * 10 + (s-'0'); s = getchar() ;} return x*f ; }