No.14 最小公倍数ソート

No.14 最小公倍数ソート - yukicoder

  • 印象深かったのでメモ
  • 最初方針はあっていたのだけど、O(N^2) で TLE してしまった。
  • 以下は他の人のコードを見て修正して通したやつ。
  • swap で pivot の位置をずらしてループの回数を減らすのはうまいと思った。(solve1)
  • int の set じゃなくて pair の set に (a[i],i) を突っ込んでソートさせるのはうまいと思った。(solve_optim)
  • こういうデータ構造の選定がうまくできるようになりたい。
class LCMsort {
public:
    // O(N*N*log(A))
    // 速度的にぎりぎりの解法
    void solve1(void) {
            int N;
            cin>>N;
            vector<int> a(N,0);
            REP(i,N)
                cin>>a[i];

            // sort を進めるにつれソート対象列が短くなるので、実際にソートするのでなく、lcm が最も小さい
            // ものを取り出していくだけでよい。
            // O(N^2*log(A))
            REP(pivot,N)
            {
                cout<<a[pivot]<<" ";
                int mx = (1<<30);
                int mi = -1;
                // O(N*log(A)) pivot が更新されるにつれこのループは高速化されていく。
                // (N-1) -> (N-2) -> ... -> (1)
                FOR(i, pivot+1, N)
                {
                    int k = lcm(a[pivot], a[i]);
                    if (k < mx || (k == mx && a[i] < a[mi]))
                    {
                        mx = k;
                        mi = i;
                    }
                }
                if (mi < 0)
                    break;
                // pivot を入れ替えることで FOR(i, pivot+1, N) のループ回数を減らせる
                swap(a[pivot+1], a[mi]);
            }
            cout<<endl;
    }
    // O(√N*N*log(A))
    void solve_optim(void) {
            int N;
            cin>>N;
            vector<int> a(N,0);
            int maxA = 0;
            REP(i,N)
            {
                cin>>a[i];
                maxA = max(a[i], maxA);
            }

            // 前計算として約数 -> index のマップを作っておく
            vector<vector<int>>        div(N+1);  // div[i] := a[i] の約数リスト
            vector<set<pair<int,int>>> S(maxA+1); // S[d] := 約数 d を持つ a[i] のリスト
                                                  // pair にしているのは a[i] が小さい順にソートしておきたいから
                                                  // second に i を入れることで a[i]==a[j] なら
                                                  // i<j なるものを選択されるようにしている。
            // O(sqrt(maxA)*N*log(maxA))
            REP(i,N)
            {
                int x = a[i];
                // O(sqrt(maxA)) loop
                for (int d = 1; d*d <= x; ++d)
                {
                    if (x%d == 0)
                    {
                        div[i].push_back(d);
                        if (d != x/d)
                            div[i].push_back(x/d);
                        // 各約数ごとに元になった数字と index を格納
                        // O(log(maxA))
                        S[d].emplace(x,i);
                        S[x/d].emplace(x,i);
                    }
                }
            }

            int pivot = 0;
            REP(_,N)
            {
                cout<<a[pivot]<<" ";
                int mx = (1<<30);
                int mi = -1;
                // O(sart(maxA)*log(maxA))
                // a[pivot] の約数ごとに捜査する
                for (auto d : div[pivot])
                {
                    // 自分自身を参照しないように削除しておく(今後も参照されないはず)
                    S[d].erase(make_pair(a[pivot], pivot));
                    if (S[d].empty())
                        continue;
                    // 先頭の組み合わせを取得
                    // 約数 d を持つ a[i] のうち最小のものをとる
                    // pivot として出現しているものは上の erase で消えているので見つからないはず。
                    // (S[d]に入っている数列は d 以外に約数を持たないとして考えてよいので)
                    // d*b1 <= d*b2 < ... なので lcms = x*b1 <= x*b2 <= ... となる。(x=a[pivot])
                    // よって最小の (aa,ii) だけとりだせばよい。
                    int aa, ii;
                    tie(aa,ii) = *S[d].begin();
                    int k = lcm(a[pivot], aa);
                    if (k < mx || (k == mx && a[ii] < a[mi]))
                    {
                        mx = k;
                        mi = ii;
                    }
                }
                if (mi < 0)
                    break;
                pivot = mi;
            }
            cout<<endl;
    }
    void solve(void) {
            solve_optim();
    }
};