from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import cross_val_score
def build_model():
model2=Sequential()
model2.add(LSTM(8,batch_input_shape=(12,12,1),stateful=True))
model2.add(Dense(8))
model2.add(Dense(8))
model2.add(Dense(1))
model2.compile(loss='mse',optimizer='adam')
return model2
model=KerasRegressor(build_fn=build_model, epochs=50, batch_size=12, verbose=0)
kfold = KFold(n_splits=5, random_state=np.random.seed(7))
score=cross_val_score(model,ts_x,ts_y,cv=kfold,scoring='neg_mean_squared_error')
ts_x.shape is (228,12,1)
ts_y.shape is (228,1,1)
As we can see here, I have 228 samples now,but when I run it:
ValueError: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size. Found: 183 samples.
I want to know why it founded 183 samples instead 228 samples?
What the error means:
The batch_size you have provided is
12
, that is,12
records are taken for the training process every time. Now, your total records are228
, which isn't a multiple of12
, so the last batch doesn't have enough records to train.However, that is not where the problem is. You are also using 5 fold cross-validation. That means your dataset is divided into 5 parts, out of which 1 part is kept untouched as a validation set whereas the model trains on the other 4 parts. The length of these parts is
228/5 = 45.6
and228*4/5 = 182.4 (~ 183)
.So, the model training which occurs is actually on
183
records at a time, which is again, not a multiple of12
.Potential solution:
You can try setting the
batch_size
to a factor of 183 (1,3,61,183) which doesn't give you much reasonable options.So, you can try changing your
n_splits
to something close (like 6), so that228 * (n_splits - 1)/n_splits
has factors close to 10 (ifn_splits
is6
,10
is one of the possiblebatch_size
s)Apart from that, I am sorry I don't have experience with tensorflow since I use pytorch, and pytorch doesn't show an error even if the last batch isn't a full batch. Still, you could look at tensorflow's documentation and their own q/a forums to get another answer.
I hope this solves your problem or at least guides you in the right direction towards a solution.