木DPを使って解ける問題には、以下のものがあります。
- ある頂点から最も遠い頂点までの距離はいくつか
- ある頂点を根としたときの木の深さはいくつか
これらの問題に共通していることは、「ある頂点」という文言が含まれていることです。頂点数を\(N\)とおいたとき、ひとつの頂点に対して上の問題を解くのに\(O(N)\)かかるため、すべての頂点に対して上の問題を解くと\(O(N ^ 2)\)かかることになります。
全方位木DPを使うことで、この計算量を\(O(N)\)まで落とすことができます。すべての頂点に対して木DPを行いたいときは、全方位木DPを使うといいかもしれません。
実装と使用例
template <typename T, typename D> struct ReRooting { using L = function<T(T, D)>; using M = function<T(T, T)>; struct Edge { int to; D data; Edge(int to, D data) : to(to), data(data) {} }; int n; vector<vector<Edge>> edges; L lift; M merge; T e; vector<vector<T>> dp; vector<int> par; vector<T> ans; ReRooting(int n, L lift, M merge, T e) : n(n), lift(lift), merge(merge), e(e) { edges = vector<vector<Edge>>(n); } void add_edge(int v, int u, D d) { edges[v].emplace_back(u, d); } T dfs(int v, int p = -1) { int l = edges[v].size(); dp[v] = vector<T>(l); T ret = e; for (int i = 0; i < l; i++) { auto e = edges[v][i]; if (e.to == p) { par[v] = i; continue; } dp[v][i] = lift(dfs(e.to, v), e.data); ret = merge(ret, dp[v][i]); } return ret; } void bfs(int v, T t) { if (par[v] != -1) { dp[v][par[v]] = lift(t, edges[v][par[v]].data); } int l = edges[v].size(); vector<T> dpl(l + 1, e); vector<T> dpr(l + 1, e); for (int i = 0; i < l; i++) { dpl[i + 1] = merge(dpl[i], dp[v][i]); } for (int i = l - 1; i >= 0; i--) { dpr[i] = merge(dpr[i + 1], dp[v][i]); } ans[v] = dpr[0]; for (int i = 0; i < l; i++) { if (i == par[v]) continue; int u = edges[v][i].to; bfs(u, merge(dpl[i], dpr[i + 1])); } } vector<T> solve() { dp = vector<vector<T>>(n); par = vector<int>(n, -1); ans = vector<T>(n); dfs(0); bfs(0, e); return ans; } };
全方位木DPの抽象部分を抜き出した構造体です。solve()
は、すべての頂点のマージし終えた結果を返します。
欲しい値が「マージし終えた結果」ではなく「マージ前の情報」のときがあります。そういったときは、少々複雑なマージの仕方を考えるか、構造体の中身を修正する必要があります。後者の場合、修正箇所はbfs()
のans[v] = dpr[0];
の行になります。それ以外を修正することはおそらくありません。
使用例
木の直径 | グラフ | Aizu Online Judge
2パターン載せます。最初の実装はReRooting
構造体を修正しない方針で、2つ目は構造体を修正する方針です。
ひとつめ
構造体を修正しない方針です。構造体の定義は省略しています。マージの仕方を少し考える必要があります。
int main() { using P = pair<int, int>; auto lift = [](P a, int w) { return P(max(0, a.first) + w, -1e9); }; auto merge = [](P a, P b) { for (auto t : { b.first, b.second }) { if (a.first < t) { a.second = a.first; a.first = t; } else if (a.second < t) { a.second = t; } } return a; }; int N; cin >> N; auto g = ReRooting<P, int>(N, lift, merge, P(-1e9, -1e9)); for (int i = 1; i < N; i++) { int s, t, w; cin >> s >> t >> w; g.add_edge(s, t, w); } int ans = 0; for (auto r : g.solve()) { ans = max(ans, r.first); ans = max(ans, r.first + r.second); } cout << ans << '\n'; return 0; }
ふたつめ
構造体の中身を修正する方針です。木DPで求めるものは「最も遠い頂点までの距離」とシンプルです。ans
周りの実装も素直です。
template <typename T, typename D> struct ReRooting { using L = function<T(T, D)>; using M = function<T(T, T)>; struct Edge { int to; D data; Edge(int to, D data) : to(to), data(data) {} }; int n; vector<vector<Edge>> edges; L lift; M merge; T e; vector<vector<T>> dp; vector<int> par; int ans; ReRooting(int n, L lift, M merge, T e) : n(n), lift(lift), merge(merge), e(e) { edges = vector<vector<Edge>>(n); } void add_edge(int v, int u, D d) { edges[v].emplace_back(u, d); } T dfs(int v = 0, int p = -1) { int l = edges[v].size(); dp[v] = vector<T>(l); T ret = e; for (int i = 0; i < l; i++) { auto e = edges[v][i]; if (e.to == p) { par[v] = i; continue; } dp[v][i] = lift(dfs(e.to, v), e.data); ret = merge(ret, dp[v][i]); } return ret; } void bfs(int v, T t) { if (par[v] != -1) { dp[v][par[v]] = lift(t, edges[v][par[v]].data); } int l = edges[v].size(); vector<T> dpl(l + 1, e); vector<T> dpr(l + 1, e); for (int i = 0; i < l; i++) { dpl[i + 1] = merge(dpl[i], dp[v][i]); } for (int i = l - 1; i >= 0; i--) { dpr[i] = merge(dpr[i + 1], dp[v][i]); } int fst = 0, snd = 0; for (int i = 0; i < l; i++) { if (fst <= dp[v][i]) { snd = fst, fst = dp[v][i]; } else if (snd <= dp[v][i]) { snd = dp[v][i]; } } ans = max(ans, fst + snd); for (int i = 0; i < l; i++) { if (i == par[v]) continue; int u = edges[v][i].to; bfs(u, merge(dpl[i], dpr[i + 1])); } } int solve() { dp = vector<vector<T>>(n); par = vector<int>(n, -1); ans = 0; dfs(); bfs(0, e); return ans; } }; int main() { auto lift = [](int a, int w) { return a + w; }; auto merge = [](int a, int b) { return max(a, b); }; int N; cin >> N; auto g = ReRooting<int, int>(N, lift, merge, 0); for (int i = 1; i < N; i++) { int s, t, w; cin >> s >> t >> w; g.add_edge(s, t, w); } cout << g.solve() << '\n'; return 0; }