题号 NC200195
名称 区区区间
来源 牛客小白月赛20

题目描述

给定一个序列,有两种操作:

1 l r k:将区间[l,r]变成[k,k + 1,...,k + r - l] 的序列

2 l r:将区间求区间[l,r]的区间和

样例

输入
5 5
1 1 1 1 1
2 1 5
1 1 5 1
2 1 5
1 1 3 3
2 1 3
输出
5
15
12

算法1

(线段树维护区间和 和 覆盖问题)

很容易观察到操作1是将一段区间变成一段等差数列,
因为等差数列的通项为:a[i] = a[i - 1] + d(1),也可以写成:a[i] = a[1] + (i - 1)d(2),
如果将等差数列放在数轴[l,r]上,我们要求位置为r的数是多少,通过上面的公式(2)我们可以得到:a[r] = a[l] + (r - l)d。
所以对于区间[l,r],我们设i (l <= i <= r), 为第i个数所在等差数列中的起点的位置,为第i个数所在等差数列中的首项的大小,
则这一段区间和就是,
我们对公式整理一下:,
其中可以用 直接求;
对于,我们观察一下下标可以发现这是一段连续区间的区间和,所以想到用线段树维护维护区间和,分别是首项之和和起点下标之和,但是用线段树如何更新呢?我们模拟操作1,可以发现每次修改一段区间[l,r]这个区间中的所有数所在等差数列中的首项的大小都变成了k,并且这些数所在等差数列中的起点的位置统一变成了l,所以我们可以用线段树让区间的所有数变成同一个数的操作完成更新,线段树维护两个懒标记一个是这个区间数的首项,一个是这个区间数的等差数列的起点位置,同时维护两个区间和:区间的首项和以及区间起点下标之和

时间复杂度 O(NlogN)

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef long long LL;
const int N = 200100;
struct Node
{
    int l,r;
    int same1,same2;
    LL sum1,sum2;
    int len;
}tr[N * 4];
int a[N];
int n,m;

inline void pushup(Node &rt,Node &left,Node &right)
{
    rt.len = left.len + right.len;
    rt.sum2 = left.sum2 + right.sum2;
    rt.sum1 = left.sum1 + right.sum1;
}

inline void pushup(int u)
{
    pushup(tr[u],tr[lc],tr[rc]);
}

inline void pushdown(int u)
{
    if(tr[u].same1 != 0 && tr[u].same2 != 0)
    {
        tr[lc].same1 = tr[rc].same1 = tr[u].same1;
        tr[lc].same2 = tr[rc].same2 = tr[u].same2;
        tr[lc].sum2 = 1ll * tr[u].same2 * tr[lc].len;
        tr[rc].sum2 = 1ll * tr[u].same2 * tr[rc].len;
        tr[lc].sum1 = 1ll * tr[lc].len * tr[u].same1;
        tr[rc].sum1 = 1ll * tr[rc].len * tr[u].same1;
        tr[u].same1 = 0;
        tr[u].same2 = 0;
    }
}

inline void build(int u,int l,int r)
{
    if(l == r)
    {
        tr[u] = Node({l,r,0,0,a[l],l,1});
        return;
    }
    tr[u] = Node({l,r,0,0,0,0,0});
    int mid = l + r >> 1;
    build(lc,l,mid);
    build(rc,mid + 1,r);
    pushup(u);
}

inline void modify(int u,int l,int r,int k,int s)
{
    if(tr[u].l >= l && tr[u].r <= r)
    {
        tr[u].same1 = k;
        tr[u].same2 = s;
        tr[u].sum2 = 1ll * s * tr[u].len;
        tr[u].sum1 = 1ll * k * tr[u].len;
        return;
    }
    pushdown(u);
    int mid = (tr[u].l + tr[u].r) >> 1;
    if(l <= mid) modify(lc,l,r,k,s);
    if(r > mid) modify(rc,l,r,k,s);
    pushup(u);
}

inline Node query(int u,int l,int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u];
    pushdown(u);
    int mid = (tr[u].l + tr[u].r) >> 1;
    if(r <= mid) return query(lc,l,r);
    else if(l > mid) return query(rc,l,r);
    else
    {
        Node res,left,right;
        left = query(lc,l,mid);
        right = query(rc,mid + 1,r);
        pushup(res,left,right);
        return res;
    }
}

void solve()
{
    scanf("%d%d",&n,&m);
    for(int i = 1;i <= n;i ++) scanf("%d",&a[i]);
    build(1,1,n);
    while(m --)
    {
        int t,l,r,k;
        scanf("%d",&t);
        if(t == 1)
        {
            scanf("%d%d%d",&l,&r,&k);
            modify(1,l,r,k,l);
        }else
        {
            scanf("%d%d",&l,&r);
            auto res = query(1,l,r);
            printf("%lld\n",res.sum1 + 1ll * (l + r) * (r - l + 1) / 2 - res.sum2);
        }
    }
}

int main()
{
    #ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #else
    #endif // LOCAL
    int T = 1;
    // init(100000);
    // scanf("%d",&T);
    while(T --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}