| import os | |
| import sys | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| import warnings | |
| import torch | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm import tqdm | |
| from datasets import Dataset, DatasetDict | |
| from transformers import AutoModel, AutoTokenizer | |
| warnings.filterwarnings('ignore') | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('embedding_generation.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class AffiliationEmbedder: | |
| def __init__( | |
| self, | |
| model_path: str = "./affiliation-clustering-0.3b", | |
| device: str = None, | |
| batch_size: int = 32, | |
| max_length: int = 512, | |
| use_fp16: bool = False | |
| ): | |
| self.model_path = model_path | |
| self.batch_size = batch_size | |
| self.max_length = max_length | |
| self.use_fp16 = use_fp16 | |
| if device is None: | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| self.device = torch.device(device) | |
| logger.info(f"Using device: {self.device}") | |
| if self.device.type == 'cuda': | |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") | |
| logger.info(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
| self._load_model() | |
| def _load_model(self): | |
| logger.info(f"Loading model from {self.model_path}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_path, | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| self.model_path, | |
| trust_remote_code=True | |
| ) | |
| self.model = self.model.to(self.device) | |
| if self.use_fp16 and self.device.type == 'cuda': | |
| self.model = self.model.half() | |
| logger.info("Using FP16 mixed precision") | |
| self.model.eval() | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def encode_batch(self, texts: List[str]) -> np.ndarray: | |
| encoded = self.tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors='pt' | |
| ) | |
| encoded = {k: v.to(self.device) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**encoded) | |
| if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: | |
| embeddings = outputs.pooler_output | |
| else: | |
| token_embeddings = outputs.last_hidden_state | |
| attention_mask = encoded['attention_mask'].unsqueeze(-1) | |
| masked_embeddings = token_embeddings * attention_mask | |
| embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| embeddings = embeddings.cpu().numpy() | |
| if self.use_fp16: | |
| embeddings = embeddings.astype(np.float32) | |
| return embeddings | |
| def process_dataset( | |
| self, | |
| data_path: str, | |
| output_path: str, | |
| checkpoint_interval: int = 1000 | |
| ) -> None: | |
| logger.info(f"Processing dataset: {data_path}") | |
| df = pd.read_parquet(data_path) | |
| logger.info(f"Loaded {len(df)} samples") | |
| checkpoint_path = output_path.replace('.parquet', '_checkpoint.parquet') | |
| start_idx = 0 | |
| if os.path.exists(checkpoint_path): | |
| logger.info(f"Found checkpoint at {checkpoint_path}") | |
| checkpoint_df = pd.read_parquet(checkpoint_path) | |
| start_idx = len(checkpoint_df) | |
| logger.info(f"Resuming from index {start_idx}") | |
| all_embeddings = [] | |
| processed_rows = [] | |
| total_batches = (len(df) - start_idx + self.batch_size - 1) // self.batch_size | |
| with tqdm(total=total_batches, desc="Generating embeddings") as pbar: | |
| for i in range(start_idx, len(df), self.batch_size): | |
| batch_df = df.iloc[i:i+self.batch_size] | |
| texts = batch_df['affiliation_name'].tolist() | |
| try: | |
| batch_embeddings = self.encode_batch(texts) | |
| for j, embedding in enumerate(batch_embeddings): | |
| row_idx = i + j | |
| row_data = df.iloc[row_idx].to_dict() | |
| row_data['embedding'] = embedding | |
| processed_rows.append(row_data) | |
| if len(processed_rows) % checkpoint_interval == 0: | |
| self._save_checkpoint(processed_rows, checkpoint_path) | |
| logger.info(f"Checkpoint saved at {len(processed_rows)} samples") | |
| pbar.update(1) | |
| except Exception as e: | |
| logger.error(f"Error processing batch at index {i}: {e}") | |
| if processed_rows: | |
| self._save_checkpoint(processed_rows, checkpoint_path) | |
| raise | |
| result_df = pd.DataFrame(processed_rows) | |
| logger.info(f"Saving embeddings to {output_path}") | |
| result_df.to_parquet(output_path, compression='snappy') | |
| if os.path.exists(checkpoint_path): | |
| os.remove(checkpoint_path) | |
| logger.info("Checkpoint file removed") | |
| logger.info(f"Successfully generated embeddings for {len(result_df)} samples") | |
| embedding_dim = len(result_df['embedding'].iloc[0]) | |
| logger.info(f"Embedding dimension: {embedding_dim}") | |
| logger.info(f"Output file size: {os.path.getsize(output_path) / 1e6:.2f} MB") | |
| def _save_checkpoint(self, processed_rows: List[Dict], checkpoint_path: str): | |
| checkpoint_df = pd.DataFrame(processed_rows) | |
| checkpoint_df.to_parquet(checkpoint_path, compression='snappy') | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate embeddings for affiliation strings" | |
| ) | |
| parser.add_argument( | |
| "--model-path", | |
| type=str, | |
| default="./affiliation-clustering-0.3b", | |
| help="Path to the pre-trained model directory" | |
| ) | |
| parser.add_argument( | |
| "--data-dir", | |
| type=str, | |
| default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample", | |
| help="Directory containing the input parquet files" | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default="./20250727-unique-openalex-affiliations-w-ror-ids-top-1K-ror-ids-100-per-sample-embeddings", | |
| help="Directory to save the output embeddings" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=32, | |
| help="Batch size for processing" | |
| ) | |
| parser.add_argument( | |
| "--max-length", | |
| type=int, | |
| default=512, | |
| help="Maximum sequence length for tokenization" | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=None, | |
| help="Device to use (cuda/cpu, auto-detect if not specified)" | |
| ) | |
| parser.add_argument( | |
| "--use-fp16", | |
| action="store_true", | |
| help="Use FP16 mixed precision for faster processing" | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-interval", | |
| type=int, | |
| default=1000, | |
| help="Save checkpoint every N batches" | |
| ) | |
| parser.add_argument( | |
| "--push-to-hub", | |
| action="store_true", | |
| help="Push the resulting dataset to Hugging Face Hub" | |
| ) | |
| parser.add_argument( | |
| "--hub-dataset-id", | |
| type=str, | |
| default=None, | |
| help="Hugging Face Hub dataset ID (required if push-to-hub is set)" | |
| ) | |
| args = parser.parse_args() | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| embedder = AffiliationEmbedder( | |
| model_path=args.model_path, | |
| device=args.device, | |
| batch_size=args.batch_size, | |
| max_length=args.max_length, | |
| use_fp16=args.use_fp16 | |
| ) | |
| data_dir = Path(args.data_dir) | |
| train_file = list(data_dir.glob("*_train.parquet"))[0] | |
| test_file = list(data_dir.glob("*_test.parquet"))[0] | |
| train_output = output_dir / "train_embeddings.parquet" | |
| test_output = output_dir / "test_embeddings.parquet" | |
| logger.info("Processing training dataset...") | |
| embedder.process_dataset( | |
| str(train_file), | |
| str(train_output), | |
| checkpoint_interval=args.checkpoint_interval | |
| ) | |
| logger.info("Processing test dataset...") | |
| embedder.process_dataset( | |
| str(test_file), | |
| str(test_output), | |
| checkpoint_interval=args.checkpoint_interval | |
| ) | |
| if args.push_to_hub: | |
| if not args.hub_dataset_id: | |
| logger.error("--hub-dataset-id is required when --push-to-hub is set") | |
| sys.exit(1) | |
| logger.info(f"Pushing dataset to Hugging Face Hub: {args.hub_dataset_id}") | |
| try: | |
| from huggingface_hub import HfApi, login | |
| token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN') | |
| if token: | |
| login(token=token) | |
| logger.info("Authenticated with Hugging Face Hub using token") | |
| else: | |
| logger.info("No HF token found in environment, attempting to use existing credentials") | |
| logger.info("Loading generated embeddings...") | |
| train_df = pd.read_parquet(train_output) | |
| test_df = pd.read_parquet(test_output) | |
| logger.info(f"Train dataset: {len(train_df)} samples") | |
| logger.info(f"Test dataset: {len(test_df)} samples") | |
| logger.info("Creating dataset dictionary...") | |
| dataset_dict = DatasetDict({ | |
| 'train': Dataset.from_pandas(train_df), | |
| 'test': Dataset.from_pandas(test_df) | |
| }) | |
| logger.info(f"Pushing to hub: {args.hub_dataset_id}") | |
| dataset_dict.push_to_hub( | |
| args.hub_dataset_id, | |
| private=False, | |
| commit_message="Add affiliation embeddings generated with affiliation-clustering-0.3b model" | |
| ) | |
| logger.info(f"Dataset successfully pushed to {args.hub_dataset_id}") | |
| logger.info(f"View at: https://huggingface.co/datasets/{args.hub_dataset_id}") | |
| except ImportError as e: | |
| logger.error(f"Failed to import required libraries: {e}") | |
| logger.error("Make sure huggingface_hub and datasets are installed") | |
| sys.exit(1) | |
| except Exception as e: | |
| logger.error(f"Failed to push dataset to hub: {e}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| import traceback | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| sys.exit(1) | |
| logger.info("Embedding generation completed successfully!") | |
| if __name__ == "__main__": | |
| main() |