JacobLinCool commited on
Commit
7c94b61
Β·
verified Β·
1 Parent(s): 59122de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import (
3
+ AutoImageProcessor,
4
+ AutoModelForCausalLM,
5
+ )
6
+ import gradio as gr
7
+ import torch
8
+ from accelerate import Accelerator
9
+ import numpy as np
10
+ import cv2
11
+ from PIL import Image
12
+ import zipfile
13
+ import io
14
+ import tempfile
15
+ import os
16
+
17
+ DEVICE = Accelerator().device
18
+ MODEL_NAME = "qihoo360/fg-clip2-so400m"
19
+
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(
22
+ DEVICE
23
+ )
24
+ image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
25
+
26
+
27
+ def determine_max_value(image):
28
+ """Determine max_num_patches based on image size."""
29
+ w, h = image.size
30
+ max_val = (w // 16) * (h // 16)
31
+ if max_val > 784:
32
+ return 1024
33
+ elif max_val > 576:
34
+ return 784
35
+ elif max_val > 256:
36
+ return 576
37
+ elif max_val > 128:
38
+ return 256
39
+ else:
40
+ return 128
41
+
42
+
43
+ @spaces.GPU
44
+ def generate_image_embeddings(zip_file):
45
+ """
46
+ Generate embeddings from images in a zip file.
47
+
48
+ Args:
49
+ zip_file: Uploaded zip file containing images
50
+
51
+ Returns:
52
+ Tuple of (embeddings as numpy file, status message)
53
+ """
54
+ try:
55
+ # Extract images from zip
56
+ images = []
57
+ with zipfile.ZipFile(zip_file.name, "r") as zip_ref:
58
+ for file_info in zip_ref.filelist:
59
+ if file_info.filename.lower().endswith(
60
+ (".png", ".jpg", ".jpeg", ".bmp", ".webp")
61
+ ):
62
+ with zip_ref.open(file_info) as img_file:
63
+ img = Image.open(io.BytesIO(img_file.read())).convert("RGB")
64
+ images.append(img)
65
+
66
+ if len(images) == 0:
67
+ return None, "❌ No valid images found in the zip file"
68
+
69
+ # Generate embeddings
70
+ embeddings = []
71
+ with torch.no_grad():
72
+ for i, image in enumerate(images):
73
+ image_input = image_processor(
74
+ images=image,
75
+ max_num_patches=determine_max_value(image),
76
+ return_tensors="pt",
77
+ ).to(DEVICE)
78
+ image_feature = model.get_image_features(**image_input)
79
+
80
+ # Normalize the embedding
81
+ normalized_features = image_feature / image_feature.norm(
82
+ dim=-1, keepdim=True
83
+ )
84
+ embeddings.append(normalized_features.cpu().numpy())
85
+
86
+ embeddings = np.vstack(embeddings)
87
+
88
+ # Save embeddings to a temporary file
89
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as tmp:
90
+ np.save(tmp.name, embeddings)
91
+ output_path = tmp.name
92
+
93
+ message = f"βœ… Successfully generated embeddings for {len(images)} images\nShape: {embeddings.shape}"
94
+ return output_path, message
95
+
96
+ except Exception as e:
97
+ return None, f"❌ Error: {str(e)}"
98
+
99
+
100
+ def extract_frames(video_path: str, fps: int = 4):
101
+ """
102
+ Extract frames from video at specified fps.
103
+
104
+ Args:
105
+ video_path: Path to the video file
106
+ fps: Frames per second to sample
107
+
108
+ Returns:
109
+ List of PIL Images
110
+ """
111
+ cap = cv2.VideoCapture(video_path)
112
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
113
+ frame_interval = int(round(video_fps) / fps)
114
+
115
+ frames = []
116
+ frame_count = 0
117
+
118
+ while True:
119
+ ret, frame = cap.read()
120
+ if not ret:
121
+ break
122
+
123
+ if frame_count % frame_interval == 0:
124
+ # Convert BGR to RGB
125
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
126
+ pil_image = Image.fromarray(frame_rgb)
127
+ frames.append(pil_image)
128
+
129
+ frame_count += 1
130
+
131
+ cap.release()
132
+ return frames
133
+
134
+
135
+ @spaces.GPU
136
+ def generate_video_embeddings(video_file, fps):
137
+ """
138
+ Generate embeddings from video frames.
139
+
140
+ Args:
141
+ video_file: Uploaded video file
142
+ fps: Frames per second to extract
143
+
144
+ Returns:
145
+ Tuple of (embeddings as numpy file, status message)
146
+ """
147
+ try:
148
+ # Extract frames
149
+ frames = extract_frames(video_file.name, fps)
150
+
151
+ if len(frames) == 0:
152
+ return None, "❌ No frames could be extracted from the video"
153
+
154
+ # Generate embeddings
155
+ embeddings = []
156
+ with torch.no_grad():
157
+ for i, frame in enumerate(frames):
158
+ image_input = image_processor(
159
+ images=frame,
160
+ max_num_patches=determine_max_value(frame),
161
+ return_tensors="pt",
162
+ ).to(DEVICE)
163
+ image_feature = model.get_image_features(**image_input)
164
+
165
+ # Normalize the embedding
166
+ normalized_features = image_feature / image_feature.norm(
167
+ dim=-1, keepdim=True
168
+ )
169
+ embeddings.append(normalized_features.cpu().numpy())
170
+
171
+ embeddings = np.vstack(embeddings)
172
+
173
+ # Save embeddings to a temporary file
174
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as tmp:
175
+ np.save(tmp.name, embeddings)
176
+ output_path = tmp.name
177
+
178
+ message = f"βœ… Successfully generated embeddings for {len(frames)} frames (extracted at {fps} fps)\nShape: {embeddings.shape}"
179
+ return output_path, message
180
+
181
+ except Exception as e:
182
+ return None, f"❌ Error: {str(e)}"
183
+
184
+
185
+ # Create Gradio interface
186
+ with gr.Blocks(title="Video & Image Embedding Generator") as demo:
187
+ gr.Markdown("# 🎬 Video & Image Embedding Generator")
188
+ gr.Markdown(f"Generate embeddings using **{MODEL_NAME}** model")
189
+
190
+ with gr.Tab("πŸ“¦ Images from ZIP"):
191
+ gr.Markdown("Upload a ZIP file containing images to generate embeddings")
192
+ with gr.Row():
193
+ with gr.Column():
194
+ zip_input = gr.File(label="Upload ZIP file", file_types=[".zip"])
195
+ img_submit_btn = gr.Button("Generate Embeddings", variant="primary")
196
+ with gr.Column():
197
+ img_output = gr.File(label="Download Embeddings (.npy)")
198
+ img_status = gr.Textbox(label="Status", lines=3)
199
+
200
+ img_submit_btn.click(
201
+ fn=generate_image_embeddings,
202
+ inputs=[zip_input],
203
+ outputs=[img_output, img_status],
204
+ )
205
+
206
+ with gr.Tab("πŸŽ₯ Video Frames"):
207
+ gr.Markdown(
208
+ "Upload a video and specify FPS to extract frames and generate embeddings"
209
+ )
210
+ with gr.Row():
211
+ with gr.Column():
212
+ video_input = gr.Video(label="Upload Video")
213
+ fps_input = gr.Slider(
214
+ minimum=1,
215
+ maximum=30,
216
+ value=4,
217
+ step=1,
218
+ label="Frames per Second (FPS)",
219
+ )
220
+ vid_submit_btn = gr.Button("Generate Embeddings", variant="primary")
221
+ with gr.Column():
222
+ vid_output = gr.File(label="Download Embeddings (.npy)")
223
+ vid_status = gr.Textbox(label="Status", lines=3)
224
+
225
+ vid_submit_btn.click(
226
+ fn=generate_video_embeddings,
227
+ inputs=[video_input, fps_input],
228
+ outputs=[vid_output, vid_status],
229
+ )
230
+
231
+ gr.Markdown(
232
+ """
233
+ ### πŸ“ Notes:
234
+ - Images in ZIP: Supports PNG, JPG, JPEG, BMP, WEBP formats
235
+ - Video: Supports common video formats (MP4, AVI, MOV, etc.)
236
+ - Output: NumPy array file (.npy) containing normalized embeddings
237
+ - Load embeddings: `embeddings = np.load('embeddings.npy')`
238
+ """
239
+ )
240
+
241
+
242
+ if __name__ == "__main__":
243
+ demo.launch()