日々drdrする人のメモ

今日も明日もdrdr

Segment Tree Beatsの実装メモ (基本まわり)

Segment Tree Beats(SGT Beats)の基本的なところの自分の理解をまとめておく。
今回は主に考え方や実装の説明メインで、Historic Informationや計算量解析周りの説明は含まない。(ここらへんも理解できて書けたら書く)

この記事に出てくる説明のためのコードはC++で書いてある。
以下を参考にしている。


※2019/05/28 計算量が間違ってる箇所があったため修正

目次

Segment Tree Beats とは

遅延評価セグメント木を改良したもの。

以下のような区間chmin/chmaxの更新クエリを、区間を加算操作で更新しながら処理できる。

  • 区間chminクエリ:  {i \in [l, r)} について  {a_i} {\min(a_i, x)} に更新
  • 区間chmaxクエリ:  {i \in [l, r)} について  {a_i} {\max(a_i, x)} に更新

f:id:smijake3:20190417085346p:plain
区間chminクエリのイメージ

ここからは、区間chminクエリのみに絞って説明する。

従来の遅延セグ木で区間chminクエリを扱うのが難しい点

f:id:smijake3:20190424212518p:plain
7, 1, 9, 7, 6, 2, 3, 5 を含む区間ノードを  min(a_i, 4) で更新するイメージ
(赤は即更新、黄は遅延させながら更新)
各ノードには区間総和が書かれており、これを一度に更新するのは難しい

区間chminクエリでは複数の異なる値が同時に更新されるケースが発生し、区間の値を更新する際にこれらの値を一度に考慮する必要があるが、これが従来の遅延セグ木では難しくなるケースがある。(単一要素参照や区間最大値の計算を扱う場合は従来の遅延セグ木で可能)

例えば、区間総和を管理する場合、変更された値全ての差を考慮して更新する必要があり、これを遅延セグ木上で扱うのは難しい。
具体的な区間総和の例を考える。上の図のように7, 1, 9, 7, 6, 2, 3, 5 を含む区間ノードを  \min(a_i, 4) で更新する場合、ノードが持つ総和に関して、更新後の 4, 1, 4, 4, 4, 2, 3, 4 の総和に更新する必要がある。
しかし、1つのノード内で5,6,7,9 の4種類の値に関する差を一度に考慮する必要があり、これを高速に処理するのは難しい。

Segment Tree Beatsの考え方

f:id:smijake3:20190428001323p:plain
7,1,9,7,6,2,3,5 の区間に対し min(a_i, 4)で更新する際の更新対象ノード
(赤は即更新、黄は遅延させながら更新)
ノードの中の"x/y"は最大値がxで(厳密な)二番目の最大値がyであることを示す

Segment Tree Beatsでは区間chminクエリを処理する際、更新対象となる値を1種類ずつ更新できるように処理する。

具体的には、区間chminクエリ  \min(a_i, x) を処理する際、更新対象とするノードを変更する。
従来の遅延セグ木では、更新対象となる区間に含まれる全てのノードを(遅延させながら)更新していたが、これを変更し、上図のように更新対象の区間の中で ((厳密な)二番目の最大値) <  x < (最大値) を満たす全てのノードを(遅延させながら)更新することで区間chminクエリの更新を行う。
(区間chmaxであれば(最小値) <  x < ((厳密な)二番目の最小値)を満たす全てのノードを更新する)

((厳密な)二番目の最大値) <  x < (最大値) を満たすことにより、更新のために考慮する値が最大値のみの1種類になる。
これにより、例えば1つのノードの区間総和を更新する場合は (最大値の個数)×( (新しい最大値 x) - (古い最大値) ) を一度加算するだけになり、各ノードを高速( O(1))で更新できる。

区間chminクエリ、RMQ、RSQを処理する実装

例えば、以下の問題を考える。

Gorgeous Sequence - HDOJ

長さ {N}の数列 {A}が与えられる。以下のクエリを合計 {M}回処理せよ。

  •  {l \le i < r} について、 {a_i} の値を  {\min(a_i, x)} に更新
  •  {\displaystyle \max_{l \le i < r} a_i} を出力
  •  {\displaystyle \sum_{l \le i < r} a_i} を出力

制約:  {N, M \le 10^6}

ここでは、区間chminクエリ、RMQ、RSQに対応できるセグ木の実装を説明する。

実装全体は以下

tjkendev.github.io

データの持ち方

Segment Tree Beatsでは区間chminを処理するために、各ノードにその区間 {[l, r)}最大値(厳密な)二番目の最大値最大値の個数をもたせる。

