set immutable environment for lua threads

This commit is contained in:
falsycat 2023-08-05 21:44:20 +09:00
parent 8ebb5871e6
commit 528d181cbf
6 changed files with 108 additions and 1 deletions

View File

@ -9,6 +9,7 @@ target_link_libraries(nf7_core
target_sources(nf7_core
PRIVATE
luajit/context.cc
luajit/thread.cc
version.cc
PUBLIC
luajit/context.hh

View File

@ -29,7 +29,26 @@ template <typename T>
class ContextImpl final : public Context {
public:
ContextImpl(const char* name, Kind kind, Env& env)
: Context(name, kind), tasq_(env.Get<T>()) { }
: Context(name, kind), tasq_(env.Get<T>()) {
auto L = state();
lua_pushthread(L);
if (luaL_newmetatable(L, "nf7::Context::ImmutableEnv")) {
lua_createtable(L, 0, 0);
{
luaL_newmetatable(L, kGlobalTableName);
lua_setfield(L, -2, "__index");
lua_pushcfunction(L, [](auto L) {
return luaL_error(L, "global is immutable");
});
lua_setfield(L, -2, "__newindex");
}
lua_setmetatable(L, -2);
}
lua_setfenv(L, -2);
lua_pop(L, 1);
}
void Push(Task&& task) noexcept override {
auto self = std::dynamic_pointer_cast<ContextImpl<T>>(shared_from_this());

View File

@ -2,12 +2,15 @@
#pragma once
#include <cassert>
#include <concepts>
#include <memory>
#include <string>
#include <utility>
#include <lua.hpp>
#include "iface/common/task.hh"
#include "iface/common/value.hh"
#include "iface/subsys/interface.hh"
#include "iface/env.hh"
@ -98,6 +101,34 @@ class TaskContext final {
Query(v);
}
template <std::move_constructible T>
T& NewUserData(T&& v) {
return *(new (lua_newuserdata(state_, sizeof(T))) T {std::move(v)});
}
template <std::copy_constructible T>
T& NewUserData(T&& v) {
return *(new (lua_newuserdata(state_, sizeof(T))) T {v});
}
template <typename T>
T& CheckUserData(int index, const char* name) {
return CheckUserData<T>(state_, index, name);
}
template <typename T>
static T& CheckUserData(lua_State* L, int index, const char* name) {
return *reinterpret_cast<T*>(luaL_checkudata(L, index, name));
}
void Push(const nf7::Value&) noexcept {
lua_pushstring(state_, "hello");
}
const nf7::Value& CheckValue(int index) noexcept {
return CheckValue(state_, index);
}
static const nf7::Value& CheckValue(lua_State* L, int index) {
return CheckUserData<nf7::Value>(L, index, "nf7::Value");
}
const std::shared_ptr<Context>& context() const noexcept { return ctx_; }
lua_State* state() const noexcept { return state_; }
@ -110,6 +141,8 @@ class Context :
public subsys::Interface,
public TaskQueue {
public:
static constexpr auto kGlobalTableName = "nf7::Context::GlobalTable";
using Item = Task;
enum Kind {

30
core/luajit/thread.cc Normal file
View File

@ -0,0 +1,30 @@
// No copyright
#include "core/luajit/thread.hh"
#include "core/luajit/context.hh"
namespace nf7::core::luajit {
void Thread::SetUpThread() noexcept {
luaL_newmetatable(th_, Context::kGlobalTableName);
{
new (lua_newuserdata(th_, sizeof(this))) Thread* {this};
if (luaL_newmetatable(th_, "nf7::Thread")) {
lua_createtable(th_, 0, 0);
{
lua_pushcfunction(th_, [](auto L) {
luaL_checkudata(L, 1, "nf7::Thread");
return luaL_error(L, lua_tostring(L, 2));
});
lua_setfield(th_, -2, "throw");
}
lua_setfield(th_, -2, "__index");
}
lua_setmetatable(th_, -2);
lua_setfield(th_, -2, "nf7");
}
lua_pop(th_, 1);
}
} // namespace nf7::core::luajit

View File

@ -50,6 +50,7 @@ class Thread : public std::enable_shared_from_this<Thread> {
return;
}
assert(kPaused == state_);
SetUpThread();
auto thlua = taskContext(lua);
const auto narg = thlua.PushAll(std::forward<Args>(args)...);
@ -79,6 +80,8 @@ class Thread : public std::enable_shared_from_this<Thread> {
virtual void onAborted(TaskContext&) noexcept { }
private:
void SetUpThread() noexcept;
TaskContext taskContext(const TaskContext& t) const noexcept {
assert(t.context() == context_);
return TaskContext {context_, th_};

View File

@ -71,6 +71,27 @@ TEST_P(LuaJIT_Thread, RunAndError) {
"return foo()");
}
TEST_P(LuaJIT_Thread, ForbidGlobalVariable) {
TestThread([](auto& sut) {
EXPECT_CALL(sut, onAborted)
.WillOnce([](auto& lua) {
EXPECT_THAT(lua_tostring(*lua, -1), ::testing::HasSubstr("immutable"));
});
},
"x = 1");
}
TEST_P(LuaJIT_Thread, StdThrow) {
TestThread([](auto& sut) {
EXPECT_CALL(sut, onAborted)
.WillOnce([](auto& lua) {
EXPECT_THAT(lua_tostring(*lua, -1),
::testing::HasSubstr("hello world"));
});
},
"nf7:throw(\"hello world\")");
}
INSTANTIATE_TEST_SUITE_P(
SyncOrAsync, LuaJIT_Thread,