No.121 傾向と対策:門松列(その2)

No.121 傾向と対策:門松列(その2) - yukicoder

  • 他の人のコードを参考にした。
  • 最後に a==c なる組み合わせを削除する。 ここのアルゴリズムは式を書いてみるまでわからなかった。勉強になるなー。
template <typename T>
class BIT
{
public:
    std::vector<T> data; // [1,n]
    BIT() {}
    BIT(int n) { init(n); }
    void init(int n) { data.resize((1LL<<(int)ceil(log2(n))), 0); }
    // a[i] += x                   O(log(n))
    void add(int i, int x)
    {
            int maxN = data.size()+1; // 2 冪
            for (int k = i+1; k <= maxN; k += (k & -k))
                data[k-1] += x;
    }
    // a[0]+...+a[i-1]             O(log(n))
    T sum(int i)
    {
            T s = 0;
            for (int k = i+1; k > 0; k -= (k & -k))
                s += data[k-1];
            return s;
    }
};

using namespace std;

//
// 愚直なやり方だと O(N^3) で TLE
// bit を使って O(n*log(n)) にする
//
// (i,j,k) をかんがえるとき j を動かしたとき左右の i,k の組み合わせを高速で計算できればよい
// 以下の工夫をする
//  * 左右の累積を bit で管理する
//  * bit の管理は A[i] の個数なので A[i] <= 10^9 を N 以下にマッピング
//  * 重複するものは後で削除
//  * 逐次更新・逐次評価
//
class TrendAndCountermeasures_PineDecorationSequence2 {
public:
    void solve(void) {
            int N;
            cin>>N;

            map<int,int> conv; // mapping 用
            vector<int> A(N);
            REP(i,N)
            {
                cin>>A[i];
                conv[A[i]] = 0;
            }
            // A は一致しているかどうかが問題になるだけなので
            // 0...N-1 をマッピングして max(A) < N にする
            int k = 0;
            for (auto kv : conv)
                conv[kv.first] = k++;
            REP(i,N)
                A[i] = conv[A[i]];

            vector<int> L(N,0); // L[i] := 現在のインデックスの右側にある A[i] の個数(インデックス自身を含む)
            vector<int> R(N,0); // R[i] := 現在のインデックスの左側にある A[i] の個数

            BIT<ll> bitR(N); // 現在のインデックスの右側にある A[i] の累積管理用(インデックス自身を含む)
            BIT<ll> bitL(N); // 現在のインデックスの左側にある A[i] の累積管理用

            // 現在のインデックスを 0 として R,bitR を初期化
            REP(i,N)
            {
                ++R[A[i]];
                bitR.add(A[i],1);
            }
            ll ret = 0;
            // 現在のインデックス i を動かしながら計算をする
            REP(i,N)
            {
                bitR.add(A[i],-1);
                //
                // 左右の A[i]-1 以下のものの組み合わせ
                // (a,b,c) a < b > c
                //
                ret += bitL.sum(A[i]-1) * bitR.sum(A[i]-1);
                //
                // 左右の A[i]+1 以上のものの組み合わせ
                // (a,b,c) a > b < c
                //
                // 左に i 個、右に (N-i-1) 個 A[k] があるので A[i]+1 以上のものの個数は
                // そこから A[i] 以下のものの数を引けばよい
                //
                ret += (i - bitL.sum(A[i])) * ((N-i-1) - bitR.sum(A[i]));
                bitL.add(A[i],1); // 右側に所属していた A[i] を左へ移す
            }
            // a,b,c の a==c なるケースを削除する。
            // 愚直に計算すると O(N^2) なので以下の O(N) アルゴリズムを使う
            //
            //  diff[k] = ∑ L[a]*R[a] とおく
            //           a!=A[k]
            //
            // これを各 i ごとに計算して ret から引けばよい
            //
            // diff[k+1] = ∑ L'[a]*R'[a]
            //           a!=A[k+1]
            //
            // L'[a] = L[a], L'[A[k]] = L[A[k]]+1
            // R'[a] = R[a], R'[A[k]] = R[A[k]]-1
            //
            //  diff[k+1] = ∑ L[A[i]]*R[A[i]] + (L[A[k]]+1)*(R[A[k]]-1)
            //             a!=A[k]
            //             a!=A[k+1]
            //
            //            = diff[k] - L[A[k+1]]*R[A[k+1]] + (L[A[k]]+1)*(R[A[k]]-1)
            //
            // となるので diff は O(1) で更新可能
            //
            ll diff = 0;
            REP(i,N)
            {
                diff -= (ll)L[A[i]] * R[A[i]];
                ret -= diff;    // diff[i] を引く
                // A[i] を右から左へ移動
                --R[A[i]];
                ++L[A[i]];
                diff += (ll)L[A[i]] * R[A[i]];
            }
            cout<<ret-diff<<endl;
    }
};