| | import argparse
|
| | import collections
|
| | import gzip
|
| | import html
|
| | import json
|
| | import os
|
| | import random
|
| | import re
|
| | import torch
|
| | from tqdm import tqdm
|
| | import numpy as np
|
| | from utils import *
|
| | from PIL import Image
|
| | import requests
|
| | from transformers import AutoProcessor, MllamaForConditionalGeneration, MllamaForCausalLM, MllamaTextModel, MllamaVisionModel
|
| |
|
| | def load_data(args):
|
| |
|
| |
|
| | item2feature = load_json(args.data_path)
|
| |
|
| |
|
| |
|
| | return item2feature
|
| |
|
| | def generate_feature(item2feature, features):
|
| | item_feature_list = []
|
| |
|
| | for item in item2feature:
|
| | data = item2feature[item]
|
| | text = []
|
| | image = []
|
| | for meta_key in features:
|
| | if meta_key in data:
|
| | if 'image' in meta_key:
|
| | image.append(data[meta_key][0])
|
| | else:
|
| | meta_value = clean_text(data[meta_key])
|
| | text.append(meta_value.strip())
|
| |
|
| | item_feature_list.append([item, text, image])
|
| |
|
| | return item_feature_list
|
| |
|
| | def preprocess_feature(args):
|
| | print('Process feature data ...')
|
| |
|
| |
|
| | item2feature = load_data(args)
|
| |
|
| | item_feature_list = generate_feature(item2feature, ['title', 'description', 'imageH'])
|
| |
|
| |
|
| | return item_feature_list
|
| |
|
| | def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1, save_path = ''):
|
| | print('Generate feature embedding ...')
|
| |
|
| |
|
| | items, texts, images = zip(*item_text_list)
|
| | order_texts, order_images = [[0]] * len(items), [[0]] * len(items)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | item_order_mapping = load_json(args.order_path)
|
| |
|
| | for item, text, image in zip(items, texts, images):
|
| | order_texts[int(item_order_mapping[item])] = text
|
| | order_images[int(item_order_mapping[item])] = image
|
| | for text in order_texts:
|
| | assert text != [0]
|
| | for image in order_images:
|
| | assert image != [0]
|
| |
|
| | embeddings = []
|
| | text_emb_result = []
|
| | image_emb_result = []
|
| | multi_modal_emb_result = []
|
| | start, batch_size = 0, 1
|
| | with torch.no_grad():
|
| | while start < len(order_texts):
|
| | if (start + 1) % 100 == 0:
|
| | print("==>", start + 1)
|
| |
|
| | item_text, item_image = ' '.join(order_texts[start: start + 1][0]), order_images[start: start + 1][0][0]
|
| |
|
| |
|
| | processed_text = processor(text = item_text, return_tensors = 'pt').to(args.device)
|
| | text_output = model.language_model.model(**processed_text)
|
| | text_masked_output = text_output.last_hidden_state * processed_text['attention_mask'].unsqueeze(-1)
|
| | text_mean_output = text_masked_output.sum(dim = 1) / processed_text['attention_mask'].sum(dim = -1, keepdim = True)
|
| | text_mean_output = text_mean_output.detach().cpu()
|
| | text_emb_result.append(text_mean_output.numpy().tolist())
|
| |
|
| | open_image = Image.open(requests.get(item_image, stream = True).raw)
|
| | processed_image = processor(images = open_image, return_tensors = "pt").to(args.device)
|
| | image_output = model.vision_model(**processed_image)
|
| | image_mean_output = image_output.last_hidden_state.squeeze().mean(dim = 0)
|
| | image_mean_output = image_mean_output.mean(dim = 0,keepdim = True)
|
| | image_mean_output = image_mean_output.detach().cpu()
|
| | image_emb_result.append(image_mean_output.numpy().tolist())
|
| |
|
| | prompt = '<|image|>' + item_text
|
| | inputs = processor(text = prompt, images = open_image, return_tensors = "pt").to(args.device)
|
| | multi_modal_output = model(**inputs, output_hidden_states = True)
|
| | multi_modal_mean_output = multi_modal_output.hidden_states[-1].mean(dim = 1)
|
| | multi_modal_mean_output = multi_modal_mean_output.detach().cpu()
|
| | multi_modal_emb_result.append(multi_modal_mean_output.numpy().tolist())
|
| |
|
| | text_embeddings = torch.cat(text_emb_result, dim = 0).numpy()
|
| | print('Text-Embeddings shape: ', text_embeddings.shape)
|
| | image_embeddings = torch.cat(image_emb_result, dim = 0).numpy()
|
| | print('Image-Embeddings shape: ', image_embeddings.shape)
|
| | multi_modal_embeddings = torch.cat(multi_modal_emb_result, dim = 0).numpy()
|
| | print('Multimodal-Embeddings shape: ', multi_modal_embeddings.shape)
|
| |
|
| | file = os.path.join(args.save_path + "Musical_Instruments.emb.text.npy")
|
| | np.save(file, text_embeddings)
|
| | file = os.path.join(args.save_path + "Musical_Instruments.emb.imgae.npy")
|
| | np.save(file, image_embeddings)
|
| | file = os.path.join(args.save_path + "Musical_Instruments.emb.multimodal.npy")
|
| | np.save(file, multi_modal_embeddings)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument('--gpu_id', type=int, default=0, help='ID of running GPU')
|
| | parser.add_argument('--plm_name', type=str, default='llama')
|
| | parser.add_argument('--plm_checkpoint', type=str, default='')
|
| | parser.add_argument('--max_sent_len', type=int, default=2048)
|
| | parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
|
| | parser.add_argument('--data_path', type=str, default='')
|
| | parser.add_argument('--order_path', type=str, default='')
|
| | parser.add_argument('--save_path', type=str, default='')
|
| | return parser.parse_args()
|
| |
|
| | if __name__ == '__main__':
|
| | args = parse_args()
|
| |
|
| | device = set_device(args.gpu_id)
|
| | args.device = device
|
| |
|
| | item_feature_list = preprocess_feature(args)
|
| |
|
| | model = MllamaForConditionalGeneration.from_pretrained(args.plm_checkpoint, torch_dtype = torch.float16)
|
| | processor = AutoProcessor.from_pretrained(args.plm_checkpoint)
|
| |
|
| |
|
| |
|
| |
|
| | model = model.to(device)
|
| |
|
| | generate_item_embedding(
|
| | args,
|
| | item_feature_list,
|
| | processor,
|
| | model,
|
| | word_drop_ratio = args.word_drop_ratio,
|
| | save_path = args.save_path
|
| | ) |