Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import time | |
| import csv | |
| import json | |
| import torch | |
| import pandas as pd | |
| from pydub import AudioSegment | |
| from tqdm import tqdm | |
| from transformers import pipeline, AutoModel | |
| from google.colab import drive, files | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| # Mount Google Drive if running in Colab | |
| drive.mount("/content/drive") | |
| # Path configurations | |
| AUDIO_FOLDER_PATH = ( | |
| "/content/drive/MyDrive/SuperAI/Season 5/Level 2/Hack 4/speechs/test/" | |
| ) | |
| METADATA_CSV_PATH = "/content/drive/MyDrive/SuperAI/Season 5/Level 2/Hack 4/test.csv" | |
| OUTPUT_JSON_PATH = "hackathon_results.json" | |
| OUTPUT_CSV_PATH = "output.csv" | |
| # Analysis criteria | |
| JSON_CRITERIA_KEYS = [ | |
| "กล่าวสวัสดี", | |
| "แนะนำชื่อและนามสกุล", | |
| "บอกประเภทใบอนุญาตและเลขที่ใบอนุญาตที่ยังไม่หมดอายุ", | |
| "บอกวัตถุประสงค์ของการเข้าพบครั้งนี้", | |
| "เน้นประโยชน์ว่าลูกค้าได้ประโยชน์อะไรจากการเข้าพบครั้งนี้", | |
| "บอกระยะเวลาที่ใช้ในการเข้าพบ", | |
| ] | |
| CRITERIA_DESCRIPTIONS = { | |
| "กล่าวสวัสดี": "RM กล่าวสวัสดี (RM greets the client). Example: 'สวัสดีครับ/ค่ะ'", | |
| "แนะนำชื่อและนามสกุล": "RM แนะนำชื่อและนามสกุล (RM introduces their full name - first name and surname). Example: 'ผมชื่อ... นามสกุล...' Not just first name or nickname.", | |
| "บอกประเภทใบอนุญาตและเลขที่ใบอนุญาตที่ยังไม่หมดอายุ": "RM บอกประเภทใบอนุญาต เลขที่ใบอนุญาต และแจ้งว่าใบอนุญาตยังไม่หมดอายุ (RM states their license type, license number, and confirms it's not expired). Example: 'มีใบอนุญาต... เลขที่... ใบอนุญาตยังไม่หมดอายุ'", | |
| "บอกวัตถุประสงค์ของการเข้าพบครั้งนี้": "RM บอกวัตถุประสงค์ของการเข้าพบครั้งนี้ (RM states the purpose of this meeting). Examples: 'เข้าพบ เพื่ออัพเดทพอร์ตการลงทุน', 'เข้าพบ เพื่ออัพเดทสภาวะตลาด'", | |
| "เน้นประโยชน์ว่าลูกค้าได้ประโยชน์อะไรจากการเข้าพบครั้งนี้": "RM เน้นประโยชน์ว่าลูกค้าได้ประโยชน์อะไรจากการเข้าพบครั้งนี้ (RM explains the benefits for the client from this meeting). Examples: 'ปรับสัดส่วนการลงทุน...', 'ปรับเปลี่ยนการลงทุนตามสภาวะตลาด...'", | |
| "บอกระยะเวลาที่ใช้ในการเข้าพบ": "RM บอกระยะเวลาที่ใช้ในการเข้าพบ (RM states the estimated duration of the meeting). Examples: 'ขอเวลา 1 ชั่วโมง สะดวกมั้ย', 'ขอเวลา 30 นาที'", | |
| } | |
| # ============================================================================= | |
| # MODEL LOADING | |
| # ============================================================================= | |
| def load_models(): | |
| """Load all required models for speech recognition and analysis""" | |
| print("Loading models...") | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| # Load Pathumma ASR model | |
| asr_pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model="nectec/Pathumma-whisper-th-large-v3", | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| # Configure language and task | |
| lang = "th" | |
| task = "transcribe" | |
| asr_pipe.model.config.forced_decoder_ids = ( | |
| asr_pipe.tokenizer.get_decoder_prompt_ids(language=lang, task=task) | |
| ) | |
| # Load Typhoon model for analysis | |
| typhoon_model = AutoModel.from_pretrained( | |
| "scb10x/llama3.1-typhoon2-audio-8b-instruct", trust_remote_code=True | |
| ) | |
| return asr_pipe, typhoon_model | |
| # ============================================================================= | |
| # AUDIO PROCESSING | |
| # ============================================================================= | |
| def transcribe_audio(audio_path, asr_pipe): | |
| """Transcribe audio file using Pathumma model with chunking""" | |
| try: | |
| audio = AudioSegment.from_file(audio_path) | |
| chunk_length_ms = 27000 # 27 seconds per chunk | |
| full_transcription = "" | |
| num_chunks = (len(audio) + chunk_length_ms - 1) // chunk_length_ms | |
| for i in range(num_chunks): | |
| start = i * chunk_length_ms | |
| chunk = audio[start : start + chunk_length_ms] | |
| chunk_path = f"temp_chunk_{i}.wav" | |
| chunk.export(chunk_path, format="wav") | |
| try: | |
| output = asr_pipe(chunk_path) | |
| full_transcription += output["text"].strip() + " " | |
| except Exception as e: | |
| print(f"Error on chunk {i} of {audio_path}: {e}") | |
| full_transcription += "[ERROR] " | |
| os.remove(chunk_path) # Remove temporary file | |
| return full_transcription.strip() | |
| except Exception as e: | |
| print(f"Error processing {audio_path}: {e}") | |
| return None | |
| # ============================================================================= | |
| # TEXT ANALYSIS | |
| # ============================================================================= | |
| def analyze_transcription(transcription, typhoon_model): | |
| """Analyze transcription for required criteria using Typhoon model""" | |
| try: | |
| prompt = "\n".join( | |
| [ | |
| "Analyze the Thai conversation transcription for the following criteria.", | |
| "Return 'True' or 'False' for each, one per line, in this order:", | |
| *[ | |
| f"{i + 1}. {key}: {CRITERIA_DESCRIPTIONS[key]}" | |
| for i, key in enumerate(JSON_CRITERIA_KEYS) | |
| ], | |
| "Transcription:\n" + transcription, | |
| ] | |
| ) | |
| response = typhoon_model.generate(prompt=prompt, max_length=100) | |
| lines = [ | |
| line.strip().lower() | |
| for line in response.strip().split("\n") | |
| if line.strip().lower() in ["true", "false"] | |
| ] | |
| return ( | |
| [line == "true" for line in lines] | |
| if len(lines) == len(JSON_CRITERIA_KEYS) | |
| else None | |
| ) | |
| except Exception as e: | |
| print(f"Error analyzing transcription: {e}") | |
| return None | |
| def has_request_time(text): | |
| """Check if text contains time request pattern""" | |
| return bool(re.search(r"ขอเวลา\s*\d+\s*(นาที|ชั่วโมง)", str(text))) | |
| def name_match(row, transcription): | |
| """Check if both first and last name appear in transcription""" | |
| first = str(row["first_name"]) | |
| last = str(row["last_name"]) | |
| text = str(transcription) | |
| return (first in text) and (last in text) | |
| # ============================================================================= | |
| # MAIN PROCESSING PIPELINE | |
| # ============================================================================= | |
| def process_all_files(): | |
| """Main function to process all audio files and analyze them""" | |
| start_time = time.perf_counter() | |
| # Load models | |
| asr_pipe, typhoon_model = load_models() | |
| # Load metadata | |
| try: | |
| metadata_df = pd.read_csv(METADATA_CSV_PATH) | |
| except Exception as e: | |
| print(f"Error reading metadata CSV: {e}") | |
| return | |
| # Verify audio folder exists | |
| if not os.path.isdir(AUDIO_FOLDER_PATH): | |
| print(f"Audio folder not found at {AUDIO_FOLDER_PATH}") | |
| return | |
| # Initialize results | |
| all_results = [] | |
| # Process each file | |
| for _, row in tqdm( | |
| metadata_df.iterrows(), total=len(metadata_df), desc="Processing files" | |
| ): | |
| file_id = str(row["id"]) | |
| result = { | |
| "id": file_id, | |
| "first_name": str(row["first_name"]), | |
| "last_name": str(row["last_name"]), | |
| **{key: False for key in JSON_CRITERIA_KEYS}, | |
| "processing_status": "Pending", | |
| } | |
| # Find audio file | |
| audio_path = None | |
| for ext in [".mp3", ".wav", ".m4a", ".flac", ".aac", ".ogg"]: | |
| potential_path = os.path.join(AUDIO_FOLDER_PATH, f"{file_id}{ext}") | |
| if os.path.exists(potential_path): | |
| audio_path = potential_path | |
| break | |
| if not audio_path: | |
| result["processing_status"] = "Error: Audio file not found" | |
| all_results.append(result) | |
| continue | |
| try: | |
| # Step 1: Transcribe audio | |
| transcription = transcribe_audio(audio_path, asr_pipe) | |
| if not transcription: | |
| result["processing_status"] = "Error: Transcription failed" | |
| all_results.append(result) | |
| continue | |
| # Step 2: Basic pattern matching | |
| # Greeting check | |
| if "สวัสดี" in transcription: | |
| result["กล่าวสวัสดี"] = True | |
| # Time request check | |
| if has_request_time(transcription): | |
| result["บอกระยะเวลาที่ใช้ในการเข้าพบ"] = True | |
| # Name check | |
| if name_match(row, transcription): | |
| result["แนะนำชื่อและนามสกุล"] = True | |
| # Step 3: Advanced analysis with Typhoon model | |
| booleans = analyze_transcription(transcription, typhoon_model) | |
| if booleans: | |
| for i, key in enumerate(JSON_CRITERIA_KEYS): | |
| # Don't overwrite values we already set from pattern matching | |
| if not result.get(key, False): | |
| result[key] = booleans[i] | |
| result["processing_status"] = "Success" | |
| except Exception as e: | |
| result["processing_status"] = f"Error: {str(e)}" | |
| all_results.append(result) | |
| time.sleep(1.5) # Rate limiting | |
| # Save results | |
| try: | |
| # Sort results by original metadata order | |
| results_dict = {result["id"]: result for result in all_results} | |
| sorted_results = [ | |
| results_dict[str(row["id"])] | |
| for _, row in metadata_df.iterrows() | |
| if str(row["id"]) in results_dict | |
| ] | |
| # Save as JSON | |
| with open(OUTPUT_JSON_PATH, "w", encoding="utf-8") as f: | |
| json.dump(sorted_results, f, ensure_ascii=False, indent=2) | |
| # Save as CSV | |
| columns = ( | |
| ["id", "first_name", "last_name"] | |
| + JSON_CRITERIA_KEYS | |
| + ["processing_status"] | |
| ) | |
| with open(OUTPUT_CSV_PATH, "w", newline="", encoding="utf-8-sig") as f: | |
| writer = csv.DictWriter(f, fieldnames=columns) | |
| writer.writeheader() | |
| for entry in sorted_results: | |
| writer.writerow({key: entry.get(key, "") for key in columns}) | |
| # Download results if running in Colab | |
| files.download(OUTPUT_JSON_PATH) | |
| files.download(OUTPUT_CSV_PATH) | |
| except Exception as e: | |
| print(f"Error saving results: {e}") | |
| end_time = time.perf_counter() | |
| elapsed_time = end_time - start_time | |
| print(f"\nProcessing complete! Time taken: {elapsed_time:.2f} seconds") | |
| # ============================================================================= | |
| # EXECUTION | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| process_all_files() | |