No.95 Alice and Graph
No.95 Alice and Graph - yukicoder
- coin の大きいノードを順にみていけばよいのはわかったが、なぜ bit DP が必要なのか(というか解説コードでやっている bit DP)がわからなかった。
- bit DP はどっちかというと二分探索で可能か判定しながら値を求めるやつに似た感じ。
- 最悪計算量的には O(max(N*2^K*K^2, N^3)) なんだけど通る。
class AliceAndGraph { public: void solve(void) { const int inf = (1<<25); int N,M,K; cin>>N>>M>>K; vector<vector<int>> dist(N, vector<int>(N, inf)); REP(_,M) { int u,v; cin>>u>>v; --u,--v; dist[u][v] = dist[v][u] = 1; } // // S2(k) = 1+2+2^2+...+2^(k-1) = 2^k-1 とすると // S2(1) + S2(2) + ... + S2(k-1) = S2(k) - k < S2(k) // // なので S2(k) の大きい物から優先的に回るとよい // // 最短経路を求めておく // O(N^3) REP(k,N) REP(i,N) REP(j,N) dist[i][j] = min(dist[i][j], dist[i][k]+dist[k][j]); // // パス上のノードのコインを獲得するが v から u へのパスが一つではないので // どのパスを選ぶのがよいか探索しなくてはいけない。 // (基本的には中間ノードで出現する番号が最も大きいパスを選べばよいが、 // それはパスを選んで通過してみないとわからない。) // // そこで以下のような bit DP を考える。 // // 最大到達可能なノード数は K 個なので // N-1,N-2,... の順で到達候補ノード集合 S (size(S) <= K+1 0 ノードを含む分 +1) // にノードを追加して実際に cost K 以内で到達が可能か試す。 // // S = {N-1,0} ... ok // S = {N-1,N-2,0} ... NG // S = {N-1,N-3,0} ... ok // S = {N-1,N-3,N-4,0} ... ok // : // S = {N-1,N-3,N-4,...,0} ... answer // // dp[bit][i] := 訪問ノードが bit で現在地が i のときの最小コスト vector<vector<int>> dp(1<<(K+1), vector<int>(K+1)); // O(2^K*K^2) auto calcCost = [&](const vector<int> &s) { int sz = s.size(); REP(i,1<<sz) fill(RANGE(dp[i]), inf); // 初期化 dp[1][0] = 0; // 0 をスタート地点にして全て回る for (int bit = 1; bit < (1<<sz); ++bit) { REP(i,sz) // i -> j への到達コストを計算 { // s[i] に未到達なら飛ばす if (dp[bit][i] == inf) continue; if ( !(bit & (1<<i)) ) continue; // 未到達なもの毎にチェック REP(j,sz) { // 到達済みは飛ばす if (bit & (1<<j)) continue; int next = bit | (1<<j); dp[next][j] = min(dp[next][j], dp[bit][i] + dist[s[i]][s[j]]); } } } int ans = inf; REP(i,sz) // 終端位置毎にチェック ans = min(dp[(1<<sz)-1][i], ans); return ans; }; vector<int> S; S.push_back(0); for (int u = N-1; u > 0; --u) { // 候補ノード u を追加してみて K 以内に到達可能か試す。 S.push_back(u); if ( calcCost(S) > K ) S.pop_back(); // だめだったら候補から外す。 // 次の追加で S の容量を超えたら終了 if (S.size()+1 > (size_t)K+1) break; } // 獲得 coin の枚数を計算 ll ans = 0; for (auto i : S) ans += (1LL<<i)-1; cout<<ans<<endl; } };