tricks on pytorch

Problem: GPU out of memory. GPU capacity is decreasing during each epoch.

Solution: torch.cuda.empty_cache()

show elements of dictionary

1
2
3
def (dic):
for k, v in dic.items():
print(k)

put model into gpu

1
model = model.cuda()

convert dtype of tensor to float

1
X = X.float()

tensor numpy transform

1
2
3
4

b = a.numpy()
# numpy to tensor
a = torch.from_numpy(b)

save/load the model

1
2
3
4
5
# load
checkpoint = torch.load(path, map_location='cpu')
vae.load_state_dict(checkpoint['state'])
# save
torch.save({'state': model.state_dict(), 'epoch': epoch}, path)