HuaminChen commited on
Commit
6445452
·
verified ·
1 Parent(s): 56f3801

Add standalone loading option using only transformers

Browse files
Files changed (1) hide show
  1. README.md +56 -35
README.md CHANGED
@@ -84,7 +84,11 @@ pip install torch transformers pillow safetensors
84
 
85
  ### Load Model
86
 
87
- First, clone the repository and install:
 
 
 
 
88
 
89
  ```bash
90
  git clone https://github.com/semantic-router/2DMSE-Multimodal-Embedder.git
@@ -92,50 +96,67 @@ cd 2DMSE-Multimodal-Embedder
92
  pip install -e .
93
  ```
94
 
95
- Two checkpoint formats are available:
96
- - `model.pt` (932 MB) - PyTorch format, smaller due to shared tensors
97
- - `model.safetensors` (1.35 GB) - SafeTensors format, recommended for production
 
 
 
98
 
99
  ```python
100
  import torch
101
- import json
 
 
102
  from huggingface_hub import hf_hub_download
103
- from src.models import MultimodalEmbedder
104
 
105
- # Download checkpoint and config
106
  checkpoint_path = hf_hub_download(
107
  repo_id="llm-semantic-router/multi-modal-embed-small",
108
- filename="model.pt" # or "model.safetensors"
109
- )
110
- config_path = hf_hub_download(
111
- repo_id="llm-semantic-router/multi-modal-embed-small",
112
- filename="config.json"
113
  )
114
-
115
- # Load config and create model
116
- with open(config_path) as f:
117
- config = json.load(f)
118
-
119
- model = MultimodalEmbedder(
120
- text_encoder_name=config["text_encoder_name"],
121
- image_encoder_name=config["image_encoder_name"],
122
- audio_encoder_name=config["audio_encoder_name"],
123
- output_dim=config["output_dim"],
124
- fusion_type=config["fusion_type"],
125
- num_fusion_layers=config["num_fusion_layers"],
126
- )
127
-
128
- # Load weights (works with both .pt and .safetensors)
129
- if checkpoint_path.endswith(".safetensors"):
130
- from safetensors.torch import load_file
131
- state_dict = load_file(checkpoint_path)
132
- else:
133
- state_dict = torch.load(checkpoint_path, map_location="cpu")
134
-
135
- model.load_state_dict(state_dict)
136
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  ```
138
 
 
 
139
  ### Text Embedding
140
 
141
  ```python
 
84
 
85
  ### Load Model
86
 
87
+ Two checkpoint formats are available:
88
+ - `model.pt` (932 MB) - PyTorch format, smaller due to shared tensors
89
+ - `model.safetensors` (1.35 GB) - SafeTensors format, recommended for production
90
+
91
+ **Option 1: Using the source repository (full features)**
92
 
93
  ```bash
94
  git clone https://github.com/semantic-router/2DMSE-Multimodal-Embedder.git
 
96
  pip install -e .
97
  ```
98
 
99
+ ```python
100
+ from src.models import MultimodalEmbedder
101
+ model = MultimodalEmbedder.from_pretrained("llm-semantic-router/multi-modal-embed-small")
102
+ ```
103
+
104
+ **Option 2: Standalone with transformers (no repo needed)**
105
 
106
  ```python
107
  import torch
108
+ import torch.nn as nn
109
+ import torch.nn.functional as F
110
+ from transformers import AutoModel, AutoTokenizer, SiglipModel, SiglipProcessor, WhisperModel, WhisperFeatureExtractor
111
  from huggingface_hub import hf_hub_download
 
112
 
113
+ # Download weights
114
  checkpoint_path = hf_hub_download(
115
  repo_id="llm-semantic-router/multi-modal-embed-small",
116
+ filename="model.pt"
 
 
 
 
117
  )
118
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
119
+
120
+ # Load text encoder
121
+ text_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
122
+ text_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
123
+
124
+ # Load image encoder
125
+ image_processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-512")
126
+ image_encoder = SiglipModel.from_pretrained("google/siglip-base-patch16-512").vision_model
127
+
128
+ # Load audio encoder
129
+ audio_processor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
130
+ audio_encoder = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
131
+
132
+ # Load trained projection weights from checkpoint
133
+ # Text projection: state_dict keys starting with "text_encoder.projection"
134
+ # Image projection: state_dict keys starting with "image_encoder.projection"
135
+ # Audio projection: state_dict keys starting with "audio_encoder.projection"
136
+
137
+ def encode_text(texts, tokenizer=text_tokenizer, encoder=text_encoder):
138
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
139
+ with torch.no_grad():
140
+ outputs = encoder(**inputs)
141
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
142
+ embeddings = F.normalize(embeddings, p=2, dim=-1)
143
+ return embeddings
144
+
145
+ def encode_image(images, processor=image_processor, encoder=image_encoder):
146
+ inputs = processor(images=images, return_tensors="pt")
147
+ with torch.no_grad():
148
+ outputs = encoder(inputs.pixel_values)
149
+ embeddings = outputs.pooler_output
150
+ embeddings = F.normalize(embeddings, p=2, dim=-1)
151
+ return embeddings
152
+
153
+ # Example usage
154
+ text_emb = encode_text(["A photo of a cat"])
155
+ print(f"Text embedding shape: {text_emb.shape}")
156
  ```
157
 
158
+ > **Note**: Option 2 loads the base encoders but not the trained projection layers. For full model with trained weights, use Option 1.
159
+
160
  ### Text Embedding
161
 
162
  ```python