implement pushing weak_ptr into Lua stack
This commit is contained in:
parent
0f131588b4
commit
64834f2b4c
@ -68,9 +68,12 @@ target_sources(nf7
|
||||
common/logger.hh
|
||||
common/logger_ref.hh
|
||||
common/luajit.hh
|
||||
common/luajit.cc
|
||||
common/luajit_obj.hh
|
||||
common/luajit_queue.hh
|
||||
common/luajit_ref.hh
|
||||
common/luajit_thread.hh
|
||||
common/luajit_thread.cc
|
||||
common/memento.hh
|
||||
common/native_file.hh
|
||||
common/node.hh
|
||||
|
23
common/luajit.cc
Normal file
23
common/luajit.cc
Normal file
@ -0,0 +1,23 @@
|
||||
#include "common/luajit.hh"
|
||||
|
||||
#include <lua.hpp>
|
||||
|
||||
|
||||
namespace nf7::luajit {
|
||||
|
||||
void PushGlobalTable(lua_State* L) noexcept {
|
||||
luaL_newmetatable(L, "nf7::luajit::PushGlobalTable");
|
||||
}
|
||||
void PushImmEnv(lua_State* L) noexcept {
|
||||
if (luaL_newmetatable(L, "nf7::luajit::PushImmEnv")) {
|
||||
lua_createtable(L, 0, 0);
|
||||
PushGlobalTable(L);
|
||||
lua_setfield(L, -2, "__index");
|
||||
|
||||
lua_pushcfunction(L, [](auto L) { return luaL_error(L, "global is immutable"); });
|
||||
lua_setfield(L, -2, "__newindex");
|
||||
lua_setmetatable(L, -2);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace nf7::luajit
|
@ -1,34 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <lua.hpp>
|
||||
|
||||
|
||||
namespace nf7::luajit {
|
||||
|
||||
static inline void PushImmEnv(lua_State* L) noexcept {
|
||||
if (luaL_newmetatable(L, "nf7::luajit::PushImmEnv")) {
|
||||
lua_createtable(L, 0, 0);
|
||||
lua_pushvalue(L, LUA_GLOBALSINDEX);
|
||||
lua_setfield(L, -2, "__index");
|
||||
void PushGlobalTable(lua_State*) noexcept;
|
||||
void PushImmEnv(lua_State*) noexcept;
|
||||
|
||||
lua_pushcfunction(L, [](auto L) { return luaL_error(L, "global is immutable"); });
|
||||
lua_setfield(L, -2, "__newindex");
|
||||
lua_setmetatable(L, -2);
|
||||
}
|
||||
template <typename T>
|
||||
inline void PushWeakPtr(lua_State* L, const std::weak_ptr<T>& wptr) noexcept {
|
||||
new (lua_newuserdata(L, sizeof(wptr))) std::weak_ptr<T>(wptr);
|
||||
}
|
||||
|
||||
static inline int SandboxCall(lua_State* L, int narg, int nret) noexcept {
|
||||
constexpr size_t kSandboxInstructionLimit = 10000000;
|
||||
|
||||
static const auto kHook = [](auto L, auto) {
|
||||
luaL_error(L, "reached instruction limit (<=1e7)");
|
||||
};
|
||||
lua_sethook(L, kHook, LUA_MASKCOUNT, kSandboxInstructionLimit);
|
||||
|
||||
PushImmEnv(L);
|
||||
lua_setfenv(L, -narg-2);
|
||||
|
||||
return lua_pcall(L, narg, nret, 0);
|
||||
template <typename T>
|
||||
inline std::weak_ptr<T>& ToWeakPtr(lua_State* L, int idx) noexcept {
|
||||
std::weak_ptr<T>* wptr = reinterpret_cast<decltype(wptr)>(lua_touserdata(L, idx));
|
||||
return *wptr;
|
||||
}
|
||||
template <typename T>
|
||||
inline std::shared_ptr<T> ToSharedPtr(lua_State* L, int idx) {
|
||||
if (auto ret = ToWeakPtr<T>(L, idx).lock()) {
|
||||
return ret;
|
||||
}
|
||||
luaL_error(L, "object expired: %s", typeid(T).name());
|
||||
return nullptr;
|
||||
}
|
||||
template <typename T>
|
||||
inline void PushWeakPtrDeleter(lua_State* L, const std::weak_ptr<T>& = {}) noexcept {
|
||||
lua_pushcfunction(L, [](auto L) {
|
||||
ToWeakPtr<T>(L, 1).~weak_ptr();
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace nf7
|
||||
|
68
common/luajit_thread.cc
Normal file
68
common/luajit_thread.cc
Normal file
@ -0,0 +1,68 @@
|
||||
#include "common/luajit_thread.hh"
|
||||
|
||||
|
||||
namespace nf7::luajit {
|
||||
|
||||
constexpr size_t kInstructionLimit = 10000000;
|
||||
|
||||
|
||||
void Thread::PushMeta(lua_State* L) noexcept {
|
||||
if (luaL_newmetatable(L, "nf7::luajit::Thread")) {
|
||||
PushWeakPtrDeleter<Thread>(L);
|
||||
lua_setfield(L, -2, "__gc");
|
||||
|
||||
lua_createtable(L, 0, 0);
|
||||
{
|
||||
lua_pushcfunction(L, [](auto L) {
|
||||
auto th = ToSharedPtr<Thread>(L, 1);
|
||||
th->ExpectYield();
|
||||
th->ljq()->Push(th->ctx(), [th, L](auto) { th->Resume(L, 0); });
|
||||
return lua_yield(L, lua_gettop(L)-1);
|
||||
});
|
||||
lua_setfield(L, -2, "yield");
|
||||
}
|
||||
lua_setfield(L, -2, "__index");
|
||||
}
|
||||
}
|
||||
|
||||
void Thread::Resume(lua_State* L, int narg) noexcept {
|
||||
std::unique_lock<std::mutex> k(mtx_);
|
||||
|
||||
if (state_ == kAborted) return;
|
||||
assert(L == th_);
|
||||
assert(state_ == kPaused);
|
||||
(void) L;
|
||||
|
||||
static const auto kHook = [](auto L, auto) {
|
||||
luaL_error(L, "reached instruction limit (<=1e7)");
|
||||
};
|
||||
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
|
||||
|
||||
PushGlobalTable(th_);
|
||||
PushWeakPtr(th_, weak_from_this());
|
||||
Thread::PushMeta(th_);
|
||||
lua_setmetatable(th_, -2);
|
||||
lua_setfield(th_, -2, "nf7");
|
||||
lua_pop(th_, 1);
|
||||
|
||||
state_ = kRunning;
|
||||
k.unlock();
|
||||
const auto ret = lua_resume(th_, narg);
|
||||
k.lock();
|
||||
if (state_ == kAborted) return;
|
||||
switch (ret) {
|
||||
case 0:
|
||||
state_ = kFinished;
|
||||
break;
|
||||
case LUA_YIELD:
|
||||
state_ = kPaused;
|
||||
break;
|
||||
default:
|
||||
state_ = kAborted;
|
||||
}
|
||||
if (!std::exchange(skip_handle_, false)) {
|
||||
handler_(*this, th_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace nf7::luajit
|
@ -11,6 +11,7 @@
|
||||
|
||||
#include "nf7.hh"
|
||||
|
||||
#include "common/future.hh"
|
||||
#include "common/luajit.hh"
|
||||
#include "common/luajit_ref.hh"
|
||||
|
||||
@ -19,9 +20,6 @@ namespace nf7::luajit {
|
||||
|
||||
class Thread final : public std::enable_shared_from_this<Thread> {
|
||||
public:
|
||||
static constexpr size_t kInstructionLimit = 10000000;
|
||||
static constexpr const char* kInstanceName = "nf7::luajit::Thread::instance_";
|
||||
|
||||
enum State { kInitial, kRunning, kPaused, kFinished, kAborted, };
|
||||
using Handler = std::function<void(Thread&, lua_State*)>;
|
||||
|
||||
@ -30,9 +28,13 @@ class Thread final : public std::enable_shared_from_this<Thread> {
|
||||
using nf7::Exception::Exception;
|
||||
};
|
||||
|
||||
static std::shared_ptr<Thread> Create(Handler&& handler) noexcept {
|
||||
return std::shared_ptr<Thread>{new Thread{std::move(handler)}};
|
||||
}
|
||||
template <typename T>
|
||||
static Thread CreateForPromise(nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
|
||||
return Thread([&pro, f = std::move(f)](auto& self, auto L) {
|
||||
static std::shared_ptr<Thread> CreateForPromise(
|
||||
nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
|
||||
return std::shared_ptr<Thread>(new Thread{[&pro, f = std::move(f)](auto& self, auto L) {
|
||||
switch (self.state()) {
|
||||
case kPaused:
|
||||
pro.Throw(std::make_exception_ptr<nf7::Exception>({"unexpected yield"}));
|
||||
@ -47,71 +49,37 @@ class Thread final : public std::enable_shared_from_this<Thread> {
|
||||
assert(false);
|
||||
throw 0;
|
||||
}
|
||||
});
|
||||
}});
|
||||
}
|
||||
|
||||
static void PushMeta(lua_State*) noexcept;
|
||||
|
||||
Thread() = delete;
|
||||
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
|
||||
}
|
||||
Thread(const Thread&) = delete;
|
||||
Thread(Thread&&) = delete;
|
||||
Thread& operator=(const Thread&) = delete;
|
||||
Thread& operator=(Thread&&) = delete;
|
||||
|
||||
// must be called on luajit thread
|
||||
// be carefully on recursive reference, ctx is held until *this is destructed.
|
||||
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
|
||||
const std::shared_ptr<nf7::luajit::Queue>& q,
|
||||
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
|
||||
const std::shared_ptr<nf7::luajit::Queue>& ljq,
|
||||
lua_State* L) noexcept {
|
||||
assert(state_ == kInitial);
|
||||
|
||||
ctx_ = ctx;
|
||||
ljq_ = ljq;
|
||||
|
||||
th_ = lua_newthread(L);
|
||||
PushImmEnv(L);
|
||||
lua_setfenv(L, -2);
|
||||
th_ref_.emplace(ctx, q, luaL_ref(L, LUA_REGISTRYINDEX));
|
||||
th_ref_.emplace(ctx, ljq, luaL_ref(L, LUA_REGISTRYINDEX));
|
||||
|
||||
state_ = kPaused;
|
||||
return th_;
|
||||
}
|
||||
|
||||
// must be called on luajit thread
|
||||
void Resume(lua_State* L, int narg) noexcept {
|
||||
std::unique_lock<std::mutex> k(mtx_);
|
||||
|
||||
if (state_ == kAborted) return;
|
||||
assert(L == th_);
|
||||
assert(state_ == kPaused);
|
||||
(void) L;
|
||||
|
||||
static const auto kHook = [](auto L, auto) {
|
||||
luaL_error(L, "reached instruction limit (<=1e7)");
|
||||
};
|
||||
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
|
||||
|
||||
// TODO: push weak_ptr instead
|
||||
lua_pushstring(th_, kInstanceName);
|
||||
lua_pushlightuserdata(th_, this);
|
||||
lua_rawset(th_, LUA_REGISTRYINDEX);
|
||||
|
||||
state_ = kRunning;
|
||||
k.unlock();
|
||||
const auto ret = lua_resume(th_, narg);
|
||||
k.lock();
|
||||
if (state_ == kAborted) return;
|
||||
switch (ret) {
|
||||
case 0:
|
||||
state_ = kFinished;
|
||||
break;
|
||||
case LUA_YIELD:
|
||||
state_ = kPaused;
|
||||
break;
|
||||
default:
|
||||
state_ = kAborted;
|
||||
}
|
||||
if (!std::exchange(skip_handle_, false)) {
|
||||
handler_(*this, th_);
|
||||
}
|
||||
}
|
||||
void Resume(lua_State* L, int narg) noexcept;
|
||||
|
||||
void Abort() noexcept {
|
||||
std::unique_lock<std::mutex> k(mtx_);
|
||||
@ -123,6 +91,8 @@ class Thread final : public std::enable_shared_from_this<Thread> {
|
||||
skip_handle_ = true;
|
||||
}
|
||||
|
||||
const std::shared_ptr<nf7::Context>& ctx() const noexcept { return ctx_; }
|
||||
const std::shared_ptr<nf7::luajit::Queue>& ljq() const noexcept { return ljq_; }
|
||||
State state() const noexcept { return state_; }
|
||||
|
||||
private:
|
||||
@ -131,10 +101,17 @@ class Thread final : public std::enable_shared_from_this<Thread> {
|
||||
Handler handler_;
|
||||
std::atomic<State> state_ = kInitial;
|
||||
|
||||
std::shared_ptr<nf7::Context> ctx_;
|
||||
std::shared_ptr<nf7::luajit::Queue> ljq_;
|
||||
|
||||
lua_State* th_ = nullptr;
|
||||
std::optional<nf7::luajit::Ref> th_ref_;
|
||||
|
||||
bool skip_handle_ = false;
|
||||
|
||||
|
||||
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace nf7::luajit
|
||||
|
@ -217,7 +217,7 @@ class Node::Lambda final : public nf7::Lambda,
|
||||
auto handler = handler_.value();
|
||||
ljq_ = handler->ljq();
|
||||
|
||||
auto th = std::make_shared<nf7::luajit::Thread>(
|
||||
auto th = nf7::luajit::Thread::Create(
|
||||
[self](auto& th, auto L) { self->HandleThread(th, L); });
|
||||
th_.emplace_back(th);
|
||||
|
||||
|
@ -181,9 +181,9 @@ class Obj::ExecTask final : public nf7::Task<std::shared_ptr<nf7::luajit::Ref>>
|
||||
interfaceOrThrow<nf7::luajit::Queue>().self();
|
||||
ljq->Push(self(), [&](auto L) {
|
||||
try {
|
||||
auto thL = th.Init(self(), ljq, L);
|
||||
auto thL = th->Init(self(), ljq, L);
|
||||
Compile(thL);
|
||||
th.Resume(thL, 0);
|
||||
th->Resume(thL, 0);
|
||||
} catch (Exception&) {
|
||||
lua_pro.Throw(std::current_exception());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user