add luajit::Thread to support lua coroutine

This commit is contained in:
falsycat 2022-06-09 15:55:26 +09:00
parent 58ba3071c2
commit 580b42c3fd
2 changed files with 194 additions and 47 deletions

123
common/luajit_thread.hh Normal file
View File

@ -0,0 +1,123 @@
#pragma once
#include <atomic>
#include <functional>
#include <memory>
#include <optional>
#include <utility>
#include <lua.hpp>
#include "nf7.hh"
#include "common/luajit_ref.hh"
namespace nf7::luajit {
class Thread final {
public:
static constexpr size_t kInstructionLimit = 10000000;
static constexpr const char* kInstanceName = "nf7::luajit::Thread::instance_";
enum State { kInitial, kRunning, kPaused, kFinished, kAborted, };
using Handler = std::function<void(State, lua_State*)>;
class Exception final : public nf7::Exception {
public:
using nf7::Exception::Exception;
};
template <typename T>
static Thread CreateForPromise(nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
return Thread([&pro, f = std::move(f)](auto state, auto L) {
switch (state) {
case kPaused:
pro.Throw(std::make_exception_ptr<nf7::Exception>({"unexpected yield"}));
break;
case kFinished:
pro.Wrap([&]() { return f(L); });
break;
case kAborted:
pro.Throw(std::make_exception_ptr<nf7::Exception>({lua_tostring(L, -1)}));
break;
default:
assert(false);
throw 0;
}
});
}
Thread() = delete;
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
}
Thread(const Thread&) = delete;
Thread(Thread&&) = delete;
Thread& operator=(const Thread&) = delete;
Thread& operator=(Thread&&) = delete;
// must be called on luajit thread
// be carefully on recursive reference, ctx is held until *this is destructed.
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
const std::shared_ptr<nf7::luajit::Queue>& q,
lua_State* L) noexcept {
assert(state_ == kInitial);
th_ = lua_newthread(L);
PushImmEnv(L);
lua_setfenv(L, -2);
th_ref_.emplace(ctx, q, luaL_ref(L, LUA_REGISTRYINDEX));
state_ = kPaused;
return th_;
}
// must be called on luajit thread
void Resume(lua_State* L, int narg) noexcept {
assert(L == th_);
assert(state_ == kPaused);
(void) L;
static const auto kHook = [](auto L, auto) {
luaL_error(L, "reached instruction limit (<=1e7)");
};
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
lua_pushstring(th_, kInstanceName);
lua_pushlightuserdata(th_, this);
lua_rawset(th_, LUA_REGISTRYINDEX);
state_ = kRunning;
switch (lua_resume(th_, narg)) {
case 0:
state_ = kFinished;
break;
case LUA_YIELD:
state_ = kPaused;
break;
default:
state_ = kAborted;
}
if (!std::exchange(skip_handle_, false)) {
handler_(state_, th_);
}
}
// handler_ won't be called on next yielding
void ExpectYield() noexcept {
skip_handle_ = true;
}
State state() const noexcept { return state_; }
private:
Handler handler_;
std::atomic<State> state_ = kInitial;
lua_State* th_ = nullptr;
std::optional<nf7::luajit::Ref> th_ref_;
bool skip_handle_ = false;
};
} // namespace nf7::luajit

View File

