[LOJ2538][PKUWC2018]Slay the Spire

被各种特判姿势卡爆…

题意

有两种牌(攻击牌,强化牌)各n张,每张上都有一个数字x。

攻击牌打出后造成伤害x。

强化牌打出后使所有剩余攻击牌点数乘x,强化牌的x > 1。

能等概率随机抽m张牌,从m张牌里照最优策略取k张。

求期望伤害乘情况数。

思路

首先有显然策略,尽量多取强化牌。

然后我们降序排序。设 $ F(i,j) $为有i张强化牌选最强的j张的倍数和 ,$ G(i,j) $ 为有i张攻击,选最强j张伤害和。

答案为 $ Ans = \sum_{i = 1}^{k - 1}{F(i,i) \times G(m - i,k - i)} + \sum_{i = k}^{m}{F(i,k - 1) \times G(m - i,1)} $ 。

我们设 $ f[i][j] $ 为用i张强化牌,最后一张是j时的倍数和, $ g[i][j] $ 为用i张攻击牌,最后一张是j时的伤害和。

有 $ f[i][j] = a[j] \sum_{x = 0}^{j - 1}{f(i - 1,x)} $,$ g[i][j] = b[j] \times C_{j - 1}^{i - 1} + \sum_{x = 0}^{j - 1}{g(i - 1,x)} $。

用前缀和推。

然后有 $ F(x,y) = \sum_{i = 0}^{n} {f[y][i] \times C_{x - y}^{n - i} } $ , $ G(x,y) = \sum_{i = 0}^{n} {g[y][i] \times C_{x - y}^{n - i} } $ 。

然后就好了。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cctype>
inline int getint()
{
int r = 0,s = 1;char c = getchar();for(;!isdigit(c);c = getchar()) if(c == '-') s = 0;
for(;isdigit(c);c = getchar()) r = (((r << 2) + r) << 1) + (c ^ '0');return s ? r : -r;
}
const int N = 3003,MOD = 998244353;
int a[N],b[N],f[N][N],g[N][N],s[N];int n,m,k;
int frc[N],inv[N];
inline void U(int &x,int y){if((x += y) >= MOD) x -= MOD;}
inline int qpow(int a,int b)
{
int r = 1;for(;b;b >>= 1)
{
if(b & 1) r = 1LL * r * a % MOD;
a = 1LL * a * a % MOD;
}
return r;
}
inline int C(int x,int y)
{
if(x > y || x < 0 || y < 0) return 0;
return 1LL * frc[y] * (1LL * inv[x] * inv[y - x] % MOD) % MOD;
}
inline int F(int x,int y)
{
int ret = 0;
for(int i = 0;i <= n;i++) U(ret,1LL * f[y][i] * C(x - y,n - i) % MOD);
return ret;
}
inline int G(int x,int y)
{
int ret = 0;
for(int i = 0;i <= n;i++) U(ret,1LL * g[y][i] * C(x - y,n - i) % MOD);
return ret;
}
inline bool cmp(int x,int y){return x > y;}
int main()
{
frc[0] = 1;for(int i = 1;i < N;i++) frc[i] = 1LL * frc[i - 1] * i % MOD;
inv[N - 1] = qpow(frc[N - 1],MOD - 2);for(int i = N - 1;i;i--) inv[i - 1] = 1LL * inv[i] * i % MOD;
int T = getint();while(T--)
{
n = getint(),m = getint(),k = getint();
for(int i = 1;i <= n;i++) a[i] = getint();
for(int i = 1;i <= n;i++) b[i] = getint();
std::sort(a + 1,a + n + 1,cmp);
std::sort(b + 1,b + n + 1,cmp);
memset(f,0,sizeof(f));memset(g,0,sizeof(g));
memset(s,0,sizeof(s));
f[0][0] = 1;s[0] = 1;for(int i = 1;i <= n;i++) s[i] = s[i - 1];
for(int i = 1;i <= n;i++)
{
for(int j = i;j <= n;j++)
{
f[i][j] = 1LL * a[j] * s[j - 1] % MOD;
if(f[i][j] >= MOD) f[i][j] -= MOD;
if(f[i][j] < 0) f[i][j] += MOD;
}
s[0] = 0;
for(int j = 1;j <= n;j++)
{
s[j] = f[i][j];
if((s[j] += s[j - 1]) >= MOD) s[j] -= MOD;
}
}
memset(s,0,sizeof(s));
g[0][0] = 0;
for(int i = 1;i <= n;i++)
{
for(int j = i;j <= n;j++) if((g[i][j] = (1LL * b[j] * C(i - 1,j - 1) % MOD) + s[j - 1]) >= MOD) g[i][j] -= MOD;
for(int j = 1;j <= n;j++)
{
s[j] = g[i][j];
if((s[j] += s[j - 1]) >= MOD) s[j] -= MOD;
}
}
int ans = 0;
for(int i = 0;i <= m;i++)
{
if(i < k) U(ans,1LL * F(i,i) * G(m - i,k - i) % MOD);
else U(ans,1LL * F(i,k - 1) * G(m - i,1) % MOD);
}
printf("%d\n",ans);
}
return 0;
}