题号 NC24093
名称 Modern Art
来源 USACOhttps://ac.nowcoder.com/acm/contest/3781/#question)

有一个N * N的网格,给你N * N中颜色,要求你用这些颜色个网格染色(每种颜色必用且只能用一次)每次可以将一个不超过网格大小的子矩阵染成同一种颜色

后来的颜色会将先前的颜色覆盖,问有多少种颜色可能是第一个染的?

样例

输入
4
2 2 3 0
2 7 3 7
2 7 7 7
0 0 0 0
输出
14

算法

(二维差分(覆盖) + 思维)

我们首先分析,如果某一个格子被上色的次数已知

  1. 这个格子被上色了1次那么这个格子当前的的颜色就有可能是第一个染色的颜色

    (只要我们第一次就使用这种颜色,接下来的染色方法都不覆盖这个颜色就能构造出这种情况)

  2. 这个格子被上色的次数大于一次那么这个格子当前的颜色就不可能是第一个染色的颜色

    (因为我们无法构造出这种情况,每个颜色只能用一次,被覆盖了还能显示在当前格子上时不可能的)

当然还有一种情况就是第一次染色的颜色被后面的颜色完全覆盖了

所以我们反过来计算有多少种颜色不可能是第一次染色的颜色,然后用N * N减去这些颜色就是答案


接着我们思考如何知道每个格子被上色的次数

我们找到每种颜色可能覆盖的最小矩形范围(找到最左上角,以及最右下角可能出现的位置)

然后用二维差分处理这块区域

最后做一遍前缀和就能得到每个格子被不同颜色覆盖的次数

图片说明

细节:
当网格中只有一种颜色且n > 1那么结果应该是N * N - 1,需要特判

时间复杂度

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
const int N = 1010;
int sum[N][N];
int g[N][N];
int l[N * N],r[N * N],u[N * N],d[N * N];
bool st[N * N];
int n;

void solve()
{
    scanf("%d",&n);
    memset(u,0x3f,sizeof u);
    memset(l,0x3f,sizeof l);
    set<int> S;
    for(int i = 1;i <= n;i ++)
        for(int j = 1;j <= n;j ++)
        {
            scanf("%d",&g[i][j]);
            if(g[i][j] == 0) continue;
            l[g[i][j]] = min(l[g[i][j]],j);
            r[g[i][j]] = max(r[g[i][j]],j);
            d[g[i][j]] = max(d[g[i][j]],i);
            u[g[i][j]] = min(u[g[i][j]],i);
            S.insert(g[i][j]);
        }
    if((int)S.size() == 1 && n > 1)
    {
        printf("%d\n",n * n - 1);
        return;
    }
    for(int i = 1;i <= n * n;i ++)
        if(r[i])
        {
            sum[u[i]][l[i]] ++;
            sum[u[i]][r[i] + 1] --;
            sum[d[i] + 1][l[i]] --;
            sum[d[i] + 1][r[i] + 1] ++;
        }
    for(int i = 1;i <= n;i ++)
        for(int j = 1;j <= n;j ++)
        {
            sum[i][j] += sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1];
            if(sum[i][j] >= 2) st[g[i][j]] = true;
        }
    int res = n * n;
    for(int i = 1;i <= n * n;i ++)
        res -= st[i];
    printf("%d\n",res);
}

int main()
{
    #ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #else
    #endif // LOCAL
    int T = 1;
    // init(500);
    // scanf("%d",&T);
    while(T --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}