Mental_health / app.py
VallampatlaBhuvan23's picture
Update app.py
714946e verified
import gradio as gr
from transformers import LongformerTokenizer, pipeline
from PIL import Image
import pytesseract
import cv2
import re
import torch
import matplotlib.pyplot as plt
import math
from typing import Dict, List, Any
import numpy as np
device = 0 if torch.cuda.is_available() else -1
model_id = "allenai/longformer-base-4096"
tok = LongformerTokenizer.from_pretrained(model_id)
emo_head = pipeline(
"text-classification",
model="j-hartmann/emotion-english-distilroberta-base",
return_all_scores=True,
device=device,
)
time_regex = re.compile(r"(\d{1,2}[:]\d{2}\s*(AM|PM|am|pm)?)|(\d{1,2}[/]\d{1,2}[/]\d{2,4})")
negative_keys = {"anger", "sadness", "fear", "disgust"}
positive_keys = {"joy", "surprise"}
def mask_names(names: List[str]) -> Dict[str, str]:
return {n: f"User_{i+1}" for i, n in enumerate(names)}
def ocr_image_path(path: str) -> str:
img = Image.open(path).convert("RGB")
return pytesseract.image_to_string(img)
def ocr_video_path(path: str) -> str:
cap = cv2.VideoCapture(path)
texts = []
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if idx % 25 == 0:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(rgb)
t = pytesseract.image_to_string(img)
if t.strip():
texts.append(t)
idx += 1
cap.release()
return "\n".join(texts)
def split_by_speaker(text: str, privacy: bool) -> Dict[str, str]:
speakers: Dict[str, List[str]] = {}
for raw in text.splitlines():
if ":" in raw:
name, msg = raw.split(":", 1)
name, msg = name.strip(), msg.strip()
if msg:
speakers.setdefault(name, []).append(msg)
if not speakers:
speakers["User"] = [text]
if privacy:
mapping = mask_names(list(speakers.keys()))
return {mapping[k]: " ".join(v) for k, v in speakers.items()}
return {k: " ".join(v) for k, v in speakers.items()}
def chunk_text(text: str, max_tokens: int = 2048) -> List[str]:
words = text.split()
chunks: List[str] = []
temp: List[str] = []
for w in words:
temp.append(w)
enc = tok(" ".join(temp), truncation=True, max_length=max_tokens)
if len(enc["input_ids"]) >= max_tokens:
temp.pop()
chunks.append(" ".join(temp))
temp = [w]
if temp:
chunks.append(" ".join(temp))
return chunks
def emotion_scores(text: str) -> Dict[str, float]:
res = emo_head(text)[0]
return {x["label"]: float(x["score"]) for x in res}
def emotions_over_chunks(chunks: List[str]) -> Dict[str, float]:
if not chunks:
return {}
sums: Dict[str, float] = {}
count = 0
for c in chunks:
e = emotion_scores(c)
for k, v in e.items():
sums[k] = sums.get(k, 0.0) + v
count += 1
return {k: v / count for k, v in sums.items()} if count else {}
def compute_risk(emotions: Dict[str, float]) -> float:
neg = sum(emotions.get(k, 0.0) for k in negative_keys)
strongest_neg = max((emotions.get(k, 0.0) for k in negative_keys), default=0.0)
risk = 0.7 * neg + 0.3 * strongest_neg
return max(0.0, min(1.0, risk))
def analyze(text_input, image_paths, video_paths, privacy_choice):
collected: List[str] = []
if text_input and text_input.strip():
collected.append(text_input)
if image_paths:
for p in image_paths:
t = ocr_image_path(p)
if t.strip():
collected.append(t)
if video_paths:
for p in video_paths:
t = ocr_video_path(p)
if t.strip():
collected.append(t)
if not collected:
return None, None
combined = "\n".join(collected)
speakers = split_by_speaker(combined, privacy_choice == "ON")
results: List[Dict[str, Any]] = []
for name, txt in speakers.items():
chunks = chunk_text(txt)
emos = emotions_over_chunks(chunks)
risk = compute_risk(emos)
results.append(
{
"name": name,
"risk": risk,
"emotions": emos,
}
)
plt.style.use("default")
fig, ax = plt.subplots(1, 2, figsize=(11, 4))
fig.patch.set_facecolor("white")
names = [x["name"] for x in results]
scores = [x["risk"] for x in results]
ax[0].bar(names, scores, color="#DC2626", alpha=0.8)
ax[0].set_ylim(0, 1)
ax[0].set_title("Risk Levels", fontweight="bold", fontsize=12, color="#1F2937")
ax[0].set_ylabel("Risk Score", fontsize=10, color="#4B5563")
ax[0].set_facecolor("white")
ax[0].grid(axis="y", alpha=0.2, linestyle="--")
ax[0].spines["top"].set_visible(False)
ax[0].spines["right"].set_visible(False)
group_emo: Dict[str, float] = {}
for r in results:
for k, v in r["emotions"].items():
group_emo[k] = group_emo.get(k, 0.0) + v
group_emo = {k: v / len(results) for k, v in group_emo.items()}
colors = ["#10B981", "#3B82F6", "#8B5CF6", "#F59E0B", "#EC4899", "#06B6D4"]
ax[1].bar(list(group_emo.keys()), list(group_emo.values()), color=colors[: len(group_emo)], alpha=0.8)
ax[1].set_ylim(0, 1)
ax[1].set_title("Group Emotion", fontweight="bold", fontsize=12, color="#1F2937")
ax[1].set_ylabel("Intensity", fontsize=10, color="#4B5563")
ax[1].set_facecolor("white")
ax[1].grid(axis="y", alpha=0.2, linestyle="--")
ax[1].spines["top"].set_visible(False)
ax[1].spines["right"].set_visible(False)
ax[1].tick_params(axis="x", rotation=45)
plt.tight_layout()
n = len(results)
cols = min(3, n)
rows = math.ceil(n / cols)
fig2, ax2 = plt.subplots(rows, cols, figsize=(5 * cols, 3 * rows))
fig2.patch.set_facecolor("white")
axlist = [ax2] if n == 1 else ax2.flatten()
emotion_colors = {
"anger": "#EF4444",
"sadness": "#3B82F6",
"fear": "#8B5CF6",
"disgust": "#F59E0B",
"joy": "#10B981",
"surprise": "#EC4899",
}
for i, r in enumerate(results):
axp = axlist[i]
emotions = list(r["emotions"].keys())
values = list(r["emotions"].values())
bar_colors = [emotion_colors.get(e, "#6B7280") for e in emotions]
axp.bar(emotions, values, color=bar_colors, alpha=0.8)
axp.set_ylim(0, 1)
axp.set_title(r["name"], fontweight="bold", fontsize=11, color="#1F2937")
axp.set_ylabel("Intensity", fontsize=9, color="#4B5563")
axp.set_facecolor("white")
axp.grid(axis="y", alpha=0.2, linestyle="--")
axp.spines["top"].set_visible(False)
axp.spines["right"].set_visible(False)
axp.tick_params(axis="x", rotation=45, labelsize=9)
for j in range(len(axlist) - n):
axlist[n + j].axis("off")
fig2.tight_layout()
return fig, fig2
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
* {
font-family: 'Inter', sans-serif !important;
}
body {
background: white !important;
}
.gradio-container {
max-width: 1400px !important;
margin: 0 auto !important;
background: white !important;
}
.main {
background: white !important;
}
.contain {
background: white !important;
}
.gr-button-primary {
background: linear-gradient(135deg, #667EEA 0%, #764BA2 100%) !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
color: white !important;
border: none !important;
}
.gr-button-primary:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.5) !important;
}
.gr-box, .gr-form, .gr-panel {
background: white !important;
border: 1px solid #E5E7EB !important;
border-radius: 12px !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.05) !important;
}
.gr-input, .gr-textarea {
background: white !important;
border: 1px solid #D1D5DB !important;
border-radius: 8px !important;
color: #1F2937 !important;
}
.gr-input:focus, .gr-textarea:focus {
border-color: #667EEA !important;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
}
.gr-input::placeholder, .gr-textarea::placeholder {
color: #9CA3AF !important;
}
label {
color: #374151 !important;
font-weight: 600 !important;
font-size: 14px !important;
}
.tabs {
background: white !important;
border: 1px solid #E5E7EB !important;
border-radius: 12px !important;
}
.tab-nav {
background: #F9FAFB !important;
border-bottom: 1px solid #E5E7EB !important;
padding: 8px !important;
}
.tab-nav button {
color: #6B7280 !important;
font-weight: 600 !important;
background: transparent !important;
border-radius: 8px !important;
padding: 10px 20px !important;
}
.tab-nav button.selected {
background: white !important;
color: #667EEA !important;
border-bottom: 2px solid #667EEA !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.05) !important;
}
.gr-accordion {
background: white !important;
border: 1px solid #E5E7EB !important;
border-radius: 10px !important;
}
.gr-file {
background: white !important;
border: 2px dashed #D1D5DB !important;
border-radius: 10px !important;
}
.gr-file:hover {
border-color: #667EEA !important;
}
.gr-radio {
background: white !important;
}
.gr-radio label {
background: white !important;
border: 1px solid #D1D5DB !important;
border-radius: 8px !important;
padding: 10px 16px !important;
color: #4B5563 !important;
}
.gr-radio label.selected {
background: #EEF2FF !important;
border-color: #667EEA !important;
color: #667EEA !important;
}
.gr-plot {
background: white !important;
border-radius: 12px !important;
padding: 16px !important;
}
footer {
display: none !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Mental Health Chat Analyzer") as demo:
gr.HTML(
"""
<div style='text-align: center; padding: 80px 20px 60px; background: white; border-bottom: 2px solid #E5E7EB; margin-bottom: 40px;'>
<h1 style='color: #1F2937; font-size: 48px; font-weight: 800; margin: 0 0 20px 0;'>
Mental Health Chat Analyzer
</h1>
<p style='color: #6B7280; font-size: 20px; max-width: 700px; margin: 0 auto 30px; line-height: 1.6;'>
AI-powered emotional intelligence that analyzes conversations to provide emotion and risk insights
</p>
</div>
<div style='max-width: 1000px; margin: 0 auto 60px; padding: 0 20px; background: white;'>
<div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 30px;'>
<div style='text-align: center; padding: 30px; background: white; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.06); border: 1px solid #E5E7EB;'>
<div style='font-size: 40px; margin-bottom: 15px;'>🧠</div>
<h3 style='color: #1F2937; margin-bottom: 10px; font-weight: 700; font-size: 18px;'>AI Analysis</h3>
<p style='color: #6B7280; font-size: 15px;'>Emotion and risk detection</p>
</div>
<div style='text-align: center; padding: 30px; background: white; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.06); border: 1px solid #E5E7EB;'>
<div style='font-size: 40px; margin-bottom: 15px;'>📊</div>
<h3 style='color: #1F2937; margin-bottom: 10px; font-weight: 700; font-size: 18px;'>Visual Insights</h3>
<p style='color: #6B7280; font-size: 15px;'>Risk and emotion charts</p>
</div>
<div style='text-align: center; padding: 30px; background: white; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.06); border: 1px solid #E5E7EB;'>
<div style='font-size: 40px; margin-bottom: 15px;'>🔒</div>
<h3 style='color: #1F2937; margin-bottom: 10px; font-weight: 700; font-size: 18px;'>Privacy First</h3>
<p style='color: #6B7280; font-size: 15px;'>Your data stays on-device</p>
</div>
</div>
</div>
"""
)
gr.HTML(
"<h2 style='text-align: center; font-size: 32px; margin-bottom: 30px; color: #1F2937; background: white; font-weight: 700;'>Start Your Analysis</h2>"
)
with gr.Row():
with gr.Column():
privacy = gr.Radio(
choices=["OFF", "ON"],
value="OFF",
label="Privacy Masking",
info="Enable to anonymize participant names",
)
text_in = gr.Textbox(
label="Conversation Text",
placeholder="Format: Name: message\n\nJohn: I'm stressed about work\nMary: Let's talk about it",
lines=10,
)
with gr.Accordion("Upload Files (Optional)", open=False):
img_in = gr.File(
label="Screenshots",
file_types=["image"],
file_count="multiple",
type="filepath",
)
vid_in = gr.File(
label="Videos",
file_count="multiple",
type="filepath",
)
analyze_btn = gr.Button("Analyze Conversation", variant="primary", size="lg")
with gr.Tabs():
with gr.Tab("Risk Assessment"):
plot1 = gr.Plot()
with gr.Tab("Individual Profiles"):
plot2 = gr.Plot()
analyze_btn.click(
analyze,
inputs=[text_in, img_in, vid_in, privacy],
outputs=[plot1, plot2],
)
if __name__ == "__main__":
demo.launch()