Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import psutil | |
| import pandas as pd | |
| import streamlit as st | |
| import plotly.express as px | |
| from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS | |
| from zeroshot_classification.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier | |
| print(f"Total mem: {psutil.virtual_memory().total}") | |
| def init_state(key: str): | |
| if key not in st.session_state: | |
| st.session_state[key] = None | |
| for k in [ | |
| "current_model", | |
| "current_model_option", | |
| "current_method_option", | |
| "current_prediction", | |
| "current_chart", | |
| ]: | |
| init_state(k) | |
| def load_model(model_option: str, method_option: str, random_state: int = 0): | |
| with st.spinner("Loading selected model..."): | |
| if method_option == "Natural Language Inference": | |
| st.session_state.current_model = NLIZeroshotClassifier( | |
| model_name=model_option, random_state=random_state | |
| ) | |
| else: | |
| st.session_state.current_model = NSPZeroshotClassifier( | |
| model_name=model_option, random_state=random_state | |
| ) | |
| st.success("Model loaded!") | |
| def visualize_output(labels: list[str], probabilities: list[float]): | |
| data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values( | |
| by="probability", ascending=False | |
| ) | |
| chart = px.bar( | |
| data, | |
| x="probability", | |
| y="labels", | |
| color="labels", | |
| orientation="h", | |
| height=290, | |
| width=500, | |
| ).update_layout( | |
| { | |
| "xaxis": {"title": "probability", "visible": True, "showticklabels": True}, | |
| "yaxis": {"title": None, "visible": True, "showticklabels": True}, | |
| "margin": dict( | |
| l=10, # left | |
| r=10, # right | |
| t=50, # top | |
| b=10, # bottom | |
| ), | |
| "showlegend": False, | |
| } | |
| ) | |
| return chart | |
| st.title("Zero-shot Turkish Text Classification") | |
| method_option = st.radio( | |
| "Select a zero-shot classification method.", | |
| [ | |
| METHOD_OPTIONS["nli"], | |
| METHOD_OPTIONS["nsp"], | |
| ], | |
| ) | |
| if method_option == METHOD_OPTIONS["nli"]: | |
| model_option = st.selectbox( | |
| "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3 | |
| ) | |
| if method_option == METHOD_OPTIONS["nsp"]: | |
| model_option = st.selectbox( | |
| "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0 | |
| ) | |
| if model_option != st.session_state.current_model_option: | |
| st.session_state.current_model_option = model_option | |
| st.session_state.current_method_option = method_option | |
| load_model( | |
| st.session_state.current_model_option, st.session_state.current_method_option | |
| ) | |
| st.header("Configure prompts and labels") | |
| col1, col2 = st.columns(2) | |
| col1.subheader("Candidate labels") | |
| labels = col1.text_area( | |
| label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.", | |
| value="spor,dünya,siyaset,ekonomi,sanat", | |
| key="current_labels", | |
| ) | |
| col1.header("Make predictions") | |
| text = col1.text_area( | |
| "Enter a sentence or a paragraph to classify.", | |
| value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.", | |
| key="current_text", | |
| ) | |
| col2.subheader("Prompt template") | |
| prompt_template = col2.text_area( | |
| label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.", | |
| value="Bu metin {} kategorisine aittir", | |
| key="current_template", | |
| ) | |
| col2.header("") | |
| make_pred = col1.button("Predict") | |
| if make_pred: | |
| st.session_state.current_prediction = ( | |
| st.session_state.current_model.predict_on_texts( | |
| [st.session_state.current_text], | |
| candidate_labels=st.session_state.current_labels.split(","), | |
| prompt_template=st.session_state.current_template, | |
| ) | |
| ) | |
| if "scores" in st.session_state.current_prediction[0]: | |
| st.session_state.current_chart = visualize_output( | |
| st.session_state.current_prediction[0]["labels"], | |
| st.session_state.current_prediction[0]["scores"], | |
| ) | |
| elif "probabilities" in st.session_state.current_prediction[0]: | |
| st.session_state.current_chart = visualize_output( | |
| st.session_state.current_prediction[0]["labels"], | |
| st.session_state.current_prediction[0]["probabilities"], | |
| ) | |
| col2.plotly_chart(st.session_state.current_chart, use_container_width=True) | |