Krokodilpirat commited on
Commit
45326b4
·
verified ·
1 Parent(s): 18a6e82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -75
app.py CHANGED
@@ -7,7 +7,9 @@ import torch
7
  import numpy as np
8
  import gradio as gr
9
  import subprocess
 
10
  import requests
 
11
  from huggingface_hub import hf_hub_download
12
  from video_depth_anything.video_depth import VideoDepthAnything
13
  from utils.dc_utils import read_video_frames, save_video
@@ -19,20 +21,27 @@ os.environ["HF_HOME"] = "/tmp/huggingface"
19
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
20
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
21
 
22
- # Patch Gradio schema bug
23
  def patch_gradio_utils():
24
  try:
25
  from gradio_client import utils
26
  original_get_type = utils.get_type
 
27
  def patched_get_type(schema):
28
- if isinstance(schema, bool): return "boolean"
29
- if not isinstance(schema, dict): return "any"
 
 
30
  return original_get_type(schema)
 
31
  utils.get_type = patched_get_type
32
- except: pass
 
 
 
33
  patch_gradio_utils()
34
 
35
- # Load BLIP
36
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
37
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cpu")
38
 
@@ -43,103 +52,116 @@ def generate_blip_name(frame: np.ndarray) -> str:
43
  caption = blip_processor.decode(out[0], skip_special_tokens=True).lower()
44
  stopwords = {"a", "an", "the", "in", "on", "at", "with", "by", "of", "for", "under", "through", "and", "is"}
45
  words = [w for w in caption.split() if w not in stopwords and w.isalpha()]
46
- return "_".join(words[:3])[:30]
 
47
 
48
  # Load depth model
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
- video_depth_anything = VideoDepthAnything(encoder='vitl', features=256, out_channels=[256,512,1024,1024])
51
- ckpt_path = hf_hub_download("depth-anything/Video-Depth-Anything-Large", filename="video_depth_anything_vitl.pth", cache_dir="/tmp/huggingface")
 
 
 
 
 
 
 
52
  video_depth_anything.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
53
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
54
 
