Challenge: Black Hole

Subdomeniu: Algebra (algebra)

Scor cont: 100.0 / 100

Submission status: Accepted

Submission score: 1.0

Submission ID: 464765002

Limbaj: cpp14

Link challenge: https://www.hackerrank.com/challenges/demidenko-black-hole/problem

Cerinta

Given integers $n$, $a$, $b$ and $M$, calculate the value $\sum \limits_{k=0}^{n} k^a b^k $ modulo $M$.

**Input Format**  
The first line contains the number of test cases $T$.   
Each of the next $T$ lines contains four space-separated integers $n$, $a$, $b$ and $M$. 

**Output Format**  
For each test case output one integer: the value of the sum.

**Note** In this problem we take $0^0 = 1$  

**Constraints**  
$1 \le T \le 6^6+6$  
$0 \le n \le 10^{18}$  
$0 \le a \le 777$  
$0 \le |b| \le 10^{18}$  
$1 \le M \le 10^9$  
The sum of all $a$ in one test file doesn't exceed 1000

**Sample input**  

    5
    3 1 1 100
    3 0 1 100
    3 1 0 100
    44 44 4 444
    77 7 47 747

**Sample Output**  

    6
    4
    0
    288
    288

**Explanation**  

$0^1\times 1^0~ + ~ 1^1 \times 1^1~ + ~2^1\times 1^2 ~ + ~3^1\times 1^3 = 0 + 1 + 2 + 3 = 6$  
$0^0 \times 1^0 ~ + ~ 1^0 \times 1^1 ~ + ~ 2^0\times 1^2 ~ + ~ 3^0\times 1^3 = 1 + 1 + 1 + 1 = 4$  
$0^1\times 0^0 ~ + ~ 1^1\times 0^1 ~ + ~ 2^1\times 0^2 ~ + ~ 3^1\times 0^3 = 0 + 0 + 0 + 0 = 0$

Cod sursa

#include <bits/stdc++.h>
using namespace std;

using i128 = __int128_t;

long long modpow(long long a, long long e, long long mod) {
    if (mod == 1) return 0;
    long long r = 1 % mod;
    a %= mod;
    if (a < 0) a += mod;
    while (e > 0) {
        if (e & 1LL) r = (long long)((i128)r * a % mod);
        a = (long long)((i128)a * a % mod);
        e >>= 1LL;
    }
    return r;
}

vector<long long> combine(const vector<long long>& A, const vector<long long>& B,
                          const vector<long long>& rec, long long mod) {
    int m = (int)rec.size();
    vector<long long> tmp(2 * m - 1, 0);

    for (int i = 0; i < m; ++i) if (A[i]) {
        for (int j = 0; j < m; ++j) if (B[j]) {
            tmp[i + j] = (tmp[i + j] + (long long)((i128)A[i] * B[j] % mod)) % mod;
        }
    }

    for (int i = 2 * m - 2; i >= m; --i) {
        long long c = tmp[i];
        if (!c) continue;
        for (int j = 0; j < m; ++j) {
            tmp[i - 1 - j] = (tmp[i - 1 - j] + (long long)((i128)c * rec[j] % mod)) % mod;
        }
    }

    tmp.resize(m);
    return tmp;
}

long long linear_nth(const vector<long long>& init, const vector<long long>& rec,
                     long long n, long long mod) {
    int m = (int)rec.size();
    if (n < (long long)init.size()) return init[(size_t)n] % mod;

    vector<long long> res(m, 0), base(m, 0);
    res[0] = 1 % mod;
    if (m == 1) {
        base[0] = rec[0] % mod;
    } else {
        base[1] = 1 % mod; // x
    }

    while (n > 0) {
        if (n & 1LL) res = combine(res, base, rec, mod);
        base = combine(base, base, rec, mod);
        n >>= 1LL;
    }

    long long ans = 0;
    for (int i = 0; i < m; ++i) {
        ans = (ans + (long long)((i128)res[i] * init[i] % mod)) % mod;
    }
    return ans;
}

long long solve_one(long long n, int a, long long b, long long mod) {
    if (mod == 1) return 0;

    long long bb = b % mod;
    if (bb < 0) bb += mod;

    int m = a + 2;

    // P(x) = (x - b)^(a+1), ascending.
    vector<long long> P(1, 1 % mod);
    long long negb = (mod - bb) % mod;
    for (int it = 0; it < a + 1; ++it) {
        vector<long long> NP(P.size() + 1, 0);
        for (int i = 0; i < (int)P.size(); ++i) {
            NP[i] = (NP[i] + (long long)((i128)P[i] * negb % mod)) % mod;
            NP[i + 1] = (NP[i + 1] + P[i]) % mod;
        }
        P.swap(NP);
    }

    // Q(x) = (x - 1) * P(x), ascending q[0..m]
    vector<long long> q(m + 1, 0);
    for (int i = 0; i < (int)P.size(); ++i) {
        q[i] = (q[i] - P[i]) % mod;
        if (q[i] < 0) q[i] += mod;
        q[i + 1] = (q[i + 1] + P[i]) % mod;
    }

    vector<long long> rec(m, 0);
    for (int j = 0; j < m; ++j) {
        rec[j] = (mod - q[m - 1 - j]) % mod;
    }

    vector<long long> init(m, 0);
    long long bp = 1 % mod; // b^k
    for (int k = 0; k < m; ++k) {
        long long kp;
        if (k == 0) kp = (a == 0 ? 1 % mod : 0);
        else kp = modpow(k, a, mod);

        long long term = (long long)((i128)kp * bp % mod);
        if (k == 0) init[k] = term;
        else init[k] = (init[k - 1] + term) % mod;

        bp = (long long)((i128)bp * bb % mod);
    }

    return linear_nth(init, rec, n, mod);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    while (T--) {
        long long n, b, mod;
        int a;
        cin >> n >> a >> b >> mod;
        cout << solve_one(n, a, b, mod) << '\n';
    }
    return 0;
}
HackerRank Algebra – Black Hole