利用分块维护每个区间的信息:
- blocks: 每个块中的所有数
- blocksinfo: 每个块中所有数排好序后的结果
- blockslz: 如果某次更改将这个块内所有的数全都加上了一个固定的值,将这个值增加到blockslz数组中该块对应的下标上
每次更新 (l, r, x):
- 通过计算得到这次更新涉及那些块
- 对于未被完整更新的块(块中有的数没有被更新),将该块中每个被影响的数加上x,之后把这个块新的排好序后的结果存到blocksinfo对应的下标中
- 对于被完整更新的块(块中所有数都被更新了),将x加到blockslz数组中该块对应的下标上,因为这个块所有的数都被增加了x,数组中所有数之间的大小关系是不变的,因此不需要更新对应的blocksinfo
每次查询(l, r, x):
- 通过计算得到这次更新涉及那些块
- 对于未被完整查询的块(块中有的数没有被查询),判断每个数在blocks中存储的值加上该块在blocksinfo中存储的值是否小于x
- 对于被完整查询的块(块中所有数都被查询了),利用二分法,判断blocksinfo中有多少个小于x减去该块blockslz值的数
'''
Hala Madrid!
https://github.com/USYDDonghaoLi/Programming_Competition
'''
import sys
import os
from io import BytesIO, IOBase
BUFSIZE = 8192
class FastIO(IOBase):
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
self.write = lambda s: self.buffer.write(s.encode("ascii"))
self.read = lambda: self.buffer.read().decode("ascii")
self.readline = lambda: self.buffer.readline().decode("ascii")
sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
input = lambda: sys.stdin.readline().rstrip("\r\n")
def I():
return input()
def II():
return int(input())
def MI():
return map(int, input().split())
def LI():
return list(input().split())
def LII():
return list(map(int, input().split()))
def GMI():
return map(lambda x: int(x) - 1, input().split())
#------------------------------FastIO---------------------------------
from bisect import *
from heapq import *
from collections import *
from functools import *
from itertools import *
from time import *
from random import *
from math import log, gcd, sqrt, ceil
inf = float('inf')
fmin = lambda x, y: x if x < y else y
fmax = lambda x, y: x if x > y else y
# @TIME
def solve(testcase):
n, q = MI()
A = LII()
blocksz = int(sqrt(n))
blocks = [[] for _ in range(blocksz + 10)]
blockinfo = [[] for _ in range(blocksz + 10)]
blocklz = [0 for _ in range(blocksz + 10)]
for i, v in enumerate(A):
where, idx = divmod(i, blocksz)
blocks[where].append(v)
for i in range(blocksz + 10):
blockinfo[i] = sorted(blocks[i])
def update(l, r, x):
lwhere, lidx = divmod(l, blocksz)
rwhere, ridx = divmod(r, blocksz)
if lwhere == rwhere:
for i in range(lidx, ridx + 1):
blocks[lwhere][i] += x
blockinfo[lwhere] = sorted(blocks[lwhere])
else:
for j in range(lidx, len(blocks[lwhere])):
blocks[lwhere][j] += x
blockinfo[lwhere] = sorted(blocks[lwhere])
lwhere += 1
for j in range(ridx + 1):
blocks[rwhere][j] += x
blockinfo[rwhere] = sorted(blocks[rwhere])
rwhere -= 1
for i in range(lwhere, rwhere + 1):
blocklz[i] += x
def query(l, r, x):
res = 0
lwhere, lidx = divmod(l, blocksz)
rwhere, ridx = divmod(r, blocksz)
if lwhere == rwhere:
for i in range(lidx, ridx + 1):
val = blocks[lwhere][i] + blocklz[lwhere]
if val < x:
res += 1
else:
for j in range(lidx, len(blocks[lwhere])):
val = blocks[lwhere][j] + blocklz[lwhere]
if val < x:
res += 1
lwhere += 1
for j in range(ridx + 1):
val = blocks[rwhere][j] + blocklz[rwhere]
if val < x:
res += 1
rwhere -= 1
for i in range(lwhere, rwhere + 1):
idx = bisect_left(blockinfo[i], x - blocklz[i])
res += idx
print(res)
for _ in range(q):
ops = LII()
op = ops[0]
if op == 1:
l, r, x = ops[1] - 1, ops[2] - 1, ops[3]
update(l, r, x)
else:
l, r, x = ops[1] - 1, ops[2] - 1, ops[3]
query(l, r, x)
for testcase in range(1):
solve(testcase)

京公网安备 11010502036488号