牛客练习赛42 C 出题的诀窍

链接:https://ac.nowcoder.com/acm/contest/393/C来源:牛客网

题目描述

给定m个长为n的序列a1,a2,…,ama_1 , a_2 , \dots , a_ma1,a2,…,am。

小Z想问你:

其中SUM(一个序列)\texttt{SUM}(\text{一个序列})SUM(一个序列)表示这个序列中所有不同的数的和,相当于先sort,unique\tt sort,uniquesort,unique再求和。

输入描述:

第一行两个整数n,m。接下来m行,每行n个整数,第i行第j个表示ai,ja_{i,j}ai,j

输出描述:

一行一个整数,表示答案。

示例1

输入

[复制](javascript:void(0)😉

2 3
1 2
2 3
1 3

输出

[复制](javascript:void(0)😉

36

题意

就是求题面中给定的公式。

思路:

计算贡献的题目。

把所有的数放入一个集合S(去重)

那么集合S中的每一个元素x,对答案的贡献就是x*num,num为含有x的一组数的个数

那么如何求num呢?

\(num=n^m-cnt\)

cnt为不含有x的一组数的个数

那么只需要m行,每一行中(n-x的个数)乘起来即可。

对于那些不含有x的行。我们用预处理n的幂次来解决。

并且这题比较卡常,

需要用快速读入+pbds的hash来离散化。

能用int的地方不要用longlong

代码:

#include <bits/stdc++.h>
#include <cstdio>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) {a %= MOD; if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}

void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}

inline void getInt(int* p);
const int maxn = 4000010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
namespace IO {
#define BUF_SIZE 100000
#define OUT_SIZE 100000
#define ll long long
//fread->read
bool IOerror = 0;
inline char nc() {
    static char buf[BUF_SIZE], *p1 = buf + BUF_SIZE, *pend = buf + BUF_SIZE;
    if (p1 == pend) {
        p1 = buf; pend = buf + fread(buf, 1, BUF_SIZE, stdin);
        if (pend == p1) {IOerror = 1; return -1;}
        //{printf("IO error!\n");system("pause");for (;;);exit(0);}
    }
    return *p1++;
}
inline bool blank(char ch) {return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';}
inline void read(int &x) {
    bool sign = 0; char ch = nc(); x = 0;
    for (; blank(ch); ch = nc());
    if (IOerror)return;
    if (ch == '-')sign = 1, ch = nc();
    for (; ch >= '0' && ch <= '9'; ch = nc())x = x * 10 + ch - '0';
    if (sign)x = -x;
}
//fwrite->write
struct Ostream_fwrite {
    char *buf, *p1, *pend;
    Ostream_fwrite() {buf = new char[BUF_SIZE]; p1 = buf; pend = buf + BUF_SIZE;}
    void out(char ch) {
        if (p1 == pend) {
            fwrite(buf, 1, BUF_SIZE, stdout); p1 = buf;
        }
        *p1++ = ch;
    }
    void print(int x) {
        static char s[15], *s1; s1 = s;
        if (!x)*s1++ = '0'; if (x < 0)out('-'), x = -x;
        while (x)*s1++ = x % 10 + '0', x /= 10;
        while (s1-- != s)out(*s1);
    }
    void print(char *s) {while (*s)out(*s++);}
} Ostream;
inline void print(int x) {Ostream.print(x);}
inline void print(char *s) {Ostream.print(s);}
};
using namespace IO;
int a[2005][2005];
int b[2005][2005];
int n, m;
ll base;
const ll mod = 1000000007ll;
ll ans;
int vis[maxn];
__gnu_pbds::gp_hash_table<int, int> w;
bool wvis[maxn];
bool solved[maxn];
int cnt[maxn];
int p[5000];
int id = 0;
__gnu_pbds::gp_hash_table<int, int> lsh;
int main()
{
    read(n);
    read(m);
    repd(i, 1, m)
    {
        repd(j, 1, n)
        {
            read(a[i][j]);
        }
    }
    repd(i, 1, m)
    {
        repd(j, 1, n)
        {
            int q = lsh[a[i][j]];
            if (q == 0)
            {
                lsh[a[i][j]] = ++id;
                b[i][j] = id;
            } else
            {
                b[i][j] = q;
            }
            vis[b[i][j]] = 1ll;
        }
    }
    repd(i, 1, m)
    {
        repd(j, 1, n)
        {
            w[b[i][j]] += 1;
        }
        repd(j, 1, n)
        {
            if (wvis[b[i][j]] == 0)
            {
                wvis[b[i][j]] = 1;
                vis[b[i][j]] = 1ll * vis[b[i][j]] * (n - w[b[i][j]]) % mod;
                cnt[b[i][j]]++;
            }
        }
        repd(j, 1, n)
        {
            wvis[b[i][j]] = 0;
            w[b[i][j]] -= 1;
        }
    }
    base = powmod(n, m, mod);
    p[0] = 1ll;
    repd(i, 1, n)
    {
        p[i] = (1ll * p[i - 1] * n) % mod;
    }
    repd(i, 1, m)
    {
        repd(j, 1, n)
        {
            if (solved[b[i][j]] == 0)
            {
                solved[b[i][j]] = 1;
                vis[b[i][j]] = (1ll * vis[b[i][j]] * p[ m - cnt[b[i][j]]]) % mod;
                ans = (ans + ( base - vis[b[i][j]] + mod) % mod * a[i][j]  % mod) % mod;
            }
        }
    }
    printf("%lld\n", ans);
    return 0;
}

inline void getInt(int* p) {
    char ch;
    do {
        ch = getchar();
    } while (ch == ' ' || ch == '\n');
    if (ch == '-') {
        *p = -(getchar() - '0');
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 - ch + '0';
        }
    }
    else {
        *p = ch - '0';
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 + ch - '0';
        }
    }
}