Tree Coloring Problem Solution Using Fast Fourier Transform

This problem involves calculating valid colorings of a tree under certain constraints. The solution uses inclusion-exclution principle combined with polynomial multiplication via Number Theoretic Transform (NTT).

Basic Approach

We approach the problem by computing the complement: count arrangements where at least one node violates the coloring condition. Using inclusion-exclusion, we consider subsets of edges that create chains in the tree.

Each selected edge connects a parent to child where the child's value equals parent's value plus one. These form continuous segments, and the number of ways to assign values to these segments equals the factorial of the number of segments.

The key insight is that selecting i edges creates (n-i) segments. For each configuration, we compute the product of available choices at each node, which leads us to multiply polynomials of the form (1 + bjx) where bj represents the number of children at node j.

constexpr int MAXN = 1e6 + 10;

// Precomputed factorials and modular arithmetic functions
int fact[MAXN], inv_fact[MAXN];

class Polynomial : public vector<int> {
private:
    static const int PRIMITIVE_ROOT = 3;
    
public:
    using vector<int>::vector;
    
    void numberTheoreticTransform(bool inverse = false) {
        int size = this->size();
        int log_size = 0;
        while((1 << log_size) < size) log_size++;
        
        this->resize(1 << log_size, 0);
        
        static int bit_reverse[MAXN];
        for(int i = 0; i < size; i++) {
            bit_reverse[i] = (bit_reverse[i >> 1] >> 1) | ((i & 1) << (log_size - 1));
            if(i < bit_reverse[i]) swap((*this)[i], (*this)[bit_reverse[i]]);
        }
        
        for(int length = 1; length < size; length <<= 1) {
            long long primitive_root_power = power(PRIMITIVE_ROOT, (MOD - 1) / (length << 1));
            for(int start = 0; start < size; start += length << 1) {
                long long current_root = 1;
                for(int offset = 0; offset < length; offset++, current_root = (current_root * primitive_root_power) % MOD) {
                    int even = (*this)[start + offset];
                    int odd = ((*this)[start + offset + length] * current_root) % MOD;
                    (*this)[start + offset] = (even + odd) % MOD;
                    (*this)[start + offset + length] = (even - odd + MOD) % MOD;
                }
            }
        }
        
        if(inverse) {
            long long inverse_size = power(size, MOD - 2);
            reverse(this->begin() + 1, this->end());
            for(int& elem : *this) elem = (elem * inverse_size) % MOD;
        }
    }
    
    friend Polynomial operator*(Polynomial a, Polynomial b) {
        int result_size = a.size() + b.size() - 1;
        a.resize(result_size, 0); b.resize(result_size, 0);
        a.numberTheoreticTransform(true);
        b.numberTheoreticTransform(true);
        
        for(size_t i = 0; i < a.size(); i++)
            a[i] = (1LL * a[i] * b[i]) % MOD;
            
        a.numberTheoreticTransform(false);
        a.resize(result_size);
        return a;
    }
};

int degree_count[MAXN];

Polynomial divideAndConquer(int left, int right) {
    if(left == right) return {1, degree_count[left]};
    int mid = (left + right) >> 1;
    return divideAndConquer(left, mid) * divideAndConquer(mid + 1, right);
}

void solve() {
    int nodes = read_input();
    memset(degree_count, -1, sizeof(degree_count));
    degree_count[1] = 0;
    
    for(int i = 1; i < nodes; i++) {
        int u = read_input(), v = read_input();
        degree_count[u]++;
        degree_count[v]++;
    }
    
    Polynomial result = divideAndConquer(1, nodes);
    int answer = 0;
    
    for(int i = 0; i < min(nodes, (int)result.size()); i++) {
        long long term = ((i & 1) ? -1LL : 1LL) * fact[nodes - i] % MOD * result[i] % MOD;
        answer = (answer + term + MOD) % MOD;
    }
    
    print_output(answer);
}

Optimized Single Logarithmic Approach

A more efficient method leverages the fact that the sum of all degrees is O(n). By grouping nodes with idantical degrees and applying binomial theorem, we can reduce redundant computations.

Instead of processing each polynomial individually, we compute (1+jx)cj for each degree j appearing cj times. This allows us to perform fewer NTT operations while maintaining correctness.

int frequency[MAXN];

void optimized_solve() {
    int nodes = read_input();
    memset(degree_count, -1, sizeof(degree_count));
    degree_count[1] = 0;
    
    for(int i = 1; i < nodes; i++) {
        int u = read_input(), v = read_input();
        degree_count[u]++;
        degree_count[v]++;
    }
    
    Polynomial accumulator = {1};
    
    for(int i = 1; i <= nodes; i++)
        frequency[degree_count[i]]++;
        
    for(int degree = nodes; degree >= 1; degree--) {
        if(!frequency[degree]) continue;
        
        Polynomial term(frequency[degree] + 1);
        int power_accum = 1;
        
        for(int coeff = 0; coeff <= frequency[degree]; coeff++) {
            term[coeff] = (1LL * power_accum * combination(frequency[degree], coeff)) % MOD;
            power_accum = (1LL * power_accum * degree) % MOD;
        }
        
        accumulator = accumulator * term;
    }
    
    int answer = 0;
    for(int i = 0; i < min(nodes, (int)accumulator.size()); i++) {
        long long contribution = ((i & 1) ? -1LL : 1LL) * fact[nodes - i] % MOD * accumulator[i] % MOD;
        answer = (answer + contribution + MOD) % MOD;
    }
    
    print_output(answer);
}

Square Root Optimization Technique

Observing that only O(√n) distinct degrees exist, we can apply a square-root decomposition strategy. Rather than multiplying polynomials sequentially, we evaluate them at specific points using fast exponentiation.

This technique evaluates all polynomial at roots of unity simultaneously, then performs a single inverse transform. While theoretically having the same complexity as brute force, it offers significant constant factor improvements in practice.

void sqrt_optimized_solve() {
    int nodes = read_input();
    memset(degree_count, -1, sizeof(degree_count));
    degree_count[1] = 0;
    
    for(int i = 1; i < nodes; i++) {
        int u = read_input(), v = read_input();
        degree_count[u]++;
        degree_count[v]++;
    }
    
    for(int i = 1; i <= nodes; i++)
        frequency[degree_count[i]]++;
        
    int transform_size = 1;
    while(transform_size < nodes - frequency[0] + 1)
        transform_size <<= 1;
        
    Polynomial evaluation_points(transform_size, 1);
    int root_of_unity = power(3, (MOD - 1) / __builtin_ctz(transform_size));
    
    for(int degree = nodes; degree >= 1; degree--) {
        if(!frequency[degree]) continue;
        
        int generator = 1;
        for(int i = 0; i < transform_size; i++) {
            int base_value = (1LL * degree * generator + 1) % MOD;
            evaluation_points[i] = (1LL * evaluation_points[i] * power(base_value, frequency[degree])) % MOD;
            generator = (1LL * generator * root_of_unity) % MOD;
        }
    }
    
    evaluation_points.numberTheoreticTransform(false);
    
    int answer = 0;
    for(int i = 0; i < min(nodes, (int)evaluation_points.size()); i++) {
        long long term = ((i & 1) ? -1LL : 1LL) * fact[nodes - i] % MOD * evaluation_points[i] % MOD;
        answer = (answer + term + MOD) % MOD;
    }
    
    print_output(answer);
}

Tags: number-theoretic-transform polynomial-multiplication inclusion-exclusion tree-algorithms combinatorics

Posted on Thu, 14 May 2026 08:51:57 +0000 by Kitara