そして、今回は区間総和を管理する必要があるため、これをもたせる。

using ll = long long;

// ノードkが持つ情報
// max_v[k]: 最大値
// smax_v[k]: 二番目の最大値
// sum[k]: 区間和
// max_c[k]: 最大値の個数
ll max_v[4*N], smax_v[4*N];
ll sum[4*N], max_c[4*N];

区間chminクエリ

 {l \le i < r} について、 {a_i} の値を  {\min(a_i, x)} に更新

この更新処理は、従来の遅延セグ木の区間更新のように実装する。

従来の遅延セグ木との実装の大きな違いは、探索終了条件(break condition)と更新条件(tag condition)に、最大値との関係が追加される点である。
(区間chmaxであれば、最小値との関係が追加される)

区間chminクエリを処理する本体

break conditionは以下のいずれかを満たすことが条件となる。

  • 探索区間 {[l, r)}が更新区間 {[a, b)}の範囲外
  • (最大値) ≤  {x} を満たす

tag conditionは以下を全て満たすことが条件となる。

  • 探索区間 {[l, r)}が更新区間 {[a, b)}に含まれる
  • (二番目の最大値) <  {x} < (最大値) を満たす
// [a, b)内の a_i について min(a_i, x) に更新 ([l, r)内を探索する)
void _update_min(ll x, int a, int b, int k, int l, int r) {
  // break condition: 対象外の区間 もしくは 最大値より大きい場合、終了
  if(b <= l || r <= a || max_v[k] <= x) {
    return;
  }
  // tag condition: 最大値より小さいが、二番目の最大値より大きい場合、
  // このノードの情報を更新して終了
  if(a <= l && r <= b && smax_v[k] < x) {
    // 区間最大値の更新処理 (後述)
    update_node_max(k, x);
    return;
  }

  // 二番目の最大値よりも小さい場合、子ノードを更新し
  // 子ノードの情報からこのノードを更新する

  push(k);
  _update_min(x, a, b, 2*k+1, l, (l+r)/2);
  _update_min(x, a, b, 2*k+2, (l+r)/2, r);
  update(k);
}

// 区間[a, b) について、 a_i を min(a_i, x) に更新 (クエリ処理する際はこの形で呼び出す)
void update_min(int a, int b, ll x) {
  return _update_min(x, a, b, 0, 0, n0);
}

この問題における区間chmin/chmaxクエリの計算量はクエリ全体でみると  {O((N+M) \log N)} となる。

区間最大値の更新処理

最大値を更新する際の区間総和の更新は ( {x} - (更新前の最大値)) × (最大値の個数) を加算することで行う。
この更新処理は、区間を即更新する場合とpushdownで更新する場合とで基本同じになりそうなのでまとめといてよさそう。

// ノードkの最大値をx (< max_v[k]) に更新する
void update_node_max(int k, ll x) {
  // 最大値を更新する際に、差分で区間和も更新
  sum[k] += (x - max_v[k]) * max_c[k];
  max_v[k] = x;
}
親ノードから子ノードへの伝搬 (pushdown)

子ノードを探索する時、遅延されている最大値の値を伝搬し、更新を行う必要がある。
処理としては、 (親ノードの最大値) < (子ノードの(古い)最大値)を満たす場合に子ノードの最大値を更新する。

void push(int k) {
  // ノードkの子ノード2k+1, 2k+2の持つ情報を更新
  if(max_v[k] < max_v[2*k+1]) {
    update_node_max(2*k+1, max_v[k]);
  }
  if(max_v[k] < max_v[2*k+2]) {
    update_node_max(2*k+2, max_v[k]);
  }
}
子ノードから親ノードへの伝搬 (update)

子ノードの情報が更新した場合は、従来の遅延セグ木と同様に最大値等の情報を子ノードから更新する。
左右ノードの最大値の大小関係によって、更新する値を変更する。

void update(int k) {
  sum[k] = sum[2*k+1] + sum[2*k+2];

  if(max_v[2*k+1] < max_v[2*k+2]) {
    max_v[k] = max_v[2*k+2];
    max_c[k] = max_c[2*k+2];
    smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
  } else if(max_v[2*k+1] > max_v[2*k+2]) {
    max_v[k] = max_v[2*k+1];
    max_c[k] = max_c[2*k+1];
    smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
  } else {
    // max_v[2*k+1] == max_v[2*k+2]
    max_v[k] = max_v[2*k+1];
    max_c[k] = max_c[2*k+1] + max_c[2*k+2];
    smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
  }
}

