小月的炼金术

解法:矩阵树定理+多项式乘法

注意到如果没有那什么mod3的限制,就是朴素的矩阵树定理,求所有生成树权值之和(生成树权值定义为边之积)。

而 mod3 的限制可以通过把模P域换成循环取模的多项式域:

边类型是普通管道,系数就是普通 ; 变类型是冰或者火,一个加系数 ,一个加系数,两个组合起来就变成了(因为要模),就能作为满足条件的方案了。

其他加减乘除在多项式域里成立,直接矩阵树定理+套多项式解决此问题。

关于求逆元,这里是下的逆元,手动推导:

解三元三次方程组得到逆元的三项系数。

代码

#include <bits/stdc++.h>

#ifdef YJL

#include "include/debug.h"

#else
#define debug(...)0
#define debugN(...)0
#endif
using namespace std;

constexpr int N = 110;
constexpr int P = 998244353;
using Poly = array<int, 3>;
Poly a[N][N];

Poly operator+=(Poly &a, Poly b) {
    for (int i = 0; i < 3; ++i) {
        a[i] += b[i];
        a[i] %= P;
    }
    return a;
}

Poly operator-=(Poly &a, Poly b) {
    for (int i = 0; i < 3; ++i) {
        a[i] -= b[i];
        a[i] %= P;
    }
    return a;
}


Poly operator*(Poly a, Poly b) {
    Poly c = {};
    for (int i = 0; i < 3; ++i) {
        for (int j = 0; j < 3; ++j) {
            c[(i + j) % 3] += 1LL * a[i] * b[j] % P;
            c[(i + j) % 3] %= P;
        }
    }
    return c;
}

long long power(long long a, long long b) {
    long long ans = 1;
    for (; b; b /= 2, a = a * a % P) {
        if (b & 1) {
            ans = ans * a % P;
        }
    }
    return ans;
}

Poly inverse(Poly p) {
    // (a+bx+cx^2) * (d+ex+fx^2) = ad + (bd+ae)x + (cd+af+be)x^2 + (ce+bf)x^3 + cfx^4
    // = (ad+ce+bf) + (bd+ae+cf)x + (cd+af+be)x^2 = 1
    // ad+ce+bf=1
    // bd+ae+cf=0
    // cd+af+be=0
    auto [a, b, c] = p;
    long long fm = (1LL * a * a % P * a + 1LL * b * b % P * b + 1LL * c * c % P * c - 3LL * a * b % P * c) % P;
    assert(fm != 0);
    fm = power(fm, P - 2);
    int d = (1LL * a * a % P - 1LL * b * c % P) * fm % P;
    int e = (1LL * c * c % P - 1LL * a * b % P) * fm % P;
    int f = (1LL * b * b % P - 1LL * a * c % P) * fm % P;
    return {d, e, f};
}

Poly det(int n) {
    Poly ans = {1, 0, 0};
    int f = 1;
    for (int i = 1; i <= n; ++i) {
        int j = i;
        for (; j <= n; ++j)
            if (!(a[j][i][0] == 0 && a[j][i][1] == 0 && a[j][i][2] == 0))
                break;
        if (j > n) return {};
        if (j != i) {
            f = -f;
            swap(a[i], a[j]);
        }
        ans = ans * a[i][i];
        auto inv = inverse(a[i][i]);
        for (int j = i + 1; j <= n; ++j)
            if (!(a[j][i][0] == 0 && a[j][i][1] == 0 && a[j][i][2] == 0)) {
                auto t = a[j][i] * inv;
                for (int k = i; k <= n; ++k) a[j][k] -= a[i][k] * t;
            }
    }
    if (f == -1) {
        ans[0] *= -1;
        ans[1] *= -1;
        ans[2] *= -1;
    }
    return ans;
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int n, m;
    cin >> n >> m;
    assert(2 <= n && n <= 100);
    assert(n - 1 <= m && m <= 5000);
    for (int i = 0; i < m; ++i) {
        int u, v, w, t;
        cin >> u >> v >> w >> t;
        assert(1 <= u && u <= n);
        assert(1 <= v && v <= n);
        assert(u != v);
        assert(1 <= w && w < 998244353);
        assert(t == 0 || t == 1 || t == 2);
        Poly cur{};
        cur[(t + 1) % 3] += w;
        a[u][u] += cur;
        a[v][v] += cur;
        a[u][v] -= cur;
        a[v][u] -= cur;
    }
    auto res = det(n - 1);
    if (res[0] < 0) {
        res[0] += P;
    }
    cout << res[0];
    return 0;
}
// 1+x+x^2 :2,0,1