File size: 17,547 Bytes
dbc0e16
439423b
555549c
 
4b260d9
555549c
 
439423b
555549c
 
 
 
377a152
e83cd54
 
abd725c
4b260d9
1878486
b2395f1
1878486
 
 
b2395f1
 
439423b
 
377a152
555549c
e83cd54
f596015
439423b
1878486
e83cd54
555549c
abd725c
377a152
 
d5ac657
555549c
4b260d9
ea6ec54
4b260d9
ea6ec54
 
0dcbb44
555549c
 
 
 
 
5710525
ea6ec54
439423b
e83cd54
fa097da
e83cd54
fa097da
b18efa0
fa097da
 
 
 
 
cf09d5c
fa097da
cf09d5c
e83cd54
 
cf09d5c
e83cd54
 
 
 
 
ea6ec54
 
 
 
 
 
 
 
 
 
b2395f1
ea6ec54
 
 
b2395f1
ea6ec54
 
abd725c
ea6ec54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf09d5c
ea6ec54
 
 
 
 
 
 
 
 
cf09d5c
ea6ec54
 
 
 
 
 
 
 
 
cf09d5c
ea6ec54
 
 
 
 
 
 
cf09d5c
ea6ec54
 
 
e83cd54
ea6ec54
 
 
 
 
abd725c
ea6ec54
 
 
 
fa097da
ea6ec54
 
 
 
 
 
 
 
fa097da
ea6ec54
 
 
 
 
 
 
 
 
 
 
 
fa097da
ea6ec54
 
fa097da
ea6ec54
 
 
 
 
fa097da
ea6ec54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ac657
ea6ec54
 
 
 
 
 
 
 
 
abd725c
ea6ec54
 
b18efa0
ea6ec54
 
 
 
 
 
 
 
 
b18efa0
ea6ec54
 
 
 
 
 
 
 
 
 
 
 
 
fa097da
ea6ec54
 
 
 
 
 
 
 
 
 
cf09d5c
ea6ec54
 
 
 
 
 
 
 
 
 
fa097da
ea6ec54
cf09d5c
ea6ec54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa097da
ea6ec54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd725c
ea6ec54
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import os
import streamlit as st
import numpy as np
import torch
import whisper
from transformers import pipeline, AutoModelForAudioClassification, AutoFeatureExtractor
from deepface import DeepFace
import logging
import soundfile as sf
import tempfile
import cv2
from moviepy.editor import VideoFileClip
import time
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

# Create a cross-platform, writable cache directory for all libraries
CACHE_DIR = os.path.join(tempfile.gettempdir(), "affectlink_cache")
DEEPFACE_CACHE_PATH = os.path.join(CACHE_DIR, ".deepface", "weights")
os.makedirs(DEEPFACE_CACHE_PATH, exist_ok=True) # Proactively create the full path

os.environ['DEEPFACE_HOME'] = CACHE_DIR
os.environ['HF_HOME'] = CACHE_DIR

# --- Page Configuration ---
st.set_page_config(page_title="AffectLink Demo", page_icon="😊", layout="wide")
st.title("AffectLink: Post-Hoc Emotion Analysis")
st.write("Upload a short video clip (under 30 seconds) to see a multimodal emotion analysis.")

# --- Logger Configuration ---
logging.basicConfig(level=logging.INFO)

# --- Emotion Mappings ---
UNIFIED_EMOTIONS = ['angry', 'happy', 'sad', 'neutral']
TEXT_TO_UNIFIED = {'neutral': 'neutral', 'joy': 'happy', 'sadness': 'sad', 'anger': 'angry'}
SER_TO_UNIFIED = {'neu': 'neutral', 'hap': 'happy', 'sad': 'sad', 'ang': 'angry'}
FACIAL_TO_UNIFIED = {'neutral': 'neutral', 'happy': 'happy', 'sad': 'sad', 'angry': 'angry', 'fear':None, 'surprise':None, 'disgust':None}
AUDIO_SAMPLE_RATE = 16000

# --- Model Loading (Staged) ---
@st.cache_resource
def load_audio_models():
    with st.spinner("Loading audio analysis models..."):
        whisper_model = whisper.load_model("tiny.en", download_root=os.path.join(CACHE_DIR, "whisper"))
        text_classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None)
        ser_model_name = "superb/hubert-large-superb-er"
        ser_feature_extractor = AutoFeatureExtractor.from_pretrained(ser_model_name)
        ser_model = AutoModelForAudioClassification.from_pretrained(ser_model_name)
        return whisper_model, text_classifier, ser_model, ser_feature_extractor

# Models will be loaded on demand

