Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import Tensor, nn
|
|
|
|
| 3 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
| 4 |
|
| 5 |
|
|
@@ -42,9 +43,24 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
|
|
| 42 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 43 |
return HFEmbedder("city96/t5-v1_1-xxl-encoder-bf16", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
if __name__ == "__main__":
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import Tensor, nn
|
| 3 |
+
import gradio as gr
|
| 4 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
| 5 |
|
| 6 |
|
|
|
|
| 43 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 44 |
return HFEmbedder("city96/t5-v1_1-xxl-encoder-bf16", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
| 45 |
|
| 46 |
+
def run_t5_and_save(text):
|
| 47 |
+
try:
|
| 48 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
+
print(f"使用设备: {device}")
|
| 50 |
+
t5 = load_t5(device, max_length=512)
|
| 51 |
+
embeddings = t5([text])
|
| 52 |
+
output_path = "/home/user/embeddings.pt"
|
| 53 |
+
torch.save(embeddings, output_path)
|
| 54 |
+
return f"嵌入形状: {embeddings.shape}", output_path
|
| 55 |
+
except Exception as e:
|
| 56 |
+
return f"运行错误: {e}", None
|
| 57 |
+
|
| 58 |
if __name__ == "__main__":
|
| 59 |
+
iface = gr.Interface(
|
| 60 |
+
fn=run_t5_and_save,
|
| 61 |
+
inputs=gr.Textbox(label="输入文本"),
|
| 62 |
+
outputs=[gr.Textbox(label="结果"), gr.File(label="下载嵌入文件")],
|
| 63 |
+
title="T5 Embedder",
|
| 64 |
+
description="输入文本,生成 T5 嵌入并保存为文件"
|
| 65 |
+
)
|
| 66 |
+
iface.launch()
|