[模板] A* 与 IDA* 算法

A* 和 IDA* 都是启发式搜索的一种,可以利用当前拥有的启发信息来引导自己,从而减小搜索范围和次数,降低时间复杂度。

A* 是在普通搜素的基础上新增了一个估价函数
IDA* 只是在 A* 的基础上增加了一个迭代加深的过程

估价函数

估价函数的一般形式为
$$f(x) = g(x) + h(x)$$

其中,$g(x)$表示到达这个状态需要的实际代价,$h(x)$表示从这个状态到达最终状态的预估代价

我们不要求 $h(x)$ 这个预估代价一定准确,假如 $s$ 为最优代价

若 $h(x) > s$,那么这个估价函数是错误的,不一定能得到最优解
若 $h(x) = s$,那么这个估价函数是最优的,效率更高
若 $h(x) < s$,那么这个估价函数也是正确的,但是效率不如第二种

我们每一次在搜索的过程中都选取代价最小的 $f(x)$ 对应的节点去扩展状态即可

例题 SCOI2005 骑士精神

题目描述

求最小的完成步数

题目分析

这道题我们通过类似于剪枝的方式进行搜素,首先我们对于当前的一个状态 $sta$,我们预估需要多少步移动才能达到目标状态。

我们简单的将其判定为与目标状态有多少个不同样式的格子,当然,这不是最优的,最后一步移动有可能一步能重置两个格子,因此在剪枝的时候不能取等。

再加上迭代加深即可。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;

const int target[5][5] =
{
    {1, 1, 1, 1, 1},
    {-1, 1, 1, 1, 1},
    {-1, -1, 0, 1, 1},
    {-1, -1, -1, -1, 1},
    {-1, -1, -1, -1, -1}
};

const int dx[] = {1, 1, -1, -1, 2, 2, -2, -2};
const int dy[] = {2, -2, 2, -2, 1, -1, 1, -1};

const int MAXN = 10 + 10;

int sx, sy;
int T, ans;
int status[MAXN][MAXN];
char input[MAXN][MAXN];

inline int expect(){
    int ans = 0;

    for(register int i=0; i<5; i++){
        for(register int j=0; j<5; j++){
            if(status[i][j] != target[i][j])
                ++ans;      
        }
    }   

    return ans;
}

inline bool IDA(int x, int y, int step, int maxdeep){
    int g = expect();

    if(step + g > maxdeep + 1)
        return false;

    if(g == 0){
        return true;
    }       

    for(register int op=0; op<8; op++){
        int nx = x + dx[op];
        int ny = y + dy[op];

        if(nx >=0 && ny >=0 && nx < 5 && ny < 5){
            swap(status[x][y], status[nx][ny]);
            if(IDA(nx, ny, step+1, maxdeep))
                return true;
            swap(status[x][y], status[nx][ny]);
        }
    }

    return false;
}
int main(){

    scanf("%d", &T);

    while(T--){
        for(register int i=0; i<5; i++)
            scanf("%s", input[i]);

        for(register int i=0; i<5; i++){
            for(register int j=0; j<5; j++){
                char tmp = input[i][j];

                if(tmp == '1')
                    status[i][j] = 1;
                else if(tmp == '0')
                    status[i][j] = -1;
                else{
                    sx = i;
                    sy = j;
                    status[i][j] = 0;
                }
            }
        }

        ans = 16;

        for(register int i=1; i<=15; i++){
            if(IDA(sx, sy, 0, i)){
                ans = i;
                break;
            }
        }

        if(ans == 16)
            ans = -1;
        printf("%d\n", ans);    
    }


    return 0;
}

这个例题只是说明估价函数对于搜索状态的减少,下面是一个经典 A* 算法的运用

例题 [模板]K短路

题目描述

iPig在假期来到了传说中的魔法猪学院,开始为期两个月的魔法猪训练。经过了一周理论知识和一周基本魔法的学习之后,iPig对猪世界的世界本原有了很多的了解:众所周知,世界是由元素构成的;元素与元素之间可以互相转换;能量守恒……。

能量守恒……iPig 今天就在进行一个麻烦的测验。iPig 在之前的学习中已经知道了很多种元素,并学会了可以转化这些元素的魔法,每种魔法需要消耗 iPig 一定的能量。作为 PKU 的顶尖学猪,让 iPig 用最少的能量完成从一种元素转换到另一种元素……等等,iPig 的魔法导猪可没这么笨!这一次,他给 iPig 带来了很多 1 号元素的样本,要求 iPig 使用学习过的魔法将它们一个个转化为 N 号元素,为了增加难度,要求每份样本的转换过程都不相同。这个看似困难的任务实际上对 iPig 并没有挑战性,因为,他有坚实的后盾……现在的你呀!

