@@ -34,11 +34,24 @@ import (
3434 "github.com/docker/cagent/pkg/tools"
3535)
3636
37+ const (
38+ // configureTimeout is the timeout for the model configure HTTP request.
39+ // This is kept short to avoid stalling client creation.
40+ configureTimeout = 10 * time .Second
41+
42+ // connectivityTimeout is the timeout for testing DMR endpoint connectivity.
43+ // This is kept short to quickly detect unreachable endpoints and try fallbacks.
44+ connectivityTimeout = 2 * time .Second
45+ )
46+
3747const (
3848 // dmrInferencePrefix mirrors github.com/docker/model-runner/pkg/inference.InferencePrefix.
3949 dmrInferencePrefix = "/engines"
4050 // dmrExperimentalEndpointsPrefix mirrors github.com/docker/model-runner/pkg/inference.ExperimentalEndpointsPrefix.
4151 dmrExperimentalEndpointsPrefix = "/exp/vDD4.40"
52+
53+ // dmrDefaultPort is the default port for Docker Model Runner.
54+ dmrDefaultPort = "12434"
4255)
4356
4457// Client represents an DMR client wrapper
@@ -84,15 +97,18 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
8497 }
8598 }
8699
87- baseURL , clientOptions , httpClient := resolveDMRBaseURL (cfg , endpoint )
88-
89- clientOptions = append (clientOptions , option .WithBaseURL (baseURL ), option .WithAPIKey ("" )) // DMR doesn't need auth
100+ baseURL , clientOptions , httpClient , err := resolveDMRBaseURL (ctx , cfg , endpoint )
101+ if err != nil {
102+ return nil , err
103+ }
90104
91105 // Ensure we always have a non-nil HTTP client for both OpenAI adapter and direct HTTP calls (rerank).
92106 if httpClient == nil {
93107 httpClient = & http.Client {}
94108 }
95109
110+ clientOptions = append (clientOptions , option .WithBaseURL (baseURL ), option .WithAPIKey ("" )) // DMR doesn't need auth
111+
96112 // Build runtime flags from ModelConfig and engine
97113 contextSize , providerRuntimeFlags , specOpts := parseDMRProviderOpts (cfg )
98114 configFlags := buildRuntimeFlagsFromModelConfig (engine , cfg )
@@ -104,8 +120,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
104120 // Skip model configuration when generating titles to avoid reconfiguring the model
105121 // with different settings (e.g., smaller max_tokens) that would affect the main agent.
106122 if ! globalOptions .GeneratingTitle () {
107- if err := configureDockerModel (ctx , cfg .Model , contextSize , finalFlags , specOpts ); err != nil {
108- slog .Debug ("docker model configure skipped or failed" , "error" , err )
123+ if err := configureModel (ctx , httpClient , baseURL , cfg .Model , contextSize , finalFlags , specOpts ); err != nil {
124+ slog .Debug ("model configure via API skipped or failed" , "error" , err )
109125 }
110126 }
111127
@@ -127,31 +143,130 @@ func inContainer() bool {
127143 return err == nil && finfo .Mode ().IsRegular ()
128144}
129145
146+ // testDMRConnectivity performs a quick health check against a DMR endpoint.
147+ // It returns true if the endpoint is reachable and responds within the timeout.
148+ func testDMRConnectivity (ctx context.Context , httpClient * http.Client , baseURL string ) bool {
149+ // Build a simple health check URL - try the models endpoint which should always exist
150+ healthURL := strings .TrimSuffix (baseURL , "/" ) + "/models"
151+
152+ ctx , cancel := context .WithTimeout (ctx , connectivityTimeout )
153+ defer cancel ()
154+
155+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , healthURL , http .NoBody )
156+ if err != nil {
157+ slog .Debug ("DMR connectivity check: failed to create request" , "url" , healthURL , "error" , err )
158+ return false
159+ }
160+
161+ resp , err := httpClient .Do (req )
162+ if err != nil {
163+ slog .Debug ("DMR connectivity check: request failed" , "url" , healthURL , "error" , err )
164+ return false
165+ }
166+ defer resp .Body .Close ()
167+
168+ // Any response (even 4xx/5xx) means the server is reachable
169+ slog .Debug ("DMR connectivity check: success" , "url" , healthURL , "status" , resp .StatusCode )
170+ return true
171+ }
172+
173+ // getDMRFallbackURLs returns a list of fallback URLs to try for DMR connectivity.
174+ // The order is chosen to maximize compatibility across platforms:
175+ // 1. model-runner.docker.internal - Docker Desktop's integrated model-runner
176+ // 2. host.docker.internal - Docker Desktop's host access (works on macOS/Windows/Linux Desktop)
177+ // 3. 172.17.0.1 - Default Docker bridge gateway (Linux Docker CE)
178+ // 4. 127.0.0.1 - Localhost (when running directly on host)
179+ func getDMRFallbackURLs (containerized bool ) []string {
180+ // Docker Desktop internal hostnames and fallback IPs for cross-platform support.
181+ // These are tried in order when the primary endpoint is unreachable.
182+ // The fallback URLs differ based on whether we're running inside a container or on the host.
183+ const dmrModelRunnerInternal = "model-runner.docker.internal" // Docker Desktop's model-runner service (container only)
184+ const dmrHostDockerInternal = "host.docker.internal" // Docker Desktop's host access (container only)
185+ const dmrDockerBridgeGateway = "172.17.0.1" // Default Docker bridge gateway (container on Linux Docker CE only)
186+ const dmrLocalhost = "127.0.0.1" // Localhost fallback (host only)
187+
188+ if containerized {
189+ // Inside a container: try Docker internal hostnames and bridge gateway
190+ return []string {
191+ fmt .Sprintf ("http://%s%s/v1/" , dmrModelRunnerInternal , dmrInferencePrefix ),
192+ fmt .Sprintf ("http://%s:%s%s/v1/" , dmrHostDockerInternal , dmrDefaultPort , dmrInferencePrefix ),
193+ fmt .Sprintf ("http://%s:%s%s/v1/" , dmrDockerBridgeGateway , dmrDefaultPort , dmrInferencePrefix ),
194+ }
195+ }
196+ // On the host: only localhost makes sense as a fallback
197+ return []string {
198+ fmt .Sprintf ("http://%s:%s%s/v1/" , dmrLocalhost , dmrDefaultPort , dmrInferencePrefix ),
199+ }
200+ }
201+
130202// resolveDMRBaseURL determines the correct base URL and HTTP options to talk to
131203// Docker Model Runner, mirroring the behavior of the `docker model` CLI as
132204// closely as possible.
133205//
134206// High‑level rules:
135- // - If the user explicitly configured a BaseURL or MODEL_RUNNER_HOST, use that.
207+ // - If the user explicitly configured a BaseURL or MODEL_RUNNER_HOST, use that (no fallbacks) .
136208// - For Desktop endpoints (model-runner.docker.internal) on the host, route
137209// through the Docker Engine experimental endpoints prefix over the Unix socket.
138210// - For standalone / offload endpoints like http://172.17.0.1:12435/engines/v1/,
139211// use localhost:<port>/engines/v1/ on the host, and the gateway IP:port inside containers.
140212// - Keep a small compatibility workaround for the legacy http://:0/engines/v1/ endpoint.
213+ // - Test connectivity and try fallback URLs if the primary endpoint is unreachable.
141214//
142215// It also returns an *http.Client when a custom transport (e.g., Docker Unix socket) is needed.
143- func resolveDMRBaseURL (cfg * latest.ModelConfig , endpoint string ) (string , []option.RequestOption , * http.Client ) {
144- var clientOptions []option.RequestOption
145- var httpClient * http.Client
146-
216+ func resolveDMRBaseURL (ctx context.Context , cfg * latest.ModelConfig , endpoint string ) (string , []option.RequestOption , * http.Client , error ) {
217+ // Explicit configuration - return immediately without fallback testing
147218 if cfg != nil && cfg .BaseURL != "" {
148- return cfg .BaseURL , clientOptions , httpClient
219+ slog .Debug ("DMR using explicitly configured BaseURL" , "url" , cfg .BaseURL )
220+ return cfg .BaseURL , nil , nil , nil
149221 }
150222 if host := os .Getenv ("MODEL_RUNNER_HOST" ); host != "" {
151223 trimmed := strings .TrimRight (host , "/" )
152- return trimmed + dmrInferencePrefix + "/v1/" , clientOptions , httpClient
224+ baseURL := trimmed + dmrInferencePrefix + "/v1/"
225+ slog .Debug ("DMR using MODEL_RUNNER_HOST" , "url" , baseURL )
226+ return baseURL , nil , nil , nil
153227 }
154228
229+ // Resolve primary URL based on endpoint
230+ baseURL , clientOptions , httpClient := resolvePrimaryDMRURL (endpoint )
231+
232+ // Test connectivity and try fallbacks if needed
233+ testClient := httpClient
234+ if testClient == nil {
235+ testClient = & http.Client {}
236+ }
237+
238+ containerized := inContainer ()
239+
240+ if ! testDMRConnectivity (ctx , testClient , baseURL ) {
241+ slog .Debug ("DMR primary endpoint unreachable, trying fallbacks" , "primary_url" , baseURL , "in_container" , containerized )
242+
243+ for _ , fallbackURL := range getDMRFallbackURLs (containerized ) {
244+ if fallbackURL == baseURL {
245+ continue // Skip if same as primary
246+ }
247+ slog .Debug ("DMR trying fallback endpoint" , "url" , fallbackURL )
248+ if testDMRConnectivity (ctx , & http.Client {}, fallbackURL ) {
249+ slog .Info ("DMR using fallback endpoint" , "fallback_url" , fallbackURL , "original_url" , baseURL )
250+ // Reset client options since we're using a different URL (no Unix socket needed for HTTP endpoints)
251+ return fallbackURL , nil , nil , nil
252+ }
253+ }
254+ slog .Error ("DMR all endpoints unreachable" , "primary_url" , baseURL , "in_container" , containerized )
255+ return "" , nil , nil , fmt .Errorf ("docker Model Runner is not reachable: tried %s and all fallback endpoints; " +
256+ "please ensure Docker Desktop is running with Model Runner enabled, or set MODEL_RUNNER_HOST environment variable" , baseURL )
257+ }
258+
259+ slog .Debug ("DMR primary endpoint reachable" , "url" , baseURL )
260+ return baseURL , clientOptions , httpClient , nil
261+ }
262+
263+ // resolvePrimaryDMRURL resolves the primary DMR URL based on the endpoint string.
264+ // This handles the various endpoint formats and platform-specific routing without
265+ // connectivity testing or fallbacks.
266+ func resolvePrimaryDMRURL (endpoint string ) (string , []option.RequestOption , * http.Client ) {
267+ var clientOptions []option.RequestOption
268+ var httpClient * http.Client
269+
155270 ep := strings .TrimSpace (endpoint )
156271
157272 // Legacy bug workaround: old DMR versions <= 0.1.44 could report http://:0/engines/v1/.
@@ -920,45 +1035,121 @@ func modelExists(ctx context.Context, model string) bool {
9201035 return true
9211036}
9221037
923- func configureDockerModel (ctx context.Context , model string , contextSize * int64 , runtimeFlags []string , specOpts * speculativeDecodingOpts ) error {
924- args := buildDockerModelConfigureArgs (model , contextSize , runtimeFlags , specOpts )
1038+ // configureRequest mirrors the model-runner's scheduling.ConfigureRequest structure.
1039+ // It specifies per-model runtime configuration options sent via POST /engines/_configure.
1040+ type configureRequest struct {
1041+ Model string `json:"model"`
1042+ ContextSize * int32 `json:"context-size,omitempty"`
1043+ RuntimeFlags []string `json:"runtime-flags,omitempty"`
1044+ Speculative * speculativeDecodingRequest `json:"speculative,omitempty"`
1045+ }
9251046
926- cmd := exec .CommandContext (ctx , "docker" , args ... )
927- slog .Debug ("Running docker model configure" , "model" , model , "args" , args )
928- var stdout , stderr bytes.Buffer
929- cmd .Stdout = & stdout
930- cmd .Stderr = & stderr
931- if err := cmd .Run (); err != nil {
932- return errors .New (strings .TrimSpace (stderr .String ()))
1047+ // speculativeDecodingRequest mirrors model-runner's inference.SpeculativeDecodingConfig.
1048+ type speculativeDecodingRequest struct {
1049+ DraftModel string `json:"draft_model,omitempty"`
1050+ NumTokens int `json:"num_tokens,omitempty"`
1051+ MinAcceptanceRate float64 `json:"min_acceptance_rate,omitempty"`
1052+ }
1053+
1054+ // configureModel sends model configuration to Model Runner via POST /engines/_configure.
1055+ // This replaces the previous approach of shelling out to `docker model configure`.
1056+ func configureModel (ctx context.Context , httpClient * http.Client , baseURL , model string , contextSize * int64 , runtimeFlags []string , specOpts * speculativeDecodingOpts ) error {
1057+ if httpClient == nil {
1058+ httpClient = & http.Client {}
1059+ }
1060+
1061+ configureURL := buildConfigureURL (baseURL )
1062+ reqBody := buildConfigureRequest (model , contextSize , runtimeFlags , specOpts )
1063+
1064+ reqData , err := json .Marshal (reqBody )
1065+ if err != nil {
1066+ return fmt .Errorf ("failed to marshal configure request: %w" , err )
1067+ }
1068+
1069+ // Use a timeout context to avoid blocking client creation indefinitely
1070+ ctx , cancel := context .WithTimeout (ctx , configureTimeout )
1071+ defer cancel ()
1072+
1073+ req , err := http .NewRequestWithContext (ctx , http .MethodPost , configureURL , bytes .NewReader (reqData ))
1074+ if err != nil {
1075+ return fmt .Errorf ("failed to create configure request: %w" , err )
1076+ }
1077+ req .Header .Set ("Content-Type" , "application/json" )
1078+
1079+ slog .Debug ("Sending model configure request via API" ,
1080+ "model" , model ,
1081+ "url" , configureURL ,
1082+ "context_size" , contextSize ,
1083+ "runtime_flags" , runtimeFlags ,
1084+ "speculative_opts" , specOpts )
1085+
1086+ resp , err := httpClient .Do (req )
1087+ if err != nil {
1088+ return fmt .Errorf ("configure request failed: %w" , err )
1089+ }
1090+ defer resp .Body .Close ()
1091+
1092+ // Model Runner returns 202 Accepted on success
1093+ if resp .StatusCode != http .StatusAccepted {
1094+ body , _ := io .ReadAll (resp .Body )
1095+ return fmt .Errorf ("configure request failed with status %d: %s" , resp .StatusCode , strings .TrimSpace (string (body )))
9331096 }
934- slog .Debug ("docker model configure completed" , "model" , model )
1097+
1098+ slog .Debug ("Model configure via API completed" , "model" , model )
9351099 return nil
9361100}
9371101
938- // buildDockerModelConfigureArgs returns the argument vector passed to `docker` for model configuration.
939- // It formats context size, speculative decoding options, and runtime flags consistently with the CLI contract.
940- func buildDockerModelConfigureArgs (model string , contextSize * int64 , runtimeFlags []string , specOpts * speculativeDecodingOpts ) []string {
941- args := []string {"model" , "configure" }
1102+ // buildConfigureURL derives the /engines/_configure endpoint URL from the OpenAI base URL.
1103+ // It handles various URL formats:
1104+ // - http://host:port/engines/v1/ → http://host:port/engines/_configure
1105+ // - http://_/exp/vDD4.40/engines/v1 → http://_/exp/vDD4.40/engines/_configure
1106+ // - http://host:port/engines/llama.cpp/v1/ → http://host:port/engines/llama.cpp/_configure
1107+ func buildConfigureURL (baseURL string ) string {
1108+ u , err := url .Parse (baseURL )
1109+ if err != nil {
1110+ // Fallback: just strip /v1/ suffix and append /_configure
1111+ baseURL = strings .TrimSuffix (baseURL , "/" )
1112+ baseURL = strings .TrimSuffix (baseURL , "/v1" )
1113+ return baseURL + "/_configure"
1114+ }
1115+
1116+ path := u .Path
1117+
1118+ // Remove trailing slash for consistent handling
1119+ path = strings .TrimSuffix (path , "/" )
1120+
1121+ // Remove /v1 suffix to get to the engines path
1122+ path = strings .TrimSuffix (path , "/v1" )
1123+
1124+ // Append /_configure
1125+ path += "/_configure"
1126+
1127+ u .Path = path
1128+ return u .String ()
1129+ }
1130+
1131+ // buildConfigureRequest constructs the JSON request body for POST /engines/_configure.
1132+ func buildConfigureRequest (model string , contextSize * int64 , runtimeFlags []string , specOpts * speculativeDecodingOpts ) configureRequest {
1133+ req := configureRequest {
1134+ Model : model ,
1135+ RuntimeFlags : runtimeFlags ,
1136+ }
1137+
1138+ // Convert int64 context size to int32 as expected by model-runner
9421139 if contextSize != nil {
943- args = append (args , "--context-size=" + strconv .FormatInt (* contextSize , 10 ))
1140+ cs := int32 (* contextSize )
1141+ req .ContextSize = & cs
9441142 }
1143+
9451144 if specOpts != nil {
946- if specOpts .draftModel != "" {
947- args = append (args , "--speculative-draft-model=" + specOpts .draftModel )
948- }
949- if specOpts .numTokens > 0 {
950- args = append (args , "--speculative-num-tokens=" + strconv .Itoa (specOpts .numTokens ))
951- }
952- if specOpts .acceptanceRate > 0 {
953- args = append (args , "--speculative-min-acceptance-rate=" + strconv .FormatFloat (specOpts .acceptanceRate , 'f' , - 1 , 64 ))
1145+ req .Speculative = & speculativeDecodingRequest {
1146+ DraftModel : specOpts .draftModel ,
1147+ NumTokens : specOpts .numTokens ,
1148+ MinAcceptanceRate : specOpts .acceptanceRate ,
9541149 }
9551150 }
956- args = append (args , model )
957- if len (runtimeFlags ) > 0 {
958- args = append (args , "--" )
959- args = append (args , runtimeFlags ... )
960- }
961- return args
1151+
1152+ return req
9621153}
9631154
9641155func getDockerModelEndpointAndEngine (ctx context.Context ) (endpoint , engine string , err error ) {
@@ -1025,7 +1216,7 @@ func buildRuntimeFlagsFromModelConfig(engine string, cfg *latest.ModelConfig) []
10251216 if cfg .PresencePenalty != nil {
10261217 flags = append (flags , "--presence-penalty" , strconv .FormatFloat (* cfg .PresencePenalty , 'f' , - 1 , 64 ))
10271218 }
1028- // Note: Context size already handled via -- context-size during `docker model configure`
1219+ // Note: Context size already handled via context-size field in the configure API request
10291220 default :
10301221 // Unknown engine: no flags
10311222 }
0 commit comments