Typical DP Contest - N 木

N: 木 - Typical DP Contest | AtCoder

  • dp の問題かと思いきや dp じゃなくても解けた問題。c++ じゃなかったら dp じゃないときついのかな。
  • MOD が素数なのでフェルマーの小定理を使って割り算を表現する。
const int MOD = (int)(1e+9)+7;
static vector<ll> mfac_impl;
// mod factorial
void mfac_init(int maxN) {
        mfac_impl.resize(maxN+1,0);
        mfac_impl[0] = 1;
        REP(i, maxN)
            mfac_impl[i+1] = (mfac_impl[i]*(i+1))%MOD;
}
ll mfac(int n) { return mfac_impl[n]; }

// MOD での逆数計算
// フェルマーの小定理と MOD が素数であることから a^-1 = a^(MOD-2)
ll minv(ll a) { return mpow(a, MOD-2, MOD); }

class n_tree {
public:
    vector<vector<int>> tree_;
    pair<ll,int> dfs(int v, int par) {
            // 子ノードをどの順で選ぶか
            //  v
            // p q
            //
            // と2つノードがあるときを考える。
            // p 以下の辺の引き方と q 以下の辺の引き方はそれぞれ独立している。
            // p 以下(と v から p 自身)で辺を張る操作を P と表現すると v 以下の辺の引き方は
            //
            // PQQPQPP...
            // QQPQPPP...
            //
            // などと書ける。この PQ の並びは
            //  size(p)+size(q) C size(p) (size(v) := v 以下(自身を含む)のノードの個数)
            // と書けるので
            //
            // v 以下のノードの辺のはりかたは
            //
            //  (size(p)+size(q))! / (size(p)! * size(q)!) * dfs(p) * dfs(q)
            //
            // となる。子ノードが複数のときも同様。
            //
            int sz = 0;
            ll res = 1;
            ll deno = 1;
            for (auto u : tree_[v])
            {
                if (par == u)
                    continue;
                int csz;
                ll  cres;
                tie(cres, csz) = dfs(u, v);
                (deno *= mfac(csz)) %= MOD;
                sz += csz;
                (res *= cres) %= MOD;
            }
            // Combination 分の計算
            res = (mfac(sz)*minv(deno)) % MOD * res;
            res %= MOD;

            return make_pair(res, sz+1);
    }
    void solve(void) {
            const int MOD = (int)(1e+9)+7;
            int N;
            cin>>N;
            tree_.resize(N);
            REP(i, N-1)
            {
                int a, b;
                cin>>a>>b;
                --a;
                --b;
                tree_[a].push_back(b);
                tree_[b].push_back(a);
            }
            mfac_init(N);
            // 根を決めて dfs していけばよい
            // O(N^2) ... dp でキャッシュしなくても間に合う
            ll res = 0;
            REP(root, N)
                (res += dfs(root, -1).first) %= MOD;

            // a -> b と辺を引くのと b -> a と辺を引くのは同じなので 2 で割る
            cout<<(res*minv(2))%MOD<<endl;
    }
};