]> git.lizzy.rs Git - lua-star.git/blob - src/lua-star.lua
Add lua star, example and documentation.
[lua-star.git] / src / lua-star.lua
1 --[[
2    lua-star.lua
3
4    Copyright 2017 wesley werner <wesley.werner@gmail.com>
5
6    This program is free software: you can redistribute it and/or modify
7    it under the terms of the GNU General Public License as published by
8    the Free Software Foundation, either version 3 of the License, or
9    any later version.
10
11    This program is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14    GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program. If not, see http://www.gnu.org/licenses/.
18
19    References:
20    https://en.wikipedia.org/wiki/A*_search_algorithm
21    https://www.redblobgames.com/pathfinding/a-star/introduction.html
22    https://www.raywenderlich.com/4946/introduction-to-a-pathfinding
23 ]]--
24
25 --- Provides easy A* path finding.
26 -- @module lua-star
27
28 local module = {}
29
30 --- Clears all cached paths.
31 function module:clearCached()
32     module.cache = nil
33 end
34
35 -- (Internal) Returns a unique key for the start and end points.
36 local function keyOf(start, goal)
37     return string.format("%d,%d>%d,%d", start.x, start.y, goal.x, goal.y)
38 end
39
40 -- (Internal) Returns the cached path for start and end points.
41 local function getCached(start, goal)
42     if module.cache then
43         local key = keyOf(start, goal)
44         return module.cache[key]
45     end
46 end
47
48 -- (Internal) Saves a path to the cache.
49 local function saveCached(start, goal, path)
50     module.cache = module.cache or { }
51     local key = keyOf(start, goal)
52     module.cache[key] = path
53 end
54
55 -- (Internal) Return the distance between two points.
56 -- This method doesn't bother getting the square root of s, it is faster
57 -- and it still works for our use.
58 local function distance(x1, y1, x2, y2)
59   local dx = x1 - x2
60   local dy = y1 - y2
61   local s = dx * dx + dy * dy
62   return s
63 end
64
65 -- (Internal) Clamp a value to a range.
66 local function clamp(x, min, max)
67   return x < min and min or (x > max and max or x)
68 end
69
70 -- (Internal) Return the score of a node.
71 -- G is the cost from START to this node.
72 -- H is a heuristic cost, in this case the distance from this node to the goal.
73 -- Returns F, the sum of G and H.
74 local function calculateScore(previous, node, goal)
75
76     local G = previous.score + 1
77     local H = distance(node.x, node.y, goal.x, goal.y)
78     return G + H, G, H
79
80 end
81
82 -- (Internal) Returns true if the given list contains the specified item.
83 local function listContains(list, item)
84     for _, test in ipairs(list) do
85         if test.x == item.x and test.y == item.y then
86             return true
87         end
88     end
89     return false
90 end
91
92 -- (Internal) Returns the item in the given list.
93 local function listItem(list, item)
94     for _, test in ipairs(list) do
95         if test.x == item.x and test.y == item.y then
96             return test
97         end
98     end
99 end
100
101 -- (Internal) Requests adjacent map values around the given node.
102 local function getAdjacent(width, height, node, positionIsOpenFunc)
103
104     local result = { }
105
106     local positions = {
107         { x = 0, y = -1 },  -- top
108         { x = -1, y = 0 },  -- left
109         { x = 0, y = 1 },   -- bottom
110         { x = 1, y = 0 },   -- right
111         -- include diagonal movements
112         { x = -1, y = -1 },   -- top left
113         { x = 1, y = -1 },   -- top right
114         { x = -1, y = 1 },   -- bot left
115         { x = 1, y = 1 },   -- bot right
116     }
117
118     for _, point in ipairs(positions) do
119         local px = clamp(node.x + point.x, 1, width)
120         local py = clamp(node.y + point.y, 1, height)
121         local value = positionIsOpenFunc( px, py )
122         if value then
123             table.insert( result, { x = px, y = py  } )
124         end
125     end
126
127     return result
128
129 end
130
131 -- Returns the path from start to goal, or false if no path exists.
132 function module:find(width, height, start, goal, positionIsOpenFunc, useCache)
133
134     if useCache then
135         local cachedPath = getCached(start, goal)
136         if cachedPath then
137             return cachedPath
138         end
139     end
140
141     local success = false
142     local open = { }
143     local closed = { }
144
145     start.score = 0
146     start.G = 0
147     start.H = distance(start.x, start.y, goal.x, goal.y)
148     start.parent = { x = 0, y = 0 }
149     table.insert(open, start)
150
151     while not success and #open > 0 do
152
153         -- sort by score: high to low
154         table.sort(open, function(a, b) return a.score > b.score end)
155
156         local current = table.remove(open)
157
158         table.insert(closed, current)
159
160         success = listContains(closed, goal)
161
162         if not success then
163
164             local adjacentList = getAdjacent(width, height, current, positionIsOpenFunc)
165
166             for _, adjacent in ipairs(adjacentList) do
167
168                 if not listContains(closed, adjacent) then
169
170                     if not listContains(open, adjacent) then
171
172                         adjacent.score = calculateScore(current, adjacent, goal)
173                         adjacent.parent = current
174                         table.insert(open, adjacent)
175
176                     end
177
178                 end
179
180             end
181
182         end
183
184     end
185
186     if not success then
187         return false
188     end
189
190     -- traverse the parents from the last point to get the path
191     local node = listItem(closed, closed[#closed])
192     local path = { }
193
194     while node do
195
196         table.insert(path, 1, { x = node.x, y = node.y } )
197         node = listItem(closed, node.parent)
198
199     end
200
201     saveCached(start, goal, path)
202
203     -- reverse the closed list to get the solution
204     return path
205
206 end
207
208 return module