T1
\(n \le 5 \times 10^5 , a_i \le 10^8\)
今天就临时测了这么一道题。
考场上用单调栈水了70分,结果那30分还是因为少取模(捂脸)
正解是分治,先计算左半部分,再计算右半部分,跨区间的:在左区间用一个指针从右向左枚举,右边用一个j,k, 表示能在左区间取到最大值和最小值的最大的区间右端点,再大力讨论维护一大堆数组就OK了。
挂一下sunny_r的链接
单调站\(O(n*玄学)\)
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n, top;
LL ans;
const int N = 500010, mod = 1e9;
int da[N], xiao[N], zhan[N];
LL a[N], s[N];
inline LL read()
{
LL res = 0; char ch = getchar(); bool XX = false;
for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
return XX ? -res : res;
}
inline LL S(LL a, LL b) {return ((a + b) * (b - a + 1) / 2) % mod;}
void solve3()
{
top = 0;
for (int i = 1; i <= n; ++i)
{
while (top && a[i] <= a[zhan[top]])xiao[zhan[top--]] = i;
zhan[++top] = i;
}
top = 0;
for (int i = 1; i <= n; ++i)
{
while (top && a[i] >= a[zhan[top]]) da[zhan[top--]] = i;
zhan[++top] = i;
}
for (int i = 1; i <= n; ++i)
{
if (!da[i])da[i] = n + 1;
if (!xiao[i])xiao[i] = n + 1;
}
LL ma, mi;
for (int i = 1; i <= n; ++i)
{
ma = mi = i;
while (ma <= n || mi <= n)
{
if (da[ma] <= xiao[mi] && ma != n + 1)(ans += (LL)a[ma] * a[mi] % mod * S(max(ma, mi) - i + 1, da[ma] - i) % mod) %= mod, ma = da[ma];
else (ans += (LL)a[ma] * a[mi] % mod * S(max(ma, mi) - i + 1, xiao[mi] - i) % mod) %= mod, mi = xiao[mi];
}
}
cout << ans;
}
signed main()
{
cin >> n;
for (int i = 1; i <= n; ++i)a[i] = read();
solve3();
return 0;
}
正解\(O(nlogn)\)
#include<iostream>
#include<cstdio>
#define int long long
#define LL long long
using namespace std;
int n;
LL ans;
const int N = 500010, mod = 1e9;
int a[N], c[N][2], p[N][2], q[N][2], f[N], g[N];
inline LL read()
{
LL res = 0; char ch = getchar(); bool XX = false;
for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
return XX ? -res : res;
}
void ad(int &x, int y) { x += y; if (x >= mod) x -= mod; }
void dl(int &x, int y) { x -= y; if (x < 0) x += mod; }
int getsum(int x, int y) { return ((LL)(x + y) * (y - x + 1) >> 1) % mod; }
LL S(LL x, LL y) {return (LL)(x + y) * (y - x + 1) >> 1;}
void solve(int l, int r)
{
if (l == r)return (void)((ans += a[l] * a[l] % mod) %= mod);
int mid = (l + r) >> 1;
solve(l, mid); solve(mid + 1, r);
c[mid][0] = c[mid][1] = a[mid];
for (int i = mid - 1; i >= l; --i)
{
c[i][0] = min(c[i + 1][0], a[i]);
c[i][1] = max(c[i + 1][1], a[i]);
}
int mn = mod, mx = -mod;
f[mid] = g[mid] = p[mid][0] = p[mid][1] = q[mid][0] = q[mid][1] = 0;
for (int i = mid + 1; i <= r; ++i)
{
mn = min(mn, a[i]); mx = max(mx, a[i]);
f[i] = (LL)mn * mx % mod * (i - mid) % mod; (f[i] += f[i - 1]) %= mod;
g[i] = (LL)mn * mx % mod; (g[i] += g[i - 1]) %= mod;
p[i][0] = (p[i - 1][0] + mn) % mod;
q[i][0] = (q[i - 1][0] + mx) % mod;
p[i][1] = ((LL)mn * (i - mid) % mod + p[i - 1][1]) % mod;
q[i][1] = ((LL)mx * (i - mid) % mod + q[i - 1][1]) % mod;
}
int j = mid, k = mid;
for (int i = mid; i >= l; --i)
{
while (j < r && c[i][0] < a[j + 1])++j;
while (k < r && c[i][1] > a[k + 1])++k;
(ans += (LL)c[i][0] * c[i][1] % mod * S(mid - i + 2, min(j, k) - i + 1) % mod) %= mod;
(ans += ((LL)g[r] * (mid - i + 1) + f[r]) % mod) %= mod;
(ans -= ((LL)g[max(j, k)] * (mid - i + 1) + f[max(j, k)]) % mod) %= mod;
if (j < k)
{
(ans += ((LL)p[k][0] * (mid - i + 1) % mod + p[k][1]) % mod * c[i][1]) %= mod;
(ans -= ((LL)p[j][0] * (mid - i + 1) % mod + p[j][1]) % mod * c[i][1]) %= mod;
}
else
{
(ans += ((LL)q[j][0] * (mid - i + 1) % mod + q[j][1]) % mod * c[i][0]) %= mod;
(ans -= ((LL)q[k][0] * (mid - i + 1) % mod + q[k][1]) % mod * c[i][0]) %= mod;
}
}
}
signed main()
{
cin >> n;
for (int i = 1; i <= n; ++i)a[i] = read();
solve(1, n);
cout << (ans % mod + mod) % mod;
return 0;
}