@ -21,6 +21,7 @@
#include "common/luajit.hh"
#include "common/luajit_obj.hh"
#include "common/luajit_queue.hh"
#include "common/luajit_thread.hh"
#include "common/logger_ref.hh"
#include "common/ptr_selector.hh"
#include "common/yas_nf7.hh"
@ -148,50 +149,82 @@ class Obj::ExecTask final : public nf7::Context, public std::enable_shared_from_
bool buf_consumed_ = false;
nf7::Future<std::shared_ptr<nf7::luajit::Ref>>::Coro Proc() noexcept
try {
auto self = shared_from_this();
nf7::Future<std::shared_ptr<nf7::luajit::Ref>>::Coro Proc() noexcept {
try {
auto self = shared_from_this();
auto& srcf = *target_->src_;
chunkname_ = srcf.abspath().Stringify();
auto& srcf = *target_->src_;
chunkname_ = srcf.abspath().Stringify();
auto src = srcf.interfaceOrThrow<nf7::AsyncBuffer>().self();
auto srclock = co_await src->AcquireLock(false).awaiter(self);
log_->Trace("source file lock acquired");
// acquire lock of source
auto src = srcf.interfaceOrThrow<nf7::AsyncBuffer>().self();
auto srclock = co_await src->AcquireLock(false).awaiter(self);
log_->Trace("source file lock acquired");
buf_size_ = co_await src->size().awaiter(self);
if (buf_size_ == 0) {
throw nf7::Exception("source is empty");
// get size of source
buf_size_ = co_await src->size().awaiter(self);
if (buf_size_ == 0) {
throw nf7::Exception("source is empty");
}
if (buf_size_ > kMaxSize) {
throw nf7::Exception("source is too huge");
}
// read source
buf_.resize(buf_size_);
const size_t read = co_await src->Read(0, buf_.data(), buf_size_).awaiter(self);
if (read != buf_size_) {
throw nf7::Exception("failed to read all bytes from source");
}
// create thread to compile lua script
nf7::Future<int>::Promise lua_pro;
auto th = nf7::luajit::Thread::CreateForPromise<int>(lua_pro, [&](auto L) {
if (lua_gettop(L) != 1) {
throw nf7::Exception("expected one object to be returned");
}
log_->Info("got '"s+lua_tostring(L, -1)+"'");
return luaL_ref(L, LUA_REGISTRYINDEX);
});
// context for luajit script running
auto lua_ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
lua_ctx->description() = "luajit object build script runner";
// queue task to trigger the thread
auto ljq = target_->
ResolveUpwardOrThrow("_luajit").
interfaceOrThrow<nf7::luajit::Queue>().self();
ljq->Push(self, [&](auto L) {
try {
auto thL = th.Init(lua_ctx, ljq, L);
Compile(thL);
th.Resume(thL, 0);
} catch (Exception&) {
lua_pro.Throw(std::current_exception());
}
});
// wait for end of execution and return built object's index
const int idx = co_await lua_pro.future().awaiter(self);
log_->Trace("task finished");
// context for object cache
// TODO use specific Context type
auto ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
ctx->description() = "luajit object cache";
// return the object and cache it
target_->cache_ = std::make_shared<nf7::luajit::Ref>(ctx, ljq, idx);
co_yield target_->cache_;
} catch (Exception& e) {
log_->Error(e.msg());
throw;
}
if (buf_size_ > kMaxSize) {
throw nf7::Exception("source is too huge");
}
buf_.resize(buf_size_);
const size_t read = co_await src->Read(0, buf_.data(), buf_size_).awaiter(self);
if (read != buf_size_) {
throw nf7::Exception("failed to read all bytes from source");
}
nf7::Future<int>::Promise lua_pro;
auto ljq = target_->
ResolveUpwardOrThrow("_luajit").
interfaceOrThrow<nf7::luajit::Queue>().self();
ljq->Push(self, [&](auto L) { lua_pro.Wrap([&]() { return ExecLua(L); }); });
const int idx = co_await lua_pro.future().awaiter(self);
log_->Trace("task finished");
auto ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
ctx->description() = "luajit object cache";
target_->cache_ = std::make_shared<nf7::luajit::Ref>(ctx, ljq, idx);
co_yield target_->cache_;
} catch (Exception& e) {
log_->Error(e.msg());
throw;
}
int ExecLua(lua_State* L) {
void Compile(lua_State* L) {
static const auto kReader = [](lua_State*, void* selfptr, size_t* size) -> const char* {
auto self = reinterpret_cast<ExecTask*>(selfptr);
if (std::exchange(self->buf_consumed_, true)) {
@ -205,15 +238,6 @@ class Obj::ExecTask final : public nf7::Context, public std::enable_shared_from_
if (0 != lua_load(L, kReader, this, chunkname_.c_str())) {
throw nf7::Exception(lua_tostring(L, -1));
}
if (0 != nf7::luajit::SandboxCall(L, 0, 1)) {
throw nf7::Exception(lua_tostring(L, -1));
}
log_->Trace("executed lua script and got "s+lua_typename(L, lua_type(L, -1)));
const auto ret = luaL_ref(L, LUA_REGISTRYINDEX);
if (ret == LUA_REFNIL) {
throw nf7::Exception("got nil object");
}
return ret;
}
};