题意概括:

给定n篇文章,每篇有价值ai和体力消耗bi。对所有区间[l, r],选择区间内价值最大的min(k, 区间长度)篇文章,求这些文章消耗的体力值之和的总和(对所有区间求和)。由于答案很大,需要对1e9+7取模 核心问题:计算每个文章被多少个区间选入前k大,然后乘以bi并累加。关键在于高效判断每个文章在哪些区间内会入选(区间内价值大于它的文章数小于k)。

### 题解 分析howdo: 首先暴力做一定会超时,但是我们可以单独考虑每个元素,对于单独的元素,假设它是按a数组原顺序的第p个元素,我们需要统计它对答案的贡献是多少,不难得到它的贡献为b[p]*count,count是它在多少个区间内是排k名之前的(即在区间内大于a[p]的数小于k个),那现在本题就转化为求a[p]在多少个区间内能够被选到

怎么用数据结构实现

我们按照a数组从小到大的排序,使用set,将a从大到小的加入set里面(加入的是a[p]对应的下标p),这里我们考虑使用pair来存储{a[p],p},使用sort可以a其排序,并且方便找到它的下标;每次加入一个数p,set里面的数对应的a都大于a[p],这样我们考虑a[p]的在那些区间内满足条件就很容易了,可以用c数组表示set里面p之前的k个数,d数组表示p之后的k个数 即找到ck ck-1 ck-2 ck-3..c1 p d1 d2 d3...dk-1;这样计算有多少区间满足的时候 就可以直接枚举 左端点l∈[l1=c[i]+1,l2=c[i-1]] 并可以求出r∈[r1:p,r2:d[k-i+1]-1]],那么满足田间的区间数就是(l2-l1+1)*(r2-r1+1) ,思路就是这样,难点就是c和d不一定满k个,所以我们枚举l的时候要注意边界:int l1=c[i]+1,l2=(c[i-1]?c[i-1]:v); 枚举c的时候d[k-i+1]可能不存在同样d也要注意边界:int r1=v;int r2=(d[k-i+1]?d[k-i+1]-1:n);

具体代码如下:

using namespace std;
const int N=1e5+10,mod=1e9+7;
using ll=long long;
using pll=pair<int,int>;
ll res;
int n,k,a[N],b[N];
pll p[N];
set<int>s;
int main()
{
    cin>>n>>k;
    for(int i=1;i<=n;i++) 
    {
        cin>>a[i];
        p[i]={a[i],i};
    }
    for(int i=1;i<=n;i++) cin>>b[i];
    s.insert(0),s.insert(n+1);
    sort(p+1,p+1+n,greater<pll>());
    ll last=0; 
    for(int m=1;m<=n;m++)
    {
        auto t=p[m];int v=t.second;
        s.insert(v);
        auto it=s.lower_bound(v);
        auto jt1=it,jt2=it;
        int cnt1=1,cnt2=1;
        vector<int>c(k+2,0); 
		vector<int>d(k+2,0); 
        for(int i=1;i<=k;i++)
        {
             jt1++;
             if(jt1==s.end()) break;
             d[cnt1++]=*jt1;
        }
        for(int i=1;i<=k;i++)
        {
             jt2--;
             if(jt2==s.begin())
             {
                 c[cnt2++]=*jt2;
                 break;
             }
             c[cnt2++]=*jt2;
       }
        last=res;
        for(int i=1;i<cnt2;i++)
        {
            int l1=c[i]+1,l2=(c[i-1]?c[i-1]:v);
            int r1=v;
            int r2=(d[k-i+1]?d[k-i+1]-1:n);
            res=(res+1ll*b[v]*(l2-l1+1)%mod*(r2-r1+1)%mod)%mod;
        }
    }
    cout<<res<<'\n';
    return 0;
}

其他优化版本:(借鉴他人的)

```#include <bits/stdc++.h>
 
#define x first
#define y second
 
using namespace std;
 
typedef long long ll;
typedef pair<ll, ll> PLL;
typedef unsigned long long ull;
 
mt19937_64 rng(time(0));
 
void chmax(ll &x, ll y){if (x < y) x = y;}
void chmin(ll &x, ll y){if (x > y) x = y;}
 
const ll P = 1e9 + 7;
 
void solve()
{
    ll n, k;
    cin >> n >> k;
 
    vector<ll> a(n + 1), b(n + 1), id(n + 1);
    iota(id.begin(), id.end(), 0ll);
 
    for (ll i = 1; i <= n; i ++ ) cin >> a[i];
    for (ll i = 1; i <= n; i ++ ) cin >> b[i];
 
    sort(id.begin(), id.end(), [&](auto x, auto y){
        return a[x] < a[y];
    });
 
    set<ll> st;
    st.insert(0), st.insert(n + 1);
    ll res = 0;
    for (ll i = n; i >= 1; i -- )
    {
        vector<ll> l(k), r(k);
        ll p = id[i];
        st.insert(p);
        auto it = st.lower_bound(p);
        for (ll j = 0; j < k && *it != n + 1; j ++ )
        {
            r[j] = *next(it) - *it;
            it ++;
        }
        it = st.lower_bound(p);
        for (ll j = 0; j < k && *it != 0; j ++ )
        {
            l[j] = *it - *prev(it);
            it --;
        }
 
        vector<ll> pre(k);
        for (ll i = 0, sum = 0; i < k; i ++ )
        {
            sum += l[i];
            pre[i] = sum;
        }
        for (ll i = 0; i < k; i ++ ) res = (res + r[i] * pre[k - i - 1] % P * b[p]) % P;
    }
 
    cout << res;
}
 
int main()
{ 
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
 
    ll _ = 1;
    //cin >> _;
 
    while (_ -- )
    solve();
 
    return 0;
}