diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index ce13b6c..cfa3707 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources(nf7_core version.cc PUBLIC luajit/context.hh + luajit/thread.hh version.hh ) @@ -20,6 +21,8 @@ target_sources(nf7_core_test PRIVATE luajit/context_test.cc luajit/context_test.hh + luajit/thread_test.cc + luajit/thread_test.hh ) target_link_libraries(nf7_core_test PRIVATE diff --git a/core/luajit/thread.hh b/core/luajit/thread.hh new file mode 100644 index 0000000..b326bf9 --- /dev/null +++ b/core/luajit/thread.hh @@ -0,0 +1,93 @@ +// No copyright +#pragma once + +#include +#include +#include +#include + +#include + +#include "core/luajit/context.hh" + +namespace nf7::core::luajit { + +class Thread : public std::enable_shared_from_this { + public: + struct DoNotCallConstructorDirectly { }; + + enum State : uint8_t { + kPaused, + kRunning, + kFinished, + }; + + public: + template + static std::shared_ptr Make( + TaskContext& lua, const std::shared_ptr& func) { + DoNotCallConstructorDirectly key; + auto th = std::make_shared(lua, key); + th->taskContext(lua).Query(*func); + return th; + } + + public: + Thread(TaskContext& t, DoNotCallConstructorDirectly&) noexcept + : context_(t.context()), th_(lua_newthread(*t)) { + assert(th_); + } + + public: + // if this finished with state_ kPaused, + // a responsibility to resume is on one who yielded + template + void Resume(TaskContext& lua, Args&&... args) noexcept { + assert(lua.context() == context_); + + if (kFinished == state_) { + return; + } + assert(kPaused == state_); + + auto thlua = taskContext(lua); + const auto narg = thlua.PushAll(std::forward(args)...); + + state_ = kRunning; + const auto ret = lua_resume(*thlua, narg); + switch (ret) { + case 0: + state_ = kFinished; + onExited(thlua); + return; + case LUA_YIELD: + state_ = kPaused; + return; + default: + state_ = kFinished; + onAborted(thlua); + return; + } + } + + const std::shared_ptr& context() const noexcept { return context_; } + State state() const noexcept { return state_; } + + protected: + virtual void onExited(TaskContext&) noexcept { } + virtual void onAborted(TaskContext&) noexcept { } + + private: + TaskContext taskContext(const TaskContext& t) const noexcept { + assert(t.context() == context_); + return TaskContext {context_, th_}; + } + + private: + const std::shared_ptr context_; + lua_State* const th_; + + State state_ = kPaused; +}; + +} // namespace nf7::core::luajit diff --git a/core/luajit/thread_test.cc b/core/luajit/thread_test.cc new file mode 100644 index 0000000..3f604e0 --- /dev/null +++ b/core/luajit/thread_test.cc @@ -0,0 +1,79 @@ +// No copyright +#include "core/luajit/thread.hh" +#include "core/luajit/thread_test.hh" + +#include +#include + +#include "core/luajit/context_test.hh" + + +class LuaJIT_Thread : public nf7::core::luajit::test::ContextFixture { + public: + using ContextFixture::ContextFixture; + + template + void TestThread( + const auto& setup, const char* script, Args&&... args) { + auto lua = nf7::core::luajit::Context::Create(*env_, GetParam()); + auto called = uint32_t {0}; + lua->Exec([&](auto& lua) { + const auto compile = luaL_loadstring(*lua, script); + ASSERT_EQ(compile, LUA_OK); + + auto sut = nf7::core::luajit::Thread::Make< + nf7::core::luajit::test::ThreadMock>(lua, lua.Register()); + setup(*sut); + + sut->Resume(lua, std::forward(args)...); + ++called; + }); + ConsumeTasks(); + EXPECT_EQ(called, 1); + } +}; + + +TEST_P(LuaJIT_Thread, ResumeWithSingleReturn) { + TestThread([](auto& sut) { + EXPECT_CALL(sut, onExited) + .WillOnce([](auto& lua) { EXPECT_EQ(lua_tointeger(*lua, 1), 6); }); + }, + "return 1+2+3"); +} + +TEST_P(LuaJIT_Thread, ResumeWithArgs) { + TestThread([](auto& sut) { + EXPECT_CALL(sut, onExited) + .WillOnce([](auto& lua) { EXPECT_EQ(lua_tointeger(*lua, 1), 60); }); + }, + "local x,y,z = ...\nreturn x+y+z", + lua_Integer {10}, lua_Integer {20}, lua_Integer {30}); +} + +TEST_P(LuaJIT_Thread, RunWithMultipleReturn) { + TestThread([](auto& sut) { + EXPECT_CALL(sut, onExited) + .WillOnce([](auto& lua) { + EXPECT_EQ(lua_gettop(*lua), 3); + EXPECT_EQ(lua_tointeger(*lua, 1), 1); + EXPECT_EQ(lua_tointeger(*lua, 2), 2); + EXPECT_EQ(lua_tointeger(*lua, 3), 3); + }); + }, + "return 1, 2, 3"); +} + +TEST_P(LuaJIT_Thread, RunAndError) { + TestThread([](auto& sut) { + EXPECT_CALL(sut, onAborted); + }, + "return foo()"); +} + + +INSTANTIATE_TEST_SUITE_P( + SyncOrAsync, LuaJIT_Thread, + testing::Values( + nf7::core::luajit::Context::kSync, + nf7::core::luajit::Context::kAsync)); diff --git a/core/luajit/thread_test.hh b/core/luajit/thread_test.hh new file mode 100644 index 0000000..bc39687 --- /dev/null +++ b/core/luajit/thread_test.hh @@ -0,0 +1,21 @@ +// No copyright +#pragma once + +#include "core/luajit/thread.hh" + +#include + +#include "core/luajit/context.hh" + + +namespace nf7::core::luajit::test { + +class ThreadMock : public Thread { + public: + using Thread::Thread; + + MOCK_METHOD(void, onExited, (TaskContext&), (noexcept, override)); + MOCK_METHOD(void, onAborted, (TaskContext&), (noexcept, override)); +}; + +} // namespace nf7::core::luajit::test