File size: 7,387 Bytes
86dcc8f
 
 
d34fc56
 
 
c7a551c
86dcc8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a551c
 
86dcc8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import streamlit as st
import torch
import torchaudio
# --- Baris BARU yang sudah diperbaiki ---
from speechbrain.inference.speaker import EncoderClassifier
from speechbrain.inference.enhancement import SpectralMaskEnhancement
from speechbrain.inference.classifiers import AudioClassifier
import os
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# --- Konfigurasi dan Pemuatan Model (Dijalankan sekali) ---

@st.cache_resource
def load_models():
    """Memuat model verifikasi speaker dan KWS."""
    # Model untuk Verifikasi Speaker (Tahap 1)
    spk_model = EncoderClassifier.from_hparams(
        source="speechbrain/spkrec-xvect-voxceleb",
        savedir="pretrained_models/spkrec-xvect-voxceleb"
    )
    
    # Model untuk Deteksi Perintah (Tahap 2)
# Model untuk Deteksi Perintah (Tahap 2)
    kws_model = AudioClassifier.from_hparams(
        source="speechbrain/google_speech_command_xvector",
        savedir="pretrained_models/google_speech_command_xvector"
    )
    
    # Model untuk membersihkan audio (Opsional tapi bagus)
    enhancer = SpectralMaskEnhancement.from_hparams(
        source="speechbrain/metricgan-plus-voicebank",
        savedir="pretrained_models/metricgan-plus-voicebank"
    )
    return spk_model, kws_model, enhancer

# Memuat model
spk_model, kws_model, enhancer = load_models()

# Direktori pendaftaran
ENROLL_DIR = "enroll/"
THRESHOLD = 0.85 # Ambang batas kemiripan

# --- Fungsi Helper ---

def preprocess_audio(wav_file):
    """Memuat, membersihkan, dan mengubah sample rate audio."""
    try:
        # Muat audio dari file yang di-upload
        sig, fs = torchaudio.load(wav_file)

        # Bersihkan noise (jika model enhancer ada)
        if enhancer:
            enhanced_sig = enhancer.enhance_batch(sig, lengths=torch.tensor([sig.shape[1]]))
            sig = enhanced_sig.squeeze(0)

        # Resample ke 16kHz (wajib untuk model)
        if fs != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)
            sig = resampler(sig)
        
        return sig
    except Exception as e:
        st.error(f"Error memproses audio: {e}")
        return None

@st.cache_data
def get_enrollment_embeddings():
    """
    Membuat embedding (sidik jari suara) rata-rata 
    untuk setiap pengguna di folder /enroll.
    """
    enrollment_data = {}
    if not os.path.exists(ENROLL_DIR):
        st.warning(f"Folder '{ENROLL_DIR}' tidak ditemukan.")
        return {}

    for speaker_name in os.listdir(ENROLL_DIR):
        speaker_dir = os.path.join(ENROLL_DIR, speaker_name)
        if os.path.isdir(speaker_dir):
            embeddings = []
            for wav_file in os.listdir(speaker_dir):
                if wav_file.endswith(".wav"):
                    wav_path = os.path.join(speaker_dir, wav_file)
                    try:
                        sig, fs = torchaudio.load(wav_path)
                        if fs != 16000:
                            resampler = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)
                            sig = resampler(sig)
                        
                        # Buat embedding
                        with torch.no_grad():
                            emb = spk_model.encode_batch(sig)
                            emb = emb.squeeze()
                            embeddings.append(emb.numpy())
                    except Exception as e:
                        st.error(f"Gagal memproses {wav_path}: {e}")
            
            if embeddings:
                # Ambil rata-rata embedding untuk speaker ini
                enrollment_data[speaker_name] = np.mean(embeddings, axis=0)
    
    return enrollment_data

# --- Antarmuka Streamlit ---
st.title("Sistem Verifikasi Perintah Suara πŸ”")
st.write("Unggah file .wav untuk verifikasi.")

# Muat data pendaftaran
enrollment_embeddings = get_enrollment_embeddings()

if not enrollment_embeddings:
    st.error("Tidak ada data pendaftaran yang ditemukan. Pastikan folder 'enroll' ada dan berisi file .wav.")
else:
    st.success(f"Berhasil memuat data pendaftaran untuk: {list(enrollment_embeddings.keys())}")

uploaded_file = st.file_uploader("Pilih file audio...", type=["wav"])

if uploaded_file is not None:
    st.audio(uploaded_file, format="audio/wav")
    
    if st.button("Verifikasi Sekarang"):
        with st.spinner("Memproses audio..."):
            signal = preprocess_audio(uploaded_file)
        
        if signal is not None:
            # --- TAHAP 1: VERIFIKASI SPEAKER (SIAPA?) ---
            st.subheader("Tahap 1: Verifikasi Speaker")
            
            with torch.no_grad():
                upload_embedding = spk_model.encode_batch(signal).squeeze().numpy()
            
            best_score = 0
            best_match = "Tidak Dikenali"
            
            # Bandingkan dengan setiap speaker yang terdaftar
            for speaker_name, enrolled_emb in enrollment_embeddings.items():
                score = cosine_similarity(
                    upload_embedding.reshape(1, -1),
                    enrolled_emb.reshape(1, -1)
                )[0][0]
                
                st.write(f"Skor kemiripan dengan {speaker_name}: **{score:.2f}**")
                
                if score > best_score:
                    best_score = score
                    best_match = speaker_name

            # --- KEPUTUSAN TAHAP 1 ---
            if best_score > THRESHOLD:
                st.success(f"βœ… **Akses Diberikan**: Dikenali sebagai **{best_match}** (Skor: {best_score:.2f})")
                
                # --- TAHAP 2: DETEKSI PERINTAH (APA?) ---
                st.subheader("Tahap 2: Deteksi Perintah")
                with st.spinner("Mendeteksi perintah..."):
                    with torch.no_grad():
                        # Model KWS memprediksi probabilitas
                        prediction = kws_model.classify_batch(signal)
                        
                        # Ambil label dengan probabilitas tertinggi
                        # 'prediction[0]' adalah tensor probabilitas
                        # 'prediction[3]' adalah labelnya
                        top_prob = torch.max(prediction[0]).item()
                        top_label = prediction[3][0]

                        # Logika untuk perintah "Buka" (Up) atau "Tutup" (Down)
                        # Catatan: Sesuaikan label ini ("Up", "Down") dengan output model KWS Anda
                        # Model Google Speech Command menggunakan "Up", "Down", "Left", "Right", "Yes", "No", dll.
                        
                        st.write(f"Perintah terdeteksi: **{top_label}** (Keyakinan: {top_prob:.2f})")

                        if top_label.lower() == "up": # Asumsikan 'Up' = 'Buka'
                            st.balloons()
                            st.success(f"πŸ”“ **Perintah Diterima**: `{best_match}` berkata 'BUKA'.")
                        elif top_label.lower() == "down": # Asumsikan 'Down' = 'Tutup'
                            st.success(f"πŸ”’ **Perintah Diterima**: `{best_match}` berkata 'TUTUP'.")
                        else:
                            st.warning(f"Perintah '{top_label}' tidak dikenali sebagai 'Buka' atau 'Tutup'.")

            else:
                st.error(f"❌ **Akses Ditolak**: Suara tidak dikenali (Skor tertinggi: {best_score:.2f})")