日々drdrする人のメモ

今日もdrdr、明日もdrdr

CodeChef August Challenge 2018: Safe Partition

CodeChef August Challenge 2018の問題: Safe Partition (Code: SAFPAR)
問題ページ: https://www.codechef.com/AUG18A/problems/SAFPAR

問題概要

 {N}個の要素を含む数列 {A}がある。

この数列を、各要素がいずれか1つの列に属するように、連続したいくつかの部分列 {S_1, S_2, ..., S_K}に分割する。
この時、全ての部分列について {\min(S_i) \le |S_i| \le \max(S_i)}を満たすようにしたい。

このような分割の仕方はいくつあるか?MOD  {10^9 + 7}で求めなさい。

制約

 {1 \le N \le 5 \times 10^5}
 {1 \le A_i \le N}

解法

満点解法はSegment Treeを使って頑張る。


まず、dpを使えば {O(N^2)}で解が求まる。
 {A(i, j)}をiからjまでの連続した部分列とすると
 {D[0] = 1}
 {\displaystyle D[i] = \sum_{\min(A(i, j)) \le |A(i, j)| \le \max(A(i, j))} D[j-1]}  {(i > 0)}
で計算した、 {D[N]}が解となる。(これは部分点10点の解法)


満点解法では、各iについて

  •  {\displaystyle P = \sum_{|A(i, j)| \le \max(A(i, j))} D[j-1]}
  •  {\displaystyle Q = \sum_{|A(i, j)| < \min(A(i, j))} D[j-1]}

を計算した上で  {D[i] = P - Q} を計算する。

 {P = \sum_{|A(i, j)| \le \max(A(i, j))} D[j-1]} の計算

ここで扱う例は以下の入力である。

15
1 2 4 10 2 1 4 2 1 3 6 2 1 8 2

各iについて、 {|A(i, j)| \le \max(A(i, j))}となる箇所を可視化した図を示す。図の左の値が {i}、図の上の値が {j-1}である。マスの色は {\max(A(i, j))}の値が大きいほど明るい。

f:id:smijake3:20180818120036p:plain

この図から、三角形の形の範囲を {N}個組み合わせた範囲の上で計算していることがわかる。

この図をよく見ると、 {D[i-1]}の時に参照していたが {D[i]}では参照されなくなった {D[j-1]}の箇所が全体に比べて少ないように見える。この例では、下の図のように9箇所存在する。
f:id:smijake3:20180818120129p:plain

これを元に、 {P = 0}から

  1. 新たな三角形の範囲が追加された時に、追加された範囲の和 {\sum_i D[j-1]} {P}に加える
  2. そこから不必要になった {D[j-1]}の値を {P}から引く

という処理を繰り返すことで計算量を抑えながら各 {P}の値が計算できそうだと予想できる。

この計算は

  • (1)はSegment Treeを使うことで {O(N \log N)}
  • (2)は(厳密に証明できてないが恐らく) {O(N \log N)}

になるため、全体で {O(N \log N)}になる。

 {Q = \sum_{|A(i, j)| < \min(A(i, j))} D[j-1]} の計算

 {Q}の値は、

  1.  {A_j \le |A(i, j)|}を満たす {j}のうち、最小の {|A(i, j)|}
  2.  {A_j > |A(i, j)|}を満たす {A_j}のうち、最小の {A_j}

を求め、2つの内の最小の値を {k}とすると、 {Q = \sum_{i-k+1 < j \le i} D[j-1]}で計算できる。

この計算は

  • (1)は全体でminをとりながら処理すると {O(N)}
  • (2)はSegment Treeを使うことで {O(N \log N)}

になるため、全体で {O(N \log N)}になる。

実装

実装は、

  • 各三角形の範囲がかぶらないように大きい三角形を優先した上で、各三角形の範囲の和を管理
  • 不必要な {D[j-1]}が出てきたら引く
  • 不必要になった三角形があったら破棄

みたいなことを難しく実装してる。これよりよい実装はいくらでもできそう。

提出コード(PyPy2): Solution: 19580651 | CodeChef

from heapq import heappush, heappop
INF = 10**18
N = int(raw_input())
A = map(int, raw_input().split())
N0 = 2**(N-1).bit_length()
# 区間内の最小値を計算するためのSegment Tree
data = [INF]*(2*N0+1)
def update_min(k, x):
    k += N0-1
    data[k] = x
    while k:
        k = (k - 1) // 2
        data[k] = min(data[2*k+1], data[2*k+2])
def query_min(l, r):
    res = INF
    l += N0; r += N0
    while l<r:
        if r & 1:
            r -= 1
            res = min(res, data[r-1])
        if l & 1:
            res = min(res, data[l-1])
            l += 1
        l >>= 1; r >>= 1
    return res
 
# 区間内の和を求めるためのSegment Tree
data1 = [0]*(2*N0+1)
def update_sum(k, x):
    k += N0-1
    data1[k] = x
    while k:
        k = (k - 1) // 2
        data1[k] = (data1[2*k+1] + data1[2*k+2]) % MOD
def query_sum(l, r):
    res = 0
    l += N0; r += N0
    while l<r:
        if r & 1:
            r -= 1
            res += data1[r-1] % MOD
        if l & 1:
            res += data1[l-1] % MOD
            l += 1
        l >>= 1; r >>= 1
    return res % MOD
 
MOD = 10**9 + 7
C = [0]*(N+1)
D = [0]*(N+1) # D[i]: 解説中のD[i]の値を保持
D[0] = C[0] = 1
update_sum(0, 1)
st = [[0, 10**9, 1]]
que = []
# r: Pの値を保持する
r = 0
 
S = {}
T = [[] for i in range(N)]
U = [0]*N
 
r0 = -1
for i, a in enumerate(A):
    # |A(i, j)| = A_jを満たすA_jを処理
    while que and que[0][0]==i:
        _, j, b = heappop(que)
        r0 = max(j, r0)
        update_min(j, INF)
    heappush(que, (i+a, i, a))
    update_min(i, a)
    # 三角形の破棄
    while st and st[-1][1] <= a:
        j, b, u = st.pop()
        r -= u % MOD
        U[j] = 1
        if j in S:
            del S[j]
    r %= MOD
 
    # 三角形の追加
    j, b, u = st[-1]
    v = query_sum(max(j, i-a), i+1)
    r += v % MOD
    e = [i+1, a, v]
    st.append(e)
    if j+a <= i:
        S[i+1] = e
    elif j+a < N:
        T[j+a].append(e)
 
    if T[i]:
        for e in T[i]:
            if U[e[0]] == 0:
                S[e[0]] = e
 
    R = []
    # 各三角形内の不必要なD[j-1]を引く
    for e in S.values():
        j, b, u = e
        if j < i-b+1:
            # 三角形が存在しなくなったので破棄
            R.append(j)
            U[j] = 1
        else:
            e[2] -= D[i-b]
            r -= D[i-b]
    for j in R:
        del S[j]
    r %= MOD
 
    r1 = max(i-query_min(i-a+1, i+1)+1, r0)
    # D[i+1] = P-Qの値を計算
    if i-a+1 <= r1:
        D[i+1] = v = (r - query_sum(max(r1+1, 0), i+1)) % MOD
    else:
        D[i+1] = v = (r - query_sum(max(i-a+2, 0), i+1)) % MOD
    update_sum(i+1, v)
print(D[-1])