library_for_cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Kazun1998/library_for_cpp

:heavy_check_mark: 永続 Segment Tree
(Segment_Tree/Persistent_Segment_Tree.hpp)

Outline

Segment Tree を永続化する. 区間を表すノードをポインタで実装する. そのため, 区間のコピーなどに対する処理を得意としている.

Contents

Constructor

(1)
template<typename M>
Persistent_Segment_Tree(const vector<M> &data, const function<M(M, M)> op, const M identity)

(2)
template<typename M>
Persistent_Segment_Tree(const int n, const function<M(M, M)> op, const M identity)

update

(1) int update(const int s, const int t, const int k, const M x)
(2) int update(const int t, const int k, const M x)
(3) int update(const int k, const M x)

amend

int amend(const int t, const int k, const M x)

copy

(1) int copy(const int s, const int t, const int u, const int l, const int r)
(2) int copy(const int s, const int t, const int l, const int r)
(3) int copy(const int t, const int l, const int r)

increment

int increment(int t = -1)

clone

(1) int clone(const int s, const int t)
(2) int clone(const int t)

product

(1) M product(const int t, const int l, const int r) const
(2) M product(const int l, const int r) const

all_product

(1) M all_product(const int t) const
(2) M all_product() const

get

(1) M get(const int t, const int k) const
(2) M get(const int k) const
(3) M operator[](const int k) const

current_version

int current_version()

size

size_t size()

History

日付 内容
2026/04/26 Persistent_Segment_Tree 実装

Depends on

Verified with

Code

#pragma once

#include "../template/template.hpp"

template<typename M>
class Persistent_Segment_Tree {
    private:
    struct Node {
        M x;
        Node *left_child, *right_child;

        Node(M x) : x(x), left_child(nullptr), right_child(nullptr) {}
        Node(M x, Node* left, Node* right) : x(x), left_child(left), right_child(right) {}
    };

    int n;
    const function<M(M, M)> op;
    const M identity;
    vector<Node*> roots;
    vector<Node*> nodes_pool;
    int version;

    Node* new_node(M x, Node* l = nullptr, Node* r = nullptr) {
        Node* res = new Node(x, l, r);
        nodes_pool.push_back(res);
        return res;
    }

    Node* build(const int l, const int r, const vector<M> &data) {
        if (l >= r) return nullptr;
        // 1 要素区間を表す頂点 → 葉
        if (r - l == 1) return new_node(data[l]);

        // そうでない場合, 2 要素以上の区間なので, 左と右に分割できる.
        int m = (l + r) / 2;

        Node* left = build(l, m, data);
        Node* right = build(m, r, data);

        return new_node(op(left->x, right->x), left, right);
    }

    void build_up(const vector<M> &data) {
        if (n > 0) roots.emplace_back(build(0, n, data));
        else roots.emplace_back(nullptr);
    }

    Node* _update(const Node* node, const int l, const int r, const int k, const M x) {
        // 葉に到達した場合, 新しい値を保持するノードを作成して返す.
        if (r - l == 1) return new_node(x);

        int m = (l + r) / 2;
        Node *left = node->left_child, *right = node->right_child;

        // 更新対象のインデックスに応じて, 左または右の子を再帰的に新しく作成する
        if (k < m) left = _update(left, l, m, k, x);
        else right = _update(right, m, r, k, x);

        // 新しく作成した子(片方)と, 既存のもう片方の子を組み合わせて,現在の高さのノードを新しく作成する
        return new_node(op(left->x, right->x), left, right);
    }

    // 半開区間 [l, r) を計算する. 現在見ているノードは半開区間 [a, b) を表す.
    M _product(const Node* node, const int l, const int r, const int a, const int b) const {
        // [l, r) と [a, b) が互いに素ならば, 単位元を返す.
        if (b <= l || r <= a) return identity;
        
        // [a, b) が [l, r) に含まれているならば, ノードの値をそのまま返す.
        if (l <= a && b <= r) return node->x;

        int m = (a + b) / 2;

        M vl = _product(node->left_child, l, r, a, m);
        M vr = _product(node->right_child, l, r, m, b);

        return op(vl, vr);
    }

