牛客练习赛42 C 出题的诀窍
链接:https://ac.nowcoder.com/acm/contest/393/C来源:牛客网
题目描述
给定m个长为n的序列a1,a2,…,ama_1 , a_2 , \dots , a_ma1,a2,…,am。
小Z想问你:
其中SUM(一个序列)\texttt{SUM}(\text{一个序列})SUM(一个序列)表示这个序列中所有不同的数的和,相当于先sort,unique\tt sort,uniquesort,unique再求和。
输入描述:
第一行两个整数n,m。接下来m行,每行n个整数,第i行第j个表示ai,ja_{i,j}ai,j
输出描述:
一行一个整数,表示答案。
示例1
输入
[复制](javascript:void(0)😉
2 3
1 2
2 3
1 3
输出
[复制](javascript:void(0)😉
36
题意
就是求题面中给定的公式。
思路:
计算贡献的题目。
把所有的数放入一个集合S(去重)
那么集合S中的每一个元素x,对答案的贡献就是x*num,num为含有x的一组数的个数
那么如何求num呢?
\(num=n^m-cnt\)
cnt为不含有x的一组数的个数
那么只需要m行,每一行中(n-x的个数)乘起来即可。
对于那些不含有x的行。我们用预处理n的幂次来解决。
并且这题比较卡常,
需要用快速读入+pbds的hash来离散化。
能用int的地方不要用longlong
代码:
#include <bits/stdc++.h>
#include <cstdio>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) {a %= MOD; if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}
void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
inline void getInt(int* p);
const int maxn = 4000010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
namespace IO {
#define BUF_SIZE 100000
#define OUT_SIZE 100000
#define ll long long
//fread->read
bool IOerror = 0;
inline char nc() {
static char buf[BUF_SIZE], *p1 = buf + BUF_SIZE, *pend = buf + BUF_SIZE;
if (p1 == pend) {
p1 = buf; pend = buf + fread(buf, 1, BUF_SIZE, stdin);
if (pend == p1) {IOerror = 1; return -1;}
//{printf("IO error!\n");system("pause");for (;;);exit(0);}
}
return *p1++;
}
inline bool blank(char ch) {return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';}
inline void read(int &x) {
bool sign = 0; char ch = nc(); x = 0;
for (; blank(ch); ch = nc());
if (IOerror)return;
if (ch == '-')sign = 1, ch = nc();
for (; ch >= '0' && ch <= '9'; ch = nc())x = x * 10 + ch - '0';
if (sign)x = -x;
}
//fwrite->write
struct Ostream_fwrite {
char *buf, *p1, *pend;
Ostream_fwrite() {buf = new char[BUF_SIZE]; p1 = buf; pend = buf + BUF_SIZE;}
void out(char ch) {
if (p1 == pend) {
fwrite(buf, 1, BUF_SIZE, stdout); p1 = buf;
}
*p1++ = ch;
}
void print(int x) {
static char s[15], *s1; s1 = s;
if (!x)*s1++ = '0'; if (x < 0)out('-'), x = -x;
while (x)*s1++ = x % 10 + '0', x /= 10;
while (s1-- != s)out(*s1);
}
void print(char *s) {while (*s)out(*s++);}
} Ostream;
inline void print(int x) {Ostream.print(x);}
inline void print(char *s) {Ostream.print(s);}
};
using namespace IO;
int a[2005][2005];
int b[2005][2005];
int n, m;
ll base;
const ll mod = 1000000007ll;
ll ans;
int vis[maxn];
__gnu_pbds::gp_hash_table<int, int> w;
bool wvis[maxn];
bool solved[maxn];
int cnt[maxn];
int p[5000];
int id = 0;
__gnu_pbds::gp_hash_table<int, int> lsh;
int main()
{
read(n);
read(m);
repd(i, 1, m)
{
repd(j, 1, n)
{
read(a[i][j]);
}
}
repd(i, 1, m)
{
repd(j, 1, n)
{
int q = lsh[a[i][j]];
if (q == 0)
{
lsh[a[i][j]] = ++id;
b[i][j] = id;
} else
{
b[i][j] = q;
}
vis[b[i][j]] = 1ll;
}
}
repd(i, 1, m)
{
repd(j, 1, n)
{
w[b[i][j]] += 1;
}
repd(j, 1, n)
{
if (wvis[b[i][j]] == 0)
{
wvis[b[i][j]] = 1;
vis[b[i][j]] = 1ll * vis[b[i][j]] * (n - w[b[i][j]]) % mod;
cnt[b[i][j]]++;
}
}
repd(j, 1, n)
{
wvis[b[i][j]] = 0;
w[b[i][j]] -= 1;
}
}
base = powmod(n, m, mod);
p[0] = 1ll;
repd(i, 1, n)
{
p[i] = (1ll * p[i - 1] * n) % mod;
}
repd(i, 1, m)
{
repd(j, 1, n)
{
if (solved[b[i][j]] == 0)
{
solved[b[i][j]] = 1;
vis[b[i][j]] = (1ll * vis[b[i][j]] * p[ m - cnt[b[i][j]]]) % mod;
ans = (ans + ( base - vis[b[i][j]] + mod) % mod * a[i][j] % mod) % mod;
}
}
}
printf("%lld\n", ans);
return 0;
}
inline void getInt(int* p) {
char ch;
do {
ch = getchar();
} while (ch == ' ' || ch == '\n');
if (ch == '-') {
*p = -(getchar() - '0');
while ((ch = getchar()) >= '0' && ch <= '9') {
*p = *p * 10 - ch + '0';
}
}
else {
*p = ch - '0';
while ((ch = getchar()) >= '0' && ch <= '9') {
*p = *p * 10 + ch - '0';
}
}
}