Atcoder abc 234 G

题目链接

题意

对于长度为 n \ n\ 的序列 ai \ a_i\ ,有 2n1 \ 2^{n-1}\ 种方案将其划分为非空连续的子区间B1,B2,...,Bk B_1, B_2, ... ,B_k\ ,对于所有划分方法,求出其关于以下式子对 998244353 \ 998244353\ 取模后的和:

i=1k(max(Bi)min(Bi))\prod_{i = 1}^{k}(max(B_i) - min(B_i))

数据范围

  • 1n3×1051 \leq n \leq 3\times 10^5
  • 1ai1091 \leq a_i \leq 10^9

样例输入

3
1 2 3

4
1 10 1 10

10
699498050 759726383 769395239 707559733 72435093 537050110 880264078 699299140 418322627 134917794

样例输出

2

90

877646588

题解

 dp[i] \ dp[i]\ 表示前 i \ i\ 项的结果,那么可以得出状态转移方程为:

dp[i]=j=1i1dp[j]×max(Bj+1,Bj+2,...,Bi)j=1i1dp[j]×min(Bj+1,Bj+2,...,Bi)dp[i] = \sum_{j = 1}^{i - 1} dp[j] \times max(B_{j+1}, B_{j+2}, ... , B_i) - \sum_{j = 1}^{i - 1} dp[j] \times min(B_{j+1}, B_{j+2}, ... , B_i)

 max \ max\ 即前半部分转移为例子,为了能够线性转移方程,我们就需要维护两样东西:

  1. max(Bj+1,Bj+2,...,Bi)max(B_{j+1}, B_{j+2}, ... , B_i)
  2.  max \ max\ 值所对应的dp值的和

我们可以利用单调队列维护这两样东西实现线性的转移。

Code

#pragma GCC optimize(2)
#pragma GCC optimize(3)

#include <bits/stdc++.h>

using namespace std;

#define ll long long

constexpr int N = 3e5 + 100;
constexpr int mod = 998244353;

ll a[N];
ll dp[N];

int main() {
    ios::sync_with_stdio(false);
    int n; cin >> n;
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
    }

    stack<ll> max_, min_, max_v, min_v;

    dp[0] = 1;

    ll sumv = 0, sumx = 0;

    for (int i = 1; i <= n; ++i) {
        ll valv = 0;
        while (not max_.empty() and max_.top() < a[i]) {
            sumv = ((sumv - max_.top() * max_v.top() % mod) % mod + mod) % mod;
            sumv = (sumv + max_v.top() * a[i] % mod) % mod;
            valv = (valv + max_v.top()) % mod;
            max_v.pop(); max_.pop();
        }
        sumv = (sumv + dp[i - 1] * a[i] % mod) % mod;
        valv = (valv + dp[i - 1]) % mod;
        max_.push(a[i]); max_v.push(valv);

        ll valx = 0;
        while (not min_.empty() and min_.top() > a[i]) {
            sumx = ((sumx - min_.top() * min_v.top() % mod) % mod + mod) % mod;
            sumx = (sumx + min_v.top() * a[i] % mod) % mod;
            valx = (valx + min_v.top()) % mod;
            min_.pop(); min_v.pop();
        }

        sumx = (sumx + dp[i - 1] * a[i] % mod) % mod;
        valx = (valx + dp[i - 1]) % mod;
        min_.push(a[i]); min_v.push(valx);

        dp[i] = (sumv - sumx + mod) % mod;

//        cout << i << ' ' << sumv << ' ' << sumx << ' ' << dp[i] << '\n';
    }

    cout << dp[n] << '\n';
}