InternAgent
/
examples
/AutoPower_IEEE39_bus
/AdaptiveHierarchicalGraphTransformer
/src
/dataset
/powerflow_dataset.py
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| import json | |
| from torch_geometric.data import HeteroData | |
| import networkx as nx | |
| class PowerFlowDataset(Dataset): | |
| def __init__(self, data_root, split_txt, pq_len, pv_len, slack_len, mask_num=0): | |
| self.data_root = data_root | |
| with open(split_txt, 'r') as f: | |
| self.file_list = [json.loads(line) for line in f] | |
| self.pq_len = pq_len | |
| self.pv_len = pv_len | |
| self.slack_len = slack_len | |
| self.mask_num = mask_num | |
| # for shortest path | |
| self.flag_distance_once_calculated = False | |
| self.shortest_paths = None | |
| self.node_type_to_global_index = None | |
| self.max_depth = 16 | |
| def __len__(self): | |
| return len(self.file_list) | |
| def update_max_depth(self): | |
| tmp_distance = max(list(self.shortest_paths.values())) | |
| if tmp_distance < self.max_depth: | |
| self.max_depth = tmp_distance | |
| def __getitem__(self, idx): | |
| file_dict = self.file_list[idx] | |
| data = torch.load(os.path.join(file_dict['file_path'])) | |
| pq_num = data['PQ'].x.shape[0] | |
| pv_num = data['PV'].x.shape[0] | |
| slack_num = data['Slack'].x.shape[0] | |
| Vm, Va, P_net, Q_net, Gs, Bs = 0, 1, 2, 3, 4, 5 | |
| # ------- add initial values -------- | |
| # y = Vm, Va, P_net, Q_net | |
| data['PQ'].y = data['PQ'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() | |
| data['PQ'].x[:, Vm] = 1.0 # Vm unknown | |
| data['PQ'].x[:, Va] = data['Slack'].x[0, Va].item() # Va unkonwn, uses value from Slack | |
| non_zero_indices = torch.nonzero(data['PQ'].x[:, Q_net]) | |
| data['PQ'].q_mask = torch.ones((pq_num,),dtype=torch.bool) | |
| if self.mask_num > 0: | |
| if file_dict.get('masked_node') is None: | |
| mask_indices = non_zero_indices[torch.randperm(non_zero_indices.shape[0])[:self.mask_num]] | |
| else: | |
| mask_indices = file_dict['masked_node'][:self.mask_num] | |
| data['PQ'].q_mask[mask_indices] = False | |
| data['PQ'].x[~data['PQ'].q_mask, Q_net] = 0 | |
| data['PV'].y = data['PV'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() | |
| data['PV'].x[:, Va] = data['Slack'].x[0, Va].item() # Va unkonwn, uses value from Slack | |
| data['PV'].x[:, Q_net] = 0 # Q unknown for PV node, set to 0 | |
| data['Slack'].y = data['Slack'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() | |
| data['Slack'].x[:, P_net] = 0 # P_net unkonwn for slack node | |
| data['Slack'].x[:, Q_net] = 0 # Q_net unknown for slack node | |
| return data | |