File size: 8,280 Bytes
22fcf31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import os
import torch
import gradio as gr
import numpy as np
from typing import List, Union, Optional
from PIL import Image
import requests
from io import BytesIO
import base64
# Import your handler
from handler import ModelHandler
# Create model handler instance
model_handler = ModelHandler()
model_handler.initialize(None) # We'll handle device placement manually
def cosine_similarity(embedding1, embedding2):
"""Calculate cosine similarity between two embeddings"""
embedding1 = np.array(embedding1)
embedding2 = np.array(embedding2)
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
def process_image(image_input):
"""Process image input (URL, uploaded file, or base64)"""
if isinstance(image_input, str):
if image_input.startswith("http"):
# URL
response = requests.get(image_input)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif image_input.startswith("data:image"):
# Base64
image_data = image_input.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
else:
# Local path
image = Image.open(image_input).convert("RGB")
else:
# Uploaded file from Gradio
image = Image.open(image_input).convert("RGB")
return image
def generate_embeddings(inputs, task="retrieval", input_type="text"):
"""Generate embeddings for text or image inputs"""
try:
# Handle different input types
if input_type == "text":
features = model_handler.model.tokenize(inputs)
else: # image
processed_images = [process_image(img) for img in inputs]
features = model_handler.model.tokenize(processed_images)
# Process features through model
with torch.no_grad():
outputs = model_handler.model.forward(features, task=task)
embeddings = outputs.get("sentence_embedding", None)
if embeddings is not None:
return embeddings.cpu().numpy().tolist()
else:
return None
except Exception as e:
return {"error": str(e)}
def text_to_embedding(text, task="retrieval"):
"""Convert text to embedding"""
if not text.strip():
return None
return generate_embeddings([text], task=task, input_type="text")[0]
def image_to_embedding(image, task="retrieval"):
"""Convert image to embedding"""
if image is None:
return None
return generate_embeddings([image], task=task, input_type="image")[0]
def compare_embeddings(embedding1, embedding2):
"""Compare two embeddings and return similarity"""
if embedding1 is None or embedding2 is None:
return "Please generate both embeddings first"
similarity = cosine_similarity(embedding1, embedding2)
return f"Cosine Similarity: {similarity:.4f}"
# Create Gradio interface
with gr.Blocks(title="Embedding Model Demo") as demo:
gr.Markdown("# Embedding Model Demo")
gr.Markdown("Generate and compare embeddings for text and images")
with gr.Tab("Text Embeddings"):
with gr.Row():
with gr.Column():
text_input1 = gr.Textbox(label="Text Input 1", lines=5)
task_dropdown1 = gr.Dropdown(
["retrieval", "text-matching", "code"],
label="Task",
value="retrieval"
)
text_embed_btn1 = gr.Button("Generate Embedding 1")
with gr.Column():
text_input2 = gr.Textbox(label="Text Input 2", lines=5)
task_dropdown2 = gr.Dropdown(
["retrieval", "text-matching", "code"],
label="Task",
value="retrieval"
)
text_embed_btn2 = gr.Button("Generate Embedding 2")
embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
compare_btn = gr.Button("Compare Embeddings")
similarity_output = gr.Textbox(label="Similarity Result")
with gr.Tab("Image Embeddings"):
with gr.Row():
with gr.Column():
image_input1 = gr.Image(label="Image Input 1", type="pil")
image_task_dropdown1 = gr.Dropdown(
["retrieval"],
label="Task",
value="retrieval"
)
image_embed_btn1 = gr.Button("Generate Embedding 1")
with gr.Column():
image_input2 = gr.Image(label="Image Input 2", type="pil")
image_task_dropdown2 = gr.Dropdown(
["retrieval"],
label="Task",
value="retrieval"
)
image_embed_btn2 = gr.Button("Generate Embedding 2")
image_embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
image_embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
image_compare_btn = gr.Button("Compare Embeddings")
image_similarity_output = gr.Textbox(label="Similarity Result")
with gr.Tab("Cross-Modal Comparison"):
with gr.Row():
with gr.Column():
cross_text_input = gr.Textbox(label="Text Input", lines=5)
cross_text_task = gr.Dropdown(
["retrieval"],
label="Task",
value="retrieval"
)
cross_text_btn = gr.Button("Generate Text Embedding")
with gr.Column():
cross_image_input = gr.Image(label="Image Input", type="pil")
cross_image_task = gr.Dropdown(
["retrieval"],
label="Task",
value="retrieval"
)
cross_image_btn = gr.Button("Generate Image Embedding")
cross_text_embedding = gr.JSON(label="Text Embedding", visible=False)
cross_image_embedding = gr.JSON(label="Image Embedding", visible=False)
cross_compare_btn = gr.Button("Compare Text and Image")
cross_similarity_output = gr.Textbox(label="Similarity Result")
# Text tab events
text_embed_btn1.click(
fn=text_to_embedding,
inputs=[text_input1, task_dropdown1],
outputs=embedding_output1
)
text_embed_btn2.click(
fn=text_to_embedding,
inputs=[text_input2, task_dropdown2],
outputs=embedding_output2
)
compare_btn.click(
fn=compare_embeddings,
inputs=[embedding_output1, embedding_output2],
outputs=similarity_output
)
# Image tab events
image_embed_btn1.click(
fn=image_to_embedding,
inputs=[image_input1, image_task_dropdown1],
outputs=image_embedding_output1
)
image_embed_btn2.click(
fn=image_to_embedding,
inputs=[image_input2, image_task_dropdown2],
outputs=image_embedding_output2
)
image_compare_btn.click(
fn=compare_embeddings,
inputs=[image_embedding_output1, image_embedding_output2],
outputs=image_similarity_output
)
# Cross-modal tab events
cross_text_btn.click(
fn=text_to_embedding,
inputs=[cross_text_input, cross_text_task],
outputs=cross_text_embedding
)
cross_image_btn.click(
fn=image_to_embedding,
inputs=[cross_image_input, cross_image_task],
outputs=cross_image_embedding
)
cross_compare_btn.click(
fn=compare_embeddings,
inputs=[cross_text_embedding, cross_image_embedding],
outputs=cross_similarity_output
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()
|