SRM 679 Div2 Hard ForbiddenStreets

TopCoder Statistics - Problem Statement

  • 本番では開いただけの問題。最短経路が複数あるとき、指定の経路を含むもの以外の最短経路を見つけるのがミソな問題。
  • どうやって見つけるのかは思いつず、以下のブログを参考にした。最短経路のパターン数を持っておいて、パターン数を比較すればよいのか。
  • ワーシャルフロイドでも最短パターン数を計算できるのかもだけど、うまく書けなかった。
  • 以下コードではとりあえず long long にパターン数を格納して system test が通ったけど、下記ブログにあるように mod で計算しないとテストケースが厳しいと long long オーバーフローして落ちてしまうのかも。

TopCoder SRM 679 Div2 Hard ForbiddenStreets - kmjp's blog

class ForbiddenStreets {
public:
vector <int> find(int N, vector <int> A, vector <int> B, vector <int> time) {
        int M = A.size();
        const int inf = (1<<30);

        vector<vector<ll>> dist(N,vector<ll>(N,inf)); // 最短経路長
        vector<vector<ll>> pat(N,vector<ll>(N,0));    // 最短経路数
        vector<vector<pair<int,int>>> tree(N);
        REP(i,M)
        {
            tree[A[i]].emplace_back(B[i],time[i]);
            tree[B[i]].emplace_back(A[i],time[i]);
        }
        // O(N*M*log(N))
        REP(start,N)
        {
            typedef pair<int, int> P;
            priority_queue<P,vector<P>,greater<P>> pque;
            auto &d = dist[start];

            d[start] = 0;
            pat[start][start] = 1;
            pque.emplace(0,start);
            while ( !pque.empty() )
            {
                int u,t;
                tie(t,u) = pque.top();
                pque.pop();

                if ( d[u] < t )
                    continue;

                for (const auto &edge : tree[u])
                {
                    int to,cost;
                    tie(to,cost) = edge;

                    if ( d[u] + cost < d[to] )
                    {
                        d[to] = d[u] + cost;
                        // 最短経路長が更新されたのでパターン数も初期化
                        pat[start][to] = 0;

                        pque.emplace(d[to], to);
                    }
                    // 最短経路長が一致するときはパターン数を加算
                    if ( d[u] + cost == d[to] )
                    {
                        pat[start][to] += pat[start][u];
                    }
                }
            }
        }
        vector<vector<int>> cnt(N,vector<int>(N,0));
        vector<int> ret(M,0);
        REP(i,M)
        {
            // A[i]~B[i] 間を止めるとき
            REP(j,N)
            REP(k,N)
            {
                // 最短経路に A[i]~B[i] が含まれるときをチェック
                // pattern 数が一致しない(他に最短経路がある)ときはスキップ
                if ( dist[j][k] == dist[j][A[i]] + time[i] + dist[B[i]][k] &&
                     pat[j][k] == pat[j][A[i]] * pat[B[i]][k] )
                {
                    ++ret[i];
                }
            }
        }
        return ret;
}
};