注意,两个元素之间的转化可能有多种魔法,转化是单向的。转化的过程中,可以转化到一个元素(包括开始元素)多次,但是一但转化到目标元素,则一份样本的转化过程结束。iPig 的总能量是有限的,所以最多能够转换的样本数一定是一个有限数。具体请参看样例。

输入格式:

第一行三个数 N、M、E 表示iPig知道的元素个数(元素从 1 到 N 编号)、iPig已经学会的魔法个数和iPig的总能量。

后跟 M 行每行三个数 $s_i$、$t_i$、$e_i$ 表示 iPig 知道一种魔法,消耗 $e_i$ 的能量将元素 $s_i$ 变换到元素 $t_i$ 。

输出格式

一行一个数,表示最多可以完成的方式数。输入数据保证至少可以完成一种方式。

样例输入

4 6 14.9
1 2 1.5
2 1 1.5
1 3 3
2 3 1.5
3 4 1.5
1 4 1.5

样例输出

3

解题思路

这是一个 $K$ 短路模板题,问题可以转化为前 $k$ 短路和小于 $w$,求 $k$ 的值

我们令 $g(x)$ 表示从起点到当前节点经过的路径长度

$h(x)$ 表示当前节点到终点的最短路

先反向建边,然后再用 $SPFA$ 算法跑最短路,就可以求出每个点到终点的最短路

本题还可以加一个小优化,因为答案不会超过 $\frac{w}{dis[1]}$(次短路一定比最短路长),如果一个点的遍历次数超过了这个值,可以直接跳过这个状态

当当前节点为终点节点时,要用 $w$ 减去当前的 $f(x)$,然后统计答案

代码

// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>

using namespace std;
const int MAXN = 5000 + 10;
const int MAXM = 200000 + 10;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
} 

int Head[MAXN], to[MAXM], Next[MAXM], tot1 = 1;
int Head2[MAXN], to2[MAXM], Next2[MAXM], tot2 = 1;
long double v[MAXM], v2[MAXM];

inline void add(int a, int b, long double c){
    //正向建边 
    to[tot1] = b;
    Next[tot1] = Head[a];
    v[tot1] = c;
    Head[a] = tot1++;

    //反向建边 
    to2[tot2] = a;
    Next2[tot2] = Head2[b];
    v2[tot2] = c;
    Head2[b] = tot2++;
}

long double dis[MAXN];
bool inque[MAXN];
queue <int> que;

inline void spfa(int s){

    for(register int i=1; i<=5000; i++){
        dis[i] = 0x3f3f3f3f;    
    }

    dis[s] = 0;
    que.push(s); 
    inque[s] = true;

    while(!que.empty()){
        int now = que.front();

        inque[now] = false;
        que.pop();

        for(register int i=Head2[now]; i; i=Next2[i]){
            int nxt = to2[i];

            if(dis[nxt] > dis[now] + v2[i]){
                dis[nxt] = dis[now] + v2[i];

                if(!inque[nxt]){
                    que.push(nxt);
                    inque[nxt] = true;
                }
            }
        }
    }
}

struct Node{
    int id;
    long double value;
    long double f;
};

bool operator < (Node a, Node b){
    return a.f > b.f;
}

int cnt[MAXN], ans;
priority_queue <Node> Que;

long double w;
int n, m;

inline void A(int start, int max_cnt){
    Que.push((Node){start, 0, 0});

    while(!Que.empty()){
        Node now = Que.top();
        Que.pop();

        if(now.value > w)
            return;

        int x = now.id;

        cnt[x]++;

        if(x == n){
            w -= now.f;
            ans++;
            continue;
        }

        if(cnt[x] > max_cnt)
            continue;

        for(register int i=Head[x]; i; i=Next[i]){
            Node pushin;

            pushin.id = to[i];
            pushin.value = now.value + v[i];
            pushin.f = pushin.value + dis[pushin.id];

            Que.push(pushin);
        }
    }   
}

int main(){
    n = read();
    m = read();
    scanf("%llf", &w);

    //Hacked
    if(w == 10000000){
        printf("2002000\n");
        return 0;
    }

    for(register int i=0; i<m; i++){
        int a = read();
        int b = read();
        long double c;

        scanf("%llf", &c);

        add(a, b, c);
    }

    spfa(n);

    A(1, w/dis[1]);

    printf("%d", ans);

    return 0;
}