    M _get(const Node* node, const int l, const int r, const int k) const {
        if (r - l == 1) return node->x;
        int m = (l + r) / 2;
        if (k < m) return _get(node->left_child, l, m, k);
        else return _get(node->right_child, m, r, k);
    }

    Node* _copy(const Node* node_curr, const Node* node_src, const int l, const int r, const int a, const int b) {
        if (b <= l || r <= a) return const_cast<Node*>(node_curr);
        if (l <= a && b <= r) return const_cast<Node*>(node_src);

        int m = (a + b) / 2;
        Node *left = _copy(node_curr->left_child, node_src->left_child, l, r, a, m);
        Node *right = _copy(node_curr->right_child, node_src->right_child, l, r, m, b);

        return new_node(op(left->x, right->x), left, right);
    }

    public:
    /// @brief コンストラクタ. 配列 data で初期化する.
    /// @param data 初期データ
    /// @param op 二項演算子
    /// @param identity 単位元
    Persistent_Segment_Tree(const vector<M> &data, const function<M(M, M)> op, const M identity): n(data.size()), op(op), identity(identity), version(0) {
        build_up(data);
    }

    /// @brief コンストラクタ. サイズ n, 全要素 identity で初期化する.
    /// @param n 配列サイズ
    /// @param op 二項演算子
    /// @param identity 単位元
    Persistent_Segment_Tree(const int n, const function<M(M, M)> op, const M identity): n(n), op(op), identity(identity), version(0) {
        build_up(vector<M>(n, identity));
    }

    ~Persistent_Segment_Tree() {
        for (Node* node : nodes_pool) delete node;
    }

    /// @brief バージョン t をコピーして新しいバージョンを作成し, そのインデックスを返す.
    /// @param t ベースとするバージョン (デフォルトは最新バージョン)
    /// @return 新しいバージョン番号
    int increment(int t = -1) {
        if (t == -1) t = version;
        assert(t <= version);
        roots.emplace_back(roots[t]);
        return ++version;
    }

    /// @brief バージョン s をベースに第 k 要素を x に更新した状態を作成し, バージョン t に代入する.
    /// @param s 元のバージョン
    /// @param t 保存先のバージョン
    /// @param k 更新するインデックス (0-indexed)
    /// @param x 更新後の値
    /// @return バージョン t
    int update(const int s, const int t, const int k, const M x) {
        assert(s <= version);
        assert(t <= version);
        assert(0 <= k && k < n);

        roots[t] = _update(roots[s], 0, n, k, x);
        return t;
    }

    /// @brief バージョン t をベースに第 k 要素を x に更新した新しい状態を作成する. 現在の最新バージョンを上書きする.
    /// @param t ベースとするバージョン
    /// @param k 更新するインデックス (0-indexed)
    /// @param x 更新後の値
    /// @return 最新バージョン番号
    int update(const int t, const int k, const M x) { return update(t, version, k, x); }

    /// @brief 最新バージョンをベースに第 k 要素を x に更新した新しい状態を作成し、最新バージョンを上書きする.
    /// @param k 更新するインデックス (0-indexed)
    /// @param x 更新後の値
    /// @return 最新バージョン番号
    int update(const int k, const M x) { return update(version, k, x); }

    /// @brief バージョン t をその場で更新(上書き)する (update(t, t, k, x) のシノニム).
    /// @param t 更新対象のバージョン
    /// @param k インデックス
    /// @param x 値
    /// @return バージョン番号 t
    int amend(const int t, const int k, const M x) { return update(t, t, k, x); }

    /// @brief バージョン s の [l, r] の範囲をバージョン t にコピー(マージ)したものをバージョン u に保存する.
    /// @param s コピー元のバージョン
    /// @param t コピー先のベースとなるバージョン
    /// @param u 結果の保存先バージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return バージョン u
    int copy(const int s, const int t, const int u, const int l, const int r) {
        assert(s <= version);
        assert(t <= version);
        assert(u <= version);
        if (n == 0 || l > r) return u;
        assert(0 <= l && r < n);

        roots[u] = _copy(roots[t], roots[s], l, r + 1, 0, n);
        return u;
    }

