|
7 | 7 |
|
8 | 8 | "github.com/Tencent/WeKnora/internal/models/provider" |
9 | 9 | "github.com/Tencent/WeKnora/internal/models/utils/ollama" |
10 | | - "github.com/Tencent/WeKnora/internal/runtime" |
11 | 10 | "github.com/Tencent/WeKnora/internal/types" |
12 | 11 | ) |
13 | 12 |
|
@@ -51,15 +50,13 @@ type Config struct { |
51 | 50 | } |
52 | 51 |
|
53 | 52 | // NewEmbedder creates an embedder based on the configuration |
54 | | -func NewEmbedder(config Config) (Embedder, error) { |
| 53 | +func NewEmbedder(config Config, pooler EmbedderPooler, ollamaService *ollama.OllamaService) (Embedder, error) { |
55 | 54 | var embedder Embedder |
56 | 55 | var err error |
57 | 56 | switch strings.ToLower(string(config.Source)) { |
58 | 57 | case string(types.ModelSourceLocal): |
59 | | - runtime.GetContainer().Invoke(func(pooler EmbedderPooler, ollamaService *ollama.OllamaService) { |
60 | | - embedder, err = NewOllamaEmbedder(config.BaseURL, |
61 | | - config.ModelName, config.TruncatePromptTokens, config.Dimensions, config.ModelID, pooler, ollamaService) |
62 | | - }) |
| 58 | + embedder, err = NewOllamaEmbedder(config.BaseURL, |
| 59 | + config.ModelName, config.TruncatePromptTokens, config.Dimensions, config.ModelID, pooler, ollamaService) |
63 | 60 | return embedder, err |
64 | 61 | case string(types.ModelSourceRemote): |
65 | 62 | // Detect or use configured provider for routing |
@@ -88,66 +85,56 @@ func NewEmbedder(config Config) (Embedder, error) { |
88 | 85 | baseURL = strings.Replace(baseURL, "/compatible-mode/v1", "", 1) |
89 | 86 | baseURL = strings.Replace(baseURL, "/compatible-mode", "", 1) |
90 | 87 | } |
91 | | - runtime.GetContainer().Invoke(func(pooler EmbedderPooler) { |
92 | | - embedder, err = NewAliyunEmbedder(config.APIKey, |
93 | | - baseURL, |
94 | | - config.ModelName, |
95 | | - config.TruncatePromptTokens, |
96 | | - config.Dimensions, |
97 | | - config.ModelID, |
98 | | - pooler) |
99 | | - }) |
| 88 | + embedder, err = NewAliyunEmbedder(config.APIKey, |
| 89 | + baseURL, |
| 90 | + config.ModelName, |
| 91 | + config.TruncatePromptTokens, |
| 92 | + config.Dimensions, |
| 93 | + config.ModelID, |
| 94 | + pooler) |
100 | 95 | } else { |
101 | 96 | baseURL := config.BaseURL |
102 | 97 | if baseURL == "" || !strings.Contains(baseURL, "/compatible-mode/") { |
103 | 98 | baseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" |
104 | 99 | } |
105 | | - runtime.GetContainer().Invoke(func(pooler EmbedderPooler) { |
106 | | - embedder, err = NewOpenAIEmbedder(config.APIKey, |
107 | | - baseURL, |
108 | | - config.ModelName, |
109 | | - config.TruncatePromptTokens, |
110 | | - config.Dimensions, |
111 | | - config.ModelID, |
112 | | - pooler) |
113 | | - }) |
114 | | - } |
115 | | - return embedder, err |
116 | | - case provider.ProviderVolcengine: |
117 | | - // Volcengine Ark uses multimodal embedding API |
118 | | - runtime.GetContainer().Invoke(func(pooler EmbedderPooler) { |
119 | | - embedder, err = NewVolcengineEmbedder(config.APIKey, |
120 | | - config.BaseURL, |
| 100 | + embedder, err = NewOpenAIEmbedder(config.APIKey, |
| 101 | + baseURL, |
121 | 102 | config.ModelName, |
122 | 103 | config.TruncatePromptTokens, |
123 | 104 | config.Dimensions, |
124 | 105 | config.ModelID, |
125 | 106 | pooler) |
126 | | - }) |
| 107 | + } |
| 108 | + return embedder, err |
| 109 | + case provider.ProviderVolcengine: |
| 110 | + // Volcengine Ark uses multimodal embedding API |
| 111 | + embedder, err = NewVolcengineEmbedder(config.APIKey, |
| 112 | + config.BaseURL, |
| 113 | + config.ModelName, |
| 114 | + config.TruncatePromptTokens, |
| 115 | + config.Dimensions, |
| 116 | + config.ModelID, |
| 117 | + pooler) |
127 | 118 | return embedder, err |
128 | 119 | case provider.ProviderJina: |
129 | 120 | // Jina AI uses different API format (truncate instead of truncate_prompt_tokens) |
130 | | - runtime.GetContainer().Invoke(func(pooler EmbedderPooler) { |
131 | | - embedder, err = NewJinaEmbedder(config.APIKey, |
132 | | - config.BaseURL, |
133 | | - config.ModelName, |
134 | | - config.TruncatePromptTokens, |
135 | | - config.Dimensions, |
136 | | - config.ModelID, |
137 | | - pooler) |
138 | | - }) |
| 121 | + embedder, err = NewJinaEmbedder(config.APIKey, |
| 122 | + config.BaseURL, |
| 123 | + config.ModelName, |
| 124 | + config.TruncatePromptTokens, |
| 125 | + config.Dimensions, |
| 126 | + config.ModelID, |
| 127 | + pooler) |
139 | 128 | return embedder, err |
140 | 129 | default: |
141 | 130 | // Use OpenAI-compatible embedder for other providers |
142 | | - runtime.GetContainer().Invoke(func(pooler EmbedderPooler) { |
143 | | - embedder, err = NewOpenAIEmbedder(config.APIKey, |
144 | | - config.BaseURL, |
145 | | - config.ModelName, |
146 | | - config.TruncatePromptTokens, |
147 | | - config.Dimensions, |
148 | | - config.ModelID, |
149 | | - pooler) |
150 | | - }) |
| 131 | + embedder, err = NewOpenAIEmbedder(config.APIKey, |
| 132 | + config.BaseURL, |
| 133 | + config.ModelName, |
| 134 | + config.TruncatePromptTokens, |
| 135 | + config.Dimensions, |
| 136 | + config.ModelID, |
| 137 | + pooler) |
151 | 138 | return embedder, err |
152 | 139 | } |
153 | 140 | default: |
|
0 commit comments