链接:https://ac.nowcoder.com/acm/contest/160/B
来源:牛客网

题目描述
给一个含n个点m条边的有向无环图(允许重边,点用1到n的整数表示),每条边上有一个字符,问图上有几条路径满足路径上经过的边上的字符组成的的字符串去掉空格后以大写字母开头,句号 '.' 结尾,中间都是小写字母,小写字母可以为0个。
输入描述:
第一行两个整数n,m
接下来m行,每行两个整数a,b和一个字符c,表示一条起点为a,终点为b的边,边上的字符是c
1 ≤ n, m ≤ 50000
1 ≤ a < b ≤ n
c可以是大小写字母、句号 '.' 或空格(方便起见用 '_' 表示空格)
输出描述:
输出一个整数,表示答案对232取模的结果
示例1
输入
复制
6 11
1 2 A
1 2 _
3 4 _
2 4 B
2 3 a
2 3 _
2 4 b
4 5 .
3 5 .
2 5 .
5 6 _
输出
复制
16

思路:
首先进行拓扑排序,

为什么要进行拓扑排序呢?

我们知道因为这是一个有向图,拓扑排序后不存在两个节点a,b 拓扑序中b在a的后面,而b有一条边指向a,这是不存在的。

因为在dp的过程中,我们的后一个状态是根据前一个状态转移过来的,这就要求上一个状态一定是不能再有改变的了。

即动态规划的无后效性:

    当前的值只和当前的状态有关,和之前怎么来到这个状态和之后怎么去其他状态都无关。

那么我们再拓扑排序之后,就可以根据有向边的字符类对状态进行转移。

我们定义dp状态如下:

dp[i][0] : 以i节点为结尾的路径中,只包括空格的路径个数。

dp[i][1] : 以i节点为结尾的路径中,去掉空格后,第一个字符为大写字符,后面均为小写字符串的路径个数。

dp[i][0] : 以i节点为结尾的路径中,以'.' 为结尾的合法路径个数 。

转移方程见代码。

注意:本题有一个坑:所谓的以大写字母意思是去掉空格后,第一个字符为大写字符,不能有多个大写字母。

细节见代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define rt return
#define dll(x) scanf("%I64d",&x)
#define xll(x) printf("%I64d\n",x)
#define sz(a) int(a.size())
#define all(a) a.begin(), a.end()
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2) { ans = ans * a % MOD; } a = a * a % MOD; b /= 2;} return ans;}
inline void getInt(int *p);
const int maxn = 1000010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
struct node {
    int from;
    int to;
    char t;
    node () {}
    node(int ff, int tt, int ty)
    {
        from = ff;
        to = tt;
        t = ty;
    }
};
int n, m;
std::vector<node> son[maxn];
queue<node> q;
node a[maxn];
int in[maxn];
ll dp[maxn][3];
const ll mod = (1ll << 32);
int main()
{
    //freopen("D:\\code\\text\\input.txt","r",stdin);
    //freopen("D:\\code\\text\\output.txt","w",stdout);
    gbtb;
    cin >> n >> m;
    int u, v;
    char t;
    repd(i, 1, m) {
        cin >> u >> v >> t;
        son[u].pb(node(u, v, t));
        in[v]++;
    }
    repd(i, 1, n) {
        if (!in[i]) {
            q.push(node(0, i, '_'));
        }
    }
    int cnt = 0;
    node temp;
    while (!q.empty()) {
        temp = q.front();
        q.pop();
        for (auto x : son[temp.to]) {
            a[++cnt] = x;
            in[x.to]--;
            if (!in[x.to]) {
                q.push(x);
            }
        }
    }

    repd(i, 1, cnt) {
        if (a[i].t == '_') {
            dp[a[i].to][0] += dp[a[i].from][0] + 1;
            dp[a[i].to][1] += dp[a[i].from][1];
            dp[a[i].to][2] += dp[a[i].from][2];
        }
        if (a[i].t == '.') {
            dp[a[i].to][2] += dp[a[i].from][1];
        }
        if (a[i].t <= 'Z' && a[i].t >= 'A') {
            dp[a[i].to][1] += dp[a[i].from][0] + 1;
        }
        if (a[i].t <= 'z' && a[i].t >= 'a') {
            dp[a[i].to][1] += dp[a[i].from][1];
        }
        repd(j, 0, 2) {
            dp[a[i].to][j] %= mod;
        }
    }
    ll ans = 0ll;
    repd(i, 1, n) {
        ans = (ans + dp[i][2]) % mod;
    }
    cout << ans << endl;
    return 0;
}

inline void getInt(int *p)
{
    char ch;
    do {
        ch = getchar();
    } while (ch == ' ' || ch == '\n');
    if (ch == '-') {
        *p = -(getchar() - '0');
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 - ch + '0';
        }
    } else {
        *p = ch - '0';
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 + ch - '0';
        }
    }
}