File size: 1,880 Bytes
e587c3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from huggingface_hub import snapshot_download
import os
import shutil
from sentence_transformers import SentenceTransformer
import torch

# Model repo and local directory
repo_id = "google/embeddinggemma-300m-qat-q4_0-unquantized"
local_dir = "embeddinggemma-300m"

# Download all files except model.safetensors and those already present
existing_files = set(os.listdir(local_dir))

# Download snapshot to a temp dir
temp_dir = "_hf_temp_download"
os.makedirs(temp_dir, exist_ok=True)
snapshot_download(
    repo_id,
    local_dir=temp_dir,
    ignore_patterns=["model.safetensors"],
    resume_download=True,
    allow_patterns=None
)

# Copy missing files
for fname in os.listdir(temp_dir):
    if fname not in existing_files:
        shutil.move(os.path.join(temp_dir, fname), os.path.join(local_dir, fname))
        print(f"Downloaded: {fname}")
    else:
        print(f"Already exists: {fname}")

# Clean up temp dir
shutil.rmtree(temp_dir)
print("Done.")

# Export Dense layers from SentenceTransformer to ONNX
st_model = SentenceTransformer(repo_id)
dense1 = st_model[2].linear
dense2 = st_model[3].linear

onnx_dir = os.path.join(local_dir, "onnx")
os.makedirs(onnx_dir, exist_ok=True)

# Export Dense1
dummy_input1 = torch.randn(1, dense1.in_features)
dense1 = dense1.to(dummy_input1.device)
torch.onnx.export(
    dense1,
    dummy_input1,
    os.path.join(onnx_dir, "dense1.onnx"),
    input_names=["input"],
    output_names=["output"],
    opset_version=14
)
print("Exported dense1.onnx")

# Export Dense2
dummy_input2 = torch.randn(1, dense2.in_features)
dense2 = dense2.to(dummy_input2.device)
torch.onnx.export(
    dense2,
    dummy_input2,
    os.path.join(onnx_dir, "dense2.onnx"),
    input_names=["input"],
    output_names=["output"],
    opset_version=14
)
print("Exported dense2.onnx")