| ### This is example of the script that will be run in the test environment. | |
| ### Some parts of the code are compulsory and you should NOT CHANGE THEM. | |
| ### They are between '''---compulsory---''' comments. | |
| ### You can change the rest of the code to define and test your solution. | |
| ### However, you should not change the signature of the provided function. | |
| ### The script would save "submission.parquet" file in the current directory. | |
| ### You can use any additional files and subdirectories to organize your code. | |
| '''---compulsory---''' | |
| import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE | |
| '''---compulsory---''' | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import numpy as np | |
| def empty_solution(sample): | |
| '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.''' | |
| return np.zeros((2,3)), [(0, 1)] | |
| if __name__ == "__main__": | |
| print ("------------ Loading dataset------------ ") | |
| params = hoho.get_params() | |
| # by default it is usually better to use `get_dataset()` like this | |
| # | |
| # dataset = hoho.get_dataset(split='all') | |
| # | |
| # but in this case (because we don't do anything with the sample | |
| # anyway) we set `decode=None`. We can set the `split` argument | |
| # to 'train' or 'val' ('all' defaults back to 'train') if we are | |
| # testing ourselves locally. | |
| # | |
| # dataset = hoho.get_dataset(split='val', decode=None) | |
| # | |
| # On the test server *`split` must be set to 'all'* | |
| # to compute both the public and private leaderboards. | |
| # | |
| dataset = hoho.get_dataset(split='all', decode=None) | |
| print('------------ Now you can do your solution ---------------') | |
| solution = [] | |
| for i, sample in enumerate(tqdm(dataset)): | |
| # replace this with your solution | |
| pred_vertices, pred_edges = empty_solution(sample) | |
| solution.append({ | |
| '__key__': sample['__key__'], | |
| 'wf_vertices': pred_vertices.tolist(), | |
| 'wf_edges': pred_edges | |
| }) | |
| print('------------ Saving results ---------------') | |
| sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges"]) | |
| sub.to_parquet(Path(params['output_path']) / "submission.parquet") | |
| print("------------ Done ------------ ") |