Skip to content

Commit 8abd071

Browse files
committed
refactor: restructure model provider handling
1 parent 8e77060 commit 8abd071

File tree

21 files changed

+401
-184
lines changed

21 files changed

+401
-184
lines changed

packages/global/core/ai/model.ts

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import { i18nT } from '../../../web/i18n/utils';
22
import type { LLMModelItemType, STTModelType, EmbeddingModelItemType } from './model.d';
3-
import { getModelProvider, type ModelProviderIdType } from './provider';
43

54
export enum ModelTypeEnum {
65
llm = 'llm',
@@ -54,29 +53,6 @@ export const defaultSTTModels: STTModelType[] = [
5453
}
5554
];
5655

57-
export const getModelFromList = (
58-
modelList: { provider: ModelProviderIdType; name: string; model: string }[],
59-
model: string,
60-
language: string
61-
):
62-
| {
63-
avatar: string;
64-
provider: ModelProviderIdType;
65-
name: string;
66-
model: string;
67-
}
68-
| undefined => {
69-
const modelData = modelList.find((item) => item.model === model) ?? modelList[0];
70-
if (!modelData) {
71-
return;
72-
}
73-
const provider = getModelProvider(modelData.provider, language);
74-
return {
75-
...modelData,
76-
avatar: provider.avatar
77-
};
78-
};
79-
8056
export const modelTypeList = [
8157
{ label: i18nT('common:model.type.chat'), value: ModelTypeEnum.llm },
8258
{ label: i18nT('common:model.type.embedding'), value: ModelTypeEnum.embedding },
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import type { I18nStringType } from '../../../common/i18n/type';
2+
import { getProviderList } from '../../../../service/core/app/tool/api';
3+
import type { ModelProviderCacheType, ModelProviderType } from './type';
4+
import type { ModelProviderIdType } from '@fastgpt-sdk/plugin';
5+
6+
export const defaultProvider: ModelProviderType = {
7+
id: 'Other',
8+
name: { en: 'Other' } as I18nStringType,
9+
avatar: 'model/other',
10+
order: 0
11+
};
12+
13+
const defaultMapData = [
14+
{
15+
id: 'Other',
16+
name: 'Other',
17+
avatar: 'model/other',
18+
provider: 'Other' as ModelProviderIdType
19+
}
20+
];
21+
22+
function getCachedModelProviders(): ModelProviderCacheType {
23+
if (!global.modelProviders_cache) {
24+
global.modelProviders_cache = {
25+
expires: 0,
26+
listData: [],
27+
mapData: []
28+
};
29+
}
30+
return global.modelProviders_cache;
31+
}
32+
33+
function createCacheWithDefaults(): ModelProviderCacheType {
34+
return {
35+
expires: Date.now() + 60 * 60 * 1000,
36+
listData: [defaultProvider],
37+
mapData: defaultMapData
38+
};
39+
}
40+
41+
// Preload model providers
42+
export async function preloadModelProviders(): Promise<void> {
43+
try {
44+
const res = await getProviderList();
45+
46+
if (res.mapData && res.listData) {
47+
const transformedListData = res.listData.map(
48+
(
49+
item: {
50+
id: string;
51+
info: I18nStringType;
52+
},
53+
index: number
54+
) => ({
55+
id: item.id,
56+
name: item.info,
57+
avatar: `/api/system/plugin/models/${item.id}.svg`,
58+
order: index
59+
})
60+
);
61+
62+
const transformedMapData = res.mapData.map(
63+
(item: {
64+
id: string;
65+
info: {
66+
name: I18nStringType | string;
67+
provider: string;
68+
avatar?: string;
69+
};
70+
}) => ({
71+
id: item.id,
72+
name: item.info.name,
73+
avatar: item.info.avatar || defaultProvider.avatar,
74+
provider: item.info.provider
75+
})
76+
);
77+
78+
global.modelProviders_cache = {
79+
expires: Date.now() + 60 * 60 * 1000,
80+
listData: transformedListData,
81+
mapData: transformedMapData
82+
};
83+
} else {
84+
global.modelProviders_cache = createCacheWithDefaults();
85+
}
86+
} catch (error) {
87+
Promise.reject(error);
88+
global.modelProviders_cache = createCacheWithDefaults();
89+
}
90+
}
91+
92+
function hasValidTranslation(name: I18nStringType | string, language: string): boolean {
93+
if (typeof name === 'string') return true;
94+
return Boolean(name[language as keyof I18nStringType]);
95+
}
96+
97+
// Get model providers
98+
export async function getModelProviders(language: string = 'en') {
99+
const cache = getCachedModelProviders();
100+
101+
if (cache.listData.length === 0 || (cache.expires > 0 && Date.now() > cache.expires)) {
102+
await preloadModelProviders();
103+
}
104+
105+
const updatedCache = getCachedModelProviders();
106+
107+
return {
108+
listData: updatedCache.listData.filter((item) => hasValidTranslation(item.name, language)),
109+
mapData: updatedCache.mapData.filter((item) => hasValidTranslation(item.name, language))
110+
};
111+
}
112+
113+
export const getModelProvider = async (provider?: ModelProviderIdType, language = 'en') => {
114+
const { listData } = await getModelProviders(language);
115+
if (!provider) {
116+
return defaultProvider;
117+
}
118+
return listData.find((item) => item.id === provider) || defaultProvider;
119+
};
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import type { I18nStringType } from '../../../common/i18n/type';
2+
3+
export type ModelShowType = {
4+
id: string;
5+
name: string;
6+
avatar: string;
7+
order: number;
8+
};
9+
10+
export type ModelProviderType = {
11+
id: string;
12+
name: I18nStringType;
13+
avatar: string;
14+
order: number;
15+
};
16+
17+
export type ModelProviderListType = {
18+
id: string;
19+
name: I18nStringType | string;
20+
avatar: string;
21+
provider: string;
22+
};
23+
24+
export type ModelProviderCacheType = {
25+
expires: number;
26+
listData: Array<ModelProviderType>;
27+
mapData: Array<ModelProviderListType>;
28+
};
29+
30+
declare global {
31+
var modelProviders_cache: ModelProviderCacheType | undefined;
32+
}

packages/service/core/ai/config/utils.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import {
99
type RerankModelItemType
1010
} from '@fastgpt/global/core/ai/model.d';
1111
import { debounce } from 'lodash';
12-
import { getModelProvider } from '@fastgpt/global/core/ai/provider';
1312
import { findModelFromAlldata } from '../model';
1413
import {
1514
reloadFastGPTConfigBuffer,
@@ -18,6 +17,7 @@ import {
1817
import { delay } from '@fastgpt/global/common/system/utils';
1918
import { pluginClient } from '../../../thirdProvider/fastgptPlugin';
2019
import { setCron } from '../../../common/system/cron';
20+
import { getModelProvider } from '@fastgpt/global/core/app/model/controller';
2121

2222
export const loadSystemModels = async (init = false, language = 'en') => {
2323
const pushModel = (model: SystemModelItemType) => {
@@ -113,9 +113,8 @@ export const loadSystemModels = async (init = false, language = 'en') => {
113113
const modelData: any = {
114114
...model,
115115
...dbModel?.metadata,
116-
provider: getModelProvider(
117-
dbModel?.metadata?.provider || (model.provider as any),
118-
language
116+
provider: (
117+
await getModelProvider(dbModel?.metadata?.provider || model.provider, language)
119118
).id,
120119
type: dbModel?.metadata?.type || model.type,
121120
isCustom: false,
@@ -171,10 +170,15 @@ export const loadSystemModels = async (init = false, language = 'en') => {
171170
}
172171

173172
// Sort model list
173+
const providerOrderMap = new Map<string, number>();
174+
for (const model of global.systemActiveModelList) {
175+
const provider = await getModelProvider(model.provider, language);
176+
providerOrderMap.set(model.provider, provider.order);
177+
}
174178
global.systemActiveModelList.sort((a, b) => {
175-
const providerA = getModelProvider(a.provider, language);
176-
const providerB = getModelProvider(b.provider, language);
177-
return providerA.order - providerB.order;
179+
const orderA = providerOrderMap.get(a.provider) ?? 0;
180+
const orderB = providerOrderMap.get(b.provider) ?? 0;
181+
return orderA - orderB;
178182
});
179183
global.systemActiveDesensitizedModels = global.systemActiveModelList.map((model) => ({
180184
...model,

packages/service/core/app/tool/api.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ export async function APIGetSystemToolList() {
2222
return Promise.reject(res.body);
2323
}
2424

25+
export async function getProviderList() {
26+
const res = await pluginClient.model.provider();
27+
return res.status === 200 ? res.body : Promise.reject(res.body);
28+
}
29+
2530
const runToolInstance = new RunToolWithStream({
2631
baseUrl: BASE_URL,
2732
token: TOKEN

projects/app/src/components/Select/AIModelSelector.tsx

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ import { HUGGING_FACE_ICON } from '@fastgpt/global/common/system/constants';
77
import { Box, Flex } from '@chakra-ui/react';
88
import Avatar from '@fastgpt/web/components/common/Avatar';
99
import MyTooltip from '@fastgpt/web/components/common/MyTooltip';
10-
import { getModelProviders } from '@fastgpt/global/core/ai/provider';
1110
import MultipleRowSelect from '@fastgpt/web/components/common/MySelect/MultipleRowSelect';
12-
import { getModelFromList } from '@fastgpt/global/core/ai/model';
1311
import type { ResponsiveValue } from '@chakra-ui/system';
12+
import type { I18nStringType } from '@fastgpt/global/common/i18n/type';
1413

1514
type Props = SelectProps & {
1615
disableTip?: string;
@@ -19,8 +18,14 @@ type Props = SelectProps & {
1918

2019
const OneRowSelector = ({ list, onChange, disableTip, noOfLines, ...props }: Props) => {
2120
const { t, i18n } = useTranslation();
22-
const { llmModelList, embeddingModelList, ttsModelList, sttModelList, reRankModelList } =
23-
useSystemStore();
21+
const {
22+
llmModelList,
23+
embeddingModelList,
24+
ttsModelList,
25+
sttModelList,
26+
reRankModelList,
27+
modelProviders
28+
} = useSystemStore();
2429
const language = i18n.language;
2530

2631
const avatarSize = useMemo(() => {
@@ -43,17 +48,19 @@ const OneRowSelector = ({ list, onChange, disableTip, noOfLines, ...props }: Pro
4348
];
4449
return list
4550
.map((item) => {
46-
const modelData = getModelFromList(allModels, item.value, language)!;
51+
const modelData = allModels.find((model) => model.model === item.value);
4752
if (!modelData) return;
4853

54+
const provider = modelProviders.listData.find((p) => p.id === modelData.provider);
55+
4956
return {
5057
value: item.value,
5158
label: (
5259
<Flex alignItems={'center'} py={1}>
5360
<Avatar
5461
borderRadius={'0'}
5562
mr={2}
56-
src={modelData?.avatar || HUGGING_FACE_ICON}
63+
src={provider?.avatar || HUGGING_FACE_ICON}
5764
w={avatarSize}
5865
fallbackSrc={HUGGING_FACE_ICON}
5966
/>
@@ -74,9 +81,9 @@ const OneRowSelector = ({ list, onChange, disableTip, noOfLines, ...props }: Pro
7481
sttModelList,
7582
reRankModelList,
7683
list,
77-
language,
7884
avatarSize,
79-
noOfLines
85+
noOfLines,
86+
modelProviders
8087
]);
8188

8289
return (
@@ -113,8 +120,14 @@ const MultipleRowSelector = ({
113120
...props
114121
}: Props) => {
115122
const { t, i18n } = useTranslation();
116-
const { llmModelList, embeddingModelList, ttsModelList, sttModelList, reRankModelList } =
117-
useSystemStore();
123+
const {
124+
llmModelList,
125+
embeddingModelList,
126+
ttsModelList,
127+
sttModelList,
128+
reRankModelList,
129+
modelProviders
130+
} = useSystemStore();
118131
const language = i18n.language;
119132
const modelList = useMemo(() => {
120133
const allModels = [
@@ -125,7 +138,9 @@ const MultipleRowSelector = ({
125138
...reRankModelList
126139
];
127140

128-
return list.map((item) => getModelFromList(allModels, item.value, language)!).filter(Boolean);
141+
return list
142+
.map((item) => allModels.find((model) => model.model === item.value))
143+
.filter(Boolean);
129144
}, [
130145
llmModelList,
131146
embeddingModelList,
@@ -149,7 +164,7 @@ const MultipleRowSelector = ({
149164
}, [props.size]);
150165

151166
const selectorList = useMemo(() => {
152-
const renderList = getModelProviders(language).map<{
167+
const renderList = modelProviders.listData.map<{
153168
label: React.JSX.Element;
154169
value: string;
155170
children: { label: string | React.ReactNode; value: string }[];
@@ -163,15 +178,15 @@ const MultipleRowSelector = ({
163178
fallbackSrc={HUGGING_FACE_ICON}
164179
w={avatarSize}
165180
/>
166-
<Box>{provider.name}</Box>
181+
<Box>{provider.name[language as keyof I18nStringType]}</Box>
167182
</Flex>
168183
),
169184
value: provider.id,
170185
children: []
171186
}));
172187

173188
for (const item of list) {
174-
const modelData = getModelFromList(modelList, item.value, language);
189+
const modelData = modelList.find((model) => model?.model === item.value);
175190
if (!modelData) continue;
176191
const provider =
177192
renderList.find((item) => item.value === (modelData?.provider || 'Other')) ??
@@ -184,7 +199,7 @@ const MultipleRowSelector = ({
184199
}
185200

186201
return renderList.filter((item) => item.children.length > 0);
187-
}, [avatarSize, list, modelList, t, language]);
202+
}, [avatarSize, list, modelList, language, modelProviders]);
188203

189204
const onSelect = useCallback(
190205
(e: string[]) => {
@@ -195,7 +210,7 @@ const MultipleRowSelector = ({
195210

196211
const SelectedLabel = useMemo(() => {
197212
if (!props.value) return <>{t('common:not_model_config')}</>;
198-
const modelData = getModelFromList(modelList, props.value, language);
213+
const modelData = modelList.find((model) => model?.model === props.value);
199214

200215
if (!modelData) return <>{t('common:not_model_config')}</>;
201216

0 commit comments

Comments
 (0)