牛客小白月赛69 G (A+B)^N%P Problem (大模拟)

题目意思就是要让我们展开组合多项式,按公式模拟就行。但是 pp 可能是非质数,所以不能预处理逆元,又 nn 比较小 (n<=1000)(n<=1000) 可以用二维动规处理组合系数,JavaJava 也可以无脑 BigIntegerBigInteger..
细节确实挺多。。

import java.util.*;
import java.io.*;
import java.math.*;

public class Main {
	static int n,m,mod=(int)1e9+7,maxn=100010;
	static long ans=0,INF=(long)1e18;
	static Scanner sc = new Scanner (System.in);
	static PrintWriter pw = new PrintWriter(System.out);
	
	public static void main(String[]args) throws IOException{
		int T = 1;
		//T = I();
		while(T-->0) solve();
		pw.flush();
	}
	
	static long qpow(long a,long b) {
		long res=1;
		a%=mod;
		while(b>0) {
			if(b%2 == 1) res = res*a%mod;
			a = a*a%mod;
			b/=2;
		}
		return res;
	}
	
	static long a=1,b=1; //默认1
	static int p=0;
	static char x ,y;
	
	static void pre(String s) { //预处理
		String ss = "";
		int i=1;
		while(s.charAt(i)>='0' && s.charAt(i)<='9') {
			ss = ss+s.charAt(i);
			i++;
		}
		x = s.charAt(i++);i++;
		if(!ss.equals(""))a=Integer.parseInt(ss);
		ss="";
		
		while(s.charAt(i)>='0' && s.charAt(i)<='9') {
			ss = ss+s.charAt(i);
			i++;
		}
		y=s.charAt(i++);i++;
		if(!ss.equals("")) b = Integer.parseInt(ss);ss="";
		if(s.charAt(i) != '^') n=1;
		else {
			i++;
			while(s.charAt(i)>='0' && s.charAt(i)<='9') {
				ss = ss+s.charAt(i);
				i++;
			}
			n = Integer.parseInt(ss);ss="";
		}
		i++;
		while(i< s.length()&& s.charAt(i)>='0' && s.charAt(i)<='9') {
			ss = ss+s.charAt(i);
			i++;
		}
		p = Integer.parseInt(ss);
	}
	
	static void solve() throws IOException{
		String s = sc.nextLine();
		pre(s);mod=p;
		pw.print(s+" = ");
		if(x == y) { //字母相同
			long xi = qpow(a+b,n);
			if(xi == 0) pw.print(0);
			else {
				if(xi!=1) pw.print(xi+"*");
				if(n == 1) pw.print(x+"%"+p);
				else pw.print(x+"^"+n+"%"+p);
			}
		}
		else {
			String res ="";
			BigInteger fac[] = new BigInteger[n+1];
			fac[0] = BigInteger.ONE;
			for(int i=1;i<=n;i++) {
				fac[i] = BigInteger.valueOf(i).multiply(fac[i-1]); //大数阶乘
			}
			boolean first=false;
			for(int i = 0 ; i <=n;i++) {
				long xi = fac[n].divide(fac[i]).divide(fac[n-i])
                  .mod(BigInteger.valueOf(mod)).longValue();
				xi = xi*qpow(a,n-i)%mod *qpow(b,i)%mod; //系数
				if(xi == 0) continue;
				if(first) res = res+('+');
				first=true;
				if(xi>1) res = res+(xi);
				if(i<n) {
					if(xi > 1) res = res+("*"+x);
					else res = res+x;
					if(n-i!=1) res = res+("^"+(n-i));
				}
				if(i>0) {
					if(xi > 1 || i<n) res = res+("*"+y);
					else res = res+y;
					if(i!=1) res = res+ ("^"+(i));
				}
			}
			
            if(res.equals("")) pw.print(0); //无项
            else{
                if(res.indexOf("+")>=0) res = "("+res+")"; //判断添加括号
                pw.print(res+"%"+mod);
            }
		}
	}
	
}