File size: 1,925 Bytes
d2a63cc
31086ae
 
 
 
d2a63cc
 
 
 
e4316f1
52419fe
31086ae
52419fe
 
 
 
 
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import json
from typing import List, Union

import ray
import ray.data

from graphgen.bases.base_reader import BaseReader


class JSONReader(BaseReader):
    """
    Reader for JSON and JSONL files.
    Columns:
        - type: The type of the document (e.g., "text", "image", etc.)
        - if type is "text", "content" column must be present.
    """

    def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset:
        """
        Read JSON file and return Ray Dataset.
        :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
        :return: Ray Dataset containing validated and filtered data.
        """
        if self.modalities and len(self.modalities) >= 2:
            ds: ray.data.Dataset = ray.data.from_items([])
            for file in input_path if isinstance(input_path, list) else [input_path]:
                data = []
                if file.endswith(".jsonl"):
                    with open(file, "r", encoding="utf-8") as f:
                        for line in f:
                            item = json.loads(line)
                            data.append(item)
                else:
                    with open(file, "r", encoding="utf-8") as f:
                        data = json.load(f)
                        data = self._unify_schema(data)
                file_ds: ray.data.Dataset = ray.data.from_items(data)
                ds = ds.union(file_ds)  # type: ignore
        else:
            ds = ray.data.read_json(input_path)
        ds = ds.map_batches(self._validate_batch, batch_format="pandas")
        ds = ds.filter(self._should_keep_item)
        return ds

    @staticmethod
    def _unify_schema(data):
        """
        Unify schema for JSON data.
        """
        for item in data:
            if "content" in item and isinstance(item["content"], dict):
                item["content"] = json.dumps(item["content"])
        return data