Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

fix: add model author for model source #2038

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions engine/services/model_source_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
namespace hu = huggingface_utils;

namespace {
constexpr const int kModeSourceCacheSecs = 600;

std::string GenSourceId(const std::string& author_hub,
const std::string& model_name) {
return author_hub + "/" + model_name;
}

std::vector<ModelInfo> ParseJsonString(const std::string& json_str) {
std::vector<ModelInfo> models;

Expand Down Expand Up @@ -79,19 +86,34 @@ cpp::result<bool, std::string> ModelSourceService::AddModelSource(
}

if (auto is_org = r.pathParams.size() == 1; is_org) {
auto& author = r.pathParams[0];
if (author == "cortexso") {
return AddCortexsoOrg(model_source);
} else {
return AddHfOrg(model_source, author);
}
return cpp::fail("Only support repository model source, url: " +
model_source);
// TODO(sang)
// auto& hub_author = r.pathParams[0];
// if (hub_author == "cortexso") {
// return AddCortexsoOrg(model_source);
// } else {
// return AddHfOrg(model_source, hub_author);
// }
} else { // Repo
auto const& author = r.pathParams[0];
auto const& hub_author = r.pathParams[0];
auto const& model_name = r.pathParams[1];
// Return cache value
if (auto key = GenSourceId(hub_author, model_name);
src_cache_.find(key) != src_cache_.end()) {
auto now = std::chrono::system_clock::now();
if (std::chrono::duration_cast<std::chrono::seconds>(now -
src_cache_.at(key))
.count() < kModeSourceCacheSecs) {
CTL_DBG("Return cache value for model source: " << model_source);
return true;
}
}

if (r.pathParams[0] == "cortexso") {
return AddCortexsoRepo(model_source, author, model_name);
return AddCortexsoRepo(model_source, hub_author, model_name);
} else {
return AddHfRepo(model_source, author, model_name);
return AddHfRepo(model_source, hub_author, model_name);
}
}
}
Expand Down Expand Up @@ -190,9 +212,9 @@ cpp::result<ModelSource, std::string> ModelSourceService::GetModelSource(
}

cpp::result<std::vector<std::string>, std::string>
ModelSourceService::GetRepositoryList(std::string_view author,
ModelSourceService::GetRepositoryList(std::string_view hub_author,
std::string_view tag_filter) {
std::string as(author);
std::string as(hub_author);
auto get_repo_list = [this, &as, &tag_filter] {
std::vector<std::string> repo_list;
auto const& mis = cortexso_repos_.at(as);
Expand Down Expand Up @@ -227,9 +249,9 @@ ModelSourceService::GetRepositoryList(std::string_view author,
}

cpp::result<bool, std::string> ModelSourceService::AddHfOrg(
const std::string& model_source, const std::string& author) {
const std::string& model_source, const std::string& hub_author) {
auto res = curl_utils::SimpleGet("https://huggingface.co/api/models?author=" +
author);
hub_author);
if (res.has_value()) {
auto models = ParseJsonString(res.value());
// Add new models
Expand All @@ -238,9 +260,10 @@ cpp::result<bool, std::string> ModelSourceService::AddHfOrg(

auto author_model = string_utils::SplitBy(m.id, "/");
if (author_model.size() == 2) {
auto const& author = author_model[0];
auto const& hub_author = author_model[0];
auto const& model_name = author_model[1];
auto r = AddHfRepo(model_source + "/" + model_name, author, model_name);
auto r =
AddHfRepo(model_source + "/" + model_name, hub_author, model_name);
if (r.has_error()) {
CTL_WRN(r.error());
}
Expand All @@ -253,14 +276,14 @@ cpp::result<bool, std::string> ModelSourceService::AddHfOrg(
}

cpp::result<bool, std::string> ModelSourceService::AddHfRepo(
const std::string& model_source, const std::string& author,
const std::string& model_source, const std::string& hub_author,
const std::string& model_name) {
// Get models from db

auto model_list_before = db_service_->GetModels(model_source)
.value_or(std::vector<cortex::db::ModelEntry>{});
std::unordered_set<std::string> updated_model_list;
auto add_res = AddRepoSiblings(model_source, author, model_name);
auto add_res = AddRepoSiblings(model_source, hub_author, model_name);
if (add_res.has_error()) {
return cpp::fail(add_res.error());
} else {
Expand All @@ -274,15 +297,17 @@ cpp::result<bool, std::string> ModelSourceService::AddHfRepo(
}
}
}
src_cache_[GenSourceId(hub_author, model_name)] =
std::chrono::system_clock::now();
return true;
}

cpp::result<std::unordered_set<std::string>, std::string>
ModelSourceService::AddRepoSiblings(const std::string& model_source,
const std::string& author,
const std::string& hub_author,
const std::string& model_name) {
std::unordered_set<std::string> res;
auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name);
auto repo_info = hu::GetHuggingFaceModelRepoInfo(hub_author, model_name);
if (repo_info.has_error()) {
return cpp::fail(repo_info.error());
}
Expand All @@ -293,14 +318,14 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source,
"supported.");
}

auto siblings_fs = hu::GetSiblingsFileSize(author, model_name);
auto siblings_fs = hu::GetSiblingsFileSize(hub_author, model_name);

if (siblings_fs.has_error()) {
return cpp::fail("Could not get siblings file size: " + author + "/" +
model_name);
return cpp::fail("Could not get siblings file size: " +
GenSourceId(hub_author, model_name));
}

auto readme = hu::GetReadMe(author, model_name);
auto readme = hu::GetReadMe(hub_author, model_name);
std::string desc;
if (!readme.has_error()) {
desc = readme.value();
Expand All @@ -326,10 +351,10 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source,
siblings_fs_v.file_sizes.at(sibling.rfilename).size_in_bytes;
}
std::string model_id =
author + ":" + model_name + ":" + sibling.rfilename;
hub_author + ":" + model_name + ":" + sibling.rfilename;
cortex::db::ModelEntry e = {
.model = model_id,
.author_repo_id = author,
.author_repo_id = hub_author,
.branch_name = "main",
.path_to_model_yaml = "",
.model_alias = "",
Expand Down Expand Up @@ -369,9 +394,9 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoOrg(
CTL_INF(m.id);
auto author_model = string_utils::SplitBy(m.id, "/");
if (author_model.size() == 2) {
auto const& author = author_model[0];
auto const& hub_author = author_model[0];
auto const& model_name = author_model[1];
auto r = AddCortexsoRepo(model_source + "/" + model_name, author,
auto r = AddCortexsoRepo(model_source + "/" + model_name, hub_author,
model_name);
if (r.has_error()) {
CTL_WRN(r.error());
Expand All @@ -386,7 +411,7 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoOrg(
}

cpp::result<bool, std::string> ModelSourceService::AddCortexsoRepo(
const std::string& model_source, const std::string& author,
const std::string& model_source, const std::string& hub_author,
const std::string& model_name) {
auto begin = std::chrono::system_clock::now();
auto branches =
Expand All @@ -395,17 +420,23 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoRepo(
return cpp::fail(branches.error());
}

auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name);
auto repo_info = hu::GetHuggingFaceModelRepoInfo(hub_author, model_name);
if (repo_info.has_error()) {
return cpp::fail(repo_info.error());
}

auto readme = hu::GetReadMe(author, model_name);
auto readme = hu::GetReadMe(hub_author, model_name);
std::string desc;
if (!readme.has_error()) {
desc = readme.value();
}

auto author = hub_author;
if (auto model_author = hu::GetModelAuthorCortexsoHub(model_name);
model_author.has_value() && !model_author->empty()) {
author = *model_author;
}

// Get models from db
auto model_list_before = db_service_->GetModels(model_source)
.value_or(std::vector<cortex::db::ModelEntry>{});
Expand Down Expand Up @@ -442,6 +473,8 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoRepo(
"Duration ms: " << std::chrono::duration_cast<std::chrono::milliseconds>(
end - begin)
.count());
src_cache_[GenSourceId(hub_author, model_name)] =
std::chrono::system_clock::now();
return true;
}

Expand Down
12 changes: 7 additions & 5 deletions engine/services/model_source_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,25 @@ class ModelSourceService {
cpp::result<ModelSource, std::string> GetModelSource(const std::string& src);

cpp::result<std::vector<std::string>, std::string> GetRepositoryList(
std::string_view author, std::string_view tag_filter);
std::string_view hub_author, std::string_view tag_filter);

private:
cpp::result<bool, std::string> AddHfOrg(const std::string& model_source,
const std::string& author);
const std::string& hub_author);

cpp::result<bool, std::string> AddHfRepo(const std::string& model_source,
const std::string& author,
const std::string& hub_author,
const std::string& model_name);

cpp::result<std::unordered_set<std::string>, std::string> AddRepoSiblings(
const std::string& model_source, const std::string& author,
const std::string& model_source, const std::string& hub_author,
const std::string& model_name);

cpp::result<bool, std::string> AddCortexsoOrg(
const std::string& model_source);

cpp::result<bool, std::string> AddCortexsoRepo(
const std::string& model_source, const std::string& author,
const std::string& model_source, const std::string& hub_author,
const std::string& model_name);

cpp::result<std::string, std::string> AddCortexsoRepoBranch(
Expand All @@ -99,4 +99,6 @@ class ModelSourceService {
std::atomic<bool> running_;

std::unordered_map<std::string, std::vector<ModelInfo>> cortexso_repos_;
using TimePoint = std::chrono::time_point<std::chrono::system_clock>;
std::unordered_map<std::string, TimePoint> src_cache_;
};
20 changes: 20 additions & 0 deletions engine/utils/huggingface_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,24 @@ inline std::optional<std::string> GetDefaultBranch(
return std::nullopt;
}
}

inline std::optional<std::string> GetModelAuthorCortexsoHub(
const std::string& model_name) {
try {
auto remote_yml = curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name));

if (remote_yml.has_error()) {
return std::nullopt;
}

auto metadata = remote_yml.value();
auto author = metadata["author"];
if (author.IsDefined()) {
return author.as<std::string>();
}
return std::nullopt;
} catch (const std::exception& e) {
return std::nullopt;
}
}
} // namespace huggingface_utils
10 changes: 8 additions & 2 deletions engine/utils/url_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ const std::regex url_regex(
R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)",
std::regex::extended);

inline void SplitPathParams(const std::string& input,
inline bool SplitPathParams(const std::string& input,
std::vector<std::string>& pathList) {
if (input.find("//") != std::string::npos) {
return false;
}
// split the path by '/'
std::string token;
std::istringstream tokenStream(input);
Expand All @@ -80,6 +83,7 @@ inline void SplitPathParams(const std::string& input,
}
pathList.push_back(token);
}
return true;
}

inline cpp::result<Url, std::string> FromUrlString(
Expand All @@ -105,7 +109,9 @@ inline cpp::result<Url, std::string> FromUrlString(
} else if (counter == hostAndPortIndex) {
url.host = res; // TODO: split the port for completeness
} else if (counter == pathIndex) {
SplitPathParams(res, url.pathParams);
if (!SplitPathParams(res, url.pathParams)) {
return cpp::fail("Malformed URL: " + urlString);
}
} else if (counter == queryIndex) {
// TODO: implement
}
Expand Down
Loading