[题解] JXOI2018 排序问题

题目描述

九条可怜是一个热爱思考的女孩子。

九条可怜最近正在研究各种排序的性质,她发现了一种很有趣的排序方法: $\operatorname{Gobo sort}​$

$\operatorname{Gobo sort}​$ 的算法描述大致如下:

  • 假设我们要对一个大小为 $n$ 的数列 $a$ 排序。
  • 等概率随机生成一个大小为 $n$ 的排列 $p$ 。
  • 构造一个大小为 $n$ 的数列 $b$ 满足 $b_i = a_{p_i}$ ,检查 $b$ 是否有序,如果 $b$ 已经有序了就结束算法,并返回 $b$ ,不然返回步骤 $2$。

显然这个算法的期望时间复杂度是 $O(n \times n!)​$ 的,但是九条可怜惊奇的发现,利用量子的神奇性质,在量子系统中,可以把这个算法的时间复杂度优化到线性。

九条可怜对这个排序算法进行了进一步研究,她发现如果一个序列满足一些性质,那么 $\operatorname{Gobo sort}$ 会很快计算出正确的结果。为了量化这个速度,她定义 $\operatorname{Gobo sort}$ 的执行轮数是步骤 2 的执行次数。

于是她就想到了这么一个问题:

现在有一个长度为 $n$ 的序列 $x$,九条可怜会在这个序列后面加入 $m$ 个元素,每个元素是 $[l,r]$内的正整数。 她希望新的长度为 $n+m$ 的序列执行 $\operatorname{Gobo sort}$ 的期望执行轮数尽量的多。她希望得到这个最多的期望轮数。

九条可怜很聪明,她很快就算出了答案,她希望和你核对一下,由于这个期望轮数实在是太大了,于是她只要求你输出对 $998244353$ 取模的结果。

输入格式

第一行输入一个整数 $T$,表示数据组数。

接下来 $2 \times T$ 行描述了 $T$ 组数据。

每组数据分成两行,第 $1$ 行有四个正整数 $n,m,l,r$ 表示数列的长度和加入数字的个数和加入数字的范围。 第 $2$ 行有 $n$ 个正整数,第 $i$ 个表示 $x_i$ 。

输出格式

输出 $T$ 个整数,表示答案。

样例输入

2
3 3 1 2
1 3 4
3 3 5 7
1 3 4

样例输出

180
720

数据范围及提示

对于第一组数据,我们可以添加${1,2,2}$ 到序列的最末尾,使得这个序列变成 1 3 4 1 2 2 ,那么进行一轮的成功概率是 $\frac{1}{180}$ ,因此期望需要 $180$ 轮。

对于第二组数据,我们可以添加 ${5,6,7}$ 到序列的最末尾,使得这个序列变成 1 3 4 5 6 7 ,那么进行一轮的成功概率是 $\frac{1}{720}$ ,因此期望需要 $720$ 轮。

对于 $30\%$ 的数据, $T\leq 10 , n,m,l,r$
对于 $50\%$ 的数据, $T\leq 300,n,m,l,r,a_i\leq 300$ 。
对于 $60\%$ 的数据,$\sum{r-l+1}\leq 10^7$ 。
对于 $70\%$ 的数据, $\sum{n} \leq 2\times 10^5$ 。
对于 $90\%$ 的数据, $m\leq 2\times 10^5*$。
对于 $100\%$ 的数据, $T\leq 10^5,n\leq 2\times 10^5,m\leq 10^7,1\leq l\leq r\leq 10^9$,$1\leq a_i\leq 10^9,\sum{n}\leq 2\times 10^6$ 。

解题思路

总的期望轮数等于一次成功的概率的倒数

考虑如何计算一次成功的概率,假设序列中有 $n$ 个数,其中总共有 $m$ 种,每种出现次数为 $a_i$
$$
\frac{\prod_{i=1}^n a_i}{n!}
$$
就是答案,我们要使期望轮数尽可能多,就是要使分子尽可能小,对于在 $l$ 到 $r$ 区间内的 $a_i$ ,要尽可能平均

也就是每一次加入的数是加入之前出现次数最少的数

考虑将 $l$ 到 $r$ 区间内的 $a_i$ 从小到大排序,将区间内不存在的数看做 $a_i$ 为 $0$

从前往后将第 $1$ 到 第 $i$ 个数一直增加,直至和 $i+1$ 数出现次数一样多,如果不足则只增加部分

注意题目有多组输入,数据清空时只清空已使用的部分数组,避免超时

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 200000 + 10;
const int MAXM= 10000000 + 200000 + 10;

const int MOD = 998244353;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int T;
int cnt[MAXN];
int num[MAXN], que[MAXN];

int n, m, l, r;

long long fac[MAXM], inv[MAXM];

inline long long powx(long long a, long long b){
    long long ans = 1;
    for(; b; b >>= 1){
        if(b & 1) ans = (ans * a) % MOD;
        a = (a * a) % MOD;
    }
    return ans;
}

inline void init(){
    fac[0] = 1;
    for(register int i=1; i<=10200000; i++){
        fac[i] = fac[i-1] * i % MOD;
    }

    inv[10200000] = powx(fac[10200000], MOD - 2);

    for(register int i=10199999; i>=0; i--){
        inv[i] = inv[i+1] * (i + 1) % MOD;
    }
}

inline void write(int x){
    if(x < 10){
        putchar('0' + x);
        return;
    }

    write(x / 10);
    write(x % 10);
}


int main(){
    init();

    T = read();

    while(T--){
        n = read();
        m = read();
        l = read();
        r = read();

        int all = n + m;

        for(register int i=1; i<=n; i++){
            num[i] = que[i] = read();
        }

        sort(que+1, que+n+1);
        int tot = unique(que+1, que+n+1) - que - 1;

        memset(cnt, 0, sizeof(int) * (n + 10));

        long long ans = 1;

        int start = 0x3f3f3f3f;
        int end = 0;

        for(register int i=1; i<=n; i++){
            int tmp = num[i];
            int id = lower_bound(que+1, que+tot+1, tmp) - que;

            if(l <= tmp && tmp <= r){
                start = min(start, id);
                end = max(end, id);
            }

            cnt[id]++;

            ans = (ans * cnt[id]) % MOD;
        }

        if(start == 0x3f3f3f3f){
            start = 5;
            end = 4;
        }

        int non = (r - l + 1) - (end - start + 1);

        sort(cnt+start, cnt+end+1);

        cnt[start - 1] = 0;
        cnt[end + 1] = 0x3f3f3f3f;

        int len = non;

        for(register int i=start; i<=end+1; i++){
            if(cnt[i] == cnt[i-1]){
                len++;
                continue;
            }

            if(1LL * len * (cnt[i] - cnt[i-1]) <= m){
                m -= len * (cnt[i] - cnt[i-1]);
                long long tmp = (inv[cnt[i-1]] * fac[cnt[i]]) % MOD;
                ans = ans * powx(tmp, len) % MOD;
            }else{
                int times = m / len;
                int rest = m % len;

                long long tmp = (inv[cnt[i-1]] * fac[cnt[i-1] + times]) % MOD;
                ans = ans * powx(tmp, len) % MOD;
                ans = ans * powx(cnt[i-1] + times + 1, rest) % MOD;

                m = 0;
                break;
            }

            len++;
        }

        write(fac[all] * powx(ans, MOD - 2) % MOD);
        putchar('\n');
    }
    return 0;
}