@@ -127,7 +127,8 @@ def recurrent_display(messages, i):
127127 return i
128128
129129
130- def normal_text_search_streamlit (search_box , model_type , database_profile , entity_slot , opensearch_info , selected_profile ,
130+ def normal_text_search_streamlit (search_box , model_type , database_profile , entity_slot , opensearch_info ,
131+ selected_profile ,
131132 use_rag ,
132133 model_provider = None ):
133134 entity_slot_retrieve = []
@@ -241,6 +242,12 @@ def main():
241242 if 'current_profile' not in st .session_state :
242243 st .session_state ['current_profile' ] = ''
243244
245+ if 'current_model_id' not in st .session_state :
246+ st .session_state ['current_model_id' ] = ''
247+
248+ if 'config_data_with_analyse' not in st .session_state :
249+ st .session_state ['config_data_with_analyse' ] = False
250+
244251 if 'nlq_chain' not in st .session_state :
245252 st .session_state ['nlq_chain' ] = None
246253
@@ -250,7 +257,8 @@ def main():
250257 if "current_sql_result" not in st .session_state :
251258 st .session_state .current_sql_result = {}
252259
253- model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0' , 'anthropic.claude-3-5-sonnet-20240620-v1:0' , 'anthropic.claude-3-opus-20240229-v1:0' ,
260+ model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0' , 'anthropic.claude-3-5-sonnet-20240620-v1:0' ,
261+ 'anthropic.claude-3-opus-20240229-v1:0' ,
254262 'anthropic.claude-3-haiku-20240307-v1:0' , 'mistral.mixtral-8x7b-instruct-v0:1' ,
255263 'meta.llama3-70b-instruct-v1:0' ]
256264
@@ -263,8 +271,15 @@ def main():
263271 with st .sidebar :
264272 st .title ('Setting' )
265273 # The default option can be the first one in the profiles dictionary, if exists
266-
267- selected_profile = st .selectbox ("Data Profile" , list (st .session_state .get ('profiles' , {}).keys ()))
274+ session_state_list = list (st .session_state .get ('profiles' , {}).keys ())
275+ if st .session_state .current_profile != "" :
276+ if st .session_state .current_profile in session_state_list :
277+ profile_index = session_state_list .index (st .session_state .current_profile )
278+ selected_profile = st .selectbox ("Data Profile" , session_state_list , index = profile_index )
279+ else :
280+ selected_profile = st .selectbox ("Data Profile" , session_state_list )
281+ else :
282+ selected_profile = st .selectbox ("Data Profile" , session_state_list )
268283 if selected_profile != st .session_state .current_profile :
269284 # clear session state
270285 st .session_state .selected_sample = ''
@@ -273,7 +288,11 @@ def main():
273288 st .session_state .messages [selected_profile ] = []
274289 st .session_state .nlq_chain = NLQChain (selected_profile )
275290
276- model_type = st .selectbox ("Choose your model" , model_ids )
291+ if st .session_state .current_model_id != "" and st .session_state .current_model_id in model_ids :
292+ model_index = model_ids .index (st .session_state .current_model_id )
293+ model_type = st .selectbox ("Choose your model" , model_ids , index = model_index )
294+ else :
295+ model_type = st .selectbox ("Choose your model" , model_ids )
277296
278297 use_rag_flag = st .checkbox ("Using RAG from Q/A Embedding" , True )
279298 visualize_results_flag = st .checkbox ("Visualize Results" , True )
@@ -363,10 +382,13 @@ def main():
363382 database_profile ['db_url' ] = db_url
364383 database_profile ['db_type' ] = ConnectionManagement .get_db_type_by_name (conn_name )
365384 prompt_map = database_profile ['prompt_map' ]
385+ prompt_map_flag = False
366386 for key in prompt_map_dict :
367387 if key not in prompt_map :
368388 prompt_map [key ] = prompt_map_dict [key ]
369- ProfileManagement .update_table_prompt_map (selected_profile , prompt_map )
389+ prompt_map_flag = True
390+ if prompt_map_flag :
391+ ProfileManagement .update_table_prompt_map (selected_profile , prompt_map )
370392
371393 # Multiple rounds of dialogue, query rewriting
372394 user_query_history = get_user_history (selected_profile )
0 commit comments