File size: 1,982 Bytes
9e67c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from typing import Any

import ray

from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage


@ray.remote
class StorageManager:
    """
    Centralized storage for all operators

    Example Usage:
    ----------
    # init
    storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123)

    # visit storage in tasks
    @ray.remote
    def some_task(storage_manager):
        full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs"))

    # visit storage in other actors
    @ray.remote
    class SomeOperator:
        def __init__(self, storage_manager):
            self.storage_manager = storage_manager
        def some_method(self):
            full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs"))
    """

    def __init__(self, working_dir: str, unique_id: int):
        self.working_dir = working_dir
        self.unique_id = unique_id

        # Initialize all storage backends
        self.storages = {
            "full_docs": JsonKVStorage(working_dir, namespace="full_docs"),
            "chunks": JsonKVStorage(working_dir, namespace="chunks"),
            "graph": NetworkXStorage(working_dir, namespace="graph"),
            "rephrase": JsonKVStorage(working_dir, namespace="rephrase"),
            "partition": JsonListStorage(working_dir, namespace="partition"),
            "search": JsonKVStorage(
                os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
                namespace="search",
            ),
            "extraction": JsonKVStorage(
                os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
                namespace="extraction",
            ),
            "qa": JsonListStorage(
                os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
                namespace="qa",
            ),
        }

    def get_storage(self, name: str) -> Any:
        return self.storages.get(name)