Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions engine/common/api_server_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ApiServerConfiguration {
const std::string& proxy_url = "", const std::string& proxy_username = "",
const std::string& proxy_password = "", const std::string& no_proxy = "",
bool verify_peer_ssl = true, bool verify_host_ssl = true,
const std::string& hf_token = "")
const std::string& hf_token = "", std::vector<std::string> api_keys = {})
: cors{cors},
allowed_origins{allowed_origins},
verify_proxy_ssl{verify_proxy_ssl},
Expand All @@ -118,7 +118,8 @@ class ApiServerConfiguration {
no_proxy{no_proxy},
verify_peer_ssl{verify_peer_ssl},
verify_host_ssl{verify_host_ssl},
hf_token{hf_token} {}
hf_token{hf_token},
api_keys{api_keys} {}

// cors
bool cors{true};
Expand All @@ -139,6 +140,9 @@ class ApiServerConfiguration {
// token
std::string hf_token{""};

// authentication
std::vector<std::string> api_keys;

Json::Value ToJson() const {
Json::Value root;
root["cors"] = cors;
Expand All @@ -155,6 +159,10 @@ class ApiServerConfiguration {
root["verify_peer_ssl"] = verify_peer_ssl;
root["verify_host_ssl"] = verify_host_ssl;
root["huggingface_token"] = hf_token;
root["api_keys"] = Json::Value(Json::arrayValue);
for (const auto& api_key : api_keys) {
root["api_keys"].append(api_key);
}

return root;
}
Expand Down Expand Up @@ -256,7 +264,8 @@ class ApiServerConfiguration {
return true;
}},

{"allowed_origins", [this](const Json::Value& value) -> bool {
{"allowed_origins",
[this](const Json::Value& value) -> bool {
if (!value.isArray()) {
return false;
}
Expand All @@ -271,7 +280,26 @@ class ApiServerConfiguration {
this->allowed_origins.push_back(origin.asString());
}
return true;
}}};
}},

{"api_keys",
[this](const Json::Value& value) -> bool {
if (!value.isArray()) {
return false;
}
for (const auto& key : value) {
if (!key.isString()) {
return false;
}
}

this->api_keys.clear();
for (const auto& key : value) {
this->api_keys.push_back(key.asString());
}
return true;
}},
};

for (const auto& key : json.getMemberNames()) {
auto updater = field_updater.find(key);
Expand Down
49 changes: 49 additions & 0 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,55 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
.setClientMaxBodySize(256 * 1024 * 1024) // Max 256MiB body size
.setClientMaxMemoryBodySize(1024 * 1024); // 1MiB before writing to disk

auto validate_api_key = [config_service](const drogon::HttpRequestPtr& req) {
auto const& api_keys =
config_service->GetApiServerConfiguration()->api_keys;
static const std::unordered_set<std::string> public_endpoints = {
"/healthz", "/processManager/destroy", "/v1/configs"};

// If API key is not set, skip validation
if (api_keys.empty()) {
return true;
}

// If path is public or is static file, skip validation
if (public_endpoints.find(req->path()) != public_endpoints.end() ||
req->path() == "/") {
return true;
}

// Check for API key in the header
auto auth_header = req->getHeader("Authorization");

std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) !=
api_keys.end()) {
return true; // API key is valid
}
}

CTL_WRN("Unauthorized: Invalid API Key\n");
return false;
};

drogon::app().registerPreRoutingAdvice(
[&validate_api_key](
const drogon::HttpRequestPtr& req,
std::function<void(const drogon::HttpResponsePtr&)>&& cb,
drogon::AdviceChainCallback&& ccb) {
if (!validate_api_key(req)) {
Json::Value ret;
ret["message"] = "Invalid API Key";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k401Unauthorized);
cb(resp);
return;
}
ccb();
});

// CORS
drogon::app().registerPostHandlingAdvice(
[config_service](const drogon::HttpRequestPtr& req,
Expand Down
17 changes: 9 additions & 8 deletions engine/services/config_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ cpp::result<ApiServerConfiguration, std::string>
ConfigService::UpdateApiServerConfiguration(const Json::Value& json) {
auto config = file_manager_utils::GetCortexConfig();
ApiServerConfiguration api_server_config{
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl, config.huggingFaceToken};
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl, config.huggingFaceToken, config.apiKeys};

std::vector<std::string> updated_fields;
std::vector<std::string> invalid_fields;
Expand All @@ -36,6 +36,7 @@ ConfigService::UpdateApiServerConfiguration(const Json::Value& json) {
config.verifyHostSsl = api_server_config.verify_host_ssl;

config.huggingFaceToken = api_server_config.hf_token;
config.apiKeys = api_server_config.api_keys;

auto result = file_manager_utils::UpdateCortexConfig(config);
return api_server_config;
Expand All @@ -45,8 +46,8 @@ cpp::result<ApiServerConfiguration, std::string>
ConfigService::GetApiServerConfiguration() {
auto config = file_manager_utils::GetCortexConfig();
return ApiServerConfiguration{
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl, config.huggingFaceToken};
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl, config.huggingFaceToken, config.apiKeys};
}
8 changes: 7 additions & 1 deletion engine/utils/config_yaml_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cpp::result<void, std::string> CortexConfigMgr::DumpYamlConfig(
node["sslKeyPath"] = config.sslKeyPath;
node["supportedEngines"] = config.supportedEngines;
node["checkedForSyncHubAt"] = config.checkedForSyncHubAt;
node["apiKeys"] = config.apiKeys;

out_file << node;
out_file.close();
Expand Down Expand Up @@ -87,7 +88,7 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path,
!node["verifyProxySsl"] || !node["verifyProxyHostSsl"] ||
!node["supportedEngines"] || !node["sslCertPath"] ||
!node["sslKeyPath"] || !node["noProxy"] ||
!node["checkedForSyncHubAt"]);
!node["checkedForSyncHubAt"] || !node["apiKeys"]);

CortexConfig config = {
.logFolderPath = node["logFolderPath"]
Expand Down Expand Up @@ -182,6 +183,11 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path,
.checkedForSyncHubAt = node["checkedForSyncHubAt"]
? node["checkedForSyncHubAt"].as<uint64_t>()
: default_cfg.checkedForSyncHubAt,
.apiKeys =
node["apiKeys"]
? node["apiKeys"].as<std::vector<std::string>>()
: default_cfg.apiKeys,

};
if (should_update_config) {
l.unlock();
Expand Down
1 change: 1 addition & 0 deletions engine/utils/config_yaml_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct CortexConfig {
std::string sslKeyPath;
std::vector<std::string> supportedEngines;
uint64_t checkedForSyncHubAt;
std::vector<std::string> apiKeys;
};

class CortexConfigMgr {
Expand Down
1 change: 1 addition & 0 deletions engine/utils/file_manager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ config_yaml_utils::CortexConfig GetDefaultConfig() {
.sslKeyPath = "",
.supportedEngines = config_yaml_utils::kDefaultSupportedEngines,
.checkedForSyncHubAt = 0u,
.apiKeys = {},
};
}

Expand Down
Loading