n,x = map(int,input().split())
a = list(map(int,input().split()))
cnt = 0
a.sort()
if x >= n :
    print(sum(a))
else :
    for i in range(n-x,n):
        a[i] -= a[n-x-1]
    print(a[n-x-1]*x + sum(a[n-x:n]))