Skip to content

Commit 89c6d98

Browse files
authored
Merge pull request docker#1393 from krissetto/model-switching-catalog-and-autocomplete
Enable selecting models in TUI based on available credentials
2 parents 2bce99d + 61a295f commit 89c6d98

File tree

7 files changed

+1211
-66
lines changed

7 files changed

+1211
-66
lines changed

pkg/model/provider/provider.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,50 @@ type Alias struct {
2828
TokenEnvVar string // Environment variable name for the API token
2929
}
3030

31+
// CoreProviders lists all natively implemented provider types.
32+
// These are the provider types that have direct implementations (not aliases).
33+
var CoreProviders = []string{
34+
"openai",
35+
"anthropic",
36+
"google",
37+
"dmr",
38+
"amazon-bedrock",
39+
}
40+
41+
// CatalogProviders returns the list of provider names that should be shown in the model catalog.
42+
// This includes core providers and aliases that have a defined BaseURL (self-contained endpoints).
43+
// Aliases without a BaseURL (like azure) require user configuration and are excluded.
44+
func CatalogProviders() []string {
45+
providers := make([]string, 0, len(CoreProviders)+len(Aliases))
46+
47+
// Add all core providers
48+
providers = append(providers, CoreProviders...)
49+
50+
// Add aliases that have a defined BaseURL (they work out of the box)
51+
for name, alias := range Aliases {
52+
if alias.BaseURL != "" {
53+
providers = append(providers, name)
54+
}
55+
}
56+
57+
return providers
58+
}
59+
60+
// IsCatalogProvider returns true if the provider name is valid for the model catalog.
61+
func IsCatalogProvider(name string) bool {
62+
// Check core providers
63+
for _, p := range CoreProviders {
64+
if p == name {
65+
return true
66+
}
67+
}
68+
// Check aliases with BaseURL
69+
if alias, exists := Aliases[name]; exists && alias.BaseURL != "" {
70+
return true
71+
}
72+
return false
73+
}
74+
3175
// Aliases maps provider names to their corresponding configurations
3276
var Aliases = map[string]Alias{
3377
"requesty": {
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package provider
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestCatalogProviders(t *testing.T) {
10+
t.Parallel()
11+
12+
providers := CatalogProviders()
13+
14+
// Should include all core providers
15+
for _, core := range CoreProviders {
16+
assert.Contains(t, providers, core, "should include core provider %s", core)
17+
}
18+
19+
// Should include aliases with BaseURL
20+
for name, alias := range Aliases {
21+
if alias.BaseURL != "" {
22+
assert.Contains(t, providers, name, "should include alias %s with BaseURL", name)
23+
} else {
24+
assert.NotContains(t, providers, name, "should NOT include alias %s without BaseURL", name)
25+
}
26+
}
27+
}
28+
29+
func TestIsCatalogProvider(t *testing.T) {
30+
t.Parallel()
31+
32+
tests := []struct {
33+
name string
34+
provider string
35+
want bool
36+
}{
37+
// Core providers
38+
{"openai is core", "openai", true},
39+
{"anthropic is core", "anthropic", true},
40+
{"google is core", "google", true},
41+
{"dmr is core", "dmr", true},
42+
{"amazon-bedrock is core", "amazon-bedrock", true},
43+
44+
// Aliases with BaseURL (should be included)
45+
{"mistral has BaseURL", "mistral", true},
46+
{"xai has BaseURL", "xai", true},
47+
{"nebius has BaseURL", "nebius", true},
48+
{"requesty has BaseURL", "requesty", true},
49+
{"ollama has BaseURL", "ollama", true},
50+
51+
// Aliases without BaseURL (should be excluded)
52+
{"azure has no BaseURL", "azure", false},
53+
54+
// Unknown providers
55+
{"unknown provider", "unknown", false},
56+
{"cohere not supported", "cohere", false},
57+
}
58+
59+
for _, tt := range tests {
60+
t.Run(tt.name, func(t *testing.T) {
61+
t.Parallel()
62+
got := IsCatalogProvider(tt.provider)
63+
assert.Equal(t, tt.want, got)
64+
})
65+
}
66+
}

pkg/modelsdev/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Provider struct {
2323
type Model struct {
2424
ID string `json:"id"`
2525
Name string `json:"name"`
26+
Family string `json:"family,omitempty"`
2627
Attachment bool `json:"attachment"`
2728
Reasoning bool `json:"reasoning"`
2829
Temperature bool `json:"temperature"`

pkg/runtime/model_switcher.go

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ import (
44
"context"
55
"fmt"
66
"log/slog"
7+
"slices"
78
"strings"
89

910
"github.com/docker/cagent/pkg/config/latest"
1011
"github.com/docker/cagent/pkg/environment"
1112
"github.com/docker/cagent/pkg/model/provider"
1213
"github.com/docker/cagent/pkg/model/provider/options"
14+
"github.com/docker/cagent/pkg/modelsdev"
1315
)
1416

1517
// ModelChoice represents a model available for selection in the TUI picker.
@@ -28,6 +30,8 @@ type ModelChoice struct {
2830
IsCurrent bool
2931
// IsCustom indicates this is a custom model from the session history (not from config)
3032
IsCustom bool
33+
// IsCatalog indicates this is a model from the models.dev catalog
34+
IsCatalog bool
3135
}
3236

3337
// ModelSwitcher is an optional interface for runtimes that support changing the model
@@ -249,7 +253,7 @@ func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloy
249253
}
250254

251255
// AvailableModels implements ModelSwitcher for LocalRuntime.
252-
func (r *LocalRuntime) AvailableModels(_ context.Context) []ModelChoice {
256+
func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice {
253257
var choices []ModelChoice
254258

255259
if r.modelSwitcherCfg == nil {
@@ -273,9 +277,178 @@ func (r *LocalRuntime) AvailableModels(_ context.Context) []ModelChoice {
273277
})
274278
}
275279

280+
// Append models.dev catalog entries filtered by available credentials
281+
catalogChoices := r.buildCatalogChoices(ctx)
282+
choices = append(choices, catalogChoices...)
283+
276284
return choices
277285
}
278286

287+
// CatalogStore is an extended interface for model stores that support fetching the full database.
288+
type CatalogStore interface {
289+
ModelStore
290+
GetDatabase(ctx context.Context) (*modelsdev.Database, error)
291+
}
292+
293+
// buildCatalogChoices builds ModelChoice entries from the models.dev catalog,
294+
// filtered by supported providers and available credentials.
295+
func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice {
296+
// Check if modelsStore supports GetDatabase
297+
catalogStore, ok := r.modelsStore.(CatalogStore)
298+
if !ok {
299+
slog.Debug("Models store does not support GetDatabase, skipping catalog")
300+
return nil
301+
}
302+
303+
db, err := catalogStore.GetDatabase(ctx)
304+
if err != nil {
305+
slog.Debug("Failed to get models.dev database for catalog", "error", err)
306+
return nil
307+
}
308+
309+
// Build set of existing model refs to avoid duplicates
310+
existingRefs := make(map[string]bool)
311+
for name, cfg := range r.modelSwitcherCfg.Models {
312+
existingRefs[name] = true
313+
if cfg.Provider != "" && cfg.Model != "" {
314+
existingRefs[cfg.Provider+"/"+cfg.Model] = true
315+
}
316+
}
317+
318+
// Check which providers the user has credentials for
319+
availableProviders := r.getAvailableProviders(ctx)
320+
if len(availableProviders) == 0 {
321+
slog.Debug("No provider credentials available, skipping catalog")
322+
return nil
323+
}
324+
325+
var choices []ModelChoice
326+
for providerID, prov := range db.Providers {
327+
// Check if this provider is supported and user has credentials
328+
cagentProvider, supported := mapModelsDevProvider(providerID)
329+
if !supported {
330+
continue
331+
}
332+
if !availableProviders[cagentProvider] {
333+
continue
334+
}
335+
336+
for modelID, model := range prov.Models {
337+
// Skip models that don't output text (not suitable for chat)
338+
if !slices.Contains(model.Modalities.Output, "text") {
339+
continue
340+
}
341+
// Skip embedding models (not suitable for chat)
342+
if isEmbeddingModel(model.Family, model.Name) {
343+
continue
344+
}
345+
346+
ref := cagentProvider + "/" + modelID
347+
if existingRefs[ref] {
348+
continue
349+
}
350+
existingRefs[ref] = true
351+
352+
choices = append(choices, ModelChoice{
353+
Name: model.Name,
354+
Ref: ref,
355+
Provider: cagentProvider,
356+
Model: modelID,
357+
IsCatalog: true,
358+
})
359+
}
360+
}
361+
362+
slog.Debug("Built catalog choices", "count", len(choices), "available_providers", len(availableProviders))
363+
return choices
364+
}
365+
366+
// mapModelsDevProvider maps a models.dev provider ID to a cagent provider name.
367+
// Returns the cagent provider name and whether it's supported.
368+
// Uses provider.IsCatalogProvider to dynamically include all core providers
369+
// and aliases with defined base URLs.
370+
func mapModelsDevProvider(providerID string) (string, bool) {
371+
if provider.IsCatalogProvider(providerID) {
372+
return providerID, true
373+
}
374+
return "", false
375+
}
376+
377+
// isEmbeddingModel returns true if the model is an embedding model
378+
// based on its family or name fields from models.dev.
379+
func isEmbeddingModel(family, name string) bool {
380+
familyLower := strings.ToLower(family)
381+
nameLower := strings.ToLower(name)
382+
return strings.Contains(familyLower, "embed") || strings.Contains(nameLower, "embed")
383+
}
384+
385+
// getAvailableProviders returns a map of provider names that the user has credentials for.
386+
func (r *LocalRuntime) getAvailableProviders(ctx context.Context) map[string]bool {
387+
available := make(map[string]bool)
388+
env := r.modelSwitcherCfg.EnvProvider
389+
390+
// If using a models gateway, check for Docker token
391+
if r.modelSwitcherCfg.ModelsGateway != "" {
392+
if token, _ := env.Get(ctx, environment.DockerDesktopTokenEnv); token != "" {
393+
// Gateway supports all providers
394+
available["openai"] = true
395+
available["anthropic"] = true
396+
available["google"] = true
397+
available["mistral"] = true
398+
available["xai"] = true
399+
}
400+
return available
401+
}
402+
403+
// Check credentials for each provider
404+
providerEnvVars := map[string]string{
405+
"openai": "OPENAI_API_KEY",
406+
"anthropic": "ANTHROPIC_API_KEY",
407+
"google": "GOOGLE_API_KEY",
408+
"mistral": "MISTRAL_API_KEY",
409+
"xai": "XAI_API_KEY",
410+
"nebius": "NEBIUS_API_KEY",
411+
"requesty": "REQUESTY_API_KEY",
412+
"azure": "AZURE_API_KEY",
413+
}
414+
415+
for providerName, envVar := range providerEnvVars {
416+
if key, _ := env.Get(ctx, envVar); key != "" {
417+
available[providerName] = true
418+
}
419+
}
420+
421+
// DMR and ollama don't require credentials (local models)
422+
available["dmr"] = true
423+
available["ollama"] = true
424+
425+
// Amazon Bedrock uses AWS credentials which can come from many sources.
426+
// We do a quick heuristic check for common indicators without blocking:
427+
// - AWS_ACCESS_KEY_ID: explicit access key
428+
// - AWS_PROFILE / AWS_DEFAULT_PROFILE: named profile (credentials in ~/.aws/)
429+
// - AWS_WEB_IDENTITY_TOKEN_FILE: EKS/IRSA web identity
430+
// - AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: ECS task role
431+
// - AWS_ROLE_ARN: assumed role
432+
// Note: This won't catch all cases (e.g., EC2 instance profiles, SSO) but
433+
// those require network calls which would block the UI.
434+
awsCredentialIndicators := []string{
435+
"AWS_ACCESS_KEY_ID",
436+
"AWS_PROFILE",
437+
"AWS_DEFAULT_PROFILE",
438+
"AWS_WEB_IDENTITY_TOKEN_FILE",
439+
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
440+
"AWS_ROLE_ARN",
441+
}
442+
for _, indicator := range awsCredentialIndicators {
443+
if val, _ := env.Get(ctx, indicator); val != "" {
444+
available["amazon-bedrock"] = true
445+
break
446+
}
447+
}
448+
449+
return available
450+
}
451+
279452
// createProviderFromConfig creates a provider from a ModelConfig using the runtime's configuration.
280453
func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) {
281454
opts := []options.Opt{

0 commit comments

Comments
 (0)