区間最大値クエリ(RMQ) / 区間総和クエリ (RSQ)

 {\displaystyle \max_{l \le i < r} a_i} を出力

 {\displaystyle \sum_{l \le i < r} a_i} を出力

ここは変わらず、従来の遅延セグ木と同じように計算すればよい。
1回のクエリあたりの計算量は  O(\log N)

ll _query_max(int a, int b, int k, int l, int r) {
  if(b <= l || r <= a) {
    return 0;
  }
  if(a <= l && r <= b) {
    return max_v[k];
  }

  push(k);
  ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2);
  ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r);
  return max(lv, rv);
}
// 区間[a, b)の最大値を求める
ll query_max(int a, int b) {
  return _query_max(a, b, 0, 0, n0);
}
ll _query_sum(int a, int b, int k, int l, int r) {
  if(b <= l || r <= a) {
    return 0;
  }
  if(a <= l && r <= b) {
    return sum[k];
  }

  push(k);
  ll lv = _query_sum(a, b, 2*k+1, l, (l+r)/2);
  ll rv = _query_sum(a, b, 2*k+2, (l+r)/2, r);
  return lv + rv;
}

// 区間[a, b)の総和を求める
ll query_sum(int a, int b) {
  return _query_sum(a, b, 0, 0, n0);
}

区間chmin/chmaxクエリ両方を扱う実装

区間chmin/chmaxクエリを一度に扱う問題も解くことができる。
以下のような問題を考える。(一つ前の問題に、区間chmaxクエリと区間最小値クエリが加わったもの)

長さ {N}の数列 {A}が与えられる。以下のクエリを合計 {M}回処理せよ。

  •  {l \le i < r} について、  {a_i} の値を  {\min(a_i, t)} に更新
  •  {l \le i < r} について、  {a_i} の値を  {\max(a_i, t)} に更新
  •  {\displaystyle \max_{l \le i < r} a_i} を出力
  •  {\displaystyle \min_{l \le i < r} a_i} を出力
  •  {\displaystyle \sum_{l \le i < r} a_i} を出力

制約:  {N, M \le 10^5}

区間chminクエリのみの場合と異なる箇所について説明する。

データの持ち方

各ノードでは、以下の情報をもたせる

  • 最大値・(厳密な)二番目の最大値・最大値の個数
  • 最小値・(厳密な)二番目の最小値・最小値の個数
  • 区間
// 最大値・(厳密な)二番目の最大値・最大値の個数
ll max_v[4*N], smax_v[4*N], max_c[4*N];
// 最小値・(厳密な)二番目の最小値・最小値の個数
ll min_v[4*N], smin_v[4*N], min_c[4*N];
// 区間和
ll sum[4*N];

区間chmin/chmaxクエリ

処理の実装は、基本的に区間chminクエリのみの場合と基本変わらない。
一つ実装で気をつけるべき所は、最大値と最小値は更新する際に互いに影響しあう点である。
(例えば、あるノードで最大値7, 最小値3だった時に  {\min(x, 2)} で更新すると、最大値も最小値も2に更新される)

最大値・最小値の更新処理まわりの実装

最大値と最小値の関係に応じて、以下のケースで場合分けして更新する。

  • 最大値 = 最小値 (1種類の値を含む場合)
    • max_v[k] = min_v[k] (smax_v[k] =  -\infty, smin_v[k] =  \infty)
  • 最大値と最小値の2種類の値しか含まない場合
    • max_v[k] = smin_v[k], smax_v[k] = min_v[k]
  • それ以外 (3種類以上の値を含む場合)
    • min_v[k] < smin_v[k] ≤ smax_v[k] < max_v[k]
// max_v[k]: 最大値, smax_v[k]: (厳密な)二番目の最大値
// min_v[k]: 最小値, smin_v[k]: (厳密な)二番目の最小値

// x < max_v[k] となるノードkについて最大値を更新
// 区間chmin更新時とpushdown時に使う
void update_node_max(int k, ll x) {
  sum[k] += (x - max_v[k]) * max_c[k];

  // 最小値も必要に応じて更新
  if(max_v[k] == min_v[k]) {
    max_v[k] = min_v[k] = x;
  } else if(max_v[k] == smin_v[k]) {
    max_v[k] = smin_v[k] = x;
  } else {
    max_v[k] = x;
  }
}

// min_v[k] < x となるノードkについて最小値を更新
// 区間chmax更新時とpushdown時に使う
void update_node_min(int k, ll x) {
  sum[k] += (x - min_v[k]) * min_c[k];

  // 最大値も必要に応じて更新
  if(max_v[k] == min_v[k]) {
    max_v[k] = min_v[k] = x;
  } else if(smax_v[k] == min_v[k]) {
    min_v[k] = smax_v[k] = x;
  } else {
    min_v[k] = x;
  }
}
区間chmin/chmaxクエリを処理する本体