    /// @brief バージョン s の [l, r] の範囲をバージョン t にコピー(マージ)する.
    /// @param s コピー元のバージョン
    /// @param t コピー先のバージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return バージョン t
    int copy(const int s, const int t, const int l, const int r) { return copy(s, t, t, l, r); }

    /// @brief バージョン t の [l, r] の範囲を最新バージョンにコピーする.
    /// @param t コピー元のバージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return 最新バージョン番号
    int copy(const int t, const int l, const int r) { return copy(t, version, version, l, r); }

    /// @brief バージョン s の内容をバージョン t にそのままコピーする.
    /// @param s コピー元のバージョン
    /// @param t コピー先のバージョン
    /// @return バージョン t
    int clone(const int s, const int t) {
        assert(s <= version);
        assert(t <= version);
        roots[t] = roots[s];
        return t;
    }

    /// @brief バージョン t の内容を現在の最新バージョンにそのままコピーする.
    /// @param t コピー元のバージョン
    /// @return 最新バージョン番号
    int clone(const int t) { return clone(t, version); }

    /// @brief バージョン t における [l, r] の範囲の総積を求める.
    /// @param t 取得対象のバージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return 区間の総積
    M product(const int t, const int l, const int r) const {
        assert(t <= version);
        if (l > r || n == 0) return identity;
        return _product(roots[t], l, r + 1, 0, n);
    }

    /// @brief 最新バージョンにおける [l, r] の範囲の総積を求める.
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return 区間の総積
    M product(const int l, const int r) const {
        return product(version, l, r);
    }

    /// @brief バージョン t における全区間の総積を求める.
    /// @param t 取得対象のバージョン
    /// @return 全区間の総積
    M all_product(const int t) const {
        assert(t <= version);
        return (n == 0 || !roots[t]) ? identity : roots[t]->x;
    }

    /// @brief 最新バージョンにおける全区間の総積を求める.
    /// @return 全区間の総積
    M all_product() const { return all_product(version); }

    /// @brief バージョン t における第 k 要素の値を取得する.
    /// @param t 取得対象のバージョン
    /// @param k インデックス (0-indexed)
    /// @return 要素の値
    M get(const int t, const int k) const {
        assert(t <= version);
        if (n == 0) return identity;
        assert(0 <= k && k < n);

        return _get(roots[t], 0, n, k);
    }

    /// @brief 最新バージョンにおける第 k 要素の値を取得する.
    /// @param k インデックス (0-indexed)
    /// @return 要素の値
    M get(const int k) const { return get(version, k); }

    /// @brief 最新バージョンにおける第 k 要素の値を取得する.
    M operator[](const int k) const { return get(version, k); }

    /// @brief 現在の最新バージョン番号を取得する.
    int current_version() const { return version; }

    /// @brief セグメント木における要素数を取得する.
    size_t size() const { return n; }
};
#line 2 "Segment_Tree/Persistent_Segment_Tree.hpp"

#line 2 "template/template.hpp"

using namespace std;

// intrinstic
#include <immintrin.h>

#include <algorithm>
#include <array>
#include <bitset>
#include <cassert>
#include <cctype>
#include <cfenv>
#include <cfloat>
#include <chrono>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <complex>
#include <concepts>
#include <cstdarg>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <initializer_list>
#include <iomanip>
#include <ios>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <optional>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <stack>
#include <streambuf>
#include <string>
#include <tuple>
#include <type_traits>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

// utility
#line 2 "template/utility.hpp"

using ll = long long;

// a ← max(a, b) を実行する. a が更新されたら, 返り値が true.
template<typename T, typename U>
inline bool chmax(T &a, const U b){
    return (a < b ? a = b, 1: 0);
}

// a ← min(a, b) を実行する. a が更新されたら, 返り値が true.
template<typename T, typename U>
inline bool chmin(T &a, const U b){
    return (a > b ? a = b, 1: 0);
}

// a の最大値を取得する.
template<typename T>
inline T max(const vector<T> &a){
    if (a.empty()) throw invalid_argument("vector is empty.");

    return *max_element(a.begin(), a.end());
}

