日々drdrする人のメモ

今日も明日もdrdr

CodeChef May Challenge 2018: S-T Mincut

CodeChef の May Challenge 2018 の問題: S-T Mincut (Code: STMINCUT)
問題: https://www.codechef.com/MAY18A/problems/STMINCUT

問題

 {N \times N}の行列 {A}が与えられる。
この行列の各要素 {A_{ij}}の値を増やすことで次の条件を満たす必要がある。

  •  {N}個の頂点のグラフの中で、頂点 {i}と頂点 {j}の間の最小カットのコストが {A_{ij}}となるグラフ {G}が存在する

この条件を満たすように行列 {A}の各要素の値を増加させた時に、その増加させる値の合計を最小いくらにできるかを計算せよ。

制約

1つのテストケースに含まれるケース数:  {1 \le T \le 100}
 {1 \le N \le 1000}
 {1 \le A_{ij} \le 10^9},  {A_{ii} = 0}
1つのテストケースの {N}の合計は2000を超えない

解法

行列 {A}から構築できるグラフ {G}の候補はいろいろありそうだと思ったけど、条件を満たす増加の値が最小となる行列 {A}から構築できるグラフ {G}は木にできる、と予想して考えたら解けた。


この問題は、Union-Findとsetを用いたKruskal法で解ける。

方針は、 {\max(A_{ij}, A_{ji})} {i-j}間の辺コストとみなし、Kruskal法で辺コストが高い順に頂点同士を繋いでいき、木を構築しながら解を計算することである。


木を構築する際、頂点同士の連結をUnion-Findで管理することに加え、同じグループ(1つの木)に含まれる頂点をsetで管理する。
そして、ある木同士をある辺eでつなぎ合わせるときに、一方の木に含まれる頂点iともう一方の木に含まれる頂点jの最小カットのコストは辺 {e}のコストと一致するため、この辺 {e}のコストと {A_{ij} (A_{ji})}との差を計算していき、その和を解として出力する。


具体的に、木同士を繋ぎ合わせる時のイメージは以下の様になる。
頂点1,2,3と頂点4,5,6をそれぞれ1つの木にしたあとに、辺3-4で2つの木を繋ぎ合わせる例である。
f:id:smijake3:20180514210302p:plain
辺コストが高い順に繋ぐため、新たに繋ぎ合わせる辺(例では辺3-4)のコストは2つの木に含まれる全ての辺のコスト以下になる。
そのため、一方の木に含まれる頂点1,2,3ともう一方の木に含まれる頂点4,5,6同士の最小カットのコストは新たにつなぎ合わせた辺(辺3-4)のコストと等しくなる。


これで構築されるグラフ {G}で計算される各頂点間の最小カットのコストがより小さいものは存在しないはずである。
もし、ある頂点のペアの最小カットのコストをより小さくする場合、別の辺のコストを、そのコストに合わせる、もしくは切る必要が出てくるため、最小カットのコストが元の {A_{ij}}より小さくなってしまうものが出てくる、ことから存在しないことが考えられそうである。(きちんと証明はしてない)

例えば、上の図において、頂点2と頂点5の最小カットのコストを2にしたい場合、上の図の時点で最小カットのコストが5であるため、頂点2と頂点5の間のパスのいずれかの辺を切ってコスト2の辺2-5を追加する、もしくは頂点2と頂点5の間のパスに含まれる辺のコストを2に変える、のいずれかを行う必要があり、他の最小カットのコストが小さくなってしまう。

実装

1ケースの計算量は {O(N^2\alpha(N))}。 ( {\alpha}アッカーマン関数逆関数)

提出コード (Python3): https://www.codechef.com/viewsolution/18434025

T = int(input())

# Union-Find
def root(x):
    if p[x] == x:
        return x
    p[x] = y = root(p[x])
    return y
def unite(x, y):
    px = p[x]; py = p[y]
    if px == py:
        return -1
    if px < py:
        p[py] = px
        return px
    else:
        p[px] = py
        return py

for _ in range(T):
    N = int(input())
    A = [list(map(int, input().split())) for i in range(N)]

    ans = 0
    Q = []
    for i in range(N):
        Ai = A[i]
        for j in range(i):
            v = max(A[i][j], A[j][i])
            # A_{ij} と A_{ji} の値を max(A_{ij}, A_{ji}) に合わせる
            ans += 2*v - A[i][j] - A[j][i]
            A[i][j] = A[j][i] = v
            Q.append((v, i, j))
    ids = [-1]*N
    # Kruskal法
    Q.sort(reverse=1)
    *p, = range(N)
    S = [{i} for i in range(N)]
    for v, i, j in Q:
        pi = root(i); pj = root(j)
        r = unite(i, j)
        if r == -1:
            continue
        # uniteする2つの木に含まれる頂点集合si, sj
        si = S[pi]; sj = S[pj]
        for x in si:
            for y in sj:
                # 頂点xと頂点yの最小カットコストはv
                ans += 2*(v - A[x][y])
        # 2つの木に含まれる頂点を合わせて1つの集合にする
        if r == pi:
            si |= sj
            S[pj] = None
        else:
            sj |= si
            S[pi] = None
    print(ans)