Challenge: Parity Party

Subdomeniu: Combinatorics (combinatorics)

Scor cont: 100.0 / 100

Submission status: Accepted

Submission score: 1.0

Submission ID: 464814485

Limbaj: cpp14

Link challenge: https://www.hackerrank.com/challenges/parity-party/problem

Cerinta

We need your help to divide candies at a very unusual party!  
There are $n$ different candies in total. There are three kinds of people at party:   
  
* $a$ of them want to get *odd* number of candies,  
* $b$ of them want to get *even* number of candies,
* $c$ simply don't care about parity of candies they get.   
    
Find out the number of ways to divide all of $n$ candies between everybody ($a + b + c$ people), such that everyone is satisfied.  Some people may not receive a candy.

Input Format

One line of input contains four space-separated integers $n, a, b, c$.

Output Format

Print one line containing answer to the problem modulo $7340033$.

Constraints

+ $1 \leq n \leq 10^9$,  
+ $0 \leq a, b, c \leq 50000$,  
+ $1 \leq a + b + c$.

Cod sursa

#include <bits/stdc++.h>
using namespace std;

static const int MOD = 7340033; // 7 * 2^20 + 1
static const int G = 3;
static const int ROOT_PW = 1 << 20;
static const int ROOT = 2187; // 3^7, primitive 2^20-th root mod MOD

static int mod_pow(long long a, long long e) {
    long long r = 1 % MOD;
    a %= MOD;
    while (e > 0) {
        if (e & 1) r = (r * a) % MOD;
        a = (a * a) % MOD;
        e >>= 1;
    }
    return (int)r;
}

static inline int addm(int a, int b) {
    a += b;
    if (a >= MOD) a -= MOD;
    return a;
}

static inline int subm(int a, int b) {
    a -= b;
    if (a < 0) a += MOD;
    return a;
}

static void ntt(vector<int>& a, bool invert) {
    int n = (int)a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        int wlen = mod_pow(ROOT, ROOT_PW / len);
        if (invert) wlen = mod_pow(wlen, MOD - 2);
        for (int i = 0; i < n; i += len) {
            int w = 1;
            for (int j = 0; j < len / 2; j++) {
                int u = a[i + j];
                int v = (int)(1LL * a[i + j + len / 2] * w % MOD);
                a[i + j] = addm(u, v);
                a[i + j + len / 2] = subm(u, v);
                w = (int)(1LL * w * wlen % MOD);
            }
        }
    }

    if (invert) {
        int n_inv = mod_pow(n, MOD - 2);
        for (int &x : a) x = (int)(1LL * x * n_inv % MOD);
    }
}

static vector<int> convolution(vector<int> a, vector<int> b) {
    int need = (int)a.size() + (int)b.size() - 1;
    int n = 1;
    while (n < need) n <<= 1;
    a.resize(n);
    b.resize(n);
    ntt(a, false);
    ntt(b, false);
    for (int i = 0; i < n; i++) a[i] = (int)(1LL * a[i] * b[i] % MOD);
    ntt(a, true);
    a.resize(need);
    return a;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    long long n;
    int a, b, c;
    if (!(cin >> n >> a >> b >> c)) return 0;

    int m = a + b + c;
    int d = a + b;
    int lim = max(a, b);

    vector<int> fact(lim + 1), invfact(lim + 1);
    fact[0] = 1;
    for (int i = 1; i <= lim; i++) fact[i] = (int)(1LL * fact[i - 1] * i % MOD);
    invfact[lim] = mod_pow(fact[lim], MOD - 2);
    for (int i = lim; i >= 1; i--) invfact[i - 1] = (int)(1LL * invfact[i] * i % MOD);

    auto C = [&](int N, int R) -> int {
        if (R < 0 || R > N) return 0;
        return (int)(1LL * fact[N] * invfact[R] % MOD * invfact[N - R] % MOD);
    };

    vector<int> A(a + 1), B(b + 1);
    for (int i = 0; i <= a; i++) {
        int v = C(a, i);
        if (i & 1) v = (v == 0 ? 0 : MOD - v);
        A[i] = v;
    }
    for (int j = 0; j <= b; j++) B[j] = C(b, j);

    vector<int> conv = convolution(A, B);

    long long S = 0;
    for (int t = 0; t <= d; t++) {
        int base = m - 2 * t;
        base %= MOD;
        if (base < 0) base += MOD;
        int pw = mod_pow(base, n);
        S += 1LL * conv[t] * pw % MOD;
        if (S >= MOD) S -= MOD;
    }

    int inv2 = (MOD + 1) / 2;
    int invPow2 = mod_pow(inv2, d);
    int ans = (int)(S % MOD * 1LL * invPow2 % MOD);
    cout << ans << '\n';
    return 0;
}
HackerRank Combinatorics – Parity Party