// vector<T> a の最小値を取得する.
template<typename T>
inline T min(const vector<T> &a){
    if (a.empty()) throw invalid_argument("vector is empty.");

    return *min_element(a.begin(), a.end());
}

// vector<T> a の最大値のインデックスを取得する.
template<typename T>
inline size_t argmax(const vector<T> &a){
    if (a.empty()) throw std::invalid_argument("vector is empty.");

    return distance(a.begin(), max_element(a.begin(), a.end()));
}

// vector<T> a の最小値のインデックスを取得する.
template<typename T>
inline size_t argmin(const vector<T> &a){
    if (a.empty()) throw invalid_argument("vector is empty.");

    return distance(a.begin(), min_element(a.begin(), a.end()));
}
#line 61 "template/template.hpp"

// math
#line 2 "template/math.hpp"

// 演算子
template<typename T>
T add(const T &x, const T &y) { return x + y; }

template<typename T>
T sub(const T &x, const T &y) { return x - y; }

template<typename T>
T mul(const T &x, const T &y) { return x * y; }

template<typename T>
T neg(const T &x) { return -x; }

template<integral T>
T bitwise_and(const T &x, const T &y) { return x & y; }

template<integral T>
T bitwise_or(const T &x, const T &y) { return x | y; }

template<integral T>
T bitwise_xor(const T &x, const T &y) { return x ^ y; }

// 除算に関する関数

// floor(x / y) を求める.
template<integral T, integral U>
auto div_floor(T x, U y){
    return x / y - ((x % y != 0) && ((x < 0) != (y < 0)));
}

// ceil(x / y) を求める.
template<integral T, integral U>
auto div_ceil(T x, U y){
    return x / y + ((x % y != 0) && ((x < 0) == (y < 0)));
}

// x を y で割った余りを求める.
template<integral T, integral U>
auto safe_mod(T x, U y){
    auto q = div_floor(x, y);
    return x - q * y ;
}

// x を y で割った商と余りを求める.
template<integral T, integral U>
auto divmod(T x, U y){
    auto q = div_floor(x, y);
    return make_pair(q, x - q * y);
}

// 四捨五入を求める.
template<integral T, integral U>
auto round(T x, U y){
    auto [q, r] = divmod(x, y);
    if (y < 0) return (r <= div_floor(y, 2)) ? q + 1 : q;
    return (r >= div_ceil(y, 2)) ? q + 1 : q;
}

// 奇数かどうか判定する.
template<integral T>
bool is_odd(const T &x) { return x % 2 != 0; }

// 偶数かどうか判定する.
template<integral T>
bool is_even(const T &x) { return x % 2 == 0; }

// m の倍数かどうか判定する.
template<integral T, integral U>
bool is_multiple(const T &x, const U &m) { return x % m == 0; }

// 正かどうか判定する.
template<typename T>
bool is_positive(const T &x) { return x > 0; }

// 負かどうか判定する.
template<typename T>
bool is_negative(const T &x) { return x < 0; }

// ゼロかどうか判定する.
template<typename T>
bool is_zero(const T &x) { return x == 0; }

// 非負かどうか判定する.
template<typename T>
bool is_non_negative(const T &x) { return x >= 0; }

// 非正かどうか判定する.
template<typename T>
bool is_non_positive(const T &x) { return x <= 0; }

// 指数に関する関数

// x の y 乗を求める.
ll intpow(ll x, ll y){
    ll a = 1;
    while (y){
        if (y & 1) { a *= x; }
        x *= x;
        y >>= 1;
    }
    return a;
}

ll pow(ll x, ll y) { return intpow(x, y); }

// x の y 乗を z で割った余りを求める.
template<typename T, integral U>
T modpow(T x, U y, T z) {
    T a = 1;
    while (y) {
        if (y & 1) { (a *= x) %= z; }

        (x *= x) %= z;
        y >>= 1;
    }

    return a;
}

template<typename T>
T sum(const vector<T> &X) {
    T y = T(0);
    for (auto &&x: X) { y += x; }
    return y;
}

template<typename T>
T gcd(const T x, const T y) {
    return y == 0 ? x : gcd(y, x % y);
}

