1--SPDX-License-Identifier: MIT
2--[[
3--*****************************************************************************
4--* Copyright (C) 1994-2016 Lua.org, PUC-Rio.
5--*
6--* Permission is hereby granted, free of charge, to any person obtaining
7--* a copy of this software and associated documentation files (the
8--* "Software"), to deal in the Software without restriction, including
9--* without limitation the rights to use, copy, modify, merge, publish,
10--* distribute, sublicense, and/or sell copies of the Software, and to
11--* permit persons to whom the Software is furnished to do so, subject to
12--* the following conditions:
13--*
14--* The above copyright notice and this permission notice shall be
15--* included in all copies or substantial portions of the Software.
16--*
17--* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18--* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19--* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
20--* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
21--* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
22--* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
23--* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24--*****************************************************************************
25--]]
26
27local f
28
29local main, ismain = coroutine.running()
30assert(type(main) == "thread" and ismain)
31assert(not coroutine.resume(main))
32
33
34-- tests for multiple yield/resume arguments
35
36local function eqtab (t1, t2)
37  assert(#t1 == #t2)
38  for i = 1, #t1 do
39    local v = t1[i]
40    assert(t2[i] == v)
41  end
42end
43
44_G.x = nil   -- declare x
45function foo (a, ...)
46  local x, y = coroutine.running()
47  assert(x == f and y == false)
48  -- next call should not corrupt coroutine (but must fail,
49  -- as it attempts to resume the running coroutine)
50  assert(coroutine.resume(f) == false)
51  assert(coroutine.status(f) == "running")
52  local arg = {...}
53  for i=1,#arg do
54    _G.x = {coroutine.yield(table.unpack(arg[i]))}
55  end
56  return table.unpack(a)
57end
58
59f = coroutine.create(foo)
60assert(type(f) == "thread" and coroutine.status(f) == "suspended")
61assert(string.find(tostring(f), "thread"))
62local s,a,b,c,d
63s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'})
64assert(s and a == nil and coroutine.status(f) == "suspended")
65s,a,b,c,d = coroutine.resume(f)
66eqtab(_G.x, {})
67assert(s and a == 1 and b == nil)
68s,a,b,c,d = coroutine.resume(f, 1, 2, 3)
69eqtab(_G.x, {1, 2, 3})
70assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil)
71s,a,b,c,d = coroutine.resume(f, "xuxu")
72eqtab(_G.x, {"xuxu"})
73assert(s and a == 1 and b == 2 and c == 3 and d == nil)
74assert(coroutine.status(f) == "dead")
75s, a = coroutine.resume(f, "xuxu")
76assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead")
77
78
79-- yields in tail calls
80local function foo (i) return coroutine.yield(i) end
81f = coroutine.wrap(function ()
82  for i=1,10 do
83    assert(foo(i) == _G.x)
84  end
85  return 'a'
86end)
87for i=1,10 do _G.x = i; assert(f(i) == i) end
88_G.x = 'xuxu'; assert(f('xuxu') == 'a')
89
90-- recursive
91function pf (n, i)
92  coroutine.yield(n)
93  pf(n*i, i+1)
94end
95
96f = coroutine.wrap(pf)
97local s=1
98for i=1,10 do
99  assert(f(1, 1) == s)
100  s = s*i
101end
102
103-- sieve implemented with co-routines
104
105-- generate all the numbers from 2 to n
106function gen (n)
107  return coroutine.wrap(function ()
108    for i=2,n do coroutine.yield(i) end
109  end)
110end
111
112-- filter the numbers generated by 'g', removing multiples of 'p'
113function filter (p, g)
114  return coroutine.wrap(function ()
115    for n in g do
116      if n%p ~= 0 then coroutine.yield(n) end
117    end
118  end)
119end
120
121-- generate primes up to 20
122local x = gen(20)
123local a = {}
124while 1 do
125  local n = x()
126  if n == nil then break end
127  table.insert(a, n)
128  x = filter(n, x)
129end
130
131-- expect 8 primes and last one is 19
132assert(#a == 8 and a[#a] == 19)
133x, a = nil
134
135
136-- yielding across C boundaries
137
138co = coroutine.wrap(function()
139       coroutine.yield(20)
140       return 30
141     end)
142
143assert(co() == 20)
144assert(co() == 30)
145
146
147local f = function (s, i) return coroutine.yield(i) end
148function f (a, b) a = coroutine.yield(a);  error{a + b} end
149function g(x) return x[1]*2 end
150
151
152-- unyieldable C call
153do
154  local function f (c)
155          return c .. c
156        end
157
158  local co = coroutine.wrap(function (c)
159               local s = string.gsub("a", ".", f)
160               return s
161             end)
162  assert(co() == "aa")
163end
164
165
166-- errors in coroutines
167function foo ()
168  coroutine.yield(3)
169  error(foo)
170end
171
172function goo() foo() end
173x = coroutine.wrap(goo)
174assert(x() == 3)
175x = coroutine.create(goo)
176a,b = coroutine.resume(x)
177assert(a and b == 3)
178a,b = coroutine.resume(x)
179assert(not a and b == foo and coroutine.status(x) == "dead")
180a,b = coroutine.resume(x)
181assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead")
182
183
184-- co-routines x for loop
185function all (a, n, k)
186  if k == 0 then coroutine.yield(a)
187  else
188    for i=1,n do
189      a[k] = i
190      all(a, n, k-1)
191    end
192  end
193end
194
195local a = 0
196for t in coroutine.wrap(function () all({}, 5, 4) end) do
197  a = a+1
198end
199assert(a == 5^4)
200
201
202-- access to locals of collected corroutines
203local C = {}; setmetatable(C, {__mode = "kv"})
204local x = coroutine.wrap (function ()
205            local a = 10
206            local function f () a = a+10; return a end
207            while true do
208              a = a+1
209              coroutine.yield(f)
210            end
211          end)
212
213C[1] = x;
214
215local f = x()
216assert(f() == 21 and x()() == 32 and x() == f)
217x = nil
218collectgarbage()
219assert(C[1] == nil)
220assert(f() == 43 and f() == 53)
221
222
223-- old bug: attempt to resume itself
224
225function co_func (current_co)
226  assert(coroutine.running() == current_co)
227  assert(coroutine.resume(current_co) == false)
228  coroutine.yield(10, 20)
229  assert(coroutine.resume(current_co) == false)
230  coroutine.yield(23)
231  return 10
232end
233
234local co = coroutine.create(co_func)
235local a,b,c = coroutine.resume(co, co)
236assert(a == true and b == 10 and c == 20)
237a,b = coroutine.resume(co, co)
238assert(a == true and b == 23)
239a,b = coroutine.resume(co, co)
240assert(a == true and b == 10)
241assert(coroutine.resume(co, co) == false)
242assert(coroutine.resume(co, co) == false)
243
244
245-- attempt to resume 'normal' coroutine
246local co1, co2
247co1 = coroutine.create(function () return co2() end)
248co2 = coroutine.wrap(function ()
249        assert(coroutine.status(co1) == 'normal')
250        assert(not coroutine.resume(co1))
251        coroutine.yield(3)
252      end)
253
254a,b = coroutine.resume(co1)
255assert(a and b == 3)
256assert(coroutine.status(co1) == 'dead')
257
258
259-- access to locals of erroneous coroutines
260local x = coroutine.create (function ()
261            local a = 10
262            _G.f = function () a=a+1; return a end
263            error('x')
264          end)
265
266assert(not coroutine.resume(x))
267-- overwrite previous position of local `a'
268assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1))
269assert(_G.f() == 11)
270assert(_G.f() == 12)
271
272
273-- leaving a pending coroutine open
274_X = coroutine.wrap(function ()
275      local a = 10
276      local x = function () a = a+1 end
277      coroutine.yield()
278    end)
279
280_X()
281
282assert(coroutine.running() == main)
283
284
285
286-- testing yields inside metamethods
287
288local mt = {
289  __eq = function(a,b) coroutine.yield(nil, "eq"); return a.x == b.x end,
290  __lt = function(a,b) coroutine.yield(nil, "lt"); return a.x < b.x end,
291  __le = function(a,b) coroutine.yield(nil, "le"); return a - b <= 0 end,
292  __add = function(a,b) coroutine.yield(nil, "add"); return a.x + b.x end,
293  __sub = function(a,b) coroutine.yield(nil, "sub"); return a.x - b.x end,
294  __mod = function(a,b) coroutine.yield(nil, "mod"); return a.x % b.x end,
295  __unm = function(a,b) coroutine.yield(nil, "unm"); return -a.x end,
296
297  __concat = function(a,b)
298               coroutine.yield(nil, "concat");
299               a = type(a) == "table" and a.x or a
300               b = type(b) == "table" and b.x or b
301               return a .. b
302             end,
303  __index = function (t,k) coroutine.yield(nil, "idx"); return t.k[k] end,
304  __newindex = function (t,k,v) coroutine.yield(nil, "nidx"); t.k[k] = v end,
305}
306
307
308local function new (x)
309  return setmetatable({x = x, k = {}}, mt)
310end
311
312
313local a = new(10)
314local b = new(12)
315local c = new"hello"
316
317local function run (f, t)
318  local i = 1
319  local c = coroutine.wrap(f)
320  while true do
321    local res, stat = c()
322    if res then assert(t[i] == nil); return res, t end
323    assert(stat == t[i])
324    i = i + 1
325  end
326end
327
328
329assert(run(function () if (a>=b) then return '>=' else return '<' end end,
330       {"le", "sub"}) == "<")
331-- '<=' using '<'
332mt.__le = nil
333assert(run(function () if (a<=b) then return '<=' else return '>' end end,
334       {"lt"}) == "<=")
335assert(run(function () if (a==b) then return '==' else return '~=' end end,
336       {"eq"}) == "~=")
337
338assert(run(function () return a % b end, {"mod"}) == 10)
339
340assert(run(function () return a..b end, {"concat"}) == "1012")
341
342assert(run(function() return a .. b .. c .. a end,
343       {"concat", "concat", "concat"}) == "1012hello10")
344
345assert(run(function() return "a" .. "b" .. a .. "c" .. c .. b .. "x" end,
346       {"concat", "concat", "concat"}) == "ab10chello12x")
347
348
349-- testing yields inside 'for' iterators
350
351local f = function (s, i)
352      if i%2 == 0 then coroutine.yield(nil, "for") end
353      if i < s then return i + 1 end
354    end
355
356assert(run(function ()
357             local s = 0
358             for i in f, 4, 0 do s = s + i end
359             return s
360           end, {"for", "for", "for"}) == 10)
361
362
363return "OK"
364