日々drdrする人のメモ

今日も明日もdrdr

ARC096 - E問題: Everything on It

問題: E - Everything on It

N種類のトッピングが乗せれるラーメンで、トッピングの組み合わせを異なる {2^N}通りの組み合わせのラーメンのうち何杯か注文し、N種類のトッピングが、注文したラーメンらのうち2杯以上に乗っている組み合わせ数を計算する問題。

公式や他の提出コードを参考にした解法メモ。


この問題では、包除原理を用いて、
(0種類を1杯以下にした時の通り数) - (1種類を1杯以下にした時の通り数) + ...
と計算する。

この中で、(K種類を1杯以下にした時の通り数)を考える。
この時、K種類に含まれるトッピングは、どのラーメンにも乗ってないか、1杯のラーメンに乗るのみである。
ここで、K種類のトッピングのいずれかが乗っているラーメンがL杯 (L = 1, ..., K)の場合について考え、各杯数Lにおける通り数を合わせることでK種類の時の通り数を計算することを考える。

この時のL杯の時の通り数は第2種スターリング数(ここでは {S(*, *)}と表現する)を用いて、 {S(K+1, L+1) 2^{(N-K)L}}と計算できる。考え方的には、K種類のトッピングから何種類か選んで、選ばれたトッピングの1個をL杯のラーメンのいずれかに乗せた時の通り数を計算し、それに加えてL杯のラーメンについて残りの {N - K}種類のトッピングの乗せ方を考慮している。

この時、なぜ {S(K+1, L+1)}になるのかに悩んだ。
 {S(K, L)}ちょうどK個の要素をL個のグループに分けた時の通り数であるが、ここで計算したいのはK個のうちのいくつかをL個のグループに分けた時の通り数である。そこで、選択されない要素を入れるためのグループを1つ追加し、選択されないグループを表すために1つの要素(この要素が属するグループは選択されないグループになる)を追加することで、今回求めたい数を計算する。

あとは、N種類からK種類のトッピングを選ぶ通り数 {{}_NC_K}や、残りの {N - K}種類のトッピングだけが乗ったラーメンの選び方の通り数 {2^{2^{N-K}}}を掛けることでK種類の時の通り数が求まる。
この時 {2^{2^{N-K}}}をmod  {M}で計算する際、指数の {2^{N-K}}はmod  {M-1}で計算する。(∵フェルマーの小定理より)

最終的に組み合わせ数は、

 {\displaystyle \sum_{K = 0}^N (-1)^K 2^{2^{N-K}} {}_NC_K \sum_{L=0}^K S(K+1, L+1) 2^{(N-K)L}} (mod  {M})

を計算することで求まる。

計算量は {O(N^2)}

提出コード (Python3): Submission #2419047 - AtCoder Regular Contest 096

N, M = map(int, input().split())
fact = [1]*(N+1)
rfact = [1]*(N+1)
for i in range(1, N+1):
    fact[i] = r = (i * fact[i-1]) % M
    rfact[i] = pow(r, M-2, M)

S = [1]

rev2 = pow(2, M-2, M)
base = pow(2, N, M) # 2^(N - K)
ans = 0
S = [1]
for K in range(N+1):
    # nCk
    res = (fact[N] * rfact[K] * rfact[N-K]) % M
    # 2^{2^{N - K}}
    res = (res * pow(2, pow(2, N - K, M-1), M)) % M
    b = 1
    v = 0
    # S[i] = 第2種スターリング数 S(K, i)
    # T[i] = 第2種スターリング数 S(K+1, i)
    T = [0]*(K+2)
    for L in range(K):
        T[L+1] = s = (S[L] + (L+1)*S[L+1]) % M
        v += s * b
        b = (b * base) % M
    v += b
    T[K+1] = 1
    S = T
    res = (res * v) % M
    if K % 2:
        ans -= res
    else:
        ans += res
    ans %= M

    base = (base * rev2) % M
print(ans)