Skip to content Skip to sidebar Skip to footer

Pytorch: Different Forward Methods For Train And Test/validation

I'm currently trying to extend a model that is based on FairSeq/PyTorch. During training I need to train two encoders: one with the target sample, and the original one with the sou

Solution 1:

First of all you should always use and define forward not some other methods that you call on the torch.nn.Module instance.

Definitely do not overload eval() as shown by trsvchn as it's evaluation method defined by PyTorch (see here). This method allows layers inside your model to be put into evaluation mode (e.g. specific changes to layers like inference mode for Dropout or BatchNorm).

Furthermore you should call it with __call__ magic method. Why? Because hooks and other PyTorch specific stuff is registered that way properly.

Secondly, do not use some external mode string variable as suggested by @Anant Mittal. That's what train variable in PyTorch is for, it's standard to differentiate by it whether model is in eval mode or train mode.

That being said you are the best off doing it like this:

import torch


classNetwork(torch.nn.Module):
    def__init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forwarddefforward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

You could (and arguably should) split the above into two separate methods, but that's not too bad as the function is rather short and readable that way. Just stick to PyTorch's way of handling things if easily possible and not some ad-hoc solutions. And no, there will be no problem with backpropagation, why would there be one?

Solution 2:

By default, calling model() invoke forward method which is train forward in your case, so you just need to define new method for your test/eval path inside your model class, smth like here:

Code:

classFooBar(nn.Module):
    """Dummy Net for testing/debugging.
    """def__init__(self):
        super().__init__()
        ...

    defforward(self, x):
        # here will be train forward
        ...

    defevaltest(self, x):
        # here will be eval/test forward
        ...

Examples:

model = FooBar()  # initialize model # train time
pred = model(x)   # calls forward() method under the hood# test/eval time
test_pred = model.evaltest(x)

Comment: I would like to recommend you to split these two forward paths into 2 separate methods, because it easier to debug and to avoid some possible problems when backpropagating.

Post a Comment for "Pytorch: Different Forward Methods For Train And Test/validation"