Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ void initConfigBindings(nb::module_& m)
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
c.getPromptTableOffloading(), c.getEnableTrtOverlap());
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
auto pickle_tuple = nb::make_tuple(cpp_states, nb::getattr(self, "__dict__"));
return pickle_tuple;
};
Expand All @@ -490,7 +490,7 @@ void initConfigBindings(nb::module_& m)
}

auto cpp_states = nb::cast<nb::tuple>(state[0]);
if (cpp_states.size() != 28)
if (cpp_states.size() != 29)
{
throw std::runtime_error("Invalid cpp_states!");
}
Expand Down Expand Up @@ -525,7 +525,8 @@ void initConfigBindings(nb::module_& m)
nb::cast<std::optional<tle::CacheTransceiverConfig>>(cpp_states[24]), // CacheTransceiverConfig
nb::cast<bool>(cpp_states[25]), // GatherGenerationLogits
nb::cast<bool>(cpp_states[26]), // PromptTableOffloading
nb::cast<bool>(cpp_states[27]) // EnableTrtOverlap
nb::cast<bool>(cpp_states[27]), // EnableTrtOverlap
nb::cast<bool>(cpp_states[28]) // FailFastOnAttentionWindowTooLarge
);

// Restore Python data
Expand Down Expand Up @@ -564,7 +565,8 @@ void initConfigBindings(nb::module_& m)
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
bool, // GatherGenerationLogits
bool, // PromptTableOffloading
bool // EnableTrtOverlap
bool, // EnableTrtOverlap
bool // FailFastOnAttentionWindowTooLarge
>(),
nb::arg("max_beam_width") = 1, nb::arg("scheduler_config") = tle::SchedulerConfig(),
nb::arg("kv_cache_config") = tle::KvCacheConfig(), nb::arg("enable_chunked_context") = false,
Expand All @@ -582,7 +584,7 @@ void initConfigBindings(nb::module_& m)
nb::arg("spec_dec_config") = nb::none(), nb::arg("guided_decoding_config") = nb::none(),
nb::arg("additional_model_outputs") = nb::none(), nb::arg("cache_transceiver_config") = nb::none(),
nb::arg("gather_generation_logits") = false, nb::arg("mm_embedding_offloading") = false,
nb::arg("enable_trt_overlap") = false)
nb::arg("enable_trt_overlap") = false, nb::arg("fail_fast_on_attention_window_too_large") = false)
.def_prop_rw("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
.def_prop_rw("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
.def_prop_rw("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
Expand Down Expand Up @@ -632,6 +634,9 @@ void initConfigBindings(nb::module_& m)
&tle::ExecutorConfig::setPromptTableOffloading)
.def_prop_rw(
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
.def_prop_rw("fail_fast_on_attention_window_too_large",
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
.def("__getstate__", executorConfigGetState)
.def("__setstate__", executorConfigSetState);
}
Expand Down