Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import re | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| import torch.nn as nn | |
| # Klassendefinition aus dem Training | |
| class FlanT5Classifier(nn.Module): | |
| def __init__(self, base_model_name="google/flan-t5-base", num_labels=4): | |
| super().__init__() | |
| self.encoder = T5EncoderModel.from_pretrained(base_model_name) | |
| self.dropout = nn.Dropout(0.1) | |
| self.classifier = nn.Linear(self.encoder.config.d_model, num_labels) | |
| def forward(self, input_ids, attention_mask=None): | |
| encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = encoder_outputs.last_hidden_state[:, 0] | |
| logits = self.classifier(self.dropout(pooled)) | |
| return {"logits": logits} | |
| # Tokenizer laden | |
| tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner") | |
| # Modell instanziieren und Token-Embeddings anpassen | |
| model = FlanT5Classifier() | |
| model.encoder.resize_token_embeddings(len(tokenizer)) | |
| # Gewichte laden | |
| state_dict = torch.load("pytorch_model.bin", map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # ID-Zuordnung | |
| id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"} | |
| # Tokenizer-Funktionen | |
| def custom_tokenize(text): | |
| return re.findall(r"\w+|[^\w\s]", text, re.UNICODE) | |
| def custom_detokenize(tokens): | |
| text = "" | |
| for i, token in enumerate(tokens): | |
| if i > 0 and re.match(r"\w", token): | |
| text += " " | |
| text += token | |
| return text | |
| # Klassifikationsfunktion | |
| def classify_tokens(text): | |
| tokens = custom_tokenize(text) | |
| results = [] | |
| for i in range(len(tokens)): | |
| wrapped = tokens[:i] + ["<TSTART>", tokens[i], "<TEND>"] + tokens[i+1:] | |
| prompt = "classify token in: " + custom_detokenize(wrapped) | |
| inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128) | |
| with torch.no_grad(): | |
| logits = model(**inputs)["logits"] | |
| pred_id = torch.argmax(logits, dim=-1).item() | |
| label = id2label[pred_id] | |
| results.append((tokens[i], label)) | |
| return results | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=classify_tokens, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."), | |
| outputs=gr.HighlightedText(label="Token Classification Output"), | |
| title="Flan-T5 Token Classification (NER)", | |
| description="Classifies each token in the input text as LOC, ORG, PER, or O." | |
| ) | |
| demo.launch() | |