关于八数码问题的三种实现方式
#include<cstdio>
#include<cstring>
using namespace std;typedef int State[9];
const int maxstate = 1000000;
State st[maxstate], goal;
int dist[maxstate];const int dx[] = {1,-1,0,0};
const int dy[] = {0,0,-1,1};bool inside(int x, int y) {return 0 <= x && x < 3 && 0 <= y && y < 3;
}const int hashsize = 1000003;
int head[hashsize], next[maxstate];
void init_lookup_table() { memset(head,0,sizeof(head)); }
int hash(State & s) {int v = 0;for(int i = 0; i < 9; ++i) v = v*10 + s[i];return v % hashsize;
}
bool try_to_insert(int s) {int h = hash(st[s]);int u = head[h];while(u) {if(memcmp(st[u], st[s], sizeof(st[s])) == 0) return 0;u = next[u];}next[s] = head[h];head[h] = s; return 1;
}int bfs() {init_lookup_table();int front = 1, rear = 2;while(front < rear) {State & s = st[front];if(memcmp(goal,s,sizeof(s)) == 0) return front;int z;for(z = 0; z < 9; ++z) if(!s[z]) break;int x = z/3, y = z%3;for(int d = 0; d < 4; ++d) {int nx = x + dx[d];int ny = y + dy[d];int nz = nx*3 + ny;if(inside(nx,ny)) {State & t = st[rear];memcpy(&t,&s,sizeof(s));t[nz] = s[z];t[z] = s[nz];dist[rear] = dist[front] + 1;if(try_to_insert(rear)) ++rear; } }++front;}return 0;
}int main() {for(int i = 0; i < 9; ++i) scanf("%d",&st[1][i]);for(int i = 0; i < 9; ++i) scanf("%d",&goal[i]);int ans = bfs();if(ans > 0) printf("%d\n",dist[ans]);else printf("-1\n");return 0;
}
这段代码虽简短,但却有很多细节和小技巧。下面一一说明
1: 程序将八数码的状态单元int[9]直接用typedef替换为State,并定义状态队列st,最终目标goal,便于阅读和编码。
2: 在初始读入数据时,直接把初始状态读入队列的第一个'状态'中,将目标状态读入goal
3: bfs中采用内存控制函数memcpy赋值和memcmp判断,高效简洁。
4: 八数码属于隐式图搜索问题,因此要在普通bfs上增加一个"结点查找表"来判重。本程序采用自己设计的hash表来判重。下面还有两种实现方式,分别是"完美哈希"和STL中的set。
//完美hash
int vis[362880], fact[9];
void init_lookup_table() {fact[0] = 1;for(int i = 1; i < 9; ++i) fact[i] = fact[i-1] * i;
}
int try_to_insert(int s) {int code = 0;for(int i = 0; i < 9; ++i) {int cnt = 0;for(int j = i+1; j < 9; ++j) if(st[s][j] < st[s][i]) ++cnt;code += fact[8-i] * cnt;}if(vis[code]) return 0;return vis[code] = 1;
}
//STL
set<int> vis;
void init_lookup_table() { vis.clear(); }
int try_to_insert(int s) {int v = 0;for(int i = 0; i < 9; ++i) v = v*10 + st[s][i];if(vis.count(v)) return 0;vis.insert(v);return 1;
}