|
|
import os
|
|
|
import torch
|
|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
from typing import List, Union, Optional
|
|
|
from PIL import Image
|
|
|
import requests
|
|
|
from io import BytesIO
|
|
|
import base64
|
|
|
|
|
|
|
|
|
from handler import ModelHandler
|
|
|
|
|
|
|
|
|
model_handler = ModelHandler()
|
|
|
model_handler.initialize(None)
|
|
|
|
|
|
def cosine_similarity(embedding1, embedding2):
|
|
|
"""Calculate cosine similarity between two embeddings"""
|
|
|
embedding1 = np.array(embedding1)
|
|
|
embedding2 = np.array(embedding2)
|
|
|
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
|
|
|
|
|
def process_image(image_input):
|
|
|
"""Process image input (URL, uploaded file, or base64)"""
|
|
|
if isinstance(image_input, str):
|
|
|
if image_input.startswith("http"):
|
|
|
|
|
|
response = requests.get(image_input)
|
|
|
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
elif image_input.startswith("data:image"):
|
|
|
|
|
|
image_data = image_input.split(",")[1]
|
|
|
image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
|
|
|
else:
|
|
|
|
|
|
image = Image.open(image_input).convert("RGB")
|
|
|
else:
|
|
|
|
|
|
image = Image.open(image_input).convert("RGB")
|
|
|
return image
|
|
|
|
|
|
def generate_embeddings(inputs, task="retrieval", input_type="text"):
|
|
|
"""Generate embeddings for text or image inputs"""
|
|
|
try:
|
|
|
|
|
|
if input_type == "text":
|
|
|
features = model_handler.model.tokenize(inputs)
|
|
|
else:
|
|
|
processed_images = [process_image(img) for img in inputs]
|
|
|
features = model_handler.model.tokenize(processed_images)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model_handler.model.forward(features, task=task)
|
|
|
embeddings = outputs.get("sentence_embedding", None)
|
|
|
|
|
|
if embeddings is not None:
|
|
|
return embeddings.cpu().numpy().tolist()
|
|
|
else:
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
def text_to_embedding(text, task="retrieval"):
|
|
|
"""Convert text to embedding"""
|
|
|
if not text.strip():
|
|
|
return None
|
|
|
return generate_embeddings([text], task=task, input_type="text")[0]
|
|
|
|
|
|
def image_to_embedding(image, task="retrieval"):
|
|
|
"""Convert image to embedding"""
|
|
|
if image is None:
|
|
|
return None
|
|
|
return generate_embeddings([image], task=task, input_type="image")[0]
|
|
|
|
|
|
def compare_embeddings(embedding1, embedding2):
|
|
|
"""Compare two embeddings and return similarity"""
|
|
|
if embedding1 is None or embedding2 is None:
|
|
|
return "Please generate both embeddings first"
|
|
|
similarity = cosine_similarity(embedding1, embedding2)
|
|
|
return f"Cosine Similarity: {similarity:.4f}"
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Embedding Model Demo") as demo:
|
|
|
gr.Markdown("# Embedding Model Demo")
|
|
|
gr.Markdown("Generate and compare embeddings for text and images")
|
|
|
|
|
|
with gr.Tab("Text Embeddings"):
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
text_input1 = gr.Textbox(label="Text Input 1", lines=5)
|
|
|
task_dropdown1 = gr.Dropdown(
|
|
|
["retrieval", "text-matching", "code"],
|
|
|
label="Task",
|
|
|
value="retrieval"
|
|
|
)
|
|
|
text_embed_btn1 = gr.Button("Generate Embedding 1")
|
|
|
|
|
|
with gr.Column():
|
|
|
text_input2 = gr.Textbox(label="Text Input 2", lines=5)
|
|
|
task_dropdown2 = gr.Dropdown(
|
|
|
["retrieval", "text-matching", "code"],
|
|
|
label="Task",
|
|
|
value="retrieval"
|
|
|
)
|
|
|
text_embed_btn2 = gr.Button("Generate Embedding 2")
|
|
|
|
|
|
embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
|
|
|
embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
|
|
|
|
|
|
compare_btn = gr.Button("Compare Embeddings")
|
|
|
similarity_output = gr.Textbox(label="Similarity Result")
|
|
|
|
|
|
with gr.Tab("Image Embeddings"):
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
image_input1 = gr.Image(label="Image Input 1", type="pil")
|
|
|
image_task_dropdown1 = gr.Dropdown(
|
|
|
["retrieval"],
|
|
|
label="Task",
|
|
|
value="retrieval"
|
|
|
)
|
|
|
image_embed_btn1 = gr.Button("Generate Embedding 1")
|
|
|
|
|
|
with gr.Column():
|
|
|
image_input2 = gr.Image(label="Image Input 2", type="pil")
|
|
|
image_task_dropdown2 = gr.Dropdown(
|
|
|
["retrieval"],
|
|
|
label="Task",
|
|
|
value="retrieval"
|
|
|
)
|
|
|
image_embed_btn2 = gr.Button("Generate Embedding 2")
|
|
|
|
|
|
image_embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
|
|
|
image_embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
|
|
|
|
|
|
image_compare_btn = gr.Button("Compare Embeddings")
|
|
|
image_similarity_output = gr.Textbox(label="Similarity Result")
|
|
|
|
|
|
with gr.Tab("Cross-Modal Comparison"):
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
cross_text_input = gr.Textbox(label="Text Input", lines=5)
|
|
|
cross_text_task = gr.Dropdown(
|
|
|
["retrieval"],
|
|
|
label="Task",
|
|
|
value="retrieval"
|
|
|
)
|
|
|
cross_text_btn = gr.Button("Generate Text Embedding")
|
|
|
|
|
|
with gr.Column():
|
|
|
cross_image_input = gr.Image(label="Image Input", type="pil")
|
|
|
cross_image_task = gr.Dropdown(
|
|
|
["retrieval"],
|
|
|
label="Task",
|
|
|
value="retrieval"
|
|
|
)
|
|
|
cross_image_btn = gr.Button("Generate Image Embedding")
|
|
|
|
|
|
cross_text_embedding = gr.JSON(label="Text Embedding", visible=False)
|
|
|
cross_image_embedding = gr.JSON(label="Image Embedding", visible=False)
|
|
|
|
|
|
cross_compare_btn = gr.Button("Compare Text and Image")
|
|
|
cross_similarity_output = gr.Textbox(label="Similarity Result")
|
|
|
|
|
|
|
|
|
text_embed_btn1.click(
|
|
|
fn=text_to_embedding,
|
|
|
inputs=[text_input1, task_dropdown1],
|
|
|
outputs=embedding_output1
|
|
|
)
|
|
|
|
|
|
text_embed_btn2.click(
|
|
|
fn=text_to_embedding,
|
|
|
inputs=[text_input2, task_dropdown2],
|
|
|
outputs=embedding_output2
|
|
|
)
|
|
|
|
|
|
compare_btn.click(
|
|
|
fn=compare_embeddings,
|
|
|
inputs=[embedding_output1, embedding_output2],
|
|
|
outputs=similarity_output
|
|
|
)
|
|
|
|
|
|
|
|
|
image_embed_btn1.click(
|
|
|
fn=image_to_embedding,
|
|
|
inputs=[image_input1, image_task_dropdown1],
|
|
|
outputs=image_embedding_output1
|
|
|
)
|
|
|
|
|
|
image_embed_btn2.click(
|
|
|
fn=image_to_embedding,
|
|
|
inputs=[image_input2, image_task_dropdown2],
|
|
|
outputs=image_embedding_output2
|
|
|
)
|
|
|
|
|
|
image_compare_btn.click(
|
|
|
fn=compare_embeddings,
|
|
|
inputs=[image_embedding_output1, image_embedding_output2],
|
|
|
outputs=image_similarity_output
|
|
|
)
|
|
|
|
|
|
|
|
|
cross_text_btn.click(
|
|
|
fn=text_to_embedding,
|
|
|
inputs=[cross_text_input, cross_text_task],
|
|
|
outputs=cross_text_embedding
|
|
|
)
|
|
|
|
|
|
cross_image_btn.click(
|
|
|
fn=image_to_embedding,
|
|
|
inputs=[cross_image_input, cross_image_task],
|
|
|
outputs=cross_image_embedding
|
|
|
)
|
|
|
|
|
|
cross_compare_btn.click(
|
|
|
fn=compare_embeddings,
|
|
|
inputs=[cross_text_embedding, cross_image_embedding],
|
|
|
outputs=cross_similarity_output
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.launch()
|
|
|
|