Inputs must have identical ragged splits. Condition x == y did not hold element-wise. Correct way to concatenate ragged tensors?

180 Views Asked by At

I have medical data consisting of several dataframes of medical test and treatments. One patient can have zero, one or many observations in each dataframe. For example, in the following sample data, the patient 0 has one row in the first dataframe and two rows in the second, while the patient 1 has one and zero rows respectively:

patient_id   |  date       | test_value  |
-----------------------------------------
0            |  2019-08-13 | 0.8         |
1            |  2019-02-05 | 0.5         |
2            |  2019-12-01 | 0.6         |

patient_id   |  treatment_description  | start_date | end_date   |  treatment_value |
------------------------------------------------------------------------------------|
0            |  ...                    | 2019-02-05 | 2019-02-15 |  0.5             |
0            |  ...                    | 2019-12-06 | 2019-04-05 |  0.2             |
2            |  ...                    | 2019-04-14 | 2019-10-12 |  0.3             |
3            |  ...                    | 2019-05-04 | 2019-09-10 |  0.4             |

I need to build a binary classifier of patients, and I am approaching the problem with a ragged multi-input Keras model. To that end, I defined my training tensors as:

X_1 = tf.ragged.constant(X_1, ragged_rank=1)
X_2 = tf.ragged.constant(X_2, ragged_rank=1)

Each input layer i accepts ragged tensors with shape (None, None, num_features_i), being the first None the patient axes, the second the number of observations of that patient, and the third the number of features (i.e. the number of columns in the sample data). Tu sum up, I defined my model as follows:

import tensorflow as tf

input_layer_1 = tf.keras.layers.Input(shape=(None, 3), ragged=True)
input_layer_2 = tf.keras.layers.Input(shape=(None, 4), ragged=True)

x = tf.keras.layers.Concatenate(axis=-1)([input_layer_1, input_layer_2])
x = tf.keras.layers.LSTM(8, return_sequences=False)(x)
x = tf.keras.layers.Dense(8, activation='relu')(x)

output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.Model(inputs=[input_layer_1, input_layer_2], outputs=output_layer)

model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=0.01), loss = 'binary_crossentropy', metrics=['accuracy', 'AUC'])

model.summary() gives the following:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, None, 3)]    0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, None, 4)]    0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, None, 7)      0           ['input_1[0][0]',                
                                                                  'input_2[0][0]']                
                                                                                                  
 lstm (LSTM)                    (None, 8)            512         ['concatenate[0][0]']            
                                                                                                  
 dense (Dense)                  (None, 8)            72          ['lstm[0][0]']                   
                                                                                                  
 dense_1 (Dense)                (None, 1)            9           ['dense[0][0]']                  
                                                                                                  
==================================================================================================
Total params: 593
Trainable params: 593
Non-trainable params: 0
__________________________________________________________________________________________________

Everything seems to work just fine until model.fit(). Sometimes it outputs:

InvalidArgumentError: Graph execution error:

Detected at node 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "/opt/conda/envs/Python-3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/conda/envs/Python-3.10/lib/python3.10/runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/__main__.py", line 3, in <module>
      app.launch_new_instance()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 677, in start
      self.io_loop.start()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
      self._run_once()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/asyncio/base_events.py", line 1906, in _run_once
      handle._run()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 367, in dispatch_shell
      await result
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2881, in run_cell
      result = self._run_cell(
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2936, in _run_cell
      return runner(coro)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3135, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3338, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/wsuser/ipykernel_933/3825575363.py", line 7, in <cell line: 7>
      history = eval('model.fit((' + ', '.join([f'X_{name}_train' for name in all_dfs.keys()]) + '), y_train, epochs = 10, batch_size = 8)')
    File "<string>", line 1, in <module>
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/training.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/training.py", line 1051, in train_function
      return step_function(self, iterator)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/training.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/training.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/training.py", line 889, in train_step
      y_pred = self(x, training=True)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/training.py", line 490, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 458, in call
      return self._run_internal_graph(
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/layers/merging/base_merge.py", line 178, in call
      return self._merge_function(inputs)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/layers/merging/concatenate.py", line 126, in _merge_function
      return backend.concatenate(inputs, axis=self.axis)
    File "/opt/conda/envs/Python-3.10/lib/python3.10/site-packages/keras/backend.py", line 3311, in concatenate
      return tf.concat(tensors, axis)
Node: 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert'
assertion failed: [Inputs must have identical ragged splits] [Condition x == y did not hold element-wise:] [x (RaggedFromVariant/RaggedTensorFromVariant:0) = ] [0 0 2...] [y (RaggedFromVariant_1/RaggedTensorFromVariant:0) = ] [0 0 0...]
     [[{{node model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert}}]] [Op:__inference_train_function_3962]

And sometimes:

InvalidArgumentError: Graph execution error:

Tried to stack list which only contains uninitialized tensors and has a non-fully-defined element_shape: [?,8]
     [[{{node TensorArrayV2Stack/TensorListStack}}]]
     [[model/lstm/PartitionedCall]] [Op:__inference_train_function_3962]

If I put a x = tf.keras.layers.GlobalMaxPooling1D()(x)layer instead of LSTM, the behavior is similar. I am pretty sure that the axis of concatenation should be -1 for my problem; however, If I simply do:

import tensorflow as tf

rt1 = tf.ragged.constant([
    [[1, 2], [2, 3], [4, 5]], 
    [[1, 2], [2, 3]], 
])
rt2 = tf.ragged.constant([
    [[6, 7, 8], [8, 9, 10]], 
    [[6, 7, 8], [8, 9, 10], [1, 2, 3]], 
])

concatenated = tf.concat([rt1, rt2], axis=-1)

print(concatenated)

I still get

InvalidArgumentError: Inputs must have identical ragged splits
Condition x == y did not hold.
Indices of first 1 different values:
[[1]]
Corresponding x values:
[3]
Corresponding y values:
[2]
First 3 elements of x:
[0 3 5]
First 3 elements of y:
[0 2 5]

If I do:

X_1 = tf.ragged.constant([
    [[1, 2, 3], [1, 2, 3], [1, 2, 3]], 
    [[1, 2, 3], [1, 2, 3]], 
], ragged_rank=1)

X_2 = tf.ragged.constant([
    [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], 
    [[1, 2, 3, 4], [1, 2, 3, 4]], 
], ragged_rank=1)
y = np.array([0, 1])

model.fit((X_1, X_2), y) works fine. If instead:

X_1 = tf.ragged.constant([
    [[1, 2, 3], [1, 2, 3], [1, 2, 3]], 
    [[1, 2, 3]], 
], ragged_rank=1)

X_2 = tf.ragged.constant([
    [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], 
    [[1, 2, 3, 4], [1, 2, 3, 4]], 
], ragged_rank=1)
y = np.array([0, 1])

model.fit((X_1, X_2), y) outputs the same error again. So the problem seems to be the patients having different number of observations in different dataframes... but that is exactly what I am trying to approach here...

So what am I missing here? Why can I not concatenate along the last dimension? Should I concatenate along the last dimension at all?

0

There are 0 best solutions below