Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from loguru import logger | |
| from pyannote.audio import Inference, Model | |
| HF_REPO_ID = "litagin/voice-samples-22050" | |
| RESNET34_ROOT = Path("./embeddings") | |
| RESNET34_DIM = 256 | |
| AUDIO_ZIP_DIR = Path("./audio_files_zipped_by_game_22_050") | |
| if AUDIO_ZIP_DIR.exists(): | |
| logger.info("Audio files already downloaded. Skip downloading.") | |
| else: | |
| logger.info("Downloading audio files...") | |
| token = os.getenv("HF_TOKEN") | |
| snapshot_download( | |
| HF_REPO_ID, repo_type="dataset", local_dir=AUDIO_ZIP_DIR, token=token | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Device: {device}") | |
| logger.info("Loading resnet34 vectors...") | |
| resnet34_embs = np.load(RESNET34_ROOT / "all_embs.npy") | |
| resnet34_embs_normalized = resnet34_embs / np.linalg.norm( | |
| resnet34_embs, axis=1, keepdims=True | |
| ) | |
| logger.info("Loading resnet34 model...") | |
| model_resnet34 = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") | |
| inference = Inference(model_resnet34, window="whole") | |
| inference.to(device) | |
| logger.info("Loading filelist...") | |
| with open(RESNET34_ROOT / "all_filelists.txt", "r", encoding="utf-8") as file: | |
| files = [line.strip() for line in file] | |
| def get_speaker_name(file_idx: int): | |
| filepath = Path(files[file_idx]) | |
| game_name = filepath.parent.parent.name | |
| speaker_name = filepath.parent.name | |
| return f"{game_name}/{speaker_name}" # ゲーム名とスピーカー名を返す | |
| # スピーカーIDの配列を取得 | |
| logger.info("Getting speaker ids...") | |
| all_speaker_set = set([get_speaker_name(i) for i in range(len(files))]) | |
| id2speaker = {i: speaker for i, speaker in enumerate(sorted(all_speaker_set))} | |
| num_speakers = len(id2speaker) | |
| speaker2id = {speaker: i for i, speaker in id2speaker.items()} | |
| speaker_id_array = np.array( | |
| [speaker2id[get_speaker_name(i)] for i in range(len(files))] | |
| ) | |
| # def get_zip_archive_path_and_internal_path(file_path: Path) -> tuple[str, str]: | |
| # # 構造: audio_files/{game_name}/{speaker_name}/{audio_file} | |
| # game_name = file_path.parent.parent.name | |
| # speaker_name = file_path.parent.name | |
| # archive_path = AUDIO_ZIP_DIR / game_name / f"{speaker_name}.zip" | |
| # internal_path = file_path.name # ZIP内のパスはファイル名のみ | |
| # return str(archive_path), str(internal_path) | |
| def get_zip_archive_path_and_internal_path(file_path: Path) -> tuple[str, str]: | |
| # 構造: audio_files/{game_name}/{speaker_name}/{audio_file} | |
| game_name = file_path.parent.parent.name | |
| speaker_name = file_path.parent.name | |
| archive_path = AUDIO_ZIP_DIR / f"{game_name}.zip" | |
| internal_path = f"{speaker_name}/{file_path.name}" # ZIP内のパスを "speaker_name/ファイル名" とする | |
| return str(archive_path), str(internal_path) | |
| def load_audio_from_zip(file_path: Path) -> tuple[np.ndarray, int]: | |
| archive_path, internal_path = get_zip_archive_path_and_internal_path(file_path) | |
| with zipfile.ZipFile(archive_path, "r") as zf: | |
| with zf.open(internal_path) as audio_file: | |
| audio_bytes = audio_file.read() | |
| # 一時ファイルに書き出してから読み込む | |
| with tempfile.NamedTemporaryFile( | |
| delete=False, suffix=Path(internal_path).suffix | |
| ) as tmp_file: | |
| tmp_file.write(audio_bytes) | |
| tmp_file_path = tmp_file.name | |
| waveform, sample_rate = librosa.load(tmp_file_path, sr=None) | |
| # 一時ファイルを削除 | |
| Path(tmp_file_path).unlink() | |
| return waveform, int(sample_rate) | |
| def get_emb(audio_path: Path | str) -> np.ndarray: | |
| emb = inference(str(audio_path)) | |
| assert isinstance(emb, np.ndarray) | |
| assert emb.shape == (RESNET34_DIM,) | |
| return emb | |
| def search(audio_path: str): | |
| logger.info("Computing embeddings...") | |
| emb = get_emb(audio_path) # ユーザー入力の音声ファイル | |
| emb = emb.reshape(1, -1) # (1, dim) | |
| logger.success("Embeddings computed.") | |
| # Normalize query vector | |
| logger.info("Computing similarities...") | |
| emb_normalized = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
| similarities = np.dot(resnet34_embs_normalized, emb_normalized.T).flatten() | |
| logger.success("Similarities computed.") | |
| # Search max similarity files | |
| top_k = 10 | |
| top_k_indices = np.argsort(similarities)[::-1][:top_k] | |
| top_k_files = [files[file_idx] for file_idx in top_k_indices] | |
| top_k_scores = similarities[top_k_indices] | |
| logger.info("Fetching audio files...") | |
| result = [] | |
| for i, (f, file_idx, score) in enumerate( | |
| zip(top_k_files, top_k_indices, top_k_scores) | |
| ): | |
| waveform_np, sample_rate = load_audio_from_zip(Path(f)) | |
| result.append( | |
| gr.Audio( | |
| value=(sample_rate, waveform_np), | |
| label=f"Top {i+1}: {get_speaker_name(file_idx)}, {score:.4f}", | |
| ) | |
| ) | |
| logger.success("Audio files fetched.") | |
| return result | |
| def get_label(audio_path: str, num_top_classes: int = 10): | |
| logger.info("Computing embeddings...") | |
| emb = get_emb(audio_path) # ユーザー入力の音声ファイル | |
| emb = emb.reshape(1, -1) # (1, dim) | |
| logger.success("Embeddings computed.") | |
| # Normalize query vector | |
| emb_normalized = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
| similarities = np.dot(resnet34_embs_normalized, emb_normalized.T).flatten() | |
| logger.info("Calculating average scores...") | |
| speaker_scores = {} | |
| for character_id in range(num_speakers): | |
| # 各キャラクターのインデックスを取得 | |
| character_indices = np.where(speaker_id_array == character_id)[0] | |
| # このキャラクターのトップ10の類似度を選択 | |
| top_similarities = np.sort(similarities[character_indices])[::-1][ | |
| :num_top_classes | |
| ] | |
| # 平均スコアを計算 | |
| average_score = np.mean(top_similarities) | |
| # スピーカー名を取得 | |
| speaker_name = id2speaker[character_id] | |
| speaker_scores[speaker_name] = average_score | |
| # スコアでソートして上位10件を返す | |
| sorted_scores = dict( | |
| sorted(speaker_scores.items(), key=lambda item: item[1], reverse=True)[:10] | |
| ) | |
| logger.success("Average scores calculated.") | |
| return sorted_scores | |
| with gr.Blocks() as app: | |
| input_audio = gr.Audio(type="filepath") | |
| with gr.Row(): | |
| with gr.Column(): | |
| btn_audio = gr.Button("似ている音声を検索") | |
| top_k = 10 | |
| components = [gr.Audio(label=f"Top {i+1}") for i in range(top_k)] | |
| with gr.Column(): | |
| btn_label = gr.Button("似ている話者を検索") | |
| label = gr.Label(num_top_classes=10) | |
| btn_audio.click(search, inputs=[input_audio], outputs=components) | |
| btn_label.click(get_label, inputs=[input_audio], outputs=[label]) | |
| app.launch() | |