// a x + b y = gcd(a, b) を満たす整数の組 (a, b) に対して, (x, y, gcd(a, b)) を求める.
template<integral T>
tuple<T, T, T> Extended_Euclid(T a, T b) {
    T s = 1, t = 0, u = 0, v = 1;
    while (b) {
        auto [q, r] = divmod(a, b);
        a = b;
        b = r;
        tie(s, t) = make_pair(t, s - q * t);
        tie(u, v) = make_pair(v, u - q * v);
    }

    return make_tuple(s, u, a);
}

// floor(sqrt(N)) を求める (N < 0 のときは, 0 とする).
ll isqrt(const ll &N) { 
    if (N <= 0) { return 0; }

    ll x = sqrtl(N);
    while ((x + 1) * (x + 1) <= N) { x++; }
    while (x * x > N) { x--; }

    return x;
}

// floor(sqrt(N)) を求める (N < 0 のときは, 0 とする).
ll floor_sqrt(const ll &N) { return isqrt(N); }

// ceil(sqrt(N)) を求める (N < 0 のときは, 0 とする).
ll ceil_sqrt(const ll &N) {
    ll x = isqrt(N);
    return x * x == N ? x : x + 1;
}
#line 64 "template/template.hpp"

// inout
#line 1 "template/inout.hpp"
// 入出力
template<class... T>
void input(T&... a){ (cin >> ... >> a); }

void print(){ cout << "\n"; }

template<class T, class... Ts>
void print(const T& a, const Ts&... b){
    cout << a;
    (cout << ... << (cout << " ", b));
    cout << "\n";
}

template<typename T, typename U>
istream &operator>>(istream &is, pair<T, U> &P){
    is >> P.first >> P.second;
    return is;
}

template<typename T, typename U>
ostream &operator<<(ostream &os, const pair<T, U> &P){
    os << P.first << " " << P.second;
    return os;
}

template<typename T>
vector<T> vector_input(int N, int index){
    vector<T> X(N+index);
    for (int i=index; i<index+N; i++) cin >> X[i];
    return X;
}

template<typename T>
istream &operator>>(istream &is, vector<T> &X){
    for (auto &x: X) { is >> x; }
    return is;
}

template<typename T>
ostream &operator<<(ostream &os, const vector<T> &X){
    int s = (int)X.size();
    for (int i = 0; i < s; i++) { os << (i ? " " : "") << X[i]; }
    return os;
}

template<typename T>
ostream &operator<<(ostream &os, const unordered_set<T> &S){
    int i = 0;
    for (T a: S) {os << (i ? " ": "") << a; i++;}
    return os;
}

template<typename T>
ostream &operator<<(ostream &os, const set<T> &S){
    int i = 0;
    for (T a: S) { os << (i ? " ": "") << a; i++; }
    return os;
}

template<typename T>
ostream &operator<<(ostream &os, const unordered_multiset<T> &S){
    int i = 0;
    for (T a: S) { os << (i ? " ": "") << a; i++; }
    return os;
}

template<typename T>
ostream &operator<<(ostream &os, const multiset<T> &S){
    int i = 0;
    for (T a: S) { os << (i ? " ": "") << a; i++; }
    return os;
}

template<typename T>
std::vector<T> input_vector(size_t n, size_t offset = 0) {
    std::vector<T> res;
    // 最初に必要な全容量を確保(再確保を防ぐ)
    res.reserve(n + offset);
    // offset 分をデフォルト値で埋める(特別 indexed 用)
    res.assign(offset, T());
    
    for (size_t i = 0; i < n; ++i) {
        T el;
        if (!(std::cin >> el)) break;
        res.push_back(std::move(el));
    }
    return res;
}
#line 67 "template/template.hpp"

// macro
#line 2 "template/macro.hpp"

// マクロの定義
#define all(x) x.begin(), x.end()
#define len(x) ll(x.size())
#define elif else if
#define unless(cond) if (!(cond))
#define until(cond) while (!(cond))
#define loop while (true)

