Commit
·
316b0d8
1
Parent(s):
90ed2ce
feat: Remove SHAP
Browse files- app.py +3 -39
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -5,8 +5,6 @@ from numba.core.errors import NumbaDeprecationWarning
|
|
| 5 |
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
|
| 6 |
import gradio as gr
|
| 7 |
from transformers import pipeline
|
| 8 |
-
from shap import Explainer
|
| 9 |
-
import numpy as np
|
| 10 |
from typing import Tuple, Dict, List
|
| 11 |
|
| 12 |
|
|
@@ -30,47 +28,13 @@ def main():
|
|
| 30 |
|
| 31 |
def classification(text) -> Tuple[Dict[str, float], dict]:
|
| 32 |
output: List[dict] = pipe(text)[0]
|
| 33 |
-
print(output)
|
|
|
|
| 34 |
|
| 35 |
-
explainer = Explainer(pipe)
|
| 36 |
-
explanation = explainer([text])
|
| 37 |
-
shap_values = explanation.values[0].sum(axis=1)
|
| 38 |
-
|
| 39 |
-
# Find the SHAP boundary
|
| 40 |
-
boundary = 0.03
|
| 41 |
-
if np.abs(shap_values).max() <= boundary:
|
| 42 |
-
boundary = np.abs(shap_values).max() - 1e-6
|
| 43 |
-
|
| 44 |
-
words: List[str] = explanation.data[0]
|
| 45 |
-
records = list()
|
| 46 |
-
char_idx = 0
|
| 47 |
-
for word, shap_value in zip(words, shap_values):
|
| 48 |
-
|
| 49 |
-
if abs(shap_value) <= boundary:
|
| 50 |
-
entity = 'O'
|
| 51 |
-
else:
|
| 52 |
-
entity = output['label'].lower().replace(' ', '-')
|
| 53 |
-
|
| 54 |
-
if len(word):
|
| 55 |
-
start = char_idx
|
| 56 |
-
char_idx += len(word)
|
| 57 |
-
end = char_idx
|
| 58 |
-
records.append(dict(
|
| 59 |
-
entity=entity,
|
| 60 |
-
word=word,
|
| 61 |
-
score=abs(shap_value),
|
| 62 |
-
start=start,
|
| 63 |
-
end=end,
|
| 64 |
-
))
|
| 65 |
-
print(records)
|
| 66 |
-
|
| 67 |
-
return ({output["label"]: output["score"]}, dict(text=text, entities=records))
|
| 68 |
-
|
| 69 |
-
color_map = {"offensive": "red", "not-offensive": "green", 'O': 'white'}
|
| 70 |
demo = gr.Interface(
|
| 71 |
fn=classification,
|
| 72 |
inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
|
| 73 |
-
outputs=
|
| 74 |
examples=examples,
|
| 75 |
title="Danish Offensive Text Detection",
|
| 76 |
description="""
|
|
|
|
| 5 |
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
|
| 6 |
import gradio as gr
|
| 7 |
from transformers import pipeline
|
|
|
|
|
|
|
| 8 |
from typing import Tuple, Dict, List
|
| 9 |
|
| 10 |
|
|
|
|
| 28 |
|
| 29 |
def classification(text) -> Tuple[Dict[str, float], dict]:
|
| 30 |
output: List[dict] = pipe(text)[0]
|
| 31 |
+
print(text, output)
|
| 32 |
+
return {output["label"]: output["score"]}
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
demo = gr.Interface(
|
| 35 |
fn=classification,
|
| 36 |
inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
|
| 37 |
+
outputs=gr.Label(),
|
| 38 |
examples=examples,
|
| 39 |
title="Danish Offensive Text Detection",
|
| 40 |
description="""
|
requirements.txt
CHANGED
|
@@ -88,7 +88,6 @@ rfc3986==1.5.0
|
|
| 88 |
scikit-learn==1.2.2
|
| 89 |
scipy==1.10.1
|
| 90 |
semantic-version==2.10.0
|
| 91 |
-
shap==0.41.0
|
| 92 |
six==1.16.0
|
| 93 |
slicer==0.0.7
|
| 94 |
sniffio==1.3.0
|
|
|
|
| 88 |
scikit-learn==1.2.2
|
| 89 |
scipy==1.10.1
|
| 90 |
semantic-version==2.10.0
|
|
|
|
| 91 |
six==1.16.0
|
| 92 |
slicer==0.0.7
|
| 93 |
sniffio==1.3.0
|