55
- # Proxy MJ download
56
- def download_video_from_url(url):
57
- proxy = "https://9cee417c-5874-4e53-939a-52ad3f6f2f30-00-16i6nbwyeqga.picard.replit.dev/"
58
- full = f"{proxy}?url={url}"
59
- temp = "temp_video.mp4"
60
- with requests.get(full, stream=True, timeout=20) as r:
61
- r.raise_for_status()
62
- with open(temp, "wb") as f:
63
- for chunk in r.iter_content(chunk_size=8192):
64
- if chunk: f.write(chunk)
65
- return temp
66
-
67
- # Trigger: Clear upload if MJ
68
- def clear_uploaded_video(url):
69
- return None, "Downloading MJ video...", None
70
-
71
- # Trigger: MJ download + optional BLIP
72
- def handle_video_url(url, use_blip):
73
- path = download_video_from_url(url)
74
- blip = ""
75
- if use_blip:
76
- frames, _ = read_video_frames(path, 999, -1, 480)
77
- frame = frames[len(frames)//2]
78
- blip = generate_blip_name(frame)
79
- return path, blip
80
-
81
- # Trigger: Upload + optional BLIP
82
- def handle_upload(path, use_blip):
83
- blip = ""
84
- if use_blip:
85
- frames, _ = read_video_frames(path, 999, -1, 480)
86
- frame = frames[len(frames)//2]
87
- blip = generate_blip_name(frame)
88
- return blip
89
-
90
- # Main process
91
-
92
- def infer_video_depth_from_source(upload_video, video_url, custom_name, use_blip,
93
- max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur):
94
- input_path = upload_video or download_video_from_url(video_url)
95
- base_name = os.path.splitext(os.path.basename(input_path))[0]
96
  if custom_name:
97
  base_name = custom_name.strip().replace(" ", "_")[:30]
98
  elif use_blip:
99
  frames, _ = read_video_frames(input_path, 999, -1, 480)
100
- frame = frames[len(frames)//2]
101
  base_name = generate_blip_name(frame)
 
 
 
 
102
  output_dir = "./outputs"
103
  os.makedirs(output_dir, exist_ok=True)
104
 
105
- stitched_path = os.path.join(output_dir, base_name + "_RGBD.mp4")
106
- vis_path = os.path.join(output_dir, base_name + "_vis.mp4")
 
107
  frames, target_fps = read_video_frames(input_path, max_len, target_fps, max_res)
108
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=518, device=DEVICE)
109
- save_video(depths, vis_path, fps=fps, is_depths=True)
110
 
111
  if stitch:
112
  full_frames, _ = read_video_frames(input_path, max_len, target_fps, max_res=-1)
113
  d_min, d_max = depths.min(), depths.max()
114
  stitched_frames = []
 
115
  for i in range(min(len(full_frames), len(depths))):
116
  rgb = full_frames[i]
117
  depth = ((depths[i] - d_min) / (d_max - d_min) * 255).astype(np.uint8)
118
  if grayscale:
119
- import matplotlib
120
- cmap = matplotlib.colormaps.get_cmap("inferno")
121
- depth_color = (cmap(depth / 255.0)[..., :3] * 255).astype(np.uint8)
122
- gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
123
- depth_vis = np.stack([gray]*3, axis=-1) if convert_from_color else np.stack([depth]*3, axis=-1)
 
 
 
124
  else:
125
  import matplotlib
126
  cmap = matplotlib.colormaps.get_cmap("inferno")
127
  depth_vis = (cmap(depth / 255.0)[..., :3] * 255).astype(np.uint8)
128
  if blur > 0:
129
- k = int(blur * 20) * 2 + 1
130
- depth_vis = cv2.GaussianBlur(depth_vis, (k, k), 0)
131
  depth_resized = cv2.resize(depth_vis, (rgb.shape[1], rgb.shape[0]))
132
- stitched_frames.append(cv2.hconcat([rgb, depth_resized]))
133
- save_video(np.array(stitched_frames), stitched_path, fps=fps)
134
 
135
- temp_audio = stitched_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
136
- cmd = ["ffmpeg", "-y", "-i", stitched_path, "-i", input_path, "-c:v", "copy", "-c:a", "aac",
137
- "-map", "0:v:0", "-map", "1:a:0?", "-shortest", temp_audio]
138
- subprocess.run(cmd)
139
- os.replace(temp_audio, stitched_path)
140
 
141
- gc.collect(); torch.cuda.empty_cache()
142
- return vis_path, stitched_path, input_path, base_name
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # Gradio UI
145
  with gr.Blocks(analytics_enabled=False, css="""
@@ -152,7 +174,8 @@ with gr.Blocks(analytics_enabled=False, css="""
152
  """) as demo:
153
 
154
  gr.Markdown("# Video Depth Anything + RGBD sbs output")
155
- gr.Markdown("Upload a video or paste a URL to generate RGBD output.")
 
156
 
157
  with gr.Row(equal_height=True):
158
  upload_video = gr.Video(label="Upload Video", height=360, scale=1)
@@ -160,14 +183,40 @@ with gr.Blocks(analytics_enabled=False, css="""
160
  rgbd_out = gr.Video(label="RGBD Output", interactive=False, autoplay=True, show_share_button=True, height=360, scale=2)
161
 
162
  with gr.Row():
163
- video_url = gr.Textbox(label="Paste MJ video URL", scale=3)
164
  use_blip = gr.Checkbox(label="Use BLIP for automatic file name", value=True, scale=1)
165
  blip_name_display = gr.Textbox(label="BLIP file name", interactive=False, scale=2)
166
  custom_name = gr.Textbox(label="Custom file name", scale=3)
167
 
168
- video_url.change(fn=clear_uploaded_video, inputs=[video_url], outputs=[upload_video, blip_name_display, custom_name])
169
- video_url.change(fn=handle_video_url, inputs=[video_url, use_blip], outputs=[upload_video, blip_name_display])
170
- upload_video.change(fn=handle_upload, inputs=[upload_video, use_blip], outputs=[blip_name_display])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  with gr.Accordion("Advanced Settings", open=False):
173
  max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=-1, step=1)
@@ -179,10 +228,11 @@ with gr.Blocks(analytics_enabled=False, css="""
179
  blur = gr.Slider(label="Blur (for edge smoothing)", minimum=0, maximum=1, value=0.3, step=0.01)
180
 
181
  run_btn = gr.Button("Generate")
 
182
  run_btn.click(
183
  fn=infer_video_depth_from_source,
184
  inputs=[upload_video, video_url, custom_name, use_blip, max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur],
185
- outputs=[depth_out, rgbd_out, upload_video, blip_name_display]
186
  )
187
 
188
  demo.queue()
 
7
  import numpy as np
8
  import gradio as gr
9
  import subprocess
10
+ import urllib.request
11
  import requests
12
+ from urllib.parse import urlparse
13
  from huggingface_hub import hf_hub_download
14
  from video_depth_anything.video_depth import VideoDepthAnything
15
  from utils.dc_utils import read_video_frames, save_video
 
21
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
22
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
23
 
24
+ # Patch for Gradio schema bug
25
  def patch_gradio_utils():
26
  try:
27
  from gradio_client import utils
28
  original_get_type = utils.get_type
29
+
30
  def patched_get_type(schema):
31
+ if isinstance(schema, bool):
32
+ return "boolean"
33
+ if not isinstance(schema, dict):
34
+ return "any"
35
  return original_get_type(schema)
36
+
37
  utils.get_type = patched_get_type
38
+ print("Successfully patched Gradio utils.get_type")
39
+ except Exception as e:
40
+ print(f"Could not patch Gradio utils: {e}")
41
+
42
  patch_gradio_utils()
43
 
44
+ # Load BLIP model
45
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
46
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cpu")
47
 
 
52
  caption = blip_processor.decode(out[0], skip_special_tokens=True).lower()
53
  stopwords = {"a", "an", "the", "in", "on", "at", "with", "by", "of", "for", "under", "through", "and", "is"}
54
  words = [w for w in caption.split() if w not in stopwords and w.isalpha()]
55
+ trimmed = "_".join(words[:3])
56
+ return trimmed[:30]
57
 
58
  # Load depth model
59
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
60
+ encoder = 'vitl'
61
+ model_name = 'Large'
62
+ model_configs = {
63
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
64
+ }
65
+ video_depth_anything = VideoDepthAnything(**model_configs[encoder])
66
+ ckpt_path = hf_hub_download(repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
67
+ filename=f"video_depth_anything_{encoder}.pth",
68
+ cache_dir="/tmp/huggingface")
69
  video_depth_anything.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
70
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
71
 
72
+ # MJ proxy download
73
+ def download_video_from_url(original_url):
74
+ try:
75
+ proxy_base = "https://9cee417c-5874-4e53-939a-52ad3f6f2f30-00-16i6nbwyeqga.picard.replit.dev/"
76
+ proxy_url = f"{proxy_base}?url={original_url}"
77
+ temp_path = "temp_video.mp4"
78
+ with requests.get(proxy_url, stream=True, timeout=20) as response:
79
+ response.raise_for_status()
80
+ with open(temp_path, "wb") as f:
81
+ for chunk in response.iter_content(chunk_size=8192):
82
+ if chunk:
83
+ f.write(chunk)
84
+ return temp_path
85
+ except Exception as e:
86
+ raise RuntimeError(f"Proxy download failed: {e}")
87
+
88
+ # Inference
89
+ def infer_video_depth_from_source(upload_video, video_url, custom_name, use_blip, *args):
90
+ max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur = args
91
+
92
+ if upload_video:
93
+ input_path = upload_video
94
+ base_name = os.path.splitext(os.path.basename(input_path))[0]
95
+ elif video_url:
96
+ input_path = download_video_from_url(video_url)
97
+ base_name = os.path.splitext(os.path.basename(input_path))[0]
98
+ else:
99
+ raise ValueError("No video source provided.")
100
+
101
+ blip_name = ""
 
 
 
 
 
 
 
 
 
 
 
102
  if custom_name:
103
  base_name = custom_name.strip().replace(" ", "_")[:30]
104
  elif use_blip:
105
  frames, _ = read_video_frames(input_path, 999, -1, 480)
106
+ frame = frames[len(frames) // 2]
107
  base_name = generate_blip_name(frame)
108
+ blip_name = base_name
109
+ else:
110
+ base_name = os.path.splitext(os.path.basename(input_path))[0]
111
+
112
  output_dir = "./outputs"
113
  os.makedirs(output_dir, exist_ok=True)
114
 
115
+ stitched_video_path = os.path.join(output_dir, base_name + "_RGBD.mp4")
116
+ vis_video_path = os.path.join(output_dir, base_name + "_vis.mp4")
117
+
118
  frames, target_fps = read_video_frames(input_path, max_len, target_fps, max_res)
119
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=518, device=DEVICE)
120
+ save_video(depths, vis_video_path, fps=fps, is_depths=True)
121
 
122
  if stitch:
123
  full_frames, _ = read_video_frames(input_path, max_len, target_fps, max_res=-1)
124
  d_min, d_max = depths.min(), depths.max()
125
  stitched_frames = []
126
+
127
  for i in range(min(len(full_frames), len(depths))):
128
  rgb = full_frames[i]
129
  depth = ((depths[i] - d_min) / (d_max - d_min) * 255).astype(np.uint8)
130
  if grayscale:
131
+ if convert_from_color:
132
+ import matplotlib
133
+ cmap = matplotlib.colormaps.get_cmap("inferno")
134
+ depth_color = (cmap(depth / 255.0)[..., :3] * 255).astype(np.uint8)
135
+ gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
136
+ depth_vis = np.stack([gray]*3, axis=-1)
137
+ else:
138
+ depth_vis = np.stack([depth]*3, axis=-1)
139
  else:
140
  import matplotlib
141
  cmap = matplotlib.colormaps.get_cmap("inferno")
142
  depth_vis = (cmap(depth / 255.0)[..., :3] * 255).astype(np.uint8)
143
  if blur > 0:
144
+ kernel = int(blur * 20) * 2 + 1
145
+ depth_vis = cv2.GaussianBlur(depth_vis, (kernel, kernel), 0)
146
  depth_resized = cv2.resize(depth_vis, (rgb.shape[1], rgb.shape[0]))
147
+ stitched = cv2.hconcat([rgb, depth_resized])
148
+ stitched_frames.append(stitched)
149
 
150
+ save_video(np.array(stitched_frames), stitched_video_path, fps=fps)
 
 
 
 
151
 
152
+ temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
153
+ cmd = [
154
+ "ffmpeg", "-y", "-i", stitched_video_path, "-i", input_path,
155
+ "-c:v", "copy", "-c:a", "aac", "-map", "0:v:0", "-map", "1:a:0?",
156
+ "-shortest", temp_audio_path
157
+ ]
158
+ subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
159
+ os.replace(temp_audio_path, stitched_video_path)
160
+
161
+ gc.collect()
162
+ torch.cuda.empty_cache()
163
+
164
+ return vis_video_path, stitched_video_path
165
 
166
  # Gradio UI
167
  with gr.Blocks(analytics_enabled=False, css="""
 
174
  """) as demo:
175
 
176
  gr.Markdown("# Video Depth Anything + RGBD sbs output")
177
+ gr.Markdown("Upload a video or paste a URL to generate RGBD output.
178
+ [Project Page](https://videodepthanything.github.io/)")
179
 
180
  with gr.Row(equal_height=True):
181
  upload_video = gr.Video(label="Upload Video", height=360, scale=1)
 
183
  rgbd_out = gr.Video(label="RGBD Output", interactive=False, autoplay=True, show_share_button=True, height=360, scale=2)
184
 
185
  with gr.Row():
186
+ video_url = gr.Textbox(label="Paste MJ video URL (experimental)", scale=3)
187
  use_blip = gr.Checkbox(label="Use BLIP for automatic file name", value=True, scale=1)
188
  blip_name_display = gr.Textbox(label="BLIP file name", interactive=False, scale=2)
189
  custom_name = gr.Textbox(label="Custom file name", scale=3)
190
 
191
+ # Neue Trigger
192
+ def handle_mj_url(url, use_blip):
193
+ if not url.strip():
194
+ return None, ""
195
+ try:
196
+ temp_path = download_video_from_url(url)
197
+ frames, _ = read_video_frames(temp_path, 999, -1, 480)
198
+ blip = generate_blip_name(frames[len(frames) // 2]) if use_blip else ""
199
+ return temp_path, blip
200
+ except Exception as e:
201
+ return None, f"Download error: {e}"
202
+
203
+ video_url.change(
204
+ fn=handle_mj_url,
205
+ inputs=[video_url, use_blip],
206
+ outputs=[upload_video, blip_name_display]
207
+ )
208
+
209
+ def handle_upload(path, use_blip):
210
+ if not path or not use_blip:
211
+ return ""
212
+ frames, _ = read_video_frames(path, 999, -1, 480)
213
+ return generate_blip_name(frames[len(frames) // 2])
214
+
215
+ upload_video.change(
216
+ fn=handle_upload,
217
+ inputs=[upload_video, use_blip],
218
+ outputs=[blip_name_display]
219
+ )
220
 
221
  with gr.Accordion("Advanced Settings", open=False):
222
  max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=-1, step=1)
 
228
  blur = gr.Slider(label="Blur (for edge smoothing)", minimum=0, maximum=1, value=0.3, step=0.01)
229
 
230
  run_btn = gr.Button("Generate")
231
+
232
  run_btn.click(
233
  fn=infer_video_depth_from_source,
234
  inputs=[upload_video, video_url, custom_name, use_blip, max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur],
235
+ outputs=[depth_out, rgbd_out]
236
  )
237
 
238
  demo.queue()