题目

2773. 函数调用
在这里插入图片描述

算法标签: 拓扑排序, 组合计数

思路

算每个函数最终对答案的贡献是多少, 计算当期加法但是要计算后面乘法的贡献, 用乘法原理加法原理计算每个函数等效的执行次数, 时间复杂度是线性的 O ( n ) O(n) O(n)

为什么 m u l mul mul需要按照拓扑序的逆序计算?
假设函数 A A A调用了函数 B B B和函数 C C C, 那么 A A A的最终的 m u l mul mul是取决于 B B B C C C m u l mul mul的, 因此需要按照拓扑序的逆序倒着推, 类似于动态规划的过程, 确保当前状态已经被计算

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

typedef long long LL;
const int N = 1e5 + 10, M = 1e6 + 10, MOD = 998244353;

int n, m, q_num;
vector<int> head[N];
int w[N], g[N];
int q[N], deg[N];
struct Func {
   
	int t, pos, val;
	int mul;
	int sum;
} f[N];

void add(int u, int v) {
   
	head[u].push_back(v);
}

void top_sort() {
   
	int h = 0, t = -1;
	for (int i = 1; i <= m; ++i) {
   
		if (deg[i] == 0) q[++t] = i;
	}

	while (h <= t) {
   
		int u = q[h++];
		for (int v: head[u]) {
   
			if (--deg[v] == 0) q[++t] = v;
		}
	}
}

void calc_mul() {
   
	for (int i = m - 1; i >= 0; --i) {
   
		int u = q[i];
		for (int v : head[u]) {
   
			f[u].mul = (LL) f[u].mul * f[v].mul % MOD;
		}
	}
}

void calc_sum() {
   
	for (int i = 0; i < m; ++i) {
   
		int u = q[i];
		int sum = f[u].sum;
		for (int v : head[u]) {
   
			f[v].sum = (f[v].sum + sum) % MOD;
			sum = (LL) sum * f[v].mul % MOD;
		}
	}
}

int main() {
   
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);

	cin >> n;
	for (int i = 1; i <= n; ++i) cin >> w[i];
	cin >> m;

	for (int i = 1; i <= m; ++i) {
   
		cin >> f[i].t;
		if (f[i].t == 1) cin >> f[i].pos >> f[i].val;
		else if (f[i].t == 2) cin >> f[i].val;
		else {
   
			int cnt;
			cin >> cnt;
			vector<int> tmp;
			while (cnt--) {
   
				int x;
				cin >> x;
				tmp.push_back(x);
			}
			// 为了保证结果是正确的, 需要从最后一个子函数向前计算
			reverse(tmp.begin(), tmp.end());
			for (int x: tmp) {
   
				add(i, x);
				deg[x]++;
			}
		}
	}

	// 初始化乘法因子
	for (int i = 1; i <= m; ++i) {
   
		if (f[i].t == 2) {
   
			f[i].mul = f[i].val;
		}
		else {
   
			f[i].mul = 1;
		}
	}

	top_sort();
	calc_mul();

	cin >> q_num;
	for (int i = 1; i <= q_num; ++i) cin >> g[i];

	// 计算全局的乘法标记和加法标记
	int sum = 1;
	for (int i = q_num; i >= 1; --i) {
   
		int k = g[i];
		f[k].sum = (f[k].sum + sum) % MOD;
		sum = (LL) sum * f[k].mul % MOD;
	}

	calc_sum();

	// 应用全局乘法
	for (int i = 1; i <= n; ++i) {
   
		w[i] = (LL) w[i] * sum % MOD;
	}

	// 处理类型1的加法操作
	for (int i = 1; i <= m; ++i) {
   
		if (f[i].t == 1) {
   
			w[f[i].pos] = (w[f[i].pos] + (LL) f[i].val * f[i].sum) % MOD;
		}
	}

	for (int i = 1; i <= n; ++i) cout << w[i] << " ";
	cout << "\n";

	return 0;
}