diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index 375ac80d..ddad88e8 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -50,23 +50,20 @@ struct PluginBase { std::string_view runtime, std::string_view plugin_configuration, bool fail_open) : name_(std::string(name)), root_id_(std::string(root_id)), vm_id_(std::string(vm_id)), runtime_(std::string(runtime)), plugin_configuration_(plugin_configuration), - fail_open_(fail_open), key_(root_id_ + "||" + plugin_configuration_) {} + fail_open_(fail_open) {} const std::string name_; const std::string root_id_; const std::string vm_id_; const std::string runtime_; - const std::string plugin_configuration_; + std::string plugin_configuration_; const bool fail_open_; - - const std::string &key() const { return key_; } const std::string &log_prefix() const { return log_prefix_; } private: std::string makeLogPrefix() const; - const std::string key_; - const std::string log_prefix_; + std::string log_prefix_; }; struct BufferBase : public BufferInterface { @@ -376,16 +373,16 @@ class ContextBase : public RootInterface, protected: friend class WasmBase; + void initializeRootBase(WasmBase *wasm, std::shared_ptr plugin); std::string makeRootLogPrefix(std::string_view vm_id) const; WasmBase *wasm_{nullptr}; uint32_t id_{0}; - uint32_t parent_context_id_{0}; // 0 for roots and the general context. - ContextBase *parent_context_{nullptr}; // set in all contexts. - std::string root_id_; // set only in root context. - std::string root_log_prefix_; // set only in root context. - std::shared_ptr plugin_; // set in root and stream contexts. - std::shared_ptr temp_plugin_; // Remove once ABI v0.1.0 is gone. + uint32_t parent_context_id_{0}; // 0 for roots and the general context. + ContextBase *parent_context_{nullptr}; // set in all contexts. + std::string root_id_; // set only in root context. + std::string root_log_prefix_; // set only in root context. + std::shared_ptr plugin_; bool in_vm_context_created_ = false; bool destroyed_ = false; }; diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index a3896fc4..b2c694db 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -59,7 +59,8 @@ class WasmBase : public std::enable_shared_from_this { std::string_view vm_key() const { return vm_key_; } WasmVm *wasm_vm() const { return wasm_vm_.get(); } ContextBase *vm_context() const { return vm_context_.get(); } - ContextBase *getRootContext(const std::shared_ptr &plugin, bool allow_closed); + ContextBase *getRootContext(std::string_view root_id); + ContextBase *getOrCreateRootContext(const std::shared_ptr &plugin); ContextBase *getContext(uint32_t id) { auto it = contexts_.find(id); if (it != contexts_.end()) @@ -77,7 +78,6 @@ class WasmBase : public std::enable_shared_from_this { void timerReady(uint32_t root_context_id); void queueReady(uint32_t root_context_id, uint32_t token); - void startShutdown(const std::shared_ptr &plugin); void startShutdown(); WasmResult done(ContextBase *root_context); void finishShutdown(); @@ -164,12 +164,11 @@ class WasmBase : public std::enable_shared_from_this { uint32_t next_context_id_ = 1; // 0 is reserved for the VM context. std::shared_ptr vm_context_; // Context unrelated to any specific root or stream // (e.g. for global constructors). - std::unordered_map> root_contexts_; // Root contexts. - std::unordered_map> pending_done_; // Root contexts. - std::unordered_set> pending_delete_; // Root contexts. + std::unordered_map> root_contexts_; std::unordered_map contexts_; // Contains all contexts. std::unordered_map timer_period_; // per root_id. std::unique_ptr shutdown_handle_; + std::unordered_set pending_done_; // Root contexts not done during shutdown. WasmCallVoid<0> _initialize_; /* Emscripten v1.39.17+ */ WasmCallVoid<0> _start_; /* Emscripten v1.39.0+ */ diff --git a/src/context.cc b/src/context.cc index a90db849..c1929e14 100644 --- a/src/context.cc +++ b/src/context.cc @@ -269,10 +269,8 @@ ContextBase::ContextBase(WasmBase *wasm) : wasm_(wasm), parent_context_(this) { wasm_->contexts_[id_] = this; } -ContextBase::ContextBase(WasmBase *wasm, std::shared_ptr plugin) - : wasm_(wasm), id_(wasm->allocContextId()), parent_context_(this), root_id_(plugin->root_id_), - root_log_prefix_(makeRootLogPrefix(plugin->vm_id_)), plugin_(plugin) { - wasm_->contexts_[id_] = this; +ContextBase::ContextBase(WasmBase *wasm, std::shared_ptr plugin) { + initializeRootBase(wasm, plugin); } // NB: wasm can be nullptr if it failed to be created successfully. @@ -290,6 +288,15 @@ WasmVm *ContextBase::wasmVm() const { return wasm_->wasm_vm(); } bool ContextBase::isFailed() { return !wasm_ || wasm_->isFailed(); } +void ContextBase::initializeRootBase(WasmBase *wasm, std::shared_ptr plugin) { + wasm_ = wasm; + id_ = wasm->allocContextId(); + root_id_ = plugin->root_id_; + root_log_prefix_ = makeRootLogPrefix(plugin->vm_id_); + parent_context_ = this; + wasm_->contexts_[id_] = this; +} + std::string ContextBase::makeRootLogPrefix(std::string_view vm_id) const { std::string prefix; if (!root_id_.empty()) { @@ -308,10 +315,10 @@ bool ContextBase::onStart(std::shared_ptr plugin) { DeferAfterCallActions actions(this); bool result = true; if (wasm_->on_context_create_) { - temp_plugin_ = plugin; + plugin_ = plugin; wasm_->on_context_create_(this, id_, 0); in_vm_context_created_ = true; - temp_plugin_.reset(); + plugin_.reset(); } if (wasm_->on_vm_start_) { // Do not set plugin_ as the on_vm_start handler should be independent of the plugin since the @@ -343,11 +350,11 @@ bool ContextBase::onConfigure(std::shared_ptr plugin) { } DeferAfterCallActions actions(this); - temp_plugin_ = plugin; + plugin_ = plugin; auto result = wasm_->on_configure_(this, id_, static_cast(plugin->plugin_configuration_.size())) .u64_ != 0; - temp_plugin_.reset(); + plugin_.reset(); return result; } @@ -637,8 +644,8 @@ WasmResult ContextBase::setTimerPeriod(std::chrono::milliseconds period, } ContextBase::~ContextBase() { - // Do not remove vm context which has the same lifetime as wasm_. - if (id_) { + // Do not remove vm or root contexts which have the same lifetime as wasm_. + if (parent_context_id_) { wasm_->contexts_.erase(id_); } } diff --git a/src/wasm.cc b/src/wasm.cc index 939fd0a4..9472873d 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -280,11 +280,7 @@ WasmBase::WasmBase(std::unique_ptr wasm_vm, std::string_view vm_id, } } -WasmBase::~WasmBase() { - root_contexts_.clear(); - pending_done_.clear(); - pending_delete_.clear(); -} +WasmBase::~WasmBase() {} bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { if (!wasm_vm_) { @@ -323,19 +319,22 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { return !isFailed(); } -ContextBase *WasmBase::getRootContext(const std::shared_ptr &plugin, - bool allow_closed) { - auto it = root_contexts_.find(plugin->key()); - if (it != root_contexts_.end()) { - return it->second.get(); +ContextBase *WasmBase::getRootContext(std::string_view root_id) { + auto it = root_contexts_.find(std::string(root_id)); + if (it == root_contexts_.end()) { + return nullptr; } - if (allow_closed) { - it = pending_done_.find(plugin->key()); - if (it != pending_done_.end()) { - return it->second.get(); - } + return it->second.get(); +} + +ContextBase *WasmBase::getOrCreateRootContext(const std::shared_ptr &plugin) { + auto root_context = getRootContext(plugin->root_id_); + if (!root_context) { + auto context = std::unique_ptr(createRootContext(plugin)); + root_context = context.get(); + root_contexts_[plugin->root_id_] = std::move(context); } - return nullptr; + return root_context; } void WasmBase::startVm(ContextBase *root_context) { @@ -353,14 +352,15 @@ bool WasmBase::configure(ContextBase *root_context, std::shared_ptr } ContextBase *WasmBase::start(std::shared_ptr plugin) { - auto it = root_contexts_.find(plugin->key()); + auto root_id = plugin->root_id_; + auto it = root_contexts_.find(root_id); if (it != root_contexts_.end()) { it->second->onStart(plugin); return it->second.get(); } auto context = std::unique_ptr(createRootContext(plugin)); auto context_ptr = context.get(); - root_contexts_[plugin->key()] = std::move(context); + root_contexts_[root_id] = std::move(context); if (!context_ptr->onStart(plugin)) { return nullptr; } @@ -377,49 +377,38 @@ uint32_t WasmBase::allocContextId() { } } -void WasmBase::startShutdown(const std::shared_ptr &plugin) { - auto it = root_contexts_.find(plugin->key()); - if (it != root_contexts_.end()) { - if (it->second->onDone()) { - it->second->onDelete(); - } else { - pending_done_[it->first] = std::move(it->second); - } - root_contexts_.erase(it); - } -} - void WasmBase::startShutdown() { - auto it = root_contexts_.begin(); - while (it != root_contexts_.end()) { - if (it->second->onDone()) { - it->second->onDelete(); - } else { - pending_done_[it->first] = std::move(it->second); + bool all_done = true; + for (auto &p : root_contexts_) { + if (!p.second->onDone()) { + all_done = false; + pending_done_.insert(p.second.get()); } - it = root_contexts_.erase(it); + } + if (!all_done) { + shutdown_handle_ = std::make_unique(shared_from_this()); + } else { + finishShutdown(); } } WasmResult WasmBase::done(ContextBase *root_context) { - auto it = pending_done_.find(root_context->plugin_->key()); + auto it = pending_done_.find(root_context); if (it == pending_done_.end()) { return WasmResult::NotFound; } - pending_delete_.insert(std::move(it->second)); pending_done_.erase(it); - // Defer the delete so that onDelete is not called from within the done() handler. - shutdown_handle_ = std::make_unique(shared_from_this()); - addAfterVmCallAction( - [shutdown_handle = shutdown_handle_.release()]() { delete shutdown_handle; }); + if (pending_done_.empty() && shutdown_handle_) { + // Defer the delete so that onDelete is not called from within the done() handler. + addAfterVmCallAction( + [shutdown_handle = shutdown_handle_.release()]() { delete shutdown_handle; }); + } return WasmResult::Ok; } void WasmBase::finishShutdown() { - auto it = pending_delete_.begin(); - while (it != pending_delete_.end()) { - (*it)->onDelete(); - it = pending_delete_.erase(it); + for (auto &p : root_contexts_) { + p.second->onDelete(); } } @@ -531,18 +520,11 @@ getOrCreateThreadLocalWasm(std::shared_ptr base_wasm, WasmHandleCloneFactory clone_factory) { auto wasm_handle = getThreadLocalWasm(base_wasm->wasm()->vm_key()); if (wasm_handle) { - auto root_context = wasm_handle->wasm()->getRootContext(plugin, false); - if (!root_context) { - root_context = wasm_handle->wasm()->start(plugin); - if (!root_context) { - base_wasm->wasm()->fail(FailState::StartFailed, "Failed to start thread-local Wasm"); - return nullptr; - } - if (!wasm_handle->wasm()->configure(root_context, plugin)) { - base_wasm->wasm()->fail(FailState::ConfigureFailed, - "Failed to configure thread-local Wasm plugin"); - return nullptr; - } + auto root_context = wasm_handle->wasm()->getOrCreateRootContext(plugin); + if (!wasm_handle->wasm()->configure(root_context, plugin)) { + base_wasm->wasm()->fail(FailState::ConfigureFailed, + "Failed to configure thread-local Wasm code"); + return nullptr; } return wasm_handle; }