| # Zero-Shot Text Classification using `facebook/bart-large-mnli` | |
| This repository demonstrates how to use the [`facebook/bart-large-mnli`](https://huggingface.co/facebook/bart-large-mnli) model for **zero-shot text classification** based on **natural language inference (NLI)**. | |
| We extend the base usage by: | |
| - Using a labeled dataset for benchmarking | |
| - Performing optional fine-tuning | |
| - Quantizing the model to FP16 | |
| - Scoring model performance | |
| --- | |
| ## π Model Description | |
| - **Model:** `facebook/bart-large-mnli` | |
| - **Type:** NLI-based zero-shot classifier | |
| - **Architecture:** BART (Bidirectional and Auto-Regressive Transformers) | |
| - **Usage:** Classifies text by scoring label hypotheses as NLI entailment | |
| --- | |
| ## π Dataset | |
| We use the [`yahoo_answers_topics`](https://huggingface.co/datasets/yahoo_answers_topics) dataset from Hugging Face for evaluation. It contains questions categorized into 10 topics. | |
| ```python | |
| from datasets import load_dataset | |
| dataset = load_dataset("yahoo_answers_topics") | |
| ``` | |
| # π§ Zero-Shot Classification Logic | |
| The model checks whether a text entails a hypothesis like: | |
| "This text is about sports." | |
| For each candidate label (e.g., "sports", "education", "health"), we convert them into such hypotheses and use the model to score them. | |
| # β Example: Inference with Zero-Shot Pipeline | |
| ```python | |
| from transformers import pipeline | |
| classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| sequence = "The team played well and won the championship." | |
| labels = ["sports", "politics", "education", "technology"] | |
| result = classifier(sequence, candidate_labels=labels) | |
| print(result) | |
| ``` | |
| # π Scoring / Evaluation | |
| Evaluate zero-shot classification using accuracy or top-k accuracy: | |
| ```python | |
| from sklearn.metrics import accuracy_score | |
| def evaluate_zero_shot(dataset, labels): | |
| correct = 0 | |
| total = 0 | |
| for example in dataset: | |
| result = classifier(example["question_content"], candidate_labels=labels) | |
| predicted = result["labels"][0] | |
| true = labels[example["topic"]] | |
| correct += int(predicted == true) | |
| total += 1 | |
| return correct / total | |
| labels = ["Society & Culture", "Science & Mathematics", "Health", "Education", | |
| "Computers & Internet", "Sports", "Business & Finance", "Entertainment & Music", | |
| "Family & Relationships", "Politics & Government"] | |
| acc = evaluate_zero_shot(dataset["test"].select(range(100)), labels) | |
| print(f"Accuracy: {acc:.2%}") | |
| ``` |