|
|
import torch |
|
|
|
|
|
from PIL import Image |
|
|
from struct_eqtable import build_model |
|
|
from pdf_extract_kit.registry.registry import MODEL_REGISTRY |
|
|
|
|
|
|
|
|
@MODEL_REGISTRY.register("table_parsing_struct_eqtable") |
|
|
class TableParsingStructEqTable: |
|
|
def __init__(self, config): |
|
|
""" |
|
|
Initialize the TableParsingStructEqTable class. |
|
|
|
|
|
Args: |
|
|
config (dict): Configuration dictionary containing model parameters. |
|
|
""" |
|
|
assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model." |
|
|
|
|
|
self.model_dir = config.get('model_path', 'U4R/StructTable-InternVL2-1B') |
|
|
self.max_new_tokens = config.get('max_new_tokens', 1024) |
|
|
self.max_time = config.get('max_time', 30) |
|
|
|
|
|
self.lmdeploy = config.get('lmdeploy', False) |
|
|
self.flash_attn = config.get('flash_attn', True) |
|
|
self.batch_size = config.get('batch_size', 1) |
|
|
self.default_format = config.get('output_format', 'latex') |
|
|
|
|
|
|
|
|
self.model = build_model( |
|
|
model_ckpt=self.model_dir, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
max_time=self.max_time, |
|
|
lmdeploy=self.lmdeploy, |
|
|
flash_attn=self.flash_attn, |
|
|
batch_size=self.batch_size, |
|
|
).cuda() |
|
|
|
|
|
def predict(self, images, result_path, output_format=None, **kwargs): |
|
|
|
|
|
load_images = [Image.open(image_path) for image_path in images] |
|
|
|
|
|
if output_format is None: |
|
|
output_format = self.default_format |
|
|
else: |
|
|
if output_format not in ['latex', 'markdown', 'html']: |
|
|
raise ValueError(f"Output format {output_format} is not supported.") |
|
|
|
|
|
results = self.model( |
|
|
load_images, output_format=output_format |
|
|
) |
|
|
|
|
|
return results |
|
|
|