Skip to content

Commit b43177b

Browse files
authored
Merge pull request #138 from aws-samples/spy_dev
add source pypi
2 parents f9a5cc5 + 7fd64d5 commit b43177b

13 files changed

+76
-27
lines changed

application/Dockerfile

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,10 @@ WORKDIR /app
44

55
RUN adduser --disabled-password --gecos '' appuser
66

7-
RUN apt-get update && apt-get install -y \
8-
build-essential \
9-
curl \
10-
software-properties-common \
11-
&& rm -rf /var/lib/apt/lists/*
12-
137
WORKDIR /app
148

159
COPY requirements.txt /app/
16-
RUN pip3 install -r requirements.txt
10+
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
1711

1812
COPY . /app/
1913

application/Dockerfile-api

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ FROM public.ecr.aws/docker/library/python:3.10-slim
33
WORKDIR /app
44

55
COPY . /app/
6-
RUN pip3 install -r requirements-api.txt
6+
RUN pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
77

88
EXPOSE 8000
99

application/api/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class Question(BaseModel):
99
visualize_results_flag: bool = True
1010
intent_ner_recognition_flag: bool = True
1111
agent_cot_flag: bool = True
12-
profile_name: str = "shopping-demo"
12+
profile_name: str
1313
explain_gen_process_flag: bool = True
1414
gen_suggested_question_flag: bool = False
1515
answer_with_insights: bool = False

application/pages/1_🌍_Generative_BI_Playground.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def recurrent_display(messages, i):
127127
return i
128128

129129

130-
def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, opensearch_info, selected_profile,
130+
def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, opensearch_info,
131+
selected_profile,
131132
use_rag,
132133
model_provider=None):
133134
entity_slot_retrieve = []
@@ -241,6 +242,12 @@ def main():
241242
if 'current_profile' not in st.session_state:
242243
st.session_state['current_profile'] = ''
243244

245+
if 'current_model_id' not in st.session_state:
246+
st.session_state['current_model_id'] = ''
247+
248+
if 'config_data_with_analyse' not in st.session_state:
249+
st.session_state['config_data_with_analyse'] = False
250+
244251
if 'nlq_chain' not in st.session_state:
245252
st.session_state['nlq_chain'] = None
246253

@@ -250,7 +257,8 @@ def main():
250257
if "current_sql_result" not in st.session_state:
251258
st.session_state.current_sql_result = {}
252259

253-
model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-5-sonnet-20240620-v1:0', 'anthropic.claude-3-opus-20240229-v1:0',
260+
model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-5-sonnet-20240620-v1:0',
261+
'anthropic.claude-3-opus-20240229-v1:0',
254262
'anthropic.claude-3-haiku-20240307-v1:0', 'mistral.mixtral-8x7b-instruct-v0:1',
255263
'meta.llama3-70b-instruct-v1:0']
256264

@@ -263,8 +271,15 @@ def main():
263271
with st.sidebar:
264272
st.title('Setting')
265273
# The default option can be the first one in the profiles dictionary, if exists
266-
267-
selected_profile = st.selectbox("Data Profile", list(st.session_state.get('profiles', {}).keys()))
274+
session_state_list = list(st.session_state.get('profiles', {}).keys())
275+
if st.session_state.current_profile != "":
276+
if st.session_state.current_profile in session_state_list:
277+
profile_index = session_state_list.index(st.session_state.current_profile)
278+
selected_profile = st.selectbox("Data Profile", session_state_list, index=profile_index)
279+
else:
280+
selected_profile = st.selectbox("Data Profile", session_state_list)
281+
else:
282+
selected_profile = st.selectbox("Data Profile", session_state_list)
268283
if selected_profile != st.session_state.current_profile:
269284
# clear session state
270285
st.session_state.selected_sample = ''
@@ -273,7 +288,11 @@ def main():
273288
st.session_state.messages[selected_profile] = []
274289
st.session_state.nlq_chain = NLQChain(selected_profile)
275290

276-
model_type = st.selectbox("Choose your model", model_ids)
291+
if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids:
292+
model_index = model_ids.index(st.session_state.current_model_id)
293+
model_type = st.selectbox("Choose your model", model_ids, index=model_index)
294+
else:
295+
model_type = st.selectbox("Choose your model", model_ids)
277296

278297
use_rag_flag = st.checkbox("Using RAG from Q/A Embedding", True)
279298
visualize_results_flag = st.checkbox("Visualize Results", True)
@@ -363,10 +382,13 @@ def main():
363382
database_profile['db_url'] = db_url
364383
database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name)
365384
prompt_map = database_profile['prompt_map']
385+
prompt_map_flag = False
366386
for key in prompt_map_dict:
367387
if key not in prompt_map:
368388
prompt_map[key] = prompt_map_dict[key]
369-
ProfileManagement.update_table_prompt_map(selected_profile, prompt_map)
389+
prompt_map_flag = True
390+
if prompt_map_flag:
391+
ProfileManagement.update_table_prompt_map(selected_profile, prompt_map)
370392

371393
# Multiple rounds of dialogue, query rewriting
372394
user_query_history = get_user_history(selected_profile)

application/pages/4_🪙_Schema_Description_Management.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ def main():
1414

1515
with st.sidebar:
1616
st.title("Schema Management")
17-
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
17+
all_profiles_list = ProfileManagement.get_all_profiles()
18+
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
19+
profile_index = all_profiles_list.index(st.session_state.current_profile)
20+
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
21+
else:
22+
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
1823
index=None,
1924
placeholder="Please select data profile...", key='current_profile_name')
2025

application/pages/5_🪙_Prompt_Management.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ def main():
1616

1717
with st.sidebar:
1818
st.title("Prompt Management")
19-
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
19+
all_profiles_list = ProfileManagement.get_all_profiles()
20+
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
21+
profile_index = all_profiles_list.index(st.session_state.current_profile)
22+
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
23+
else:
24+
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
2025
index=None,
2126
placeholder="Please select data profile...", key='current_profile_name')
2227

@@ -29,6 +34,9 @@ def main():
2934
format_func=lambda x: prompt_map[x].get('title'),
3035
placeholder="Please select a prompt type")
3136

37+
profile_detail = ProfileManagement.get_profile_by_name(current_profile)
38+
prompt_map = profile_detail.prompt_map
39+
3240
if prompt_type_selected_table is not None:
3341
single_type_prompt_map = prompt_map.get(prompt_type_selected_table)
3442
system_prompt = single_type_prompt_map.get('system_prompt')

application/pages/6_📚_Index_Management.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ def main():
2626

2727
with st.sidebar:
2828
st.title("Index Management")
29-
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
29+
all_profiles_list = ProfileManagement.get_all_profiles()
30+
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
31+
profile_index = all_profiles_list.index(st.session_state.current_profile)
32+
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index = profile_index)
33+
else:
34+
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
3035
index=None,
3136
placeholder="Please select data profile...", key='current_profile_name')
3237

application/pages/7_📚_Entity_Management.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ def main():
2727

2828
with st.sidebar:
2929
st.title("Entity Management")
30-
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
30+
all_profiles_list = ProfileManagement.get_all_profiles()
31+
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
32+
profile_index = all_profiles_list.index(st.session_state.current_profile)
33+
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
34+
else:
35+
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
3136
index=None,
3237
placeholder="Please select data profile...", key='current_profile_name')
3338

application/pages/8_📚_Agent_Cot_Management.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ def main():
2626

2727
with st.sidebar:
2828
st.title("Agent Cot Management")
29-
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
29+
all_profiles_list = ProfileManagement.get_all_profiles()
30+
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
31+
profile_index = all_profiles_list.index(st.session_state.current_profile)
32+
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
33+
else:
34+
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
3035
index=None,
3136
placeholder="Please select data profile...", key='current_profile_name')
3237

application/utils/env_var.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242

4343
BEDROCK_SECRETS_AK_SK = os.getenv('BEDROCK_SECRETS_AK_SK', '')
4444

45+
BEDROCK_EMBEDDING_MODEL = os.getenv('BEDROCK_EMBEDDING_MODEL', '')
46+
47+
SAGEMAKER_ENDPOINT_EMBEDDING = os.getenv('SAGEMAKER_ENDPOINT_EMBEDDING', '')
4548

4649
def get_opensearch_parameter():
4750
try:

0 commit comments

Comments
 (0)