エイシングプログラミングコンテスト2021(AtCoder Beginner Contest 202)- Count Descendants
オンラインとオフラインどっちが好みの人が多いんでしょうか?
解法
問題文を整理すると,各クエリ において,「深さが でかつ, の部分木に属しているものの個数を求めよ」となります. ある頂点 が部分木に属するかどうかは,オイラーツアー (参考 : Euler Tour のお勉強 | maspyのHP) などをすると, かどうかを確認すればわかります.各クエリは「深さが でかつ, である頂点の個数を求めよ」という問題になります.
オンラインクエリで解く方法
BinaryTrie と呼ばれるデータ構造 ( 参考 : 非負整数値を扱う Trie について - kazuma8128’s blog ) を各深さ毎に持ちます.各深さの BinaryTrie にそれぞれの, を管理すると,BinaryTrie は区間にある頂点の個数を求めるのは でできることから,全体で で解くことができました.なお,座標圧縮などをすることにより,BinaryIndexedTree を用いてもオンラインクエリで解くことができます.
オフラインクエリで解く方法
区間和を求めたくなりますが,各深さごとに区間 の BinaryIndexedTree を持つのは,メモリが厳しいです.ここで,1 つの BinaryIndexedTree を全ての深さで使いまわします.予め,クエリを先読みし,深さごとにクエリを管理します.そして,各深さごとに順にクエリを処理していきます.ある深さ のクエリを処理する場合,まず,深さ の を BinaryIndexedTree で管理します. そして,深さ に関するクエリについて処理します.深さ のクエリを全て処理し終えたら,深さ の を BinaryIndexedTree から取り除きます.計算量は です.
実装例 ( オンライン)
#include <iostream> #include <string> #include <algorithm> #include <vector> #include <queue> #include <utility> #include <tuple> #include <cmath> #include <numeric> #include <set> #include <map> #include <array> #include <complex> #include <iomanip> #include <cassert> #include <random> #include <chrono> #include <valarray> #include <bitset> using ll = long long; using std::cin; using std::cout; using std::endl; std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count()); template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; } template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; } const int inf = (int)1e9 + 7; const long long INF = 1LL << 60; namespace KKT89 { template< typename T, int MAX_LOG, typename D = int > struct BinaryTrie { public: struct Node { Node *nxt[2]; D exist; std::vector< int > accept; Node() : nxt{nullptr, nullptr}, exist(0) {} }; Node *root; explicit BinaryTrie() : root(new Node()) {} explicit BinaryTrie(Node *root) : root(root) {} void add(const T &bit, int idx = -1, D delta = 1, T xor_val = 0) { root = add(root, bit, idx, MAX_LOG, delta, xor_val); } void erase(const T &bit, T xor_val = 0) { add(bit, -1, -1, xor_val); } Node *find(const T &bit, T xor_val = 0) { return find(root, bit, MAX_LOG, xor_val); } D count(const T &bit, T xor_val = 0) { auto node = find(bit, xor_val); return node ? node->exist : 0; } std::pair< T, Node * > min_element(T xor_val = 0) { assert(root->exist > 0); return kth_element(0, xor_val); } std::pair< T, Node * > max_element(T xor_val = 0) { assert(root->exist > 0); return kth_element(root->exist - 1, xor_val); } std::pair< T, Node * > kth_element(D k, T xor_val = 0) { // 0-indexed assert(0 <= k && k < root->exist); return kth_element(root, k, MAX_LOG, xor_val); } D count_less(const T &bit, T xor_val = 0) { // < bit return count_less(root, bit, MAX_LOG, xor_val); } private: virtual Node *clone(Node *t) { return t; } Node *add(Node *t, T bit, int idx, int depth, D x, T xor_val, bool need = true) { if(need) t = clone(t); if(depth == -1) { t->exist += x; if(idx >= 0) t->accept.emplace_back(idx); } else { bool f = (xor_val >> depth) & 1; auto &to = t->nxt[f ^ ((bit >> depth) & 1)]; if(!to) to = new Node(), need = false; to = add(to, bit, idx, depth - 1, x, xor_val, need); t->exist += x; } return t; } Node *find(Node *t, T bit, int depth, T xor_val) { if(depth == -1) { return t; } else { bool f = (xor_val >> depth) & 1; auto &to = t->nxt[f ^ ((bit >> depth) & 1)]; return to ? find(to, bit, depth - 1, xor_val) : nullptr; } } std::pair< T, Node * > kth_element(Node *t, D k, int bit_index, T xor_val) { // 0-indexed if(bit_index == -1) { return {0, t}; } else { bool f = (xor_val >> bit_index) & 1; if((t->nxt[f] ? t->nxt[f]->exist : 0) <= k) { auto ret = kth_element(t->nxt[f ^ 1], k - (t->nxt[f] ? t->nxt[f]->exist : 0), bit_index - 1, xor_val); ret.first |= T(1) << bit_index; return ret; } else { return kth_element(t->nxt[f], k, bit_index - 1, xor_val); } } } D count_less(Node *t, const T &bit, int bit_index, T xor_val) { if(bit_index == -1) return 0; D ret = 0; bool f = (xor_val >> bit_index) & 1; if((bit >> bit_index & 1) and t->nxt[f]) ret += t->nxt[f]->exist; if(t->nxt[f ^ (bit >> bit_index & 1)]) ret += count_less(t->nxt[f ^ (bit >> bit_index & 1)], bit, bit_index - 1, xor_val); return ret; } }; } void solve() { int n; cin >> n; std::vector<std::vector<int>> g(n); for (int i = 1; i < n; ++i) { int p; cin >> p; p -= 1; g[i].emplace_back(p); g[p].emplace_back(i); } std::vector<int> dep(n), in(n), out(n); std::vector<KKT89::BinaryTrie<unsigned, 20, int>> vb(n); { int idx = 0; auto dfs = [&](auto &&self, int cur, int pre)->void { if(pre >= 0) dep[cur] = dep[pre] + 1; in[cur] = idx++; vb[dep[cur]].add(in[cur]); for(const auto &nxt : g[cur]) { if(nxt == pre) continue; self(self, nxt, cur); } out[cur] = idx; }; dfs(dfs, 0, -1); } int kkt; cin >> kkt; while(kkt--) { int u, d; cin >> u >> d; u -= 1; cout << vb[d].count_less(out[u]) - vb[d].count_less(in[u]) << "\n"; } } int main() { std::cin.tie(nullptr); std::ios::sync_with_stdio(false); int kkt = 1; // cin >> kkt; while(kkt--) solve(); return 0; }
実装例 ( オフライン )
#include <iostream> #include <string> #include <algorithm> #include <vector> #include <queue> #include <utility> #include <tuple> #include <cmath> #include <numeric> #include <set> #include <map> #include <array> #include <complex> #include <iomanip> #include <cassert> #include <random> #include <chrono> #include <valarray> #include <bitset> using ll = long long; using std::cin; using std::cout; using std::endl; std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count()); template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; } template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; } const int inf = (int)1e9 + 7; const long long INF = 1LL << 60; namespace KKT89 { template<typename T> struct BinaryIndexedTree { int n; std::vector<T> bit; BinaryIndexedTree() :n(0) {} BinaryIndexedTree(int _n) :n(_n) { bit = std::vector<T>(n + 1); } void add1(int idx, T val) { for (int i = idx; i <= n; i += i & -i) bit[i] += val; } T sum1(int idx) { T res = 0; for (int i = idx; i > 0; i -= i & -i) res += bit[i]; return res; } //0-indexed void add(int idx, T val) { add1(idx + 1, val); } //0-indexed [left, right) T sum(int left, int right) { return sum1(right) - sum1(left); } int lower_bound(T x) { int res = 0; int k = 1; while (2 * k <= n) k <<= 1; for (; k > 0; k >>= 1) { if (res + k <= n and bit[res + k] < x) { x -= bit[res + k]; res += k; } } return res; } }; } void solve() { int n; cin >> n; std::vector<std::vector<int>> g(n); for (int i = 1; i < n; ++i) { int p; cin >> p; p -= 1; g[i].emplace_back(p); g[p].emplace_back(i); } KKT89::BinaryIndexedTree<ll> bit(n); std::vector<int> dep(n), in(n), out(n); std::vector<std::vector<int>> vdep(n); { int idx = 0; auto dfs = [&](auto &&self, int cur, int pre)->void { if(pre >= 0) dep[cur] = dep[pre] + 1; in[cur] = idx++; for(const auto &nxt : g[cur]) { if(nxt == pre) continue; self(self, nxt, cur); } out[cur] = idx; }; dfs(dfs, 0, -1); for (int i = 0; i < n; ++i) { vdep[dep[i]].emplace_back(in[i]); } } int kkt; cin >> kkt; std::vector<std::vector<std::pair<int, int>>> query(n); std::vector<int> res(kkt); for (int i = 0; i < kkt; ++i) { int u, d; cin >> u >> d; u -= 1; query[d].emplace_back(u, i); } for (int i = n - 1; i >= 0; --i) { for(const auto &idx : vdep[i]) { bit.add(idx, 1); } for(const auto &[u, idx] : query[i]) { //cout << in[u] << " " << out[u] << endl; res[idx] = bit.sum(in[u], out[u]); } for(const auto &idx : vdep[i]) { bit.add(idx, -1); } } for (int i = 0; i < kkt; ++i) { cout << res[i] << "\n"; } } int main() { std::cin.tie(nullptr); std::ios::sync_with_stdio(false); int kkt = 1; // cin >> kkt; while(kkt--) solve(); return 0; }