I have extended nn.Module
to implement my network whose forward function is like this ...
def forward(self, X, **kwargs):
batch_size, seq_len = X.size()
length = kwargs['length']
embedded = self.embedding(X) # [batch_size, seq_len, embedding_dim]
if self.use_padding:
if length is None:
raise AttributeError("Length must be a tensor when using padding")
embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True)
#print("Size of Embedded packed", embedded[0].size())
hidden, cell = self.init_hidden(batch_size)
if self.rnn_unit == 'rnn':
out, _ = self.rnn(embedded, hidden)
elif self.rnn_unit == 'lstm':
out, (hidden, cell) = self.rnn(embedded, (hidden, cell))
# unpack if padding was used
if self.use_padding:
out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first = True)
I initialized a skorch NeuralNetClassifier
like this,
net = NeuralNetClassifier(
model,
criterion=nn.CrossEntropyLoss,
optimizer=Adam,
max_epochs=8,
lr=0.01,
batch_size=32
)
Now if I call net.fit(X, y, length=X_len)
it throws an error
TypeError: __call__() got an unexpected keyword argument 'length'
According to the documentation fit function expects a fit_params
dictionary,
**fit_params : dict Additional parameters passed to the ``forward`` method of the module and to the ``self.train_split`` call.
and the source code always send my parameters to train_split
where obviously my keyword argument would not be recognized.
Is there any way around to pass the arguments to my forward function?
The
fit_params
parameter is intended for passing information that is relevant to data splits and the model alike, like split groups.In your case, you are passing additional data to the module via
fit_params
which is not what it is intended for. In fact, you could easily run into trouble doing this if you, for example, enable batch shuffling on the train data loader since then your lengths and your data are misaligned.The best way to do this is already described in the answer to your question on the issue tracker:
Since skorch supports
dict
s you can simply add the length's to your input dict and have it both passed to the module, nicely batched and passed through the same data loader. In your module you can then access it via the parameters inforward
:Further documentation of this behaviour can be found in the docs.