add luajit::Thread to support lua coroutine
This commit is contained in:
parent
58ba3071c2
commit
580b42c3fd
123
common/luajit_thread.hh
Normal file
123
common/luajit_thread.hh
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include <lua.hpp>
|
||||||
|
|
||||||
|
#include "nf7.hh"
|
||||||
|
|
||||||
|
#include "common/luajit_ref.hh"
|
||||||
|
|
||||||
|
|
||||||
|
namespace nf7::luajit {
|
||||||
|
|
||||||
|
class Thread final {
|
||||||
|
public:
|
||||||
|
static constexpr size_t kInstructionLimit = 10000000;
|
||||||
|
static constexpr const char* kInstanceName = "nf7::luajit::Thread::instance_";
|
||||||
|
|
||||||
|
enum State { kInitial, kRunning, kPaused, kFinished, kAborted, };
|
||||||
|
using Handler = std::function<void(State, lua_State*)>;
|
||||||
|
|
||||||
|
class Exception final : public nf7::Exception {
|
||||||
|
public:
|
||||||
|
using nf7::Exception::Exception;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static Thread CreateForPromise(nf7::Future<T>::Promise& pro, std::function<T(lua_State*)>&& f) noexcept {
|
||||||
|
return Thread([&pro, f = std::move(f)](auto state, auto L) {
|
||||||
|
switch (state) {
|
||||||
|
case kPaused:
|
||||||
|
pro.Throw(std::make_exception_ptr<nf7::Exception>({"unexpected yield"}));
|
||||||
|
break;
|
||||||
|
case kFinished:
|
||||||
|
pro.Wrap([&]() { return f(L); });
|
||||||
|
break;
|
||||||
|
case kAborted:
|
||||||
|
pro.Throw(std::make_exception_ptr<nf7::Exception>({lua_tostring(L, -1)}));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
throw 0;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Thread() = delete;
|
||||||
|
Thread(Handler&& handler) noexcept : handler_(std::move(handler)) {
|
||||||
|
}
|
||||||
|
Thread(const Thread&) = delete;
|
||||||
|
Thread(Thread&&) = delete;
|
||||||
|
Thread& operator=(const Thread&) = delete;
|
||||||
|
Thread& operator=(Thread&&) = delete;
|
||||||
|
|
||||||
|
// must be called on luajit thread
|
||||||
|
// be carefully on recursive reference, ctx is held until *this is destructed.
|
||||||
|
lua_State* Init(const std::shared_ptr<nf7::Context>& ctx,
|
||||||
|
const std::shared_ptr<nf7::luajit::Queue>& q,
|
||||||
|
lua_State* L) noexcept {
|
||||||
|
assert(state_ == kInitial);
|
||||||
|
|
||||||
|
th_ = lua_newthread(L);
|
||||||
|
PushImmEnv(L);
|
||||||
|
lua_setfenv(L, -2);
|
||||||
|
th_ref_.emplace(ctx, q, luaL_ref(L, LUA_REGISTRYINDEX));
|
||||||
|
|
||||||
|
state_ = kPaused;
|
||||||
|
return th_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// must be called on luajit thread
|
||||||
|
void Resume(lua_State* L, int narg) noexcept {
|
||||||
|
assert(L == th_);
|
||||||
|
assert(state_ == kPaused);
|
||||||
|
(void) L;
|
||||||
|
|
||||||
|
static const auto kHook = [](auto L, auto) {
|
||||||
|
luaL_error(L, "reached instruction limit (<=1e7)");
|
||||||
|
};
|
||||||
|
lua_sethook(th_, kHook, LUA_MASKCOUNT, kInstructionLimit);
|
||||||
|
|
||||||
|
lua_pushstring(th_, kInstanceName);
|
||||||
|
lua_pushlightuserdata(th_, this);
|
||||||
|
lua_rawset(th_, LUA_REGISTRYINDEX);
|
||||||
|
|
||||||
|
state_ = kRunning;
|
||||||
|
switch (lua_resume(th_, narg)) {
|
||||||
|
case 0:
|
||||||
|
state_ = kFinished;
|
||||||
|
break;
|
||||||
|
case LUA_YIELD:
|
||||||
|
state_ = kPaused;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
state_ = kAborted;
|
||||||
|
}
|
||||||
|
if (!std::exchange(skip_handle_, false)) {
|
||||||
|
handler_(state_, th_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handler_ won't be called on next yielding
|
||||||
|
void ExpectYield() noexcept {
|
||||||
|
skip_handle_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
State state() const noexcept { return state_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Handler handler_;
|
||||||
|
std::atomic<State> state_ = kInitial;
|
||||||
|
|
||||||
|
lua_State* th_ = nullptr;
|
||||||
|
std::optional<nf7::luajit::Ref> th_ref_;
|
||||||
|
|
||||||
|
bool skip_handle_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nf7::luajit
|
@ -21,6 +21,7 @@
|
|||||||
#include "common/luajit.hh"
|
#include "common/luajit.hh"
|
||||||
#include "common/luajit_obj.hh"
|
#include "common/luajit_obj.hh"
|
||||||
#include "common/luajit_queue.hh"
|
#include "common/luajit_queue.hh"
|
||||||
|
#include "common/luajit_thread.hh"
|
||||||
#include "common/logger_ref.hh"
|
#include "common/logger_ref.hh"
|
||||||
#include "common/ptr_selector.hh"
|
#include "common/ptr_selector.hh"
|
||||||
#include "common/yas_nf7.hh"
|
#include "common/yas_nf7.hh"
|
||||||
@ -148,50 +149,82 @@ class Obj::ExecTask final : public nf7::Context, public std::enable_shared_from_
|
|||||||
bool buf_consumed_ = false;
|
bool buf_consumed_ = false;
|
||||||
|
|
||||||
|
|
||||||
nf7::Future<std::shared_ptr<nf7::luajit::Ref>>::Coro Proc() noexcept
|
nf7::Future<std::shared_ptr<nf7::luajit::Ref>>::Coro Proc() noexcept {
|
||||||
try {
|
try {
|
||||||
auto self = shared_from_this();
|
auto self = shared_from_this();
|
||||||
|
|
||||||
auto& srcf = *target_->src_;
|
auto& srcf = *target_->src_;
|
||||||
chunkname_ = srcf.abspath().Stringify();
|
chunkname_ = srcf.abspath().Stringify();
|
||||||
|
|
||||||
auto src = srcf.interfaceOrThrow<nf7::AsyncBuffer>().self();
|
// acquire lock of source
|
||||||
auto srclock = co_await src->AcquireLock(false).awaiter(self);
|
auto src = srcf.interfaceOrThrow<nf7::AsyncBuffer>().self();
|
||||||
log_->Trace("source file lock acquired");
|
auto srclock = co_await src->AcquireLock(false).awaiter(self);
|
||||||
|
log_->Trace("source file lock acquired");
|
||||||
|
|
||||||
buf_size_ = co_await src->size().awaiter(self);
|
// get size of source
|
||||||
if (buf_size_ == 0) {
|
buf_size_ = co_await src->size().awaiter(self);
|
||||||
throw nf7::Exception("source is empty");
|
if (buf_size_ == 0) {
|
||||||
|
throw nf7::Exception("source is empty");
|
||||||
|
}
|
||||||
|
if (buf_size_ > kMaxSize) {
|
||||||
|
throw nf7::Exception("source is too huge");
|
||||||
|
}
|
||||||
|
|
||||||
|
// read source
|
||||||
|
buf_.resize(buf_size_);
|
||||||
|
const size_t read = co_await src->Read(0, buf_.data(), buf_size_).awaiter(self);
|
||||||
|
if (read != buf_size_) {
|
||||||
|
throw nf7::Exception("failed to read all bytes from source");
|
||||||
|
}
|
||||||
|
|
||||||
|
// create thread to compile lua script
|
||||||
|
nf7::Future<int>::Promise lua_pro;
|
||||||
|
auto th = nf7::luajit::Thread::CreateForPromise<int>(lua_pro, [&](auto L) {
|
||||||
|
if (lua_gettop(L) != 1) {
|
||||||
|
throw nf7::Exception("expected one object to be returned");
|
||||||
|
}
|
||||||
|
log_->Info("got '"s+lua_tostring(L, -1)+"'");
|
||||||
|
return luaL_ref(L, LUA_REGISTRYINDEX);
|
||||||
|
});
|
||||||
|
|
||||||
|
// context for luajit script running
|
||||||
|
auto lua_ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
|
||||||
|
lua_ctx->description() = "luajit object build script runner";
|
||||||
|
|
||||||
|
// queue task to trigger the thread
|
||||||
|
auto ljq = target_->
|
||||||
|
ResolveUpwardOrThrow("_luajit").
|
||||||
|
interfaceOrThrow<nf7::luajit::Queue>().self();
|
||||||
|
ljq->Push(self, [&](auto L) {
|
||||||
|
try {
|
||||||
|
auto thL = th.Init(lua_ctx, ljq, L);
|
||||||
|
Compile(thL);
|
||||||
|
th.Resume(thL, 0);
|
||||||
|
} catch (Exception&) {
|
||||||
|
lua_pro.Throw(std::current_exception());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// wait for end of execution and return built object's index
|
||||||
|
const int idx = co_await lua_pro.future().awaiter(self);
|
||||||
|
log_->Trace("task finished");
|
||||||
|
|
||||||
|
// context for object cache
|
||||||
|
// TODO use specific Context type
|
||||||
|
auto ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
|
||||||
|
ctx->description() = "luajit object cache";
|
||||||
|
|
||||||
|
// return the object and cache it
|
||||||
|
target_->cache_ = std::make_shared<nf7::luajit::Ref>(ctx, ljq, idx);
|
||||||
|
co_yield target_->cache_;
|
||||||
|
|
||||||
|
} catch (Exception& e) {
|
||||||
|
log_->Error(e.msg());
|
||||||
|
throw;
|
||||||
}
|
}
|
||||||
if (buf_size_ > kMaxSize) {
|
|
||||||
throw nf7::Exception("source is too huge");
|
|
||||||
}
|
|
||||||
|
|
||||||
buf_.resize(buf_size_);
|
|
||||||
const size_t read = co_await src->Read(0, buf_.data(), buf_size_).awaiter(self);
|
|
||||||
if (read != buf_size_) {
|
|
||||||
throw nf7::Exception("failed to read all bytes from source");
|
|
||||||
}
|
|
||||||
|
|
||||||
nf7::Future<int>::Promise lua_pro;
|
|
||||||
auto ljq = target_->
|
|
||||||
ResolveUpwardOrThrow("_luajit").
|
|
||||||
interfaceOrThrow<nf7::luajit::Queue>().self();
|
|
||||||
ljq->Push(self, [&](auto L) { lua_pro.Wrap([&]() { return ExecLua(L); }); });
|
|
||||||
const int idx = co_await lua_pro.future().awaiter(self);
|
|
||||||
log_->Trace("task finished");
|
|
||||||
|
|
||||||
auto ctx = std::make_shared<nf7::GenericContext>(env(), initiator());
|
|
||||||
ctx->description() = "luajit object cache";
|
|
||||||
target_->cache_ = std::make_shared<nf7::luajit::Ref>(ctx, ljq, idx);
|
|
||||||
co_yield target_->cache_;
|
|
||||||
|
|
||||||
} catch (Exception& e) {
|
|
||||||
log_->Error(e.msg());
|
|
||||||
throw;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int ExecLua(lua_State* L) {
|
void Compile(lua_State* L) {
|
||||||
static const auto kReader = [](lua_State*, void* selfptr, size_t* size) -> const char* {
|
static const auto kReader = [](lua_State*, void* selfptr, size_t* size) -> const char* {
|
||||||
auto self = reinterpret_cast<ExecTask*>(selfptr);
|
auto self = reinterpret_cast<ExecTask*>(selfptr);
|
||||||
if (std::exchange(self->buf_consumed_, true)) {
|
if (std::exchange(self->buf_consumed_, true)) {
|
||||||
@ -205,15 +238,6 @@ class Obj::ExecTask final : public nf7::Context, public std::enable_shared_from_
|
|||||||
if (0 != lua_load(L, kReader, this, chunkname_.c_str())) {
|
if (0 != lua_load(L, kReader, this, chunkname_.c_str())) {
|
||||||
throw nf7::Exception(lua_tostring(L, -1));
|
throw nf7::Exception(lua_tostring(L, -1));
|
||||||
}
|
}
|
||||||
if (0 != nf7::luajit::SandboxCall(L, 0, 1)) {
|
|
||||||
throw nf7::Exception(lua_tostring(L, -1));
|
|
||||||
}
|
|
||||||
log_->Trace("executed lua script and got "s+lua_typename(L, lua_type(L, -1)));
|
|
||||||
const auto ret = luaL_ref(L, LUA_REGISTRYINDEX);
|
|
||||||
if (ret == LUA_REFNIL) {
|
|
||||||
throw nf7::Exception("got nil object");
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user