题解:BISHI127 区间根号与区间求和

题目链接

区间根号与区间求和

题目描述

维护长度为 的数组 ,支持两类操作:

  • 区间根号:把区间 内每个 替换为
  • 区间和查询:输出

解题思路

用线段树维护区间和与区间最大值:对“区间根号”操作,若当前节点区间的最大值 ,则该区间再怎么开根号都不变,直接剪枝;否则递归到底层点进行 更新后回溯维护区间和与最大值。

正确性与复杂度:每个元素在经历若干次 后很快收敛到 ,对每个点的修改次数为 。结合线段树,整体复杂度约为 ,能通过本题。

代码

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

struct SegTree {
    int n; vector<i64> sum; vector<int> mx;
    SegTree(int n=0): n(n) { sum.assign(4*n+4, 0); mx.assign(4*n+4, 0); }
    void build(int idx, int l, int r, const vector<i64>& a){
        if(l==r){ sum[idx]=a[l]; mx[idx]=(int)a[l]; return; }
        int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a);
        pull(idx);
    }
    void pull(int idx){ sum[idx]=sum[idx<<1]+sum[idx<<1|1]; mx[idx]=max(mx[idx<<1], mx[idx<<1|1]); }
    void range_sqrt(int idx,int l,int r,int ql,int qr){
        if(ql>r || qr<l || mx[idx]<=1) return;
        if(l==r){ int v=(int)floor(sqrt((double)sum[idx])); sum[idx]=v; mx[idx]=v; return; }
        int m=(l+r)>>1; range_sqrt(idx<<1,l,m,ql,qr); range_sqrt(idx<<1|1,m+1,r,ql,qr); pull(idx);
    }
    i64 range_sum(int idx,int l,int r,int ql,int qr){
        if(ql>r || qr<l) return 0;
        if(ql<=l && r<=qr) return sum[idx];
        int m=(l+r)>>1; return range_sum(idx<<1,l,m,ql,qr)+range_sum(idx<<1|1,m+1,r,ql,qr);
    }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m; if(!(cin>>n>>m)) return 0;
    vector<i64> a(n+1); for(int i=1;i<=n;i++) cin>>a[i];
    SegTree st(n); st.build(1,1,n,a);
    while(m--){
        int op,l,r; cin>>op>>l>>r;
        if(op==1){ st.range_sqrt(1,1,n,l,r); }
        else { cout<<st.range_sum(1,1,n,l,r)<<'\n'; }
    }
    return 0;
}
import java.io.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
        long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
    }

    static class SegTree {
        int n; long[] sum; int[] mx;
        SegTree(int n){ this.n=n; sum=new long[4*n+4]; mx=new int[4*n+4]; }
        void build(int idx,int l,int r,long[] a){ if(l==r){ sum[idx]=a[l]; mx[idx]=(int)a[l]; return; } int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a); pull(idx); }
        void pull(int idx){ sum[idx]=sum[idx<<1]+sum[idx<<1|1]; mx[idx]=Math.max(mx[idx<<1], mx[idx<<1|1]); }
        void rangeSqrt(int idx,int l,int r,int ql,int qr){
            if(ql>r || qr<l || mx[idx]<=1) return;
            if(l==r){ int v=(int)Math.floor(Math.sqrt((double)sum[idx])); sum[idx]=v; mx[idx]=v; return; }
            int m=(l+r)>>1; rangeSqrt(idx<<1,l,m,ql,qr); rangeSqrt(idx<<1|1,m+1,r,ql,qr); pull(idx);
        }
        long rangeSum(int idx,int l,int r,int ql,int qr){ if(ql>r||qr<l) return 0; if(ql<=l && r<=qr) return sum[idx]; int m=(l+r)>>1; return rangeSum(idx<<1,l,m,ql,qr)+rangeSum(idx<<1|1,m+1,r,ql,qr); }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int m = fs.nextInt();
        long[] a = new long[n+1];
        for(int i=1;i<=n;i++) a[i]=fs.nextLong();
        SegTree st = new SegTree(n);
        st.build(1,1,n,a);
        StringBuilder out = new StringBuilder();
        while(m-- > 0){
            int op = fs.nextInt();
            int l = fs.nextInt();
            int r = fs.nextInt();
            if(op==1){ st.rangeSqrt(1,1,n,l,r); }
            else { out.append(st.rangeSum(1,1,n,l,r)).append('\n'); }
        }
        System.out.print(out.toString());
    }
}
import sys, math

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); m = int(next(it))
a = [0]*(n+1)
for i in range(1, n+1):
    a[i] = int(next(it))

size = 4*n + 5
sumv = [0]*size
mxv = [0]*size

def floor_isqrt(x: int) -> int:
    v = int(math.sqrt(x))
    while (v+1)*(v+1) <= x:
        v += 1
    while v*v > x:
        v -= 1
    return v

def build(idx, l, r):
    if l == r:
        sumv[idx] = a[l]
        mxv[idx] = a[l]
        return
    mid = (l + r) >> 1
    build(idx<<1, l, mid)
    build(idx<<1|1, mid+1, r)
    sumv[idx] = sumv[idx<<1] + sumv[idx<<1|1]
    mxv[idx] = mxv[idx<<1] if mxv[idx<<1] > mxv[idx<<1|1] else mxv[idx<<1|1]

def range_sqrt(idx, l, r, ql, qr):
    if ql > r or qr < l or mxv[idx] <= 1:
        return
    if l == r:
        v = floor_isqrt(sumv[idx])
        sumv[idx] = v
        mxv[idx] = v
        return
    mid = (l + r) >> 1
    range_sqrt(idx<<1, l, mid, ql, qr)
    range_sqrt(idx<<1|1, mid+1, r, ql, qr)
    sumv[idx] = sumv[idx<<1] + sumv[idx<<1|1]
    mxv[idx] = mxv[idx<<1] if mxv[idx<<1] > mxv[idx<<1|1] else mxv[idx<<1|1]

def range_sum(idx, l, r, ql, qr):
    if ql > r or qr < l:
        return 0
    if ql <= l and r <= qr:
        return sumv[idx]
    mid = (l + r) >> 1
    return range_sum(idx<<1, l, mid, ql, qr) + range_sum(idx<<1|1, mid+1, r, ql, qr)

build(1, 1, n)

out = []
for _ in range(m):
    op = int(next(it))
    l = int(next(it)); r = int(next(it))
    if op == 1:
        range_sqrt(1, 1, n, l, r)
    else:
        out.append(str(range_sum(1, 1, n, l, r)))

sys.stdout.write('\n'.join(out))

算法及复杂度

  • 算法:线段树维护区间和与区间最大值,区间根号时对 的区间剪枝
  • 时间复杂度: 近似
  • 空间复杂度: