| from io import StringIO | |
| from typing import Optional | |
| import gradio as gr | |
| import pandas as pd | |
| from utils.similarity import batch_cos_sim | |
| def read_data(filepath: str) -> Optional[pd.DataFrame]: | |
| if filepath.endswith('.xlsx'): | |
| df = pd.read_csv(filepath) | |
| elif filepath.endswith('.csv'): | |
| df = pd.read_csv(filepath) | |
| else: | |
| raise Exception('File type not supported') | |
| return df | |
| def process(model_name: str, | |
| text: str, | |
| file=None, | |
| ): | |
| if file: | |
| df = read_data(file.name) | |
| elif text: | |
| string_io = StringIO(text) | |
| df = pd.read_csv(string_io) | |
| else: | |
| raise Exception('No input provided') | |
| df = batch_cos_sim(df, model_name) | |
| path = 'output.csv' | |
| df.to_csv(path, index=False, encoding='utf-8-sig') | |
| return str(df), path | |
| model_name_input = gr.components.Textbox( | |
| value='paraphrase-multilingual-MiniLM-L12-v2', | |
| lines=1, | |
| type='text' | |
| ) | |
| model_name_option = gr.components.Dropdown( | |
| label='Model Name', | |
| value='paraphrase-multilingual-MiniLM-L12-v2', | |
| choices=[ | |
| 'paraphrase-multilingual-MiniLM-L12-v2', | |
| 'paraphrase-multilingual-mpnet-base-v2', | |
| 'cyclone/simcse-chinese-roberta-wwm-ext' | |
| ] | |
| ) | |
| text_input = gr.components.Textbox( | |
| value='prompt,response\n', | |
| lines=10, | |
| type='text' | |
| ) | |
| text_output = gr.components.Textbox( | |
| label='Output', | |
| type='text' | |
| ) | |
| file_output = gr.components.File(label='Output File', | |
| file_count='single', | |
| file_types=['', '.', '.csv', '.xls', '.xlsx']) | |
| app = gr.Interface( | |
| fn=process, | |
| inputs=[model_name_option, text_input, 'file'], | |
| outputs=[text_output, file_output] | |
| ) | |
| app.launch() | |