# --- Helper Functions for Analysis ---
def create_unified_vector(scores_dict, mapping_dict):
    vector = np.zeros(len(UNIFIED_EMOTIONS))
    total_score = 0
    # Use .items() to iterate over keys and values
    for label, score in scores_dict.items():
        unified_label = mapping_dict.get(label)
        if unified_label in UNIFIED_EMOTIONS:
            vector[UNIFIED_EMOTIONS.index(unified_label)] += score
            total_score += score
    if total_score > 0:
        vector /= total_score
    return vector

def get_consistency_level(cosine_sim):
    if np.isnan(cosine_sim): return "N/A"
    if cosine_sim >= 0.8: return "High"
    if cosine_sim >= 0.6: return "Medium"
    if cosine_sim >= 0.3: return "Low"
    return "Very Low"

# --- Helper Functions for Results Display ---
def process_timeline_to_df(timeline, mapping):
    if not timeline: return pd.DataFrame(columns=UNIFIED_EMOTIONS)
    df = pd.DataFrame.from_dict(timeline, orient='index')
    df_unified = pd.DataFrame(index=df.index, columns=UNIFIED_EMOTIONS).fillna(0.0)
    for raw_col in df.columns:
        unified_col = mapping.get(raw_col)
        if unified_col:
            df_unified[unified_col] += df[raw_col]
    return df_unified

def get_dominant_emotion_from_df(df):
    if df.empty or df.sum().sum() == 0: return "N/A"
    return df.sum().idxmax().capitalize()

def get_avg_unified_scores(df):
    return df.mean().to_dict() if not df.empty else {}

def display_results():
    """Display the final analysis results using data from session state"""
    st.header("Analysis Results")
    
    # Get data from session state
    full_transcription = st.session_state.get('full_transcription', 'No speech detected.')
    ser_timeline = st.session_state.get('ser_timeline', {})
    ter_timeline = st.session_state.get('ter_timeline', {})
    fer_timeline = st.session_state.get('fer_timeline', {})
    duration = st.session_state.get('duration', 0)
    
    # Process timelines
    fer_df = process_timeline_to_df(fer_timeline, FACIAL_TO_UNIFIED)
    ser_df = process_timeline_to_df(ser_timeline, SER_TO_UNIFIED)
    ter_df = process_timeline_to_df(ter_timeline, TEXT_TO_UNIFIED)
    
    # Get dominant emotions
    dominant_fer = get_dominant_emotion_from_df(fer_df)
    dominant_ser = get_dominant_emotion_from_df(ser_df)
    dominant_text = get_dominant_emotion_from_df(ter_df)
    
    # Get average scores
    fer_avg_scores = get_avg_unified_scores(fer_df)
    ser_avg_scores = get_avg_unified_scores(ser_df)
    ter_avg_scores = get_avg_unified_scores(ter_df)
    
    # Calculate vectors and similarity
    fer_vector = create_unified_vector(fer_avg_scores, {e:e for e in UNIFIED_EMOTIONS})
    ser_vector = create_unified_vector(ser_avg_scores, {e:e for e in UNIFIED_EMOTIONS})
    text_vector = create_unified_vector(ter_avg_scores, {e:e for e in UNIFIED_EMOTIONS})
    
    similarities = [cosine_similarity([fer_vector], [text_vector])[0][0], cosine_similarity([fer_vector], [ser_vector])[0][0], cosine_similarity([ser_vector], [text_vector])[0][0]]
    avg_similarity = np.nanmean([s for s in similarities if not np.isnan(s)])
    
    # Display transcription
    st.subheader("Transcription")
    st.markdown(f"> *{full_transcription}*")
    st.divider()
    
    # Display summary and timeline
    col1, col2 = st.columns([1, 2])
    with col1:
        st.subheader("Multimodal Summary")
        st.metric("Dominant Facial Emotion", dominant_fer)
        st.metric("Dominant Text Emotion", dominant_text)
        st.metric("Dominant Speech Emotion", dominant_ser)
        st.metric("Emotion Consistency", get_consistency_level(avg_similarity), f"{avg_similarity:.2f} Avg. Cosine Similarity")
    
    with col2:
        st.subheader("Unified Emotion Timeline")
        
        if duration > 0:
            full_index = np.arange(0, duration, 0.5)
            combined_df = pd.DataFrame(index=full_index)
            
            # ECI Timeline Calculation
            eci_timeline = {}
            for t_stamp in full_index:
                vectors = []
                
                # Interpolate to get a value for any timestamp
                fer_scores = fer_df.reindex(fer_df.index.union([t_stamp])).interpolate(method='linear').loc[t_stamp]
                if not fer_scores.isnull().all():
                    vectors.append(create_unified_vector(fer_scores.to_dict(), {e:e for e in UNIFIED_EMOTIONS}))

                if int(t_stamp) in ser_df.index:
                    vectors.append(create_unified_vector(ser_df.loc[int(t_stamp)].to_dict(), {e:e for e in UNIFIED_EMOTIONS}))
                
                if int(t_stamp) in ter_df.index:
                    vectors.append(create_unified_vector(ter_df.loc[int(t_stamp)].to_dict(), {e:e for e in UNIFIED_EMOTIONS}))
                
                if len(vectors) >= 2:
                    sims = [cosine_similarity([v1], [v2])[0][0] for i, v1 in enumerate(vectors) for v2 in vectors[i+1:]]
                    eci_timeline[t_stamp] = np.mean(sims)

            if not fer_df.empty:
                fer_df_resampled = fer_df.reindex(fer_df.index.union(full_index)).interpolate(method='linear').reindex(full_index)
                for e in UNIFIED_EMOTIONS: combined_df[f'Facial_{e}'] = fer_df_resampled.get(e, 0.0)
            
            if not ser_df.empty:
                ser_df_resampled = ser_df.reindex(ser_df.index.union(full_index)).interpolate(method='linear').reindex(full_index)
                for e in UNIFIED_EMOTIONS: combined_df[f'Speech_{e}'] = ser_df_resampled.get(e, 0.0)

            if not ter_df.empty:
                ter_df_resampled = ter_df.reindex(ter_df.index.union(full_index)).interpolate(method='linear').reindex(full_index)
                for e in UNIFIED_EMOTIONS: combined_df[f'Text_{e}'] = ter_df_resampled.get(e, 0.0)
            
            if eci_timeline:
                eci_series = pd.Series(eci_timeline).reindex(full_index).interpolate(method='linear')
                combined_df['ECI'] = eci_series

            combined_df.fillna(0, inplace=True)
            
            if not combined_df.empty:
                fig, ax = plt.subplots(figsize=(10, 5))
                colors = {'happy': 'green', 'sad': 'blue', 'angry': 'red', 'neutral': 'gray'}
                styles = {'Facial': '-', 'Speech': '--', 'Text': ':'}

                for col in combined_df.columns:
                    if col == 'ECI': continue
                    modality, emotion = col.split('_')
                    if emotion in colors:
                        ax.plot(combined_df.index, combined_df[col], label=f'{modality} {emotion.capitalize()}', color=colors[emotion], linestyle=styles[modality], alpha=0.7)
                
                if 'ECI' in combined_df.columns:
                    ax.plot(combined_df.index, combined_df['ECI'], label='Emotion Consistency', color='black', linewidth=2.5, alpha=0.9)

                ax.set_title("Emotion Confidence Over Time (Normalized)")
                ax.set_xlabel("Time (seconds)")
                ax.set_ylabel("Confidence Score (0-1)")
                ax.set_ylim(0, 1)
                ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
                ax.grid(True, which='both', linestyle='--', linewidth=0.5)
                plt.tight_layout()
                st.pyplot(fig)
            else:
                st.write("No emotion data available to plot.")
        else:
            st.write("No timeline data available.")

