Spaces:
Runtime error
Runtime error
change defaults etc
Browse files
app.py
CHANGED
|
@@ -5,14 +5,20 @@ import plotly.express as px
|
|
| 5 |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
|
| 6 |
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
|
| 7 |
|
| 8 |
-
if "current_model" not in st.session_state:
|
| 9 |
-
st.session_state["current_model"] = None
|
| 10 |
|
| 11 |
-
|
| 12 |
-
st.session_state
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def load_model(model_option: str, method_option: str, random_state: int = 0):
|
|
@@ -66,13 +72,11 @@ method_option = st.radio(
|
|
| 66 |
)
|
| 67 |
if method_option == METHOD_OPTIONS["nli"]:
|
| 68 |
model_option = st.selectbox(
|
| 69 |
-
"Select a natural language inference model.",
|
| 70 |
-
NLI_MODEL_OPTIONS,
|
| 71 |
)
|
| 72 |
if method_option == METHOD_OPTIONS["nsp"]:
|
| 73 |
model_option = st.selectbox(
|
| 74 |
-
"Select a BERT model for next sentence prediction.",
|
| 75 |
-
NSP_MODEL_OPTIONS,
|
| 76 |
)
|
| 77 |
|
| 78 |
if model_option != st.session_state.current_model_option:
|
|
@@ -105,17 +109,25 @@ prompt_template = col2.text_area(
|
|
| 105 |
key="current_template",
|
| 106 |
)
|
| 107 |
col2.header("")
|
|
|
|
|
|
|
| 108 |
make_pred = col1.button("Predict")
|
| 109 |
if make_pred:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
-
if "scores" in
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
-
col2.plotly_chart(
|
|
|
|
| 5 |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
|
| 6 |
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def init_state(key: str):
|
| 10 |
+
if key not in st.session_state:
|
| 11 |
+
st.session_state[key] = None
|
| 12 |
|
| 13 |
+
|
| 14 |
+
for k in [
|
| 15 |
+
"current_model",
|
| 16 |
+
"current_model_option",
|
| 17 |
+
"current_method_option",
|
| 18 |
+
"current_prediction",
|
| 19 |
+
"current_chart",
|
| 20 |
+
]:
|
| 21 |
+
init_state(k)
|
| 22 |
|
| 23 |
|
| 24 |
def load_model(model_option: str, method_option: str, random_state: int = 0):
|
|
|
|
| 72 |
)
|
| 73 |
if method_option == METHOD_OPTIONS["nli"]:
|
| 74 |
model_option = st.selectbox(
|
| 75 |
+
"Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
|
|
|
|
| 76 |
)
|
| 77 |
if method_option == METHOD_OPTIONS["nsp"]:
|
| 78 |
model_option = st.selectbox(
|
| 79 |
+
"Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
if model_option != st.session_state.current_model_option:
|
|
|
|
| 109 |
key="current_template",
|
| 110 |
)
|
| 111 |
col2.header("")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
make_pred = col1.button("Predict")
|
| 115 |
if make_pred:
|
| 116 |
+
st.session_state.current_prediction = (
|
| 117 |
+
st.session_state.current_model.predict_on_texts(
|
| 118 |
+
[st.session_state.current_text],
|
| 119 |
+
candidate_labels=st.session_state.current_labels.split(","),
|
| 120 |
+
prompt_template=st.session_state.current_template,
|
| 121 |
+
)
|
| 122 |
)
|
| 123 |
+
if "scores" in st.session_state.current_prediction[0]:
|
| 124 |
+
st.session_state.current_chart = visualize_output(
|
| 125 |
+
st.session_state.current_prediction[0]["labels"],
|
| 126 |
+
st.session_state.current_prediction[0]["scores"],
|
| 127 |
+
)
|
| 128 |
+
elif "probabilities" in st.session_state.current_prediction[0]:
|
| 129 |
+
st.session_state.current_chart = visualize_output(
|
| 130 |
+
st.session_state.current_prediction[0]["labels"],
|
| 131 |
+
st.session_state.current_prediction[0]["probabilities"],
|
| 132 |
)
|
| 133 |
+
col2.plotly_chart(st.session_state.current_chart, use_container_width=True)
|