maxious commited on
Commit
e587c3d
·
verified ·
1 Parent(s): 9635ff7

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ embeddinggemma-300m/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gemma3 Embedding Model: ONNX Conversion Demonstration
2
+
3
+ This repository demonstrates the conversion and comparison of the Gemma3 embedding model from Hugging Face to ONNX format using optimum-onnx. It includes scripts for both ONNX and PyTorch inference pipelines, as well as a comparison of their outputs.
4
+
5
+ ## Files
6
+
7
+ - `onnx_gemma3_pipeline.py`: Runs the Gemma3 embedding model using ONNXRuntime, including post-processing steps (Pooling, Dense, Normalize) with ONNX exported layers.
8
+ - `pytorch_gemma3_pipeline.py`: Runs the original Gemma3 embedding model using PyTorch and SentenceTransformer for reference.
9
+ - `compare_gemma3_onnx_vs_pytorch.py`: Compares the output embeddings and cosine similarities between the ONNX and PyTorch pipelines.
10
+ - `download_missing_hf_files.py`: Downloads required files from Hugging Face and exports Dense layers to ONNX.
11
+ - `gemma3_mean_pooling_basic.py`: The most basic implementation, running Gemma3 ONNX inference with only mean pooling (no Dense or Normalize stages).
12
+
13
+ ## Pipeline Differences
14
+
15
+ Both pipelines use ONNXRuntime for transformer inference via `ORTModelForFeatureExtraction`. The key difference is in post-processing:
16
+
17
+ - **ONNX pipeline** (`onnx_gemma3_pipeline.py`): Uses ONNXRuntime for both the transformer and Dense layers (exported to ONNX), making most of the pipeline ONNX-based except for normalization.
18
+ - **PyTorch pipeline** (`pytorch_gemma3_pipeline.py`): Uses ONNXRuntime for the transformer, but all post-processing (Pooling, Dense, Normalize) is performed with PyTorch modules from SentenceTransformer.
19
+
20
+ This demonstrates how ONNX conversion can offload more computation for faster, hardware-agnostic inference, while the PyTorch pipeline serves as the reference implementation.
21
+
22
+ ## Setup
23
+
24
+ 1. Install dependencies:
25
+ ```sh
26
+ pip install git+https://github.com/simondanielsson/optimum-onnx.git@feature/add-gemma3-export
27
+ pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
28
+ pip install sentence-transformers onnxruntime safetensors huggingface_hub
29
+ ```
30
+ 2. Export the ONNX model:
31
+ ```sh
32
+ optimum-cli export onnx --model google/embeddinggemma-300m-qat-q4_0-unquantized embeddinggemma-300m-onnx
33
+ python download_missing_hf_files.py
34
+ ```
35
+
36
+ ## Usage
37
+
38
+ - Run the ONNX pipeline:
39
+ ```sh
40
+ python onnx_gemma3_pipeline.py
41
+ ```
42
+ - Run the PyTorch pipeline:
43
+ ```sh
44
+ python pytorch_gemma3_pipeline.py
45
+ ```
46
+ - Compare outputs:
47
+ ```sh
48
+ python compare_gemma3_onnx_vs_pytorch.py
49
+ ```
50
+
51
+ ## Results
52
+
53
+ The comparison script prints cosine similarities between sample word embeddings (e.g., "apple", "banana", "car") for both ONNX and PyTorch pipelines, demonstrating the fidelity of the ONNX conversion.
54
+
55
+ ## References
56
+ - [Optimum-ONNX Gemma3 PR](https://github.com/huggingface/optimum-onnx/pull/50)
57
+ - [Gemma3 Model](https://huggingface.co/google/embeddinggemma-300m-qat-q4_0-unquantized)
__pycache__/onnx_gemma3_pipeline.cpython-312.pyc ADDED
Binary file (6.63 kB). View file
 
compare_gemma3_onnx_vs_pytorch.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+ import numpy as np
4
+
5
+ # Words to compare
6
+ words = ["apple", "banana", "car"]
7
+
8
+ # Load original SentenceTransformer (PyTorch, CUDA)
9
+ st_model = SentenceTransformer("google/embeddinggemma-300m-qat-q4_0-unquantized")
10
+ st_model = st_model.to("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Get PyTorch embeddings
13
+ with torch.no_grad():
14
+ pt_embeddings = st_model.encode(words, convert_to_numpy=True, device="cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ from onnx_gemma3_pipeline import onnx_st
17
+ from transformers import AutoTokenizer
18
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
19
+
20
+ # Basic mean pooling ONNX implementation
21
+ def basic_mean_pooling(words):
22
+ tokenizer = AutoTokenizer.from_pretrained("./embeddinggemma-300m")
23
+ model = ORTModelForFeatureExtraction.from_pretrained("./embeddinggemma-300m")
24
+ embeddings = []
25
+ for word in words:
26
+ inputs = tokenizer(word, return_tensors="pt")
27
+ input_ids = inputs['input_ids']
28
+ sequence_length = input_ids.shape[1]
29
+ position_ids = np.arange(sequence_length)[None, :]
30
+ position_ids = np.tile(position_ids, (input_ids.shape[0], 1))
31
+ inputs['position_ids'] = torch.tensor(position_ids, dtype=torch.long)
32
+ outputs = model(**inputs)
33
+ last_hidden = outputs.last_hidden_state
34
+ attention_mask = inputs['attention_mask']
35
+ from sentence_transformers import models
36
+ pooling = models.Pooling(word_embedding_dimension=last_hidden.shape[-1], pooling_mode_mean_tokens=True)
37
+ features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask}
38
+ pooled = pooling(features)['sentence_embedding']
39
+ embeddings.append(pooled[0].detach().cpu().numpy())
40
+ return np.stack(embeddings)
41
+ from transformers import AutoTokenizer
42
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
43
+ onnx_embeddings = onnx_st.encode(words)
44
+
45
+ # Cosine similarity function
46
+ def cosine_similarity(a, b):
47
+ a = a.flatten()
48
+ b = b.flatten()
49
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
50
+
51
+ print("Safetensor Cosine similarities:")
52
+ print(f"apple vs banana: {cosine_similarity(pt_embeddings[0], pt_embeddings[1]):.4f}")
53
+ print(f"apple vs car: {cosine_similarity(pt_embeddings[0], pt_embeddings[2]):.4f}")
54
+ print(f"banana vs car: {cosine_similarity(pt_embeddings[1], pt_embeddings[2]):.4f}")
55
+
56
+ print("\nONNX Cosine similarities:")
57
+ print(f"apple vs banana: {cosine_similarity(onnx_embeddings[0], onnx_embeddings[1]):.4f}")
58
+ print(f"apple vs car: {cosine_similarity(onnx_embeddings[0], onnx_embeddings[2]):.4f}")
59
+ print(f"banana vs car: {cosine_similarity(onnx_embeddings[1], onnx_embeddings[2]):.4f}")
60
+
61
+ # Basic mean pooling ONNX pipeline
62
+ basic_embeddings = basic_mean_pooling(words)
63
+ print("\nBasic ONNX (mean pooling only) Cosine similarities:")
64
+ print(f"apple vs banana: {cosine_similarity(basic_embeddings[0], basic_embeddings[1]):.4f}")
65
+ print(f"apple vs car: {cosine_similarity(basic_embeddings[0], basic_embeddings[2]):.4f}")
66
+ print(f"banana vs car: {cosine_similarity(basic_embeddings[1], basic_embeddings[2]):.4f}")
download_missing_hf_files.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ import os
3
+ import shutil
4
+ from sentence_transformers import SentenceTransformer
5
+ import torch
6
+
7
+ # Model repo and local directory
8
+ repo_id = "google/embeddinggemma-300m-qat-q4_0-unquantized"
9
+ local_dir = "embeddinggemma-300m"
10
+
11
+ # Download all files except model.safetensors and those already present
12
+ existing_files = set(os.listdir(local_dir))
13
+
14
+ # Download snapshot to a temp dir
15
+ temp_dir = "_hf_temp_download"
16
+ os.makedirs(temp_dir, exist_ok=True)
17
+ snapshot_download(
18
+ repo_id,
19
+ local_dir=temp_dir,
20
+ ignore_patterns=["model.safetensors"],
21
+ resume_download=True,
22
+ allow_patterns=None
23
+ )
24
+
25
+ # Copy missing files
26
+ for fname in os.listdir(temp_dir):
27
+ if fname not in existing_files:
28
+ shutil.move(os.path.join(temp_dir, fname), os.path.join(local_dir, fname))
29
+ print(f"Downloaded: {fname}")
30
+ else:
31
+ print(f"Already exists: {fname}")
32
+
33
+ # Clean up temp dir
34
+ shutil.rmtree(temp_dir)
35
+ print("Done.")
36
+
37
+ # Export Dense layers from SentenceTransformer to ONNX
38
+ st_model = SentenceTransformer(repo_id)
39
+ dense1 = st_model[2].linear
40
+ dense2 = st_model[3].linear
41
+
42
+ onnx_dir = os.path.join(local_dir, "onnx")
43
+ os.makedirs(onnx_dir, exist_ok=True)
44
+
45
+ # Export Dense1
46
+ dummy_input1 = torch.randn(1, dense1.in_features)
47
+ dense1 = dense1.to(dummy_input1.device)
48
+ torch.onnx.export(
49
+ dense1,
50
+ dummy_input1,
51
+ os.path.join(onnx_dir, "dense1.onnx"),
52
+ input_names=["input"],
53
+ output_names=["output"],
54
+ opset_version=14
55
+ )
56
+ print("Exported dense1.onnx")
57
+
58
+ # Export Dense2
59
+ dummy_input2 = torch.randn(1, dense2.in_features)
60
+ dense2 = dense2.to(dummy_input2.device)
61
+ torch.onnx.export(
62
+ dense2,
63
+ dummy_input2,
64
+ os.path.join(onnx_dir, "dense2.onnx"),
65
+ input_names=["input"],
66
+ output_names=["output"],
67
+ opset_version=14
68
+ )
69
+ print("Exported dense2.onnx")
embeddinggemma-300m/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
embeddinggemma-300m/1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
embeddinggemma-300m/2_Dense/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "in_features": 768,
3
+ "out_features": 3072,
4
+ "bias": false,
5
+ "activation_function": "torch.nn.modules.linear.Identity"
6
+ }
embeddinggemma-300m/3_Dense/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "in_features": 3072,
3
+ "out_features": 768,
4
+ "bias": false,
5
+ "activation_function": "torch.nn.modules.linear.Identity"
6
+ }
embeddinggemma-300m/README.md ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gemma
3
+ pipeline_tag: sentence-similarity
4
+ library_name: sentence-transformers
5
+ tags:
6
+ - sentence-transformers
7
+ - sentence-similarity
8
+ - feature-extraction
9
+ extra_gated_heading: Access EmbeddingGemma on Hugging Face
10
+ extra_gated_prompt: To access EmbeddingGemma on Hugging Face, you’re required to review and
11
+ agree to Google’s usage license. To do this, please ensure you’re logged in to Hugging
12
+ Face and click below. Requests are processed immediately.
13
+ extra_gated_button_content: Acknowledge license
14
+ ---
15
+
16
+ # EmbeddingGemma model card
17
+
18
+ **Model Page**: [EmbeddingGemma](https://ai.google.dev/gemma/docs/embeddinggemma)
19
+
20
+ **Resources and Technical Documentation**:
21
+
22
+ * [Responsible Generative AI Toolkit](https://ai.google.dev/responsible)
23
+ * [EmbeddingGemma on Kaggle](https://www.kaggle.com/models/google/embeddinggemma/)
24
+ * [EmbeddingGemma on Vertex Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/embeddinggemma)
25
+
26
+ **Terms of Use**: [Terms](https://ai.google.dev/gemma/terms)
27
+
28
+ **Authors**: Google DeepMind
29
+
30
+ ## Model Information
31
+
32
+ ### Description
33
+
34
+ EmbeddingGemma is a 300M parameter, state-of-the-art for its size, open embedding model from Google, built from Gemma 3 (with T5Gemma initialization) and the same research and technology used to create Gemini models. EmbeddingGemma produces vector representations of text, making it well-suited for search and retrieval tasks, including classification, clustering, and semantic similarity search. This model was trained with data in 100+ spoken languages.
35
+
36
+ The small size and on-device focus makes it possible to deploy in environments with limited resources such as mobile phones, laptops, or desktops, democratizing access to state of the art AI models and helping foster innovation for everyone.
37
+
38
+ ### Inputs and outputs
39
+
40
+ - **Input:**
41
+ - Text string, such as a question, a prompt, or a document to be embedded
42
+ - Maximum input context length of 2048 tokens
43
+
44
+ - **Output:**
45
+ - Numerical vector representations of input text data
46
+ - Output embedding dimension size of 768, with smaller options available (512, 256, or 128) via Matryoshka Representation Learning (MRL). MRL allows users to truncate the output embedding of size 768 to their desired size and then re-normalize for efficient and accurate representation.
47
+
48
+ ### Usage
49
+
50
+ These model weights are designed to be used with [Sentence Transformers](https://www.SBERT.net), using the [Gemma 3](https://huggingface.co/docs/transformers/main/en/model_doc/gemma3) implementation from [Hugging Face Transformers](https://huggingface.co/docs/transformers/en/index) as the backbone.
51
+
52
+ First install the Sentence Transformers library:
53
+
54
+ ```bash
55
+ pip install -U sentence-transformers
56
+ ```
57
+
58
+ Then you can load this model and run inference.
59
+
60
+ ```python
61
+ from sentence_transformers import SentenceTransformer
62
+
63
+ # Download from the 🤗 Hub
64
+ model = SentenceTransformer("google/embeddinggemma-300m")
65
+
66
+ # Run inference with queries and documents
67
+ query = "Which planet is known as the Red Planet?"
68
+ documents = [
69
+ "Venus is often called Earth's twin because of its similar size and proximity.",
70
+ "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
71
+ "Jupiter, the largest planet in our solar system, has a prominent red spot.",
72
+ "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
73
+ ]
74
+ query_embeddings = model.encode_query(query)
75
+ document_embeddings = model.encode_document(documents)
76
+ print(query_embeddings.shape, document_embeddings.shape)
77
+ # (768,) (4, 768)
78
+
79
+ # Compute similarities to determine a ranking
80
+ similarities = model.similarity(query_embeddings, document_embeddings)
81
+ print(similarities)
82
+ # tensor([[0.3011, 0.6359, 0.4930, 0.4889]])
83
+ ```
84
+
85
+ **NOTE**: EmbeddingGemma activations do not support `float16`. Please use `float32` or `bfloat16` as appropriate for your hardware.
86
+
87
+ ## Model Data
88
+
89
+ ### Training Dataset
90
+
91
+ This model was trained on a dataset of text data that includes a wide variety of sources totaling approximately 320 billion tokens. Here are the key components:
92
+
93
+ - **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. The training dataset includes content in over 100 languages.
94
+ - **Code and Technical Documents**: Exposing the model to code and technical documentation helps it learn the structure and patterns of programming languages and specialized scientific content, which improves its understanding of code and technical questions.
95
+ - **Synthetic and Task-Specific Data**: Synthetically training data helps to teach the model specific skills. This includes curated data for tasks like information retrieval, classification, and sentiment analysis, which helps to fine-tune its performance for common embedding applications.
96
+
97
+ The combination of these diverse data sources is crucial for training a powerful multilingual embedding model that can handle a wide variety of different tasks and data formats.
98
+
99
+ ### Data Preprocessing
100
+
101
+ Here are the key data cleaning and filtering methods applied to the training data:
102
+
103
+ - CSAM Filtering: Rigorous CSAM (Child Sexual Abuse Material) filtering was applied at multiple stages in the data preparation process to ensure the exclusion of harmful and illegal content.
104
+ - Sensitive Data Filtering: As part of making Gemma pre-trained models safe and reliable, automated techniques were used to filter out certain personal information and other sensitive data from training sets.
105
+ - Additional methods: Filtering based on content quality and safety in line with [our policies](https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf).
106
+
107
+ ## Model Development
108
+
109
+ ### Hardware
110
+
111
+ EmbeddingGemma was trained using the latest generation of [Tensor Processing Unit (TPU)](https://cloud.google.com/tpu/docs/intro-to-tpu) hardware (TPUv5e), for more details refer to the [Gemma 3 model card](https://ai.google.dev/gemma/docs/core/model_card_3).
112
+
113
+ ### Software
114
+
115
+ Training was done using [JAX](https://github.com/jax-ml/jax) and [ML Pathways](https://blog.google/technology/ai/introducing-pathways-next-generation-ai-architecture/). For more details refer to the [Gemma 3 model card](https://ai.google.dev/gemma/docs/core/model_card_3).
116
+
117
+ ## Evaluation
118
+
119
+ ### Benchmark Results
120
+
121
+ The model was evaluated against a large collection of different datasets and metrics to cover different aspects of text understanding.
122
+
123
+ #### Full Precision Checkpoint
124
+
125
+ <table>
126
+ <thead>
127
+ <tr>
128
+ <th colspan="3"><strong>MTEB (Multilingual, v2)</strong></th>
129
+ </tr>
130
+ </thead>
131
+ <tbody>
132
+ <tr>
133
+ <td><strong>Dimensionality</strong></td>
134
+ <td><strong>Mean (Task)</strong></td>
135
+ <td><strong>Mean (TaskType)</strong></td>
136
+ </tr>
137
+ <tr>
138
+ <td>768d</td>
139
+ <td>61.15</td>
140
+ <td>54.31</td>
141
+ </tr>
142
+ <tr>
143
+ <td>512d</td>
144
+ <td>60.71</td>
145
+ <td>53.89</td>
146
+ </tr>
147
+ <tr>
148
+ <td>256d</td>
149
+ <td>59.68</td>
150
+ <td>53.01</td>
151
+ </tr>
152
+ <tr>
153
+ <td>128d</td>
154
+ <td>58.23</td>
155
+ <td>51.77</td>
156
+ </tr>
157
+ </tbody>
158
+ </table>
159
+
160
+ <table>
161
+ <thead>
162
+ <tr>
163
+ <th colspan="3"><strong>MTEB (English, v2)</strong></th>
164
+ </tr>
165
+ </thead>
166
+ <tbody>
167
+ <tr>
168
+ <td><strong>Dimensionality</strong></td>
169
+ <td><strong>Mean (Task)</strong></td>
170
+ <td><strong>Mean (TaskType)</strong></td>
171
+ </tr>
172
+ <tr>
173
+ <td>768d</td>
174
+ <td>68.36</td>
175
+ <td>64.15</td>
176
+ </tr>
177
+ <tr>
178
+ <td>512d</td>
179
+ <td>67.80</td>
180
+ <td>63.59</td>
181
+ </tr>
182
+ <tr>
183
+ <td>256d</td>
184
+ <td>66.89</td>
185
+ <td>62.94</td>
186
+ </tr>
187
+ <tr>
188
+ <td>128d</td>
189
+ <td>65.09</td>
190
+ <td>61.56</td>
191
+ </tr>
192
+ </tbody>
193
+ </table>
194
+
195
+ <table>
196
+ <thead>
197
+ <tr>
198
+ <th colspan="3"><strong>MTEB (Code, v1)</strong></th>
199
+ </tr>
200
+ </thead>
201
+ <tbody>
202
+ <tr>
203
+ <td><strong>Dimensionality</strong></td>
204
+ <td><strong>Mean (Task)</strong></td>
205
+ <td><strong>Mean (TaskType)</strong></td>
206
+ </tr>
207
+ <tr>
208
+ <td>768d</td>
209
+ <td>68.76</td>
210
+ <td>68.76</td>
211
+ </tr>
212
+ <tr>
213
+ <td>512d</td>
214
+ <td>68.48</td>
215
+ <td>68.48</td>
216
+ </tr>
217
+ <tr>
218
+ <td>256d</td>
219
+ <td>66.74</td>
220
+ <td>66.74</td>
221
+ </tr>
222
+ <tr>
223
+ <td>128d</td>
224
+ <td>62.96</td>
225
+ <td>62.96</td>
226
+ </tr>
227
+ </tbody>
228
+ </table>
229
+
230
+ #### QAT Checkpoints
231
+
232
+ <table>
233
+ <thead>
234
+ <tr>
235
+ <th colspan="3"><strong>MTEB (Multilingual, v2)</strong></th>
236
+ </tr>
237
+ </thead>
238
+ <tbody>
239
+ <tr>
240
+ <td><strong>Quant config (dimensionality)</strong></td>
241
+ <td><strong>Mean (Task)</strong></td>
242
+ <td><strong>Mean (TaskType)</strong></td>
243
+ </tr>
244
+ <tr>
245
+ <td><em><strong>Q4_0 (768d)</strong></em></td>
246
+ <td><em><strong>60.62</strong></em></td>
247
+ <td><em><strong>53.61</strong></em></td>
248
+ </tr>
249
+ <tr>
250
+ <td>Q8_0 (768d)</td>
251
+ <td>60.93</td>
252
+ <td>53.95</td>
253
+ </tr>
254
+ <tr>
255
+ <td>Mixed Precision* (768d)</td>
256
+ <td>60.69</td>
257
+ <td>53.82</td>
258
+ </tr>
259
+ </tbody>
260
+ </table>
261
+
262
+ <table>
263
+ <thead>
264
+ <tr>
265
+ <th colspan="3"><strong>MTEB (English, v2)</strong></th>
266
+ </tr>
267
+ </thead>
268
+ <tbody>
269
+ <tr>
270
+ <td><strong>Quant config (dimensionality)</strong></td>
271
+ <td><strong>Mean (Task)</strong></td>
272
+ <td><strong>Mean (TaskType)</strong></td>
273
+ </tr>
274
+ <tr>
275
+ <td><em><strong>Q4_0 (768d)</strong></em></td>
276
+ <td><em><strong>67.91</strong></em></td>
277
+ <td><em><strong>63.64</strong></em></td>
278
+ </tr>
279
+ <tr>
280
+ <td>Q8_0 (768d)</td>
281
+ <td>68.13</td>
282
+ <td>63.85</td>
283
+ </tr>
284
+ <tr>
285
+ <td>Mixed Precision* (768d)</td>
286
+ <td>67.95</td>
287
+ <td>63.83</td>
288
+ </tr>
289
+ </tbody>
290
+ </table>
291
+
292
+ <table>
293
+ <thead>
294
+ <tr>
295
+ <th colspan="3"><strong>MTEB (Code, v1)</strong></th>
296
+ </tr>
297
+ </thead>
298
+ <tbody>
299
+ <tr>
300
+ <td><strong>Quant config (dimensionality)</strong></td>
301
+ <td><strong>Mean (Task)</strong></td>
302
+ <td><strong>Mean (TaskType)</strong></td>
303
+ </tr>
304
+ <tr>
305
+ <td><em><strong>Q4_0 (768d)</strong></em></td>
306
+ <td><em><strong>67.99</strong></em></td>
307
+ <td><em><strong>67.99</strong></em></td>
308
+ </tr>
309
+ <tr>
310
+ <td>Q8_0 (768d)</td>
311
+ <td>68.70</td>
312
+ <td>68.70</td>
313
+ </tr>
314
+ <tr>
315
+ <td>Mixed Precision* (768d)</td>
316
+ <td>68.03</td>
317
+ <td>68.03</td>
318
+ </tr>
319
+ </tbody>
320
+ </table>
321
+
322
+ Note: QAT models are evaluated after quantization
323
+
324
+ \* Mixed Precision refers to per-channel quantization with int4 for embeddings, feedforward, and projection layers, and int8 for attention (e4_a8_f4_p4).
325
+
326
+ ### Prompt Instructions
327
+
328
+ EmbeddingGemma can generate optimized embeddings for various use cases—such as document retrieval, question answering, and fact verification—or for specific input types—either a query or a document—using prompts that are prepended to the input strings.
329
+ Query prompts follow the form `task: {task description} | query: ` where the task description varies by the use case, with the default task description being `search result`. Document-style prompts follow the form `title: {title | "none"} | text: ` where the title is either `none` (the default) or the actual title of the document. Note that providing a title, if available, will improve model performance for document prompts but may require manual formatting.
330
+
331
+ Use the following prompts based on your use case and input data type. These may already be available in the EmbeddingGemma configuration in your modeling framework of choice.
332
+
333
+ <table>
334
+ <thead>
335
+ <tr>
336
+ <th><br>
337
+ <strong>Use Case (task type enum)</strong></th>
338
+ <th><br>
339
+ <strong>Descriptions</strong></th>
340
+ <th><br>
341
+ <strong>Recommended Prompt</strong></th>
342
+ </tr>
343
+ </thead>
344
+ <tbody>
345
+ <tr>
346
+ <td><br>
347
+ Retrieval (Query)</td>
348
+ <td rowspan="4"><br>
349
+ Used to generate embeddings that are optimized for document search or information retrieval</td>
350
+ <td><br>
351
+ task: search result | query: {content}</td>
352
+ </tr>
353
+ <tr>
354
+ <td><br>
355
+ Retrieval (Document)</td>
356
+ <td><br>
357
+ title: {title | "none"} | text: {content}</td>
358
+ </tr>
359
+ <tr>
360
+ <td><br>
361
+ Question Answering</td>
362
+ <td><br>
363
+ task: question answering | query: {content}</td>
364
+ </tr>
365
+ <tr>
366
+ <td><br>
367
+ Fact Verification</td>
368
+ <td><br>
369
+ task: fact checking | query: {content}</td>
370
+ </tr>
371
+ <tr>
372
+ <td><br>
373
+ Classification</td>
374
+ <td><br>
375
+ Used to generate embeddings that are optimized to classify texts according to preset labels</td>
376
+ <td><br>
377
+ task: classification | query: {content}</td>
378
+ </tr>
379
+ <tr>
380
+ <td><br>
381
+ Clustering</td>
382
+ <td><br>
383
+ Used to generate embeddings that are optimized to cluster texts based on their similarities</td>
384
+ <td><br>
385
+ task: clustering | query: {content}</td>
386
+ </tr>
387
+ <tr>
388
+ <td><br>
389
+ Semantic Similarity</td>
390
+ <td><br>
391
+ Used to generate embeddings that are optimized to assess text similarity. This is not intended for retrieval use cases.</td>
392
+ <td><br>
393
+ task: sentence similarity | query: {content}</td>
394
+ </tr>
395
+ <tr>
396
+ <td><br>
397
+ Code Retrieval</td>
398
+ <td><br>
399
+ Used to retrieve a code block based on a natural language query, such as <em>sort an array</em> or <em>reverse a linked list</em>. Embeddings of the code blocks are computed using retrieval_document.</td>
400
+ <td><br>
401
+ task: code retrieval | query: {content}</td>
402
+ </tr>
403
+ </tbody>
404
+ </table>
405
+
406
+ ## Usage and Limitations
407
+
408
+ These models have certain limitations that users should be aware of.
409
+
410
+ ### Intended Usage
411
+
412
+ Open embedding models have a wide range of applications across various industries and domains. The following list of potential uses is not comprehensive. The purpose of this list is to provide contextual information about the possible use-cases that the model creators considered as part of model training and development.
413
+
414
+ - **Semantic Similarity**: Embeddings optimized to assess text similarity, such as recommendation systems and duplicate detection
415
+ - **Classification**: Embeddings optimized to classify texts according to preset labels, such as sentiment analysis and spam detection
416
+ - **Clustering**: Embeddings optimized to cluster texts based on their similarities, such as document organization, market research, and anomaly detection
417
+ - **Retrieval**
418
+ - **Document**: Embeddings optimized for document search, such as indexing articles, books, or web pages for search
419
+ - **Query**: Embeddings optimized for general search queries, such as custom search
420
+ - **Code Query**: Embeddings optimized for retrieval of code blocks based on natural language queries, such as code suggestions and search
421
+
422
+ - **Question Answering**: Embeddings for questions in a question-answering system, optimized for finding documents that answer the question, such as chatbox.
423
+ - **Fact Verification**: Embeddings for statements that need to be verified, optimized for retrieving documents that contain evidence supporting or refuting the statement, such as automated fact-checking systems.
424
+
425
+ ### Limitations
426
+
427
+ - Training Data
428
+ - The quality and diversity of the training data significantly influence the model's capabilities. Biases or gaps in the training data can lead to limitations in the model's responses.
429
+ - The scope of the training dataset determines the subject areas the model can handle effectively.
430
+
431
+ - Language Ambiguity and Nuance
432
+ - Natural language is inherently complex. Models might struggle to grasp subtle nuances, sarcasm, or figurative language.
433
+
434
+ ### Ethical Considerations and Risks
435
+
436
+ Risks identified and mitigations:
437
+
438
+ - **Perpetuation of biases**: It's encouraged to perform continuous monitoring (using evaluation metrics, human review) and the exploration of de-biasing techniques during model training, fine-tuning, and other use cases.
439
+ - **Misuse for malicious purposes**: Technical limitations and developer and end-user education can help mitigate against malicious applications of embeddings. Educational resources and reporting mechanisms for users to flag misuse are provided. Prohibited uses of Gemma models are outlined in the [Gemma Prohibited Use Policy](https://ai.google.dev/gemma/prohibited_use_policy).
440
+ - **Privacy violations**: Models were trained on data filtered for removal of certain personal information and other sensitive data. Developers are encouraged to adhere to privacy regulations with privacy-preserving techniques.
441
+
442
+ ### Benefits
443
+
444
+ At the time of release, this family of models provides high-performance open embedding model implementations designed from the ground up for responsible AI development compared to similarly sized models. Using the benchmark evaluation metrics described in this document, these models have shown superior performance to other, comparably-sized open model alternatives.
embeddinggemma-300m/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
embeddinggemma-300m/config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_sliding_window_pattern": 6,
3
+ "architectures": [
4
+ "Gemma3TextModel"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "attn_logit_softcapping": null,
9
+ "bos_token_id": 2,
10
+ "dtype": "float32",
11
+ "eos_token_id": 1,
12
+ "final_logit_softcapping": null,
13
+ "head_dim": 256,
14
+ "hidden_activation": "gelu_pytorch_tanh",
15
+ "hidden_size": 768,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 1152,
18
+ "layer_types": [
19
+ "sliding_attention",
20
+ "sliding_attention",
21
+ "sliding_attention",
22
+ "sliding_attention",
23
+ "sliding_attention",
24
+ "full_attention",
25
+ "sliding_attention",
26
+ "sliding_attention",
27
+ "sliding_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "full_attention",
31
+ "sliding_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "full_attention"
43
+ ],
44
+ "max_position_embeddings": 2048,
45
+ "model_type": "gemma3_text",
46
+ "num_attention_heads": 3,
47
+ "num_hidden_layers": 24,
48
+ "num_key_value_heads": 1,
49
+ "pad_token_id": 0,
50
+ "query_pre_attn_scalar": 256,
51
+ "rms_norm_eps": 1e-06,
52
+ "rope_local_base_freq": 10000.0,
53
+ "rope_scaling": null,
54
+ "rope_theta": 1000000.0,
55
+ "sliding_window": 512,
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.53.3",
58
+ "use_bidirectional_attention": true,
59
+ "use_cache": true,
60
+ "vocab_size": 262144
61
+ }
embeddinggemma-300m/config_sentence_transformers.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SentenceTransformer",
3
+ "__version__": {
4
+ "sentence_transformers": "5.1.0",
5
+ "transformers": "4.57.0.dev0",
6
+ "pytorch": "2.8.0+cu128"
7
+ },
8
+ "prompts": {
9
+ "query": "task: search result | query: ",
10
+ "document": "title: none | text: ",
11
+ "BitextMining": "task: search result | query: ",
12
+ "Clustering": "task: clustering | query: ",
13
+ "Classification": "task: classification | query: ",
14
+ "InstructionRetrieval": "task: code retrieval | query: ",
15
+ "MultilabelClassification": "task: classification | query: ",
16
+ "PairClassification": "task: sentence similarity | query: ",
17
+ "Reranking": "task: search result | query: ",
18
+ "Retrieval": "task: search result | query: ",
19
+ "Retrieval-query": "task: search result | query: ",
20
+ "Retrieval-document": "title: none | text: ",
21
+ "STS": "task: sentence similarity | query: ",
22
+ "Summarization": "task: summarization | query: "
23
+ },
24
+ "default_prompt_name": null,
25
+ "similarity_fn_name": "cosine"
26
+ }
embeddinggemma-300m/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cache_implementation": "hybrid",
3
+ "do_sample": true,
4
+ "top_k": 64,
5
+ "top_p": 0.95,
6
+ "transformers_version": "4.57.0.dev0"
7
+ }
embeddinggemma-300m/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dee985629c11dd0f70531093aeb8e8f7f5ddfb403f6c2705db340d58e4e03ffb
3
+ size 1212541436
embeddinggemma-300m/modules.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Dense",
18
+ "type": "sentence_transformers.models.Dense"
19
+ },
20
+ {
21
+ "idx": 3,
22
+ "name": "3",
23
+ "path": "3_Dense",
24
+ "type": "sentence_transformers.models.Dense"
25
+ },
26
+ {
27
+ "idx": 4,
28
+ "name": "4",
29
+ "path": "4_Normalize",
30
+ "type": "sentence_transformers.models.Normalize"
31
+ }
32
+ ]
embeddinggemma-300m/onnx/dense1.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0131f901cb1759fd45439258c08d0a0bd1eea8aaae8b25556bd64550426cbf9
3
+ size 9437360
embeddinggemma-300m/onnx/dense2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47c5394b30c83bfd957392ef027c6ce52f6771c05c7af2fc627537ef3377cde9
3
+ size 9437360
embeddinggemma-300m/sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 2048,
3
+ "do_lower_case": false
4
+ }
embeddinggemma-300m/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
embeddinggemma-300m/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6852f8d561078cc0cebe70ca03c5bfdd0d60a45f9d2e0e1e4cc05b68e9ec329e
3
+ size 33385008
embeddinggemma-300m/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
embeddinggemma-300m/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
gemma3_mean_pooling_basic.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
3
+ from sentence_transformers import models
4
+ import numpy as np
5
+ import torch
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("./embeddinggemma-300m")
8
+ model = ORTModelForFeatureExtraction.from_pretrained("./embeddinggemma-300m")
9
+
10
+ inputs = tokenizer("apple", return_tensors="pt")
11
+ print(inputs)
12
+ input_ids = inputs['input_ids']
13
+ sequence_length = input_ids.shape[1]
14
+ position_ids = np.arange(sequence_length)[None, :]
15
+ position_ids = np.tile(position_ids, (input_ids.shape[0], 1))
16
+ inputs['position_ids'] = torch.tensor(position_ids, dtype=torch.long)
17
+ outputs = model(**inputs)
18
+ last_hidden = outputs.last_hidden_state
19
+ attention_mask = inputs['attention_mask']
20
+ # Use SentenceTransformer's Pooling module for mean pooling
21
+ pooling = models.Pooling(word_embedding_dimension=last_hidden.shape[-1], pooling_mode_mean_tokens=True)
22
+ features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask}
23
+ pooled = pooling(features)['sentence_embedding']
24
+ print("Mean pooled:", pooled[0][:5].detach().cpu().numpy())
onnx_gemma3_pipeline.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import models
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
5
+ import numpy as np
6
+ import os
7
+ import onnxruntime
8
+
9
+ # ONNX pipeline for Gemma3 embedding model
10
+ model_dir = "embeddinggemma-300m"
11
+ tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ onnx_model = ORTModelForFeatureExtraction.from_pretrained(
14
+ model_dir,
15
+ file_name="model.onnx"
16
+ ).to(device)
17
+
18
+ class ONNXTransformer:
19
+ def __init__(self, onnx_model, tokenizer, max_seq_length=2048):
20
+ self.onnx_model = onnx_model
21
+ self.tokenizer = tokenizer
22
+ self.max_seq_length = max_seq_length
23
+ def encode(self, sentences):
24
+ inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=self.max_seq_length)
25
+ input_ids = inputs['input_ids']
26
+ sequence_length = input_ids.shape[1]
27
+ position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
28
+ inputs['position_ids'] = position_ids.to(input_ids.device)
29
+ with torch.no_grad():
30
+ outputs = self.onnx_model(**inputs)
31
+ return outputs.last_hidden_state
32
+
33
+ modules = []
34
+ onnx_transformer = ONNXTransformer(onnx_model, tokenizer, max_seq_length=2048)
35
+ modules.append(onnx_transformer)
36
+ for idx, name in [(1, "Pooling"), (2, "Dense"), (3, "Dense"), (4, "Normalize")]:
37
+ module_path = os.path.join(model_dir, f"{idx}_{name}")
38
+ if name == "Pooling":
39
+ modules.append(models.Pooling(module_path))
40
+ elif name == "Dense":
41
+ # Use ONNXRuntime for Dense layers
42
+ dense_onnx_path = os.path.join(model_dir, "onnx", f"dense{idx-1}.onnx")
43
+ modules.append(onnxruntime.InferenceSession(dense_onnx_path, providers=["CPUExecutionProvider"]))
44
+ elif name == "Normalize":
45
+ modules.append(models.Normalize())
46
+
47
+ class ONNXSentenceTransformer:
48
+ def __init__(self, modules):
49
+ self.modules = modules
50
+ def encode(self, sentences):
51
+ features = self.modules[0].encode(sentences)
52
+ for module in self.modules[1:]:
53
+ if isinstance(module, models.Pooling):
54
+ features = module({'token_embeddings': features, 'attention_mask': torch.ones(features.shape[:2], device=features.device)})['sentence_embedding']
55
+ elif isinstance(module, onnxruntime.InferenceSession):
56
+ # ONNX Dense layer expects shape [1, in_features], so process each embedding separately
57
+ if isinstance(features, torch.Tensor):
58
+ features = features.cpu().detach().numpy()
59
+ outputs = []
60
+ for vec in features:
61
+ ort_inputs = {module.get_inputs()[0].name: vec.reshape(1, -1)}
62
+ out = module.run(None, ort_inputs)[0]
63
+ outputs.append(out.squeeze(0))
64
+ features = np.stack(outputs, axis=0)
65
+ elif isinstance(module, models.Normalize):
66
+ # Normalize still uses PyTorch
67
+ if not isinstance(features, torch.Tensor):
68
+ features = torch.from_numpy(features)
69
+ features = module({'sentence_embedding': features})['sentence_embedding']
70
+ if isinstance(features, torch.Tensor):
71
+ return features.cpu().detach().numpy()
72
+ return features
73
+
74
+ onnx_st = ONNXSentenceTransformer(modules)
75
+
76
+ def cosine_similarity(a, b):
77
+ a = a.flatten()
78
+ b = b.flatten()
79
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
80
+
81
+ if __name__ == "__main__":
82
+ words = ["apple", "banana", "car"]
83
+ embeddings = onnx_st.encode(words)
84
+ print(embeddings)
85
+ for idx, embedding in enumerate(embeddings):
86
+ print(f"Embedding {idx+1}: {embedding.shape}")
87
+
88
+ print("\nCosine similarities:")
89
+ print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
90
+ print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
91
+ print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")
pytorch_gemma3_pipeline.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import models
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
5
+ import numpy as np
6
+ # Load tokenizer and ONNX model
7
+ model_path = "./embeddinggemma-300m"
8
+ tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_path).to(device)
11
+
12
+ class ONNXSentenceTransformer:
13
+ def __init__(self, model, tokenizer):
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+ self.word_embedding_dimension = 768
17
+ self.pooling = models.Pooling(word_embedding_dimension=self.word_embedding_dimension, pooling_mode_mean_tokens=True)
18
+
19
+ def encode(self, sentences, batch_size=32):
20
+ if isinstance(sentences, str):
21
+ sentences = [sentences]
22
+ embeddings = []
23
+ for i in range(0, len(sentences), batch_size):
24
+ batch = sentences[i:i+batch_size]
25
+ inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
26
+ input_ids = inputs['input_ids']
27
+ sequence_length = input_ids.shape[1]
28
+ position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
29
+ inputs['position_ids'] = position_ids
30
+ with torch.no_grad():
31
+ outputs = self.model(**inputs)
32
+ last_hidden = outputs.last_hidden_state
33
+ attention_mask = inputs['attention_mask'].to(last_hidden.device)
34
+ features = {'token_embeddings': last_hidden, 'attention_mask': attention_mask}
35
+ pooled = self.pooling(features)['sentence_embedding']
36
+ embeddings.append(pooled)
37
+ return torch.cat(embeddings, dim=0).cpu().detach().numpy()
38
+
39
+
40
+ # Usage example
41
+ onnx_st = ONNXSentenceTransformer(onnx_model, tokenizer)
42
+
43
+ words = ["apple", "banana", "car"]
44
+ embeddings = onnx_st.encode(words)
45
+ print(embeddings)
46
+ for idx, embedding in enumerate(embeddings):
47
+ print(f"Embedding {idx+1}: {embedding.shape}")
48
+
49
+ # Cosine similarity demonstration
50
+ def cosine_similarity(a, b):
51
+ a = a.flatten()
52
+ b = b.flatten()
53
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
54
+
55
+ print("\nCosine similarities:")
56
+ print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
57
+ print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
58
+ print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")