File size: 2,610 Bytes
acd7cf4
 
fb9c306
d2a63cc
fb9c306
acd7cf4
 
 
 
9e67c3b
acd7cf4
 
 
 
 
 
 
 
 
 
9e67c3b
acd7cf4
 
9e67c3b
acd7cf4
 
9e67c3b
acd7cf4
 
9e67c3b
acd7cf4
 
 
 
 
 
 
 
 
 
 
9e67c3b
283e483
 
9e67c3b
acd7cf4
 
9e67c3b
acd7cf4
f1eedd1
 
acd7cf4
 
9e67c3b
f1eedd1
 
fb9c306
 
 
 
8c66169
 
fb9c306
 
 
 
 
 
 
 
 
 
 
9e67c3b
fb9c306
 
9e67c3b
fb9c306
 
9e67c3b
fb9c306
 
 
 
9e67c3b
fb9c306
 
9e67c3b
fb9c306
 
 
 
9e67c3b
fb9c306
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from dataclasses import dataclass

from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage
from graphgen.utils import load_json, logger, write_json


@dataclass
class JsonKVStorage(BaseKVStorage):
    _data: dict[str, dict] = None

    def __post_init__(self):
        self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
        self._data = load_json(self._file_name) or {}
        logger.info("Load KV %s with %d data", self.namespace, len(self._data))

    @property
    def data(self):
        return self._data

    def all_keys(self) -> list[str]:
        return list(self._data.keys())

    def index_done_callback(self):
        write_json(self._data, self._file_name)

    def get_by_id(self, id):
        return self._data.get(id, None)

    def get_by_ids(self, ids, fields=None) -> list:
        if fields is None:
            return [self._data.get(id, None) for id in ids]
        return [
            (
                {k: v for k, v in self._data[id].items() if k in fields}
                if self._data.get(id, None)
                else None
            )
            for id in ids
        ]

    def get_all(self) -> dict[str, dict]:
        return self._data

    def filter_keys(self, data: list[str]) -> set[str]:
        return {s for s in data if s not in self._data}

    def upsert(self, data: dict):
        left_data = {k: v for k, v in data.items() if k not in self._data}
        if left_data:
            self._data.update(left_data)
        return left_data

    def drop(self):
        if self._data:
            self._data.clear()


@dataclass
class JsonListStorage(BaseListStorage):
    working_dir: str = None
    namespace: str = None
    _data: list = None

    def __post_init__(self):
        self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
        self._data = load_json(self._file_name) or []
        logger.info("Load List %s with %d data", self.namespace, len(self._data))

    @property
    def data(self):
        return self._data

    def all_items(self) -> list:
        return self._data

    def index_done_callback(self):
        write_json(self._data, self._file_name)

    def get_by_index(self, index: int):
        if index < 0 or index >= len(self._data):
            return None
        return self._data[index]

    def append(self, data):
        self._data.append(data)

    def upsert(self, data: list):
        left_data = [d for d in data if d not in self._data]
        self._data.extend(left_data)
        return left_data

    def drop(self):
        self._data = []