// オーバーロードマクロ
#define overload2(_1, _2, name, ...) name
#define overload3(_1, _2, _3, name, ...) name
#define overload4(_1, _2, _3, _4, name, ...) name
#define overload5(_1, _2, _3, _4, _5, name, ...) name

// 繰り返し系
#define rep1(n) for (ll i = 0; i < n; i++)
#define rep2(i, n) for (ll i = 0; i < n; i++)
#define rep3(i, a, b) for (ll i = a; i < b; i++)
#define rep4(i, a, b, c) for (ll i = a; i < b; i += c)
#define rep(...) overload4(__VA_ARGS__, rep4, rep3, rep2, rep1)(__VA_ARGS__)

#define foreach1(x, a) for (auto &&x: a)
#define foreach2(x, y, a) for (auto &&[x, y]: a)
#define foreach3(x, y, z, a) for (auto &&[x, y, z]: a)
#define foreach4(x, y, z, w, a) for (auto &&[x, y, z, w]: a)
#define foreach(...) overload5(__VA_ARGS__, foreach4, foreach3, foreach2, foreach1)(__VA_ARGS__)
#line 70 "template/template.hpp"

// bitop
#line 2 "template/bitop.hpp"

// 非負整数 x の bit legnth を求める.
ll bit_length(ll x) {
    if (x == 0) { return 0; }
    return (sizeof(long) * CHAR_BIT) - __builtin_clzll(x);
}

// 非負整数 x の popcount を求める.
ll popcount(ll x) { return __builtin_popcountll(x); }

// 正の整数 x に対して, floor(log2(x)) を求める.
ll floor_log2(ll x) { return bit_length(x) - 1; }

// 正の整数 x に対して, ceil(log2(x)) を求める.
ll ceil_log2(ll x) { return bit_length(x - 1); }

// x の第 k ビットを取得する
int get_bit(ll x, int k) { return (x >> k) & 1; }

// x のビット列を取得する.
// k はビット列の長さとする.
vector<int> get_bits(ll x, int k) {
    vector<int> bits(k);
    rep(i, k) {
        bits[i] = x & 1;
        x >>= 1;
    }

    return bits;
}

// x のビット列を取得する.
vector<int> get_bits(ll x) { return get_bits(x, bit_length(x)); }

// x に立っているなんかしらのビットの番号を出力する.
ll lowest_bit(const ll x) { return floor_log2(x & (-x)); }
#line 73 "template/template.hpp"

// exception
#line 2 "template/exception.hpp"

class NotExist: public exception {
    private:
    string message;

    public:
    NotExist() : message("求めようとしていたものは存在しません.") {}

    const char* what() const noexcept override {
        return message.c_str();
    }
};
#line 4 "Segment_Tree/Persistent_Segment_Tree.hpp"

template<typename M>
class Persistent_Segment_Tree {
    private:
    struct Node {
        M x;
        Node *left_child, *right_child;

        Node(M x) : x(x), left_child(nullptr), right_child(nullptr) {}
        Node(M x, Node* left, Node* right) : x(x), left_child(left), right_child(right) {}
    };

    int n;
    const function<M(M, M)> op;
    const M identity;
    vector<Node*> roots;
    vector<Node*> nodes_pool;
    int version;

    Node* new_node(M x, Node* l = nullptr, Node* r = nullptr) {
        Node* res = new Node(x, l, r);
        nodes_pool.push_back(res);
        return res;
    }

    Node* build(const int l, const int r, const vector<M> &data) {
        if (l >= r) return nullptr;
        // 1 要素区間を表す頂点 → 葉
        if (r - l == 1) return new_node(data[l]);

        // そうでない場合, 2 要素以上の区間なので, 左と右に分割できる.
        int m = (l + r) / 2;

        Node* left = build(l, m, data);
        Node* right = build(m, r, data);

        return new_node(op(left->x, right->x), left, right);
    }

    void build_up(const vector<M> &data) {
        if (n > 0) roots.emplace_back(build(0, n, data));
        else roots.emplace_back(nullptr);
    }

