Suppose I have data of format (4, 2), and my purpose is to predict three 0 or 1.(simple sigmoid problem)
Since I know this data should be handled with CNN, I will use keras.layers.Conv1D and so on.
Here's a minimal example.
BATCH_SIZE = 2
sample_inputs = keras.Input((4, 2), batch_size=BATCH_SIZE) # shape: (BATCH_SIZE, 4, 2)
cnn_layer = keras.layers.Conv1D(3, kernel_size=2, use_bias=False) # No bias for fewer trainable variables.
cnn_outputs = cnn_layer(sample_inputs) # shape: (BATCH_SIZE, 3, 3)
dense_layer = keras.layers.Dense(1, use_bias=False)
final_outputs = dense_layer(cnn_outputs) # shape: (BATCH_SIZE, 3, 1)
dummy_model = keras.Model(sample_inputs, final_outputs)
But my problem is, I am aware that certain part of y_true is wrong!
Let me illustrate more,
train_x = tf.random.normal((BATCH_SIZE*5, 4, 2)) # Let's say we have 10 number of data
train_y = tf.reshape(tf.random.categorical(tf.math.log([[0.5, 0.5]]),
num_samples=BATCH_SIZE*5*3*1, dtype=tf.int32), (BATCH_SIZE*5, 3, 1)) # to make values 0 or 1.
Say train_y[2:3, :, :] is a tensor shape=(1, 3, 1), dtype=int32, array([[[1], [0], [0]]]).
But I'm sure that the first element of this tensor is contaminated. So I don't want my gradients to be backpropagated for this element, but I still want trainable_variables of dummy_model to be trained for second and third elements since every data is so valuable.
Question1. Am I logically right for this? I can still calculate y_pred using the corresponding one of train_x using the model for that contaminated one. And I can also skip the backpropagation of that certain element for the certain data, when it is inevitable right?
(Maybe this can cause some skewness or cherrypicking the data in some sense, I'm aware that)
Question2. Tensorflow Implementation.
After a long research, my strategy is to use stop_recording() method of tf.GradientTape().
I will tag(or edit) that contaminated part to an integer 100, and whenever the loop notice that there is a 100 in y_batch using tf.where(y_batch == 100), set the cond=True, let the tf.GradientTape() do what he did as usual until then, put tf.UnconnectedGradients.ZERO instead for that contaminated one, ...
I know this sounds crazy, not even sure whether it would worth. But at least, am I thinking right?