Skip to content

Commit 877a04c

Browse files
committed
Use internal endpoint for dmr model configuration
Also makes endpoint selection a bit smarter so cagent ca try multiple well-known URL to best support the environment its deployed in (local binary, container on linux, in a container on DD on mac/windows) Signed-off-by: krissetto <chrisjpetito@gmail.com>
1 parent d05fd9a commit 877a04c

File tree

6 files changed

+576
-95
lines changed

6 files changed

+576
-95
lines changed

docs/PROVIDERS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,6 @@ The DMR provider supports speculative decoding for faster inference. Configure i
134134
- `speculative_num_tokens` (int): Number of tokens to generate speculatively
135135
- `speculative_acceptance_rate` (float): Acceptance rate threshold for speculative tokens
136136

137-
All three options are passed to `docker model configure` as command-line flags.
137+
All three options are sent to Model Runner via its internal `POST /engines/_configure` API endpoint.
138138

139139
You can also pass any flag of the underlying model runtime (llama.cpp or vllm) using the `runtime_flags` option

docs/USAGE.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -722,12 +722,12 @@ models:
722722
speculative_acceptance_rate: 0.8 # Acceptance rate threshold
723723
```
724724

725-
All three speculative decoding options are passed to `docker model configure` as flags:
726-
- `speculative_draft_model` → `--speculative-draft-model`
727-
- `speculative_num_tokens` → `--speculative-num-tokens`
728-
- `speculative_acceptance_rate` → `--speculative-acceptance-rate`
725+
All three speculative decoding options are sent to Model Runner via its internal `POST /engines/_configure` API endpoint:
726+
- `speculative_draft_model` → `speculative.draft_model`
727+
- `speculative_num_tokens` → `speculative.num_tokens`
728+
- `speculative_acceptance_rate` → `speculative.min_acceptance_rate`
729729

730-
These options work alongside `max_tokens` (which sets `--context-size`) and `runtime_flags`.
730+
These options work alongside `max_tokens` (which sets `context-size`) and `runtime_flags`.
731731

732732
##### Troubleshooting:
733733

pkg/model/provider/clone.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ func CloneWithOptions(ctx context.Context, base Provider, opts ...options.Opt) P
1919

2020
// Apply max_tokens override if present in options
2121
// We need to apply it to the ModelConfig itself since that's what providers use
22+
// Only update MaxTokens if an option explicitly sets it (non-zero value)
2223
modelConfig := config.ModelConfig
2324
for _, opt := range mergedOpts {
2425
tempOpts := &options.ModelOptions{}
2526
opt(tempOpts)
26-
mt := tempOpts.MaxTokens()
27-
modelConfig.MaxTokens = &mt
27+
if mt := tempOpts.MaxTokens(); mt != 0 {
28+
modelConfig.MaxTokens = &mt
29+
}
2830
}
2931

3032
clone, err := New(ctx, &modelConfig, config.Env, mergedOpts...)

pkg/model/provider/dmr/client.go

Lines changed: 233 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,24 @@ import (
3434
"github.com/docker/cagent/pkg/tools"
3535
)
3636

37+
const (
38+
// configureTimeout is the timeout for the model configure HTTP request.
39+
// This is kept short to avoid stalling client creation.
40+
configureTimeout = 10 * time.Second
41+
42+
// connectivityTimeout is the timeout for testing DMR endpoint connectivity.
43+
// This is kept short to quickly detect unreachable endpoints and try fallbacks.
44+
connectivityTimeout = 2 * time.Second
45+
)
46+
3747
const (
3848
// dmrInferencePrefix mirrors github.com/docker/model-runner/pkg/inference.InferencePrefix.
3949
dmrInferencePrefix = "/engines"
4050
// dmrExperimentalEndpointsPrefix mirrors github.com/docker/model-runner/pkg/inference.ExperimentalEndpointsPrefix.
4151
dmrExperimentalEndpointsPrefix = "/exp/vDD4.40"
52+
53+
// dmrDefaultPort is the default port for Docker Model Runner.
54+
dmrDefaultPort = "12434"
4255
)
4356

4457
// Client represents an DMR client wrapper
@@ -84,15 +97,18 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
8497
}
8598
}
8699

87-
baseURL, clientOptions, httpClient := resolveDMRBaseURL(cfg, endpoint)
88-
89-
clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth
100+
baseURL, clientOptions, httpClient, err := resolveDMRBaseURL(ctx, cfg, endpoint)
101+
if err != nil {
102+
return nil, err
103+
}
90104

91105
// Ensure we always have a non-nil HTTP client for both OpenAI adapter and direct HTTP calls (rerank).
92106
if httpClient == nil {
93107
httpClient = &http.Client{}
94108
}
95109

110+
clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth
111+
96112
// Build runtime flags from ModelConfig and engine
97113
contextSize, providerRuntimeFlags, specOpts := parseDMRProviderOpts(cfg)
98114
configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg)
@@ -104,8 +120,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
104120
// Skip model configuration when generating titles to avoid reconfiguring the model
105121
// with different settings (e.g., smaller max_tokens) that would affect the main agent.
106122
if !globalOptions.GeneratingTitle() {
107-
if err := configureDockerModel(ctx, cfg.Model, contextSize, finalFlags, specOpts); err != nil {
108-
slog.Debug("docker model configure skipped or failed", "error", err)
123+
if err := configureModel(ctx, httpClient, baseURL, cfg.Model, contextSize, finalFlags, specOpts); err != nil {
124+
slog.Debug("model configure via API skipped or failed", "error", err)
109125
}
110126
}
111127

@@ -127,31 +143,130 @@ func inContainer() bool {
127143
return err == nil && finfo.Mode().IsRegular()
128144
}
129145

146+
// testDMRConnectivity performs a quick health check against a DMR endpoint.
147+
// It returns true if the endpoint is reachable and responds within the timeout.
148+
func testDMRConnectivity(ctx context.Context, httpClient *http.Client, baseURL string) bool {
149+
// Build a simple health check URL - try the models endpoint which should always exist
150+
healthURL := strings.TrimSuffix(baseURL, "/") + "/models"
151+
152+
ctx, cancel := context.WithTimeout(ctx, connectivityTimeout)
153+
defer cancel()
154+
155+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, http.NoBody)
156+
if err != nil {
157+
slog.Debug("DMR connectivity check: failed to create request", "url", healthURL, "error", err)
158+
return false
159+
}
160+
161+
resp, err := httpClient.Do(req)
162+
if err != nil {
163+
slog.Debug("DMR connectivity check: request failed", "url", healthURL, "error", err)
164+
return false
165+
}
166+
defer resp.Body.Close()
167+
168+
// Any response (even 4xx/5xx) means the server is reachable
169+
slog.Debug("DMR connectivity check: success", "url", healthURL, "status", resp.StatusCode)
170+
return true
171+
}
172+
173+
// getDMRFallbackURLs returns a list of fallback URLs to try for DMR connectivity.
174+
// The order is chosen to maximize compatibility across platforms:
175+
// 1. model-runner.docker.internal - Docker Desktop's integrated model-runner
176+
// 2. host.docker.internal - Docker Desktop's host access (works on macOS/Windows/Linux Desktop)
177+
// 3. 172.17.0.1 - Default Docker bridge gateway (Linux Docker CE)
178+
// 4. 127.0.0.1 - Localhost (when running directly on host)
179+
func getDMRFallbackURLs(containerized bool) []string {
180+
// Docker Desktop internal hostnames and fallback IPs for cross-platform support.
181+
// These are tried in order when the primary endpoint is unreachable.
182+
// The fallback URLs differ based on whether we're running inside a container or on the host.
183+
const dmrModelRunnerInternal = "model-runner.docker.internal" // Docker Desktop's model-runner service (container only)
184+
const dmrHostDockerInternal = "host.docker.internal" // Docker Desktop's host access (container only)
185+
const dmrDockerBridgeGateway = "172.17.0.1" // Default Docker bridge gateway (container on Linux Docker CE only)
186+
const dmrLocalhost = "127.0.0.1" // Localhost fallback (host only)
187+
188+
if containerized {
189+
// Inside a container: try Docker internal hostnames and bridge gateway
190+
return []string{
191+
fmt.Sprintf("http://%s%s/v1/", dmrModelRunnerInternal, dmrInferencePrefix),
192+
fmt.Sprintf("http://%s:%s%s/v1/", dmrHostDockerInternal, dmrDefaultPort, dmrInferencePrefix),
193+
fmt.Sprintf("http://%s:%s%s/v1/", dmrDockerBridgeGateway, dmrDefaultPort, dmrInferencePrefix),
194+
}
195+
}
196+
// On the host: only localhost makes sense as a fallback
197+
return []string{
198+
fmt.Sprintf("http://%s:%s%s/v1/", dmrLocalhost, dmrDefaultPort, dmrInferencePrefix),
199+
}
200+
}
201+
130202
// resolveDMRBaseURL determines the correct base URL and HTTP options to talk to
131203
// Docker Model Runner, mirroring the behavior of the `docker model` CLI as
132204
// closely as possible.
133205
//
134206
// High‑level rules:
135-
// - If the user explicitly configured a BaseURL or MODEL_RUNNER_HOST, use that.
207+
// - If the user explicitly configured a BaseURL or MODEL_RUNNER_HOST, use that (no fallbacks).
136208
// - For Desktop endpoints (model-runner.docker.internal) on the host, route
137209
// through the Docker Engine experimental endpoints prefix over the Unix socket.
138210
// - For standalone / offload endpoints like http://172.17.0.1:12435/engines/v1/,
139211
// use localhost:<port>/engines/v1/ on the host, and the gateway IP:port inside containers.
140212
// - Keep a small compatibility workaround for the legacy http://:0/engines/v1/ endpoint.
213+
// - Test connectivity and try fallback URLs if the primary endpoint is unreachable.
141214
//
142215
// It also returns an *http.Client when a custom transport (e.g., Docker Unix socket) is needed.
143-
func resolveDMRBaseURL(cfg *latest.ModelConfig, endpoint string) (string, []option.RequestOption, *http.Client) {
144-
var clientOptions []option.RequestOption
145-
var httpClient *http.Client
146-
216+
func resolveDMRBaseURL(ctx context.Context, cfg *latest.ModelConfig, endpoint string) (string, []option.RequestOption, *http.Client, error) {
217+
// Explicit configuration - return immediately without fallback testing
147218
if cfg != nil && cfg.BaseURL != "" {
148-
return cfg.BaseURL, clientOptions, httpClient
219+
slog.Debug("DMR using explicitly configured BaseURL", "url", cfg.BaseURL)
220+
return cfg.BaseURL, nil, nil, nil
149221
}
150222
if host := os.Getenv("MODEL_RUNNER_HOST"); host != "" {
151223
trimmed := strings.TrimRight(host, "/")
152-
return trimmed + dmrInferencePrefix + "/v1/", clientOptions, httpClient
224+
baseURL := trimmed + dmrInferencePrefix + "/v1/"
225+
slog.Debug("DMR using MODEL_RUNNER_HOST", "url", baseURL)
226+
return baseURL, nil, nil, nil
153227
}
154228

229+
// Resolve primary URL based on endpoint
230+
baseURL, clientOptions, httpClient := resolvePrimaryDMRURL(endpoint)
231+
232+
// Test connectivity and try fallbacks if needed
233+
testClient := httpClient
234+
if testClient == nil {
235+
testClient = &http.Client{}
236+
}
237+
238+
containerized := inContainer()
239+
240+
if !testDMRConnectivity(ctx, testClient, baseURL) {
241+
slog.Debug("DMR primary endpoint unreachable, trying fallbacks", "primary_url", baseURL, "in_container", containerized)
242+
243+
for _, fallbackURL := range getDMRFallbackURLs(containerized) {
244+
if fallbackURL == baseURL {
245+
continue // Skip if same as primary
246+
}
247+
slog.Debug("DMR trying fallback endpoint", "url", fallbackURL)
248+
if testDMRConnectivity(ctx, &http.Client{}, fallbackURL) {
249+
slog.Info("DMR using fallback endpoint", "fallback_url", fallbackURL, "original_url", baseURL)
250+
// Reset client options since we're using a different URL (no Unix socket needed for HTTP endpoints)
251+
return fallbackURL, nil, nil, nil
252+
}
253+
}
254+
slog.Error("DMR all endpoints unreachable", "primary_url", baseURL, "in_container", containerized)
255+
return "", nil, nil, fmt.Errorf("docker Model Runner is not reachable: tried %s and all fallback endpoints; "+
256+
"please ensure Docker Desktop is running with Model Runner enabled, or set MODEL_RUNNER_HOST environment variable", baseURL)
257+
}
258+
259+
slog.Debug("DMR primary endpoint reachable", "url", baseURL)
260+
return baseURL, clientOptions, httpClient, nil
261+
}
262+
263+
// resolvePrimaryDMRURL resolves the primary DMR URL based on the endpoint string.
264+
// This handles the various endpoint formats and platform-specific routing without
265+
// connectivity testing or fallbacks.
266+
func resolvePrimaryDMRURL(endpoint string) (string, []option.RequestOption, *http.Client) {
267+
var clientOptions []option.RequestOption
268+
var httpClient *http.Client
269+
155270
ep := strings.TrimSpace(endpoint)
156271

157272
// Legacy bug workaround: old DMR versions <= 0.1.44 could report http://:0/engines/v1/.
@@ -920,45 +1035,121 @@ func modelExists(ctx context.Context, model string) bool {
9201035
return true
9211036
}
9221037

923-
func configureDockerModel(ctx context.Context, model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) error {
924-
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags, specOpts)
1038+
// configureRequest mirrors the model-runner's scheduling.ConfigureRequest structure.
1039+
// It specifies per-model runtime configuration options sent via POST /engines/_configure.
1040+
type configureRequest struct {
1041+
Model string `json:"model"`
1042+
ContextSize *int32 `json:"context-size,omitempty"`
1043+
RuntimeFlags []string `json:"runtime-flags,omitempty"`
1044+
Speculative *speculativeDecodingRequest `json:"speculative,omitempty"`
1045+
}
9251046

926-
cmd := exec.CommandContext(ctx, "docker", args...)
927-
slog.Debug("Running docker model configure", "model", model, "args", args)
928-
var stdout, stderr bytes.Buffer
929-
cmd.Stdout = &stdout
930-
cmd.Stderr = &stderr
931-
if err := cmd.Run(); err != nil {
932-
return errors.New(strings.TrimSpace(stderr.String()))
1047+
// speculativeDecodingRequest mirrors model-runner's inference.SpeculativeDecodingConfig.
1048+
type speculativeDecodingRequest struct {
1049+
DraftModel string `json:"draft_model,omitempty"`
1050+
NumTokens int `json:"num_tokens,omitempty"`
1051+
MinAcceptanceRate float64 `json:"min_acceptance_rate,omitempty"`
1052+
}
1053+
1054+
// configureModel sends model configuration to Model Runner via POST /engines/_configure.
1055+
// This replaces the previous approach of shelling out to `docker model configure`.
1056+
func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) error {
1057+
if httpClient == nil {
1058+
httpClient = &http.Client{}
1059+
}
1060+
1061+
configureURL := buildConfigureURL(baseURL)
1062+
reqBody := buildConfigureRequest(model, contextSize, runtimeFlags, specOpts)
1063+
1064+
reqData, err := json.Marshal(reqBody)
1065+
if err != nil {
1066+
return fmt.Errorf("failed to marshal configure request: %w", err)
1067+
}
1068+
1069+
// Use a timeout context to avoid blocking client creation indefinitely
1070+
ctx, cancel := context.WithTimeout(ctx, configureTimeout)
1071+
defer cancel()
1072+
1073+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, configureURL, bytes.NewReader(reqData))
1074+
if err != nil {
1075+
return fmt.Errorf("failed to create configure request: %w", err)
1076+
}
1077+
req.Header.Set("Content-Type", "application/json")
1078+
1079+
slog.Debug("Sending model configure request via API",
1080+
"model", model,
1081+
"url", configureURL,
1082+
"context_size", contextSize,
1083+
"runtime_flags", runtimeFlags,
1084+
"speculative_opts", specOpts)
1085+
1086+
resp, err := httpClient.Do(req)
1087+
if err != nil {
1088+
return fmt.Errorf("configure request failed: %w", err)
1089+
}
1090+
defer resp.Body.Close()
1091+
1092+
// Model Runner returns 202 Accepted on success
1093+
if resp.StatusCode != http.StatusAccepted {
1094+
body, _ := io.ReadAll(resp.Body)
1095+
return fmt.Errorf("configure request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
9331096
}
934-
slog.Debug("docker model configure completed", "model", model)
1097+
1098+
slog.Debug("Model configure via API completed", "model", model)
9351099
return nil
9361100
}
9371101

938-
// buildDockerModelConfigureArgs returns the argument vector passed to `docker` for model configuration.
939-
// It formats context size, speculative decoding options, and runtime flags consistently with the CLI contract.
940-
func buildDockerModelConfigureArgs(model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) []string {
941-
args := []string{"model", "configure"}
1102+
// buildConfigureURL derives the /engines/_configure endpoint URL from the OpenAI base URL.
1103+
// It handles various URL formats:
1104+
// - http://host:port/engines/v1/ → http://host:port/engines/_configure
1105+
// - http://_/exp/vDD4.40/engines/v1 → http://_/exp/vDD4.40/engines/_configure
1106+
// - http://host:port/engines/llama.cpp/v1/ → http://host:port/engines/llama.cpp/_configure
1107+
func buildConfigureURL(baseURL string) string {
1108+
u, err := url.Parse(baseURL)
1109+
if err != nil {
1110+
// Fallback: just strip /v1/ suffix and append /_configure
1111+
baseURL = strings.TrimSuffix(baseURL, "/")
1112+
baseURL = strings.TrimSuffix(baseURL, "/v1")
1113+
return baseURL + "/_configure"
1114+
}
1115+
1116+
path := u.Path
1117+
1118+
// Remove trailing slash for consistent handling
1119+
path = strings.TrimSuffix(path, "/")
1120+
1121+
// Remove /v1 suffix to get to the engines path
1122+
path = strings.TrimSuffix(path, "/v1")
1123+
1124+
// Append /_configure
1125+
path += "/_configure"
1126+
1127+
u.Path = path
1128+
return u.String()
1129+
}
1130+
1131+
// buildConfigureRequest constructs the JSON request body for POST /engines/_configure.
1132+
func buildConfigureRequest(model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) configureRequest {
1133+
req := configureRequest{
1134+
Model: model,
1135+
RuntimeFlags: runtimeFlags,
1136+
}
1137+
1138+
// Convert int64 context size to int32 as expected by model-runner
9421139
if contextSize != nil {
943-
args = append(args, "--context-size="+strconv.FormatInt(*contextSize, 10))
1140+
cs := int32(*contextSize)
1141+
req.ContextSize = &cs
9441142
}
1143+
9451144
if specOpts != nil {
946-
if specOpts.draftModel != "" {
947-
args = append(args, "--speculative-draft-model="+specOpts.draftModel)
948-
}
949-
if specOpts.numTokens > 0 {
950-
args = append(args, "--speculative-num-tokens="+strconv.Itoa(specOpts.numTokens))
951-
}
952-
if specOpts.acceptanceRate > 0 {
953-
args = append(args, "--speculative-min-acceptance-rate="+strconv.FormatFloat(specOpts.acceptanceRate, 'f', -1, 64))
1145+
req.Speculative = &speculativeDecodingRequest{
1146+
DraftModel: specOpts.draftModel,
1147+
NumTokens: specOpts.numTokens,
1148+
MinAcceptanceRate: specOpts.acceptanceRate,
9541149
}
9551150
}
956-
args = append(args, model)
957-
if len(runtimeFlags) > 0 {
958-
args = append(args, "--")
959-
args = append(args, runtimeFlags...)
960-
}
961-
return args
1151+
1152+
return req
9621153
}
9631154

9641155
func getDockerModelEndpointAndEngine(ctx context.Context) (endpoint, engine string, err error) {
@@ -1025,7 +1216,7 @@ func buildRuntimeFlagsFromModelConfig(engine string, cfg *latest.ModelConfig) []
10251216
if cfg.PresencePenalty != nil {
10261217
flags = append(flags, "--presence-penalty", strconv.FormatFloat(*cfg.PresencePenalty, 'f', -1, 64))
10271218
}
1028-
// Note: Context size already handled via --context-size during `docker model configure`
1219+
// Note: Context size already handled via context-size field in the configure API request
10291220
default:
10301221
// Unknown engine: no flags
10311222
}

0 commit comments

Comments
 (0)