Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -195,12 +195,13 @@ def get_embedding_for_text(text, tokenizer, model):
|
|
| 195 |
continue
|
| 196 |
|
| 197 |
if chunk_embeddings:
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
| 204 |
def format_topics(topic_model, topic_counts):
|
| 205 |
"""Format topics for display."""
|
| 206 |
formatted_topics = []
|
|
@@ -252,41 +253,40 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
|
|
| 252 |
topic_model_params["nr_topics"] = "auto"
|
| 253 |
|
| 254 |
topic_model = BERTopic(
|
| 255 |
-
embedding_model=
|
| 256 |
**topic_model_params
|
| 257 |
)
|
| 258 |
|
| 259 |
-
vectorizer = CountVectorizer(
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
| 262 |
topic_model.vectorizer_model = vectorizer
|
| 263 |
|
| 264 |
-
# Create a placeholder for the progress bar
|
| 265 |
progress_placeholder = st.empty()
|
| 266 |
progress_bar = progress_placeholder.progress(0)
|
| 267 |
-
|
| 268 |
-
# Create status message placeholder
|
| 269 |
status_message = st.empty()
|
| 270 |
|
| 271 |
for country, group in df.groupby('country'):
|
| 272 |
-
# Clear memory at the start of each country's processing
|
| 273 |
gc.collect()
|
| 274 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 275 |
|
| 276 |
status_message.text(f"Processing poems for {country}...")
|
| 277 |
texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
|
| 278 |
all_emotions = []
|
|
|
|
| 279 |
|
| 280 |
-
# Use cached embeddings with progress tracking
|
| 281 |
-
embeddings = []
|
| 282 |
total_texts = len(texts)
|
| 283 |
for i, text in enumerate(texts):
|
| 284 |
try:
|
| 285 |
embedding = cache_embeddings(text, bert_tokenizer, bert_model)
|
| 286 |
if embedding is not None and not np.isnan(embedding).any():
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
-
# Update progress more frequently
|
| 290 |
if i % max(1, total_texts // 100) == 0:
|
| 291 |
progress = (i + 1) / total_texts * 0.4
|
| 292 |
progress_bar.progress(progress)
|
|
@@ -296,7 +296,7 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
|
|
| 296 |
st.warning(f"Error processing poem {i+1} in {country}: {str(e)}")
|
| 297 |
continue
|
| 298 |
|
| 299 |
-
# Process emotions
|
| 300 |
for i, text in enumerate(texts):
|
| 301 |
try:
|
| 302 |
emotion = cache_emotion_classification(text, emotion_classifier)
|
|
@@ -316,30 +316,32 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
|
|
| 316 |
st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
|
| 317 |
continue
|
| 318 |
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
except Exception as e:
|
| 335 |
st.warning(f"Could not generate topics for {country}: {str(e)}")
|
| 336 |
continue
|
| 337 |
|
| 338 |
-
# Clear progress for next country
|
| 339 |
progress_placeholder.empty()
|
| 340 |
status_message.empty()
|
| 341 |
-
|
| 342 |
-
# Create new progress bar for next country
|
| 343 |
progress_placeholder = st.empty()
|
| 344 |
progress_bar = progress_placeholder.progress(0)
|
| 345 |
status_message = st.empty()
|
|
|
|
| 195 |
continue
|
| 196 |
|
| 197 |
if chunk_embeddings:
|
| 198 |
+
# Convert to numpy array and ensure 2D shape
|
| 199 |
+
chunk_embeddings = np.array(chunk_embeddings)
|
| 200 |
+
if len(chunk_embeddings.shape) == 1:
|
| 201 |
+
chunk_embeddings = chunk_embeddings.reshape(1, -1)
|
| 202 |
+
return chunk_embeddings
|
| 203 |
+
return np.zeros((1, model.config.hidden_size))
|
| 204 |
+
|
| 205 |
def format_topics(topic_model, topic_counts):
|
| 206 |
"""Format topics for display."""
|
| 207 |
formatted_topics = []
|
|
|
|
| 253 |
topic_model_params["nr_topics"] = "auto"
|
| 254 |
|
| 255 |
topic_model = BERTopic(
|
| 256 |
+
embedding_model=None, # Set to None since we're providing embeddings
|
| 257 |
**topic_model_params
|
| 258 |
)
|
| 259 |
|
| 260 |
+
vectorizer = CountVectorizer(
|
| 261 |
+
stop_words=list(ARABIC_STOP_WORDS),
|
| 262 |
+
min_df=1,
|
| 263 |
+
max_df=1.0
|
| 264 |
+
)
|
| 265 |
topic_model.vectorizer_model = vectorizer
|
| 266 |
|
|
|
|
| 267 |
progress_placeholder = st.empty()
|
| 268 |
progress_bar = progress_placeholder.progress(0)
|
|
|
|
|
|
|
| 269 |
status_message = st.empty()
|
| 270 |
|
| 271 |
for country, group in df.groupby('country'):
|
|
|
|
| 272 |
gc.collect()
|
| 273 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 274 |
|
| 275 |
status_message.text(f"Processing poems for {country}...")
|
| 276 |
texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
|
| 277 |
all_emotions = []
|
| 278 |
+
embeddings_list = []
|
| 279 |
|
|
|
|
|
|
|
| 280 |
total_texts = len(texts)
|
| 281 |
for i, text in enumerate(texts):
|
| 282 |
try:
|
| 283 |
embedding = cache_embeddings(text, bert_tokenizer, bert_model)
|
| 284 |
if embedding is not None and not np.isnan(embedding).any():
|
| 285 |
+
# Ensure embedding is 2D
|
| 286 |
+
if len(embedding.shape) == 1:
|
| 287 |
+
embedding = embedding.reshape(1, -1)
|
| 288 |
+
embeddings_list.append(embedding)
|
| 289 |
|
|
|
|
| 290 |
if i % max(1, total_texts // 100) == 0:
|
| 291 |
progress = (i + 1) / total_texts * 0.4
|
| 292 |
progress_bar.progress(progress)
|
|
|
|
| 296 |
st.warning(f"Error processing poem {i+1} in {country}: {str(e)}")
|
| 297 |
continue
|
| 298 |
|
| 299 |
+
# Process emotions
|
| 300 |
for i, text in enumerate(texts):
|
| 301 |
try:
|
| 302 |
emotion = cache_emotion_classification(text, emotion_classifier)
|
|
|
|
| 316 |
st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
|
| 317 |
continue
|
| 318 |
|
| 319 |
+
if embeddings_list:
|
| 320 |
+
# Stack all embeddings into a single 2D array
|
| 321 |
+
embeddings = np.vstack(embeddings_list)
|
| 322 |
+
|
| 323 |
+
topics, probs = topic_model.fit_transform(texts, embeddings)
|
| 324 |
+
topic_counts = Counter(topics)
|
| 325 |
+
|
| 326 |
+
top_topics = format_topics(topic_model, topic_counts.most_common(top_n))
|
| 327 |
+
top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
|
| 328 |
+
|
| 329 |
+
summaries.append({
|
| 330 |
+
'country': country,
|
| 331 |
+
'total_poems': len(texts),
|
| 332 |
+
'top_topics': top_topics,
|
| 333 |
+
'top_emotions': top_emotions
|
| 334 |
+
})
|
| 335 |
+
progress_bar.progress(1.0, text="Processing complete!")
|
| 336 |
+
else:
|
| 337 |
+
st.warning(f"No valid embeddings generated for {country}")
|
| 338 |
+
|
| 339 |
except Exception as e:
|
| 340 |
st.warning(f"Could not generate topics for {country}: {str(e)}")
|
| 341 |
continue
|
| 342 |
|
|
|
|
| 343 |
progress_placeholder.empty()
|
| 344 |
status_message.empty()
|
|
|
|
|
|
|
| 345 |
progress_placeholder = st.empty()
|
| 346 |
progress_bar = progress_placeholder.progress(0)
|
| 347 |
status_message = st.empty()
|