Spaces:
Sleeping
Sleeping
File size: 1,649 Bytes
8972ad7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass
from binary_shield.comparison import compute_similarity, hamming_distance
from binary_shield.embedding import extract_embedding
from binary_shield.privacy import apply_randomized_response
from binary_shield.quantization import BinaryPackedEmbedding, binary_quantize
@dataclass
class BinaryFingerprint:
fingerprint: BinaryPackedEmbedding
epsilon: float | None
@dataclass
class ComparisonResult:
hamming_distance: int
similarity: float
is_match: bool
class BinaryShield:
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
epsilon: float | None = None,
) -> None:
self.model = SentenceTransformer(model_name)
self.epsilon = epsilon
def generate_fingerprint(self, text: str) -> BinaryFingerprint:
embedding = extract_embedding(text, self.model)
bin_embedding = binary_quantize(embedding)
if self.epsilon is not None:
bin_embedding = apply_randomized_response(bin_embedding, self.epsilon)
return BinaryFingerprint(
fingerprint=bin_embedding,
epsilon=self.epsilon,
)
@staticmethod
def compare(
fp1: BinaryFingerprint,
fp2: BinaryFingerprint,
threshold: float = 0.8,
) -> ComparisonResult:
dist = hamming_distance(fp1.fingerprint, fp2.fingerprint)
sim = compute_similarity(fp1.fingerprint, fp2.fingerprint)
return ComparisonResult(
hamming_distance=dist,
similarity=sim,
is_match=sim >= threshold,
)
|