[模板] 最短路计数

考虑这样一个问题

给出一个 $n$ 个点 $m$ 条边的无向带正权图,询问从顶点 $1$ 开始,到其他每个点的最短路有多少条

思路

我们知道不带负权的最短路可以使用 Dijkistra 算法求解,考虑如何在 Dijkistra 过程中如何计算最短路条数

我们用 $num[x]$ 表示到达第 $x$ 个节点的最短路条数

在 Dijkistra 过程中,每一个节点只会被访问一次。在通过点 $x$ 对点 $v$ 松弛的过程中,如果

$$dis[v] > dis[x] + w $$

说明点 $v$ 的当前最短路必然通过点 $x$,那么

$$num[x] = num[v] $$

并且进行正常的松弛操作

如果

$$dis[v] = dis[x] + w $$

说明点 $v$ 的最短路也可以通过点 $x$,那么

$$num[x] \leftarrow num[x] + num[v] $$

不需要进行松弛操作

最后依次输出 $num$ 即可

Luogu-1144 最短路计数

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#define P pair<int, int>

using namespace std;
const int MAXN = 1000000 + 10;
const int MAXM = 4000000 + 10;
const int MOD = 100003;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int Head[MAXN], to[MAXM], Next[MAXM], w[MAXM], tot = 1;
int dis[MAXN], num[MAXN];
int n, m;
bool vis[MAXN];

priority_queue<P, vector<P>, greater<P> >que;

inline void dij(int start){
    memset(dis, 0x3f, sizeof(dis));

    dis[start] = 0;
    num[start] = 1;
    que.push(make_pair(0, start));

    while(!que.empty()){
        P front = que.top();
        que.pop();

        if(vis[front.second])
            continue;

        int x = front.second;
        vis[x] = true;

        for(register int i=Head[x]; i; i=Next[i]){
            int v = to[i];

            if(dis[v] > dis[x] + w[i]){
                dis[v] = dis[x] + w[i];
                num[v] = num[x];

                que.push(make_pair(dis[v], v));
            }else if(dis[v] == dis[x] + w[i]){
                num[v] += num[x];
                num[v] %= MOD;
            }
        }
    }
}

inline void add(int a, int b, int c){
    to[tot] = b;
    w[tot] = c;
    Next[tot] = Head[a];
    Head[a] = tot++;
}

int main(){
    n = read();
    m = read();

    for(register int i=1; i<=m; i++){
        int a = read();
        int b = read();

        add(a, b, 1);
        add(b, a, 1);
    }

    dij(1);

    for(register int i=1; i<=n; i++)
        printf("%d\n", num[i]);
    return 0;
}