# --- Two-Stage UI and Processing Logic ---
uploaded_file = st.file_uploader("Choose a video file...", type=["mp4", "mov", "avi", "mkv"])

# Initialize session state variables
if 'temp_video_path' not in st.session_state:
    st.session_state.temp_video_path = None
if 'uploaded_file_id' not in st.session_state:
    st.session_state.uploaded_file_id = None

# Clear previous results when a new file is uploaded
if uploaded_file is not None:
    file_id = uploaded_file.file_id if hasattr(uploaded_file, 'file_id') else str(hash(uploaded_file.name + str(uploaded_file.size)))
    
    if st.session_state.uploaded_file_id != file_id:
        # New file uploaded, clear previous results
        st.session_state.uploaded_file_id = file_id
        for key in ['stage1_complete', 'stage2_complete', 'full_transcription', 'ser_timeline', 'ter_timeline', 'fer_timeline', 'duration']:
            if key in st.session_state:
                del st.session_state[key]
        
        # Save the video file
        if st.session_state.temp_video_path and os.path.exists(st.session_state.temp_video_path):
            try:
                os.unlink(st.session_state.temp_video_path)
            except Exception:
                pass
        
        with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
            tfile.write(uploaded_file.read())
            st.session_state.temp_video_path = tfile.name

if uploaded_file is not None and st.session_state.temp_video_path:
    st.video(st.session_state.temp_video_path)
    
    # Stage 1: Audio & Text Analysis
    if not st.session_state.get('stage1_complete', False):
        if st.button("🎡 Step 1: Analyze Audio & Text", type="primary"):
            try:
                # Load audio models
                whisper_model, text_classifier, ser_model, ser_feature_extractor = load_audio_models()
                
                ser_timeline, ter_timeline = {}, {}
                full_transcription = "No speech detected."
                
                video_clip = VideoFileClip(st.session_state.temp_video_path)
                duration = video_clip.duration
                st.session_state.duration = duration
                
                with st.spinner("Analyzing audio and text..."):
                    if video_clip.audio:
                        with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as taudio:
                            video_clip.audio.write_audiofile(taudio.name, fps=AUDIO_SAMPLE_RATE, logger=None)
                            temp_audio_path = taudio.name

                        # Transcription
                        whisper_result = whisper_model.transcribe(
                            temp_audio_path, 
                            word_timestamps=True, 
                            fp16=False,
                            condition_on_previous_text=False
                        )
                        full_transcription = whisper_result['text'].strip()
                        
                        # Speech emotion recognition
                        audio_array, _ = sf.read(temp_audio_path, dtype='float32')
                        if audio_array.ndim == 2: 
                            audio_array = audio_array.mean(axis=1)

                        for i in range(int(duration)):
                            start_sample, end_sample = i * AUDIO_SAMPLE_RATE, (i + 1) * AUDIO_SAMPLE_RATE
                            chunk = audio_array[start_sample:end_sample]
                            
                            if len(chunk) > 400:
                                inputs = ser_feature_extractor(chunk, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt", padding=True)
                                with torch.no_grad():
                                    logits = ser_model(**inputs).logits
                                scores = torch.nn.functional.softmax(logits, dim=1).squeeze()
                                ser_timeline[i] = {ser_model.config.id2label[k]: score.item() for k, score in enumerate(scores)}

                            # Text emotion recognition
                            words_in_segment = [seg['word'] for seg in whisper_result.get('segments', []) if seg['start'] >= i and seg['start'] < i+1 for seg in seg.get('words', [])]
                            segment_text = " ".join(words_in_segment).strip()
                            if segment_text:
                                text_emotions = text_classifier(segment_text)[0]
                                ter_timeline[i] = {emo['label']: emo['score'] for emo in text_emotions}
                        
                        # Clean up audio file
                        if os.path.exists(temp_audio_path):
                            os.unlink(temp_audio_path)
                
                video_clip.close()
                
                # Store results in session state
                st.session_state.full_transcription = full_transcription
                st.session_state.ser_timeline = ser_timeline
                st.session_state.ter_timeline = ter_timeline
                st.session_state.stage1_complete = True
                
                st.success("βœ… Audio analysis complete! Speech and text emotions have been analyzed.")
                st.rerun()
                
            except Exception as e:
                st.error(f"Error during audio analysis: {str(e)}")
    
    else:
        st.success("βœ… Stage 1 (Audio & Text Analysis) - Complete!")
    
    # Stage 2: Facial Analysis
    if st.session_state.get('stage1_complete', False) and not st.session_state.get('stage2_complete', False):
        if st.button("😊 Step 2: Analyze Facial Expressions", type="primary"):
            try:
                fer_timeline = {}
                
                with st.spinner("Analyzing facial expressions..."):
                    cap = cv2.VideoCapture(st.session_state.temp_video_path)
                    fps = cap.get(cv2.CAP_PROP_FPS) or 30
                    frame_count = 0
                    
                    while cap.isOpened():
                        ret, frame = cap.read()
                        if not ret: 
                            break
                        timestamp = frame_count / fps
                        if frame_count % int(fps) == 0:
                            analysis = DeepFace.analyze(frame, actions=['emotion'], enforce_detection=False, silent=True)
                            if isinstance(analysis, list) and len(analysis) > 0:
                                fer_timeline[timestamp] = {k: v / 100.0 for k, v in analysis[0]['emotion'].items()}
                        frame_count += 1
                    cap.release()
                
                # Store results in session state
                st.session_state.fer_timeline = fer_timeline
                st.session_state.stage2_complete = True
                
                st.success("βœ… Facial analysis complete! All analyses are now finished.")
                st.rerun()
                
            except Exception as e:
                st.error(f"Error during facial analysis: {str(e)}")
    
    elif st.session_state.get('stage2_complete', False):
        st.success("βœ… Stage 2 (Facial Expression Analysis) - Complete!")
    
    # Display results if both stages are complete
    if st.session_state.get('stage1_complete', False) and st.session_state.get('stage2_complete', False):
        display_results()

# Cleanup on app restart or when session ends
if st.session_state.temp_video_path and not uploaded_file:
    try:
        if os.path.exists(st.session_state.temp_video_path):
            os.unlink(st.session_state.temp_video_path)
        st.session_state.temp_video_path = None
    except Exception:
        pass