日々drdrする人のメモ

今日も明日もdrdr

DPまとめコンテスト - O問題: Matching

問題: O - Matching

気づき含めた解法メモ

問題概要

 {N}人の男性と {N}人の女性がいる。

男性 {i}と女性 {j}の相性の良し悪しが {a_{i,j}}によって与えられ、1なら相性が良く、0なら相性が悪い。

ここで、相性の良い男女同士のペア {N}組を作るとき、これは何通り存在するか。 {10^9 + 7}で割った余りで求めよ。

解法

bitDPで計算する。

 {dp[state]} = 状態値stateのiビット目が1の時、女性 {i}はまだペアを作っていない

と定義して、 {dp[2^N-1] = 1}から計算すればよい。

計算量は  {O(N 2^N)}

以下は異なる計算方法のメモ。他にも方法はありそう。

方法1: メモ化DFS

どのstateからどのstateへ伝搬するかが直感的に書ける。

提出コード(PyPy3, 1473ms): Submission #4021434 - Educational DP Contest / DP まとめコンテスト

N = int(input())
A = [list(map(int, input().split())) for i in range(N)]
MOD = 10**9 + 7
 
memo = [-1]*(1 << N)
memo[0] = 1
def dfs(state, c):
    if memo[state] != -1:
        return memo[state]
    r = 0
    for i in range(N):
        if state & (1 << i) and A[c][i]:
            r += dfs(state ^ (1 << i), c+1)
    memo[state] = r = r % MOD
    return r
print(dfs((1 << N)-1, 0))
方法2: BFS

再帰で解くやり方の1つ。

BFSをすると、状態値 {2^N - 1}から、0が1つずつ少なくなる状態値にうまく伝搬させることができる。

少し遅い。

提出コード(PyPy3, 1931ms): Submission #4021478 - Educational DP Contest / DP まとめコンテスト

from collections import deque
N = int(input())
A = [list(map(int, input().split())) for i in range(N)]
MOD = 10**9 + 7

cs = [0]*(1 << N)
used = [0]*(1 << N)
s = (1 << N) - 1
que = deque([(s, 0)])
cs[s] = 1
used[s] = 1

while que:
    state, c = que.popleft()
    v = cs[state] % MOD

    for i in range(N):
        if state & (1 << i) and A[c][i]:
            n_state = state ^ (1 << i)
            cs[n_state] += v

            # 同じstateを複数回queueに入れないようにする
            if not used[n_state] and c+1 < N:
                que.append((n_state, c+1))
                used[n_state] = 1
print(cs[0] % MOD)
方法3: 1がkビットある状態値を全列挙する

蟻本第二版p.144に載っているテクニックを使う。

 {k = N, N-1, ..., 1}について、 {\{0, 1, ..., N-1\}}に対し、サイズkの部分集合をビット演算により列挙できる。

これにより、

  • 1がNビットある状態値stateを列挙して伝搬
  • 1がN-1ビットある状態値stateを列挙して伝搬

...

  • 1が1ビットある状態値stateを列挙して伝搬

という処理をシンプルに処理できる。

この解法はPyPy3で少し高速に通るっぽい。

提出コード(PyPy3, 947ms): Submission #3973864 - Educational DP Contest / DP まとめコンテスト

N = int(input())
A = [list(map(int, input().split())) for i in range(N)]

ALL = 1 << N
S = [0]*ALL
MOD = 10**9 + 7
S[ALL-1] = 1
# 1がNビット存在する状態 -> 1が{N-1}ビット存在する状態 -> ... -> 1が1ビット存在する状態
# という感じで伝搬
for k in range(N, 0, -1):
    I = [i for i in range(N) if A[k-1][i]]
    # サイズkの部分集合を列挙していく
    v = (1 << k) - 1
    while v < ALL:
        ## このvの値がサイズkの部分集合になる
        w = S[v] % MOD
        for i in I:
            if v & (1 << i):
                S[v ^ (1 << i)] += w

        x = v & -v; y = v + x
        v = ((v & ~y) // x >> 1) | y
print(S[0] % MOD)

この部分集合列挙の計算、蟻本で見た時からいつ使うんだろうと思ってたけど、こういう使い方ができるという意味で気づきだった。