diff --git a/engine/services/model_source_service.cc b/engine/services/model_source_service.cc index 59275e8db..3314fd53e 100644 --- a/engine/services/model_source_service.cc +++ b/engine/services/model_source_service.cc @@ -14,6 +14,13 @@ namespace hu = huggingface_utils; namespace { +constexpr const int kModeSourceCacheSecs = 600; + +std::string GenSourceId(const std::string& author_hub, + const std::string& model_name) { + return author_hub + "/" + model_name; +} + std::vector ParseJsonString(const std::string& json_str) { std::vector models; @@ -79,19 +86,34 @@ cpp::result ModelSourceService::AddModelSource( } if (auto is_org = r.pathParams.size() == 1; is_org) { - auto& author = r.pathParams[0]; - if (author == "cortexso") { - return AddCortexsoOrg(model_source); - } else { - return AddHfOrg(model_source, author); - } + return cpp::fail("Only support repository model source, url: " + + model_source); + // TODO(sang) + // auto& hub_author = r.pathParams[0]; + // if (hub_author == "cortexso") { + // return AddCortexsoOrg(model_source); + // } else { + // return AddHfOrg(model_source, hub_author); + // } } else { // Repo - auto const& author = r.pathParams[0]; + auto const& hub_author = r.pathParams[0]; auto const& model_name = r.pathParams[1]; + // Return cache value + if (auto key = GenSourceId(hub_author, model_name); + src_cache_.find(key) != src_cache_.end()) { + auto now = std::chrono::system_clock::now(); + if (std::chrono::duration_cast(now - + src_cache_.at(key)) + .count() < kModeSourceCacheSecs) { + CTL_DBG("Return cache value for model source: " << model_source); + return true; + } + } + if (r.pathParams[0] == "cortexso") { - return AddCortexsoRepo(model_source, author, model_name); + return AddCortexsoRepo(model_source, hub_author, model_name); } else { - return AddHfRepo(model_source, author, model_name); + return AddHfRepo(model_source, hub_author, model_name); } } } @@ -190,9 +212,9 @@ cpp::result ModelSourceService::GetModelSource( } cpp::result, std::string> -ModelSourceService::GetRepositoryList(std::string_view author, +ModelSourceService::GetRepositoryList(std::string_view hub_author, std::string_view tag_filter) { - std::string as(author); + std::string as(hub_author); auto get_repo_list = [this, &as, &tag_filter] { std::vector repo_list; auto const& mis = cortexso_repos_.at(as); @@ -227,9 +249,9 @@ ModelSourceService::GetRepositoryList(std::string_view author, } cpp::result ModelSourceService::AddHfOrg( - const std::string& model_source, const std::string& author) { + const std::string& model_source, const std::string& hub_author) { auto res = curl_utils::SimpleGet("https://huggingface.co/api/models?author=" + - author); + hub_author); if (res.has_value()) { auto models = ParseJsonString(res.value()); // Add new models @@ -238,9 +260,10 @@ cpp::result ModelSourceService::AddHfOrg( auto author_model = string_utils::SplitBy(m.id, "/"); if (author_model.size() == 2) { - auto const& author = author_model[0]; + auto const& hub_author = author_model[0]; auto const& model_name = author_model[1]; - auto r = AddHfRepo(model_source + "/" + model_name, author, model_name); + auto r = + AddHfRepo(model_source + "/" + model_name, hub_author, model_name); if (r.has_error()) { CTL_WRN(r.error()); } @@ -253,14 +276,14 @@ cpp::result ModelSourceService::AddHfOrg( } cpp::result ModelSourceService::AddHfRepo( - const std::string& model_source, const std::string& author, + const std::string& model_source, const std::string& hub_author, const std::string& model_name) { // Get models from db auto model_list_before = db_service_->GetModels(model_source) .value_or(std::vector{}); std::unordered_set updated_model_list; - auto add_res = AddRepoSiblings(model_source, author, model_name); + auto add_res = AddRepoSiblings(model_source, hub_author, model_name); if (add_res.has_error()) { return cpp::fail(add_res.error()); } else { @@ -274,15 +297,17 @@ cpp::result ModelSourceService::AddHfRepo( } } } + src_cache_[GenSourceId(hub_author, model_name)] = + std::chrono::system_clock::now(); return true; } cpp::result, std::string> ModelSourceService::AddRepoSiblings(const std::string& model_source, - const std::string& author, + const std::string& hub_author, const std::string& model_name) { std::unordered_set res; - auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + auto repo_info = hu::GetHuggingFaceModelRepoInfo(hub_author, model_name); if (repo_info.has_error()) { return cpp::fail(repo_info.error()); } @@ -293,14 +318,14 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source, "supported."); } - auto siblings_fs = hu::GetSiblingsFileSize(author, model_name); + auto siblings_fs = hu::GetSiblingsFileSize(hub_author, model_name); if (siblings_fs.has_error()) { - return cpp::fail("Could not get siblings file size: " + author + "/" + - model_name); + return cpp::fail("Could not get siblings file size: " + + GenSourceId(hub_author, model_name)); } - auto readme = hu::GetReadMe(author, model_name); + auto readme = hu::GetReadMe(hub_author, model_name); std::string desc; if (!readme.has_error()) { desc = readme.value(); @@ -326,10 +351,10 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source, siblings_fs_v.file_sizes.at(sibling.rfilename).size_in_bytes; } std::string model_id = - author + ":" + model_name + ":" + sibling.rfilename; + hub_author + ":" + model_name + ":" + sibling.rfilename; cortex::db::ModelEntry e = { .model = model_id, - .author_repo_id = author, + .author_repo_id = hub_author, .branch_name = "main", .path_to_model_yaml = "", .model_alias = "", @@ -369,9 +394,9 @@ cpp::result ModelSourceService::AddCortexsoOrg( CTL_INF(m.id); auto author_model = string_utils::SplitBy(m.id, "/"); if (author_model.size() == 2) { - auto const& author = author_model[0]; + auto const& hub_author = author_model[0]; auto const& model_name = author_model[1]; - auto r = AddCortexsoRepo(model_source + "/" + model_name, author, + auto r = AddCortexsoRepo(model_source + "/" + model_name, hub_author, model_name); if (r.has_error()) { CTL_WRN(r.error()); @@ -386,7 +411,7 @@ cpp::result ModelSourceService::AddCortexsoOrg( } cpp::result ModelSourceService::AddCortexsoRepo( - const std::string& model_source, const std::string& author, + const std::string& model_source, const std::string& hub_author, const std::string& model_name) { auto begin = std::chrono::system_clock::now(); auto branches = @@ -395,17 +420,23 @@ cpp::result ModelSourceService::AddCortexsoRepo( return cpp::fail(branches.error()); } - auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + auto repo_info = hu::GetHuggingFaceModelRepoInfo(hub_author, model_name); if (repo_info.has_error()) { return cpp::fail(repo_info.error()); } - auto readme = hu::GetReadMe(author, model_name); + auto readme = hu::GetReadMe(hub_author, model_name); std::string desc; if (!readme.has_error()) { desc = readme.value(); } + auto author = hub_author; + if (auto model_author = hu::GetModelAuthorCortexsoHub(model_name); + model_author.has_value() && !model_author->empty()) { + author = *model_author; + } + // Get models from db auto model_list_before = db_service_->GetModels(model_source) .value_or(std::vector{}); @@ -442,6 +473,8 @@ cpp::result ModelSourceService::AddCortexsoRepo( "Duration ms: " << std::chrono::duration_cast( end - begin) .count()); + src_cache_[GenSourceId(hub_author, model_name)] = + std::chrono::system_clock::now(); return true; } diff --git a/engine/services/model_source_service.h b/engine/services/model_source_service.h index cffe93bb9..54acae380 100644 --- a/engine/services/model_source_service.h +++ b/engine/services/model_source_service.h @@ -65,25 +65,25 @@ class ModelSourceService { cpp::result GetModelSource(const std::string& src); cpp::result, std::string> GetRepositoryList( - std::string_view author, std::string_view tag_filter); + std::string_view hub_author, std::string_view tag_filter); private: cpp::result AddHfOrg(const std::string& model_source, - const std::string& author); + const std::string& hub_author); cpp::result AddHfRepo(const std::string& model_source, - const std::string& author, + const std::string& hub_author, const std::string& model_name); cpp::result, std::string> AddRepoSiblings( - const std::string& model_source, const std::string& author, + const std::string& model_source, const std::string& hub_author, const std::string& model_name); cpp::result AddCortexsoOrg( const std::string& model_source); cpp::result AddCortexsoRepo( - const std::string& model_source, const std::string& author, + const std::string& model_source, const std::string& hub_author, const std::string& model_name); cpp::result AddCortexsoRepoBranch( @@ -99,4 +99,6 @@ class ModelSourceService { std::atomic running_; std::unordered_map> cortexso_repos_; + using TimePoint = std::chrono::time_point; + std::unordered_map src_cache_; }; \ No newline at end of file diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index e5c74a6e1..fde5d11b2 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -311,4 +311,24 @@ inline std::optional GetDefaultBranch( return std::nullopt; } } + +inline std::optional GetModelAuthorCortexsoHub( + const std::string& model_name) { + try { + auto remote_yml = curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name)); + + if (remote_yml.has_error()) { + return std::nullopt; + } + + auto metadata = remote_yml.value(); + auto author = metadata["author"]; + if (author.IsDefined()) { + return author.as(); + } + return std::nullopt; + } catch (const std::exception& e) { + return std::nullopt; + } +} } // namespace huggingface_utils diff --git a/engine/utils/url_parser.h b/engine/utils/url_parser.h index 244b13719..4802ba1a1 100644 --- a/engine/utils/url_parser.h +++ b/engine/utils/url_parser.h @@ -69,8 +69,11 @@ const std::regex url_regex( R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)", std::regex::extended); -inline void SplitPathParams(const std::string& input, +inline bool SplitPathParams(const std::string& input, std::vector& pathList) { + if (input.find("//") != std::string::npos) { + return false; + } // split the path by '/' std::string token; std::istringstream tokenStream(input); @@ -80,6 +83,7 @@ inline void SplitPathParams(const std::string& input, } pathList.push_back(token); } + return true; } inline cpp::result FromUrlString( @@ -105,7 +109,9 @@ inline cpp::result FromUrlString( } else if (counter == hostAndPortIndex) { url.host = res; // TODO: split the port for completeness } else if (counter == pathIndex) { - SplitPathParams(res, url.pathParams); + if (!SplitPathParams(res, url.pathParams)) { + return cpp::fail("Malformed URL: " + urlString); + } } else if (counter == queryIndex) { // TODO: implement }