-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
[BUG] loss results are different even though random seed is set #1770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Can you change the input image size when training a classification network with EfficientFormerV2 ?If I change the input to something other than 224, an error will be reported |
@ljm565 there is non determinism inherint in default pytorch models, https://pytorch.org/docs/stable/notes/randomness.html The is also likely to be some in most train scripts. timm isn't doing anything unusual, I typically don't find it worth going to extremes in this area but you are welcome to try. |
@hacktmz Actually, our data size is 3 * 224 * 224 as same as pre-trained EfficientFormer-V2. We have to use 224*224 size if we want to use pre-trained model. |
@rwightman If I don't inherit the timm Efficient-V2 (just using nn.Linear and nn.Conv), all trials make same results. This is because I actually set the torch and random module's seed. Actually, Hugging Face transformer-based models do not show this problem... |
@ljm565 as per the randomness info on PyTorch, you don't get true determinism in PyTorch unless you change default flags, not sure if transformers changes anything by default. Also, these models have batchnorm so you do have to flip between .train() and .eval() |
@ljm565 if you can compare your model as with two others, a |
@rwightman Yes of course, I applied model.train() and model.eval() at different phases. Also, now I tried to test resnet in torchvision.model, this model show the same results at every training. The below is that seeds applied.
|
EfficientFormer-v2
resnet50
|
@ljm565 hmmm, does it happen if you disable the attention bias cache for both attention modules ie, in two locations, change def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if torch.jit.is_tracing() or self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key] to def get_attention_biases(self, device: torch.device) -> torch.Tensor:
return self.attention_biases[:, self.attention_bias_idxs] |
@rwightman I changed function in two location, the loss results are different. Since all loss results in step 0 are the same, the model initialization always seems to be the same due to the random seed. |
@ljm565 hmm, this is definitely odd, even without all those flags but the same seeds, the typical level of non-determinism inherient in benchmarking and cudnn, etc quite small and results don't diverge as much as you see there. For good measure, have you forced Other thought is that maybe there is some numeric stability issue making it sensitive to very small changes. Does enabling gradientg clipping, lowering adam/adamw beta from .999/.99 -> .95 (if using adam) change anything? |
@rwightman According to my code in the question, I used efficientformerv2_s0, so its drop_path_rate may 0. |
@ljm565 try changing |
Unfortunately, It still show the different result at every trial... tested os: Ubuntu 22.04 -------------added------------- |
@ljm565 to be clear, on mac the losses are the same? it may not be your environment but cuda/cudnn instead of CPU? did you try forcing CPU on ubuntu? I guess a fresh environment, pytorch 2.0 might be worth a check too |
Loss results of every training are different despite using efficientformer V2 model with fixed random seed.
More specifically, training steps of each training epoch are 75 and after few steps (around 30 steps), loss results are different.
I think, the issue is come from timm randomness because when I use our customed model, the loss results are the same.
Is there any solution for this issue?
Do I have to use only train.py that they provided?
Thanks
The text was updated successfully, but these errors were encountered: