tips on randomness of pytorch

tips on randomness of pytorch

In this post, you will know the seeding mechnism in pytorch, including the behavior of random, torch.random, numpy.random under the context of multi/single-processing data loading. Further, you will know how to ensure distinct randomness between different processes(or workers in Dataloader) and how to ensure reproducibility between different experiments.

To begin with, let’s go through some strange phenonmenon. With following 4 code snippets, you could run each snippets for at least two times and pay attention to the reproducibility between different runs on the same snippets.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import random

torch.manual_seed(555)

class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 10
def __getitem__(self, index):
x = random.randint(0, 10)
return x

m = MyDataset()
data_loader = DataLoader(m, num_workers=0)
for x in data_loader:
print(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import random

torch.manual_seed(555)

class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 10
def __getitem__(self, index):
x = random.randint(0, 10)
return x

m = MyDataset()
data_loader = DataLoader(m, num_workers=1)
for x in data_loader:
print(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import random

random.seed(555)

class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 10
def __getitem__(self, index):
x = random.randint(0, 10)
return x

m = MyDataset()
data_loader = DataLoader(m, num_workers=0)
for x in data_loader:
print(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import random

random.seed(555)

class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 10
def __getitem__(self, index):
x = random.randint(0, 10)
return x

m = MyDataset()
data_loader = DataLoader(m, num_workers=1)
for x in data_loader:
print(x)

You will observe that:

  1. torch.manual_seed(555), num_workers=0 cannot reproduce
  2. torch.manual_seed(555), num_workers=1 can reproduce
  3. random.seed(555), num_workers=0 can reproduce
  4. random.seed(555), num_workers=1 cannot reproduce

Why is it happening? How could torch.manual_seed(555) control the behavior of random.randint when num_workers=1? How couldn’t random.seed(555) control the behavior of random.randint when num_workers=1? And why does the results flip when num_workers=0?

Let’s start from the Dataloader of pytorch.

Dataloader in Pytorch

In pytorch, users usually implement their own Dataset and wrap it with Dataloader to facilitate a ready-to-use data generator. With Dataloader, there are many useful parameters, such as batch_size, shuffle, num_workers, worker_init_fn. With num_workers, we could easily create a multi/single-process dataloader. If we set num_workers to 0, then the data-loading is on the same process as other steps, which might cause other computing steps to wait for the data-loading. Meanwhile, if num_workers>=1, then pytorch will start another worker(process) to load the data.

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily).

explanations of the outputs

So, when num_workers is set to 0, no workers will be created and random.seed package normally control the behavior of random.randint, which has nothing to do with torch.manual_seed(). That’s why following two phenonmenon happended:

  • torch.manual_seed(555), num_workers=0 cannot reproduce
  • random.seed(555), num_workers=0 can reproduce

But, when we set num_workers to 1, pytorch will take control of the random seed of both random and torch.random. That’s why following two phenonmenon happended:

  • torch.manual_seed(555), num_workers=1 can reproduce
  • random.seed(555), num_workers=1 cannot reproduce

This implicit mechanism is not well documented by Pytorch. I guess it is because the pytorch default torchvision.transforms relies on random package, and users might not explicitly set random.seed(_) but torch.manual_seed(_) since they have not even used this random. So pytorch by default take control of random to avoid other issues.

distinctness of different worker

Before, we only consider the problem of reproducibility of different experiments. But, there still exists another problem: could different worker get different seed? For random and torch.random, we already know that the seed is base_seed + worker_id. What about other third-party packages like numpy.random()? Let’s try on following code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 4
def __getitem__(self, index):
x = np.random.randint(0, 10)
return x

m = MyDataset()
data_loader = DataLoader(m, num_workers=2)
for x in data_loader:
print(x)

You will see something like it:

1
2
3
4
tensor([6])
tensor([6])
tensor([4])
tensor([4])

So, different workers generate entirely same random sequences! Actually, it is a long-existing problem for numpy, since by default subprocesses will get a copy of the RNG of main process. If we get back to the Dataloader parameters, you will see worker_init_fn, which is the key to solve this problem.

worker_init_fn(callable object, functions or classes with __call__ implemented) will only make effects when num_workers>=1. As official doc implied, it will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading.

Here, after-seeding means after the seeding of random and torch.random as noted before. Before-data-loading give us a chance to control the numpy.random before we call related codes. We could manually set the numpy random seed here to avoid repeative behavior between different workers.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

class my_worker_init_fn(worker_id):
seed = torch.utils.data.get_worker_info().seed
print(seed)
numpy.random.seed(worker_id)

class MyDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 4
def __getitem__(self, index):
x = np.random.randint(0, 10)
return x

m = MyDataset()
data_loader = DataLoader(m, num_workers=2, worker_init_fn=my_worker_init_fn)
for x in data_loader:
print(x)

distinctness and reproducibility of numpy

So, here comes to the last problem: how to ensure the reproducibility of different experiments while kept the distinctness of different worker of third-party packages like numpy and imgaug. With aforementioned instructions, it is rather trivial since we could just seed the numpy.random with worker_id

1
2
def my_worker_init_fn(worker_id):
np.random.seed(worker_id)

cudnn issues

Actually, to fully ensure the reproducibility, we should also set other parameters related to cudnn. Normally, if you set cudnn.benchmark=False and cudnn.deterministic=True, cudnn behavior is fixed. However, from my experiments, there still exists some operators that cannot be deterministic even under these two settings. For these situations, you might have to disable cudnn support by cudnn.enabled=False, but it will largely affect the speed of forward pass.

conclusions

Now, let’s make a brief summary. In this post, We discuss about two concepts related to randomness: distinctness of different workers and reproducibility of different experiments, whether multi- or single-processing.

If you want to ensure the distinctness of different workers(multi-processing), you ought to manually set all seed of third-party packages with worker_init_fn.

If you want to ensure the reproducibility of different experiments(single-processing), do as following:

1
2
3
4
5
6
7
torch.manual_seed(233)  # network initialization
torch.cuda.manual_seed_all(233)
random.seed(233)
numpy.random.seed(233)

cudnn.benchmark=False
cudnn.deterministic=True

If you want to ensure the reproducibility of different experiments(multi-processing), do as following:

1
2
3
4
5
6
7
8
torch.manual_seed(233)  # network initialization
torch.cuda.manual_seed_all(233)

cudnn.benchmark=False
cudnn.deterministic=True

def _worker_init_fn(worker_id):
numpy.random.seed(worker_id)

tips

  • use random, it is good enough.
  • most third-party packages have similar issues with numpy. Be careful to write your worker_init_fn