GraphGen / graphgen /bases /datatypes.py
github-actions[bot]
Auto-sync from demo at Tue Dec 16 08:21:05 UTC 2025
31086ae
import math
from dataclasses import dataclass, field
from typing import List, Union
from pydantic import BaseModel, Field, field_validator
@dataclass
class Chunk:
id: str
content: str
type: str
metadata: dict = field(default_factory=dict)
@staticmethod
def from_dict(key: str, data: dict) -> "Chunk":
return Chunk(
id=key,
content=data.get("content", ""),
type=data.get("type", "text"),
metadata={k: v for k, v in data.items() if k != "content"},
)
@dataclass
class QAPair:
"""
A pair of question and answer.
"""
question: str
answer: str
@dataclass
class Token:
text: str
prob: float
top_candidates: List = field(default_factory=list)
ppl: Union[float, None] = field(default=None)
@property
def logprob(self) -> float:
return math.log(self.prob)
@dataclass
class Community:
id: Union[int, str]
nodes: List[str] = field(default_factory=list)
edges: List[tuple] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
class Node(BaseModel):
id: str = Field(..., description="unique node id")
op_name: str = Field(..., description="operator name")
type: str = Field(
..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch"
)
params: dict = Field(default_factory=dict, description="operator parameters")
dependencies: List[str] = Field(
default_factory=list, description="list of dependent node ids"
)
execution_params: dict = Field(
default_factory=dict, description="execution parameters like replicas, batch_size"
)
@classmethod
@field_validator("type")
def validate_type(cls, v: str) -> str:
valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"}
if v not in valid_types:
raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.")
return v
class Config(BaseModel):
global_params: dict = Field(
default_factory=dict, description="global context for the computation graph"
)
nodes: List[Node] = Field(
..., min_length=1, description="list of nodes in the computation graph"
)
@classmethod
@field_validator("nodes")
def validate_unique_ids(cls, v: List[Node]) -> List[Node]:
ids = [node.id for node in v]
if len(ids) != len(set(ids)):
duplicates = {id_ for id_ in ids if ids.count(id_) > 1}
raise ValueError(f"Duplicate node ids found: {duplicates}")
return v