Spaces:
Runtime error
Runtime error
Update app2.py
Browse files
app2.py
CHANGED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from transformers import (
|
| 4 |
+
AutoModelForSeq2SeqLM,
|
| 5 |
+
AutoModelForTableQuestionAnswering,
|
| 6 |
+
AutoTokenizer,
|
| 7 |
+
pipeline,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
model_tapex = "microsoft/tapex-large-finetuned-wtq"
|
| 11 |
+
tokenizer_tapex = AutoTokenizer.from_pretrained(model_tapex)
|
| 12 |
+
model_tapex = AutoModelForSeq2SeqLM.from_pretrained(model_tapex)
|
| 13 |
+
pipe_tapex = pipeline(
|
| 14 |
+
"table-question-answering", model=model_tapex, tokenizer=tokenizer_tapex
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
model_tapas = "google/tapas-large-finetuned-wtq"
|
| 18 |
+
tokenizer_tapas = AutoTokenizer.from_pretrained(model_tapas)
|
| 19 |
+
model_tapas = AutoModelForTableQuestionAnswering.from_pretrained(model_tapas)
|
| 20 |
+
pipe_tapas = pipeline(
|
| 21 |
+
"table-question-answering", model=model_tapas, tokenizer=tokenizer_tapas
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def process(query, file, correct_answer, rows=20):
|
| 26 |
+
table = pd.read_csv(file.name, header=0).astype(str)
|
| 27 |
+
table = table[:rows]
|
| 28 |
+
result_tapex = pipe_tapex(table=table, query=query)
|
| 29 |
+
result_tapas = pipe_tapas(table=table, query=query)
|
| 30 |
+
return result_tapex["answer"], result_tapas["answer"], correct_answer
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Inputs
|
| 34 |
+
query_text = gr.Text(label="Enter a question")
|
| 35 |
+
input_file = gr.File(label="Upload a CSV file", type="file")
|
| 36 |
+
rows_slider = gr.Slider(label="Number of rows")
|
| 37 |
+
|
| 38 |
+
# Output
|
| 39 |
+
answer_text_tapex = gr.Text(label="TAPEX answer")
|
| 40 |
+
answer_text_tapas = gr.Text(label="TAPAS answer")
|
| 41 |
+
|
| 42 |
+
description = "This Space lets you ask questions on CSV documents with Microsoft [TAPEX-Large](https://huggingface.co/microsoft/tapex-large-finetuned-wtq) and Google [TAPAS-Large](https://huggingface.co/google/tapas-large-finetuned-wtq). \
|
| 43 |
+
Both have been fine-tuned on the [WikiTableQuestions](https://huggingface.co/datasets/wikitablequestions) dataset. \n\n\
|
| 44 |
+
A sample file with football statistics is available in the repository: \n\n\
|
| 45 |
+
* Which team has the most wins? Answer: Manchester City FC\n\
|
| 46 |
+
* Which team has the most wins: Chelsea, Liverpool or Everton? Answer: Liverpool\n\
|
| 47 |
+
* Which teams have scored less than 40 goals? Answer: Cardiff City FC, Fulham FC, Brighton & Hove Albion FC, Huddersfield Town FC\n\
|
| 48 |
+
* What is the average number of wins? Answer: 16 (rounded)\n\n\
|
| 49 |
+
You can also upload your own CSV file. Please note that maximum sequence length for both models is 1024 tokens, \
|
| 50 |
+
so you may need to limit the number of rows in your CSV file. Chunking is not implemented yet."
|
| 51 |
+
|
| 52 |
+
iface = gr.Interface(
|
| 53 |
+
theme="huggingface",
|
| 54 |
+
description=description,
|
| 55 |
+
layout="vertical",
|
| 56 |
+
fn=process,
|
| 57 |
+
inputs=[query_text, input_file, rows_slider],
|
| 58 |
+
outputs=[answer_text_tapex, answer_text_tapas],
|
| 59 |
+
examples=[
|
| 60 |
+
["Which team has the most wins?", "default_file.csv", 20],
|
| 61 |
+
[
|
| 62 |
+
"Which team has the most wins: Chelsea, Liverpool or Everton?",
|
| 63 |
+
"default_file.csv",
|
| 64 |
+
20,
|
| 65 |
+
],
|
| 66 |
+
["Which teams have scored less than 40 goals?", "default_file.csv", 20],
|
| 67 |
+
["What is the average number of wins?", "default_file.csv", 20],
|
| 68 |
+
],
|
| 69 |
+
allow_flagging="never",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
iface.launch()
|