mugu5 commited on
Commit
e6c25d5
·
verified ·
1 Parent(s): 9ea42e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -1,40 +1,40 @@
1
- import streamlit as st
2
- import torch
3
- from Models.All_Model import *
4
-
5
- # model init
6
- model = BertForMultiLabel()
7
- # Load fine-tuned weights
8
- state_dict = torch.load(BERT_MODEL_PATH, map_location="cpu")
9
- model.load_state_dict(state_dict)
10
- model.eval()
11
-
12
- # -------------------------------
13
- # Streamlit App
14
- # -------------------------------
15
- st.title("Emotion Classification with fine‑tuned BERT")
16
-
17
- # Input text box
18
- text = st.text_area("Enter text to analyze five different emotions:")
19
-
20
- if st.button("Predict"):
21
- if text.strip():
22
- # Tokenize input
23
- inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
24
-
25
- with torch.no_grad():
26
- # logits = model(**inputs) for Ro berta
27
- logits = model(input_ids=inputs["input_ids"],
28
- attention_mask=inputs["attention_mask"])
29
-
30
- probs = torch.sigmoid(logits).cpu().numpy().tolist()[0]
31
-
32
- emotions = ["anger", "fear", "joy", "sadness", "surprise"]
33
- result = dict(zip(emotions, probs))
34
-
35
- # Display results
36
- st.subheader("Predicted Emotion Probabilities")
37
- for emotion, prob in result.items():
38
- st.write(f"{emotion} : {prob:.4f}")
39
- else:
40
- st.warning("Please enter some text before predicting.")
 
1
+ import streamlit as st
2
+ import torch
3
+ from All_Model import BertForMultiLabel ,bert_tokenizer
4
+
5
+ # model init
6
+ model = BertForMultiLabel()
7
+ # Load fine-tuned weights
8
+ state_dict = torch.load(BERT_MODEL_PATH, map_location="cpu")
9
+ model.load_state_dict(state_dict)
10
+ model.eval()
11
+
12
+ # -------------------------------
13
+ # Streamlit App
14
+ # -------------------------------
15
+ st.title("Emotion Classification with fine‑tuned BERT")
16
+
17
+ # Input text box
18
+ text = st.text_area("Enter text to analyze five different emotions:")
19
+
20
+ if st.button("Predict"):
21
+ if text.strip():
22
+ # Tokenize input
23
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
24
+
25
+ with torch.no_grad():
26
+ # logits = model(**inputs) for Ro berta
27
+ logits = model(input_ids=inputs["input_ids"],
28
+ attention_mask=inputs["attention_mask"])
29
+
30
+ probs = torch.sigmoid(logits).cpu().numpy().tolist()[0]
31
+
32
+ emotions = ["anger", "fear", "joy", "sadness", "surprise"]
33
+ result = dict(zip(emotions, probs))
34
+
35
+ # Display results
36
+ st.subheader("Predicted Emotion Probabilities")
37
+ for emotion, prob in result.items():
38
+ st.write(f"{emotion} : {prob:.4f}")
39
+ else:
40
+ st.warning("Please enter some text before predicting.")