mich123geb commited on
Commit
e86c9cb
·
verified ·
1 Parent(s): 021e9c8

Upload 43 files

Browse files
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pkl
2
+ *.jpg
3
+ *.mp4
4
+ *.pth
5
+ *.pyc
6
+ __pycache__
7
+ *.h5
8
+ *.avi
9
+ *.wav
10
+ filelists/*.txt
11
+ evaluation/test_filelists/lr*.txt
12
+ *.pyc
13
+ *.mkv
14
+ *.gif
15
+ *.webm
16
+ *.mp3
README.md CHANGED
@@ -1,15 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Wav2lip Api
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.36.2
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
- # Wav2Lip on CPU (Hugging Face Free Tier)
12
 
13
- Upload an image and a WAV file to generate a talking video. Expect 2–4 minutes per video on free CPU.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Wav2Lip**: *Accurately Lip-syncing Videos In The Wild*
2
+
3
+ # Commercial Version
4
+
5
+ Create your first lipsync generation in minutes. Please note, the commercial version is of a much higher quality than the old open source model!
6
+
7
+ ## Create your API Key
8
+
9
+ Create your API key from the [Dashboard](https://sync.so/keys). You will use this key to securely access the Sync API.
10
+
11
+ ## Make your first generation
12
+
13
+ The following example shows how to make a lipsync generation using the Sync API.
14
+
15
+ ### Python
16
+
17
+ #### Step 1: Install Sync SDK
18
+
19
+ ```bash
20
+ pip install syncsdk
21
+ ```
22
+
23
+ #### Step 2: Make your first generation
24
+
25
+ Copy the following code into a file `quickstart.py` and replace `YOUR_API_KEY_HERE` with your generated API key.
26
+
27
+ ```python
28
+ # quickstart.py
29
+ import time
30
+ from sync import Sync
31
+ from sync.common import Audio, GenerationOptions, Video
32
+ from sync.core.api_error import ApiError
33
+
34
+ # ---------- UPDATE API KEY ----------
35
+ # Replace with your Sync.so API key
36
+ api_key = "YOUR_API_KEY_HERE"
37
+
38
+ # ----------[OPTIONAL] UPDATE INPUT VIDEO AND AUDIO URL ----------
39
+ # URL to your source video
40
+ video_url = "https://assets.sync.so/docs/example-video.mp4"
41
+ # URL to your audio file
42
+ audio_url = "https://assets.sync.so/docs/example-audio.wav"
43
+ # ----------------------------------------
44
+
45
+ client = Sync(
46
+ base_url="https://api.sync.so",
47
+ api_key=api_key
48
+ ).generations
49
+
50
+ print("Starting lip sync generation job...")
51
+
52
+ try:
53
+ response = client.create(
54
+ input=[Video(url=video_url),Audio(url=audio_url)],
55
+ model="lipsync-2",
56
+ options=GenerationOptions(sync_mode="cut_off"),
57
+ outputFileName="quickstart"
58
+ )
59
+ except ApiError as e:
60
+ print(f'create generation request failed with status code {e.status_code} and error {e.body}')
61
+ exit()
62
+
63
+ job_id = response.id
64
+ print(f"Generation submitted successfully, job id: {job_id}")
65
+
66
+ generation = client.get(job_id)
67
+ status = generation.status
68
+ while status not in ['COMPLETED', 'FAILED']:
69
+ print('polling status for generation', job_id)
70
+ time.sleep(10)
71
+ generation = client.get(job_id)
72
+ status = generation.status
73
+
74
+ if status == 'COMPLETED':
75
+ print('generation', job_id, 'completed successfully, output url:', generation.output_url)
76
+ else:
77
+ print('generation', job_id, 'failed')
78
+ ```
79
+
80
+ Run the script:
81
+
82
+ ```bash
83
+ python quickstart.py
84
+ ```
85
+
86
+ #### Step 3: Done!
87
+
88
+ It may take a few minutes for the generation to complete. You should see the generated video URL in the terminal post completion.
89
+
90
  ---
91
+
92
+ ### TypeScript
93
+
94
+ #### Step 1: Install dependencies
95
+
96
+ ```bash
97
+ npm i @sync.so/sdk
98
+ ```
99
+
100
+ #### Step 2: Make your first generation
101
+
102
+ Copy the following code into a file `quickstart.ts` and replace `YOUR_API_KEY_HERE` with your generated API key.
103
+
104
+ ```typescript
105
+ // quickstart.ts
106
+ import { SyncClient, SyncError } from "@sync.so/sdk";
107
+
108
+ // ---------- UPDATE API KEY ----------
109
+ // Replace with your Sync.so API key
110
+ const apiKey = "YOUR_API_KEY_HERE";
111
+
112
+ // ----------[OPTIONAL] UPDATE INPUT VIDEO AND AUDIO URL ----------
113
+ // URL to your source video
114
+ const videoUrl = "https://assets.sync.so/docs/example-video.mp4";
115
+ // URL to your audio file
116
+ const audioUrl = "https://assets.sync.so/docs/example-audio.wav";
117
+ // ----------------------------------------
118
+
119
+ const client = new SyncClient({ apiKey });
120
+
121
+ async function main() {
122
+ console.log("Starting lip sync generation job...");
123
+
124
+ let jobId: string;
125
+ try {
126
+ const response = await client.generations.create({
127
+ input: [
128
+ {
129
+ type: "video",
130
+ url: videoUrl,
131
+ },
132
+ {
133
+ type: "audio",
134
+ url: audioUrl,
135
+ },
136
+ ],
137
+ model: "lipsync-2",
138
+ options: {
139
+ sync_mode: "cut_off",
140
+ },
141
+ outputFileName: "quickstart"
142
+ });
143
+ jobId = response.id;
144
+ console.log(`Generation submitted successfully, job id: ${jobId}`);
145
+ } catch (err) {
146
+ if (err instanceof SyncError) {
147
+ console.error(`create generation request failed with status code ${err.statusCode} and error ${JSON.stringify(err.body)}`);
148
+ } else {
149
+ console.error('An unexpected error occurred:', err);
150
+ }
151
+ return;
152
+ }
153
+
154
+ let generation;
155
+ let status;
156
+ while (status !== 'COMPLETED' && status !== 'FAILED') {
157
+ console.log(`polling status for generation ${jobId}...`);
158
+ try {
159
+ await new Promise(resolve => setTimeout(resolve, 10000));
160
+ generation = await client.generations.get(jobId);
161
+ status = generation.status;
162
+ } catch (err) {
163
+ if (err instanceof SyncError) {
164
+ console.error(`polling failed with status code ${err.statusCode} and error ${JSON.stringify(err.body)}`);
165
+ } else {
166
+ console.error('An unexpected error occurred during polling:', err);
167
+ }
168
+ status = 'FAILED';
169
+ }
170
+ }
171
+
172
+ if (status === 'COMPLETED') {
173
+ console.log(`generation ${jobId} completed successfully, output url: ${generation?.outputUrl}`);
174
+ } else {
175
+ console.log(`generation ${jobId} failed`);
176
+ }
177
+ }
178
+
179
+ main();
180
+ ```
181
+
182
+ Run the script:
183
+
184
+ ```bash
185
+ npx tsx quickstart.ts -y
186
+ ```
187
+
188
+ #### Step 3: Done!
189
+
190
+ You should see the generated video URL in the terminal.
191
+
192
  ---
 
193
 
194
+ ## Next Steps
195
+
196
+ Well done! You've just made your first lipsync generation with sync.so!
197
+
198
+ Ready to unlock the full potential of lipsync? Dive into our interactive [Studio](https://sync.so/login) to experiment with all available models, or explore our [API Documentation](/api-reference) to take your lip-sync generations to the next level!
199
+
200
+ ## Contact
201
+ - prady@sync.so
202
+ - pavan@sync.so
203
+ - sanjit@sync.so
204
+
205
+
206
+
207
+ # Non Commercial Open-source Version
208
+
209
+ This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild_ published at ACM Multimedia 2020.
210
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs2)](https://paperswithcode.com/sota/lip-sync-on-lrs2?p=a-lip-sync-expert-is-all-you-need-for-speech)
211
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs3)](https://paperswithcode.com/sota/lip-sync-on-lrs3?p=a-lip-sync-expert-is-all-you-need-for-speech)
212
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrw)](https://paperswithcode.com/sota/lip-sync-on-lrw?p=a-lip-sync-expert-is-all-you-need-for-speech)
213
+ |📑 Original Paper|📰 Project Page|🌀 Demo|⚡ Live Testing|📔 Colab Notebook
214
+ |:-:|:-:|:-:|:-:|:-:|
215
+ [Paper](http://arxiv.org/abs/2008.10010) | [Project Page](http://cvit.iiit.ac.in/research/projects/cvit-projects/a-lip-sync-expert-is-all-you-need-for-speech-to-lip-generation-in-the-wild/) | [Demo Video](https://youtu.be/0fXaDCZNOJc) | [Interactive Demo](https://synclabs.so/) | [Colab Notebook](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing) /[Updated Collab Notebook](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH)
216
+
217
+ ![Logo](https://drive.google.com/uc?export=view&id=1Wn0hPmpo4GRbCIJR8Tf20Akzdi1qjjG9)
218
+ ----------
219
+ **Highlights**
220
+ ----------
221
+ - Weights of the visual quality disc has been updated in readme!
222
+ - Lip-sync videos to any target speech with high accuracy :100:. Try our [interactive demo](https://sync.so/).
223
+ - :sparkles: Works for any identity, voice, and language. Also works for CGI faces and synthetic voices.
224
+ - Complete training code, inference code, and pretrained models are available :boom:
225
+ - Or, quick-start with the Google Colab Notebook: [Link](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing). Checkpoints and samples are available in a Google Drive [folder](https://drive.google.com/drive/folders/1I-0dNLfFOSFwrfqjNa-SXuwaURHE5K4k?usp=sharing) as well. There is also a [tutorial video](https://www.youtube.com/watch?v=Ic0TBhfuOrA) on this, courtesy of [What Make Art](https://www.youtube.com/channel/UCmGXH-jy0o2CuhqtpxbaQgA). Also, thanks to [Eyal Gruss](https://eyalgruss.com), there is a more accessible [Google Colab notebook](https://j.mp/wav2lip) with more useful features. A tutorial collab notebook is present at this [link](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH).
226
+ - :fire: :fire: Several new, reliable evaluation benchmarks and metrics [[`evaluation/` folder of this repo]](https://github.com/Rudrabha/Wav2Lip/tree/master/evaluation) released. Instructions to calculate the metrics reported in the paper are also present.
227
+ --------
228
+ **Disclaimer**
229
+ --------
230
+ All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the <a href="http://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs2.html">LRS2 dataset</a>, any form of commercial use is strictly prohibited. For commercial requests please contact us directly!
231
+ Prerequisites
232
+ -------------
233
+ - `Python 3.6`
234
+ - ffmpeg: `sudo apt-get install ffmpeg`
235
+ - Install necessary packages using `pip install -r requirements.txt`. Alternatively, instructions for using a docker image is provided [here](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668). Have a look at [this comment](https://github.com/Rudrabha/Wav2Lip/issues/131#issuecomment-725478562) and comment on [the gist](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668) if you encounter any issues.
236
+ - Face detection [pre-trained model](https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth) should be downloaded to `face_detection/detection/sfd/s3fd.pth`. Alternative [link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/prajwal_k_research_iiit_ac_in/EZsy6qWuivtDnANIG73iHjIBjMSoojcIV0NULXV-yiuiIg?e=qTasa8) if the above does not work.
237
+ Getting the weights
238
+ ----------
239
+ | Model | Description | Link to the model |
240
+ | :-------------: | :---------------: | :---------------: |
241
+ | Wav2Lip | Highly accurate lip-sync | [Link](https://drive.google.com/drive/folders/153HLrqlBNxzZcHi17PEvP09kkAfzRshM?usp=share_link) |
242
+ | Wav2Lip + GAN | Slightly inferior lip-sync, but better visual quality | [Link](https://drive.google.com/file/d/15G3U08c8xsCkOqQxE38Z2XXDnPcOptNk/view?usp=share_link) |
243
+
244
 
245
+ Lip-syncing videos using the pre-trained models (Inference)
246
+ -------
247
+ You can lip-sync any video to any audio:
248
+ ```bash
249
+ python inference.py --checkpoint_path <ckpt> --face <video.mp4> --audio <an-audio-source>
250
+ ```
251
+ The result is saved (by default) in `results/result_voice.mp4`. You can specify it as an argument, similar to several other available options. The audio source can be any file supported by `FFMPEG` containing audio data: `*.wav`, `*.mp3` or even a video file, from which the code will automatically extract the audio.
252
+ ##### Tips for better results:
253
+ - Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`.
254
+ - If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give it another try.
255
+ - Experiment with the `--resize_factor` argument, to get a lower-resolution video. Why? The models are trained on faces that were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too).
256
+ - The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well.
257
+ Preparing LRS2 for training
258
+ ----------
259
+ Our models are trained on LRS2. See [here](#training-on-datasets-other-than-lrs2) for a few suggestions regarding training on other datasets.
260
+ ##### LRS2 dataset folder structure
261
+ ```
262
+ data_root (mvlrs_v1)
263
+ ├── main, pretrain (we use only main folder in this work)
264
+ | ├── list of folders
265
+ | │ ├── five-digit numbered video IDs ending with (.mp4)
266
+ ```
267
+ Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` folder.
268
+ ##### Preprocess the dataset for fast training
269
+ ```bash
270
+ python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/
271
+ ```
272
+ Additional options like `batch_size` and the number of GPUs to use in parallel to use can also be set.
273
+ ##### Preprocessed LRS2 folder structure
274
+ ```
275
+ preprocessed_root (lrs2_preprocessed)
276
+ ├── list of folders
277
+ | ├── Folders with five-digit numbered video IDs
278
+ | │ ├── *.jpg
279
+ | │ ├── audio.wav
280
+ ```
281
+ Train!
282
+ ----------
283
+ There are two major steps: (i) Train the expert lip-sync discriminator, (ii) Train the Wav2Lip model(s).
284
+ ##### Training the expert discriminator
285
+ You can download [the pre-trained weights](#getting-the-weights) if you want to skip this step. To train it:
286
+ ```bash
287
+ python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints>
288
+ ```
289
+ ##### Training the Wav2Lip models
290
+ You can either train the model without the additional visual quality discriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run:
291
+ ```bash
292
+ python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints> --syncnet_checkpoint_path <path_to_expert_disc_checkpoint>
293
+ ```
294
+ To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both files are similar. In both cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file.
295
+ Training on datasets other than LRS2
296
+ ------------------------------------
297
+ Training on other datasets might require modifications to the code. Please read the following before you raise an issue:
298
+ - You might not get good results by training/fine-tuning on a few minutes of a single speaker. This is a separate research problem, to which we do not have a solution yet. Thus, we would most likely not be able to resolve your issue.
299
+ - You must train the expert discriminator for your own dataset before training Wav2Lip.
300
+ - If it is your own dataset downloaded from the web, in most cases, needs to be sync-corrected.
301
+ - Be mindful of the FPS of the videos of your dataset. Changes to FPS would need significant code changes.
302
+ - The expert discriminator's eval loss should go down to ~0.25 and the Wav2Lip eval sync loss should go down to ~0.2 to get good results.
303
+ When raising an issue on this topic, please let us know that you are aware of all these points.
304
+ We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model.
305
+ Evaluation
306
+ ----------
307
+ Please check the `evaluation/` folder for the instructions.
308
+ License and Citation
309
+ ----------
310
+ This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at rudrabha@synclabs.so or prajwal@synclabs.so. We have a turn-key hosted API with new and improved lip-syncing models here: https://synclabs.so/
311
+ The size of the generated face will be 192 x 288 in our new models. Please cite the following paper if you use this repository:
312
+ ```
313
+ @inproceedings{10.1145/3394171.3413532,
314
+ author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.},
315
+ title = {A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild},
316
+ year = {2020},
317
+ isbn = {9781450379885},
318
+ publisher = {Association for Computing Machinery},
319
+ address = {New York, NY, USA},
320
+ url = {https://doi.org/10.1145/3394171.3413532},
321
+ doi = {10.1145/3394171.3413532},
322
+ booktitle = {Proceedings of the 28th ACM International Conference on Multimedia},
323
+ pages = {484–492},
324
+ numpages = {9},
325
+ keywords = {lip sync, talking face generation, video generation},
326
+ location = {Seattle, WA, USA},
327
+ series = {MM '20}
328
+ }
329
+ ```
330
+ Acknowledgments
331
+ ----------
332
+ Parts of the code structure are inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook.
333
+ ## Acknowledgements
334
+ - [Awesome Readme Templates](https://awesomeopensource.com/project/elangosundar/awesome-README-templates)
335
+ - [Awesome README](https://github.com/matiassingers/awesome-readme)
336
+ - [How to write a Good readme](https://bulldogjob.com/news/449-how-to-write-a-good-readme-for-your-github-project)
audio.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from hparams import hparams as hp
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ librosa.output.write_wav(path, wav, sr=sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ def get_hop_size():
31
+ hop_size = hp.hop_size
32
+ if hop_size is None:
33
+ assert hp.frame_shift_ms is not None
34
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35
+ return hop_size
36
+
37
+ def linearspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def melspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+ def _lws_processor():
54
+ import lws
55
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56
+
57
+ def _stft(y):
58
+ if hp.use_lws:
59
+ return _lws_processor(hp).stft(y).T
60
+ else:
61
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62
+
63
+ ##########################################################
64
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65
+ def num_frames(length, fsize, fshift):
66
+ """Compute number of time frames of spectrogram
67
+ """
68
+ pad = (fsize - fshift)
69
+ if length % fshift == 0:
70
+ M = (length + pad * 2 - fsize) // fshift + 1
71
+ else:
72
+ M = (length + pad * 2 - fsize) // fshift + 2
73
+ return M
74
+
75
+
76
+ def pad_lr(x, fsize, fshift):
77
+ """Compute left and right padding
78
+ """
79
+ M = num_frames(len(x), fsize, fshift)
80
+ pad = (fsize - fshift)
81
+ T = len(x) + 2 * pad
82
+ r = (M - 1) * fshift + fsize - T
83
+ return pad, pad + r
84
+ ##########################################################
85
+ #Librosa correct padding
86
+ def librosa_pad_lr(x, fsize, fshift):
87
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88
+
89
+ # Conversions
90
+ _mel_basis = None
91
+
92
+ def _linear_to_mel(spectogram):
93
+ global _mel_basis
94
+ if _mel_basis is None:
95
+ _mel_basis = _build_mel_basis()
96
+ return np.dot(_mel_basis, spectogram)
97
+
98
+ def _build_mel_basis():
99
+ assert hp.fmax <= hp.sample_rate // 2
100
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
101
+ fmin=hp.fmin, fmax=hp.fmax)
102
+
103
+ def _amp_to_db(x):
104
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105
+ return 20 * np.log10(np.maximum(min_level, x))
106
+
107
+ def _db_to_amp(x):
108
+ return np.power(10.0, (x) * 0.05)
109
+
110
+ def _normalize(S):
111
+ if hp.allow_clipping_in_normalization:
112
+ if hp.symmetric_mels:
113
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114
+ -hp.max_abs_value, hp.max_abs_value)
115
+ else:
116
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117
+
118
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119
+ if hp.symmetric_mels:
120
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121
+ else:
122
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123
+
124
+ def _denormalize(D):
125
+ if hp.allow_clipping_in_normalization:
126
+ if hp.symmetric_mels:
127
+ return (((np.clip(D, -hp.max_abs_value,
128
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129
+ + hp.min_level_db)
130
+ else:
131
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132
+
133
+ if hp.symmetric_mels:
134
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135
+ else:
136
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
checkpoints/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Place all your checkpoints (.pth files) here.
color_syncnet_train.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ import audio
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+ import torch.backends.cudnn as cudnn
11
+ from torch.utils import data as data_utils
12
+ import numpy as np
13
+
14
+ from glob import glob
15
+
16
+ import os, random, cv2, argparse
17
+ from hparams import hparams, get_image_list
18
+
19
+ parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')
20
+
21
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)
22
+
23
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
24
+ parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
25
+
26
+ args = parser.parse_args()
27
+
28
+
29
+ global_step = 0
30
+ global_epoch = 0
31
+ use_cuda = torch.cuda.is_available()
32
+ print('use_cuda: {}'.format(use_cuda))
33
+
34
+ syncnet_T = 5
35
+ syncnet_mel_step_size = 16
36
+
37
+ class Dataset(object):
38
+ def __init__(self, split):
39
+ self.all_videos = get_image_list(args.data_root, split)
40
+
41
+ def get_frame_id(self, frame):
42
+ return int(basename(frame).split('.')[0])
43
+
44
+ def get_window(self, start_frame):
45
+ start_id = self.get_frame_id(start_frame)
46
+ vidname = dirname(start_frame)
47
+
48
+ window_fnames = []
49
+ for frame_id in range(start_id, start_id + syncnet_T):
50
+ frame = join(vidname, '{}.jpg'.format(frame_id))
51
+ if not isfile(frame):
52
+ return None
53
+ window_fnames.append(frame)
54
+ return window_fnames
55
+
56
+ def crop_audio_window(self, spec, start_frame):
57
+ # num_frames = (T x hop_size * fps) / sample_rate
58
+ start_frame_num = self.get_frame_id(start_frame)
59
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
60
+
61
+ end_idx = start_idx + syncnet_mel_step_size
62
+
63
+ return spec[start_idx : end_idx, :]
64
+
65
+
66
+ def __len__(self):
67
+ return len(self.all_videos)
68
+
69
+ def __getitem__(self, idx):
70
+ while 1:
71
+ idx = random.randint(0, len(self.all_videos) - 1)
72
+ vidname = self.all_videos[idx]
73
+
74
+ img_names = list(glob(join(vidname, '*.jpg')))
75
+ if len(img_names) <= 3 * syncnet_T:
76
+ continue
77
+ img_name = random.choice(img_names)
78
+ wrong_img_name = random.choice(img_names)
79
+ while wrong_img_name == img_name:
80
+ wrong_img_name = random.choice(img_names)
81
+
82
+ if random.choice([True, False]):
83
+ y = torch.ones(1).float()
84
+ chosen = img_name
85
+ else:
86
+ y = torch.zeros(1).float()
87
+ chosen = wrong_img_name
88
+
89
+ window_fnames = self.get_window(chosen)
90
+ if window_fnames is None:
91
+ continue
92
+
93
+ window = []
94
+ all_read = True
95
+ for fname in window_fnames:
96
+ img = cv2.imread(fname)
97
+ if img is None:
98
+ all_read = False
99
+ break
100
+ try:
101
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
102
+ except Exception as e:
103
+ all_read = False
104
+ break
105
+
106
+ window.append(img)
107
+
108
+ if not all_read: continue
109
+
110
+ try:
111
+ wavpath = join(vidname, "audio.wav")
112
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
113
+
114
+ orig_mel = audio.melspectrogram(wav).T
115
+ except Exception as e:
116
+ continue
117
+
118
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
119
+
120
+ if (mel.shape[0] != syncnet_mel_step_size):
121
+ continue
122
+
123
+ # H x W x 3 * T
124
+ x = np.concatenate(window, axis=2) / 255.
125
+ x = x.transpose(2, 0, 1)
126
+ x = x[:, x.shape[1]//2:]
127
+
128
+ x = torch.FloatTensor(x)
129
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
130
+
131
+ return x, mel, y
132
+
133
+ logloss = nn.BCELoss()
134
+ def cosine_loss(a, v, y):
135
+ d = nn.functional.cosine_similarity(a, v)
136
+ loss = logloss(d.unsqueeze(1), y)
137
+
138
+ return loss
139
+
140
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
141
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
142
+
143
+ global global_step, global_epoch
144
+ resumed_step = global_step
145
+
146
+ while global_epoch < nepochs:
147
+ running_loss = 0.
148
+ prog_bar = tqdm(enumerate(train_data_loader))
149
+ for step, (x, mel, y) in prog_bar:
150
+ model.train()
151
+ optimizer.zero_grad()
152
+
153
+ # Transform data to CUDA device
154
+ x = x.to(device)
155
+
156
+ mel = mel.to(device)
157
+
158
+ a, v = model(mel, x)
159
+ y = y.to(device)
160
+
161
+ loss = cosine_loss(a, v, y)
162
+ loss.backward()
163
+ optimizer.step()
164
+
165
+ global_step += 1
166
+ cur_session_steps = global_step - resumed_step
167
+ running_loss += loss.item()
168
+
169
+ if global_step == 1 or global_step % checkpoint_interval == 0:
170
+ save_checkpoint(
171
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
172
+
173
+ if global_step % hparams.syncnet_eval_interval == 0:
174
+ with torch.no_grad():
175
+ eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
176
+
177
+ prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
178
+
179
+ global_epoch += 1
180
+
181
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
182
+ eval_steps = 1400
183
+ print('Evaluating for {} steps'.format(eval_steps))
184
+ losses = []
185
+ while 1:
186
+ for step, (x, mel, y) in enumerate(test_data_loader):
187
+
188
+ model.eval()
189
+
190
+ # Transform data to CUDA device
191
+ x = x.to(device)
192
+
193
+ mel = mel.to(device)
194
+
195
+ a, v = model(mel, x)
196
+ y = y.to(device)
197
+
198
+ loss = cosine_loss(a, v, y)
199
+ losses.append(loss.item())
200
+
201
+ if step > eval_steps: break
202
+
203
+ averaged_loss = sum(losses) / len(losses)
204
+ print(averaged_loss)
205
+
206
+ return
207
+
208
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
209
+
210
+ checkpoint_path = join(
211
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
212
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
213
+ torch.save({
214
+ "state_dict": model.state_dict(),
215
+ "optimizer": optimizer_state,
216
+ "global_step": step,
217
+ "global_epoch": epoch,
218
+ }, checkpoint_path)
219
+ print("Saved checkpoint:", checkpoint_path)
220
+
221
+ def _load(checkpoint_path):
222
+ if use_cuda:
223
+ checkpoint = torch.load(checkpoint_path)
224
+ else:
225
+ checkpoint = torch.load(checkpoint_path,
226
+ map_location=lambda storage, loc: storage)
227
+ return checkpoint
228
+
229
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False):
230
+ global global_step
231
+ global global_epoch
232
+
233
+ print("Load checkpoint from: {}".format(path))
234
+ checkpoint = _load(path)
235
+ model.load_state_dict(checkpoint["state_dict"])
236
+ if not reset_optimizer:
237
+ optimizer_state = checkpoint["optimizer"]
238
+ if optimizer_state is not None:
239
+ print("Load optimizer state from {}".format(path))
240
+ optimizer.load_state_dict(checkpoint["optimizer"])
241
+ global_step = checkpoint["global_step"]
242
+ global_epoch = checkpoint["global_epoch"]
243
+
244
+ return model
245
+
246
+ if __name__ == "__main__":
247
+ checkpoint_dir = args.checkpoint_dir
248
+ checkpoint_path = args.checkpoint_path
249
+
250
+ if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
251
+
252
+ # Dataset and Dataloader setup
253
+ train_dataset = Dataset('train')
254
+ test_dataset = Dataset('val')
255
+
256
+ train_data_loader = data_utils.DataLoader(
257
+ train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
258
+ num_workers=hparams.num_workers)
259
+
260
+ test_data_loader = data_utils.DataLoader(
261
+ test_dataset, batch_size=hparams.syncnet_batch_size,
262
+ num_workers=8)
263
+
264
+ device = torch.device("cuda" if use_cuda else "cpu")
265
+
266
+ # Model
267
+ model = SyncNet().to(device)
268
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
269
+
270
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
271
+ lr=hparams.syncnet_lr)
272
+
273
+ if checkpoint_path is not None:
274
+ load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)
275
+
276
+ train(device, model, train_data_loader, test_data_loader, optimizer,
277
+ checkpoint_dir=checkpoint_dir,
278
+ checkpoint_interval=hparams.syncnet_checkpoint_interval,
279
+ nepochs=hparams.nepochs)
evaluation/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Novel Evaluation Framework, new filelists, and using the LSE-D and LSE-C metric.
2
+
3
+ Our paper also proposes a novel evaluation framework (Section 4). To evaluate on LRS2, LRS3, and LRW, the filelists are present in the `test_filelists` folder. Please use `gen_videos_from_filelist.py` script to generate the videos. After that, you can calculate the LSE-D and LSE-C scores using the instructions below. Please see [this thread](https://github.com/Rudrabha/Wav2Lip/issues/22#issuecomment-712825380) on how to calculate the FID scores.
4
+
5
+ The videos of the ReSyncED benchmark for real-world evaluation will be released soon.
6
+
7
+ ### Steps to set-up the evaluation repository for LSE-D and LSE-C metric:
8
+ We use the pre-trained syncnet model available in this [repository](https://github.com/joonson/syncnet_python).
9
+
10
+ * Clone the SyncNet repository.
11
+ ```
12
+ git clone https://github.com/joonson/syncnet_python.git
13
+ ```
14
+ * Follow the procedure given in the above linked [repository](https://github.com/joonson/syncnet_python) to download the pretrained models and set up the dependencies.
15
+ * **Note: Please install a separate virtual environment for the evaluation scripts. The versions used by Wav2Lip and the publicly released code of SyncNet is different and can cause version mis-match issues. To avoid this, we suggest the users to install a separate virtual environment for the evaluation scripts**
16
+ ```
17
+ cd syncnet_python
18
+ pip install -r requirements.txt
19
+ sh download_model.sh
20
+ ```
21
+ * The above step should ensure that all the dependencies required by the repository is installed and the pre-trained models are downloaded.
22
+
23
+ ### Running the evaluation scripts:
24
+ * Copy our evaluation scripts given in this folder to the cloned repository.
25
+ ```
26
+ cd Wav2Lip/evaluation/scores_LSE/
27
+ cp *.py syncnet_python/
28
+ cp *.sh syncnet_python/
29
+ ```
30
+ **Note: We will release the test filelists for LRW, LRS2 and LRS3 shortly once we receive permission from the dataset creators. We will also release the Real World Dataset we have collected shortly.**
31
+
32
+ * Our evaluation technique does not require ground-truth of any sorts. Given lip-synced videos we can directly calculate the scores from only the generated videos. Please store the generated videos (from our test sets or your own generated videos) in the following folder structure.
33
+ ```
34
+ video data root (Folder containing all videos)
35
+ ├── All .mp4 files
36
+ ```
37
+ * Change the folder back to the cloned repository.
38
+ ```
39
+ cd syncnet_python
40
+ ```
41
+ * To run evaluation on the LRW, LRS2 and LRS3 test files, please run the following command:
42
+ ```
43
+ python calculate_scores_LRS.py --data_root /path/to/video/data/root --tmp_dir tmp_dir/
44
+ ```
45
+
46
+ * To run evaluation on the ReSynced dataset or your own generated videos, please run the following command:
47
+ ```
48
+ sh calculate_scores_real_videos.sh /path/to/video/data/root
49
+ ```
50
+ * The generated scores will be present in the all_scores.txt generated in the ```syncnet_python/``` folder
51
+
52
+ # Evaluation of image quality using FID metric.
53
+ We use the [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository for calculating the FID metrics. We dump all the frames in both ground-truth and generated videos and calculate the FID score.
54
+
55
+
56
+ # Opening issues related to evaluation scripts
57
+ * Please open the issues with the "Evaluation" label if you face any issues in the evaluation scripts.
58
+
59
+ # Acknowledgements
60
+ Our evaluation pipeline in based on two existing repositories. LSE metrics are based on the [syncnet_python](https://github.com/joonson/syncnet_python) repository and the FID score is based on [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository. We thank the authors of both the repositories for releasing their wonderful code.
61
+
62
+
63
+
evaluation/gen_videos_from_filelist.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse
4
+ import dlib, json, subprocess
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch
8
+
9
+ sys.path.append('../')
10
+ import audio
11
+ import face_detection
12
+ from models import Wav2Lip
13
+
14
+ parser = argparse.ArgumentParser(description='Code to generate results for test filelists')
15
+
16
+ parser.add_argument('--filelist', type=str,
17
+ help='Filepath of filelist file to read', required=True)
18
+ parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
19
+ required=True)
20
+ parser.add_argument('--data_root', type=str, required=True)
21
+ parser.add_argument('--checkpoint_path', type=str,
22
+ help='Name of saved checkpoint to load weights from', required=True)
23
+
24
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0],
25
+ help='Padding (top, bottom, left, right)')
26
+ parser.add_argument('--face_det_batch_size', type=int,
27
+ help='Single GPU batch size for face detection', default=64)
28
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
29
+
30
+ # parser.add_argument('--resize_factor', default=1, type=int)
31
+
32
+ args = parser.parse_args()
33
+ args.img_size = 96
34
+
35
+ def get_smoothened_boxes(boxes, T):
36
+ for i in range(len(boxes)):
37
+ if i + T > len(boxes):
38
+ window = boxes[len(boxes) - T:]
39
+ else:
40
+ window = boxes[i : i + T]
41
+ boxes[i] = np.mean(window, axis=0)
42
+ return boxes
43
+
44
+ def face_detect(images):
45
+ batch_size = args.face_det_batch_size
46
+
47
+ while 1:
48
+ predictions = []
49
+ try:
50
+ for i in range(0, len(images), batch_size):
51
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
52
+ except RuntimeError:
53
+ if batch_size == 1:
54
+ raise RuntimeError('Image too big to run face detection on GPU')
55
+ batch_size //= 2
56
+ args.face_det_batch_size = batch_size
57
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
58
+ continue
59
+ break
60
+
61
+ results = []
62
+ pady1, pady2, padx1, padx2 = args.pads
63
+ for rect, image in zip(predictions, images):
64
+ if rect is None:
65
+ raise ValueError('Face not detected!')
66
+
67
+ y1 = max(0, rect[1] - pady1)
68
+ y2 = min(image.shape[0], rect[3] + pady2)
69
+ x1 = max(0, rect[0] - padx1)
70
+ x2 = min(image.shape[1], rect[2] + padx2)
71
+
72
+ results.append([x1, y1, x2, y2])
73
+
74
+ boxes = get_smoothened_boxes(np.array(results), T=5)
75
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
76
+
77
+ return results
78
+
79
+ def datagen(frames, face_det_results, mels):
80
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
81
+
82
+ for i, m in enumerate(mels):
83
+ if i >= len(frames): raise ValueError('Equal or less lengths only')
84
+
85
+ frame_to_save = frames[i].copy()
86
+ face, coords, valid_frame = face_det_results[i].copy()
87
+ if not valid_frame:
88
+ continue
89
+
90
+ face = cv2.resize(face, (args.img_size, args.img_size))
91
+
92
+ img_batch.append(face)
93
+ mel_batch.append(m)
94
+ frame_batch.append(frame_to_save)
95
+ coords_batch.append(coords)
96
+
97
+ if len(img_batch) >= args.wav2lip_batch_size:
98
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
99
+
100
+ img_masked = img_batch.copy()
101
+ img_masked[:, args.img_size//2:] = 0
102
+
103
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
104
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
105
+
106
+ yield img_batch, mel_batch, frame_batch, coords_batch
107
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
108
+
109
+ if len(img_batch) > 0:
110
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
111
+
112
+ img_masked = img_batch.copy()
113
+ img_masked[:, args.img_size//2:] = 0
114
+
115
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
116
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
117
+
118
+ yield img_batch, mel_batch, frame_batch, coords_batch
119
+
120
+ fps = 25
121
+ mel_step_size = 16
122
+ mel_idx_multiplier = 80./fps
123
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
124
+ print('Using {} for inference.'.format(device))
125
+
126
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
127
+ flip_input=False, device=device)
128
+
129
+ def _load(checkpoint_path):
130
+ if device == 'cuda':
131
+ checkpoint = torch.load(checkpoint_path)
132
+ else:
133
+ checkpoint = torch.load(checkpoint_path,
134
+ map_location=lambda storage, loc: storage)
135
+ return checkpoint
136
+
137
+ def load_model(path):
138
+ model = Wav2Lip()
139
+ print("Load checkpoint from: {}".format(path))
140
+ checkpoint = _load(path)
141
+ s = checkpoint["state_dict"]
142
+ new_s = {}
143
+ for k, v in s.items():
144
+ new_s[k.replace('module.', '')] = v
145
+ model.load_state_dict(new_s)
146
+
147
+ model = model.to(device)
148
+ return model.eval()
149
+
150
+ model = load_model(args.checkpoint_path)
151
+
152
+ def main():
153
+ assert args.data_root is not None
154
+ data_root = args.data_root
155
+
156
+ if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
157
+
158
+ with open(args.filelist, 'r') as filelist:
159
+ lines = filelist.readlines()
160
+
161
+ for idx, line in enumerate(tqdm(lines)):
162
+ audio_src, video = line.strip().split()
163
+
164
+ audio_src = os.path.join(data_root, audio_src) + '.mp4'
165
+ video = os.path.join(data_root, video) + '.mp4'
166
+
167
+ command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
168
+ subprocess.call(command, shell=True)
169
+ temp_audio = '../temp/temp.wav'
170
+
171
+ wav = audio.load_wav(temp_audio, 16000)
172
+ mel = audio.melspectrogram(wav)
173
+ if np.isnan(mel.reshape(-1)).sum() > 0:
174
+ continue
175
+
176
+ mel_chunks = []
177
+ i = 0
178
+ while 1:
179
+ start_idx = int(i * mel_idx_multiplier)
180
+ if start_idx + mel_step_size > len(mel[0]):
181
+ break
182
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
183
+ i += 1
184
+
185
+ video_stream = cv2.VideoCapture(video)
186
+
187
+ full_frames = []
188
+ while 1:
189
+ still_reading, frame = video_stream.read()
190
+ if not still_reading or len(full_frames) > len(mel_chunks):
191
+ video_stream.release()
192
+ break
193
+ full_frames.append(frame)
194
+
195
+ if len(full_frames) < len(mel_chunks):
196
+ continue
197
+
198
+ full_frames = full_frames[:len(mel_chunks)]
199
+
200
+ try:
201
+ face_det_results = face_detect(full_frames.copy())
202
+ except ValueError as e:
203
+ continue
204
+
205
+ batch_size = args.wav2lip_batch_size
206
+ gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
207
+
208
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
209
+ if i == 0:
210
+ frame_h, frame_w = full_frames[0].shape[:-1]
211
+ out = cv2.VideoWriter('../temp/result.avi',
212
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
213
+
214
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
215
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
216
+
217
+ with torch.no_grad():
218
+ pred = model(mel_batch, img_batch)
219
+
220
+
221
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
222
+
223
+ for pl, f, c in zip(pred, frames, coords):
224
+ y1, y2, x1, x2 = c
225
+ pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
226
+ f[y1:y2, x1:x2] = pl
227
+ out.write(f)
228
+
229
+ out.release()
230
+
231
+ vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
232
+
233
+ command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio,
234
+ '../temp/result.avi', vid)
235
+ subprocess.call(command, shell=True)
236
+
237
+ if __name__ == '__main__':
238
+ main()
evaluation/real_videos_inference.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse
4
+ import dlib, json, subprocess
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch
8
+
9
+ sys.path.append('../')
10
+ import audio
11
+ import face_detection
12
+ from models import Wav2Lip
13
+
14
+ parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set')
15
+
16
+ parser.add_argument('--mode', type=str,
17
+ help='random | dubbed | tts', required=True)
18
+
19
+ parser.add_argument('--filelist', type=str,
20
+ help='Filepath of filelist file to read', default=None)
21
+
22
+ parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
23
+ required=True)
24
+ parser.add_argument('--data_root', type=str, required=True)
25
+ parser.add_argument('--checkpoint_path', type=str,
26
+ help='Name of saved checkpoint to load weights from', required=True)
27
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
28
+ help='Padding (top, bottom, left, right)')
29
+
30
+ parser.add_argument('--face_det_batch_size', type=int,
31
+ help='Single GPU batch size for face detection', default=16)
32
+
33
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
34
+ parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180)
35
+ parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480)
36
+ parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720)
37
+ # parser.add_argument('--resize_factor', default=1, type=int)
38
+
39
+ args = parser.parse_args()
40
+ args.img_size = 96
41
+
42
+ def get_smoothened_boxes(boxes, T):
43
+ for i in range(len(boxes)):
44
+ if i + T > len(boxes):
45
+ window = boxes[len(boxes) - T:]
46
+ else:
47
+ window = boxes[i : i + T]
48
+ boxes[i] = np.mean(window, axis=0)
49
+ return boxes
50
+
51
+ def rescale_frames(images):
52
+ rect = detector.get_detections_for_batch(np.array([images[0]]))[0]
53
+ if rect is None:
54
+ raise ValueError('Face not detected!')
55
+ h, w = images[0].shape[:-1]
56
+
57
+ x1, y1, x2, y2 = rect
58
+
59
+ face_size = max(np.abs(y1 - y2), np.abs(x1 - x2))
60
+
61
+ diff = np.abs(face_size - args.face_res)
62
+ for factor in range(2, 16):
63
+ downsampled_res = face_size // factor
64
+ if min(h//factor, w//factor) < args.min_frame_res: break
65
+ if np.abs(downsampled_res - args.face_res) >= diff: break
66
+
67
+ factor -= 1
68
+ if factor == 1: return images
69
+
70
+ return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images]
71
+
72
+
73
+ def face_detect(images):
74
+ batch_size = args.face_det_batch_size
75
+ images = rescale_frames(images)
76
+
77
+ while 1:
78
+ predictions = []
79
+ try:
80
+ for i in range(0, len(images), batch_size):
81
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
82
+ except RuntimeError:
83
+ if batch_size == 1:
84
+ raise RuntimeError('Image too big to run face detection on GPU')
85
+ batch_size //= 2
86
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
87
+ continue
88
+ break
89
+
90
+ results = []
91
+ pady1, pady2, padx1, padx2 = args.pads
92
+ for rect, image in zip(predictions, images):
93
+ if rect is None:
94
+ raise ValueError('Face not detected!')
95
+
96
+ y1 = max(0, rect[1] - pady1)
97
+ y2 = min(image.shape[0], rect[3] + pady2)
98
+ x1 = max(0, rect[0] - padx1)
99
+ x2 = min(image.shape[1], rect[2] + padx2)
100
+
101
+ results.append([x1, y1, x2, y2])
102
+
103
+ boxes = get_smoothened_boxes(np.array(results), T=5)
104
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
105
+
106
+ return results, images
107
+
108
+ def datagen(frames, face_det_results, mels):
109
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
+
111
+ for i, m in enumerate(mels):
112
+ if i >= len(frames): raise ValueError('Equal or less lengths only')
113
+
114
+ frame_to_save = frames[i].copy()
115
+ face, coords, valid_frame = face_det_results[i].copy()
116
+ if not valid_frame:
117
+ continue
118
+
119
+ face = cv2.resize(face, (args.img_size, args.img_size))
120
+
121
+ img_batch.append(face)
122
+ mel_batch.append(m)
123
+ frame_batch.append(frame_to_save)
124
+ coords_batch.append(coords)
125
+
126
+ if len(img_batch) >= args.wav2lip_batch_size:
127
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
128
+
129
+ img_masked = img_batch.copy()
130
+ img_masked[:, args.img_size//2:] = 0
131
+
132
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
133
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
134
+
135
+ yield img_batch, mel_batch, frame_batch, coords_batch
136
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
137
+
138
+ if len(img_batch) > 0:
139
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
140
+
141
+ img_masked = img_batch.copy()
142
+ img_masked[:, args.img_size//2:] = 0
143
+
144
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
145
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
146
+
147
+ yield img_batch, mel_batch, frame_batch, coords_batch
148
+
149
+ def increase_frames(frames, l):
150
+ ## evenly duplicating frames to increase length of video
151
+ while len(frames) < l:
152
+ dup_every = float(l) / len(frames)
153
+
154
+ final_frames = []
155
+ next_duplicate = 0.
156
+
157
+ for i, f in enumerate(frames):
158
+ final_frames.append(f)
159
+
160
+ if int(np.ceil(next_duplicate)) == i:
161
+ final_frames.append(f)
162
+
163
+ next_duplicate += dup_every
164
+
165
+ frames = final_frames
166
+
167
+ return frames[:l]
168
+
169
+ mel_step_size = 16
170
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
171
+ print('Using {} for inference.'.format(device))
172
+
173
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
174
+ flip_input=False, device=device)
175
+
176
+ def _load(checkpoint_path):
177
+ if device == 'cuda':
178
+ checkpoint = torch.load(checkpoint_path)
179
+ else:
180
+ checkpoint = torch.load(checkpoint_path,
181
+ map_location=lambda storage, loc: storage)
182
+ return checkpoint
183
+
184
+ def load_model(path):
185
+ model = Wav2Lip()
186
+ print("Load checkpoint from: {}".format(path))
187
+ checkpoint = _load(path)
188
+ s = checkpoint["state_dict"]
189
+ new_s = {}
190
+ for k, v in s.items():
191
+ new_s[k.replace('module.', '')] = v
192
+ model.load_state_dict(new_s)
193
+
194
+ model = model.to(device)
195
+ return model.eval()
196
+
197
+ model = load_model(args.checkpoint_path)
198
+
199
+ def main():
200
+ if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
201
+
202
+ if args.mode == 'dubbed':
203
+ files = listdir(args.data_root)
204
+ lines = ['{} {}'.format(f, f) for f in files]
205
+
206
+ else:
207
+ assert args.filelist is not None
208
+ with open(args.filelist, 'r') as filelist:
209
+ lines = filelist.readlines()
210
+
211
+ for idx, line in enumerate(tqdm(lines)):
212
+ video, audio_src = line.strip().split()
213
+
214
+ audio_src = os.path.join(args.data_root, audio_src)
215
+ video = os.path.join(args.data_root, video)
216
+
217
+ command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
218
+ subprocess.call(command, shell=True)
219
+ temp_audio = '../temp/temp.wav'
220
+
221
+ wav = audio.load_wav(temp_audio, 16000)
222
+ mel = audio.melspectrogram(wav)
223
+
224
+ if np.isnan(mel.reshape(-1)).sum() > 0:
225
+ raise ValueError('Mel contains nan!')
226
+
227
+ video_stream = cv2.VideoCapture(video)
228
+
229
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
230
+ mel_idx_multiplier = 80./fps
231
+
232
+ full_frames = []
233
+ while 1:
234
+ still_reading, frame = video_stream.read()
235
+ if not still_reading:
236
+ video_stream.release()
237
+ break
238
+
239
+ if min(frame.shape[:-1]) > args.max_frame_res:
240
+ h, w = frame.shape[:-1]
241
+ scale_factor = min(h, w) / float(args.max_frame_res)
242
+ h = int(h/scale_factor)
243
+ w = int(w/scale_factor)
244
+
245
+ frame = cv2.resize(frame, (w, h))
246
+ full_frames.append(frame)
247
+
248
+ mel_chunks = []
249
+ i = 0
250
+ while 1:
251
+ start_idx = int(i * mel_idx_multiplier)
252
+ if start_idx + mel_step_size > len(mel[0]):
253
+ break
254
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
255
+ i += 1
256
+
257
+ if len(full_frames) < len(mel_chunks):
258
+ if args.mode == 'tts':
259
+ full_frames = increase_frames(full_frames, len(mel_chunks))
260
+ else:
261
+ raise ValueError('#Frames, audio length mismatch')
262
+
263
+ else:
264
+ full_frames = full_frames[:len(mel_chunks)]
265
+
266
+ try:
267
+ face_det_results, full_frames = face_detect(full_frames.copy())
268
+ except ValueError as e:
269
+ continue
270
+
271
+ batch_size = args.wav2lip_batch_size
272
+ gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
273
+
274
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
275
+ if i == 0:
276
+ frame_h, frame_w = full_frames[0].shape[:-1]
277
+
278
+ out = cv2.VideoWriter('../temp/result.avi',
279
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
280
+
281
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
282
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
283
+
284
+ with torch.no_grad():
285
+ pred = model(mel_batch, img_batch)
286
+
287
+
288
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
289
+
290
+ for pl, f, c in zip(pred, frames, coords):
291
+ y1, y2, x1, x2 = c
292
+ pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
293
+ f[y1:y2, x1:x2] = pl
294
+ out.write(f)
295
+
296
+ out.release()
297
+
298
+ vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
299
+ command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav',
300
+ '../temp/result.avi', vid)
301
+ subprocess.call(command, shell=True)
302
+
303
+
304
+ if __name__ == '__main__':
305
+ main()
evaluation/scores_LSE/SyncNetInstance_calc_scores.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+ # Video 25 FPS, Audio 16000HZ
4
+
5
+ import torch
6
+ import numpy
7
+ import time, pdb, argparse, subprocess, os, math, glob
8
+ import cv2
9
+ import python_speech_features
10
+
11
+ from scipy import signal
12
+ from scipy.io import wavfile
13
+ from SyncNetModel import *
14
+ from shutil import rmtree
15
+
16
+
17
+ # ==================== Get OFFSET ====================
18
+
19
+ def calc_pdist(feat1, feat2, vshift=10):
20
+
21
+ win_size = vshift*2+1
22
+
23
+ feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
24
+
25
+ dists = []
26
+
27
+ for i in range(0,len(feat1)):
28
+
29
+ dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:]))
30
+
31
+ return dists
32
+
33
+ # ==================== MAIN DEF ====================
34
+
35
+ class SyncNetInstance(torch.nn.Module):
36
+
37
+ def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
38
+ super(SyncNetInstance, self).__init__();
39
+
40
+ self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
41
+
42
+ def evaluate(self, opt, videofile):
43
+
44
+ self.__S__.eval();
45
+
46
+ # ========== ==========
47
+ # Convert files
48
+ # ========== ==========
49
+
50
+ if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)):
51
+ rmtree(os.path.join(opt.tmp_dir,opt.reference))
52
+
53
+ os.makedirs(os.path.join(opt.tmp_dir,opt.reference))
54
+
55
+ command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg')))
56
+ output = subprocess.call(command, shell=True, stdout=None)
57
+
58
+ command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav')))
59
+ output = subprocess.call(command, shell=True, stdout=None)
60
+
61
+ # ========== ==========
62
+ # Load video
63
+ # ========== ==========
64
+
65
+ images = []
66
+
67
+ flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg'))
68
+ flist.sort()
69
+
70
+ for fname in flist:
71
+ img_input = cv2.imread(fname)
72
+ img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE
73
+ images.append(img_input)
74
+
75
+ im = numpy.stack(images,axis=3)
76
+ im = numpy.expand_dims(im,axis=0)
77
+ im = numpy.transpose(im,(0,3,4,1,2))
78
+
79
+ imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
80
+
81
+ # ========== ==========
82
+ # Load audio
83
+ # ========== ==========
84
+
85
+ sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))
86
+ mfcc = zip(*python_speech_features.mfcc(audio,sample_rate))
87
+ mfcc = numpy.stack([numpy.array(i) for i in mfcc])
88
+
89
+ cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
90
+ cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
91
+
92
+ # ========== ==========
93
+ # Check audio and video input length
94
+ # ========== ==========
95
+
96
+ #if (float(len(audio))/16000) != (float(len(images))/25) :
97
+ # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
98
+
99
+ min_length = min(len(images),math.floor(len(audio)/640))
100
+
101
+ # ========== ==========
102
+ # Generate video and audio feats
103
+ # ========== ==========
104
+
105
+ lastframe = min_length-5
106
+ im_feat = []
107
+ cc_feat = []
108
+
109
+ tS = time.time()
110
+ for i in range(0,lastframe,opt.batch_size):
111
+
112
+ im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
113
+ im_in = torch.cat(im_batch,0)
114
+ im_out = self.__S__.forward_lip(im_in.cuda());
115
+ im_feat.append(im_out.data.cpu())
116
+
117
+ cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
118
+ cc_in = torch.cat(cc_batch,0)
119
+ cc_out = self.__S__.forward_aud(cc_in.cuda())
120
+ cc_feat.append(cc_out.data.cpu())
121
+
122
+ im_feat = torch.cat(im_feat,0)
123
+ cc_feat = torch.cat(cc_feat,0)
124
+
125
+ # ========== ==========
126
+ # Compute offset
127
+ # ========== ==========
128
+
129
+ #print('Compute time %.3f sec.' % (time.time()-tS))
130
+
131
+ dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift)
132
+ mdist = torch.mean(torch.stack(dists,1),1)
133
+
134
+ minval, minidx = torch.min(mdist,0)
135
+
136
+ offset = opt.vshift-minidx
137
+ conf = torch.median(mdist) - minval
138
+
139
+ fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
140
+ # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
141
+ fconf = torch.median(mdist).numpy() - fdist
142
+ fconfm = signal.medfilt(fconf,kernel_size=9)
143
+
144
+ numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
145
+ #print('Framewise conf: ')
146
+ #print(fconfm)
147
+ #print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf))
148
+
149
+ dists_npy = numpy.array([ dist.numpy() for dist in dists ])
150
+ return offset.numpy(), conf.numpy(), minval.numpy()
151
+
152
+ def extract_feature(self, opt, videofile):
153
+
154
+ self.__S__.eval();
155
+
156
+ # ========== ==========
157
+ # Load video
158
+ # ========== ==========
159
+ cap = cv2.VideoCapture(videofile)
160
+
161
+ frame_num = 1;
162
+ images = []
163
+ while frame_num:
164
+ frame_num += 1
165
+ ret, image = cap.read()
166
+ if ret == 0:
167
+ break
168
+
169
+ images.append(image)
170
+
171
+ im = numpy.stack(images,axis=3)
172
+ im = numpy.expand_dims(im,axis=0)
173
+ im = numpy.transpose(im,(0,3,4,1,2))
174
+
175
+ imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
176
+
177
+ # ========== ==========
178
+ # Generate video feats
179
+ # ========== ==========
180
+
181
+ lastframe = len(images)-4
182
+ im_feat = []
183
+
184
+ tS = time.time()
185
+ for i in range(0,lastframe,opt.batch_size):
186
+
187
+ im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
188
+ im_in = torch.cat(im_batch,0)
189
+ im_out = self.__S__.forward_lipfeat(im_in.cuda());
190
+ im_feat.append(im_out.data.cpu())
191
+
192
+ im_feat = torch.cat(im_feat,0)
193
+
194
+ # ========== ==========
195
+ # Compute offset
196
+ # ========== ==========
197
+
198
+ print('Compute time %.3f sec.' % (time.time()-tS))
199
+
200
+ return im_feat
201
+
202
+
203
+ def loadParameters(self, path):
204
+ loaded_state = torch.load(path, map_location=lambda storage, loc: storage);
205
+
206
+ self_state = self.__S__.state_dict();
207
+
208
+ for name, param in loaded_state.items():
209
+
210
+ self_state[name].copy_(param);
evaluation/scores_LSE/calculate_scores_LRS.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import time, pdb, argparse, subprocess
5
+ import glob
6
+ import os
7
+ from tqdm import tqdm
8
+
9
+ from SyncNetInstance_calc_scores import *
10
+
11
+ # ==================== LOAD PARAMS ====================
12
+
13
+
14
+ parser = argparse.ArgumentParser(description = "SyncNet");
15
+
16
+ parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
17
+ parser.add_argument('--batch_size', type=int, default='20', help='');
18
+ parser.add_argument('--vshift', type=int, default='15', help='');
19
+ parser.add_argument('--data_root', type=str, required=True, help='');
20
+ parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help='');
21
+ parser.add_argument('--reference', type=str, default="demo", help='');
22
+
23
+ opt = parser.parse_args();
24
+
25
+
26
+ # ==================== RUN EVALUATION ====================
27
+
28
+ s = SyncNetInstance();
29
+
30
+ s.loadParameters(opt.initial_model);
31
+ #print("Model %s loaded."%opt.initial_model);
32
+ path = os.path.join(opt.data_root, "*.mp4")
33
+
34
+ all_videos = glob.glob(path)
35
+
36
+ prog_bar = tqdm(range(len(all_videos)))
37
+ avg_confidence = 0.
38
+ avg_min_distance = 0.
39
+
40
+
41
+ for videofile_idx in prog_bar:
42
+ videofile = all_videos[videofile_idx]
43
+ offset, confidence, min_distance = s.evaluate(opt, videofile=videofile)
44
+ avg_confidence += confidence
45
+ avg_min_distance += min_distance
46
+ prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3)))
47
+ prog_bar.refresh()
48
+
49
+ print ('Average Confidence: {}'.format(avg_confidence/len(all_videos)))
50
+ print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos)))
51
+
52
+
53
+
evaluation/scores_LSE/calculate_scores_real_videos.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #-*- coding: utf-8 -*-
3
+
4
+ import time, pdb, argparse, subprocess, pickle, os, gzip, glob
5
+
6
+ from SyncNetInstance_calc_scores import *
7
+
8
+ # ==================== PARSE ARGUMENT ====================
9
+
10
+ parser = argparse.ArgumentParser(description = "SyncNet");
11
+ parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='');
12
+ parser.add_argument('--batch_size', type=int, default='20', help='');
13
+ parser.add_argument('--vshift', type=int, default='15', help='');
14
+ parser.add_argument('--data_dir', type=str, default='data/work', help='');
15
+ parser.add_argument('--videofile', type=str, default='', help='');
16
+ parser.add_argument('--reference', type=str, default='', help='');
17
+ opt = parser.parse_args();
18
+
19
+ setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi'))
20
+ setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp'))
21
+ setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork'))
22
+ setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop'))
23
+
24
+
25
+ # ==================== LOAD MODEL AND FILE LIST ====================
26
+
27
+ s = SyncNetInstance();
28
+
29
+ s.loadParameters(opt.initial_model);
30
+ #print("Model %s loaded."%opt.initial_model);
31
+
32
+ flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi'))
33
+ flist.sort()
34
+
35
+ # ==================== GET OFFSETS ====================
36
+
37
+ dists = []
38
+ for idx, fname in enumerate(flist):
39
+ offset, conf, dist = s.evaluate(opt,videofile=fname)
40
+ print (str(dist)+" "+str(conf))
41
+
42
+ # ==================== PRINT RESULTS TO FILE ====================
43
+
44
+ #with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil:
45
+ # pickle.dump(dists, fil)
evaluation/scores_LSE/calculate_scores_real_videos.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ rm all_scores.txt
2
+ yourfilenames=`ls $1`
3
+
4
+ for eachfile in $yourfilenames
5
+ do
6
+ python run_pipeline.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir
7
+ python calculate_scores_real_videos.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir >> all_scores.txt
8
+ done
evaluation/test_filelists/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This folder contains the filelists for the new evaluation framework proposed in the paper.
2
+
3
+ ## Test filelists for LRS2, LRS3, and LRW.
4
+
5
+ This folder contains three filelists, each containing a list of names of audio-video pairs from the test sets of LRS2, LRS3, and LRW. The LRS2 and LRW filelists are strictly "Copyright BBC" and can only be used for “non-commercial research by applicants who have an agreement with the BBC to access the Lip Reading in the Wild and/or Lip Reading Sentences in the Wild datasets”. Please follow this link for more details: [https://www.bbc.co.uk/rd/projects/lip-reading-datasets](https://www.bbc.co.uk/rd/projects/lip-reading-datasets).
6
+
7
+
8
+ ## ReSynCED benchmark
9
+
10
+ The sub-folder `ReSynCED` contains filelists for our own Real-world lip-Sync Evaluation Dataset (ReSyncED).
11
+
12
+
13
+ #### Instructions on how to use the above two filelists are available in the README of the parent folder.
evaluation/test_filelists/ReSyncED/random_pairs.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sachin.mp4 emma_cropped.mp4
2
+ sachin.mp4 mourinho.mp4
3
+ sachin.mp4 elon.mp4
4
+ sachin.mp4 messi2.mp4
5
+ sachin.mp4 cr1.mp4
6
+ sachin.mp4 sachin.mp4
7
+ sachin.mp4 sg.mp4
8
+ sachin.mp4 fergi.mp4
9
+ sachin.mp4 spanish_lec1.mp4
10
+ sachin.mp4 bush_small.mp4
11
+ sachin.mp4 macca_cut.mp4
12
+ sachin.mp4 ca_cropped.mp4
13
+ sachin.mp4 lecun.mp4
14
+ sachin.mp4 spanish_lec0.mp4
15
+ srk.mp4 emma_cropped.mp4
16
+ srk.mp4 mourinho.mp4
17
+ srk.mp4 elon.mp4
18
+ srk.mp4 messi2.mp4
19
+ srk.mp4 cr1.mp4
20
+ srk.mp4 srk.mp4
21
+ srk.mp4 sachin.mp4
22
+ srk.mp4 sg.mp4
23
+ srk.mp4 fergi.mp4
24
+ srk.mp4 spanish_lec1.mp4
25
+ srk.mp4 bush_small.mp4
26
+ srk.mp4 macca_cut.mp4
27
+ srk.mp4 ca_cropped.mp4
28
+ srk.mp4 guardiola.mp4
29
+ srk.mp4 lecun.mp4
30
+ srk.mp4 spanish_lec0.mp4
31
+ cr1.mp4 emma_cropped.mp4
32
+ cr1.mp4 elon.mp4
33
+ cr1.mp4 messi2.mp4
34
+ cr1.mp4 cr1.mp4
35
+ cr1.mp4 spanish_lec1.mp4
36
+ cr1.mp4 bush_small.mp4
37
+ cr1.mp4 macca_cut.mp4
38
+ cr1.mp4 ca_cropped.mp4
39
+ cr1.mp4 lecun.mp4
40
+ cr1.mp4 spanish_lec0.mp4
41
+ macca_cut.mp4 emma_cropped.mp4
42
+ macca_cut.mp4 elon.mp4
43
+ macca_cut.mp4 messi2.mp4
44
+ macca_cut.mp4 spanish_lec1.mp4
45
+ macca_cut.mp4 macca_cut.mp4
46
+ macca_cut.mp4 ca_cropped.mp4
47
+ macca_cut.mp4 spanish_lec0.mp4
48
+ lecun.mp4 emma_cropped.mp4
49
+ lecun.mp4 elon.mp4
50
+ lecun.mp4 messi2.mp4
51
+ lecun.mp4 spanish_lec1.mp4
52
+ lecun.mp4 macca_cut.mp4
53
+ lecun.mp4 ca_cropped.mp4
54
+ lecun.mp4 lecun.mp4
55
+ lecun.mp4 spanish_lec0.mp4
56
+ messi2.mp4 emma_cropped.mp4
57
+ messi2.mp4 elon.mp4
58
+ messi2.mp4 messi2.mp4
59
+ messi2.mp4 spanish_lec1.mp4
60
+ messi2.mp4 macca_cut.mp4
61
+ messi2.mp4 ca_cropped.mp4
62
+ messi2.mp4 spanish_lec0.mp4
63
+ ca_cropped.mp4 emma_cropped.mp4
64
+ ca_cropped.mp4 elon.mp4
65
+ ca_cropped.mp4 spanish_lec1.mp4
66
+ ca_cropped.mp4 ca_cropped.mp4
67
+ ca_cropped.mp4 spanish_lec0.mp4
68
+ spanish_lec1.mp4 spanish_lec1.mp4
69
+ spanish_lec1.mp4 spanish_lec0.mp4
70
+ elon.mp4 elon.mp4
71
+ elon.mp4 spanish_lec1.mp4
72
+ elon.mp4 spanish_lec0.mp4
73
+ guardiola.mp4 emma_cropped.mp4
74
+ guardiola.mp4 mourinho.mp4
75
+ guardiola.mp4 elon.mp4
76
+ guardiola.mp4 messi2.mp4
77
+ guardiola.mp4 cr1.mp4
78
+ guardiola.mp4 sachin.mp4
79
+ guardiola.mp4 sg.mp4
80
+ guardiola.mp4 fergi.mp4
81
+ guardiola.mp4 spanish_lec1.mp4
82
+ guardiola.mp4 bush_small.mp4
83
+ guardiola.mp4 macca_cut.mp4
84
+ guardiola.mp4 ca_cropped.mp4
85
+ guardiola.mp4 guardiola.mp4
86
+ guardiola.mp4 lecun.mp4
87
+ guardiola.mp4 spanish_lec0.mp4
88
+ fergi.mp4 emma_cropped.mp4
89
+ fergi.mp4 mourinho.mp4
90
+ fergi.mp4 elon.mp4
91
+ fergi.mp4 messi2.mp4
92
+ fergi.mp4 cr1.mp4
93
+ fergi.mp4 sachin.mp4
94
+ fergi.mp4 sg.mp4
95
+ fergi.mp4 fergi.mp4
96
+ fergi.mp4 spanish_lec1.mp4
97
+ fergi.mp4 bush_small.mp4
98
+ fergi.mp4 macca_cut.mp4
99
+ fergi.mp4 ca_cropped.mp4
100
+ fergi.mp4 lecun.mp4
101
+ fergi.mp4 spanish_lec0.mp4
102
+ spanish.mp4 emma_cropped.mp4
103
+ spanish.mp4 spanish.mp4
104
+ spanish.mp4 mourinho.mp4
105
+ spanish.mp4 elon.mp4
106
+ spanish.mp4 messi2.mp4
107
+ spanish.mp4 cr1.mp4
108
+ spanish.mp4 srk.mp4
109
+ spanish.mp4 sachin.mp4
110
+ spanish.mp4 sg.mp4
111
+ spanish.mp4 fergi.mp4
112
+ spanish.mp4 spanish_lec1.mp4
113
+ spanish.mp4 bush_small.mp4
114
+ spanish.mp4 macca_cut.mp4
115
+ spanish.mp4 ca_cropped.mp4
116
+ spanish.mp4 guardiola.mp4
117
+ spanish.mp4 lecun.mp4
118
+ spanish.mp4 spanish_lec0.mp4
119
+ bush_small.mp4 emma_cropped.mp4
120
+ bush_small.mp4 elon.mp4
121
+ bush_small.mp4 messi2.mp4
122
+ bush_small.mp4 spanish_lec1.mp4
123
+ bush_small.mp4 bush_small.mp4
124
+ bush_small.mp4 macca_cut.mp4
125
+ bush_small.mp4 ca_cropped.mp4
126
+ bush_small.mp4 lecun.mp4
127
+ bush_small.mp4 spanish_lec0.mp4
128
+ emma_cropped.mp4 emma_cropped.mp4
129
+ emma_cropped.mp4 elon.mp4
130
+ emma_cropped.mp4 spanish_lec1.mp4
131
+ emma_cropped.mp4 spanish_lec0.mp4
132
+ sg.mp4 emma_cropped.mp4
133
+ sg.mp4 mourinho.mp4
134
+ sg.mp4 elon.mp4
135
+ sg.mp4 messi2.mp4
136
+ sg.mp4 cr1.mp4
137
+ sg.mp4 sachin.mp4
138
+ sg.mp4 sg.mp4
139
+ sg.mp4 fergi.mp4
140
+ sg.mp4 spanish_lec1.mp4
141
+ sg.mp4 bush_small.mp4
142
+ sg.mp4 macca_cut.mp4
143
+ sg.mp4 ca_cropped.mp4
144
+ sg.mp4 lecun.mp4
145
+ sg.mp4 spanish_lec0.mp4
146
+ spanish_lec0.mp4 spanish_lec0.mp4
147
+ mourinho.mp4 emma_cropped.mp4
148
+ mourinho.mp4 mourinho.mp4
149
+ mourinho.mp4 elon.mp4
150
+ mourinho.mp4 messi2.mp4
151
+ mourinho.mp4 cr1.mp4
152
+ mourinho.mp4 sachin.mp4
153
+ mourinho.mp4 sg.mp4
154
+ mourinho.mp4 fergi.mp4
155
+ mourinho.mp4 spanish_lec1.mp4
156
+ mourinho.mp4 bush_small.mp4
157
+ mourinho.mp4 macca_cut.mp4
158
+ mourinho.mp4 ca_cropped.mp4
159
+ mourinho.mp4 lecun.mp4
160
+ mourinho.mp4 spanish_lec0.mp4
evaluation/test_filelists/ReSyncED/tts_pairs.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_1.mp4 andreng_optimization.wav
2
+ agad_2.mp4 agad_2.wav
3
+ agad_1.mp4 agad_1.wav
4
+ agad_3.mp4 agad_3.wav
5
+ rms_prop_1.mp4 rms_prop_tts.wav
6
+ tf_1.mp4 tf_1.wav
7
+ tf_2.mp4 tf_2.wav
8
+ andrew_ng_ai_business.mp4 andrewng_business_tts.wav
9
+ covid_autopsy_1.mp4 autopsy_tts.wav
10
+ news_1.mp4 news_tts.wav
11
+ andrew_ng_fund_1.mp4 andrewng_ai_fund.wav
12
+ covid_treatments_1.mp4 covid_tts.wav
13
+ pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav
14
+ pytorch_1.mp4 pytorch.wav
15
+ pkb_1.mp4 pkb_1.wav
16
+ ss_1.mp4 ss_1.wav
17
+ carlsen_1.mp4 carlsen_eng.wav
18
+ french.mp4 french.wav
evaluation/test_filelists/lrs2.txt ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/test_filelists/lrs3.txt ADDED
The diff for this file is too large to render. See raw diff
 
evaluation/test_filelists/lrw.txt ADDED
The diff for this file is too large to render. See raw diff
 
filelists/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Place LRS2 (and any other) filelists here for training.
hq_wav2lip_train.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip, Wav2Lip_disc_qual
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch import optim
12
+ import torch.backends.cudnn as cudnn
13
+ from torch.utils import data as data_utils
14
+ import numpy as np
15
+
16
+ from glob import glob
17
+
18
+ import os, random, cv2, argparse
19
+ from hparams import hparams, get_image_list
20
+
21
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
22
+
23
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
24
+
25
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
26
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
27
+
28
+ parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
29
+ parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
30
+
31
+ args = parser.parse_args()
32
+
33
+
34
+ global_step = 0
35
+ global_epoch = 0
36
+ use_cuda = torch.cuda.is_available()
37
+ print('use_cuda: {}'.format(use_cuda))
38
+
39
+ syncnet_T = 5
40
+ syncnet_mel_step_size = 16
41
+
42
+ class Dataset(object):
43
+ def __init__(self, split):
44
+ self.all_videos = get_image_list(args.data_root, split)
45
+
46
+ def get_frame_id(self, frame):
47
+ return int(basename(frame).split('.')[0])
48
+
49
+ def get_window(self, start_frame):
50
+ start_id = self.get_frame_id(start_frame)
51
+ vidname = dirname(start_frame)
52
+
53
+ window_fnames = []
54
+ for frame_id in range(start_id, start_id + syncnet_T):
55
+ frame = join(vidname, '{}.jpg'.format(frame_id))
56
+ if not isfile(frame):
57
+ return None
58
+ window_fnames.append(frame)
59
+ return window_fnames
60
+
61
+ def read_window(self, window_fnames):
62
+ if window_fnames is None: return None
63
+ window = []
64
+ for fname in window_fnames:
65
+ img = cv2.imread(fname)
66
+ if img is None:
67
+ return None
68
+ try:
69
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
70
+ except Exception as e:
71
+ return None
72
+
73
+ window.append(img)
74
+
75
+ return window
76
+
77
+ def crop_audio_window(self, spec, start_frame):
78
+ if type(start_frame) == int:
79
+ start_frame_num = start_frame
80
+ else:
81
+ start_frame_num = self.get_frame_id(start_frame)
82
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
83
+
84
+ end_idx = start_idx + syncnet_mel_step_size
85
+
86
+ return spec[start_idx : end_idx, :]
87
+
88
+ def get_segmented_mels(self, spec, start_frame):
89
+ mels = []
90
+ assert syncnet_T == 5
91
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
92
+ if start_frame_num - 2 < 0: return None
93
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
94
+ m = self.crop_audio_window(spec, i - 2)
95
+ if m.shape[0] != syncnet_mel_step_size:
96
+ return None
97
+ mels.append(m.T)
98
+
99
+ mels = np.asarray(mels)
100
+
101
+ return mels
102
+
103
+ def prepare_window(self, window):
104
+ # 3 x T x H x W
105
+ x = np.asarray(window) / 255.
106
+ x = np.transpose(x, (3, 0, 1, 2))
107
+
108
+ return x
109
+
110
+ def __len__(self):
111
+ return len(self.all_videos)
112
+
113
+ def __getitem__(self, idx):
114
+ while 1:
115
+ idx = random.randint(0, len(self.all_videos) - 1)
116
+ vidname = self.all_videos[idx]
117
+ img_names = list(glob(join(vidname, '*.jpg')))
118
+ if len(img_names) <= 3 * syncnet_T:
119
+ continue
120
+
121
+ img_name = random.choice(img_names)
122
+ wrong_img_name = random.choice(img_names)
123
+ while wrong_img_name == img_name:
124
+ wrong_img_name = random.choice(img_names)
125
+
126
+ window_fnames = self.get_window(img_name)
127
+ wrong_window_fnames = self.get_window(wrong_img_name)
128
+ if window_fnames is None or wrong_window_fnames is None:
129
+ continue
130
+
131
+ window = self.read_window(window_fnames)
132
+ if window is None:
133
+ continue
134
+
135
+ wrong_window = self.read_window(wrong_window_fnames)
136
+ if wrong_window is None:
137
+ continue
138
+
139
+ try:
140
+ wavpath = join(vidname, "audio.wav")
141
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
142
+
143
+ orig_mel = audio.melspectrogram(wav).T
144
+ except Exception as e:
145
+ continue
146
+
147
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
148
+
149
+ if (mel.shape[0] != syncnet_mel_step_size):
150
+ continue
151
+
152
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
153
+ if indiv_mels is None: continue
154
+
155
+ window = self.prepare_window(window)
156
+ y = window.copy()
157
+ window[:, :, window.shape[2]//2:] = 0.
158
+
159
+ wrong_window = self.prepare_window(wrong_window)
160
+ x = np.concatenate([window, wrong_window], axis=0)
161
+
162
+ x = torch.FloatTensor(x)
163
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
164
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
165
+ y = torch.FloatTensor(y)
166
+ return x, indiv_mels, mel, y
167
+
168
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
169
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
171
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
172
+
173
+ refs, inps = x[..., 3:], x[..., :3]
174
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
175
+ if not os.path.exists(folder): os.mkdir(folder)
176
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
177
+ for batch_idx, c in enumerate(collage):
178
+ for t in range(len(c)):
179
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
180
+
181
+ logloss = nn.BCELoss()
182
+ def cosine_loss(a, v, y):
183
+ d = nn.functional.cosine_similarity(a, v)
184
+ loss = logloss(d.unsqueeze(1), y)
185
+
186
+ return loss
187
+
188
+ device = torch.device("cuda" if use_cuda else "cpu")
189
+ syncnet = SyncNet().to(device)
190
+ for p in syncnet.parameters():
191
+ p.requires_grad = False
192
+
193
+ recon_loss = nn.L1Loss()
194
+ def get_sync_loss(mel, g):
195
+ g = g[:, :, :, g.size(3)//2:]
196
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
197
+ # B, 3 * T, H//2, W
198
+ a, v = syncnet(mel, g)
199
+ y = torch.ones(g.size(0), 1).float().to(device)
200
+ return cosine_loss(a, v, y)
201
+
202
+ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
203
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
204
+ global global_step, global_epoch
205
+ resumed_step = global_step
206
+
207
+ while global_epoch < nepochs:
208
+ print('Starting Epoch: {}'.format(global_epoch))
209
+ running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
210
+ running_disc_real_loss, running_disc_fake_loss = 0., 0.
211
+ prog_bar = tqdm(enumerate(train_data_loader))
212
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
213
+ disc.train()
214
+ model.train()
215
+
216
+ x = x.to(device)
217
+ mel = mel.to(device)
218
+ indiv_mels = indiv_mels.to(device)
219
+ gt = gt.to(device)
220
+
221
+ ### Train generator now. Remove ALL grads.
222
+ optimizer.zero_grad()
223
+ disc_optimizer.zero_grad()
224
+
225
+ g = model(indiv_mels, x)
226
+
227
+ if hparams.syncnet_wt > 0.:
228
+ sync_loss = get_sync_loss(mel, g)
229
+ else:
230
+ sync_loss = 0.
231
+
232
+ if hparams.disc_wt > 0.:
233
+ perceptual_loss = disc.perceptual_forward(g)
234
+ else:
235
+ perceptual_loss = 0.
236
+
237
+ l1loss = recon_loss(g, gt)
238
+
239
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
240
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
241
+
242
+ loss.backward()
243
+ optimizer.step()
244
+
245
+ ### Remove all gradients before Training disc
246
+ disc_optimizer.zero_grad()
247
+
248
+ pred = disc(gt)
249
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
250
+ disc_real_loss.backward()
251
+
252
+ pred = disc(g.detach())
253
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
254
+ disc_fake_loss.backward()
255
+
256
+ disc_optimizer.step()
257
+
258
+ running_disc_real_loss += disc_real_loss.item()
259
+ running_disc_fake_loss += disc_fake_loss.item()
260
+
261
+ if global_step % checkpoint_interval == 0:
262
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
263
+
264
+ # Logs
265
+ global_step += 1
266
+ cur_session_steps = global_step - resumed_step
267
+
268
+ running_l1_loss += l1loss.item()
269
+ if hparams.syncnet_wt > 0.:
270
+ running_sync_loss += sync_loss.item()
271
+ else:
272
+ running_sync_loss += 0.
273
+
274
+ if hparams.disc_wt > 0.:
275
+ running_perceptual_loss += perceptual_loss.item()
276
+ else:
277
+ running_perceptual_loss += 0.
278
+
279
+ if global_step == 1 or global_step % checkpoint_interval == 0:
280
+ save_checkpoint(
281
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
282
+ save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
283
+
284
+
285
+ if global_step % hparams.eval_interval == 0:
286
+ with torch.no_grad():
287
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
288
+
289
+ if average_sync_loss < .75:
290
+ hparams.set_hparam('syncnet_wt', 0.03)
291
+
292
+ prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
293
+ running_sync_loss / (step + 1),
294
+ running_perceptual_loss / (step + 1),
295
+ running_disc_fake_loss / (step + 1),
296
+ running_disc_real_loss / (step + 1)))
297
+
298
+ global_epoch += 1
299
+
300
+ def eval_model(test_data_loader, global_step, device, model, disc):
301
+ eval_steps = 300
302
+ print('Evaluating for {} steps'.format(eval_steps))
303
+ running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
304
+ while 1:
305
+ for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
306
+ model.eval()
307
+ disc.eval()
308
+
309
+ x = x.to(device)
310
+ mel = mel.to(device)
311
+ indiv_mels = indiv_mels.to(device)
312
+ gt = gt.to(device)
313
+
314
+ pred = disc(gt)
315
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
316
+
317
+ g = model(indiv_mels, x)
318
+ pred = disc(g)
319
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
320
+
321
+ running_disc_real_loss.append(disc_real_loss.item())
322
+ running_disc_fake_loss.append(disc_fake_loss.item())
323
+
324
+ sync_loss = get_sync_loss(mel, g)
325
+
326
+ if hparams.disc_wt > 0.:
327
+ perceptual_loss = disc.perceptual_forward(g)
328
+ else:
329
+ perceptual_loss = 0.
330
+
331
+ l1loss = recon_loss(g, gt)
332
+
333
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
334
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
335
+
336
+ running_l1_loss.append(l1loss.item())
337
+ running_sync_loss.append(sync_loss.item())
338
+
339
+ if hparams.disc_wt > 0.:
340
+ running_perceptual_loss.append(perceptual_loss.item())
341
+ else:
342
+ running_perceptual_loss.append(0.)
343
+
344
+ if step > eval_steps: break
345
+
346
+ print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
347
+ sum(running_sync_loss) / len(running_sync_loss),
348
+ sum(running_perceptual_loss) / len(running_perceptual_loss),
349
+ sum(running_disc_fake_loss) / len(running_disc_fake_loss),
350
+ sum(running_disc_real_loss) / len(running_disc_real_loss)))
351
+ return sum(running_sync_loss) / len(running_sync_loss)
352
+
353
+
354
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
355
+ checkpoint_path = join(
356
+ checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
357
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
358
+ torch.save({
359
+ "state_dict": model.state_dict(),
360
+ "optimizer": optimizer_state,
361
+ "global_step": step,
362
+ "global_epoch": epoch,
363
+ }, checkpoint_path)
364
+ print("Saved checkpoint:", checkpoint_path)
365
+
366
+ def _load(checkpoint_path):
367
+ if use_cuda:
368
+ checkpoint = torch.load(checkpoint_path)
369
+ else:
370
+ checkpoint = torch.load(checkpoint_path,
371
+ map_location=lambda storage, loc: storage)
372
+ return checkpoint
373
+
374
+
375
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
376
+ global global_step
377
+ global global_epoch
378
+
379
+ print("Load checkpoint from: {}".format(path))
380
+ checkpoint = _load(path)
381
+ s = checkpoint["state_dict"]
382
+ new_s = {}
383
+ for k, v in s.items():
384
+ new_s[k.replace('module.', '')] = v
385
+ model.load_state_dict(new_s)
386
+ if not reset_optimizer:
387
+ optimizer_state = checkpoint["optimizer"]
388
+ if optimizer_state is not None:
389
+ print("Load optimizer state from {}".format(path))
390
+ optimizer.load_state_dict(checkpoint["optimizer"])
391
+ if overwrite_global_states:
392
+ global_step = checkpoint["global_step"]
393
+ global_epoch = checkpoint["global_epoch"]
394
+
395
+ return model
396
+
397
+ if __name__ == "__main__":
398
+ checkpoint_dir = args.checkpoint_dir
399
+
400
+ # Dataset and Dataloader setup
401
+ train_dataset = Dataset('train')
402
+ test_dataset = Dataset('val')
403
+
404
+ train_data_loader = data_utils.DataLoader(
405
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
406
+ num_workers=hparams.num_workers)
407
+
408
+ test_data_loader = data_utils.DataLoader(
409
+ test_dataset, batch_size=hparams.batch_size,
410
+ num_workers=4)
411
+
412
+ device = torch.device("cuda" if use_cuda else "cpu")
413
+
414
+ # Model
415
+ model = Wav2Lip().to(device)
416
+ disc = Wav2Lip_disc_qual().to(device)
417
+
418
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
419
+ print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
420
+
421
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
422
+ lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
423
+ disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
424
+ lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
425
+
426
+ if args.checkpoint_path is not None:
427
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
428
+
429
+ if args.disc_checkpoint_path is not None:
430
+ load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
431
+ reset_optimizer=False, overwrite_global_states=False)
432
+
433
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
434
+ overwrite_global_states=False)
435
+
436
+ if not os.path.exists(checkpoint_dir):
437
+ os.mkdir(checkpoint_dir)
438
+
439
+ # Train!
440
+ train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
441
+ checkpoint_dir=checkpoint_dir,
442
+ checkpoint_interval=hparams.checkpoint_interval,
443
+ nepochs=hparams.nepochs)
preprocess.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ if sys.version_info[0] < 3 and sys.version_info[1] < 2:
4
+ raise Exception("Must be using >= Python 3.2")
5
+
6
+ from os import listdir, path
7
+
8
+ if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
9
+ raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
10
+ before running this script!')
11
+
12
+ import multiprocessing as mp
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+ import numpy as np
15
+ import argparse, os, cv2, traceback, subprocess
16
+ from tqdm import tqdm
17
+ from glob import glob
18
+ import audio
19
+ from hparams import hparams as hp
20
+
21
+ import face_detection
22
+
23
+ parser = argparse.ArgumentParser()
24
+
25
+ parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
26
+ parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
27
+ parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
28
+ parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
29
+
30
+ args = parser.parse_args()
31
+
32
+ fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
33
+ device='cuda:{}'.format(id)) for id in range(args.ngpu)]
34
+
35
+ template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
36
+ # template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
37
+
38
+ def process_video_file(vfile, args, gpu_id):
39
+ video_stream = cv2.VideoCapture(vfile)
40
+
41
+ frames = []
42
+ while 1:
43
+ still_reading, frame = video_stream.read()
44
+ if not still_reading:
45
+ video_stream.release()
46
+ break
47
+ frames.append(frame)
48
+
49
+ vidname = os.path.basename(vfile).split('.')[0]
50
+ dirname = vfile.split('/')[-2]
51
+
52
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
53
+ os.makedirs(fulldir, exist_ok=True)
54
+
55
+ batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
56
+
57
+ i = -1
58
+ for fb in batches:
59
+ preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
60
+
61
+ for j, f in enumerate(preds):
62
+ i += 1
63
+ if f is None:
64
+ continue
65
+
66
+ x1, y1, x2, y2 = f
67
+ cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
68
+
69
+ def process_audio_file(vfile, args):
70
+ vidname = os.path.basename(vfile).split('.')[0]
71
+ dirname = vfile.split('/')[-2]
72
+
73
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
74
+ os.makedirs(fulldir, exist_ok=True)
75
+
76
+ wavpath = path.join(fulldir, 'audio.wav')
77
+
78
+ command = template.format(vfile, wavpath)
79
+ subprocess.call(command, shell=True)
80
+
81
+
82
+ def mp_handler(job):
83
+ vfile, args, gpu_id = job
84
+ try:
85
+ process_video_file(vfile, args, gpu_id)
86
+ except KeyboardInterrupt:
87
+ exit(0)
88
+ except:
89
+ traceback.print_exc()
90
+
91
+ def main(args):
92
+ print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
93
+
94
+ filelist = glob(path.join(args.data_root, '*/*.mp4'))
95
+
96
+ jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
97
+ p = ThreadPoolExecutor(args.ngpu)
98
+ futures = [p.submit(mp_handler, j) for j in jobs]
99
+ _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
100
+
101
+ print('Dumping audios...')
102
+
103
+ for vfile in tqdm(filelist):
104
+ try:
105
+ process_audio_file(vfile, args)
106
+ except KeyboardInterrupt:
107
+ exit(0)
108
+ except:
109
+ traceback.print_exc()
110
+ continue
111
+
112
+ if __name__ == '__main__':
113
+ main(args)
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
- torch
2
- numpy
3
- scipy
4
- opencv-python-headless
5
- moviepy
6
- numba
7
- pillow
8
- pydub
9
- soundfile
10
- gradio
 
