题目描述

幼儿园里有 NN 个小朋友,lxhgww老师现在想要给这些小朋友们分配糖果,要求每个小朋友都要分到糖果。但是小朋友们 也有嫉妒心,总是会提出一些要求,比如小明不希望小红分到的糖果比他的多,于是在分配糖果的时候,lxhgww需要满足小朋友们的 KK 个要求。幼儿园的糖果总是有限的,lxhgww想知道他至少需要准备多少个糖果,才能使得每个小朋友都能够分到糖果,并且满足小朋友们所有的要求。

如果X=1, 表示第A个小朋友分到的糖果必须和第B个小朋友分到的糖果一样多;

如果X=2, 表示第A个小朋友分到的糖果必须少于第B个小朋友分到的糖果;

如果X=3, 表示第A个小朋友分到的糖果必须不少于第B个小朋友分到的糖果;

如果X=4, 表示第A个小朋友分到的糖果必须多于第B个小朋友分到的糖果;

如果X=5, 表示第A个小朋友分到的糖果必须不多于第B个小朋友分到的糖果;


思路

差分约束详细讲解

对于差分约束这类问题,主要就是寻找不等式关系,形如 xixj+ckx_i \geqslant x_j + c_k。 对于本题而言的所有不等式关系如下:

  1. X=1X=1时,就有 A=BAB,BAA = B \Leftrightarrow A \geqslant B, B \geqslant A。就连一条 BA0AB0B \rightarrow A, 边权为0;A \rightarrow B,边权为0
  2. X=2X=2时,就有 A<BBA+1A < B \Leftrightarrow B \geqslant A + 1。就连一条 AB1A \rightarrow B, 边权为1
  3. X=3X=3时,就有 ABABA \geqslant B \Leftrightarrow A \geqslant B。就连一条 BA0B \rightarrow A,边权为0
  4. X=4X=4时,就有 A>BAB+1A > B \Leftrightarrow A \geqslant B + 1。就连一条 BA1B \geq A,边权为1
  5. X=5X=5时,就有 ABBAA \leq B \Leftrightarrow B \geqslant A。就连一条 AB0A \rightarrow B, 边权为0
  6. 要求每个小朋友都要分到糖果,也就形如 xi>=1x_i >= 1,可以假想一个虚拟原点,即 xix0+1x_i \geq x_0 + 1。就连一条0i10 \rightarrow i,边权为1

AC代码

#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 100010, M = N * 3;

int n, m;
int h[N], e[M], w[M], ne[M], idx;
LL dist[N];
int q[N], cnt[N];
bool st[N];

void add(int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

bool spfa()
{
    memset(dist, -0x3f, sizeof dist);
    dist[0] = 0;
    int tt = 1;
    q[0] = 0, st[0] = true;

    while (tt)
    {
        int t = q[ -- tt];
        st[t] = false;

        for (int i = h[t]; ~i; i = ne[i])
        {
            int j = e[i];
            if (dist[j] < dist[t] + w[i])
            {
                dist[j] = dist[t] + w[i];
                cnt[j] = cnt[t] + 1;
                if (cnt[j] >= n + 1) return false;
                if (!st[j])
                {
                    q[tt ++ ] = j;
                    st[j] = true;
                }
            }
        }
    }

    return true;
}

int main()
{
    scanf("%d%d", &n, &m);
    memset(h, -1, sizeof h);
    while (m -- )
    {
        int x, a, b;
        scanf("%d%d%d", &x, &a, &b);
        if (x == 1) add(a, b, 0), add(b, a, 0);
        else if (x == 2) add(a, b, 1);
        else if (x == 3) add(b, a, 0);
        else if (x == 4) add(b, a, 1);
        else add(a, b, 0);
    }

    for (int i = 1; i <= n; i ++ ) add(0, i, 1);

    if (!spfa()) puts("-1");
    else 
    {
        LL res = 0;
        for (int i = 1; i <= n; i ++ ) res += dist[i];
        printf("%lld\n", res);
    }

    return 0;
}