add luajit::Thread to support lua coroutine
This commit is contained in:
parent
58ba3071c2
commit
580b42c3fd
123
common/luajit_thread.hh
Normal file
123
common/luajit_thread.hh
Normal file
@ -0,0 +1,123 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include <lua.hpp>
|
||||
|
||||
#include "nf7.hh"
|
||||
|
||||
#include "common/luajit_ref.hh"
|
||||
|
||||
|
||||
namespace nf7::luajit {
|
||||
|
||||
class Thread final {
|
||||
public:
|
||||
static constexpr size_t kInstructionLimit = 10000000;
|
||||
static constexpr const char* kInstanceName = "nf7::luajit::Thread::instance_";
|
||||
|
||||
enum State { kInitial, kRunning, kPaused, kFinished, kAborted, };
|
||||
using Handler = std::function<void(State, lua_State*)>;
|
||||
|
||||
class Exception final : public nf7::Exception {
|
||||
public:
|
||||
using nf7::Exception::Exception;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static Thread CreateForPromise(nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
|
||||
return Thread([&pro, f = std::move(f)](auto state, auto L) {
|
||||
switch (state) {
|
||||
case kPaused:
|
||||
pro.Throw(std::make_exception_ptr<nf7::Exception>({"unexpected yield"}));
|
||||
break;
|
||||
case kFinished:
|
||||
pro.Wrap([&]() { return f(L); });
|
||||
break;
|
||||
case kAborted:
|
||||
pro.Throw(std::make_exception_ptr<nf7::Exception>({lua_tostring(L, -1)}));
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
throw 0;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Thread() = delete;
|
||||
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
|
||||
}
|
||||
Thread(const Thread&) = delete;
|
||||
Thread(Thread&&) = delete;
|
||||
Thread& operator=(const Thread&) = delete;
|
||||
Thread& operator=(Thread&&) = delete;
|
||||
|
||||
// must be called on luajit thread
|
||||
// be carefully on recursive reference, ctx is held until *this is destructed.
|
||||
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
|
||||
const std::shared_ptr<nf7::luajit::Queue>& q,
|
||||
lua_State* L) noexcept {
|
||||
assert(state_ == kInitial);
|
||||
|
||||
th_ = lua_newthread(L);
|
||||
PushImmEnv(L);
|
||||
lua_setfenv(L, -2);
|
||||
th_ref_.emplace(ctx, q, luaL_ref(L, LUA_REGISTRYINDEX));
|
||||
|
||||
state_ = kPaused;
|
||||
return th_;
|
||||
}
|
||||
|
||||
// must be called on luajit thread
|
||||
void Resume(lua_State* L, int narg) noexcept {
|
||||
assert(L == th_);
|
||||
assert(state_ == kPaused);
|
||||
(void) L;
|
||||
|
||||
static const auto kHook = [](auto L, auto) {
|
||||
luaL_error(L, "reached instruction limit (<=1e7)");
|
||||
};
|
||||
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
|
||||
|
||||
lua_pushstring(th_, kInstanceName);
|
||||
lua_pushlightuserdata(th_, this);
|
||||
lua_rawset(th_, LUA_REGISTRYINDEX);
|
||||
|
||||
state_ = kRunning;
|
||||
switch (lua_resume(th_, narg)) {
|
||||
case 0:
|
||||
state_ = kFinished;
|
||||
break;
|
||||
case LUA_YIELD:
|
||||
state_ = kPaused;
|
||||
break;
|
||||
default:
|
||||
state_ = kAborted;
|
||||
}
|
||||
if (!std::exchange(skip_handle_, false)) {
|
||||
handler_(state_, th_);
|
||||
}
|
||||
}
|
||||
|
||||
// handler_ won't be called on next yielding
|
||||
void ExpectYield() noexcept {
|
||||
skip_handle_ = true;
|
||||
}
|
||||
|
||||
State state() const noexcept { return state_; }
|
||||
|
||||
private:
|
||||
Handler handler_;
|
||||
std::atomic<State> state_ = kInitial;
|
||||
|
||||
lua_State* th_ = nullptr;
|
||||
std::optional<nf7::luajit::Ref> th_ref_;
|
||||
|
||||
bool skip_handle_ = false;
|
||||
};
|
||||
|
||||
} // namespace nf7::luajit
|
@ -21,6 +21,7 @@
|
||||
#include "common/luajit.hh"
|
||||
#include "common/luajit_obj.hh"
|
||||
#include "common/luajit_queue.hh"
|
||||
#include "common/luajit_thread.hh"
|
||||
#include "common/logger_ref.hh"
|
||||
#include "common/ptr_selector.hh"
|
||||
#include "common/yas_nf7.hh"
|
||||
@ -148,50 +149,82 @@ class Obj::ExecTask final : public nf7::Context, public std::enable_shared_from_
|
||||
bool buf_consumed_ = false;
|
||||
|
||||
|
||||
nf7::Future<std::shared_ptr<nf7::luajit::Ref>>::Coro Proc() noexcept
|
||||
try {
|
||||
auto self = shared_from_this();
|
||||
nf7::Future<std::shared_ptr<nf7::luajit::Ref>>::Coro Proc() noexcept {
|
||||
try {
|
||||
auto self = shared_from_this();
|
||||
|
||||
auto& srcf = *target_->src_;
|
||||
chunkname_ = srcf.abspath().Stringify();
|
||||
auto& srcf = *target_->src_;
|
||||
chunkname_ = srcf.abspath().Stringify();
|
||||
|
||||
auto src = srcf.interfaceOrThrow<nf7::AsyncBuffer>().self();
|
||||
auto srclock = co_await src->AcquireLock(false).awaiter(self);
|
||||
log_->Trace("source file lock acquired");
|
||||
// acquire lock of source
|
||||
auto src = srcf.interfaceOrThrow<nf7::AsyncBuffer>().self();
|
||||
auto srclock = co_await src->AcquireLock(false).awaiter(self);
|
||||
log_->Trace("source file lock acquired");
|
||||
|
||||
buf_size_ = co_await src->size().awaiter(self);
|
||||
if (buf_size_ == 0) {
|
||||
throw nf7::Exception("source is empty");
|
||||
// get size of source
|
||||
buf_size_ = co_await src->size().awaiter(self);
|
||||
if (buf_size_ == 0) {
|
||||
throw nf7::Exception("source is empty");
|
||||
}
|
||||
if (buf_size_ > kMaxSize) {
|
||||
throw nf7::Exception("source is too huge");
|
||||
}
|
||||
|
||||
// read source
|
||||
buf_.resize(buf_size_);
|
||||
const size_t read = co_await src->Read(0, buf_.data(), buf_size_).awaiter(self);
|
||||
if (read != buf_size_) {
|
||||
throw nf7::Exception("failed to read all bytes from source");
|
||||
}
|
||||
|
||||
// create thread to compile lua script
|
||||
nf7::Future<int>::Promise lua_pro;
|
||||
auto th = nf7::luajit::Thread::CreateForPromise<int>(lua_pro, [&](auto L) {
|
||||
if (lua_gettop(L) != 1) {
|
||||
throw nf7::Exception("expected one object to be returned");
|
||||
}
|
||||
log_->Info("got '"s+lua_tostring(L, -1)+"'");
|
||||
return luaL_ref(L, LUA_REGISTRYINDEX);
|
||||
});
|
||||
|
||||
// context for luajit script running
|
||||
auto lua_ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
|
||||
lua_ctx->description() = "luajit object build script runner";
|
||||
|
||||
// queue task to trigger the thread
|
||||
auto ljq = target_->
|
||||
ResolveUpwardOrThrow("_luajit").
|
||||
interfaceOrThrow<nf7::luajit::Queue>().self();
|
||||
ljq->Push(self, [&](auto L) {
|
||||
try {
|
||||
auto thL = th.Init(lua_ctx, ljq, L);
|
||||
Compile(thL);
|
||||
th.Resume(thL, 0);
|
||||
} catch (Exception&) {
|
||||
lua_pro.Throw(std::current_exception());
|
||||
}
|
||||
});
|
||||
|
||||
// wait for end of execution and return built object's index
|
||||
const int idx = co_await lua_pro.future().awaiter(self);
|
||||
log_->Trace("task finished");
|
||||
|
||||
// context for object cache
|
||||
// TODO use specific Context type
|
||||
auto ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
|
||||
ctx->description() = "luajit object cache";
|
||||
|
||||
// return the object and cache it
|
||||
target_->cache_ = std::make_shared<nf7::luajit::Ref>(ctx, ljq, idx);
|
||||
co_yield target_->cache_;
|
||||
|
||||
} catch (Exception& e) {
|
||||
log_->Error(e.msg());
|
||||
throw;
|
||||
}
|
||||
if (buf_size_ > kMaxSize) {
|
||||
throw nf7::Exception("source is too huge");
|
||||
}
|
||||
|
||||
buf_.resize(buf_size_);
|
||||
const size_t read = co_await src->Read(0, buf_.data(), buf_size_).awaiter(self);
|
||||
if (read != buf_size_) {
|
||||
throw nf7::Exception("failed to read all bytes from source");
|
||||
}
|
||||
|
||||
nf7::Future<int>::Promise lua_pro;
|
||||
auto ljq = target_->
|
||||
ResolveUpwardOrThrow("_luajit").
|
||||
interfaceOrThrow<nf7::luajit::Queue>().self();
|
||||
ljq->Push(self, [&](auto L) { lua_pro.Wrap([&]() { return ExecLua(L); }); });
|
||||
const int idx = co_await lua_pro.future().awaiter(self);
|
||||
log_->Trace("task finished");
|
||||
|
||||
auto ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
|
||||
ctx->description() = "luajit object cache";
|
||||
target_->cache_ = std::make_shared<nf7::luajit::Ref>(ctx, ljq, idx);
|
||||
co_yield target_->cache_;
|
||||
|
||||
} catch (Exception& e) {
|
||||
log_->Error(e.msg());
|
||||
throw;
|
||||
}
|
||||
|
||||
int ExecLua(lua_State* L) {
|
||||
void Compile(lua_State* L) {
|
||||
static const auto kReader = [](lua_State*, void* selfptr, size_t* size) -> const char* {
|
||||
auto self = reinterpret_cast<ExecTask*>(selfptr);
|
||||
if (std::exchange(self->buf_consumed_, true)) {
|
||||
@ -205,15 +238,6 @@ class Obj::ExecTask final : public nf7::Context, public std::enable_shared_from_
|
||||
if (0 != lua_load(L, kReader, this, chunkname_.c_str())) {
|
||||
throw nf7::Exception(lua_tostring(L, -1));
|
||||
}
|
||||
if (0 != nf7::luajit::SandboxCall(L, 0, 1)) {
|
||||
throw nf7::Exception(lua_tostring(L, -1));
|
||||
}
|
||||
log_->Trace("executed lua script and got "s+lua_typename(L, lua_type(L, -1)));
|
||||
const auto ret = luaL_ref(L, LUA_REGISTRYINDEX);
|
||||
if (ret == LUA_REFNIL) {
|
||||
throw nf7::Exception("got nil object");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user