File size: 2,765 Bytes
816b4b1
31086ae
 
 
 
 
816b4b1
 
31086ae
816b4b1
 
 
 
31086ae
 
52419fe
 
31086ae
 
 
816b4b1
 
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816b4b1
31086ae
 
816b4b1
31086ae
 
816b4b1
31086ae
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import pickle
from typing import List, Union

import pandas as pd
import ray
from ray.data import Dataset

from graphgen.bases.base_reader import BaseReader
from graphgen.utils import logger


class PickleReader(BaseReader):
    """
    Read pickle files, requiring the schema to be restored to List[Dict[str, Any]].
    Each pickle file should contain a list of dictionaries with at least:
    - type: The type of the document (e.g., "text", "image", etc.)
    - if type is "text", "content" column must be present.

    Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available.
    For Ray >= 2.5, consider using read_pickle if available in your version.
    """

    def read(
        self,
        input_path: Union[str, List[str]],
    ) -> Dataset:
        """
        Read Pickle files using Ray Data.

        :param input_path: Path to pickle file or list of pickle files.
        :return: Ray Dataset containing validated documents.
        """
        if not ray.is_initialized():
            ray.init()

        # Use read_binary_files as a reliable alternative to read_pickle
        ds = ray.data.read_binary_files(input_path, include_paths=True)

        # Deserialize pickle files and flatten into individual records
        def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame:
            all_records = []
            for _, row in batch.iterrows():
                try:
                    # Load pickle data from bytes
                    data = pickle.loads(row["bytes"])

                    # Validate structure
                    if not isinstance(data, list):
                        logger.error(
                            "Pickle file {row['path']} must contain a list, got {type(data)}"
                        )
                        continue

                    if not all(isinstance(item, dict) for item in data):
                        logger.error(
                            "Pickle file {row['path']} must contain a list of dictionaries"
                        )
                        continue

                    # Flatten: each dict in the list becomes a separate row
                    all_records.extend(data)
                except Exception as e:
                    logger.error(
                        "Failed to deserialize pickle file %s: %s", row["path"], str(e)
                    )
                    continue

            return pd.DataFrame(all_records)

        # Apply deserialization and flattening
        ds = ds.map_batches(deserialize_batch, batch_format="pandas")

        # Validate the schema
        ds = ds.map_batches(self._validate_batch, batch_format="pandas")

        # Filter valid items
        ds = ds.filter(self._should_keep_item)
        return ds