Minimum Spanning Tree Construction with Modular Arithmetic

Minimum Spanning Tree with Modular Edge Weights

Prim's Algorithm Adaptation

The classic Prim's algorithm builds a minimum spanning tree by starting from an initial vertex and repeatedly adding the minimum-weight edge connecting the tree to vertices outside it. For this problem, we adapt Prim's algorithm to handle edge weights computed as (a[i] + a[j]) % k.

The key optimization involves maintaining a sorted set of vertex weights and using binary search to efficiently find the optimal connection. For each vertex u, we locate the vertex v whose weight is closest to k - a[u] modulo k. If no such vertex exists, we select the vertex with the smallest weight.

We maintain a priority queue of candidate edges and ensure we don't add edges that would create cycles by checking if vertices belong to different connected components.


#include <iostream>
#include <set>
#include <queue>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;

void buildMST() {
    int vertices, mod;
    cin >> vertices >> mod;
    
    set<array<int, 2>> weightSet;
    vector<int> weights(vertices + 1);
    for (int i = 1; i <= vertices; i++) {
        cin >> weights[i];
        weights[i] %= mod;
        weightSet.insert({weights[i], i});
    }
    
    auto findOptimal = [&](int val) -> array<int, 2> {
        auto candidate = weightSet.lower_bound({mod - val, 0});
        return candidate == weightSet.end() ? *weightSet.begin() : *candidate;
    };
    
    priority_queue<array<int, 3>, vector<array<int, 3>>, greater<>> edgeQueue;
    
    auto [w1, start] = *weightSet.begin();
    weightSet.erase(weightSet.begin());
    auto [w2, next] = findOptimal(w1);
    
    ll total = 0;
    edgeQueue.push({(w1 + w2) % mod, start, next});
    
    while (!edgeQueue.empty()) {
        auto [cost, u, v] = edgeQueue.top();
        edgeQueue.pop();
        
        if (!weightSet.count({weights[v], v})) continue;
        
        total += cost;
        weightSet.erase(weightSet.lower_bound({weights[v], v}));
        if (weightSet.empty()) break;
        
        auto [w3, x] = findOptimal(weights[v]);
        edgeQueue.push({(w3 + weights[v]) % mod, v, x});
        auto [w4, y] = findOptimal(weights[u]);
        edgeQueue.push({(w4 + weights[u]) % mod, u, y});
    }
    
    cout << total << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int testCases;
    cin >> testCases;
    while (testCases--) {
        buildMST();
    }
    return 0;
}

Boruvka's Algorithm Approach

Boruvka's algorithm provides an alternative approach where we process the graph in phases, finding the minimum-weight edge from each connected component to other components in each phase.

After sorting vertex weights and computing modulo k, we identify a boundary point for each vertex where edge weights transition from a[i] + a[j] to a[i] + a[j] - k. This creates a staircase pattern in the adjacency matrix.

We maintain next pointers that track the optimal connection candidates and use a union-find data structure to manage connected components.


#include <iostream>
#include <vector>
#include <algorithm>
#include <tuple>
using namespace std;
using ll = long long;

struct UnionFind {
    vector<int> parent, size;
    
    UnionFind(int n) {
        parent.resize(n);
        size.assign(n, 1);
        iota(parent.begin(), parent.end(), 0);
    }
    
    int find(int x) {
        while (x != parent[x]) {
            x = parent[x] = parent[parent[x]];
        }
        return x;
    }
    
    bool connected(int x, int y) {
        return find(x) == find(y);
    }
    
    bool unite(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) return false;
        if (size[x] < size[y]) swap(x, y);
        size[x] += size[y];
        parent[y] = x;
        return true;
    }
};

