File size: 1,813 Bytes
18e4106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch

from PIL import Image
from struct_eqtable import build_model
from pdf_extract_kit.registry.registry import MODEL_REGISTRY


@MODEL_REGISTRY.register("table_parsing_struct_eqtable")
class TableParsingStructEqTable:
    def __init__(self, config):
        """
        Initialize the TableParsingStructEqTable class.

        Args:
            config (dict): Configuration dictionary containing model parameters.
        """
        assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."

        self.model_dir = config.get('model_path', 'U4R/StructTable-InternVL2-1B')
        self.max_new_tokens = config.get('max_new_tokens', 1024)
        self.max_time = config.get('max_time', 30)

        self.lmdeploy = config.get('lmdeploy', False)
        self.flash_attn = config.get('flash_attn', True)
        self.batch_size = config.get('batch_size', 1)
        self.default_format = config.get('output_format', 'latex')

        # Load the StructEqTable model
        self.model = build_model(
            model_ckpt=self.model_dir,
            max_new_tokens=self.max_new_tokens,
            max_time=self.max_time,
            lmdeploy=self.lmdeploy,
            flash_attn=self.flash_attn,
            batch_size=self.batch_size,
        ).cuda()

    def predict(self, images, result_path, output_format=None, **kwargs):        

        load_images = [Image.open(image_path) for image_path in images]

        if output_format is None:
            output_format = self.default_format
        else:
            if output_format not in ['latex', 'markdown', 'html']:
                raise ValueError(f"Output format {output_format} is not supported.")

        results = self.model(
            load_images, output_format=output_format
        )

        return results