    Node* _update(const Node* node, const int l, const int r, const int k, const M x) {
        // 葉に到達した場合, 新しい値を保持するノードを作成して返す.
        if (r - l == 1) return new_node(x);

        int m = (l + r) / 2;
        Node *left = node->left_child, *right = node->right_child;

        // 更新対象のインデックスに応じて, 左または右の子を再帰的に新しく作成する
        if (k < m) left = _update(left, l, m, k, x);
        else right = _update(right, m, r, k, x);

        // 新しく作成した子(片方)と, 既存のもう片方の子を組み合わせて,現在の高さのノードを新しく作成する
        return new_node(op(left->x, right->x), left, right);
    }

    // 半開区間 [l, r) を計算する. 現在見ているノードは半開区間 [a, b) を表す.
    M _product(const Node* node, const int l, const int r, const int a, const int b) const {
        // [l, r) と [a, b) が互いに素ならば, 単位元を返す.
        if (b <= l || r <= a) return identity;
        
        // [a, b) が [l, r) に含まれているならば, ノードの値をそのまま返す.
        if (l <= a && b <= r) return node->x;

        int m = (a + b) / 2;

        M vl = _product(node->left_child, l, r, a, m);
        M vr = _product(node->right_child, l, r, m, b);

        return op(vl, vr);
    }

    M _get(const Node* node, const int l, const int r, const int k) const {
        if (r - l == 1) return node->x;
        int m = (l + r) / 2;
        if (k < m) return _get(node->left_child, l, m, k);
        else return _get(node->right_child, m, r, k);
    }

    Node* _copy(const Node* node_curr, const Node* node_src, const int l, const int r, const int a, const int b) {
        if (b <= l || r <= a) return const_cast<Node*>(node_curr);
        if (l <= a && b <= r) return const_cast<Node*>(node_src);

        int m = (a + b) / 2;
        Node *left = _copy(node_curr->left_child, node_src->left_child, l, r, a, m);
        Node *right = _copy(node_curr->right_child, node_src->right_child, l, r, m, b);

        return new_node(op(left->x, right->x), left, right);
    }

    public:
    /// @brief コンストラクタ. 配列 data で初期化する.
    /// @param data 初期データ
    /// @param op 二項演算子
    /// @param identity 単位元
    Persistent_Segment_Tree(const vector<M> &data, const function<M(M, M)> op, const M identity): n(data.size()), op(op), identity(identity), version(0) {
        build_up(data);
    }

    /// @brief コンストラクタ. サイズ n, 全要素 identity で初期化する.
    /// @param n 配列サイズ
    /// @param op 二項演算子
    /// @param identity 単位元
    Persistent_Segment_Tree(const int n, const function<M(M, M)> op, const M identity): n(n), op(op), identity(identity), version(0) {
        build_up(vector<M>(n, identity));
    }

    ~Persistent_Segment_Tree() {
        for (Node* node : nodes_pool) delete node;
    }

    /// @brief バージョン t をコピーして新しいバージョンを作成し, そのインデックスを返す.
    /// @param t ベースとするバージョン (デフォルトは最新バージョン)
    /// @return 新しいバージョン番号
    int increment(int t = -1) {
        if (t == -1) t = version;
        assert(t <= version);
        roots.emplace_back(roots[t]);
        return ++version;
    }

    /// @brief バージョン s をベースに第 k 要素を x に更新した状態を作成し, バージョン t に代入する.
    /// @param s 元のバージョン
    /// @param t 保存先のバージョン
    /// @param k 更新するインデックス (0-indexed)
    /// @param x 更新後の値
    /// @return バージョン t
    int update(const int s, const int t, const int k, const M x) {
        assert(s <= version);
        assert(t <= version);
        assert(0 <= k && k < n);

        roots[t] = _update(roots[s], 0, n, k, x);
        return t;
    }

    /// @brief バージョン t をベースに第 k 要素を x に更新した新しい状態を作成する. 現在の最新バージョンを上書きする.
    /// @param t ベースとするバージョン
    /// @param k 更新するインデックス (0-indexed)
    /// @param x 更新後の値
    /// @return 最新バージョン番号
    int update(const int t, const int k, const M x) { return update(t, version, k, x); }

