| import torch | |
| from torch_geometric.datasets import QM9 as QM9_geometric | |
| from torch_geometric.nn.models.schnet import qm9_target_dict | |
| from torch_geometric.transforms import Compose | |
| class QM9(QM9_geometric): | |
| def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, dataset_arg=None): | |
| assert dataset_arg is not None, ( | |
| "Please pass the desired property to " | |
| 'train on via "dataset_arg". Available ' | |
| f'properties are {", ".join(qm9_target_dict.values())}.' | |
| ) | |
| self.label = dataset_arg | |
| label2idx = dict(zip(qm9_target_dict.values(), qm9_target_dict.keys())) | |
| self.label_idx = label2idx[self.label] | |
| if transform is None: | |
| transform = self._filter_label | |
| else: | |
| transform = Compose([transform, self._filter_label]) | |
| super(QM9, self).__init__(root, transform=transform, pre_transform=pre_transform, pre_filter=pre_filter) | |
| def get_atomref(self, max_z=100): | |
| atomref = self.atomref(self.label_idx) | |
| if atomref is None: | |
| return None | |
| if atomref.size(0) != max_z: | |
| tmp = torch.zeros(max_z).unsqueeze(1) | |
| idx = min(max_z, atomref.size(0)) | |
| tmp[:idx] = atomref[:idx] | |
| return tmp | |
| return atomref | |
| def _filter_label(self, batch): | |
| batch.y = batch.y[:, self.label_idx].unsqueeze(1) | |
| return batch |