implement pushing weak_ptr into Lua stack

This commit is contained in:
falsycat 2022-06-18 22:19:00 +09:00
parent 0f131588b4
commit 64834f2b4c
7 changed files with 149 additions and 74 deletions

View File

@ -68,9 +68,12 @@ target_sources(nf7
common/logger.hh
common/logger_ref.hh
common/luajit.hh
common/luajit.cc
common/luajit_obj.hh
common/luajit_queue.hh
common/luajit_ref.hh
common/luajit_thread.hh
common/luajit_thread.cc
common/memento.hh
common/native_file.hh
common/node.hh

23
common/luajit.cc Normal file
View File

@ -0,0 +1,23 @@
#include "common/luajit.hh"
#include <lua.hpp>
namespace nf7::luajit {
void PushGlobalTable(lua_State* L) noexcept {
luaL_newmetatable(L, "nf7::luajit::PushGlobalTable");
}
void PushImmEnv(lua_State* L) noexcept {
if (luaL_newmetatable(L, "nf7::luajit::PushImmEnv")) {
lua_createtable(L, 0, 0);
PushGlobalTable(L);
lua_setfield(L, -2, "__index");
lua_pushcfunction(L, [](auto L) { return luaL_error(L, "global is immutable"); });
lua_setfield(L, -2, "__newindex");
lua_setmetatable(L, -2);
}
}
} // namespace nf7::luajit

View File

@ -1,34 +1,38 @@
#pragma once
#include <memory>
#include <lua.hpp>
namespace nf7::luajit {
static inline void PushImmEnv(lua_State* L) noexcept {
if (luaL_newmetatable(L, "nf7::luajit::PushImmEnv")) {
lua_createtable(L, 0, 0);
lua_pushvalue(L, LUA_GLOBALSINDEX);
lua_setfield(L, -2, "__index");
void PushGlobalTable(lua_State*) noexcept;
void PushImmEnv(lua_State*) noexcept;
lua_pushcfunction(L, [](auto L) { return luaL_error(L, "global is immutable"); });
lua_setfield(L, -2, "__newindex");
lua_setmetatable(L, -2);
}
template <typename T>
inline void PushWeakPtr(lua_State* L, const std::weak_ptr<T>& wptr) noexcept {
new (lua_newuserdata(L, sizeof(wptr))) std::weak_ptr<T>(wptr);
}
static inline int SandboxCall(lua_State* L, int narg, int nret) noexcept {
constexpr size_t kSandboxInstructionLimit = 10000000;
static const auto kHook = [](auto L, auto) {
luaL_error(L, "reached instruction limit (<=1e7)");
};
lua_sethook(L, kHook, LUA_MASKCOUNT, kSandboxInstructionLimit);
PushImmEnv(L);
lua_setfenv(L, -narg-2);
return lua_pcall(L, narg, nret, 0);
template <typename T>
inline std::weak_ptr<T>& ToWeakPtr(lua_State* L, int idx) noexcept {
std::weak_ptr<T>* wptr = reinterpret_cast<decltype(wptr)>(lua_touserdata(L, idx));
return *wptr;
}
template <typename T>
inline std::shared_ptr<T> ToSharedPtr(lua_State* L, int idx) {
if (auto ret = ToWeakPtr<T>(L, idx).lock()) {
return ret;
}
luaL_error(L, "object expired: %s", typeid(T).name());
return nullptr;
}
template <typename T>
inline void PushWeakPtrDeleter(lua_State* L, const std::weak_ptr<T>& = {}) noexcept {
lua_pushcfunction(L, [](auto L) {
ToWeakPtr<T>(L, 1).~weak_ptr();
return 0;
});
}
} // namespace nf7

68
common/luajit_thread.cc Normal file
View File

@ -0,0 +1,68 @@
#include "common/luajit_thread.hh"
namespace nf7::luajit {
constexpr size_t kInstructionLimit = 10000000;
void Thread::PushMeta(lua_State* L) noexcept {
if (luaL_newmetatable(L, "nf7::luajit::Thread")) {
PushWeakPtrDeleter<Thread>(L);
lua_setfield(L, -2, "__gc");
lua_createtable(L, 0, 0);
{
lua_pushcfunction(L, [](auto L) {
auto th = ToSharedPtr<Thread>(L, 1);
th->ExpectYield();
th->ljq()->Push(th->ctx(), [th, L](auto) { th->Resume(L, 0); });
return lua_yield(L, lua_gettop(L)-1);
});
lua_setfield(L, -2, "yield");
}
lua_setfield(L, -2, "__index");
}
}
void Thread::Resume(lua_State* L, int narg) noexcept {
std::unique_lock<std::mutex> k(mtx_);
if (state_ == kAborted) return;
assert(L == th_);
assert(state_ == kPaused);
(void) L;
static const auto kHook = [](auto L, auto) {
luaL_error(L, "reached instruction limit (<=1e7)");
};
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
PushGlobalTable(th_);
PushWeakPtr(th_, weak_from_this());
Thread::PushMeta(th_);
lua_setmetatable(th_, -2);
lua_setfield(th_, -2, "nf7");
lua_pop(th_, 1);
state_ = kRunning;
k.unlock();
const auto ret = lua_resume(th_, narg);
k.lock();
if (state_ == kAborted) return;
switch (ret) {
case 0:
state_ = kFinished;
break;
case LUA_YIELD:
state_ = kPaused;
break;
default:
state_ = kAborted;
}
if (!std::exchange(skip_handle_, false)) {
handler_(*this, th_);
}
}
} // namespace nf7::luajit

View File

@ -11,6 +11,7 @@
#include "nf7.hh"
#include "common/future.hh"
#include "common/luajit.hh"
#include "common/luajit_ref.hh"
@ -19,9 +20,6 @@ namespace nf7::luajit {
class Thread final : public std::enable_shared_from_this<Thread> {
public:
static constexpr size_t kInstructionLimit = 10000000;
static constexpr const char* kInstanceName = "nf7::luajit::Thread::instance_";
enum State { kInitial, kRunning, kPaused, kFinished, kAborted, };
using Handler = std::function<void(Thread&, lua_State*)>;
@ -30,9 +28,13 @@ class Thread final : public std::enable_shared_from_this<Thread> {
using nf7::Exception::Exception;
};
static std::shared_ptr<Thread> Create(Handler&& handler) noexcept {
return std::shared_ptr<Thread>{new Thread{std::move(handler)}};
}
template <typename T>
static Thread CreateForPromise(nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
return Thread([&pro, f = std::move(f)](auto& self, auto L) {
static std::shared_ptr<Thread> CreateForPromise(
nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
return std::shared_ptr<Thread>(new Thread{[&pro, f = std::move(f)](auto& self, auto L) {
switch (self.state()) {
case kPaused:
pro.Throw(std::make_exception_ptr<nf7::Exception>({"unexpected yield"}));
@ -47,71 +49,37 @@ class Thread final : public std::enable_shared_from_this<Thread> {
assert(false);
throw 0;
}
});
}});
}
static void PushMeta(lua_State*) noexcept;
Thread() = delete;
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
}
Thread(const Thread&) = delete;
Thread(Thread&&) = delete;
Thread& operator=(const Thread&) = delete;
Thread& operator=(Thread&&) = delete;
// must be called on luajit thread
// be carefully on recursive reference, ctx is held until *this is destructed.
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
const std::shared_ptr<nf7::luajit::Queue>& q,
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
const std::shared_ptr<nf7::luajit::Queue>& ljq,
lua_State* L) noexcept {
assert(state_ == kInitial);
ctx_ = ctx;
ljq_ = ljq;
th_ = lua_newthread(L);
PushImmEnv(L);
lua_setfenv(L, -2);
th_ref_.emplace(ctx, q, luaL_ref(L, LUA_REGISTRYINDEX));
th_ref_.emplace(ctx, ljq, luaL_ref(L, LUA_REGISTRYINDEX));
state_ = kPaused;
return th_;
}
// must be called on luajit thread
void Resume(lua_State* L, int narg) noexcept {
std::unique_lock<std::mutex> k(mtx_);
if (state_ == kAborted) return;
assert(L == th_);
assert(state_ == kPaused);
(void) L;
static const auto kHook = [](auto L, auto) {
luaL_error(L, "reached instruction limit (<=1e7)");
};
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
// TODO: push weak_ptr instead
lua_pushstring(th_, kInstanceName);
lua_pushlightuserdata(th_, this);
lua_rawset(th_, LUA_REGISTRYINDEX);
state_ = kRunning;
k.unlock();
const auto ret = lua_resume(th_, narg);
k.lock();
if (state_ == kAborted) return;
switch (ret) {
case 0:
state_ = kFinished;
break;
case LUA_YIELD:
state_ = kPaused;
break;
default:
state_ = kAborted;
}
if (!std::exchange(skip_handle_, false)) {
handler_(*this, th_);
}
}
void Resume(lua_State* L, int narg) noexcept;
void Abort() noexcept {
std::unique_lock<std::mutex> k(mtx_);
@ -123,6 +91,8 @@ class Thread final : public std::enable_shared_from_this<Thread> {
skip_handle_ = true;
}
const std::shared_ptr<nf7::Context>& ctx() const noexcept { return ctx_; }
const std::shared_ptr<nf7::luajit::Queue>& ljq() const noexcept { return ljq_; }
State state() const noexcept { return state_; }
private:
@ -131,10 +101,17 @@ class Thread final : public std::enable_shared_from_this<Thread> {
Handler handler_;
std::atomic<State> state_ = kInitial;
std::shared_ptr<nf7::Context> ctx_;
std::shared_ptr<nf7::luajit::Queue> ljq_;
lua_State* th_ = nullptr;
std::optional<nf7::luajit::Ref> th_ref_;
bool skip_handle_ = false;
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
}
};
} // namespace nf7::luajit

View File

@ -217,7 +217,7 @@ class Node::Lambda final : public nf7::Lambda,
auto handler = handler_.value();
ljq_ = handler->ljq();
auto th = std::make_shared<nf7::luajit::Thread>(
auto th = nf7::luajit::Thread::Create(
[self](auto& th, auto L) { self->HandleThread(th, L); });
th_.emplace_back(th);

View File

@ -181,9 +181,9 @@ class Obj::ExecTask final : public nf7::Task<std::shared_ptr<nf7::luajit::Ref>>
interfaceOrThrow<nf7::luajit::Queue>().self();
ljq->Push(self(), [&](auto L) {
try {
auto thL = th.Init(self(), ljq, L);
auto thL = th->Init(self(), ljq, L);
Compile(thL);
th.Resume(thL, 0);
th->Resume(thL, 0);
} catch (Exception&) {
lua_pro.Throw(std::current_exception());
}