From 528d181cbf548aac4f53566ce8f892d03629aebe Mon Sep 17 00:00:00 2001 From: falsycat Date: Sat, 5 Aug 2023 21:44:20 +0900 Subject: [PATCH] set immutable environment for lua threads --- core/CMakeLists.txt | 1 + core/luajit/context.cc | 21 ++++++++++++++++++++- core/luajit/context.hh | 33 +++++++++++++++++++++++++++++++++ core/luajit/thread.cc | 30 ++++++++++++++++++++++++++++++ core/luajit/thread.hh | 3 +++ core/luajit/thread_test.cc | 21 +++++++++++++++++++++ 6 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 core/luajit/thread.cc diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index cfa3707..303ce7e 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -9,6 +9,7 @@ target_link_libraries(nf7_core target_sources(nf7_core PRIVATE luajit/context.cc + luajit/thread.cc version.cc PUBLIC luajit/context.hh diff --git a/core/luajit/context.cc b/core/luajit/context.cc index bfcef24..986e354 100644 --- a/core/luajit/context.cc +++ b/core/luajit/context.cc @@ -29,7 +29,26 @@ template class ContextImpl final : public Context { public: ContextImpl(const char* name, Kind kind, Env& env) - : Context(name, kind), tasq_(env.Get()) { } + : Context(name, kind), tasq_(env.Get()) { + auto L = state(); + + lua_pushthread(L); + if (luaL_newmetatable(L, "nf7::Context::ImmutableEnv")) { + lua_createtable(L, 0, 0); + { + luaL_newmetatable(L, kGlobalTableName); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, [](auto L) { + return luaL_error(L, "global is immutable"); + }); + lua_setfield(L, -2, "__newindex"); + } + lua_setmetatable(L, -2); + } + lua_setfenv(L, -2); + lua_pop(L, 1); + } void Push(Task&& task) noexcept override { auto self = std::dynamic_pointer_cast>(shared_from_this()); diff --git a/core/luajit/context.hh b/core/luajit/context.hh index faddaef..b3c1c67 100644 --- a/core/luajit/context.hh +++ b/core/luajit/context.hh @@ -2,12 +2,15 @@ #pragma once #include +#include #include +#include #include #include #include "iface/common/task.hh" +#include "iface/common/value.hh" #include "iface/subsys/interface.hh" #include "iface/env.hh" @@ -98,6 +101,34 @@ class TaskContext final { Query(v); } + template + T& NewUserData(T&& v) { + return *(new (lua_newuserdata(state_, sizeof(T))) T {std::move(v)}); + } + template + T& NewUserData(T&& v) { + return *(new (lua_newuserdata(state_, sizeof(T))) T {v}); + } + + template + T& CheckUserData(int index, const char* name) { + return CheckUserData(state_, index, name); + } + template + static T& CheckUserData(lua_State* L, int index, const char* name) { + return *reinterpret_cast(luaL_checkudata(L, index, name)); + } + + void Push(const nf7::Value&) noexcept { + lua_pushstring(state_, "hello"); + } + const nf7::Value& CheckValue(int index) noexcept { + return CheckValue(state_, index); + } + static const nf7::Value& CheckValue(lua_State* L, int index) { + return CheckUserData(L, index, "nf7::Value"); + } + const std::shared_ptr& context() const noexcept { return ctx_; } lua_State* state() const noexcept { return state_; } @@ -110,6 +141,8 @@ class Context : public subsys::Interface, public TaskQueue { public: + static constexpr auto kGlobalTableName = "nf7::Context::GlobalTable"; + using Item = Task; enum Kind { diff --git a/core/luajit/thread.cc b/core/luajit/thread.cc new file mode 100644 index 0000000..67b33ee --- /dev/null +++ b/core/luajit/thread.cc @@ -0,0 +1,30 @@ +// No copyright +#include "core/luajit/thread.hh" + +#include "core/luajit/context.hh" + + +namespace nf7::core::luajit { + +void Thread::SetUpThread() noexcept { + luaL_newmetatable(th_, Context::kGlobalTableName); + { + new (lua_newuserdata(th_, sizeof(this))) Thread* {this}; + if (luaL_newmetatable(th_, "nf7::Thread")) { + lua_createtable(th_, 0, 0); + { + lua_pushcfunction(th_, [](auto L) { + luaL_checkudata(L, 1, "nf7::Thread"); + return luaL_error(L, lua_tostring(L, 2)); + }); + lua_setfield(th_, -2, "throw"); + } + lua_setfield(th_, -2, "__index"); + } + lua_setmetatable(th_, -2); + lua_setfield(th_, -2, "nf7"); + } + lua_pop(th_, 1); +} + +} // namespace nf7::core::luajit diff --git a/core/luajit/thread.hh b/core/luajit/thread.hh index c36df30..4543266 100644 --- a/core/luajit/thread.hh +++ b/core/luajit/thread.hh @@ -50,6 +50,7 @@ class Thread : public std::enable_shared_from_this { return; } assert(kPaused == state_); + SetUpThread(); auto thlua = taskContext(lua); const auto narg = thlua.PushAll(std::forward(args)...); @@ -79,6 +80,8 @@ class Thread : public std::enable_shared_from_this { virtual void onAborted(TaskContext&) noexcept { } private: + void SetUpThread() noexcept; + TaskContext taskContext(const TaskContext& t) const noexcept { assert(t.context() == context_); return TaskContext {context_, th_}; diff --git a/core/luajit/thread_test.cc b/core/luajit/thread_test.cc index 945e7fc..04737ff 100644 --- a/core/luajit/thread_test.cc +++ b/core/luajit/thread_test.cc @@ -71,6 +71,27 @@ TEST_P(LuaJIT_Thread, RunAndError) { "return foo()"); } +TEST_P(LuaJIT_Thread, ForbidGlobalVariable) { + TestThread([](auto& sut) { + EXPECT_CALL(sut, onAborted) + .WillOnce([](auto& lua) { + EXPECT_THAT(lua_tostring(*lua, -1), ::testing::HasSubstr("immutable")); + }); + }, + "x = 1"); +} + +TEST_P(LuaJIT_Thread, StdThrow) { + TestThread([](auto& sut) { + EXPECT_CALL(sut, onAborted) + .WillOnce([](auto& lua) { + EXPECT_THAT(lua_tostring(*lua, -1), + ::testing::HasSubstr("hello world")); + }); + }, + "nf7:throw(\"hello world\")"); +} + INSTANTIATE_TEST_SUITE_P( SyncOrAsync, LuaJIT_Thread,