Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import io | |
| from PIL import Image | |
| # Function to perform classification and create pie and bar charts | |
| def classify_and_plot(text, labels): | |
| # Splitting labels entered by user | |
| labels_list = labels.split(',') | |
| # Load the zero-shot classification pipeline with the specific model | |
| classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| # Perform classification | |
| result = classifier(text, labels_list) | |
| # Extract labels and scores | |
| labels = result['labels'] | |
| scores = result['scores'] | |
| # Generate a colour for each label | |
| colors = plt.cm.viridis(np.linspace(0, 1, len(labels))) | |
| # Create a pie chart | |
| fig1, ax1 = plt.subplots() | |
| wedges, texts = ax1.pie(scores, startangle=140, colors=colors) | |
| ax1.axis('equal') # Equal aspect ratio ensures the pie chart is circular. | |
| ax1.set_title('Pie Chart') | |
| # Prepare labels with percentages for the pie chart legend | |
| legend_labels = ['{0} - {1:1.2f} %'.format(i,j*100) for i,j in zip(labels, scores)] | |
| ax1.legend(wedges, legend_labels, title="Labels with Scores", loc="center left", bbox_to_anchor=(1, 0.5)) | |
| # Save the pie chart to a buffer | |
| buf1 = io.BytesIO() | |
| plt.savefig(buf1, format='png', bbox_inches='tight') | |
| buf1.seek(0) | |
| pie_chart = Image.open(buf1) | |
| pie_chart_array = np.array(pie_chart) | |
| plt.close() | |
| # Create a bar chart | |
| fig2, ax2 = plt.subplots() | |
| y_pos = np.arange(len(labels)) | |
| ax2.bar(y_pos, scores, align='center', alpha=0.7, color='blue') | |
| ax2.set_xticks(y_pos) | |
| ax2.set_xticklabels(labels, rotation=45, ha="right") | |
| ax2.set_ylabel('Scores') | |
| ax2.set_title('Bar Chart') | |
| # Save the bar chart to a buffer | |
| buf2 = io.BytesIO() | |
| plt.savefig(buf2, format='png', bbox_inches='tight') | |
| buf2.seek(0) | |
| bar_chart = Image.open(buf2) | |
| bar_chart_array = np.array(bar_chart) | |
| plt.close() | |
| return pie_chart_array, bar_chart_array | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_and_plot, | |
| inputs=["text", "text"], | |
| outputs=["image", "image"], | |
| title="Zero-Shot Classification with Pie and Bar Charts", | |
| description="Enter text and comma-separated labels for classification using the facebook/bart-large-mnli model. The outputs will be separate pie and bar charts representing the classification scores." | |
| ) | |
| # Launch the interface with the 'share' argument | |
| iface.launch(share=True) | |