Typical DP Contest - S マス目

S: マス目 - Typical DP Contest | AtCoder

  • dp の状態が vector で表現される dp
  • 接続判定は union find でやる。
class s_squares {
public:
    //
    // 色の全てのぬり方のうち左上と右下が連結になるようなものの個数を数える
    //
    // 列を進めながら DP する?
    //
    // BWBBB
    // BWBWB
    // BBBWB
    // WWWWB
    //
    // みたいなのをどうやって数える?
    //
    // i列目と i+1 列目を考える
    //
    //   i  i+1 i+2
    //   B   B   B
    //   B   W   B
    //.. W   W   B ...
    //   W   W   B
    //   B   B   B
    //
    // i,i+1 で連結でなくても i+2 以降で連結になるケースもある。
    // そこで各列である白黒の塗り方に連結成分のラベルを張って DP する。
    //
    //   i     i     i
    //   1     1     2
    //   1     1     2
    //   0  or 0  or 0
    //   0     0     0
    //   1     2     1
    //
    // dp[i][combination] = i 列目が combination となるような組み合わせの総数
    //
    // として DP を更新していけばよい。
    void solve(void) {
            int H,W;
            cin>>H>>W;

            map<vector<int>, ll> dp[2];
            // 初期化
            REP(bit, (1<<H))
            {
                // 左上が黒色でないものは飛ばす
                if (!(bit & 1))
                    continue;

                // 初期状態の組み合わせ作る
                // label 1 は左上につながる島のラベルとして固定する。。
                vector<int> conn(H, 0);
                int c = 1;
                REP(h, H)
                {
                    // 上のタイルが黒で自身も黒なら同じ連結成分。そうでなければ別の連結成分とする。
                    // 初期状態なので 100011 みたいなのを考える必要はない。 100022 とみなす
                    // (2 列目以降でマージされるので同じ)
                    if (bit & 1<<h)
                    {
                        if (h > 0 && bit & (1<<(h-1)))
                            conn[h] = conn[h-1];
                        else
                            conn[h] = c++;
                    }
                }
                dp[0][conn] = 1;
            }

            int cur = 0;
            int next = 1;
            // W を左から右へ見ていく
            FOR(_, 1, W)
            {
                dp[next].clear();
                for (auto iter : dp[cur])
                {
                    // conn と bit を比較して接続関係を更新する
                    const auto &conn = iter.first;
                    // union find で接続関係をチェックする。このとき bit と cur conn
                    // の島を分けて考えるため以下のよう使える label を割り振っておく
                    // empty      : 0
                    // cur conn   : 1...H
                    // bit        : H+1...2*H
                    REP(bit, 1<<H)
                    {
                        UFT uft(H*2+1);
                        vector<int> conn2(H, 0);
                        int c = H;

                        // bit の島に番号をつけておく
                        REP(j, H)
                            // 黒色の場所にインデックスをふっておく
                            if (bit & (1<<j))
                                conn2[j] = c++;

                        // 連結部分をマージ
                        REP(j, H)
                        {
                            // 島になっている部分はマージ
                            if (j > 0 && conn2[j] && conn2[j-1])
                                uft.merge(conn2[j], conn2[j-1]);
                            // cur -> bit でつながっているならマージ
                            if (conn[j] && conn2[j])
                                uft.merge(conn[j], conn2[j]);
                        }

                        // 接続結果をまとめる
                        vector<int> to(H*2+1, -1); // UFT の結果から(0...2*H) label(0...H) への map
                        to[0] = 0; // これらは固定
                        to[1] = 1;
                        c = 2;
                        REP(j, H)
                        {
                            if (!conn2[j])
                                continue;
                            // 連結成分ごとに label を貼り直す
                            // 0,1 以外は出現順に label を張っている
                            int root = uft.root(conn2[j]);
                            if (to[root] < 0)
                                to[root] = c++;
                            conn2[j] = to[root];
                        }
                        // label 1 の島(左上につながるもの)がなければ飛ばす
                        if ( to[uft.root(1)] < 0 )
                            continue;
                        (dp[next][conn2] += iter.second) %= MOD;
                    }
                }
                swap(cur, next);
            }
            ll ans = 0;
            for (auto itr : dp[cur])
            {
                if (itr.first[H-1] == 1) // 左上から接続されているやつのみ加算
                    (ans += itr.second) %= MOD;
            }
            cout<<ans % MOD <<endl;
    }
};