diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index 2313d9cd..51d7dbbd 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -50,20 +50,23 @@ struct PluginBase { std::string_view runtime, std::string_view plugin_configuration, bool fail_open) : name_(std::string(name)), root_id_(std::string(root_id)), vm_id_(std::string(vm_id)), runtime_(std::string(runtime)), plugin_configuration_(plugin_configuration), - fail_open_(fail_open) {} + fail_open_(fail_open), key_(root_id_ + "||" + plugin_configuration_) {} const std::string name_; const std::string root_id_; const std::string vm_id_; const std::string runtime_; - std::string plugin_configuration_; + const std::string plugin_configuration_; const bool fail_open_; + + const std::string &key() const { return key_; } const std::string &log_prefix() const { return log_prefix_; } private: std::string makeLogPrefix() const; - std::string log_prefix_; + const std::string key_; + const std::string log_prefix_; }; struct BufferBase : public BufferInterface { @@ -373,16 +376,16 @@ class ContextBase : public RootInterface, protected: friend class WasmBase; - void initializeRootBase(WasmBase *wasm, std::shared_ptr plugin); std::string makeRootLogPrefix(std::string_view vm_id) const; WasmBase *wasm_{nullptr}; uint32_t id_{0}; - uint32_t parent_context_id_{0}; // 0 for roots and the general context. - ContextBase *parent_context_{nullptr}; // set in all contexts. - std::string root_id_; // set only in root context. - std::string root_log_prefix_; // set only in root context. - std::shared_ptr plugin_; + uint32_t parent_context_id_{0}; // 0 for roots and the general context. + ContextBase *parent_context_{nullptr}; // set in all contexts. + std::string root_id_; // set only in root context. + std::string root_log_prefix_; // set only in root context. + std::shared_ptr plugin_; // set in root and stream contexts. + std::shared_ptr temp_plugin_; // Remove once ABI v0.1.0 is gone. bool in_vm_context_created_ = false; bool destroyed_ = false; diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index c71c5848..53d0b638 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -59,8 +59,7 @@ class WasmBase : public std::enable_shared_from_this { std::string_view vm_key() const { return vm_key_; } WasmVm *wasm_vm() const { return wasm_vm_.get(); } ContextBase *vm_context() const { return vm_context_.get(); } - ContextBase *getRootContext(std::string_view root_id); - ContextBase *getOrCreateRootContext(const std::shared_ptr &plugin); + ContextBase *getRootContext(const std::shared_ptr &plugin, bool allow_closed); ContextBase *getContext(uint32_t id) { auto it = contexts_.find(id); if (it != contexts_.end()) @@ -78,6 +77,7 @@ class WasmBase : public std::enable_shared_from_this { void timerReady(uint32_t root_context_id); void queueReady(uint32_t root_context_id, uint32_t token); + void startShutdown(std::string_view plugin_key); void startShutdown(); WasmResult done(ContextBase *root_context); void finishShutdown(); @@ -170,11 +170,12 @@ class WasmBase : public std::enable_shared_from_this { uint32_t next_context_id_ = 1; // 0 is reserved for the VM context. std::shared_ptr vm_context_; // Context unrelated to any specific root or stream // (e.g. for global constructors). - std::unordered_map> root_contexts_; + std::unordered_map> root_contexts_; // Root contexts. + std::unordered_map> pending_done_; // Root contexts. + std::unordered_set> pending_delete_; // Root contexts. std::unordered_map contexts_; // Contains all contexts. std::unordered_map timer_period_; // per root_id. std::unique_ptr shutdown_handle_; - std::unordered_set pending_done_; // Root contexts not done during shutdown. WasmCallVoid<0> _initialize_; /* Emscripten v1.39.17+ */ WasmCallVoid<0> _start_; /* Emscripten v1.39.0+ */ @@ -275,11 +276,29 @@ createWasm(std::string vm_key, std::string code, std::shared_ptr plu WasmHandleFactory factory, WasmHandleCloneFactory clone_factory, bool allow_precompiled); // Get an existing ThreadLocal VM matching 'vm_id' or nullptr if there isn't one. std::shared_ptr getThreadLocalWasm(std::string_view vm_id); + +class PluginHandleBase : public std::enable_shared_from_this { +public: + explicit PluginHandleBase(std::shared_ptr wasm_handle, + std::string_view plugin_key) + : wasm_handle_(wasm_handle), plugin_key_(plugin_key) {} + ~PluginHandleBase() { wasm_handle_->wasm()->startShutdown(plugin_key_); } + + std::shared_ptr &wasm() { return wasm_handle_->wasm(); } + +protected: + std::shared_ptr wasm_handle_; + std::string plugin_key_; +}; + +using PluginHandleFactory = std::function( + std::shared_ptr base_wasm, std::string_view plugin_key)>; + // Get an existing ThreadLocal VM matching 'vm_id' or create one using 'base_wavm' by cloning or by // using it it as a template. -std::shared_ptr -getOrCreateThreadLocalWasm(std::shared_ptr base_wasm, - std::shared_ptr plugin, WasmHandleCloneFactory factory); +std::shared_ptr getOrCreateThreadLocalPlugin( + std::shared_ptr base_wasm, std::shared_ptr plugin, + WasmHandleCloneFactory clone_factory, PluginHandleFactory plugin_factory); // Clear Base Wasm cache and the thread-local Wasm sandbox cache for the calling thread. void clearWasmCachesForTesting(); diff --git a/src/context.cc b/src/context.cc index 89daa97a..33fab486 100644 --- a/src/context.cc +++ b/src/context.cc @@ -272,8 +272,10 @@ ContextBase::ContextBase(WasmBase *wasm) : wasm_(wasm), parent_context_(this) { wasm_->contexts_[id_] = this; } -ContextBase::ContextBase(WasmBase *wasm, std::shared_ptr plugin) { - initializeRootBase(wasm, plugin); +ContextBase::ContextBase(WasmBase *wasm, std::shared_ptr plugin) + : wasm_(wasm), id_(wasm->allocContextId()), parent_context_(this), root_id_(plugin->root_id_), + root_log_prefix_(makeRootLogPrefix(plugin->vm_id_)), plugin_(plugin) { + wasm_->contexts_[id_] = this; } // NB: wasm can be nullptr if it failed to be created successfully. @@ -291,15 +293,6 @@ WasmVm *ContextBase::wasmVm() const { return wasm_->wasm_vm(); } bool ContextBase::isFailed() { return !wasm_ || wasm_->isFailed(); } -void ContextBase::initializeRootBase(WasmBase *wasm, std::shared_ptr plugin) { - wasm_ = wasm; - id_ = wasm->allocContextId(); - root_id_ = plugin->root_id_; - root_log_prefix_ = makeRootLogPrefix(plugin->vm_id_); - parent_context_ = this; - wasm_->contexts_[id_] = this; -} - std::string ContextBase::makeRootLogPrefix(std::string_view vm_id) const { std::string prefix; if (!root_id_.empty()) { @@ -318,10 +311,10 @@ bool ContextBase::onStart(std::shared_ptr plugin) { DeferAfterCallActions actions(this); bool result = true; if (wasm_->on_context_create_) { - plugin_ = plugin; + temp_plugin_ = plugin; wasm_->on_context_create_(this, id_, 0); in_vm_context_created_ = true; - plugin_.reset(); + temp_plugin_.reset(); } if (wasm_->on_vm_start_) { // Do not set plugin_ as the on_vm_start handler should be independent of the plugin since the @@ -353,11 +346,11 @@ bool ContextBase::onConfigure(std::shared_ptr plugin) { } DeferAfterCallActions actions(this); - plugin_ = plugin; + temp_plugin_ = plugin; auto result = wasm_->on_configure_(this, id_, static_cast(plugin->plugin_configuration_.size())) .u64_ != 0; - plugin_.reset(); + temp_plugin_.reset(); return result; } @@ -656,8 +649,8 @@ FilterMetadataStatus ContextBase::convertVmCallResultToFilterMetadataStatus(uint } ContextBase::~ContextBase() { - // Do not remove vm or root contexts which have the same lifetime as wasm_. - if (parent_context_id_) { + // Do not remove vm context which has the same lifetime as wasm_. + if (id_) { wasm_->contexts_.erase(id_); } } diff --git a/src/wasm.cc b/src/wasm.cc index 9472873d..b83064a7 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -37,6 +37,7 @@ namespace { // Map from Wasm Key to the local Wasm instance. thread_local std::unordered_map> local_wasms; +thread_local std::unordered_map> local_plugins; // Map from Wasm Key to the base Wasm instance, using a pointer to avoid the initialization fiasco. std::mutex base_wasms_mutex; std::unordered_map> *base_wasms = nullptr; @@ -280,7 +281,11 @@ WasmBase::WasmBase(std::unique_ptr wasm_vm, std::string_view vm_id, } } -WasmBase::~WasmBase() {} +WasmBase::~WasmBase() { + root_contexts_.clear(); + pending_done_.clear(); + pending_delete_.clear(); +} bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { if (!wasm_vm_) { @@ -319,22 +324,19 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { return !isFailed(); } -ContextBase *WasmBase::getRootContext(std::string_view root_id) { - auto it = root_contexts_.find(std::string(root_id)); - if (it == root_contexts_.end()) { - return nullptr; +ContextBase *WasmBase::getRootContext(const std::shared_ptr &plugin, + bool allow_closed) { + auto it = root_contexts_.find(plugin->key()); + if (it != root_contexts_.end()) { + return it->second.get(); } - return it->second.get(); -} - -ContextBase *WasmBase::getOrCreateRootContext(const std::shared_ptr &plugin) { - auto root_context = getRootContext(plugin->root_id_); - if (!root_context) { - auto context = std::unique_ptr(createRootContext(plugin)); - root_context = context.get(); - root_contexts_[plugin->root_id_] = std::move(context); + if (allow_closed) { + it = pending_done_.find(plugin->key()); + if (it != pending_done_.end()) { + return it->second.get(); + } } - return root_context; + return nullptr; } void WasmBase::startVm(ContextBase *root_context) { @@ -352,15 +354,14 @@ bool WasmBase::configure(ContextBase *root_context, std::shared_ptr } ContextBase *WasmBase::start(std::shared_ptr plugin) { - auto root_id = plugin->root_id_; - auto it = root_contexts_.find(root_id); + auto it = root_contexts_.find(plugin->key()); if (it != root_contexts_.end()) { it->second->onStart(plugin); return it->second.get(); } auto context = std::unique_ptr(createRootContext(plugin)); auto context_ptr = context.get(); - root_contexts_[root_id] = std::move(context); + root_contexts_[plugin->key()] = std::move(context); if (!context_ptr->onStart(plugin)) { return nullptr; } @@ -377,38 +378,49 @@ uint32_t WasmBase::allocContextId() { } } -void WasmBase::startShutdown() { - bool all_done = true; - for (auto &p : root_contexts_) { - if (!p.second->onDone()) { - all_done = false; - pending_done_.insert(p.second.get()); +void WasmBase::startShutdown(std::string_view plugin_key) { + auto it = root_contexts_.find(std::string(plugin_key)); + if (it != root_contexts_.end()) { + if (it->second->onDone()) { + it->second->onDelete(); + } else { + pending_done_[it->first] = std::move(it->second); } + root_contexts_.erase(it); } - if (!all_done) { - shutdown_handle_ = std::make_unique(shared_from_this()); - } else { - finishShutdown(); +} + +void WasmBase::startShutdown() { + auto it = root_contexts_.begin(); + while (it != root_contexts_.end()) { + if (it->second->onDone()) { + it->second->onDelete(); + } else { + pending_done_[it->first] = std::move(it->second); + } + it = root_contexts_.erase(it); } } WasmResult WasmBase::done(ContextBase *root_context) { - auto it = pending_done_.find(root_context); + auto it = pending_done_.find(root_context->plugin_->key()); if (it == pending_done_.end()) { return WasmResult::NotFound; } + pending_delete_.insert(std::move(it->second)); pending_done_.erase(it); - if (pending_done_.empty() && shutdown_handle_) { - // Defer the delete so that onDelete is not called from within the done() handler. - addAfterVmCallAction( - [shutdown_handle = shutdown_handle_.release()]() { delete shutdown_handle; }); - } + // Defer the delete so that onDelete is not called from within the done() handler. + shutdown_handle_ = std::make_unique(shared_from_this()); + addAfterVmCallAction( + [shutdown_handle = shutdown_handle_.release()]() { delete shutdown_handle; }); return WasmResult::Ok; } void WasmBase::finishShutdown() { - for (auto &p : root_contexts_) { - p.second->onDelete(); + auto it = pending_delete_.begin(); + while (it != pending_delete_.end()) { + (*it)->onDelete(); + it = pending_delete_.erase(it); } } @@ -475,33 +487,6 @@ std::shared_ptr createWasm(std::string vm_key, std::string code, return wasm_handle; }; -static std::shared_ptr -createThreadLocalWasm(std::shared_ptr &base_wasm, - std::shared_ptr plugin, WasmHandleCloneFactory clone_factory) { - auto wasm_handle = clone_factory(base_wasm); - if (!wasm_handle) { - wasm_handle->wasm()->fail(FailState::UnableToCloneVM, "Failed to clone Base Wasm"); - return nullptr; - } - if (!wasm_handle->wasm()->initialize(base_wasm->wasm()->code(), - base_wasm->wasm()->allow_precompiled())) { - wasm_handle->wasm()->fail(FailState::UnableToInitializeCode, "Failed to initialize Wasm code"); - return nullptr; - } - ContextBase *root_context = wasm_handle->wasm()->start(plugin); - if (!root_context) { - base_wasm->wasm()->fail(FailState::StartFailed, "Failed to start thread-local Wasm"); - return nullptr; - } - if (!wasm_handle->wasm()->configure(root_context, plugin)) { - base_wasm->wasm()->fail(FailState::ConfigureFailed, - "Failed to configure thread-local Wasm plugin"); - return nullptr; - } - local_wasms[std::string(wasm_handle->wasm()->vm_key())] = wasm_handle; - return wasm_handle; -} - std::shared_ptr getThreadLocalWasm(std::string_view vm_key) { auto it = local_wasms.find(std::string(vm_key)); if (it == local_wasms.end()) { @@ -514,24 +499,72 @@ std::shared_ptr getThreadLocalWasm(std::string_view vm_key) { return wasm; } -std::shared_ptr -getOrCreateThreadLocalWasm(std::shared_ptr base_wasm, - std::shared_ptr plugin, +static std::shared_ptr +getOrCreateThreadLocalWasm(std::shared_ptr base_handle, WasmHandleCloneFactory clone_factory) { - auto wasm_handle = getThreadLocalWasm(base_wasm->wasm()->vm_key()); - if (wasm_handle) { - auto root_context = wasm_handle->wasm()->getOrCreateRootContext(plugin); - if (!wasm_handle->wasm()->configure(root_context, plugin)) { - base_wasm->wasm()->fail(FailState::ConfigureFailed, - "Failed to configure thread-local Wasm code"); - return nullptr; + std::string vm_key(base_handle->wasm()->vm_key()); + // Get existing thread-local WasmVM. + auto it = local_wasms.find(vm_key); + if (it != local_wasms.end()) { + auto wasm_handle = it->second.lock(); + if (wasm_handle) { + return wasm_handle; } - return wasm_handle; + // Remove stale entry. + local_wasms.erase(vm_key); + } + // Create and initialize new thread-local WasmVM. + auto wasm_handle = clone_factory(base_handle); + if (!wasm_handle) { + base_handle->wasm()->fail(FailState::UnableToCloneVM, "Failed to clone Base Wasm"); + return nullptr; + } + if (!wasm_handle->wasm()->initialize(base_handle->wasm()->code(), + base_handle->wasm()->allow_precompiled())) { + base_handle->wasm()->fail(FailState::UnableToInitializeCode, "Failed to initialize Wasm code"); + return nullptr; + } + local_wasms[vm_key] = wasm_handle; + return wasm_handle; +} + +std::shared_ptr getOrCreateThreadLocalPlugin( + std::shared_ptr base_handle, std::shared_ptr plugin, + WasmHandleCloneFactory clone_factory, PluginHandleFactory plugin_factory) { + std::string key(std::string(base_handle->wasm()->vm_key()) + "||" + plugin->key()); + // Get existing thread-local Plugin handle. + auto it = local_plugins.find(key); + if (it != local_plugins.end()) { + auto plugin_handle = it->second.lock(); + if (plugin_handle) { + return plugin_handle; + } + // Remove stale entry. + local_plugins.erase(key); + } + // Get thread-local WasmVM. + auto wasm_handle = getOrCreateThreadLocalWasm(base_handle, clone_factory); + if (!wasm_handle) { + return nullptr; + } + // Create and initialize new thread-local Plugin. + auto plugin_context = wasm_handle->wasm()->start(plugin); + if (!plugin_context) { + base_handle->wasm()->fail(FailState::StartFailed, "Failed to start thread-local Wasm"); + return nullptr; + } + if (!wasm_handle->wasm()->configure(plugin_context, plugin)) { + base_handle->wasm()->fail(FailState::ConfigureFailed, + "Failed to configure thread-local Wasm plugin"); + return nullptr; } - return createThreadLocalWasm(base_wasm, plugin, clone_factory); + auto plugin_handle = plugin_factory(wasm_handle, plugin->key()); + local_plugins[key] = plugin_handle; + return plugin_handle; } void clearWasmCachesForTesting() { + local_plugins.clear(); local_wasms.clear(); std::lock_guard guard(base_wasms_mutex); if (base_wasms) {