题目链接:https://codeforces.ml/contest/1312/problem/D
题目大意:问你一个满足下列条件的数组数量。
<mstyle displaystyle="false" scriptlevel="0"> 1. n </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 2. [ 1 , m ] </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 3. </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 4. i : </mstyle> <mstyle displaystyle="false" scriptlevel="0"> a [ j ] < a [ j ] + 1 ( j < i ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> a [ j ] > a [ j ] + 1 ( j i ) </mstyle> \begin{array}{l} 1.有n个元素\\ 2.每个元素在[1, m]\\ 3.正好有两个元素相同\\ 4.存在一个i:\\a[j]<a[j]+1(j<i) \\ a[j]>a[j]+1(j≥i) \end{array} 1.n2.[1,m]3.4.i:a[j]<a[j]+1(j<i)a[j]>a[j]+1(ji)

思路:
我们可以根据条件推导:
<mstyle displaystyle="false" scriptlevel="0"> n 1 a [ i ] a [ i ] </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 1. m n 1 C ( m , n 1 ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 2. n 2 ( ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 3. : 2 n 3 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> C ( m , n 1 ) ( n 2 ) 2 n 3 </mstyle> \begin{array}{l} 有n-1个元素是完全不同的,a[i]是最大值,相等的两个元素只能在a[i]的两边。\\\\ 1.在m个元素中选择n-1个元素:C(m, n-1)\\ \\ 2.选择谁是相等的元素。有n-2种选法(最大值是唯一的)\\\\ 3.我们放好最大值和相等元素后。其余的元素可以放最大值的左边或者右边:2^{n-3}\\\\ 答案是:C(m, n-1)*(n-2)*2^{n-3} \end{array} n1a[i]a[i]1.mn1C(m,n1)2.n2()3.:2n3C(m,n1)(n2)2n3

#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=998244353;

LL quick_pow(LL a, LL b){
    if(a==0){
        return 0;
    }
    LL ans=1;
    while(b){
        if(b&1)
            ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}


int main() {
    LL n, m;
    scanf("%lld%lld", &n, &m);
    if(m>=n-1&&n>2){
        LL ans=quick_pow(2, n-3)*(n-2)%mod;
        LL s1=1, s2=1;
        for(LL i=m; i>=m-n+2; i--){
            s1*=i; s1%=mod;
        }
        for(LL i=1; i<=n-1; i++){
            s2*=i; s2%=mod;
        }
        ans=(ans*s1%mod)*quick_pow(s2, mod-2)%mod;
        printf("%lld\n", ans);
    }
    else{
        printf("0\n");
    }

    return 0;
}