| | """ |
| | Multi-Domain Classifier - Inference Example |
| | Repository: https://huggingface.co/ovinduG/multi-domain-classifier-phi3 |
| | """ |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| | import torch |
| | import json |
| |
|
| | class MultiDomainClassifier: |
| | def __init__(self, model_id="ovinduG/multi-domain-classifier-phi3"): |
| | print("Loading model...") |
| | |
| | |
| | self.base_model = AutoModelForCausalLM.from_pretrained( |
| | "microsoft/Phi-3-mini-4k-instruct", |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto" |
| | ) |
| | |
| | |
| | self.model = PeftModel.from_pretrained(self.base_model, model_id) |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | self.model.eval() |
| | |
| | print("✅ Model loaded!") |
| | |
| | def predict(self, query: str) -> dict: |
| | """Classify a query into domains""" |
| | |
| | prompt = f"""Classify this query: {query} |
| | |
| | Output JSON format: |
| | { |
| | "primary_domain": "domain_name", |
| | "primary_confidence": 0.95, |
| | "is_multi_domain": true/false, |
| | "secondary_domains": [{"domain": "name", "confidence": 0.85}] |
| | }""" |
| | |
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=200, |
| | temperature=0.1, |
| | do_sample=False, |
| | use_cache=False |
| | ) |
| | |
| | response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | try: |
| | json_str = response.split("Output JSON format:")[-1].strip() |
| | result = json.loads(json_str) |
| | return result |
| | except: |
| | return {"error": "Failed to parse response", "raw": response} |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | classifier = MultiDomainClassifier() |
| | |
| | |
| | queries = [ |
| | "Write a Python function to calculate factorial", |
| | "Build ML model to analyze sales data and create API endpoints", |
| | "What is quantum entanglement?", |
| | "Create a REST API for healthcare diabetes prediction" |
| | ] |
| | |
| | print("\n" + "="*80) |
| | print("CLASSIFICATION EXAMPLES") |
| | print("="*80) |
| | |
| | for query in queries: |
| | print(f"\nQuery: {query}") |
| | result = classifier.predict(query) |
| | print(f"Result: {json.dumps(result, indent=2)}") |
| | print("-"*80) |
| |
|