本体の実装は、片方のみ処理する場合と比べて変化しない。

// [a, b)内の a_i について min(a_i, x) に更新
void _update_min(ll x, int a, int b, int k, int l, int r) {
  if(b <= l || r <= a || max_v[k] <= x) {
    return;
  }
  if(a <= l && r <= b && smax_v[k] < x) {
    update_node_max(k, x);
    return;
  }

  push(k);
  _update_min(x, a, b, 2*k+1, l, (l+r)/2);
  _update_min(x, a, b, 2*k+2, (l+r)/2, r);
  update(k);
}

// [a, b)内の a_i について max(a_i, x) に更新
void _update_max(ll x, int a, int b, int k, int l, int r) {
  if(b <= l || r <= a || x <= min_v[k]) {
    return;
  }
  if(a <= l && r <= b && x < smin_v[k]) {
    update_node_min(k, x);
    return;
  }

  push(k);
  _update_max(x, a, b, 2*k+1, l, (l+r)/2);
  _update_max(x, a, b, 2*k+2, (l+r)/2, r);
  update(k);
}

この問題における区間chmin/chmaxクエリの計算量はクエリ全体でみると {O((N+M) \log N)}となる。
(基本的に RAQ, RUQ等を含まなければ  O((N+M) \log N) になる)

親ノードから子ノードへの伝搬 (pushdown)

最大値・最小値をそれぞれpushdownすればよい。

void push(int k) {
  // ノードk -> ノード2*k+1 への最大値・最小値の伝搬
  if(max_v[k] < max_v[2*k+1]) {
    update_node_max(2*k+1, max_v[k]);
  }
  if(min_v[2*k+1] < min_v[k]) {
    update_node_min(2*k+1, min_v[k]);
  }

  // ノードk -> ノード2*k+2 への最大値・最小値の伝搬
  if(max_v[k] < max_v[2*k+2]) {
    update_node_max(2*k+2, max_v[k]);
  }
  if(min_v[2*k+2] < min_v[k]) {
    update_node_min(2*k+2, min_v[k]);
  }
}
子ノードから親ノードへの伝搬 (update)

最大値・最小値の情報をそれぞれ親ノードへ伝搬する。

void update(int k) {
  sum[k] = sum[2*k+1] + sum[2*k+2];

  // 最大値の方の情報更新
  // 左右の子ノードの最大値の大小関係によって更新内容を変更
  if(max_v[2*k+1] < max_v[2*k+2]) {
    max_v[k] = max_v[2*k+2];
    max_c[k] = max_c[2*k+2];
    smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
  } else if(max_v[2*k+1] > max_v[2*k+2]) {
    max_v[k] = max_v[2*k+1];
    max_c[k] = max_c[2*k+1];
    smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
  } else {
    max_v[k] = max_v[2*k+1];
    max_c[k] = max_c[2*k+1] + max_c[2*k+2];
    smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
  }

  // 最小値の方の情報更新
  // 左右の子ノードの最小値の大小関係によって更新内容を変更
  if(min_v[2*k+1] < min_v[2*k+2]) {
    min_v[k] = min_v[2*k+1];
    min_c[k] = min_c[2*k+1];
    smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]);
  } else if(min_v[2*k+1] > min_v[2*k+2]) {
    min_v[k] = min_v[2*k+2];
    min_c[k] = min_c[2*k+2];
    smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]);
  } else {
    min_v[k] = min_v[2*k+1];
    min_c[k] = min_c[2*k+1] + min_c[2*k+2];
    smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]);
  }
}

区間chmin/chmaxクエリに加えてRAQ、RUQを扱う

区間chmin/chmaxクエリに合わせてRAQ(区間加算クエリ)、RUQ(区間更新クエリ)も扱うことができる。

  •  {l \le i < r} について、  {a_i} の値を  {a_i + x} に更新
  •  {l \le i < r} について、  {a_i} の値を  {x} に更新

この場合の実装は基本的には前記の区間chmin/chmaxクエリを処理する場合の実装とあまり変わらない。
それらの実装に加えて、従来の遅延セグ木と同様に各ノードで加算された値と更新された値を遅延値として持ち、pushdownされる際にこれらの値を伝搬しながら最大値・最小値を更新すればよい。

RAQ、RUQを含むと、計算量は全体で  O(N \log N + M \log^2 N) になる。

全体実装は以下。
tjkendev.github.io