J题py
import sys
sys.setrecursionlimit(1000000)
from sys import stdin
input=lambda:stdin.readline().strip()
n,q=map(int,input().split())
A=[0]+list(map(int,input().split()))
node=[i for i in range(n+2)]
SUM=[0]*(n+2)
def F(x):
return abs(x**3-3*x)//(3*x**2+1)*2
def find(x):
if x==node[x]:
return node[x]
node[x]=find(node[x])
return node[x]
def low_bit(x):
return x&(-x)
def add(index,x):
while index<n+2:
SUM[index]+=x
index=index+low_bit(index)
def get(index):
ret=0
while index>0:
ret+=SUM[index]
index-=low_bit(index)
return ret
for i in range(1,n+1):
add(i,A[i])
for i in range(q):
query=list(map(int,input().split()))
if query[0]==1:
l,r=query[1:]
left=find(node[l])
while left<=r:
temp=F(A[left])
if temp==A[left]:
node[left]=find(node[left+1])
add(left,temp-A[left])
A[left]=temp
left=find(node[left+1])
else:
l, r = query[1:]
print(get(r)-get(l-1))