Use torch.inference_mode() and disable gradient checkpointing

#4
by prathamj31 - opened