File size: 2,605 Bytes
3a3b216
43d27f2
3a3b216
43d27f2
31086ae
 
43d27f2
 
 
 
 
0b9d8c7
43d27f2
 
0b9d8c7
 
 
 
 
bccd595
0b9d8c7
 
 
43d27f2
 
 
 
 
 
 
 
 
3a3b216
 
 
 
 
 
 
 
 
 
 
 
799ac7c
 
 
 
 
 
 
 
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math
from dataclasses import dataclass, field
from typing import List, Union

from pydantic import BaseModel, Field, field_validator


@dataclass
class Chunk:
    id: str
    content: str
    type: str
    metadata: dict = field(default_factory=dict)

    @staticmethod
    def from_dict(key: str, data: dict) -> "Chunk":
        return Chunk(
            id=key,
            content=data.get("content", ""),
            type=data.get("type", "text"),
            metadata={k: v for k, v in data.items() if k != "content"},
        )


@dataclass
class QAPair:
    """
    A pair of question and answer.
    """

    question: str
    answer: str


@dataclass
class Token:
    text: str
    prob: float
    top_candidates: List = field(default_factory=list)
    ppl: Union[float, None] = field(default=None)

    @property
    def logprob(self) -> float:
        return math.log(self.prob)


@dataclass
class Community:
    id: Union[int, str]
    nodes: List[str] = field(default_factory=list)
    edges: List[tuple] = field(default_factory=list)
    metadata: dict = field(default_factory=dict)


class Node(BaseModel):
    id: str = Field(..., description="unique node id")
    op_name: str = Field(..., description="operator name")
    type: str = Field(
        ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch"
    )
    params: dict = Field(default_factory=dict, description="operator parameters")
    dependencies: List[str] = Field(
        default_factory=list, description="list of dependent node ids"
    )
    execution_params: dict = Field(
        default_factory=dict, description="execution parameters like replicas, batch_size"
    )

    @classmethod
    @field_validator("type")
    def validate_type(cls, v: str) -> str:
        valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"}
        if v not in valid_types:
            raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.")
        return v


class Config(BaseModel):
    global_params: dict = Field(
        default_factory=dict, description="global context for the computation graph"
    )

    nodes: List[Node] = Field(
        ..., min_length=1, description="list of nodes in the computation graph"
    )

    @classmethod
    @field_validator("nodes")
    def validate_unique_ids(cls, v: List[Node]) -> List[Node]:
        ids = [node.id for node in v]
        if len(ids) != len(set(ids)):
            duplicates = {id_ for id_ in ids if ids.count(id_) > 1}
            raise ValueError(f"Duplicate node ids found: {duplicates}")
        return v