ACM模版

描述

题解

这个题用线段树可解,奇思妙想啊~~~

首先我们可以很容易理解的是, S 序列的图像可以抽象为锯齿状,我们需要将注意力放在上齿,例如: S={1,2,3,2,3,4,1,2,5} ,这里的 { 1,2,3}{ 2,3,4}{ 1,2,5} 就是上齿。

然后我们需要考虑将这个序列分解为 N 个固定左端点的极大近似有序区间,例如上述这个序列,固定第 2 个为左端点的极大近似有序区间是 { 2,3,2,3,4} ,固定第 6 个为左端点的极大近似有序区间是 {4}

找到上述的所有极大近似有序区间后,我们也就只需要考虑固定左端点条件下,该区间有多少结点作为右端点可选,可选的条件自然是要大于等于左边一直到左端点的所有数,这也就构成了区间查询,不断查询并且缩小查询区间,一直到无法查询为止,也就找到了所有的合法右端点。

所以根据上述过程,我们可以很容易想到,这里需要倒着扫描序列,这样就很容易用线段树解决这个问题了。

代码

#include <cstdio>
#include <iostream>

#define ls rt << 1
#define rs rt << 1 | 1

using namespace std;

const int MAXN = 5e4 + 10;

int n;
int S[MAXN];
int vis[MAXN];
int sum[MAXN];
int tree[MAXN << 2];

void build(int rt, int l, int r)
{
    if (l == r)
    {
        tree[rt] = S[l];
        return ;
    }

    int m = (l + r) >> 1;
    build(ls, l, m);
    build(rs, m + 1, r);

    tree[rt] = max(tree[ls], tree[rs]);
}

int query(int rt, int l, int r, int x, int y, int mx)
{
    if (l == r)
    {
        if (S[l] >= mx)
        {
            return l;
        }

        return 0;
    }

    int m = (l + r) >> 1, k = 0;
    if (tree[rt] >= mx)
    {
        if (x <= m)
        {
            k = query(ls, l, m, x, y, mx);
        }
        if (k == 0 && y > m)
        {
            k = query(rs, m + 1, r, x, y, mx);
        }
    }

    return k;
}

int main()
{
    cin >> n;

    for (int i = 1; i <= n; i++)
    {
        cin >> S[i];
    }

    build(1, 1, n);

    for (int i = n; i >= 1; i--)
    {
        if (S[i] > S[i + 1])
        {
            sum[i] = 1;
            vis[i] = i;
        }
        else
        {
            vis[i] = vis[i + 1];
            sum[i] = sum[i + 1] + 1;
            int mx = S[vis[i]];
            int l = vis[i] + 1, r = l;

            while (S[r] >= S[i])
            {
                r = vis[r] + 1;
            }
            r -= 1;

            while (l <= r)
            {
                mx = query(1, 1, n, l, r, mx);
                if (mx != 0)
                {
                    sum[i]++;
                }
                else
                {
                    break;
                }

                l = mx + 1;
                vis[i] = mx;
                mx = S[mx];
            }
        }
    }

    long long ans = 0;
    for (int i = 1; i <= n; i++)
    {
        ans += sum[i];
    }
    cout << ans << '\n';

    return 0;
}