]> git.lizzy.rs Git - lua-star.git/blob - src/lua-star.lua
7c337b065528dcb32ee1e69d339bfbea7c72c68b
[lua-star.git] / src / lua-star.lua
1 --[[
2     Lua star example - Run with love (https://love2d.org/)
3     Copyright 2018 Wesley Werner <wesley.werner@gmail.com>
4
5     Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
7     The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
9     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10
11     References:
12     https://en.wikipedia.org/wiki/A*_search_algorithm
13     https://www.redblobgames.com/pathfinding/a-star/introduction.html
14     https://www.raywenderlich.com/4946/introduction-to-a-pathfinding
15 ]]--
16
17 --- Provides easy A* path finding.
18 -- @module lua-star
19
20 local module = {}
21
22 --- Clears all cached paths.
23 function module:clearCached()
24     module.cache = nil
25 end
26
27 -- (Internal) Returns a unique key for the start and end points.
28 local function keyOf(start, goal)
29     return string.format("%d,%d>%d,%d", start.x, start.y, goal.x, goal.y)
30 end
31
32 -- (Internal) Returns the cached path for start and end points.
33 local function getCached(start, goal)
34     if module.cache then
35         local key = keyOf(start, goal)
36         return module.cache[key]
37     end
38 end
39
40 -- (Internal) Saves a path to the cache.
41 local function saveCached(start, goal, path)
42     module.cache = module.cache or { }
43     local key = keyOf(start, goal)
44     module.cache[key] = path
45 end
46
47 -- (Internal) Return the distance between two points.
48 -- This method doesn't bother getting the square root of s, it is faster
49 -- and it still works for our use.
50 local function distance(x1, y1, x2, y2)
51   local dx = x1 - x2
52   local dy = y1 - y2
53   local s = dx * dx + dy * dy
54   return s
55 end
56
57 -- (Internal) Clamp a value to a range.
58 local function clamp(x, min, max)
59   return x < min and min or (x > max and max or x)
60 end
61
62 -- (Internal) Return the score of a node.
63 -- G is the cost from START to this node.
64 -- H is a heuristic cost, in this case the distance from this node to the goal.
65 -- Returns F, the sum of G and H.
66 local function calculateScore(previous, node, goal)
67
68     local G = previous.score + 1
69     local H = distance(node.x, node.y, goal.x, goal.y)
70     return G + H, G, H
71
72 end
73
74 -- (Internal) Returns true if the given list contains the specified item.
75 local function listContains(list, item)
76     for _, test in ipairs(list) do
77         if test.x == item.x and test.y == item.y then
78             return true
79         end
80     end
81     return false
82 end
83
84 -- (Internal) Returns the item in the given list.
85 local function listItem(list, item)
86     for _, test in ipairs(list) do
87         if test.x == item.x and test.y == item.y then
88             return test
89         end
90     end
91 end
92
93 -- (Internal) Requests adjacent map values around the given node.
94 local function getAdjacent(width, height, node, positionIsOpenFunc, includeDiagonals)
95
96     local result = { }
97
98     local positions = {
99         { x = 0, y = -1 },  -- top
100         { x = -1, y = 0 },  -- left
101         { x = 0, y = 1 },   -- bottom
102         { x = 1, y = 0 },   -- right
103     }
104
105     if includeDiagonals then
106         local diagonalMovements = {
107             { x = -1, y = -1 },   -- top left
108             { x = 1, y = -1 },   -- top right
109             { x = -1, y = 1 },   -- bot left
110             { x = 1, y = 1 },   -- bot right
111         }
112
113         for _, value in ipairs(diagonalMovements) do
114             table.insert(positions, value)
115         end
116     end
117
118     for _, point in ipairs(positions) do
119         local px = clamp(node.x + point.x, 1, width)
120         local py = clamp(node.y + point.y, 1, height)
121         local value = positionIsOpenFunc( px, py )
122         if value then
123             table.insert( result, { x = px, y = py  } )
124         end
125     end
126
127     return result
128
129 end
130
131 -- Returns the path from start to goal, or false if no path exists.
132 function module:find(width, height, start, goal, positionIsOpenFunc, useCache, excludeDiagonalMoving)
133
134     if useCache then
135         local cachedPath = getCached(start, goal)
136         if cachedPath then
137             return cachedPath
138         end
139     end
140
141     local success = false
142     local open = { }
143     local closed = { }
144
145     start.score = 0
146     start.G = 0
147     start.H = distance(start.x, start.y, goal.x, goal.y)
148     start.parent = { x = 0, y = 0 }
149     table.insert(open, start)
150
151     while not success and #open > 0 do
152
153         -- sort by score: high to low
154         table.sort(open, function(a, b) return a.score > b.score end)
155
156         local current = table.remove(open)
157
158         table.insert(closed, current)
159
160         success = listContains(closed, goal)
161
162         if not success then
163
164             local adjacentList = getAdjacent(width, height, current, positionIsOpenFunc, not excludeDiagonalMoving)
165
166             for _, adjacent in ipairs(adjacentList) do
167
168                 if not listContains(closed, adjacent) then
169
170                     if not listContains(open, adjacent) then
171
172                         adjacent.score = calculateScore(current, adjacent, goal)
173                         adjacent.parent = current
174                         table.insert(open, adjacent)
175
176                     end
177
178                 end
179
180             end
181
182         end
183
184     end
185
186     if not success then
187         return false
188     end
189
190     -- traverse the parents from the last point to get the path
191     local node = listItem(closed, closed[#closed])
192     local path = { }
193
194     while node do
195
196         table.insert(path, 1, { x = node.x, y = node.y } )
197         node = listItem(closed, node.parent)
198
199     end
200
201     saveCached(start, goal, path)
202
203     -- reverse the closed list to get the solution
204     return path
205
206 end
207
208 return module