Skip to content

Commit b1633c8

Browse files
chengjoeylyingbug
authored andcommitted
fix: concurrent map writes due to dig.Container is not thread safe
Signed-off-by: joeyczheng <[email protected]>
1 parent b648bd4 commit b1633c8

File tree

4 files changed

+51
-62
lines changed

4 files changed

+51
-62
lines changed

internal/application/service/model.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ var ErrModelNotFound = errors.New("model not found")
2020
type modelService struct {
2121
repo interfaces.ModelRepository
2222
ollamaService *ollama.OllamaService
23+
pooler embedding.EmbedderPooler
2324
}
2425

2526
// NewModelService creates a new model service instance
26-
func NewModelService(repo interfaces.ModelRepository, ollamaService *ollama.OllamaService) interfaces.ModelService {
27+
func NewModelService(repo interfaces.ModelRepository, ollamaService *ollama.OllamaService, pooler embedding.EmbedderPooler) interfaces.ModelService {
2728
return &modelService{
2829
repo: repo,
2930
ollamaService: ollamaService,
31+
pooler: pooler,
3032
}
3133
}
3234

@@ -253,7 +255,7 @@ func (s *modelService) GetEmbeddingModel(ctx context.Context, modelId string) (e
253255
Dimensions: model.Parameters.EmbeddingParameters.Dimension,
254256
TruncatePromptTokens: model.Parameters.EmbeddingParameters.TruncatePromptTokens,
255257
Provider: model.Parameters.Provider,
256-
})
258+
}, s.pooler, s.ollamaService)
257259
if err != nil {
258260
logger.ErrorWithFields(ctx, err, map[string]interface{}{
259261
"model_id": model.ID,
@@ -335,7 +337,7 @@ func (s *modelService) GetChatModel(ctx context.Context, modelId string) (chat.C
335337
BaseURL: model.Parameters.BaseURL,
336338
ModelName: model.Name,
337339
Source: model.Source,
338-
})
340+
}, s.ollamaService)
339341
if err != nil {
340342
logger.ErrorWithFields(ctx, err, map[string]interface{}{
341343
"model_id": model.ID,

internal/handler/initialization.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ type InitializationHandler struct {
5858
knowledgeService interfaces.KnowledgeService
5959
ollamaService *ollama.OllamaService
6060
docReaderClient *client.Client
61+
pooler embedding.EmbedderPooler
6162
}
6263

6364
// NewInitializationHandler 创建初始化处理器
@@ -70,6 +71,7 @@ func NewInitializationHandler(
7071
knowledgeService interfaces.KnowledgeService,
7172
ollamaService *ollama.OllamaService,
7273
docReaderClient *client.Client,
74+
pooler embedding.EmbedderPooler,
7375
) *InitializationHandler {
7476
return &InitializationHandler{
7577
config: config,
@@ -80,6 +82,7 @@ func NewInitializationHandler(
8082
knowledgeService: knowledgeService,
8183
ollamaService: ollamaService,
8284
docReaderClient: docReaderClient,
85+
pooler: pooler,
8386
}
8487
}
8588

@@ -1544,7 +1547,7 @@ func (h *InitializationHandler) TestEmbeddingModel(c *gin.Context) {
15441547
Provider: req.Provider,
15451548
}
15461549

1547-
emb, err := embedding.NewEmbedder(cfg)
1550+
emb, err := embedding.NewEmbedder(cfg, h.pooler, h.ollamaService)
15481551
if err != nil {
15491552
logger.ErrorWithFields(ctx, err, map[string]interface{}{"model": utils.SanitizeForLog(req.ModelName)})
15501553
c.JSON(http.StatusOK, gin.H{
@@ -1588,7 +1591,7 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
15881591
}
15891592

15901593
// 创建聊天实例
1591-
chatInstance, err := chat.NewChat(chatConfig)
1594+
chatInstance, err := chat.NewChat(chatConfig, h.ollamaService)
15921595
if err != nil {
15931596
return false, fmt.Sprintf("创建聊天实例失败: %v", err)
15941597
}
@@ -2087,7 +2090,7 @@ func (h *InitializationHandler) extractRelationsFromText(
20872090
BaseURL: llm.BaseUrl,
20882091
ModelName: llm.ModelName,
20892092
Source: types.ModelSource(llm.Source),
2090-
})
2093+
}, h.ollamaService)
20912094
if err != nil {
20922095
logger.Error(ctx, "初始化模型服务失败", err)
20932096
return nil, err
@@ -2169,7 +2172,7 @@ func (h *InitializationHandler) fabriText(ctx context.Context, tags []string, ll
21692172
BaseURL: llm.BaseUrl,
21702173
ModelName: llm.ModelName,
21712174
Source: types.ModelSource(llm.Source),
2172-
})
2175+
}, h.ollamaService)
21732176
if err != nil {
21742177
logger.Error(ctx, "初始化模型服务失败", err)
21752178
return "", err

internal/models/chat/chat.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"strings"
88

99
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
10-
"github.com/Tencent/WeKnora/internal/runtime"
1110
"github.com/Tencent/WeKnora/internal/types"
1211
)
1312

@@ -87,14 +86,12 @@ type ChatConfig struct {
8786
}
8887

8988
// NewChat 创建聊天实例
90-
func NewChat(config *ChatConfig) (Chat, error) {
89+
func NewChat(config *ChatConfig, ollamaService *ollama.OllamaService) (Chat, error) {
9190
var chat Chat
9291
var err error
9392
switch strings.ToLower(string(config.Source)) {
9493
case string(types.ModelSourceLocal):
95-
runtime.GetContainer().Invoke(func(ollamaService *ollama.OllamaService) {
96-
chat, err = NewOllamaChat(config, ollamaService)
97-
})
94+
chat, err = NewOllamaChat(config, ollamaService)
9895
if err != nil {
9996
return nil, err
10097
}

internal/models/embedding/embedder.go

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77

88
"github.com/Tencent/WeKnora/internal/models/provider"
99
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
10-
"github.com/Tencent/WeKnora/internal/runtime"
1110
"github.com/Tencent/WeKnora/internal/types"
1211
)
1312

@@ -51,15 +50,13 @@ type Config struct {
5150
}
5251

5352
// NewEmbedder creates an embedder based on the configuration
54-
func NewEmbedder(config Config) (Embedder, error) {
53+
func NewEmbedder(config Config, pooler EmbedderPooler, ollamaService *ollama.OllamaService) (Embedder, error) {
5554
var embedder Embedder
5655
var err error
5756
switch strings.ToLower(string(config.Source)) {
5857
case string(types.ModelSourceLocal):
59-
runtime.GetContainer().Invoke(func(pooler EmbedderPooler, ollamaService *ollama.OllamaService) {
60-
embedder, err = NewOllamaEmbedder(config.BaseURL,
61-
config.ModelName, config.TruncatePromptTokens, config.Dimensions, config.ModelID, pooler, ollamaService)
62-
})
58+
embedder, err = NewOllamaEmbedder(config.BaseURL,
59+
config.ModelName, config.TruncatePromptTokens, config.Dimensions, config.ModelID, pooler, ollamaService)
6360
return embedder, err
6461
case string(types.ModelSourceRemote):
6562
// Detect or use configured provider for routing
@@ -88,66 +85,56 @@ func NewEmbedder(config Config) (Embedder, error) {
8885
baseURL = strings.Replace(baseURL, "/compatible-mode/v1", "", 1)
8986
baseURL = strings.Replace(baseURL, "/compatible-mode", "", 1)
9087
}
91-
runtime.GetContainer().Invoke(func(pooler EmbedderPooler) {
92-
embedder, err = NewAliyunEmbedder(config.APIKey,
93-
baseURL,
94-
config.ModelName,
95-
config.TruncatePromptTokens,
96-
config.Dimensions,
97-
config.ModelID,
98-
pooler)
99-
})
88+
embedder, err = NewAliyunEmbedder(config.APIKey,
89+
baseURL,
90+
config.ModelName,
91+
config.TruncatePromptTokens,
92+
config.Dimensions,
93+
config.ModelID,
94+
pooler)
10095
} else {
10196
baseURL := config.BaseURL
10297
if baseURL == "" || !strings.Contains(baseURL, "/compatible-mode/") {
10398
baseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
10499
}
105-
runtime.GetContainer().Invoke(func(pooler EmbedderPooler) {
106-
embedder, err = NewOpenAIEmbedder(config.APIKey,
107-
baseURL,
108-
config.ModelName,
109-
config.TruncatePromptTokens,
110-
config.Dimensions,
111-
config.ModelID,
112-
pooler)
113-
})
114-
}
115-
return embedder, err
116-
case provider.ProviderVolcengine:
117-
// Volcengine Ark uses multimodal embedding API
118-
runtime.GetContainer().Invoke(func(pooler EmbedderPooler) {
119-
embedder, err = NewVolcengineEmbedder(config.APIKey,
120-
config.BaseURL,
100+
embedder, err = NewOpenAIEmbedder(config.APIKey,
101+
baseURL,
121102
config.ModelName,
122103
config.TruncatePromptTokens,
123104
config.Dimensions,
124105
config.ModelID,
125106
pooler)
126-
})
107+
}
108+
return embedder, err
109+
case provider.ProviderVolcengine:
110+
// Volcengine Ark uses multimodal embedding API
111+
embedder, err = NewVolcengineEmbedder(config.APIKey,
112+
config.BaseURL,
113+
config.ModelName,
114+
config.TruncatePromptTokens,
115+
config.Dimensions,
116+
config.ModelID,
117+
pooler)
127118
return embedder, err
128119
case provider.ProviderJina:
129120
// Jina AI uses different API format (truncate instead of truncate_prompt_tokens)
130-
runtime.GetContainer().Invoke(func(pooler EmbedderPooler) {
131-
embedder, err = NewJinaEmbedder(config.APIKey,
132-
config.BaseURL,
133-
config.ModelName,
134-
config.TruncatePromptTokens,
135-
config.Dimensions,
136-
config.ModelID,
137-
pooler)
138-
})
121+
embedder, err = NewJinaEmbedder(config.APIKey,
122+
config.BaseURL,
123+
config.ModelName,
124+
config.TruncatePromptTokens,
125+
config.Dimensions,
126+
config.ModelID,
127+
pooler)
139128
return embedder, err
140129
default:
141130
// Use OpenAI-compatible embedder for other providers
142-
runtime.GetContainer().Invoke(func(pooler EmbedderPooler) {
143-
embedder, err = NewOpenAIEmbedder(config.APIKey,
144-
config.BaseURL,
145-
config.ModelName,
146-
config.TruncatePromptTokens,
147-
config.Dimensions,
148-
config.ModelID,
149-
pooler)
150-
})
131+
embedder, err = NewOpenAIEmbedder(config.APIKey,
132+
config.BaseURL,
133+
config.ModelName,
134+
config.TruncatePromptTokens,
135+
config.Dimensions,
136+
config.ModelID,
137+
pooler)
151138
return embedder, err
152139
}
153140
default:

0 commit comments

Comments
 (0)