日々drdrする人のメモ

今日も明日もdrdr

AtCoder: Mujin Programming Challenge 2018 - F問題: チーム分け

実際には少なく見積もれる計算量がなかなか慣れない。

atcoder.jp

問題概要

 {N}人をいくつかのチームに分ける。
この時、 {i}番目の人は {a_i}人以下のチームのみに所属できる。

この条件の元、 {N}人のチーム分けとして何通り考えられるかをMOD  {998244353}で求めよ。
(ここでは、人は区別するがチームは区別しない)

制約

  •  {1 \le N \le 1000}
  •  {1 \le a_i \le N}

解法

dpの方針まではわかったけど、計算量がうまく落とせず解けなかった...。
公式解説ベース。


まず、 {X}人が属するチームを作る場合を考える。
この時、以下のことが分かる。

  •  {X \le a_i}となる人は {a_i}の値に関係なく所属可能
  •  {a_i < X}となる人は所属不可能


そこで、以下のようなdpを考える。
 {dp[x][y] :=}  {x \le a_i} となる人の中から、 {x}人以上を含むいくつかのチームを作った時点で {y}人余った時の通り数

このdpの初期値を {dp[N+1][0] = 1}としてから、 {dp[0][0]}を問題の解として求めればよい。


そして、値の伝搬であるが、 {A}人存在した時に {i}人チームを {k}チーム作る時の通り数は
 {\displaystyle {}_AC_{k \cdot i} \cdot {}_{k \cdot i}C_{i} \cdot {}_{(k-1) \cdot i}C_{i} \cdot ... \cdot {}_{2i}C_i \cdot {}_iC_i = \frac{A!}{(A - k \cdot i)! \cdot (i!)^k \cdot k!}}
となる。

そのため、 {A = j + c_i}とした時、( {c_i} {i = a_x}となる人数)
 {dp[i][A - k \cdot i] \leftarrow dp[i+1][j] \cdot \frac{A!}{(A - k \cdot i)! \cdot (i!)^k \cdot k!}}
と伝搬すればよい。


あとは実際に計算するだけだが、一見するとこの計算は {O(N^3)}っぽく見える。
ここで、 {dp[i+1][j]}から {dp[i][*]}へ伝搬させる計算の回数は {\lfloor \frac{j + c_i}{i} \rfloor}回( {\le \frac{N}{i} })であることから、

 {1 \le i, j \le N}における伝搬の合計回数の上限は
 {\displaystyle \sum_{i=1}^N \sum_{j=1}^N \frac{N}{i} = \displaystyle N^2 \sum_{i=1}^N \frac{1}{i} \le N^2 (\log_e N + 1)}
と評価できるため、計算量は  {O(N^2 \log N)} となる。


 {\displaystyle \sum_{i=1}^N \frac{1}{i} \le \log_e N + 1} と評価できるのよく見落とすので慣れたい。

実装

提出コード(PyPy3): Submission #4511779 - Mujin Programming Challenge 2018

import sys
readline = sys.stdin.readline

N = int(readline())
*A, = map(int, readline().split())
B = [0]*(N+1)
for a in A:
    B[a] += 1

MOD = 998244353

fact = [1]*(N+1)
rfact = [1]*(N+1)
r = 1
for i in range(1, N+1):
    fact[i] = r = r * i % MOD
    rfact[i] = pow(r, MOD-2, MOD)

dp = [0]*(N+1)
dp2 = [0]*(N+1)
zeros = [0]*(N+1)
dp[0] = 1
for i in range(N, 0, -1):
    c = B[i]
    dp2[:] = zeros
    for j in range(N-c+1):
        num = j + c
        # v == pow(rfact[i], k, MOD)
        ri = rfact[i]; v = 1
        for k in range(num//i+1):
            dp2[num - k*i] += dp[j] * fact[num] * rfact[num - k*i] * rfact[k] * v % MOD
            v = v * ri % MOD
    dp, dp2 = dp2, dp
sys.stdout.write("%d\n" % (dp[0] % MOD))