题解:BISHI127 区间根号与区间求和
题目链接
题目描述
维护长度为 的数组
,支持两类操作:
- 区间根号:把区间
内每个
替换为
;
- 区间和查询:输出
。
解题思路
用线段树维护区间和与区间最大值:对“区间根号”操作,若当前节点区间的最大值 ,则该区间再怎么开根号都不变,直接剪枝;否则递归到底层点进行
更新后回溯维护区间和与最大值。
正确性与复杂度:每个元素在经历若干次 后很快收敛到
,对每个点的修改次数为
。结合线段树,整体复杂度约为
,能通过本题。
代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
struct SegTree {
int n; vector<i64> sum; vector<int> mx;
SegTree(int n=0): n(n) { sum.assign(4*n+4, 0); mx.assign(4*n+4, 0); }
void build(int idx, int l, int r, const vector<i64>& a){
if(l==r){ sum[idx]=a[l]; mx[idx]=(int)a[l]; return; }
int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a);
pull(idx);
}
void pull(int idx){ sum[idx]=sum[idx<<1]+sum[idx<<1|1]; mx[idx]=max(mx[idx<<1], mx[idx<<1|1]); }
void range_sqrt(int idx,int l,int r,int ql,int qr){
if(ql>r || qr<l || mx[idx]<=1) return;
if(l==r){ int v=(int)floor(sqrt((double)sum[idx])); sum[idx]=v; mx[idx]=v; return; }
int m=(l+r)>>1; range_sqrt(idx<<1,l,m,ql,qr); range_sqrt(idx<<1|1,m+1,r,ql,qr); pull(idx);
}
i64 range_sum(int idx,int l,int r,int ql,int qr){
if(ql>r || qr<l) return 0;
if(ql<=l && r<=qr) return sum[idx];
int m=(l+r)>>1; return range_sum(idx<<1,l,m,ql,qr)+range_sum(idx<<1|1,m+1,r,ql,qr);
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m; if(!(cin>>n>>m)) return 0;
vector<i64> a(n+1); for(int i=1;i<=n;i++) cin>>a[i];
SegTree st(n); st.build(1,1,n,a);
while(m--){
int op,l,r; cin>>op>>l>>r;
if(op==1){ st.range_sqrt(1,1,n,l,r); }
else { cout<<st.range_sum(1,1,n,l,r)<<'\n'; }
}
return 0;
}
import java.io.*;
public class Main {
static class FastScanner {
private final InputStream in; private final byte[] buf = new byte[1<<16];
private int p=0,l=0; FastScanner(InputStream is){in=is;}
private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
}
static class SegTree {
int n; long[] sum; int[] mx;
SegTree(int n){ this.n=n; sum=new long[4*n+4]; mx=new int[4*n+4]; }
void build(int idx,int l,int r,long[] a){ if(l==r){ sum[idx]=a[l]; mx[idx]=(int)a[l]; return; } int m=(l+r)>>1; build(idx<<1,l,m,a); build(idx<<1|1,m+1,r,a); pull(idx); }
void pull(int idx){ sum[idx]=sum[idx<<1]+sum[idx<<1|1]; mx[idx]=Math.max(mx[idx<<1], mx[idx<<1|1]); }
void rangeSqrt(int idx,int l,int r,int ql,int qr){
if(ql>r || qr<l || mx[idx]<=1) return;
if(l==r){ int v=(int)Math.floor(Math.sqrt((double)sum[idx])); sum[idx]=v; mx[idx]=v; return; }
int m=(l+r)>>1; rangeSqrt(idx<<1,l,m,ql,qr); rangeSqrt(idx<<1|1,m+1,r,ql,qr); pull(idx);
}
long rangeSum(int idx,int l,int r,int ql,int qr){ if(ql>r||qr<l) return 0; if(ql<=l && r<=qr) return sum[idx]; int m=(l+r)>>1; return rangeSum(idx<<1,l,m,ql,qr)+rangeSum(idx<<1|1,m+1,r,ql,qr); }
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
int n = fs.nextInt();
int m = fs.nextInt();
long[] a = new long[n+1];
for(int i=1;i<=n;i++) a[i]=fs.nextLong();
SegTree st = new SegTree(n);
st.build(1,1,n,a);
StringBuilder out = new StringBuilder();
while(m-- > 0){
int op = fs.nextInt();
int l = fs.nextInt();
int r = fs.nextInt();
if(op==1){ st.rangeSqrt(1,1,n,l,r); }
else { out.append(st.rangeSum(1,1,n,l,r)).append('\n'); }
}
System.out.print(out.toString());
}
}
import sys, math
data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); m = int(next(it))
a = [0]*(n+1)
for i in range(1, n+1):
a[i] = int(next(it))
size = 4*n + 5
sumv = [0]*size
mxv = [0]*size
def floor_isqrt(x: int) -> int:
v = int(math.sqrt(x))
while (v+1)*(v+1) <= x:
v += 1
while v*v > x:
v -= 1
return v
def build(idx, l, r):
if l == r:
sumv[idx] = a[l]
mxv[idx] = a[l]
return
mid = (l + r) >> 1
build(idx<<1, l, mid)
build(idx<<1|1, mid+1, r)
sumv[idx] = sumv[idx<<1] + sumv[idx<<1|1]
mxv[idx] = mxv[idx<<1] if mxv[idx<<1] > mxv[idx<<1|1] else mxv[idx<<1|1]
def range_sqrt(idx, l, r, ql, qr):
if ql > r or qr < l or mxv[idx] <= 1:
return
if l == r:
v = floor_isqrt(sumv[idx])
sumv[idx] = v
mxv[idx] = v
return
mid = (l + r) >> 1
range_sqrt(idx<<1, l, mid, ql, qr)
range_sqrt(idx<<1|1, mid+1, r, ql, qr)
sumv[idx] = sumv[idx<<1] + sumv[idx<<1|1]
mxv[idx] = mxv[idx<<1] if mxv[idx<<1] > mxv[idx<<1|1] else mxv[idx<<1|1]
def range_sum(idx, l, r, ql, qr):
if ql > r or qr < l:
return 0
if ql <= l and r <= qr:
return sumv[idx]
mid = (l + r) >> 1
return range_sum(idx<<1, l, mid, ql, qr) + range_sum(idx<<1|1, mid+1, r, ql, qr)
build(1, 1, n)
out = []
for _ in range(m):
op = int(next(it))
l = int(next(it)); r = int(next(it))
if op == 1:
range_sqrt(1, 1, n, l, r)
else:
out.append(str(range_sum(1, 1, n, l, r)))
sys.stdout.write('\n'.join(out))
算法及复杂度
- 算法:线段树维护区间和与区间最大值,区间根号时对
的区间剪枝
- 时间复杂度:
近似
- 空间复杂度: