J题py

import sys
sys.setrecursionlimit(1000000)
from sys import stdin
input=lambda:stdin.readline().strip()
n,q=map(int,input().split())
A=[0]+list(map(int,input().split()))
node=[i for i in range(n+2)]
SUM=[0]*(n+2)
def F(x):
    return abs(x**3-3*x)//(3*x**2+1)*2
def find(x):
    if x==node[x]:
        return node[x]
    node[x]=find(node[x])
    return node[x]
def low_bit(x):
    return x&(-x)
def add(index,x):
    while index<n+2:
        SUM[index]+=x
        index=index+low_bit(index)
def get(index):
    ret=0
    while index>0:
        ret+=SUM[index]
        index-=low_bit(index)
    return ret
for i in range(1,n+1):
    add(i,A[i])
for i in range(q):
    query=list(map(int,input().split()))
    if query[0]==1:
        l,r=query[1:]
        left=find(node[l])
        while left<=r:
            temp=F(A[left])
            if temp==A[left]:
                node[left]=find(node[left+1])
            add(left,temp-A[left])
            A[left]=temp
            left=find(node[left+1])
    else:
        l, r = query[1:]
        print(get(r)-get(l-1))