embedding / app.py
Purgatorium's picture
Create app.py
810908b verified
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import torch
# Load the model (Qwen3-Embedding-0.6B)
# trust_remote_code is required for some Qwen architectures
model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True)
def embed_text(text, is_query=True):
"""
API 1: Text Embedding
Qwen3 benefits from a 'query' prompt for retrieval tasks.
"""
prompt_name = "query" if is_query else None
embedding = model.encode(text, prompt_name=prompt_name, convert_to_tensor=True)
return embedding.tolist()
def calculate_similarity(text_a, text_b, is_query=True):
"""
API 2: Embedding Similarity
Returns a float between 0 and 1 (clamped) representing the similarity.
"""
prompt_name = "query" if is_query else None
# Encode both texts
emb_a = model.encode(text_a, prompt_name=prompt_name, convert_to_tensor=True)
emb_b = model.encode(text_b, prompt_name=prompt_name, convert_to_tensor=True)
# Compute Cosine Similarity
similarity = util.cos_sim(emb_a, emb_b).item()
# Clamp to [0, 1] for "percentage" logic
score = max(0, min(1, similarity))
percentage = f"{score * 100:.2f}%"
return score, percentage
# Building the Gradio UI
with gr.Blocks(title="Qwen3 Embedding API") as demo:
gr.Markdown("# Qwen3-Embedding-0.6B API & UI")
gr.Markdown("This space provides high-quality text embeddings and similarity scores.")
with gr.Tab("Text Embedding"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to embed...")
is_query_toggle = gr.Checkbox(label="Is this a search query?", value=True, info="Uses Qwen's specific query prompt for better retrieval.")
btn_embed = gr.Button("Generate Embedding", variant="primary")
with gr.Column():
output_vec = gr.JSON(label="Embedding Vector (Truncated in UI)")
btn_embed.click(fn=embed_text, inputs=[input_text, is_query_toggle], outputs=output_vec, api_name="embed")
with gr.Tab("Similarity Score"):
with gr.Row():
with gr.Column():
text_a = gr.Textbox(label="Text A", placeholder="First sentence...")
text_b = gr.Textbox(label="Text B", placeholder="Second sentence...")
is_query_sim = gr.Checkbox(label="Use query prompts?", value=True)
btn_sim = gr.Button("Compare Texts", variant="primary")
with gr.Column():
sim_float = gr.Number(label="Similarity Score (0-1)")
sim_percent = gr.Label(label="Match Percentage")
# API returns both the float and the label string
btn_sim.click(fn=calculate_similarity, inputs=[text_a, text_b, is_query_sim], outputs=[sim_float, sim_percent], api_name="similarity")
gr.Markdown("""
### How to use the API
You can call these endpoints programmatically using the Gradio Python Client:
```python
from gradion_client import Client
client = Client("your-username/your-space-name")
# API 1: Embedding
result = client.predict("Hello world", True, api_name="/embed")
# API 2: Similarity
score, percent = client.predict("Text A", "Text B", True, api_name="/similarity")
```
""")
if __name__ == "__main__":
demo.launch()