2020牛客寒假算法基础集训营2 - J 求函数 (线段树)

链接:https://ac.nowcoder.com/acm/contest/3003/J
来源:牛客网

求函数

时间限制:C/C++ 1秒,其他语言2秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld

题目描述

输入描述:

第一行,两个正整数 n,m 。
第二行,n 个整数 k1,k2,…,knk_1,k_2,\dots,k_nk1,k2,…,kn 。
第三行,n 个整数 b1,b2,…,bnb_1,b_2,\dots,b_nb1,b2,…,bn。
接下来 m 行,每行一个操作,格式见上。
保证 1≤n,m≤2×1051\leq n,m\leq 2\times 10^51≤n,m≤2×105,0≤ki,bi<109+70\leq k_i,b_i < 10^9+70≤ki,bi<109+7。

输出描述:

对于每个求值操作,输出一行一个整数,表示答案。

示例1

输入

[复制](javascript:void(0)😉

2 3
1 1
1 0
1 2 114514 1919810
2 1 2
2 1 1

输出

[复制](javascript:void(0)😉

2148838
2

说明

初始 f1(x)=x+1,f2(x)=xf_1(x)=x+1,f_2(x)=xf1(x)=x+1,f2(x)=x
修改后 f2(x)=114514x+1919810f_2(x)=114514x+1919810f2(x)=114514x+1919810
查询时 f1(1)=2,f2(f1(1))=2148838f_1(1)=2,f_2(f_1(1))=2148838f1(1)=2,f2(f1(1))=2148838 

思路:

我们观察以下数据:

\(f_1(1)=k1+b1=(k1)+(b1)\)

\(f_2(f_1(1))=k2*(f_1(1))+b2=k2*k1+k2*b1+b2=(k2*k1)+(k2*b1+b2)\)

\(f_3(f_2(f_1(1)))=(k3*k2*k1)+(k3*k2*b1+k3*b2+b3)\)

我们不妨设括号的左边部分为first,右边部分为second

容易观察出 \(f_3(f_2(f_1(1)))=[k3*(f_2(f_1(1)).first)]+[k3*(f_2(f_1(1)).second)+b3]\)

右式中前中括号部分为\(f_3(...).first\) ,后面的中括号部分为\(f_3(...).second\)

于是我们不难发现将上式中的数字3换成r,数字1换成l,数字2换成mid 也同样满足

即:\(f_r(f_{r-1}(⋯f_l(1)⋯)).first=(f_{[mid+1,r]}(1).first)*(f_{[l,mid]}(1).first)\)

\(f_r(f_{r-1}(⋯f_l(1)⋯)).second=(f_{[mid+1,r]}(1).first)*((f_{[l,mid]}(1).second))+(f_{[mid+1,r]}(1).second)\)

这样我们就可以用线段树来分别维护一个区间\([l,r]\)的first 部分和second 部分

first部分其实为\(\prod_{i=l}^r k_i\)

second部分其实为\(\sum_{i=l}^r (b_i*\prod_{j=i+1}^r k_j)\)

区间的合并公式上面已经给了,记得全程取模即可。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) { if (a == 0ll) {return 0ll;} a %= MOD; ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}
void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
inline long long readll() {long long tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
inline int readint() {int tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
const int maxn = 200010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
const ll mod = 1e9 + 7;
int n, m;
ll k[maxn];
ll b[maxn];
struct node
{
    int l, r;
    ll val;
};
node seg1[maxn << 2];
node seg2[maxn << 2];
void build1(int rt, int l, int r)
{
    seg1[rt].l = l;
    seg1[rt].r = r;
    if (l == r)
    {
        seg1[rt].val = k[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build1(rt << 1, l, mid);
    build1(rt << 1 | 1, mid + 1, r);
    seg1[rt].val = (seg1[rt << 1].val * seg1[rt << 1 | 1].val) % mod;
}
void update1(int rt, int pos)
{
    if (seg1[rt].l == pos && seg1[rt].r == pos)
    {
        seg1[rt].val = k[pos];
        return ;
    }
    int mid = (seg1[rt].l + seg1[rt].r) >> 1;
    if (pos <= mid)
    {
        update1(rt << 1, pos);
    } else
    {
        update1(rt << 1 | 1, pos);
    }
    seg1[rt].val = (seg1[rt << 1].val * seg1[rt << 1 | 1].val) % mod;
}
void build2(int rt, int l, int r)
{
    seg2[rt].l = l;
    seg2[rt].r = r;
    if (l == r)
    {
        seg2[rt].val = b[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build2(rt << 1, l, mid);
    build2(rt << 1 | 1, mid + 1, r);
    seg2[rt].val = ( seg1[rt << 1 | 1].val * seg2[rt << 1].val % mod + seg2[rt << 1 | 1].val) % mod;
}
 
void update2(int rt, int pos)
{
    if (seg2[rt].l == pos && seg2[rt].r == pos)
    {
        seg2[rt].val = b[pos];
        return ;
    }
    int mid = (seg2[rt].l + seg2[rt].r) >> 1;
    if (pos <= mid)
    {
        update2(rt << 1, pos);
    } else
    {
        update2(rt << 1 | 1, pos);
    }
    seg2[rt].val = ( seg1[rt << 1 | 1].val * seg2[rt << 1].val % mod + seg2[rt << 1 | 1].val) % mod;
}
pll mg(pll resl, pll resr)
{
    pll res;
    res.se = (resr.fi * resl.se % mod + resr.se) % mod;
    res.fi = resl.fi * resr.fi % mod;
    return res;
}
pll ask2(int rt, int l, int r)
{
    if (seg2[rt].l >= l && seg2[rt].r <= r)
    {
        return mp(seg1[rt].val, seg2[rt].val);
    }
    int mid = (seg2[rt].l + seg2[rt].r) >> 1;
    pll resl, resr, res;
    if (r <= mid)
    {
        return ask2(rt << 1, l, r);
    }
    if (l > mid) {
        return ask2(rt << 1 | 1, l, r);
    }
    return mg(ask2(rt << 1, l, r), ask2(rt << 1 | 1, l, r));
}
int main()
{
    //freopen("D:\\code\\text\\input.txt","r",stdin);
    //freopen("D:\\code\\text\\output.txt","w",stdout);
    n = readint();
    m = readint();
    repd(i, 1, n)
    {
        k[i] = readll();
    }
    repd(i, 1, n)
    {
        b[i] = readll();
    }
    build1(1, 1, n);
    build2(1, 1, n);
    int pos, t1, t2, l, r;
    int op;
    repd(i, 1, m)
    {
        op = readint();
        if (op == 1)
        {
            pos = readint(); t1 = readint();
            t2 = readint();
            k[pos] = t1;
            b[pos] = t2;
            update1(1, pos);
            update2(1, pos);
        } else
        {
            l = readint();
            r = readint();
            pll ans2 = ask2(1, l, r);
            printf("%lld\n", (ans2.fi + ans2.se) % mod );
        }
    }
    return 0;
}