Question on the number of samples in LSTM

188 Views Asked by At
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import cross_val_score
def build_model():

  model2=Sequential()
  model2.add(LSTM(8,batch_input_shape=(12,12,1),stateful=True))
  model2.add(Dense(8))
  model2.add(Dense(8))
  model2.add(Dense(1))
  model2.compile(loss='mse',optimizer='adam')
  return model2
model=KerasRegressor(build_fn=build_model, epochs=50, batch_size=12, verbose=0)
kfold = KFold(n_splits=5, random_state=np.random.seed(7))
score=cross_val_score(model,ts_x,ts_y,cv=kfold,scoring='neg_mean_squared_error')

ts_x.shape is (228,12,1) ts_y.shape is (228,1,1) As we can see here, I have 228 samples now,but when I run it: ValueError: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size. Found: 183 samples. I want to know why it founded 183 samples instead 228 samples?

1

There are 1 best solutions below

2
On

What the error means:
The batch_size you have provided is 12, that is, 12 records are taken for the training process every time. Now, your total records are 228, which isn't a multiple of 12, so the last batch doesn't have enough records to train.
However, that is not where the problem is. You are also using 5 fold cross-validation. That means your dataset is divided into 5 parts, out of which 1 part is kept untouched as a validation set whereas the model trains on the other 4 parts. The length of these parts is 228/5 = 45.6 and 228*4/5 = 182.4 (~ 183).
So, the model training which occurs is actually on 183 records at a time, which is again, not a multiple of 12.

Potential solution:
You can try setting the batch_size to a factor of 183 (1,3,61,183) which doesn't give you much reasonable options.
So, you can try changing your n_splits to something close (like 6), so that 228 * (n_splits - 1)/n_splits has factors close to 10 (if n_splits is 6, 10 is one of the possible batch_sizes)

Apart from that, I am sorry I don't have experience with tensorflow since I use pytorch, and pytorch doesn't show an error even if the last batch isn't a full batch. Still, you could look at tensorflow's documentation and their own q/a forums to get another answer.

I hope this solves your problem or at least guides you in the right direction towards a solution.