File size: 8,280 Bytes
22fcf31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import torch
import gradio as gr
import numpy as np
from typing import List, Union, Optional
from PIL import Image
import requests
from io import BytesIO
import base64

# Import your handler
from handler import ModelHandler

# Create model handler instance
model_handler = ModelHandler()
model_handler.initialize(None)  # We'll handle device placement manually

def cosine_similarity(embedding1, embedding2):
    """Calculate cosine similarity between two embeddings"""
    embedding1 = np.array(embedding1)
    embedding2 = np.array(embedding2)
    return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))

def process_image(image_input):
    """Process image input (URL, uploaded file, or base64)"""
    if isinstance(image_input, str):
        if image_input.startswith("http"):
            # URL
            response = requests.get(image_input)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        elif image_input.startswith("data:image"):
            # Base64
            image_data = image_input.split(",")[1]
            image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
        else:
            # Local path
            image = Image.open(image_input).convert("RGB")
    else:
        # Uploaded file from Gradio
        image = Image.open(image_input).convert("RGB")
    return image

def generate_embeddings(inputs, task="retrieval", input_type="text"):
    """Generate embeddings for text or image inputs"""
    try:
        # Handle different input types
        if input_type == "text":
            features = model_handler.model.tokenize(inputs)
        else:  # image
            processed_images = [process_image(img) for img in inputs]
            features = model_handler.model.tokenize(processed_images)
            
        # Process features through model
        with torch.no_grad():
            outputs = model_handler.model.forward(features, task=task)
            embeddings = outputs.get("sentence_embedding", None)
            
            if embeddings is not None:
                return embeddings.cpu().numpy().tolist()
            else:
                return None
    except Exception as e:
        return {"error": str(e)}

def text_to_embedding(text, task="retrieval"):
    """Convert text to embedding"""
    if not text.strip():
        return None
    return generate_embeddings([text], task=task, input_type="text")[0]

def image_to_embedding(image, task="retrieval"):
    """Convert image to embedding"""
    if image is None:
        return None
    return generate_embeddings([image], task=task, input_type="image")[0]

def compare_embeddings(embedding1, embedding2):
    """Compare two embeddings and return similarity"""
    if embedding1 is None or embedding2 is None:
        return "Please generate both embeddings first"
    similarity = cosine_similarity(embedding1, embedding2)
    return f"Cosine Similarity: {similarity:.4f}"

# Create Gradio interface
with gr.Blocks(title="Embedding Model Demo") as demo:
    gr.Markdown("# Embedding Model Demo")
    gr.Markdown("Generate and compare embeddings for text and images")
    
    with gr.Tab("Text Embeddings"):
        with gr.Row():
            with gr.Column():
                text_input1 = gr.Textbox(label="Text Input 1", lines=5)
                task_dropdown1 = gr.Dropdown(
                    ["retrieval", "text-matching", "code"], 
                    label="Task", 
                    value="retrieval"
                )
                text_embed_btn1 = gr.Button("Generate Embedding 1")
            
            with gr.Column():
                text_input2 = gr.Textbox(label="Text Input 2", lines=5)
                task_dropdown2 = gr.Dropdown(
                    ["retrieval", "text-matching", "code"], 
                    label="Task", 
                    value="retrieval"
                )
                text_embed_btn2 = gr.Button("Generate Embedding 2")
        
        embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
        embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
        
        compare_btn = gr.Button("Compare Embeddings")
        similarity_output = gr.Textbox(label="Similarity Result")
    
    with gr.Tab("Image Embeddings"):
        with gr.Row():
            with gr.Column():
                image_input1 = gr.Image(label="Image Input 1", type="pil")
                image_task_dropdown1 = gr.Dropdown(
                    ["retrieval"], 
                    label="Task", 
                    value="retrieval"
                )
                image_embed_btn1 = gr.Button("Generate Embedding 1")
            
            with gr.Column():
                image_input2 = gr.Image(label="Image Input 2", type="pil")
                image_task_dropdown2 = gr.Dropdown(
                    ["retrieval"], 
                    label="Task", 
                    value="retrieval"
                )
                image_embed_btn2 = gr.Button("Generate Embedding 2")
        
        image_embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
        image_embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
        
        image_compare_btn = gr.Button("Compare Embeddings")
        image_similarity_output = gr.Textbox(label="Similarity Result")
    
    with gr.Tab("Cross-Modal Comparison"):
        with gr.Row():
            with gr.Column():
                cross_text_input = gr.Textbox(label="Text Input", lines=5)
                cross_text_task = gr.Dropdown(
                    ["retrieval"], 
                    label="Task", 
                    value="retrieval"
                )
                cross_text_btn = gr.Button("Generate Text Embedding")
            
            with gr.Column():
                cross_image_input = gr.Image(label="Image Input", type="pil")
                cross_image_task = gr.Dropdown(
                    ["retrieval"], 
                    label="Task", 
                    value="retrieval"
                )
                cross_image_btn = gr.Button("Generate Image Embedding")
        
        cross_text_embedding = gr.JSON(label="Text Embedding", visible=False)
        cross_image_embedding = gr.JSON(label="Image Embedding", visible=False)
        
        cross_compare_btn = gr.Button("Compare Text and Image")
        cross_similarity_output = gr.Textbox(label="Similarity Result")
    
    # Text tab events
    text_embed_btn1.click(
        fn=text_to_embedding, 
        inputs=[text_input1, task_dropdown1], 
        outputs=embedding_output1
    )
    
    text_embed_btn2.click(
        fn=text_to_embedding, 
        inputs=[text_input2, task_dropdown2], 
        outputs=embedding_output2
    )
    
    compare_btn.click(
        fn=compare_embeddings,
        inputs=[embedding_output1, embedding_output2],
        outputs=similarity_output
    )
    
    # Image tab events
    image_embed_btn1.click(
        fn=image_to_embedding,
        inputs=[image_input1, image_task_dropdown1],
        outputs=image_embedding_output1
    )
    
    image_embed_btn2.click(
        fn=image_to_embedding,
        inputs=[image_input2, image_task_dropdown2],
        outputs=image_embedding_output2
    )
    
    image_compare_btn.click(
        fn=compare_embeddings,
        inputs=[image_embedding_output1, image_embedding_output2],
        outputs=image_similarity_output
    )
    
    # Cross-modal tab events
    cross_text_btn.click(
        fn=text_to_embedding,
        inputs=[cross_text_input, cross_text_task],
        outputs=cross_text_embedding
    )
    
    cross_image_btn.click(
        fn=image_to_embedding,
        inputs=[cross_image_input, cross_image_task],
        outputs=cross_image_embedding
    )
    
    cross_compare_btn.click(
        fn=compare_embeddings,
        inputs=[cross_text_embedding, cross_image_embedding],
        outputs=cross_similarity_output
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()