from sentence_transformers import SentenceTransformer from dataclasses import dataclass from binary_shield.comparison import compute_similarity, hamming_distance from binary_shield.embedding import extract_embedding from binary_shield.privacy import apply_randomized_response from binary_shield.quantization import BinaryPackedEmbedding, binary_quantize @dataclass class BinaryFingerprint: fingerprint: BinaryPackedEmbedding epsilon: float | None @dataclass class ComparisonResult: hamming_distance: int similarity: float is_match: bool class BinaryShield: def __init__( self, model_name: str = "all-MiniLM-L6-v2", epsilon: float | None = None, ) -> None: self.model = SentenceTransformer(model_name) self.epsilon = epsilon def generate_fingerprint(self, text: str) -> BinaryFingerprint: embedding = extract_embedding(text, self.model) bin_embedding = binary_quantize(embedding) if self.epsilon is not None: bin_embedding = apply_randomized_response(bin_embedding, self.epsilon) return BinaryFingerprint( fingerprint=bin_embedding, epsilon=self.epsilon, ) @staticmethod def compare( fp1: BinaryFingerprint, fp2: BinaryFingerprint, threshold: float = 0.8, ) -> ComparisonResult: dist = hamming_distance(fp1.fingerprint, fp2.fingerprint) sim = compute_similarity(fp1.fingerprint, fp2.fingerprint) return ComparisonResult( hamming_distance=dist, similarity=sim, is_match=sim >= threshold, )