Pytorch: Different Forward Methods For Train And Test/validation
Solution 1:
First of all you should always use and define forward
not some other methods that you call on the torch.nn.Module
instance.
Definitely do not overload eval()
as shown by trsvchn as it's evaluation method defined by PyTorch (see here). This method allows layers inside your model to be put into evaluation mode (e.g. specific changes to layers like inference mode for Dropout
or BatchNorm
).
Furthermore you should call it with __call__
magic method. Why? Because hooks and other PyTorch specific stuff is registered that way properly.
Secondly, do not use some external mode
string variable as suggested by @Anant Mittal. That's what train
variable in PyTorch is for, it's standard to differentiate by it whether model is in eval
mode or train
mode.
That being said you are the best off doing it like this:
import torch
classNetwork(torch.nn.Module):
def__init__(self):
super().__init__()
...
# You could split it into two functions but both should be called by forwarddefforward(
self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
if self.train:
return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
You could (and arguably should) split the above into two separate methods, but that's not too bad as the function is rather short and readable that way. Just stick to PyTorch's way of handling things if easily possible and not some ad-hoc solutions. And no, there will be no problem with backpropagation, why would there be one?
Solution 2:
By default, calling model()
invoke forward
method which is train forward in your case, so you just need to define new method for your test/eval path inside your model class, smth like here:
Code:
classFooBar(nn.Module):
"""Dummy Net for testing/debugging.
"""def__init__(self):
super().__init__()
...
defforward(self, x):
# here will be train forward
...
defevaltest(self, x):
# here will be eval/test forward
...
Examples:
model = FooBar() # initialize model # train time
pred = model(x) # calls forward() method under the hood# test/eval time
test_pred = model.evaltest(x)
Comment: I would like to recommend you to split these two forward paths into 2 separate methods, because it easier to debug and to avoid some possible problems when backpropagating.
Post a Comment for "Pytorch: Different Forward Methods For Train And Test/validation"