    /// @brief 最新バージョンをベースに第 k 要素を x に更新した新しい状態を作成し、最新バージョンを上書きする.
    /// @param k 更新するインデックス (0-indexed)
    /// @param x 更新後の値
    /// @return 最新バージョン番号
    int update(const int k, const M x) { return update(version, k, x); }

    /// @brief バージョン t をその場で更新(上書き)する (update(t, t, k, x) のシノニム).
    /// @param t 更新対象のバージョン
    /// @param k インデックス
    /// @param x 値
    /// @return バージョン番号 t
    int amend(const int t, const int k, const M x) { return update(t, t, k, x); }

    /// @brief バージョン s の [l, r] の範囲をバージョン t にコピー(マージ)したものをバージョン u に保存する.
    /// @param s コピー元のバージョン
    /// @param t コピー先のベースとなるバージョン
    /// @param u 結果の保存先バージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return バージョン u
    int copy(const int s, const int t, const int u, const int l, const int r) {
        assert(s <= version);
        assert(t <= version);
        assert(u <= version);
        if (n == 0 || l > r) return u;
        assert(0 <= l && r < n);

        roots[u] = _copy(roots[t], roots[s], l, r + 1, 0, n);
        return u;
    }

    /// @brief バージョン s の [l, r] の範囲をバージョン t にコピー(マージ)する.
    /// @param s コピー元のバージョン
    /// @param t コピー先のバージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return バージョン t
    int copy(const int s, const int t, const int l, const int r) { return copy(s, t, t, l, r); }

    /// @brief バージョン t の [l, r] の範囲を最新バージョンにコピーする.
    /// @param t コピー元のバージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return 最新バージョン番号
    int copy(const int t, const int l, const int r) { return copy(t, version, version, l, r); }

    /// @brief バージョン s の内容をバージョン t にそのままコピーする.
    /// @param s コピー元のバージョン
    /// @param t コピー先のバージョン
    /// @return バージョン t
    int clone(const int s, const int t) {
        assert(s <= version);
        assert(t <= version);
        roots[t] = roots[s];
        return t;
    }

    /// @brief バージョン t の内容を現在の最新バージョンにそのままコピーする.
    /// @param t コピー元のバージョン
    /// @return 最新バージョン番号
    int clone(const int t) { return clone(t, version); }

    /// @brief バージョン t における [l, r] の範囲の総積を求める.
    /// @param t 取得対象のバージョン
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return 区間の総積
    M product(const int t, const int l, const int r) const {
        assert(t <= version);
        if (l > r || n == 0) return identity;
        return _product(roots[t], l, r + 1, 0, n);
    }

    /// @brief 最新バージョンにおける [l, r] の範囲の総積を求める.
    /// @param l 左端 (閉区間)
    /// @param r 右端 (閉区間)
    /// @return 区間の総積
    M product(const int l, const int r) const {
        return product(version, l, r);
    }

    /// @brief バージョン t における全区間の総積を求める.
    /// @param t 取得対象のバージョン
    /// @return 全区間の総積
    M all_product(const int t) const {
        assert(t <= version);
        return (n == 0 || !roots[t]) ? identity : roots[t]->x;
    }

    /// @brief 最新バージョンにおける全区間の総積を求める.
    /// @return 全区間の総積
    M all_product() const { return all_product(version); }

    /// @brief バージョン t における第 k 要素の値を取得する.
    /// @param t 取得対象のバージョン
    /// @param k インデックス (0-indexed)
    /// @return 要素の値
    M get(const int t, const int k) const {
        assert(t <= version);
        if (n == 0) return identity;
        assert(0 <= k && k < n);

        return _get(roots[t], 0, n, k);
    }

    /// @brief 最新バージョンにおける第 k 要素の値を取得する.
    /// @param k インデックス (0-indexed)
    /// @return 要素の値
    M get(const int k) const { return get(version, k); }

    /// @brief 最新バージョンにおける第 k 要素の値を取得する.
    M operator[](const int k) const { return get(version, k); }

    /// @brief 現在の最新バージョン番号を取得する.
    int current_version() const { return version; }

    /// @brief セグメント木における要素数を取得する.
    size_t size() const { return n; }
};
Back to top page