Spaces:
Sleeping
Sleeping
Update my_model/state_manager.py
Browse files- my_model/state_manager.py +17 -14
my_model/state_manager.py
CHANGED
|
@@ -4,12 +4,15 @@ import streamlit as st
|
|
| 4 |
from my_model.utilities.gen_utilities import free_gpu_resources
|
| 5 |
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
|
| 6 |
|
| 7 |
-
|
| 8 |
-
col1, col2, col3 = st.columns([0.2, 0.6, 0.2])
|
| 9 |
|
| 10 |
|
| 11 |
class StateManager:
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def initialize_state(self):
|
| 14 |
if 'images_data' not in st.session_state:
|
| 15 |
st.session_state['images_data'] = {}
|
|
@@ -26,18 +29,18 @@ class StateManager:
|
|
| 26 |
|
| 27 |
def set_up_widgets(self):
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
# Conditional display of model settings
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
|
| 43 |
|
|
@@ -54,11 +57,11 @@ class StateManager:
|
|
| 54 |
|
| 55 |
|
| 56 |
def display_model_settings(self):
|
| 57 |
-
|
| 58 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
| 59 |
df = pd.DataFrame(data)
|
| 60 |
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
|
| 61 |
-
col3.table(styled_df)
|
| 62 |
|
| 63 |
|
| 64 |
def display_session_state(self):
|
|
@@ -109,7 +112,7 @@ class StateManager:
|
|
| 109 |
if self.is_model_loaded():
|
| 110 |
prepare_kbvqa_model(only_reload_detection_model=True)
|
| 111 |
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
|
| 112 |
-
col1.success("Model reloaded with updated settings and ready for inference.")
|
| 113 |
free_gpu_resources()
|
| 114 |
except Exception as e:
|
| 115 |
st.error(f"Error reloading detection model: {e}")
|
|
|
|
| 4 |
from my_model.utilities.gen_utilities import free_gpu_resources
|
| 5 |
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
|
| 6 |
|
| 7 |
+
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class StateManager:
|
| 11 |
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# Create three columns with different widths
|
| 14 |
+
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
|
| 15 |
+
|
| 16 |
def initialize_state(self):
|
| 17 |
if 'images_data' not in st.session_state:
|
| 18 |
st.session_state['images_data'] = {}
|
|
|
|
| 29 |
|
| 30 |
def set_up_widgets(self):
|
| 31 |
|
| 32 |
+
|
| 33 |
+
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
| 34 |
+
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
| 35 |
+
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
|
| 36 |
+
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)
|
| 37 |
|
| 38 |
# Conditional display of model settings
|
| 39 |
|
| 40 |
+
|
| 41 |
+
show_model_settings = self.col3.checkbox("Show Model Settings", False)
|
| 42 |
+
if show_model_settings:
|
| 43 |
+
self.display_model_settings()
|
| 44 |
|
| 45 |
|
| 46 |
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def display_model_settings(self):
|
| 60 |
+
self.col3.write("##### Current Model Settings:")
|
| 61 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
| 62 |
df = pd.DataFrame(data)
|
| 63 |
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
|
| 64 |
+
self.col3.table(styled_df)
|
| 65 |
|
| 66 |
|
| 67 |
def display_session_state(self):
|
|
|
|
| 112 |
if self.is_model_loaded():
|
| 113 |
prepare_kbvqa_model(only_reload_detection_model=True)
|
| 114 |
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
|
| 115 |
+
self.col1.success("Model reloaded with updated settings and ready for inference.")
|
| 116 |
free_gpu_resources()
|
| 117 |
except Exception as e:
|
| 118 |
st.error(f"Error reloading detection model: {e}")
|