「AGC037B」RGB Balls

Description

有 $3n$ 个球摆成一列,其中有 $\text{RGB}$ 每种颜色各 $n$ 个。现在第 $i$ 个人有三个颜色互不相同的球,并且在序列中的位置从小到大是 $a_i <b_i < c_i$ 。求在保证 $\sum c_i - a_i$ 最小的情况下有多少种分球方案。答案对 $998244353$ 取模。

数据范围:$n \leq 10^5$

Solution

自然地第一步要想的是如何满足 $\sum c_i - a_i$ 最小。

我们换一个角度来分析它。把问题看成从前往后把每一个球给一个人。那么对于每一个人,他对答案产生的贡献是他手中已经分了球但是还没有凑全三个球的时刻数。

于是便可以贪心的最小化每一个时刻手中拿了球但是没有拿全三个的人数。

对于一个时刻,不妨设目前球数最多的是 $\text{R}$ ,其次是 $\text{G}$ ,然后是 $\text{B}$ 。那么此时最优的情况一定是若干组 $\text{RGB}$ 加上若干组 $\text{RG}$ 再加上若干组 $\text{R}$ 。于是便可以解决 $\sum c_i - a_i$ 最小的问题。

方案数也能够求出。每次新加一个球的时候看他是原来的第几大和现在的第几大。方案数相应的乘上能够放进的组的个数即可。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <bits/stdc++.h>

using namespace std;

#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double

#define fi first
#define se second
#define MP make_pair
#define pii pair <int, int>
#define pil pair <int, ll>
#define pli pair <ll, int>
#define pll pair <ll, ll>

#define All(x) x.begin(), x.end()
#define pb push_back
#define pf push_front

#define ms0(x) memset(x, 0, sizeof(x))
#define ms1(x) memset(x, -1, sizeof(x))

#define oye cerr << "Yes!" << endl;
#define O(x) cerr << #x << ": " << x << endl;

template <typename T> void printarr(T a[], int b, int e) {
if(b > e) return ;
for(int i = b; i < e; i++) cout << a[i] << " ";
cout << a[e] << endl;
}

template <typename T> int chkmax(T &x, const T &y) {
return x < y ? x = y, 1 : 0;
}

template <typename T> int chkmin(T &x, const T &y) {
return x > y ? x = y, 1 : 0;
}

const int N = 1001000;
const int mod = 998244353;

int addp(int x, int y) {
return (x += y) >= mod ? x - mod : x;
}

char s[N];
int n;
int ans = 1;
int cnt[3];

int main() {
scanf("%d", &n); n *= 3;
scanf("%s", s + 1);
for(int i = 1; i <= n; i++) {
int x[3], y[3];
x[0] = cnt[0]; x[1] = cnt[1], x[2] = cnt[2];
sort(x, x + 3);
int op = 0;
if(s[i] == 'R') op = 0;
if(s[i] == 'G') op = 1;
if(s[i] == 'B') op = 2;
cnt[op]++;
y[0] = cnt[0]; y[1] = cnt[1], y[2] = cnt[2];
sort(y, y + 3);
if(y[2] != x[2]) {
; // ans = 1ll * ans * x[2] % mod;
}
else if(y[0] != x[0]) {
ans = 1ll * ans * (x[1] - x[0]) % mod;
}
else {
ans = 1ll * ans * (x[2] - x[1]) % mod;
}
}
for(int i = 1; i <= n / 3; i++) ans = 1ll * ans * i % mod;
printf("%d\n", ans);
return 0;
}