# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from tqdm import tqdm import logging import os from verification import init_model, MODEL_LIST import soundfile as sf import torch import numpy as np import torch.nn.functional as F from torchaudio.transforms import Resample import torch.multiprocessing as mp console_format = logging.Formatter( "[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" ) console_handler = logging.StreamHandler() console_handler.setFormatter(console_format) console_handler.setLevel(logging.INFO) if len(logging.root.handlers) > 0: for handler in logging.root.handlers: logging.root.removeHandler(handler) logging.root.addHandler(console_handler) logging.root.setLevel(logging.INFO) MODEL_NAME = "wavlm_large" S3PRL_PATH = os.environ.get("S3PRL_PATH") if S3PRL_PATH is not None: import patch_unispeech logging.info("Applying Patches for unispeech!!!") patch_unispeech.patch_for_npu() def get_ref_and_gen_files( test_lst, test_folder, task_queue ): with open(test_lst, "r") as fp: for line in fp: fields = line.strip().split("|") gen_name = fields[2].split("/")[-1] gen_name = gen_name.split(".")[0] gen_file = f"{test_folder}/{gen_name}_gen.wav" ref_name = fields[0].split("/")[-1] ref_name = ref_name.split(".")[0] ref_file = f"{test_folder}/{ref_name}_ref.wav" task_queue.put((ref_file, gen_file)) return def eval_speaker_similarity(model, wav1, wav2, rank): wav1, sr1 = sf.read(wav1) wav2, sr2 = sf.read(wav2) wav1 = torch.from_numpy(wav1).unsqueeze(0).float() wav2 = torch.from_numpy(wav2).unsqueeze(0).float() resample1 = Resample(orig_freq=sr1, new_freq=16000) resample2 = Resample(orig_freq=sr2, new_freq=16000) wav1 = resample1(wav1) wav2 = resample2(wav2) wav1 = wav1.cuda(f"cuda:{rank}") wav2 = wav2.cuda(f"cuda:{rank}") model.eval() with torch.no_grad(): emb1 = model(wav1) emb2 = model(wav2) sim = F.cosine_similarity(emb1, emb2) logging.info("The similarity score between two audios is %.4f (-1.0, 1.0)." % (sim[0].item())) return sim[0].item() def eval_proc(model_path, task_queue, rank, sim_list): model = None assert MODEL_NAME in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST) model = init_model(MODEL_NAME, model_path) if model is None else model model.to(f"cuda:{rank}") # sim_list = [] # for ref, gen in tqdm(ref_gen_list): while True: try: new_record = task_queue.get() if new_record is None: logging.info("FINISH processing all inputs") break ref = new_record[0] gen = new_record[1] logging.info(f"eval SIM: {ref} v.s. {gen}") if not os.path.exists(ref) or not os.path.exists(gen): logging.info(f"MISSING: {ref} v.s. {gen}") continue sim = eval_speaker_similarity(model, ref, gen, rank) sim_list.append((sim, ref, gen)) except: logging.info(f"FAIL to eval SIM: {ref} v.s. {gen}") def main(args): handler = logging.FileHandler(filename=args.log_file, mode="w") logging.root.addHandler(handler) device_list = [0] if "CUDA_VISIBLE_DEVICES" in os.environ: device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] logging.info(f"Using devices: {device_list}") n_procs = len(device_list) ctx = mp.get_context('spawn') with ctx.Manager() as manager: sim_list = manager.list() task_queue = manager.Queue() get_ref_and_gen_files(args.test_lst, args.test_path, task_queue) processes = [] for idx in range(n_procs): task_queue.put(None) rank = idx # device_list[idx] p = ctx.Process(target=eval_proc, args=(args.model_path, task_queue, rank, sim_list)) processes.append(p) for proc in processes: proc.start() for proc in processes: proc.join() sim_scores = [] for sim, ref, gen in sim_list: logging.info(f"{ref} vs {gen} : {sim}") sim_scores.append(sim) avg_sim = round(np.mean(np.array(list(sim_scores))), 3) logging.info("total evaluated wav pairs: %d" % (len(sim_list))) logging.info("The average similarity score of %s is %.4f (-1.0, 1.0)." % (args.test_path, avg_sim)) return avg_sim if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--test-path", required=True, type=str, help=f"folder of wav files", ) parser.add_argument( "--test-lst", required=True, type=str, help=f"path to test file lst", ) parser.add_argument( "--log-file", required=False, type=str, default=None, help=f"path to test file lst", ) parser.add_argument( "--model-path", type=str, default="./wavlm-sv", help=f"path to sv model", ) args = parser.parse_args() main(args)