Wayra-Perplexity-Estimator-55M / tensorrt_inference.py
oflorez's picture
Upload TensorRT WayraPPL model optimized for A100 GPUs
140a80a verified
"""
TensorRT Inference Example for WayraPPL
Requires A100 GPU with TensorRT 10.13+
"""
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
from transformers import AutoTokenizer
import torch
import time
class WayraPPLTensorRT:
def __init__(self, engine_path: str):
# Load TensorRT engine
trt_logger = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(trt_logger)
with open(engine_path, 'rb') as f:
engine_data = f.read()
self.engine = runtime.deserialize_cuda_engine(engine_data)
self.context = self.engine.create_execution_context()
self.stream = cuda.Stream()
def infer(self, input_ids: np.ndarray, attention_mask: np.ndarray):
batch_size, seq_len = input_ids.shape
# Set dynamic shapes
self.context.set_input_shape("input_ids", input_ids.shape)
self.context.set_input_shape("attention_mask", attention_mask.shape)
# Allocate memory
d_input_ids = cuda.mem_alloc(input_ids.nbytes)
d_attention_mask = cuda.mem_alloc(attention_mask.nbytes)
# Copy inputs
cuda.memcpy_htod_async(d_input_ids, input_ids.astype(np.int64), self.stream)
cuda.memcpy_htod_async(d_attention_mask, attention_mask.astype(np.int64), self.stream)
# Setup outputs
outputs = {}
device_outputs = {}
for i in range(self.engine.num_io_tensors):
tensor_name = self.engine.get_tensor_name(i)
if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.OUTPUT:
output_shape = self.context.get_tensor_shape(tensor_name)
if output_shape[0] == -1:
output_shape = (batch_size,) + output_shape[1:]
host_output = np.empty(output_shape, dtype=np.float32)
device_output = cuda.mem_alloc(host_output.nbytes)
outputs[tensor_name] = host_output
device_outputs[tensor_name] = device_output
self.context.set_tensor_address(tensor_name, int(device_output))
# Set input addresses
self.context.set_tensor_address("input_ids", int(d_input_ids))
self.context.set_tensor_address("attention_mask", int(d_attention_mask))
# Execute
self.context.execute_async_v3(stream_handle=self.stream.handle)
# Copy outputs
for tensor_name, host_output in outputs.items():
cuda.memcpy_dtoh_async(host_output, device_outputs[tensor_name], self.stream)
self.stream.synchronize()
# Cleanup
d_input_ids.free()
d_attention_mask.free()
for device_output in device_outputs.values():
device_output.free()
return outputs
# Usage examples with multilingual text
if __name__ == "__main__":
# Load model
model = WayraPPLTensorRT("wayrappl_fp16_bs2048.engine")
tokenizer = AutoTokenizer.from_pretrained(".")
# Multilingual examples
texts = [
# Spanish
"La inteligencia artificial está transformando el mundo de manera profunda y acelerada.",
"El análisis de datos permite descubrir patrones ocultos en grandes volúmenes de información.",
# Portuguese
"A tecnologia blockchain promete revolucionar sistemas financeiros tradicionais.",
"Machine learning possibilita a automação de processos complexos em diversas indústrias.",
# English
"Natural language processing enables computers to understand human communication.",
"Deep learning algorithms require massive computational resources for training."
]
# Single inference
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = model.infer(inputs['input_ids'].numpy(), inputs['attention_mask'].numpy())
print("Perplexity scores:")
for i, text in enumerate(texts):
print(f"{text[:50]}... -> PPL: {outputs['ppl'][i]:.2f}")
# Performance comparison: 100K examples
print("
" + "="*50)
print("PERFORMANCE COMPARISON: 100K Examples")
print("="*50)
# Generate 100K examples
large_texts = texts * 16667 # ~100K examples
# TensorRT benchmark
start_time = time.time()
batch_size = 2048
total_processed = 0
for i in range(0, len(large_texts), batch_size):
batch = large_texts[i:i+batch_size]
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = model.infer(inputs['input_ids'].numpy(), inputs['attention_mask'].numpy())
total_processed += len(batch)
tensorrt_time = time.time() - start_time
tensorrt_throughput = total_processed / tensorrt_time
print(f"TensorRT Results:")
print(f" Time: {tensorrt_time:.2f} hours")
print(f" Throughput: {tensorrt_throughput:.0f} samples/sec")
print(f" Total processed: {total_processed:,} examples")
# Estimated PyTorch comparison
pytorch_throughput = 1000 # samples/sec (estimated)
pytorch_time = total_processed / pytorch_throughput / 3600 # hours
print(f"
PyTorch Estimated:")
print(f" Time: {pytorch_time:.2f} hours")
print(f" Throughput: {pytorch_throughput} samples/sec")
speedup = pytorch_time / tensorrt_time
print(f"
Speedup: {speedup:.1f}x faster with TensorRT")