1
+ librosa==0.7.0
2
+ numpy==1.17.1
3
+ opencv-contrib-python>=4.2.0.34
4
+ opencv-python==4.1.0.25
5
+ torch==1.1.0
6
+ torchvision==0.3.0
7
+ tqdm==4.45.0
8
+ numba==0.48
 
 
results/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Generated results will be placed in this folder by default.
temp/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Temporary files at the time of inference/testing will be saved here. You can ignore them.
wav2lip_train.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip as Wav2Lip
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch import optim
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils import data as data_utils
13
+ import numpy as np
14
+
15
+ from glob import glob
16
+
17
+ import os, random, cv2, argparse
18
+ from hparams import hparams, get_image_list
19
+
20
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
21
+
22
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
23
+
24
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
25
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
26
+
27
+ parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
28
+
29
+ args = parser.parse_args()
30
+
31
+
32
+ global_step = 0
33
+ global_epoch = 0
34
+ use_cuda = torch.cuda.is_available()
35
+ print('use_cuda: {}'.format(use_cuda))
36
+
37
+ syncnet_T = 5
38
+ syncnet_mel_step_size = 16
39
+
40
+ class Dataset(object):
41
+ def __init__(self, split):
42
+ self.all_videos = get_image_list(args.data_root, split)
43
+
44
+ def get_frame_id(self, frame):
45
+ return int(basename(frame).split('.')[0])
46
+
47
+ def get_window(self, start_frame):
48
+ start_id = self.get_frame_id(start_frame)
49
+ vidname = dirname(start_frame)
50
+
51
+ window_fnames = []
52
+ for frame_id in range(start_id, start_id + syncnet_T):
53
+ frame = join(vidname, '{}.jpg'.format(frame_id))
54
+ if not isfile(frame):
55
+ return None
56
+ window_fnames.append(frame)
57
+ return window_fnames
58
+
59
+ def read_window(self, window_fnames):
60
+ if window_fnames is None: return None
61
+ window = []
62
+ for fname in window_fnames:
63
+ img = cv2.imread(fname)
64
+ if img is None:
65
+ return None
66
+ try:
67
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
68
+ except Exception as e:
69
+ return None
70
+
71
+ window.append(img)
72
+
73
+ return window
74
+
75
+ def crop_audio_window(self, spec, start_frame):
76
+ if type(start_frame) == int:
77
+ start_frame_num = start_frame
78
+ else:
79
+ start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
80
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
81
+
82
+ end_idx = start_idx + syncnet_mel_step_size
83
+
84
+ return spec[start_idx : end_idx, :]
85
+
86
+ def get_segmented_mels(self, spec, start_frame):
87
+ mels = []
88
+ assert syncnet_T == 5
89
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
90
+ if start_frame_num - 2 < 0: return None
91
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
92
+ m = self.crop_audio_window(spec, i - 2)
93
+ if m.shape[0] != syncnet_mel_step_size:
94
+ return None
95
+ mels.append(m.T)
96
+
97
+ mels = np.asarray(mels)
98
+
99
+ return mels
100
+
101
+ def prepare_window(self, window):
102
+ # 3 x T x H x W
103
+ x = np.asarray(window) / 255.
104
+ x = np.transpose(x, (3, 0, 1, 2))
105
+
106
+ return x
107
+
108
+ def __len__(self):
109
+ return len(self.all_videos)
110
+
111
+ def __getitem__(self, idx):
112
+ while 1:
113
+ idx = random.randint(0, len(self.all_videos) - 1)
114
+ vidname = self.all_videos[idx]
115
+ img_names = list(glob(join(vidname, '*.jpg')))
116
+ if len(img_names) <= 3 * syncnet_T:
117
+ continue
118
+
119
+ img_name = random.choice(img_names)
120
+ wrong_img_name = random.choice(img_names)
121
+ while wrong_img_name == img_name:
122
+ wrong_img_name = random.choice(img_names)
123
+
124
+ window_fnames = self.get_window(img_name)
125
+ wrong_window_fnames = self.get_window(wrong_img_name)
126
+ if window_fnames is None or wrong_window_fnames is None:
127
+ continue
128
+
129
+ window = self.read_window(window_fnames)
130
+ if window is None:
131
+ continue
132
+
133
+ wrong_window = self.read_window(wrong_window_fnames)
134
+ if wrong_window is None:
135
+ continue
136
+
137
+ try:
138
+ wavpath = join(vidname, "audio.wav")
139
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
140
+
141
+ orig_mel = audio.melspectrogram(wav).T
142
+ except Exception as e:
143
+ continue
144
+
145
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
146
+
147
+ if (mel.shape[0] != syncnet_mel_step_size):
148
+ continue
149
+
150
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
151
+ if indiv_mels is None: continue
152
+
153
+ window = self.prepare_window(window)
154
+ y = window.copy()
155
+ window[:, :, window.shape[2]//2:] = 0.
156
+
157
+ wrong_window = self.prepare_window(wrong_window)
158
+ x = np.concatenate([window, wrong_window], axis=0)
159
+
160
+ x = torch.FloatTensor(x)
161
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
162
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
163
+ y = torch.FloatTensor(y)
164
+ return x, indiv_mels, mel, y
165
+
166
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
167
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
168
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
169
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+
171
+ refs, inps = x[..., 3:], x[..., :3]
172
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
173
+ if not os.path.exists(folder): os.mkdir(folder)
174
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
175
+ for batch_idx, c in enumerate(collage):
176
+ for t in range(len(c)):
177
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
178
+
179
+ logloss = nn.BCELoss()
180
+ def cosine_loss(a, v, y):
181
+ d = nn.functional.cosine_similarity(a, v)
182
+ loss = logloss(d.unsqueeze(1), y)
183
+
184
+ return loss
185
+
186
+ device = torch.device("cuda" if use_cuda else "cpu")
187
+ syncnet = SyncNet().to(device)
188
+ for p in syncnet.parameters():
189
+ p.requires_grad = False
190
+
191
+ recon_loss = nn.L1Loss()
192
+ def get_sync_loss(mel, g):
193
+ g = g[:, :, :, g.size(3)//2:]
194
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
195
+ # B, 3 * T, H//2, W
196
+ a, v = syncnet(mel, g)
197
+ y = torch.ones(g.size(0), 1).float().to(device)
198
+ return cosine_loss(a, v, y)
199
+
200
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
201
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
202
+
203
+ global global_step, global_epoch
204
+ resumed_step = global_step
205
+
206
+ while global_epoch < nepochs:
207
+ print('Starting Epoch: {}'.format(global_epoch))
208
+ running_sync_loss, running_l1_loss = 0., 0.
209
+ prog_bar = tqdm(enumerate(train_data_loader))
210
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
211
+ model.train()
212
+ optimizer.zero_grad()
213
+
214
+ # Move data to CUDA device
215
+ x = x.to(device)
216
+ mel = mel.to(device)
217
+ indiv_mels = indiv_mels.to(device)
218
+ gt = gt.to(device)
219
+
220
+ g = model(indiv_mels, x)
221
+
222
+ if hparams.syncnet_wt > 0.:
223
+ sync_loss = get_sync_loss(mel, g)
224
+ else:
225
+ sync_loss = 0.
226
+
227
+ l1loss = recon_loss(g, gt)
228
+
229
+ loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ if global_step % checkpoint_interval == 0:
234
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
235
+
236
+ global_step += 1
237
+ cur_session_steps = global_step - resumed_step
238
+
239
+ running_l1_loss += l1loss.item()
240
+ if hparams.syncnet_wt > 0.:
241
+ running_sync_loss += sync_loss.item()
242
+ else:
243
+ running_sync_loss += 0.
244
+
245
+ if global_step == 1 or global_step % checkpoint_interval == 0:
246
+ save_checkpoint(
247
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
248
+
249
+ if global_step == 1 or global_step % hparams.eval_interval == 0:
250
+ with torch.no_grad():
251
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
252
+
253
+ if average_sync_loss < .75:
254
+ hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient
255
+
256
+ prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
257
+ running_sync_loss / (step + 1)))
258
+
259
+ global_epoch += 1
260
+
261
+
262
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
263
+ eval_steps = 700
264
+ print('Evaluating for {} steps'.format(eval_steps))
265
+ sync_losses, recon_losses = [], []
266
+ step = 0
267
+ while 1:
268
+ for x, indiv_mels, mel, gt in test_data_loader:
269
+ step += 1
270
+ model.eval()
271
+
272
+ # Move data to CUDA device
273
+ x = x.to(device)
274
+ gt = gt.to(device)
275
+ indiv_mels = indiv_mels.to(device)
276
+ mel = mel.to(device)
277
+
278
+ g = model(indiv_mels, x)
279
+
280
+ sync_loss = get_sync_loss(mel, g)
281
+ l1loss = recon_loss(g, gt)
282
+
283
+ sync_losses.append(sync_loss.item())
284
+ recon_losses.append(l1loss.item())
285
+
286
+ if step > eval_steps:
287
+ averaged_sync_loss = sum(sync_losses) / len(sync_losses)
288
+ averaged_recon_loss = sum(recon_losses) / len(recon_losses)
289
+
290
+ print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
291
+
292
+ return averaged_sync_loss
293
+
294
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
295
+
296
+ checkpoint_path = join(
297
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
298
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
299
+ torch.save({
300
+ "state_dict": model.state_dict(),
301
+ "optimizer": optimizer_state,
302
+ "global_step": step,
303
+ "global_epoch": epoch,
304
+ }, checkpoint_path)
305
+ print("Saved checkpoint:", checkpoint_path)
306
+
307
+
308
+ def _load(checkpoint_path):
309
+ if use_cuda:
310
+ checkpoint = torch.load(checkpoint_path)
311
+ else:
312
+ checkpoint = torch.load(checkpoint_path,
313
+ map_location=lambda storage, loc: storage)
314
+ return checkpoint
315
+
316
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
317
+ global global_step
318
+ global global_epoch
319
+
320
+ print("Load checkpoint from: {}".format(path))
321
+ checkpoint = _load(path)
322
+ s = checkpoint["state_dict"]
323
+ new_s = {}
324
+ for k, v in s.items():
325
+ new_s[k.replace('module.', '')] = v
326
+ model.load_state_dict(new_s)
327
+ if not reset_optimizer:
328
+ optimizer_state = checkpoint["optimizer"]
329
+ if optimizer_state is not None:
330
+ print("Load optimizer state from {}".format(path))
331
+ optimizer.load_state_dict(checkpoint["optimizer"])
332
+ if overwrite_global_states:
333
+ global_step = checkpoint["global_step"]
334
+ global_epoch = checkpoint["global_epoch"]
335
+
336
+ return model
337
+
338
+ if __name__ == "__main__":
339
+ checkpoint_dir = args.checkpoint_dir
340
+
341
+ # Dataset and Dataloader setup
342
+ train_dataset = Dataset('train')
343
+ test_dataset = Dataset('val')
344
+
345
+ train_data_loader = data_utils.DataLoader(
346
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
347
+ num_workers=hparams.num_workers)
348
+
349
+ test_data_loader = data_utils.DataLoader(
350
+ test_dataset, batch_size=hparams.batch_size,
351
+ num_workers=4)
352
+
353
+ device = torch.device("cuda" if use_cuda else "cpu")
354
+
355
+ # Model
356
+ model = Wav2Lip().to(device)
357
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
358
+
359
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
360
+ lr=hparams.initial_learning_rate)
361
+
362
+ if args.checkpoint_path is not None:
363
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
364
+
365
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
366
+
367
+ if not os.path.exists(checkpoint_dir):
368
+ os.mkdir(checkpoint_dir)
369
+
370
+ # Train!
371
+ train(device, model, train_data_loader, test_data_loader, optimizer,
372
+ checkpoint_dir=checkpoint_dir,
373
+ checkpoint_interval=hparams.checkpoint_interval,
374
+ nepochs=hparams.nepochs)