题意概括:
给定n篇文章,每篇有价值ai和体力消耗bi。对所有区间[l, r],选择区间内价值最大的min(k, 区间长度)篇文章,求这些文章消耗的体力值之和的总和(对所有区间求和)。由于答案很大,需要对1e9+7取模 核心问题:计算每个文章被多少个区间选入前k大,然后乘以bi并累加。关键在于高效判断每个文章在哪些区间内会入选(区间内价值大于它的文章数小于k)。
### 题解 分析howdo: 首先暴力做一定会超时,但是我们可以单独考虑每个元素,对于单独的元素,假设它是按a数组原顺序的第p个元素,我们需要统计它对答案的贡献是多少,不难得到它的贡献为b[p]*count,count是它在多少个区间内是排k名之前的(即在区间内大于a[p]的数小于k个),那现在本题就转化为求a[p]在多少个区间内能够被选到
怎么用数据结构实现:
我们按照a数组从小到大的排序,使用set,将a从大到小的加入set里面(加入的是a[p]对应的下标p),这里我们考虑使用pair来存储{a[p],p},使用sort可以a其排序,并且方便找到它的下标;每次加入一个数p,set里面的数对应的a都大于a[p],这样我们考虑a[p]的在那些区间内满足条件就很容易了,可以用c数组表示set里面p之前的k个数,d数组表示p之后的k个数 即找到ck ck-1 ck-2 ck-3..c1 p d1 d2 d3...dk-1;这样计算有多少区间满足的时候 就可以直接枚举 左端点l∈[l1=c[i]+1,l2=c[i-1]] 并可以求出r∈[r1:p,r2:d[k-i+1]-1]],那么满足田间的区间数就是(l2-l1+1)*(r2-r1+1) ,思路就是这样,难点就是c和d不一定满k个,所以我们枚举l的时候要注意边界:int l1=c[i]+1,l2=(c[i-1]?c[i-1]:v); 枚举c的时候d[k-i+1]可能不存在同样d也要注意边界:int r1=v;int r2=(d[k-i+1]?d[k-i+1]-1:n);
具体代码如下:
using namespace std;
const int N=1e5+10,mod=1e9+7;
using ll=long long;
using pll=pair<int,int>;
ll res;
int n,k,a[N],b[N];
pll p[N];
set<int>s;
int main()
{
cin>>n>>k;
for(int i=1;i<=n;i++)
{
cin>>a[i];
p[i]={a[i],i};
}
for(int i=1;i<=n;i++) cin>>b[i];
s.insert(0),s.insert(n+1);
sort(p+1,p+1+n,greater<pll>());
ll last=0;
for(int m=1;m<=n;m++)
{
auto t=p[m];int v=t.second;
s.insert(v);
auto it=s.lower_bound(v);
auto jt1=it,jt2=it;
int cnt1=1,cnt2=1;
vector<int>c(k+2,0);
vector<int>d(k+2,0);
for(int i=1;i<=k;i++)
{
jt1++;
if(jt1==s.end()) break;
d[cnt1++]=*jt1;
}
for(int i=1;i<=k;i++)
{
jt2--;
if(jt2==s.begin())
{
c[cnt2++]=*jt2;
break;
}
c[cnt2++]=*jt2;
}
last=res;
for(int i=1;i<cnt2;i++)
{
int l1=c[i]+1,l2=(c[i-1]?c[i-1]:v);
int r1=v;
int r2=(d[k-i+1]?d[k-i+1]-1:n);
res=(res+1ll*b[v]*(l2-l1+1)%mod*(r2-r1+1)%mod)%mod;
}
}
cout<<res<<'\n';
return 0;
}
其他优化版本:(借鉴他人的)
```#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;
typedef long long ll;
typedef pair<ll, ll> PLL;
typedef unsigned long long ull;
mt19937_64 rng(time(0));
void chmax(ll &x, ll y){if (x < y) x = y;}
void chmin(ll &x, ll y){if (x > y) x = y;}
const ll P = 1e9 + 7;
void solve()
{
ll n, k;
cin >> n >> k;
vector<ll> a(n + 1), b(n + 1), id(n + 1);
iota(id.begin(), id.end(), 0ll);
for (ll i = 1; i <= n; i ++ ) cin >> a[i];
for (ll i = 1; i <= n; i ++ ) cin >> b[i];
sort(id.begin(), id.end(), [&](auto x, auto y){
return a[x] < a[y];
});
set<ll> st;
st.insert(0), st.insert(n + 1);
ll res = 0;
for (ll i = n; i >= 1; i -- )
{
vector<ll> l(k), r(k);
ll p = id[i];
st.insert(p);
auto it = st.lower_bound(p);
for (ll j = 0; j < k && *it != n + 1; j ++ )
{
r[j] = *next(it) - *it;
it ++;
}
it = st.lower_bound(p);
for (ll j = 0; j < k && *it != 0; j ++ )
{
l[j] = *it - *prev(it);
it --;
}
vector<ll> pre(k);
for (ll i = 0, sum = 0; i < k; i ++ )
{
sum += l[i];
pre[i] = sum;
}
for (ll i = 0; i < k; i ++ ) res = (res + r[i] * pre[k - i - 1] % P * b[p]) % P;
}
cout << res;
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
ll _ = 1;
//cin >> _;
while (_ -- )
solve();
return 0;
}