Single Round Match 797 - RollMe

MP法を用いて、goalの前  i 文字に文字  j を加えた文字列のsuffixがgoalのprefixと何文字一致するかを前計算します。

後のDPは簡単で、計算量は  O(|goal||die|) です

追記(DPパート)

簡単でで説明を省略しない!ということで簡単に

 dp _ {i} := スタートから  i に行くまでの期待値

とします。

 i \rightarrow i + 1 における変化を考えます。

 i + 1 文字目の目が  d であったとしましょう。

出た目が  d ならば  dp _ {i} + 1 回で  i + 1 で到達できます。

それ以外の場合、 dp _ {i} + 1 回で戻り、その後に  i + 1 まで行くのは(スタートから  i +1 まで行く) - (戻る位置まで行く)と考えることができます。

これを立式すると、

 dp _ {i + 1} = p _ {d} (dp _ {i} + 1) + \displaystyle \sum _ {j \neq d} p _ {j} (dp _ {i} + 1 + dp _ {i + 1} - dp _ {戻る位置} )

となり、整理するとソースコードの式になります

提出コード


struct mint {
    long long x;
    mint(long long x = 0) :x((x% mod + mod) % mod) {}
    mint& operator+=(const mint a) {
        if ((x += a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint a) {
        if ((x += mod - a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator*=(const mint a) {
        (x *= a.x) %= mod;
        return *this;
    }
    mint operator+(const mint a) const {
        mint res(*this);
        return res += a;
    }
    mint operator-(const mint a) const {
        mint res(*this);
        return res -= a;
    }
    mint operator*(const mint a) const {
        mint res(*this);
        return res *= a;
    }
    mint pow(ll t) const {
        if (!t) return 1;
        mint a = pow(t >> 1);
        a *= a;
        if (t & 1) a *= *this;
        return a;
    }

    // for prime mod
    mint inv() const {
        return pow(mod - 2);
    }
    mint& operator/=(const mint a) {
        return (*this) *= a.inv();
    }
    mint operator/(const mint a) const {
        mint res(*this);
        return res /= a;
    }
};


class  RollMe {
public:
    int solve(vector <int> die, string goal) {
        int n = goal.size();
        vector<vector<int>> pre(n + 1, vector<int>(die.size()));

        vector<mint> dp(n + 1);
        int sm = accumulate(die.begin(), die.end(), 0);
        mint inv = mint((ll)sm).inv();
        string s;
        vector<int> A(n + 1, -1);
        for (int i = 0; i < n; i++) {
            int k = A[i];
            for (int j = 0; j < die.size(); j++) {
                int k = A[i];
                s += (char)('0' + j);
                while (k >= 0 and s[i] != s[k]) k = A[k];
                k++;
                pre[i][j] = k;
                if (goal[i] == s.back()) A[i + 1] = k;
                s.pop_back();
            }
            s += goal[i];
        }
        for (int i = 0; i < n; i++) {
            mint t = 0;
            mint fp = 0;
            for (int j = 0; j < die.size(); j++) {
                mint pr = mint(die[j]) * inv;
                if (goal[i] == (char)('0' + j)) {
                    t += pr * (dp[i] + 1);
                }
                else {
                    fp += pr;
                    t += pr * (dp[i] + 1 - dp[pre[i][j]]);
                }
                dp[i + 1] = t * (mint(1) - mint(fp)).inv();
            }
        }
        return dp[n].x;
    }
};