Spaces:
Running
Running
File size: 2,610 Bytes
acd7cf4 fb9c306 d2a63cc fb9c306 acd7cf4 9e67c3b acd7cf4 9e67c3b acd7cf4 9e67c3b acd7cf4 9e67c3b acd7cf4 9e67c3b acd7cf4 9e67c3b 283e483 9e67c3b acd7cf4 9e67c3b acd7cf4 f1eedd1 acd7cf4 9e67c3b f1eedd1 fb9c306 8c66169 fb9c306 9e67c3b fb9c306 9e67c3b fb9c306 9e67c3b fb9c306 9e67c3b fb9c306 9e67c3b fb9c306 9e67c3b fb9c306 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
from dataclasses import dataclass
from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage
from graphgen.utils import load_json, logger, write_json
@dataclass
class JsonKVStorage(BaseKVStorage):
_data: dict[str, dict] = None
def __post_init__(self):
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info("Load KV %s with %d data", self.namespace, len(self._data))
@property
def data(self):
return self._data
def all_keys(self) -> list[str]:
return list(self._data.keys())
def index_done_callback(self):
write_json(self._data, self._file_name)
def get_by_id(self, id):
return self._data.get(id, None)
def get_by_ids(self, ids, fields=None) -> list:
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
def get_all(self) -> dict[str, dict]:
return self._data
def filter_keys(self, data: list[str]) -> set[str]:
return {s for s in data if s not in self._data}
def upsert(self, data: dict):
left_data = {k: v for k, v in data.items() if k not in self._data}
if left_data:
self._data.update(left_data)
return left_data
def drop(self):
if self._data:
self._data.clear()
@dataclass
class JsonListStorage(BaseListStorage):
working_dir: str = None
namespace: str = None
_data: list = None
def __post_init__(self):
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
self._data = load_json(self._file_name) or []
logger.info("Load List %s with %d data", self.namespace, len(self._data))
@property
def data(self):
return self._data
def all_items(self) -> list:
return self._data
def index_done_callback(self):
write_json(self._data, self._file_name)
def get_by_index(self, index: int):
if index < 0 or index >= len(self._data):
return None
return self._data[index]
def append(self, data):
self._data.append(data)
def upsert(self, data: list):
left_data = [d for d in data if d not in self._data]
self._data.extend(left_data)
return left_data
def drop(self):
self._data = []
|