void solveBoruvka() {
    int n, k;
    cin >> n >> k;
    
    vector<int> arr(n + 1);
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
        arr[i] %= k;
    }
    
    sort(arr.begin() + 1, arr.end());
    
    int boundary = n + 1;
    vector<int> nextPtr(n + 1);
    for (int i = 1; i <= n; i++) {
        while (boundary - 1 >= 1 && arr[boundary - 1] + arr[i] >= k) {
            boundary--;
        }
        nextPtr[i] = boundary;
    }
    
    ll result = 0;
    UnionFind uf(n + 1);
    
    while (true) {
        vector<tuple<int, int, int>> edges;
        
        for (int i = 1; i <= n; i++) {
            int minCost = INT_MAX, target = -1, current = i;
            
            while (nextPtr[current] <= n && uf.connected(current, nextPtr[current])) {
                nextPtr[current]++;
            }
            
            if (nextPtr[current] <= n && !uf.connected(current, nextPtr[current])) {
                int cost = (arr[nextPtr[current]] + arr[current]) % k;
                if (cost < minCost) {
                    minCost = cost;
                    target = nextPtr[current];
                }
            } else if (nextPtr[current] > n && !uf.connected(current, 1)) {
                int cost = (arr[1] + arr[current]) % k;
                if (cost < minCost) {
                    minCost = cost;
                    target = 1;
                }
            }
            
            if (target != -1) {
                edges.emplace_back(minCost, current, target);
            }
        }
        
        sort(edges.begin(), edges.end());
        for (auto &[cost, u, v] : edges) {
            if (uf.unite(u, v)) {
                result += cost;
            }
        }
        
        if (uf.size[uf.find(1)] == n) break;
    }
    
    cout << result << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int t;
    cin >> t;
    while (t--) {
        solveBoruvka();
    }
    return 0;
}

Counting Pairs with Square Root Decomposition

Frequency-Based Optimization

For counting pairs of elements in sequences, we employ a square root decomposition strategy. We classify elements based on their frequency, using a threshold B ≈ √(n²/q).

For elements with frequency below B, we perform brute-force enumeration combined with binary search. For pairs where both elements have high frequency, we use preprocessing to optimize queries.


#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
using ll = long long;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int n, q;
    cin >> n >> q;
    
    vector<int> sequence(n + 1);
    map<int, vector<int>> positions;
    for (int i = 1; i <= n; i++) {
        cin >> sequence[i];
        positions[sequence[i]].push_back(i);
    }
    
    int threshold = sqrt(1.0 * n * n / q);
    map<array<int, 2>, ll> cachedResults;
    
    while (q--) {
        int x, y;
        cin >> x >> y;
        
        if (cachedResults.count({x, y})) {
            cout << cachedResults[{x, y}] << "\n";
            continue;
        }
        
        ll count = 0;
        int freqX = positions[x].size();
        int freqY = positions[y].size();
        
        if (x == y) {
            count = 1LL * (freqX - 1) * freqX / 2;
        } else if (freqX < threshold || (freqX >= threshold && freqY >= threshold && freqX < freqY)) {
            for (auto pos : positions[x]) {
                count += positions[y].end() - lower_bound(positions[y].begin(), positions[y].end(), pos);
            }
        } else {
            for (auto pos : positions[y]) {
                count += lower_bound(positions[x].begin(), positions[x].end(), pos) - positions[x].begin();
            }
        }
        
        cachedResults[{x, y}] = count;
        cout << count << "\n";
    }
    
    return 0;
}

Offline Processing with Discretization

This approach involves preprocessing elements with frequency above √n and handling low-frequency pairs directly. We discretize values and precompute prefix sums for efficient range queries.


#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;
using ll = long long;

