[AtCoder Beginner Contest 151] E - Max-Min Sums(预处理组合数学,贡献)

Problem Statement

For a finite set of integers XX, let f(X)=maxX−minXf(X)=maxX−minX.

Given are NN integers A1,...,ANA1,...,AN.

We will choose KK of them and let SS be the set of the integers chosen. If we distinguish elements with different indices even when their values are the same, there are NCKNCK ways to make this choice. Find the sum of f(S)f(S) over all those ways.

Since the answer can be enormous, print it mod(109+7)mod(109+7).

Constraints

  • 1≤N≤1051≤N≤105
  • 1≤K≤N1≤K≤N
  • |Ai|≤109|Ai|≤109

Input

Input is given from Standard Input in the following format:

NN KK
A1A1 ...... ANAN

Output

Print the answer mod(109+7)mod(109+7).

题意:

给定一个含有n个正数的数组,从中任意选k个数组成一个集合S,集合S的价值为\(f(S)=\max S - \min S\)

显然有\(C(n,k)\)种集合,让求所有集合的\(f(s)\)的sum和。

思路:

我们知道此类问题解决方法通常为计算每一个数组中的元素\(a_i\)对答案的贡献。

先将数组按升序排序,我们知道\(a_i\)对答案贡献有2种方式,即1、当\(f(S)\)中的最大值,2、当\(f(S)\)中的最小值。

分析可得排序后的数组中元素\(a_i\)\(C(i - 1, k - 1)\)个集合中当最大值,在\(C(n - i, k - 1)\)个集合中当最小值。

直接计算答案即可,记得取模。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#include <sstream>
#include <bitset>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) { if (a == 0ll) {return 0ll;} a %= MOD; ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}
ll poww(ll a, ll b) { if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a ;} a = a * a ; b >>= 1;} return ans;}
void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
inline long long readll() {long long tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
inline int readint() {int tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
const int maxn = 1000010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
int n;
ll a[maxn];
int k;
const ll mod = 1e9 + 7;
ll fac[maxn], inv[maxn];
void pre()
{
    fac[0] = 1;
    for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % mod;
    inv[maxn - 1] = powmod(fac[maxn - 1], mod - 2, mod);
    for (int i = maxn - 2; i >= 0; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
}
ll C(int a, int b)
{
    if (b > a || b < 0) return 0;
    return fac[a] * inv[b] % mod * inv[a - b] % mod;
}
int main()
{
    //freopen("D:\\code\\text\\input.txt","r",stdin);
    //freopen("D:\\code\\text\\output.txt","w",stdout);
    pre();
    n = readint();
    k = readint();
    repd(i, 1, n)
    {
        a[i] = readll();
    }
    sort(a + 1, a + 1 + n);
    ll ans = 0ll;
    repd(i, k, n)
    {
        ans += C(i - 1, k - 1) * a[i] % mod;
        ans %= mod;
    }
    repd(i, 1, n - k + 1)
    {
        ans += -C(n - i, k - 1) * a[i] % mod + mod;
        ans %= mod;
    }
    printf("%lld\n", ans );
    return 0;
}