链接:https://ac.nowcoder.com/acm/contest/894/C
来源:牛客网

华华跟奕奕玩游戏
时间限制:C/C++ 1秒,其他语言2秒
空间限制:C/C++ 524288K,其他语言1048576K
64bit IO Format: %lld
题目描述
有一个箱子,开始时有n个黑球,m个蓝球。每一轮游戏规则如下:
第一步:奕奕有p的概率往箱子里添加一个黑球,有(1-p)的概率往箱子里添加一个蓝球。
第二步:华华随机从箱子里取出一个球。
华华喜欢黑球,他想知道k轮游戏之后箱子里黑球个数的期望。
输入描述:
输入五个整数n,m,k,a,b。
1<=n,m<=1e6,1<=k<=1e9
其中p=\frac{a}{b}
b
a

,且a<=b,0<=a<1e9+7,0<b<1e9+7
输出描述:
输出一个数表示k轮游戏后箱子里黑球个数的期望。
输出一个整数,为答案对1e9+7取模的结果。即设答案化为最简分式后的形式为\frac{a}{b}
b
a

,其中a和b互质。输出整数 x 使得bx≡a(mod 1e9+7)且0≤x<1e9+7。可以证明这样的整数x是唯一的。
示例1
输入
复制
2 2 1 1 2
输出
复制
2
示例2
输入
复制
2 2 2 3 10
输出
复制
184000003

思路:

这位哥哥写的非常好。
推荐!

https://blog.csdn.net/weixin_43702895/article/details/90343536#commentBox

细节见代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define all(a) a.begin(), a.end()
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
ll lcm(ll a,ll b){return a/gcd(a,b)*b;}
ll powmod(ll a,ll b,ll MOD){ll ans=1;while(b){if(b%2)ans=ans*a%MOD;a=a*a%MOD;b/=2;}return ans;}
inline void getInt(int* p);
const int maxn=1000010;
const int inf=0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/

ll n,m,k,a,b;
const ll mod=1e9+7ll;

ll inv(ll a)
{
    return powmod(a,mod-2ll,mod);
}
int main()
{
    //freopen("D:\\common_text\\code_stream\\in.txt","r",stdin);
    //freopen("D:\\common_text\\code_stream\\out.txt","w",stdout);
    
    cin>>n>>m>>k>>a>>b;
    ll ans=(n-a*(n+m)%mod*inv(b)%mod+mod)%mod*(powmod(n+m,k,mod))%mod*inv(powmod(n+m+1ll,k,mod))%mod+a*(n+m)%mod*inv(b)%mod;
    ans%=mod;
    cout<<ans<<endl;
    
    
    
    return 0;
}

inline void getInt(int* p) {
    char ch;
    do {
        ch = getchar();
    } while (ch == ' ' || ch == '\n');
    if (ch == '-') {
        *p = -(getchar() - '0');
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 - ch + '0';
        }
    }
    else {
        *p = ch - '0';
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 + ch - '0';
        }
    }
}