2020 CCPC-Wannafly Winter Camp Day 1 A | Nanako

2020 CCPC-Wannafly Winter Camp Day 1 A

Problem A. 期望逆序对

Statement

给 $n$ 个随机变量 $x_1,x_2,\dots,x_n$,$x_i$ 的值是 $[l_i,r_i]$ 中随机选取的整数。你可以将这些随机变量排成任意的顺序,求逆序对数期望的最小值对 $998244353$ 取余后的结果。

$1 \leq n \leq 5 \times 10^3$, $1 \leq l_i \leq r_i \leq 10^9$

Solution

事实上,取值的期望(即 $\frac{l_i+r_i}{2}$)越小就应该排在越前面,这样得到的序列的逆序对数的期望就是最小的。

为什么呢?对于任意一个序列,我们选择相邻的两项,其他项不动,考虑这两项是否应该交换。如果前一项的期望比后一项大,那么显然交换之后逆序对数减小,所以答案的序列中相邻里两项一定是前一项的期望比较小,而要满足这个性质就只能令期望单增,于是我们确定了这个序列的顺序。

之后就只需要 $O(n^2)$ 枚举每一对变量,$O(1)$ 计算其产生的逆序对数就行了。把两个变量看成分别在 $x$ 轴和 $y$ 轴上的区间,两个变量之间能产生的逆序对数其实就等价于一个矩形在 $y=x$ 的一侧的面积吧,这个就各凭本事吧,现场写的是分类讨论。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef unsigned long long ll;

const int N = 1e5 + 5;
const int M = 1e3;
const int mod = 998244353;

ll qpow (ll a, ll b) {
ll ret = 1;
while (b) {
if (b & 1) ret = ret * a % mod;
a = a * a % mod;
b >>= 1;
}
return ret;
}

ll inv (ll a) {
return qpow(a, mod - 2);
}

struct node {
int l, r;
} a[N];
ll invv[N];
bool operator < (const node& a, const node& b) {
return a.l + a.r < b.l + b.r;
}

ll calc (const node& a, const node& b) {
if (a.l <= b.l) {
if (a.r <= b.l) return 0;
if ((a.r > b.l) && (a.r <= b.r)) {
ll tmp = a.r - b.l + 1;
return tmp * (tmp - 1) / 2 % mod;
}
if (a.r > b.r) {
ll tmp = (b.r - b.l + 1);
ll sum = tmp * (tmp - 1) / 2 % mod;
sum += (a.r - b.r) * tmp;
return sum % mod;
}
} else {
ll tmp = (a.r - a.l + 1);
ll sum = tmp * (tmp - 1) / 2 % mod;
sum += (a.l - b.l) * tmp;
return sum % mod;
}
}

int main () {
int n;
scanf("%d", &n);
for (int i = 1;i <= n;i++) {
scanf("%d%d", &a[i].l, &a[i].r);
}
sort(a + 1, a + n + 1);
for (int i = 1;i <= n;i++) {
invv[i] = inv(a[i].r - a[i].l + 1);
}

ll q = 1;
for (int i = 1;i <= n;i++) {
q *= a[i].r - a[i].l + 1;
q %= mod;
}

ll p = 0;
for (int i = 1;i <= n;i++) {
for (int j = i + 1;j <= n;j++) {
ll tmp = calc(a[i], a[j]);
tmp = tmp * invv[i] % mod;
tmp = tmp * invv[j] % mod;
p += tmp;
if (p >= mod) p -= mod;
}
}

printf("%lld", p);
}

欢迎关注我的其它发布渠道