Typical DP Contest - T フィボナッチ

T: フィボナッチ - Typical DP Contest | AtCoder

  • kitamasa 法
  • ほとんどきたまさ法メモ - よすぽの日記を見て理解した。
  • 下のコードは再帰で書いているが再帰でない書き方もあるっぽい。そっちはまだ実装できるレベルではない。(理解できてない。)
  • これで Typical DP Contest は全問解いた。長かったー。
//
// kitamasa 法
//
// a[n+k] = d[0]*a[n] + d[2]*a[n+1] + ... + d[k-1]*a[n+k-1] ... (*)
//
// という k 項間漸化式の n 項目を O(k^2*log(n)) で求める
//
// f(n) = {x[0], ..., x[k-1]} なる f を求める
//  ( a[n] = x[0]*a[0]+...+x[k-1]*a[k-1] )
//
// (*) より
// a[n+m] = x[0]*a[m] + x[1]*a[m+1] + ... + x[k-1]*a[m+k-1]  (∀m >= 0, {x} depend on n)
//
// となる。({x} はあくまで k 項間漸化式 (*) からのみ決定し、初項に依存しないため。 ただし n には依存 )
//
// a[n+1] = x[0]*a[1] + x[1]*a[2] + ... + x[k-1]*a[k]
//        = x[0]*a[1] + ...       + x[k-2]*a[k-1] + x[k-1]*(d[0]*a[0] + d[1]*a[1] + ... + d[k-1]*a[k-1])
//        = (x[k-1]*d[0])*a[0] + (x[0] + x[k-1]*d[1])*a[1] + ... + (x[k-2] + x[k-1]*d[k-1])*a[k-1]
//
// よって O(k) で f(n) を更新できる
//   (x[0], ..., x[k-1]) ->  (x[k-1]*d[0], x[0]+x[k-1]*d[1], ..., x[k-2]+x[k-1]*d[k-1])
//         f(n)                             f(n+1)
//
// f(n), f(n+1), ..., f(n+k-1) を O(k^2) で列挙できる。すると
//
// a[2*n] = x[0]*a[n] + x[1]*a[n+1] + ... + x[k-1]*a[n+k-1]
//        = x[0]*f(n)%*%A + x[1]*f(n+1)%*%A + ... + x[k-1]*f(n+k-1)%*%A
//
// (A=(a[0], ..., a[k-1]), %*% はベクトルの内積) と f(n) -> f(2*n) を O(k^2) で計算できる。
//
// 一般の N は 2 の冪上の和で計算できるので f(N) は O(k^2*log(N)) で計算できる
//
// f(N) = f(2*n+1) = A*f(2*n1) = A*B*f(n1) = ...
//
// O(k^2*log(n))
template <typename T>
T kitamasa(const std::vector<T> &a0,   // 初項
           const std::vector<T> &d,    // 係数
           int n)                      // 求めたい index
{
        int k = a0.size();
        // f(m) -> f(m+1)
        auto increment = [&](std::vector<T> &x) {
            T last = x[k-1];
            for (int j = k-1; j >= 1; --j)
                x[j] = x[j-1]+last*d[j];
            x[0] = last*d[0];
        };
        std::function<void(std::vector<T>&, int)> rec = [&](std::vector<T> &x, int n) {
            if (n == 0)
            {
                std::fill(x.begin(), x.end(), 0);
                x[0] = 1; // f(0) = (1,0,...,0)
                return;
            }
            if (n & 1)          // f(n) <- f(n-1)
            {
                rec(x, n-1);
                increment(x);
            }
            else                // f(2*n) <- f(n)
            {
                std::vector<T> tmp(k, 0);
                rec(tmp, n/2);
                std::vector<T> tmp2(tmp);
                for (int j = 0; j < k; ++j)
                    x[j] = tmp2[0]*tmp[j];
                for (int i = 1; i < k; ++i)
                {
                    // f(n+1) <- f(n)
                    increment(tmp);
                    for (int j = 0; j < k; ++j)
                        x[j] += tmp2[i]*tmp[j];
                }
            }
        };
        std::vector<T> xs(k, 0);
        rec(xs, n);
        T ret = 0;
        for (int i = 0; i < k; ++i)
            ret += a0[i]*xs[i];
        return ret;
}
class t_fibonacci {
public:
    void solve(void) {
            int K, N;
            cin>>K>>N;
            ModInt::set_mod((int)1E+9+7);
            vector<ModInt> d(K, 1);
            vector<ModInt> a(K, 1);
            cout<<kitamasa(a, d, N-1)<<endl;
    }
};