void processQueries() {
    int n, q;
    cin >> n >> q;
    int blockSize = sqrt(n);
    
    vector<int> arr(n + 1);
    vector<int> allValues;
    
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
        allValues.push_back(arr[i]);
    }
    
    vector<array<int, 2>> queries;
    for (int i = 0; i < q; i++) {
        int x, y;
        cin >> x >> y;
        queries.push_back({x, y});
        allValues.push_back(x);
        allValues.push_back(y);
    }
    
    sort(allValues.begin(), allValues.end());
    allValues.erase(unique(allValues.begin(), allValues.end()), allValues.end());
    
    int totalValues = allValues.size();
    vector<int> frequency(totalValues + 3);
    vector<vector<int>> valuePositions(totalValues + 3);
    
    for (int i = 1; i <= n; i++) {
        int compressed = lower_bound(allValues.begin(), allValues.end(), arr[i]) - allValues.begin() + 1;
        arr[i] = compressed;
        valuePositions[compressed].push_back(i);
        frequency[compressed]++;
    }
    
    vector<vector<int>> forwardQueries(totalValues + 3);
    vector<vector<int>> backwardQueries(totalValues + 3);
    
    for (auto [x, y] : queries) {
        int compX = lower_bound(allValues.begin(), allValues.end(), x) - allValues.begin() + 1;
        int compY = lower_bound(allValues.begin(), allValues.end(), y) - allValues.begin() + 1;
        forwardQueries[compX].push_back(compY);
        backwardQueries[compY].push_back(compX);
    }
    
    vector<vector<ll>> forwardResults(totalValues + 3);
    vector<vector<ll>> backwardResults(totalValues + 3);
    
    for (int i = 1; i <= totalValues; i++) {
        sort(forwardQueries[i].begin(), forwardQueries[i].end());
        forwardQueries[i].erase(unique(forwardQueries[i].begin(), forwardQueries[i].end()), forwardQueries[i].end());
        forwardResults[i].resize(forwardQueries[i].size());
        
        sort(backwardQueries[i].begin(), backwardQueries[i].end());
        backwardQueries[i].erase(unique(backwardQueries[i].begin(), backwardQueries[i].end()), backwardQueries[i].end());
        backwardResults[i].resize(backwardQueries[i].size());
    }
    
    vector<int> indexMap(totalValues + 3, -1);
    
    for (int i = 1; i <= totalValues; i++) {
        if (forwardQueries[i].empty()) continue;
        if (frequency[i] >= blockSize) {
            int idx = 0;
            for (auto target : forwardQueries[i]) {
                indexMap[target] = idx++;
            }
            
            ll prefixSum = 0;
            for (int j = 1; j <= n; j++) {
                int current = arr[j];
                if (indexMap[current] != -1 && frequency[current] < blockSize) {
                    forwardResults[i][indexMap[current]] += prefixSum;
                }
                if (current == i) prefixSum++;
            }
            
            for (auto target : forwardQueries[i]) {
                indexMap[target] = -1;
            }
        }
    }
    
    for (int i = 1; i <= totalValues; i++) {
        if (backwardQueries[i].empty()) continue;
        if (frequency[i] >= blockSize) {
            int idx = 0;
            for (auto source : backwardQueries[i]) {
                indexMap[source] = idx++;
            }
            
            ll suffixSum = 0;
            for (int j = n; j >= 1; j--) {
                int current = arr[j];
                if (indexMap[current] != -1) {
                    backwardResults[i][indexMap[current]] += suffixSum;
                }
                if (current == i) suffixSum++;
            }
            
            for (auto source : backwardQueries[i]) {
                indexMap[source] = -1;
            }
        }
    }
    
    for (auto [x, y] : queries) {
        int compX = lower_bound(allValues.begin(), allValues.end(), x) - allValues.begin() + 1;
        int compY = lower_bound(allValues.begin(), allValues.end(), y) - allValues.begin() + 1;
        
        if (frequency[compX] >= blockSize && frequency[compY] < blockSize) {
            int pos = lower_bound(forwardQueries[compX].begin(), forwardQueries[compX].end(), compY) - forwardQueries[compX].begin();
            cout << forwardResults[compX][pos] << "\n";
        } else if (frequency[compY] >= blockSize) {
            int pos = lower_bound(backwardQueries[compY].begin(), backwardQueries[compY].end(), compX) - backwardQueries[compY].begin();
            cout << backwardResults[compY][pos] << "\n";
        } else {
            ll total = 0;
            int left = 0;
            for (int i = 0; i < valuePositions[compY].size(); i++) {
                int position = valuePositions[compY][i];
                while (left < valuePositions[compX].size() && valuePositions[compX][left] < position) {
                    left++;
                }
                total += left;
            }
            cout << total << "\n";
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    processQueries();
    return 0;
}

Tags: minimum-spanning-tree prim-algorithm boruvka-algorithm modular-arithmetic square-root-decomposition

Posted on Sat, 16 May 2026 04:00:26 +0000 by dbchip2000