wangpangintsig commited on
Commit
c25fc46
·
verified ·
1 Parent(s): 55c1d88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from torch import Tensor, nn
 
3
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
4
 
5
 
@@ -42,9 +43,24 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
42
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
43
  return HFEmbedder("city96/t5-v1_1-xxl-encoder-bf16", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  if __name__ == "__main__":
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
- t5 = load_t5(device, max_length=256)
48
- texts = ["这是一个测试句子。"]
49
- embeddings = t5(texts)
50
- print(f"Embeddings shape: {embeddings.shape}")
 
 
 
 
1
  import torch
2
  from torch import Tensor, nn
3
+ import gradio as gr
4
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
5
 
6
 
 
43
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
44
  return HFEmbedder("city96/t5-v1_1-xxl-encoder-bf16", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
45
 
46
+ def run_t5_and_save(text):
47
+ try:
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ print(f"使用设备: {device}")
50
+ t5 = load_t5(device, max_length=512)
51
+ embeddings = t5([text])
52
+ output_path = "/home/user/embeddings.pt"
53
+ torch.save(embeddings, output_path)
54
+ return f"嵌入形状: {embeddings.shape}", output_path
55
+ except Exception as e:
56
+ return f"运行错误: {e}", None
57
+
58
  if __name__ == "__main__":
59
+ iface = gr.Interface(
60
+ fn=run_t5_and_save,
61
+ inputs=gr.Textbox(label="输入文本"),
62
+ outputs=[gr.Textbox(label="结果"), gr.File(label="下载嵌入文件")],
63
+ title="T5 Embedder",
64
+ description="输入文本,生成 T5 嵌入并保存为文件"
65
+ )
66
+ iface.launch()